transformers-rb 0.1.0 → 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +7 -0
- data/README.md +94 -14
- data/lib/transformers/configuration_utils.rb +7 -0
- data/lib/transformers/feature_extraction_utils.rb +1 -0
- data/lib/transformers/modeling_utils.rb +63 -11
- data/lib/transformers/models/bert/modeling_bert.rb +3 -3
- data/lib/transformers/pipelines/_init.rb +19 -1
- data/lib/transformers/pipelines/embedding.rb +46 -0
- data/lib/transformers/pipelines/image_classification.rb +1 -1
- data/lib/transformers/pipelines/image_feature_extraction.rb +1 -1
- data/lib/transformers/pipelines/text_classification.rb +2 -2
- data/lib/transformers/pipelines/token_classification.rb +76 -1
- data/lib/transformers/sentence_transformer.rb +3 -20
- data/lib/transformers/tokenization_utils_fast.rb +8 -0
- data/lib/transformers/version.rb +1 -1
- data/lib/transformers.rb +7 -0
- metadata +5 -4
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 2b7441c0cc400d85d3d0775e6c53a631c83887e481588148a7cb3242e50c0f08
|
4
|
+
data.tar.gz: df639eb5c3231d344ec8516ec861cb40239ac5ff0310faa5b1b102a79af97972
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 566a03e8a65255d3eb7e87bb3e68c0ac4d228ba4bc55e62eaad98f6677b8a5f29ee9201efaefc3edb309d279c1a986a664dfffbf3cc292114394be1c10fa949a
|
7
|
+
data.tar.gz: fdae84f4fa8aac90f1a627d738fb701ef91d384414173f4042bac56bfe5c78c6ceb98aaea8ef0f2158d324ec60d18cbd91f84bd30f42b49d63193392e0105e96
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -2,6 +2,8 @@
|
|
2
2
|
|
3
3
|
:slightly_smiling_face: State-of-the-art [transformers](https://github.com/huggingface/transformers) for Ruby
|
4
4
|
|
5
|
+
For fast inference, check out [Informers](https://github.com/ankane/informers) :fire:
|
6
|
+
|
5
7
|
[](https://github.com/ankane/transformers-ruby/actions)
|
6
8
|
|
7
9
|
## Installation
|
@@ -21,6 +23,17 @@ gem "transformers-rb"
|
|
21
23
|
|
22
24
|
## Models
|
23
25
|
|
26
|
+
Embedding
|
27
|
+
|
28
|
+
- [sentence-transformers/all-MiniLM-L6-v2](#sentence-transformersall-MiniLM-L6-v2)
|
29
|
+
- [sentence-transformers/multi-qa-MiniLM-L6-cos-v1](#sentence-transformersmulti-qa-MiniLM-L6-cos-v1)
|
30
|
+
- [mixedbread-ai/mxbai-embed-large-v1](#mixedbread-aimxbai-embed-large-v1)
|
31
|
+
- [thenlper/gte-small](#thenlpergte-small)
|
32
|
+
- [intfloat/e5-base-v2](#intfloate5-base-v2)
|
33
|
+
- [BAAI/bge-base-en-v1.5](#baaibge-base-en-v15)
|
34
|
+
- [Snowflake/snowflake-arctic-embed-m-v1.5](#snowflakesnowflake-arctic-embed-m-v15)
|
35
|
+
- [opensearch-project/opensearch-neural-sparse-encoding-v1](#opensearch-projectopensearch-neural-sparse-encoding-v1)
|
36
|
+
|
24
37
|
### sentence-transformers/all-MiniLM-L6-v2
|
25
38
|
|
26
39
|
[Docs](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)
|
@@ -28,8 +41,8 @@ gem "transformers-rb"
|
|
28
41
|
```ruby
|
29
42
|
sentences = ["This is an example sentence", "Each sentence is converted"]
|
30
43
|
|
31
|
-
model = Transformers
|
32
|
-
embeddings = model.
|
44
|
+
model = Transformers.pipeline("embedding", "sentence-transformers/all-MiniLM-L6-v2")
|
45
|
+
embeddings = model.(sentences)
|
33
46
|
```
|
34
47
|
|
35
48
|
### sentence-transformers/multi-qa-MiniLM-L6-cos-v1
|
@@ -40,10 +53,10 @@ embeddings = model.encode(sentences)
|
|
40
53
|
query = "How many people live in London?"
|
41
54
|
docs = ["Around 9 Million people live in London", "London is known for its financial district"]
|
42
55
|
|
43
|
-
model = Transformers
|
44
|
-
|
45
|
-
|
46
|
-
scores =
|
56
|
+
model = Transformers.pipeline("embedding", "sentence-transformers/multi-qa-MiniLM-L6-cos-v1")
|
57
|
+
query_embedding = model.(query)
|
58
|
+
doc_embeddings = model.(docs)
|
59
|
+
scores = doc_embeddings.map { |e| e.zip(query_embedding).sum { |d, q| d * q } }
|
47
60
|
doc_score_pairs = docs.zip(scores).sort_by { |d, s| -s }
|
48
61
|
```
|
49
62
|
|
@@ -52,18 +65,78 @@ doc_score_pairs = docs.zip(scores).sort_by { |d, s| -s }
|
|
52
65
|
[Docs](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1)
|
53
66
|
|
54
67
|
```ruby
|
55
|
-
|
56
|
-
"Represent this sentence for searching relevant passages: #{query}"
|
57
|
-
end
|
68
|
+
query_prefix = "Represent this sentence for searching relevant passages: "
|
58
69
|
|
59
|
-
|
60
|
-
transform_query("puppy"),
|
70
|
+
input = [
|
61
71
|
"The dog is barking",
|
62
|
-
"The cat is purring"
|
72
|
+
"The cat is purring",
|
73
|
+
query_prefix + "puppy"
|
74
|
+
]
|
75
|
+
|
76
|
+
model = Transformers.pipeline("embedding", "mixedbread-ai/mxbai-embed-large-v1")
|
77
|
+
embeddings = model.(input)
|
78
|
+
```
|
79
|
+
|
80
|
+
### thenlper/gte-small
|
81
|
+
|
82
|
+
[Docs](https://huggingface.co/thenlper/gte-small)
|
83
|
+
|
84
|
+
```ruby
|
85
|
+
sentences = ["That is a happy person", "That is a very happy person"]
|
86
|
+
|
87
|
+
model = Transformers.pipeline("embedding", "thenlper/gte-small")
|
88
|
+
embeddings = model.(sentences)
|
89
|
+
```
|
90
|
+
|
91
|
+
### intfloat/e5-base-v2
|
92
|
+
|
93
|
+
[Docs](https://huggingface.co/intfloat/e5-base-v2)
|
94
|
+
|
95
|
+
```ruby
|
96
|
+
doc_prefix = "passage: "
|
97
|
+
query_prefix = "query: "
|
98
|
+
|
99
|
+
input = [
|
100
|
+
doc_prefix + "Ruby is a programming language created by Matz",
|
101
|
+
query_prefix + "Ruby creator"
|
63
102
|
]
|
64
103
|
|
65
|
-
model = Transformers
|
66
|
-
embeddings = model.
|
104
|
+
model = Transformers.pipeline("embedding", "intfloat/e5-base-v2")
|
105
|
+
embeddings = model.(input)
|
106
|
+
```
|
107
|
+
|
108
|
+
### BAAI/bge-base-en-v1.5
|
109
|
+
|
110
|
+
[Docs](https://huggingface.co/BAAI/bge-base-en-v1.5)
|
111
|
+
|
112
|
+
```ruby
|
113
|
+
query_prefix = "Represent this sentence for searching relevant passages: "
|
114
|
+
|
115
|
+
input = [
|
116
|
+
"The dog is barking",
|
117
|
+
"The cat is purring",
|
118
|
+
query_prefix + "puppy"
|
119
|
+
]
|
120
|
+
|
121
|
+
model = Transformers.pipeline("embedding", "BAAI/bge-base-en-v1.5")
|
122
|
+
embeddings = model.(input)
|
123
|
+
```
|
124
|
+
|
125
|
+
### Snowflake/snowflake-arctic-embed-m-v1.5
|
126
|
+
|
127
|
+
[Docs](https://huggingface.co/Snowflake/snowflake-arctic-embed-m-v1.5)
|
128
|
+
|
129
|
+
```ruby
|
130
|
+
query_prefix = "Represent this sentence for searching relevant passages: "
|
131
|
+
|
132
|
+
input = [
|
133
|
+
"The dog is barking",
|
134
|
+
"The cat is purring",
|
135
|
+
query_prefix + "puppy"
|
136
|
+
]
|
137
|
+
|
138
|
+
model = Transformers.pipeline("embedding", "Snowflake/snowflake-arctic-embed-m-v1.5")
|
139
|
+
embeddings = model.(input, pooling: "cls")
|
67
140
|
```
|
68
141
|
|
69
142
|
### opensearch-project/opensearch-neural-sparse-encoding-v1
|
@@ -89,6 +162,13 @@ embeddings = values.to_a
|
|
89
162
|
|
90
163
|
## Pipelines
|
91
164
|
|
165
|
+
Embedding
|
166
|
+
|
167
|
+
```ruby
|
168
|
+
embed = Transformers.pipeline("embedding")
|
169
|
+
embed.("We are very happy to show you the 🤗 Transformers library.")
|
170
|
+
```
|
171
|
+
|
92
172
|
Named-entity recognition
|
93
173
|
|
94
174
|
```ruby
|
@@ -207,10 +207,17 @@ module Transformers
|
|
207
207
|
|
208
208
|
config = new(**config_dict)
|
209
209
|
|
210
|
+
to_remove = []
|
210
211
|
kwargs.each do |key, value|
|
211
212
|
if config.respond_to?("#{key}=")
|
212
213
|
config.public_send("#{key}=", value)
|
213
214
|
end
|
215
|
+
if key != :torch_dtype
|
216
|
+
to_remove << key
|
217
|
+
end
|
218
|
+
end
|
219
|
+
to_remove.each do |key|
|
220
|
+
kwargs.delete(key)
|
214
221
|
end
|
215
222
|
|
216
223
|
Transformers.logger.info("Model config #{config}")
|
@@ -14,6 +14,47 @@
|
|
14
14
|
# limitations under the License.
|
15
15
|
|
16
16
|
module Transformers
|
17
|
+
module ModelingUtils
|
18
|
+
TORCH_INIT_FUNCTIONS = {
|
19
|
+
"uniform!" => Torch::NN::Init.method(:uniform!),
|
20
|
+
"normal!" => Torch::NN::Init.method(:normal!),
|
21
|
+
# "trunc_normal!" => Torch::NN::Init.method(:trunc_normal!),
|
22
|
+
"constant!" => Torch::NN::Init.method(:constant!),
|
23
|
+
"xavier_uniform!" => Torch::NN::Init.method(:xavier_uniform!),
|
24
|
+
"xavier_normal!" => Torch::NN::Init.method(:xavier_normal!),
|
25
|
+
"kaiming_uniform!" => Torch::NN::Init.method(:kaiming_uniform!),
|
26
|
+
"kaiming_normal!" => Torch::NN::Init.method(:kaiming_normal!),
|
27
|
+
# "uniform" => Torch::NN::Init.method(:uniform),
|
28
|
+
# "normal" => Torch::NN::Init.method(:normal),
|
29
|
+
# "xavier_uniform" => Torch::NN::Init.method(:xavier_uniform),
|
30
|
+
# "xavier_normal" => Torch::NN::Init.method(:xavier_normal),
|
31
|
+
# "kaiming_uniform" => Torch::NN::Init.method(:kaiming_uniform),
|
32
|
+
# "kaiming_normal" => Torch::NN::Init.method(:kaiming_normal)
|
33
|
+
}
|
34
|
+
|
35
|
+
# private
|
36
|
+
# note: this improves loading time significantly, but is not thread-safe!
|
37
|
+
def self.no_init_weights
|
38
|
+
return yield unless Transformers.fast_init
|
39
|
+
|
40
|
+
_skip_init = lambda do |*args, **kwargs|
|
41
|
+
# pass
|
42
|
+
end
|
43
|
+
# Save the original initialization functions
|
44
|
+
TORCH_INIT_FUNCTIONS.each do |name, init_func|
|
45
|
+
Torch::NN::Init.singleton_class.undef_method(name)
|
46
|
+
Torch::NN::Init.define_singleton_method(name, &_skip_init)
|
47
|
+
end
|
48
|
+
yield
|
49
|
+
ensure
|
50
|
+
# Restore the original initialization functions
|
51
|
+
TORCH_INIT_FUNCTIONS.each do |name, init_func|
|
52
|
+
Torch::NN::Init.singleton_class.undef_method(name)
|
53
|
+
Torch::NN::Init.define_singleton_method(name, init_func)
|
54
|
+
end
|
55
|
+
end
|
56
|
+
end
|
57
|
+
|
17
58
|
module ModuleUtilsMixin
|
18
59
|
def get_extended_attention_mask(
|
19
60
|
attention_mask,
|
@@ -138,7 +179,11 @@ module Transformers
|
|
138
179
|
end
|
139
180
|
|
140
181
|
def _initialize_weights(mod)
|
182
|
+
if mod.instance_variable_defined?(:@is_hf_initialized)
|
183
|
+
return
|
184
|
+
end
|
141
185
|
_init_weights(mod)
|
186
|
+
mod.instance_variable_set(:@is_hf_initialized, true)
|
142
187
|
end
|
143
188
|
|
144
189
|
def tie_weights
|
@@ -166,7 +211,9 @@ module Transformers
|
|
166
211
|
prune_heads(@config.pruned_heads)
|
167
212
|
end
|
168
213
|
|
169
|
-
|
214
|
+
# TODO implement no_init_weights context manager
|
215
|
+
_init_weights = false
|
216
|
+
if _init_weights
|
170
217
|
# Initialize weights
|
171
218
|
apply(method(:_initialize_weights))
|
172
219
|
|
@@ -512,8 +559,8 @@ module Transformers
|
|
512
559
|
|
513
560
|
config.name_or_path = pretrained_model_name_or_path
|
514
561
|
|
515
|
-
|
516
|
-
model = new(config, *model_args, **model_kwargs)
|
562
|
+
# Instantiate model.
|
563
|
+
model = ModelingUtils.no_init_weights { new(config, *model_args, **model_kwargs) }
|
517
564
|
|
518
565
|
# make sure we use the model's config since the __init__ call might have copied it
|
519
566
|
config = model.config
|
@@ -683,6 +730,10 @@ module Transformers
|
|
683
730
|
end
|
684
731
|
end
|
685
732
|
|
733
|
+
if _fast_init
|
734
|
+
# TODO
|
735
|
+
end
|
736
|
+
|
686
737
|
# Make sure we are able to load base models as well as derived models (with heads)
|
687
738
|
start_prefix = ""
|
688
739
|
model_to_load = model
|
@@ -756,28 +807,29 @@ module Transformers
|
|
756
807
|
raise Todo
|
757
808
|
end
|
758
809
|
|
810
|
+
model_class_name = model.class.name.split("::").last
|
759
811
|
if unexpected_keys.length > 0
|
760
812
|
archs = model.config.architectures.nil? ? [] : model.config.architectures
|
761
|
-
warner = archs.include?(
|
813
|
+
warner = archs.include?(model_class_name) ? Transformers.logger.method(:warn) : Transformers.logger.method(:info)
|
762
814
|
warner.(
|
763
815
|
"Some weights of the model checkpoint at #{pretrained_model_name_or_path} were not used when" +
|
764
|
-
" initializing #{
|
765
|
-
" initializing #{
|
816
|
+
" initializing #{model_class_name}: #{unexpected_keys}\n- This IS expected if you are" +
|
817
|
+
" initializing #{model_class_name} from the checkpoint of a model trained on another task or" +
|
766
818
|
" with another architecture (e.g. initializing a BertForSequenceClassification model from a" +
|
767
819
|
" BertForPreTraining model).\n- This IS NOT expected if you are initializing" +
|
768
|
-
" #{
|
820
|
+
" #{model_class_name} from the checkpoint of a model that you expect to be exactly identical" +
|
769
821
|
" (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
|
770
822
|
)
|
771
823
|
else
|
772
|
-
Transformers.logger.info("All model checkpoint weights were used when initializing #{
|
824
|
+
Transformers.logger.info("All model checkpoint weights were used when initializing #{model_class_name}.\n")
|
773
825
|
end
|
774
826
|
if missing_keys.length > 0
|
775
|
-
Transformers.logger.info("Some weights of #{
|
827
|
+
Transformers.logger.info("Some weights of #{model_class_name} were not initialized from the model checkpoint at #{pretrained_model_name_or_path} and are newly initialized: #{missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.")
|
776
828
|
elsif mismatched_keys.length == 0
|
777
829
|
Transformers.logger.info(
|
778
|
-
"All the weights of #{
|
830
|
+
"All the weights of #{model_class_name} were initialized from the model checkpoint at" +
|
779
831
|
" #{pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint" +
|
780
|
-
" was trained on, you can already use #{
|
832
|
+
" was trained on, you can already use #{model_class_name} for predictions without further" +
|
781
833
|
" training."
|
782
834
|
)
|
783
835
|
end
|
@@ -97,7 +97,7 @@ module Transformers
|
|
97
97
|
@position_embedding_type = position_embedding_type || config.position_embedding_type || "absolute"
|
98
98
|
if @position_embedding_type == "relative_key" || @position_embedding_type == "relative_key_query"
|
99
99
|
@max_position_embeddings = config.max_position_embeddings
|
100
|
-
@distance_embedding = Torch
|
100
|
+
@distance_embedding = Torch::NN::Embedding.new(2 * config.max_position_embeddings - 1, @attention_head_size)
|
101
101
|
end
|
102
102
|
|
103
103
|
@is_decoder = config.is_decoder
|
@@ -639,8 +639,8 @@ module Transformers
|
|
639
639
|
extended_attention_mask = get_extended_attention_mask(attention_mask, input_shape)
|
640
640
|
end
|
641
641
|
|
642
|
-
#
|
643
|
-
#
|
642
|
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
643
|
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
644
644
|
if @config.is_decoder && !encoder_hidden_states.nil?
|
645
645
|
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size
|
646
646
|
encoder_hidden_shape = [encoder_batch_size, encoder_sequence_length]
|
@@ -78,7 +78,17 @@ module Transformers
|
|
78
78
|
}
|
79
79
|
},
|
80
80
|
"type" => "image"
|
81
|
-
}
|
81
|
+
},
|
82
|
+
"embedding" => {
|
83
|
+
"impl" => EmbeddingPipeline,
|
84
|
+
"pt" => [AutoModel],
|
85
|
+
"default" => {
|
86
|
+
"model" => {
|
87
|
+
"pt" => ["sentence-transformers/all-MiniLM-L6-v2", "8b3219a"]
|
88
|
+
}
|
89
|
+
},
|
90
|
+
"type" => "text"
|
91
|
+
},
|
82
92
|
}
|
83
93
|
|
84
94
|
PIPELINE_REGISTRY = PipelineRegistry.new(supported_tasks: SUPPORTED_TASKS, task_aliases: TASK_ALIASES)
|
@@ -86,6 +96,7 @@ module Transformers
|
|
86
96
|
class << self
|
87
97
|
def pipeline(
|
88
98
|
task,
|
99
|
+
model_arg = nil,
|
89
100
|
model: nil,
|
90
101
|
config: nil,
|
91
102
|
tokenizer: nil,
|
@@ -103,6 +114,13 @@ module Transformers
|
|
103
114
|
pipeline_class: nil,
|
104
115
|
**kwargs
|
105
116
|
)
|
117
|
+
if !model_arg.nil?
|
118
|
+
if !model.nil?
|
119
|
+
raise ArgumentError, "Cannot pass multiple models"
|
120
|
+
end
|
121
|
+
model = model_arg
|
122
|
+
end
|
123
|
+
|
106
124
|
model_kwargs ||= {}
|
107
125
|
# Make sure we only pass use_auth_token once as a kwarg (it used to be possible to pass it in model_kwargs,
|
108
126
|
# this is to keep BC).
|
@@ -0,0 +1,46 @@
|
|
1
|
+
module Transformers
|
2
|
+
class EmbeddingPipeline < Pipeline
|
3
|
+
def _sanitize_parameters(**kwargs)
|
4
|
+
[{}, {}, kwargs]
|
5
|
+
end
|
6
|
+
|
7
|
+
def preprocess(inputs)
|
8
|
+
@tokenizer.(inputs, return_tensors: @framework)
|
9
|
+
end
|
10
|
+
|
11
|
+
def _forward(model_inputs)
|
12
|
+
{
|
13
|
+
last_hidden_state: @model.(**model_inputs)[0],
|
14
|
+
attention_mask: model_inputs[:attention_mask]
|
15
|
+
}
|
16
|
+
end
|
17
|
+
|
18
|
+
def postprocess(model_outputs, pooling: "mean", normalize: true)
|
19
|
+
output = model_outputs[:last_hidden_state]
|
20
|
+
|
21
|
+
case pooling
|
22
|
+
when "none"
|
23
|
+
# do nothing
|
24
|
+
when "mean"
|
25
|
+
output = mean_pooling(output, model_outputs[:attention_mask])
|
26
|
+
when "cls"
|
27
|
+
output = output[0.., 0]
|
28
|
+
else
|
29
|
+
raise Error, "Pooling method '#{pooling}' not supported."
|
30
|
+
end
|
31
|
+
|
32
|
+
if normalize
|
33
|
+
output = Torch::NN::Functional.normalize(output, p: 2, dim: 1)
|
34
|
+
end
|
35
|
+
|
36
|
+
output[0].to_a
|
37
|
+
end
|
38
|
+
|
39
|
+
private
|
40
|
+
|
41
|
+
def mean_pooling(output, attention_mask)
|
42
|
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(output.size).float
|
43
|
+
Torch.sum(output * input_mask_expanded, 1) / Torch.clamp(input_mask_expanded.sum(1), min: 1e-9)
|
44
|
+
end
|
45
|
+
end
|
46
|
+
end
|
@@ -34,7 +34,7 @@ module Transformers
|
|
34
34
|
end
|
35
35
|
|
36
36
|
if function_to_apply.is_a?(String)
|
37
|
-
function_to_apply = ClassificationFunction.new(function_to_apply.
|
37
|
+
function_to_apply = ClassificationFunction.new(function_to_apply.downcase).to_s
|
38
38
|
end
|
39
39
|
|
40
40
|
if !function_to_apply.nil?
|
@@ -62,7 +62,7 @@ module Transformers
|
|
62
62
|
end
|
63
63
|
|
64
64
|
def _forward(model_inputs)
|
65
|
-
@model.(**model_inputs
|
65
|
+
@model.(**model_inputs)
|
66
66
|
end
|
67
67
|
|
68
68
|
def postprocess(model_outputs, function_to_apply: nil, top_k: 1, _legacy: true)
|
@@ -62,7 +62,7 @@ module Transformers
|
|
62
62
|
|
63
63
|
if !aggregation_strategy.nil?
|
64
64
|
if aggregation_strategy.is_a?(String)
|
65
|
-
aggregation_strategy = AggregationStrategy.new(aggregation_strategy.
|
65
|
+
aggregation_strategy = AggregationStrategy.new(aggregation_strategy.downcase).to_s
|
66
66
|
end
|
67
67
|
if (
|
68
68
|
[AggregationStrategy::FIRST, AggregationStrategy::MAX, AggregationStrategy::AVERAGE].include?(aggregation_strategy) &&
|
@@ -278,5 +278,80 @@ module Transformers
|
|
278
278
|
end
|
279
279
|
group_entities(entities)
|
280
280
|
end
|
281
|
+
|
282
|
+
def aggregate_word(entities, aggregation_strategy)
|
283
|
+
raise Todo
|
284
|
+
end
|
285
|
+
|
286
|
+
def aggregate_words(entities, aggregation_strategy)
|
287
|
+
raise Todo
|
288
|
+
end
|
289
|
+
|
290
|
+
def group_sub_entities(entities)
|
291
|
+
# Get the first entity in the entity group
|
292
|
+
entity = entities[0][:entity].split("-", 2)[-1]
|
293
|
+
scores = entities.map { |entity| entity[:score] }
|
294
|
+
tokens = entities.map { |entity| entity[:word] }
|
295
|
+
|
296
|
+
entity_group = {
|
297
|
+
entity_group: entity,
|
298
|
+
score: scores.sum / scores.count.to_f,
|
299
|
+
word: @tokenizer.convert_tokens_to_string(tokens),
|
300
|
+
start: entities[0][:start],
|
301
|
+
end: entities[-1][:end]
|
302
|
+
}
|
303
|
+
entity_group
|
304
|
+
end
|
305
|
+
|
306
|
+
def get_tag(entity_name)
|
307
|
+
if entity_name.start_with?("B-")
|
308
|
+
bi = "B"
|
309
|
+
tag = entity_name[2..]
|
310
|
+
elsif entity_name.start_with?("I-")
|
311
|
+
bi = "I"
|
312
|
+
tag = entity_name[2..]
|
313
|
+
else
|
314
|
+
# It's not in B-, I- format
|
315
|
+
# Default to I- for continuation.
|
316
|
+
bi = "I"
|
317
|
+
tag = entity_name
|
318
|
+
end
|
319
|
+
[bi, tag]
|
320
|
+
end
|
321
|
+
|
322
|
+
def group_entities(entities)
|
323
|
+
entity_groups = []
|
324
|
+
entity_group_disagg = []
|
325
|
+
|
326
|
+
entities.each do |entity|
|
327
|
+
if entity_group_disagg.empty?
|
328
|
+
entity_group_disagg << entity
|
329
|
+
next
|
330
|
+
end
|
331
|
+
|
332
|
+
# If the current entity is similar and adjacent to the previous entity,
|
333
|
+
# append it to the disaggregated entity group
|
334
|
+
# The split is meant to account for the "B" and "I" prefixes
|
335
|
+
# Shouldn't merge if both entities are B-type
|
336
|
+
bi, tag = get_tag(entity[:entity])
|
337
|
+
_last_bi, last_tag = get_tag(entity_group_disagg[-1][:entity])
|
338
|
+
|
339
|
+
if tag == last_tag && bi != "B"
|
340
|
+
# Modify subword type to be previous_type
|
341
|
+
entity_group_disagg << entity
|
342
|
+
else
|
343
|
+
# If the current entity is different from the previous entity
|
344
|
+
# aggregate the disaggregated entity group
|
345
|
+
entity_groups << group_sub_entities(entity_group_disagg)
|
346
|
+
entity_group_disagg = [entity]
|
347
|
+
end
|
348
|
+
end
|
349
|
+
if entity_group_disagg.any?
|
350
|
+
# it's the last entity, add it to the entity groups
|
351
|
+
entity_groups << group_sub_entities(entity_group_disagg)
|
352
|
+
end
|
353
|
+
|
354
|
+
entity_groups
|
355
|
+
end
|
281
356
|
end
|
282
357
|
end
|
@@ -2,36 +2,19 @@ module Transformers
|
|
2
2
|
class SentenceTransformer
|
3
3
|
def initialize(model_id)
|
4
4
|
@model_id = model_id
|
5
|
-
@
|
6
|
-
@model = Transformers::AutoModel.from_pretrained(model_id)
|
5
|
+
@model = Transformers.pipeline("embedding", model_id)
|
7
6
|
end
|
8
7
|
|
9
8
|
def encode(sentences)
|
10
|
-
singular = sentences.is_a?(String)
|
11
|
-
sentences = [sentences] if singular
|
12
|
-
|
13
|
-
input = @tokenizer.(sentences, padding: true, truncation: true, return_tensors: "pt")
|
14
|
-
output = Torch.no_grad { @model.(**input) }[0]
|
15
|
-
|
16
9
|
# TODO check modules.json
|
17
10
|
if [
|
18
11
|
"sentence-transformers/all-MiniLM-L6-v2",
|
19
12
|
"sentence-transformers/multi-qa-MiniLM-L6-cos-v1"
|
20
13
|
].include?(@model_id)
|
21
|
-
|
22
|
-
output = Torch::NN::Functional.normalize(output, p: 2, dim: 1).to_a
|
14
|
+
@model.(sentences)
|
23
15
|
else
|
24
|
-
|
16
|
+
@model.(sentences, pooling: "cls", normalize: false)
|
25
17
|
end
|
26
|
-
|
27
|
-
singular ? output[0] : output
|
28
|
-
end
|
29
|
-
|
30
|
-
private
|
31
|
-
|
32
|
-
def mean_pooling(output, attention_mask)
|
33
|
-
input_mask_expanded = attention_mask.unsqueeze(-1).expand(output.size).float
|
34
|
-
Torch.sum(output * input_mask_expanded, 1) / Torch.clamp(input_mask_expanded.sum(1), min: 1e-9)
|
35
18
|
end
|
36
19
|
end
|
37
20
|
end
|
@@ -91,6 +91,10 @@ module Transformers
|
|
91
91
|
get_vocab
|
92
92
|
end
|
93
93
|
|
94
|
+
def backend_tokenizer
|
95
|
+
@tokenizer
|
96
|
+
end
|
97
|
+
|
94
98
|
def convert_tokens_to_ids(tokens)
|
95
99
|
if tokens.nil?
|
96
100
|
return nil
|
@@ -130,6 +134,10 @@ module Transformers
|
|
130
134
|
tokens
|
131
135
|
end
|
132
136
|
|
137
|
+
def convert_tokens_to_string(tokens)
|
138
|
+
backend_tokenizer.decoder.decode(tokens)
|
139
|
+
end
|
140
|
+
|
133
141
|
private
|
134
142
|
|
135
143
|
def set_truncation_and_padding(
|
data/lib/transformers/version.rb
CHANGED
data/lib/transformers.rb
CHANGED
@@ -75,6 +75,7 @@ require_relative "transformers/models/vit/modeling_vit"
|
|
75
75
|
# pipelines
|
76
76
|
require_relative "transformers/pipelines/base"
|
77
77
|
require_relative "transformers/pipelines/feature_extraction"
|
78
|
+
require_relative "transformers/pipelines/embedding"
|
78
79
|
require_relative "transformers/pipelines/image_classification"
|
79
80
|
require_relative "transformers/pipelines/image_feature_extraction"
|
80
81
|
require_relative "transformers/pipelines/pt_utils"
|
@@ -97,4 +98,10 @@ module Transformers
|
|
97
98
|
"not implemented yet"
|
98
99
|
end
|
99
100
|
end
|
101
|
+
|
102
|
+
class << self
|
103
|
+
# experimental
|
104
|
+
attr_accessor :fast_init
|
105
|
+
end
|
106
|
+
self.fast_init = false
|
100
107
|
end
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: transformers-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.1.
|
4
|
+
version: 0.1.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2024-08-
|
11
|
+
date: 2024-08-30 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -44,14 +44,14 @@ dependencies:
|
|
44
44
|
requirements:
|
45
45
|
- - ">="
|
46
46
|
- !ruby/object:Gem::Version
|
47
|
-
version:
|
47
|
+
version: 0.5.2
|
48
48
|
type: :runtime
|
49
49
|
prerelease: false
|
50
50
|
version_requirements: !ruby/object:Gem::Requirement
|
51
51
|
requirements:
|
52
52
|
- - ">="
|
53
53
|
- !ruby/object:Gem::Version
|
54
|
-
version:
|
54
|
+
version: 0.5.2
|
55
55
|
- !ruby/object:Gem::Dependency
|
56
56
|
name: torch-rb
|
57
57
|
requirement: !ruby/object:Gem::Requirement
|
@@ -113,6 +113,7 @@ files:
|
|
113
113
|
- lib/transformers/models/vit/modeling_vit.rb
|
114
114
|
- lib/transformers/pipelines/_init.rb
|
115
115
|
- lib/transformers/pipelines/base.rb
|
116
|
+
- lib/transformers/pipelines/embedding.rb
|
116
117
|
- lib/transformers/pipelines/feature_extraction.rb
|
117
118
|
- lib/transformers/pipelines/image_classification.rb
|
118
119
|
- lib/transformers/pipelines/image_feature_extraction.rb
|