transformers-rb 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +3 -0
  3. data/LICENSE.txt +203 -0
  4. data/README.md +163 -0
  5. data/lib/transformers/activations.rb +57 -0
  6. data/lib/transformers/configuration_utils.rb +285 -0
  7. data/lib/transformers/convert_slow_tokenizer.rb +90 -0
  8. data/lib/transformers/data/processors/squad.rb +115 -0
  9. data/lib/transformers/dynamic_module_utils.rb +25 -0
  10. data/lib/transformers/feature_extraction_utils.rb +110 -0
  11. data/lib/transformers/hf_hub/constants.rb +71 -0
  12. data/lib/transformers/hf_hub/errors.rb +11 -0
  13. data/lib/transformers/hf_hub/file_download.rb +764 -0
  14. data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
  15. data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
  16. data/lib/transformers/image_processing_base.rb +169 -0
  17. data/lib/transformers/image_processing_utils.rb +63 -0
  18. data/lib/transformers/image_transforms.rb +208 -0
  19. data/lib/transformers/image_utils.rb +165 -0
  20. data/lib/transformers/modeling_outputs.rb +81 -0
  21. data/lib/transformers/modeling_utils.rb +888 -0
  22. data/lib/transformers/models/auto/auto_factory.rb +138 -0
  23. data/lib/transformers/models/auto/configuration_auto.rb +61 -0
  24. data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
  25. data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
  26. data/lib/transformers/models/auto/modeling_auto.rb +80 -0
  27. data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
  28. data/lib/transformers/models/bert/configuration_bert.rb +65 -0
  29. data/lib/transformers/models/bert/modeling_bert.rb +836 -0
  30. data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
  31. data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
  32. data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
  33. data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
  34. data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
  35. data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
  36. data/lib/transformers/models/vit/configuration_vit.rb +60 -0
  37. data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
  38. data/lib/transformers/models/vit/modeling_vit.rb +506 -0
  39. data/lib/transformers/pipelines/_init.rb +348 -0
  40. data/lib/transformers/pipelines/base.rb +301 -0
  41. data/lib/transformers/pipelines/feature_extraction.rb +47 -0
  42. data/lib/transformers/pipelines/image_classification.rb +110 -0
  43. data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
  44. data/lib/transformers/pipelines/pt_utils.rb +53 -0
  45. data/lib/transformers/pipelines/question_answering.rb +508 -0
  46. data/lib/transformers/pipelines/text_classification.rb +123 -0
  47. data/lib/transformers/pipelines/token_classification.rb +282 -0
  48. data/lib/transformers/ruby_utils.rb +33 -0
  49. data/lib/transformers/sentence_transformer.rb +37 -0
  50. data/lib/transformers/tokenization_utils.rb +152 -0
  51. data/lib/transformers/tokenization_utils_base.rb +937 -0
  52. data/lib/transformers/tokenization_utils_fast.rb +386 -0
  53. data/lib/transformers/torch_utils.rb +25 -0
  54. data/lib/transformers/utils/_init.rb +31 -0
  55. data/lib/transformers/utils/generic.rb +107 -0
  56. data/lib/transformers/utils/hub.rb +209 -0
  57. data/lib/transformers/utils/import_utils.rb +45 -0
  58. data/lib/transformers/utils/logging.rb +52 -0
  59. data/lib/transformers/version.rb +3 -0
  60. data/lib/transformers-rb.rb +1 -0
  61. data/lib/transformers.rb +100 -0
  62. data/licenses/LICENSE-huggingface-hub.txt +201 -0
  63. data/licenses/LICENSE-sentence-transformers.txt +201 -0
  64. data/licenses/NOTICE-sentence-transformers.txt +5 -0
  65. metadata +161 -0
@@ -0,0 +1,282 @@
1
+ module Transformers
2
+ class TokenClassificationArgumentHandler < ArgumentHandler
3
+ end
4
+
5
+ class AggregationStrategy < ExplicitEnum
6
+ NONE = "none"
7
+ SIMPLE = "simple"
8
+ FIRST = "first"
9
+ AVERAGE = "average"
10
+ MAX = "max"
11
+ end
12
+
13
+ class TokenClassificationPipeline < ChunkPipeline
14
+ extend ClassAttribute
15
+
16
+ class_attribute :default_input_names, "sequences"
17
+
18
+ def initialize(*args, args_parser: TokenClassificationArgumentHandler.new, **kwargs)
19
+ super(*args, **kwargs)
20
+ check_model_type(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES)
21
+
22
+ @basic_tokenizer = Bert::BertTokenizer::BasicTokenizer.new(do_lower_case: false)
23
+ @args_parser = args_parser
24
+ end
25
+
26
+ def _sanitize_parameters(
27
+ ignore_labels: nil,
28
+ grouped_entities: nil,
29
+ ignore_subwords: nil,
30
+ aggregation_strategy: nil,
31
+ offset_mapping: nil,
32
+ stride: nil
33
+ )
34
+ preprocess_params = {}
35
+ if !offset_mapping.nil?
36
+ preprocess_params[:offset_mapping] = offset_mapping
37
+ end
38
+
39
+ postprocess_params = {}
40
+ if !grouped_entities.nil? || !ignore_subwords.nil?
41
+ if grouped_entities && ignore_subwords
42
+ aggregation_strategy = AggregationStrategy::FIRST
43
+ elsif grouped_entities && !ignore_subwords
44
+ aggregation_strategy = AggregationStrategy::SIMPLE
45
+ else
46
+ aggregation_strategy = AggregationStrategy::NONE
47
+ end
48
+
49
+ if !grouped_entities.nil?
50
+ warn(
51
+ "`grouped_entities` is deprecated and will be removed in version v5.0.0, defaulted to" +
52
+ " `aggregation_strategy=\"#{aggregation_strategy}\"` instead."
53
+ )
54
+ end
55
+ if !ignore_subwords.nil?
56
+ warn(
57
+ "`ignore_subwords` is deprecated and will be removed in version v5.0.0, defaulted to" +
58
+ " `aggregation_strategy=\"#{aggregation_strategy}\"` instead."
59
+ )
60
+ end
61
+ end
62
+
63
+ if !aggregation_strategy.nil?
64
+ if aggregation_strategy.is_a?(String)
65
+ aggregation_strategy = AggregationStrategy.new(aggregation_strategy.upcase).to_s
66
+ end
67
+ if (
68
+ [AggregationStrategy::FIRST, AggregationStrategy::MAX, AggregationStrategy::AVERAGE].include?(aggregation_strategy) &&
69
+ !@tokenizer.is_fast
70
+ )
71
+ raise ArgumentError,
72
+ "Slow tokenizers cannot handle subwords. Please set the `aggregation_strategy` option" +
73
+ ' to `"simple"` or use a fast tokenizer.'
74
+ end
75
+ postprocess_params[:aggregation_strategy] = aggregation_strategy
76
+ end
77
+ if !ignore_labels.nil?
78
+ postprocess_params[:ignore_labels] = ignore_labels
79
+ end
80
+ if !stride.nil?
81
+ if stride >= @tokenizer.model_max_length
82
+ raise ArgumentError,
83
+ "`stride` must be less than `tokenizer.model_max_length` (or even lower if the tokenizer adds special tokens)"
84
+ end
85
+ if aggregation_strategy == AggregationStrategy::NONE
86
+ raise ArgumentError,
87
+ "`stride` was provided to process all the text but `aggregation_strategy=" +
88
+ "\"#{aggregation_strategy}\"`, please select another one instead."
89
+ else
90
+ if @tokenizer.is_fast
91
+ tokenizer_params = {
92
+ return_overflowing_tokens: true,
93
+ padding: true,
94
+ stride: stride
95
+ }
96
+ preprocess_params[:tokenizer_params] = tokenizer_params
97
+ else
98
+ raise ArgumentError,
99
+ "`stride` was provided to process all the text but you're using a slow tokenizer." +
100
+ " Please use a fast tokenizer."
101
+ end
102
+ end
103
+ end
104
+ [preprocess_params, {}, postprocess_params]
105
+ end
106
+
107
+ def preprocess(sentence, offset_mapping: nil, **preprocess_params)
108
+ tokenizer_params = preprocess_params.delete(:tokenizer_params) { {} }
109
+ truncation = @tokenizer.model_max_length && @tokenizer.model_max_length > 0
110
+ inputs = @tokenizer.(
111
+ sentence,
112
+ return_tensors: @framework,
113
+ truncation: truncation,
114
+ return_special_tokens_mask: true,
115
+ return_offsets_mapping: @tokenizer.is_fast,
116
+ **tokenizer_params
117
+ )
118
+ inputs.delete(:overflow_to_sample_mapping)
119
+ num_chunks = inputs[:input_ids].length
120
+
121
+ num_chunks.times do |i|
122
+ if @framework == "tf"
123
+ raise Todo
124
+ else
125
+ model_inputs = inputs.to_h { |k, v| [k, v[i].unsqueeze(0)] }
126
+ end
127
+ if !@offset_mapping.nil?
128
+ model_inputs[:offset_mapping] = offset_mapping
129
+ end
130
+ model_inputs[:sentence] = i == 0 ? sentence : nil
131
+ model_inputs[:is_last] = (i == num_chunks - 1)
132
+
133
+ yield model_inputs
134
+ end
135
+ end
136
+
137
+ def _forward(model_inputs)
138
+ # Forward
139
+ special_tokens_mask = model_inputs.delete(:special_tokens_mask)
140
+ offset_mapping = model_inputs.delete(:offset_mapping)
141
+ sentence = model_inputs.delete(:sentence)
142
+ is_last = model_inputs.delete(:is_last)
143
+ if @framework == "tf"
144
+ logits = @model.(**model_inputs)[0]
145
+ else
146
+ output = @model.(**model_inputs)
147
+ logits = output.is_a?(Hash) ? output[:logits] : output[0]
148
+ end
149
+
150
+ {
151
+ logits: logits,
152
+ special_tokens_mask: special_tokens_mask,
153
+ offset_mapping: offset_mapping,
154
+ sentence: sentence,
155
+ is_last: is_last,
156
+ **model_inputs
157
+ }
158
+ end
159
+
160
+ def postprocess(all_outputs, aggregation_strategy: AggregationStrategy::NONE, ignore_labels: nil)
161
+ if ignore_labels.nil?
162
+ ignore_labels = ["O"]
163
+ end
164
+ all_entities = []
165
+ all_outputs.each do |model_outputs|
166
+ logits = model_outputs[:logits][0].numo
167
+ sentence = all_outputs[0][:sentence]
168
+ input_ids = model_outputs[:input_ids][0]
169
+ offset_mapping = (
170
+ !model_outputs[:offset_mapping].nil? ? model_outputs[:offset_mapping][0] : nil
171
+ )
172
+ special_tokens_mask = model_outputs[:special_tokens_mask][0].numo
173
+
174
+ maxes = logits.max(axis: -1).expand_dims(-1)
175
+ shifted_exp = Numo::NMath.exp(logits - maxes)
176
+ scores = shifted_exp / shifted_exp.sum(axis: -1).expand_dims(-1)
177
+
178
+ if @framework == "tf"
179
+ raise Todo
180
+ end
181
+
182
+ pre_entities = gather_pre_entities(
183
+ sentence, input_ids, scores, offset_mapping, special_tokens_mask, aggregation_strategy
184
+ )
185
+ grouped_entities = aggregate(pre_entities, aggregation_strategy)
186
+ # Filter anything that is in self.ignore_labels
187
+ entities =
188
+ grouped_entities.select do |entity|
189
+ !ignore_labels.include?(entity[:entity]) && !ignore_labels.include?(entity[:entity_group])
190
+ end
191
+ all_entities.concat(entities)
192
+ end
193
+ num_chunks = all_outputs.length
194
+ if num_chunks > 1
195
+ all_entities = aggregate_overlapping_entities(all_entities)
196
+ end
197
+ all_entities
198
+ end
199
+
200
+ def gather_pre_entities(
201
+ sentence,
202
+ input_ids,
203
+ scores,
204
+ offset_mapping,
205
+ special_tokens_mask,
206
+ aggregation_strategy
207
+ )
208
+ pre_entities = []
209
+ scores.each_over_axis(0).with_index do |token_scores, idx|
210
+ # Filter special_tokens
211
+ if special_tokens_mask[idx] != 0
212
+ next
213
+ end
214
+
215
+ word = @tokenizer.convert_ids_to_tokens(input_ids[idx].to_i)
216
+ if !offset_mapping.nil?
217
+ start_ind, end_ind = offset_mapping[idx].to_a
218
+ if !start_ind.is_a?(Integer)
219
+ if @framework == "pt"
220
+ start_ind = start_ind.item
221
+ end_ind = end_ind.item
222
+ end
223
+ end
224
+ word_ref = sentence[start_ind...end_ind]
225
+ if @tokenizer.instance_variable_get(:@tokenizer).respond_to?(:continuing_subword_prefix)
226
+ # This is a BPE, word aware tokenizer, there is a correct way
227
+ # to fuse tokens
228
+ is_subword = word.length != word_ref.length
229
+ else
230
+ is_subword = start_ind > 0 && !sentence[(start_ind - 1)...(start_ind + 1)].include?(" ")
231
+ end
232
+
233
+ if input_ids[idx].to_i == @tokenizer.unk_token_id
234
+ word = word_ref
235
+ is_subword = false
236
+ end
237
+ else
238
+ start_ind = nil
239
+ end_ind = nil
240
+ is_subword = nil
241
+ end
242
+
243
+ pre_entity = {
244
+ word: word,
245
+ scores: token_scores,
246
+ start: start_ind,
247
+ end: end_ind,
248
+ index: idx,
249
+ is_subword: is_subword
250
+ }
251
+ pre_entities << pre_entity
252
+ end
253
+ pre_entities
254
+ end
255
+
256
+ def aggregate(pre_entities, aggregation_strategy)
257
+ if [AggregationStrategy::NONE, AggregationStrategy::SIMPLE].include?(aggregation_strategy)
258
+ entities = []
259
+ pre_entities.each do |pre_entity|
260
+ entity_idx = pre_entity[:scores].argmax
261
+ score = pre_entity[:scores][entity_idx]
262
+ entity = {
263
+ entity: @model.config.id2label[entity_idx],
264
+ score: score,
265
+ index: pre_entity[:index],
266
+ word: pre_entity[:word],
267
+ start: pre_entity[:start],
268
+ end: pre_entity[:end]
269
+ }
270
+ entities << entity
271
+ end
272
+ else
273
+ entities = aggregate_words(pre_entities, aggregation_strategy)
274
+ end
275
+
276
+ if aggregation_strategy == AggregationStrategy::NONE
277
+ return entities
278
+ end
279
+ group_entities(entities)
280
+ end
281
+ end
282
+ end
@@ -0,0 +1,33 @@
1
+ module Transformers
2
+ module ClassAttribute
3
+ def class_attribute(name, default = nil)
4
+ singleton_class.attr_writer name
5
+ var = "@#{name}"
6
+ instance_variable_set(var, default)
7
+ singleton_class.define_method(name) do
8
+ # ancestors includes current module
9
+ ancestors.find { |c| c.instance_variable_defined?(var) }.instance_variable_get(var)
10
+ end
11
+ define_method(name) do
12
+ self.class.send(name)
13
+ end
14
+ end
15
+ end
16
+
17
+ module Copy
18
+ def self.deepcopy(value, memo = {})
19
+ key = value.object_id
20
+ if !memo.key?(key)
21
+ copy = value.dup
22
+ memo[key] = copy
23
+ if value.is_a?(Hash)
24
+ copy.transform_keys! { |k| deepcopy(k, memo) }
25
+ copy.transform_values! { |v| deepcopy(v, memo) }
26
+ elsif value.is_a?(Array)
27
+ copy.map! { |v| deepcopy(v, memo) }
28
+ end
29
+ end
30
+ memo[key]
31
+ end
32
+ end
33
+ end
@@ -0,0 +1,37 @@
1
+ module Transformers
2
+ class SentenceTransformer
3
+ def initialize(model_id)
4
+ @model_id = model_id
5
+ @tokenizer = Transformers::AutoTokenizer.from_pretrained(model_id)
6
+ @model = Transformers::AutoModel.from_pretrained(model_id)
7
+ end
8
+
9
+ def encode(sentences)
10
+ singular = sentences.is_a?(String)
11
+ sentences = [sentences] if singular
12
+
13
+ input = @tokenizer.(sentences, padding: true, truncation: true, return_tensors: "pt")
14
+ output = Torch.no_grad { @model.(**input) }[0]
15
+
16
+ # TODO check modules.json
17
+ if [
18
+ "sentence-transformers/all-MiniLM-L6-v2",
19
+ "sentence-transformers/multi-qa-MiniLM-L6-cos-v1"
20
+ ].include?(@model_id)
21
+ output = mean_pooling(output, input[:attention_mask])
22
+ output = Torch::NN::Functional.normalize(output, p: 2, dim: 1).to_a
23
+ else
24
+ output = output[0.., 0].to_a
25
+ end
26
+
27
+ singular ? output[0] : output
28
+ end
29
+
30
+ private
31
+
32
+ def mean_pooling(output, attention_mask)
33
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(output.size).float
34
+ Torch.sum(output * input_mask_expanded, 1) / Torch.clamp(input_mask_expanded.sum(1), min: 1e-9)
35
+ end
36
+ end
37
+ end
@@ -0,0 +1,152 @@
1
+ # Copyright 2020 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ class PreTrainedTokenizer < PreTrainedTokenizerBase
17
+ def initialize(**kwargs)
18
+
19
+ # 2. init `_added_tokens_decoder` if child class did not
20
+ if !instance_variable_defined?(:@added_tokens_decoder)
21
+ @added_tokens_decoder = {}
22
+ end
23
+
24
+ # 3. if a `added_tokens_decoder` is passed, we are loading from a saved tokenizer, we overwrite
25
+ @added_tokens_decoder.merge!(kwargs.delete(:added_tokens_decoder) { {} })
26
+ @added_tokens_encoder = @added_tokens_decoder.to_h { |k, v| [k.content, v] }
27
+
28
+ # 4 init the parent class
29
+ super(**kwargs)
30
+ end
31
+
32
+ def is_fast
33
+ false
34
+ end
35
+
36
+ def vocab_size
37
+ raise NotImplementedError
38
+ end
39
+
40
+ def tokenize(text, **kwargs)
41
+ raise Todo
42
+ end
43
+
44
+ def _encode_plus(
45
+ text:,
46
+ text_pair: nil,
47
+ add_special_tokens: true,
48
+ padding_strategy: PaddingStrategy::DO_NOT_PAD,
49
+ truncation_strategy: TruncationStrategy::DO_NOT_TRUNCATE,
50
+ max_length: nil,
51
+ stride: 0,
52
+ is_split_into_words: false,
53
+ pad_to_multiple_of: nil,
54
+ return_tensors: nil,
55
+ return_token_type_ids: nil,
56
+ return_attention_mask: nil,
57
+ return_overflowing_tokens: false,
58
+ return_special_tokens_mask: false,
59
+ return_offsets_mapping: false,
60
+ return_length: false,
61
+ verbose: true,
62
+ **kwargs
63
+ )
64
+ get_input_ids = lambda do |text|
65
+ if text.is_a?(String)
66
+ tokens = tokenize(text, **kwargs)
67
+ convert_tokens_to_ids(tokens)
68
+ elsif text.is_a?(Array) && text.length > 0 && text[0].is_a?(String)
69
+ if is_split_into_words
70
+ raise Todo
71
+ else
72
+ convert_tokens_to_ids(text)
73
+ end
74
+ elsif text.is_a?(Array) && text.length > 0 && text[0].is_a?(Integer)
75
+ text
76
+ else
77
+ if is_split_into_words
78
+ raise ArgumentError,
79
+ "Input #{text} is not valid. Should be a string or a list/tuple of strings when" +
80
+ " `is_split_into_words=True`."
81
+ else
82
+ raise ArgumentError,
83
+ "Input #{text} is not valid. Should be a string, a list/tuple of strings or a list/tuple of" +
84
+ " integers."
85
+ end
86
+ end
87
+ end
88
+
89
+ if return_offsets_mapping
90
+ raise RuntimeError,
91
+ "return_offset_mapping is not available when using Ruby tokenizers. " +
92
+ "To use this feature, change your tokenizer to one deriving from " +
93
+ "Transformers::PreTrainedTokenizerFast. " +
94
+ "More information on available tokenizers at " +
95
+ "https://github.com/huggingface/transformers/pull/2674"
96
+ end
97
+
98
+ first_ids = get_input_ids.(text)
99
+ second_ids = !text_pair.nil? ? get_input_ids.(text_pair) : nil
100
+
101
+ prepare_for_model(
102
+ first_ids,
103
+ pair_ids: second_ids,
104
+ add_special_tokens: add_special_tokens,
105
+ padding: padding_strategy,
106
+ truncation: truncation_strategy,
107
+ max_length: max_length,
108
+ stride: stride,
109
+ pad_to_multiple_of: pad_to_multiple_of,
110
+ return_tensors: return_tensors,
111
+ prepend_batch_axis: true,
112
+ return_attention_mask: return_attention_mask,
113
+ return_token_type_ids: return_token_type_ids,
114
+ return_overflowing_tokens: return_overflowing_tokens,
115
+ return_special_tokens_mask: return_special_tokens_mask,
116
+ return_length: return_length,
117
+ verbose: verbose
118
+ )
119
+ end
120
+
121
+ def convert_tokens_to_ids(tokens)
122
+ if tokens.nil?
123
+ return nil
124
+ end
125
+
126
+ if tokens.is_a?(String)
127
+ return _convert_token_to_id_with_added_voc(tokens)
128
+ end
129
+
130
+ ids = []
131
+ tokens.each do |token|
132
+ ids << _convert_token_to_id_with_added_voc(token)
133
+ end
134
+ ids
135
+ end
136
+
137
+ def _convert_token_to_id_with_added_voc(token)
138
+ if token.nil?
139
+ return nil
140
+ end
141
+
142
+ if @added_tokens_encoder.include?(token)
143
+ return @added_tokens_encoder[token]
144
+ end
145
+ _convert_token_to_id(token)
146
+ end
147
+
148
+ def _convert_token_to_id(token)
149
+ raise NotImplementedError
150
+ end
151
+ end
152
+ end