transformers-rb 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (65) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +3 -0
  3. data/LICENSE.txt +203 -0
  4. data/README.md +163 -0
  5. data/lib/transformers/activations.rb +57 -0
  6. data/lib/transformers/configuration_utils.rb +285 -0
  7. data/lib/transformers/convert_slow_tokenizer.rb +90 -0
  8. data/lib/transformers/data/processors/squad.rb +115 -0
  9. data/lib/transformers/dynamic_module_utils.rb +25 -0
  10. data/lib/transformers/feature_extraction_utils.rb +110 -0
  11. data/lib/transformers/hf_hub/constants.rb +71 -0
  12. data/lib/transformers/hf_hub/errors.rb +11 -0
  13. data/lib/transformers/hf_hub/file_download.rb +764 -0
  14. data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
  15. data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
  16. data/lib/transformers/image_processing_base.rb +169 -0
  17. data/lib/transformers/image_processing_utils.rb +63 -0
  18. data/lib/transformers/image_transforms.rb +208 -0
  19. data/lib/transformers/image_utils.rb +165 -0
  20. data/lib/transformers/modeling_outputs.rb +81 -0
  21. data/lib/transformers/modeling_utils.rb +888 -0
  22. data/lib/transformers/models/auto/auto_factory.rb +138 -0
  23. data/lib/transformers/models/auto/configuration_auto.rb +61 -0
  24. data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
  25. data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
  26. data/lib/transformers/models/auto/modeling_auto.rb +80 -0
  27. data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
  28. data/lib/transformers/models/bert/configuration_bert.rb +65 -0
  29. data/lib/transformers/models/bert/modeling_bert.rb +836 -0
  30. data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
  31. data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
  32. data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
  33. data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
  34. data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
  35. data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
  36. data/lib/transformers/models/vit/configuration_vit.rb +60 -0
  37. data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
  38. data/lib/transformers/models/vit/modeling_vit.rb +506 -0
  39. data/lib/transformers/pipelines/_init.rb +348 -0
  40. data/lib/transformers/pipelines/base.rb +301 -0
  41. data/lib/transformers/pipelines/feature_extraction.rb +47 -0
  42. data/lib/transformers/pipelines/image_classification.rb +110 -0
  43. data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
  44. data/lib/transformers/pipelines/pt_utils.rb +53 -0
  45. data/lib/transformers/pipelines/question_answering.rb +508 -0
  46. data/lib/transformers/pipelines/text_classification.rb +123 -0
  47. data/lib/transformers/pipelines/token_classification.rb +282 -0
  48. data/lib/transformers/ruby_utils.rb +33 -0
  49. data/lib/transformers/sentence_transformer.rb +37 -0
  50. data/lib/transformers/tokenization_utils.rb +152 -0
  51. data/lib/transformers/tokenization_utils_base.rb +937 -0
  52. data/lib/transformers/tokenization_utils_fast.rb +386 -0
  53. data/lib/transformers/torch_utils.rb +25 -0
  54. data/lib/transformers/utils/_init.rb +31 -0
  55. data/lib/transformers/utils/generic.rb +107 -0
  56. data/lib/transformers/utils/hub.rb +209 -0
  57. data/lib/transformers/utils/import_utils.rb +45 -0
  58. data/lib/transformers/utils/logging.rb +52 -0
  59. data/lib/transformers/version.rb +3 -0
  60. data/lib/transformers-rb.rb +1 -0
  61. data/lib/transformers.rb +100 -0
  62. data/licenses/LICENSE-huggingface-hub.txt +201 -0
  63. data/licenses/LICENSE-sentence-transformers.txt +201 -0
  64. data/licenses/NOTICE-sentence-transformers.txt +5 -0
  65. metadata +161 -0
@@ -0,0 +1,282 @@
1
+ module Transformers
2
+ class TokenClassificationArgumentHandler < ArgumentHandler
3
+ end
4
+
5
+ class AggregationStrategy < ExplicitEnum
6
+ NONE = "none"
7
+ SIMPLE = "simple"
8
+ FIRST = "first"
9
+ AVERAGE = "average"
10
+ MAX = "max"
11
+ end
12
+
13
+ class TokenClassificationPipeline < ChunkPipeline
14
+ extend ClassAttribute
15
+
16
+ class_attribute :default_input_names, "sequences"
17
+
18
+ def initialize(*args, args_parser: TokenClassificationArgumentHandler.new, **kwargs)
19
+ super(*args, **kwargs)
20
+ check_model_type(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES)
21
+
22
+ @basic_tokenizer = Bert::BertTokenizer::BasicTokenizer.new(do_lower_case: false)
23
+ @args_parser = args_parser
24
+ end
25
+
26
+ def _sanitize_parameters(
27
+ ignore_labels: nil,
28
+ grouped_entities: nil,
29
+ ignore_subwords: nil,
30
+ aggregation_strategy: nil,
31
+ offset_mapping: nil,
32
+ stride: nil
33
+ )
34
+ preprocess_params = {}
35
+ if !offset_mapping.nil?
36
+ preprocess_params[:offset_mapping] = offset_mapping
37
+ end
38
+
39
+ postprocess_params = {}
40
+ if !grouped_entities.nil? || !ignore_subwords.nil?
41
+ if grouped_entities && ignore_subwords
42
+ aggregation_strategy = AggregationStrategy::FIRST
43
+ elsif grouped_entities && !ignore_subwords
44
+ aggregation_strategy = AggregationStrategy::SIMPLE
45
+ else
46
+ aggregation_strategy = AggregationStrategy::NONE
47
+ end
48
+
49
+ if !grouped_entities.nil?
50
+ warn(
51
+ "`grouped_entities` is deprecated and will be removed in version v5.0.0, defaulted to" +
52
+ " `aggregation_strategy=\"#{aggregation_strategy}\"` instead."
53
+ )
54
+ end
55
+ if !ignore_subwords.nil?
56
+ warn(
57
+ "`ignore_subwords` is deprecated and will be removed in version v5.0.0, defaulted to" +
58
+ " `aggregation_strategy=\"#{aggregation_strategy}\"` instead."
59
+ )
60
+ end
61
+ end
62
+
63
+ if !aggregation_strategy.nil?
64
+ if aggregation_strategy.is_a?(String)
65
+ aggregation_strategy = AggregationStrategy.new(aggregation_strategy.upcase).to_s
66
+ end
67
+ if (
68
+ [AggregationStrategy::FIRST, AggregationStrategy::MAX, AggregationStrategy::AVERAGE].include?(aggregation_strategy) &&
69
+ !@tokenizer.is_fast
70
+ )
71
+ raise ArgumentError,
72
+ "Slow tokenizers cannot handle subwords. Please set the `aggregation_strategy` option" +
73
+ ' to `"simple"` or use a fast tokenizer.'
74
+ end
75
+ postprocess_params[:aggregation_strategy] = aggregation_strategy
76
+ end
77
+ if !ignore_labels.nil?
78
+ postprocess_params[:ignore_labels] = ignore_labels
79
+ end
80
+ if !stride.nil?
81
+ if stride >= @tokenizer.model_max_length
82
+ raise ArgumentError,
83
+ "`stride` must be less than `tokenizer.model_max_length` (or even lower if the tokenizer adds special tokens)"
84
+ end
85
+ if aggregation_strategy == AggregationStrategy::NONE
86
+ raise ArgumentError,
87
+ "`stride` was provided to process all the text but `aggregation_strategy=" +
88
+ "\"#{aggregation_strategy}\"`, please select another one instead."
89
+ else
90
+ if @tokenizer.is_fast
91
+ tokenizer_params = {
92
+ return_overflowing_tokens: true,
93
+ padding: true,
94
+ stride: stride
95
+ }
96
+ preprocess_params[:tokenizer_params] = tokenizer_params
97
+ else
98
+ raise ArgumentError,
99
+ "`stride` was provided to process all the text but you're using a slow tokenizer." +
100
+ " Please use a fast tokenizer."
101
+ end
102
+ end
103
+ end
104
+ [preprocess_params, {}, postprocess_params]
105
+ end
106
+
107
+ def preprocess(sentence, offset_mapping: nil, **preprocess_params)
108
+ tokenizer_params = preprocess_params.delete(:tokenizer_params) { {} }
109
+ truncation = @tokenizer.model_max_length && @tokenizer.model_max_length > 0
110
+ inputs = @tokenizer.(
111
+ sentence,
112
+ return_tensors: @framework,
113
+ truncation: truncation,
114
+ return_special_tokens_mask: true,
115
+ return_offsets_mapping: @tokenizer.is_fast,
116
+ **tokenizer_params
117
+ )
118
+ inputs.delete(:overflow_to_sample_mapping)
119
+ num_chunks = inputs[:input_ids].length
120
+
121
+ num_chunks.times do |i|
122
+ if @framework == "tf"
123
+ raise Todo
124
+ else
125
+ model_inputs = inputs.to_h { |k, v| [k, v[i].unsqueeze(0)] }
126
+ end
127
+ if !@offset_mapping.nil?
128
+ model_inputs[:offset_mapping] = offset_mapping
129
+ end
130
+ model_inputs[:sentence] = i == 0 ? sentence : nil
131
+ model_inputs[:is_last] = (i == num_chunks - 1)
132
+
133
+ yield model_inputs
134
+ end
135
+ end
136
+
137
+ def _forward(model_inputs)
138
+ # Forward
139
+ special_tokens_mask = model_inputs.delete(:special_tokens_mask)
140
+ offset_mapping = model_inputs.delete(:offset_mapping)
141
+ sentence = model_inputs.delete(:sentence)
142
+ is_last = model_inputs.delete(:is_last)
143
+ if @framework == "tf"
144
+ logits = @model.(**model_inputs)[0]
145
+ else
146
+ output = @model.(**model_inputs)
147
+ logits = output.is_a?(Hash) ? output[:logits] : output[0]
148
+ end
149
+
150
+ {
151
+ logits: logits,
152
+ special_tokens_mask: special_tokens_mask,
153
+ offset_mapping: offset_mapping,
154
+ sentence: sentence,
155
+ is_last: is_last,
156
+ **model_inputs
157
+ }
158
+ end
159
+
160
+ def postprocess(all_outputs, aggregation_strategy: AggregationStrategy::NONE, ignore_labels: nil)
161
+ if ignore_labels.nil?
162
+ ignore_labels = ["O"]
163
+ end
164
+ all_entities = []
165
+ all_outputs.each do |model_outputs|
166
+ logits = model_outputs[:logits][0].numo
167
+ sentence = all_outputs[0][:sentence]
168
+ input_ids = model_outputs[:input_ids][0]
169
+ offset_mapping = (
170
+ !model_outputs[:offset_mapping].nil? ? model_outputs[:offset_mapping][0] : nil
171
+ )
172
+ special_tokens_mask = model_outputs[:special_tokens_mask][0].numo
173
+
174
+ maxes = logits.max(axis: -1).expand_dims(-1)
175
+ shifted_exp = Numo::NMath.exp(logits - maxes)
176
+ scores = shifted_exp / shifted_exp.sum(axis: -1).expand_dims(-1)
177
+
178
+ if @framework == "tf"
179
+ raise Todo
180
+ end
181
+
182
+ pre_entities = gather_pre_entities(
183
+ sentence, input_ids, scores, offset_mapping, special_tokens_mask, aggregation_strategy
184
+ )
185
+ grouped_entities = aggregate(pre_entities, aggregation_strategy)
186
+ # Filter anything that is in self.ignore_labels
187
+ entities =
188
+ grouped_entities.select do |entity|
189
+ !ignore_labels.include?(entity[:entity]) && !ignore_labels.include?(entity[:entity_group])
190
+ end
191
+ all_entities.concat(entities)
192
+ end
193
+ num_chunks = all_outputs.length
194
+ if num_chunks > 1
195
+ all_entities = aggregate_overlapping_entities(all_entities)
196
+ end
197
+ all_entities
198
+ end
199
+
200
+ def gather_pre_entities(
201
+ sentence,
202
+ input_ids,
203
+ scores,
204
+ offset_mapping,
205
+ special_tokens_mask,
206
+ aggregation_strategy
207
+ )
208
+ pre_entities = []
209
+ scores.each_over_axis(0).with_index do |token_scores, idx|
210
+ # Filter special_tokens
211
+ if special_tokens_mask[idx] != 0
212
+ next
213
+ end
214
+
215
+ word = @tokenizer.convert_ids_to_tokens(input_ids[idx].to_i)
216
+ if !offset_mapping.nil?
217
+ start_ind, end_ind = offset_mapping[idx].to_a
218
+ if !start_ind.is_a?(Integer)
219
+ if @framework == "pt"
220
+ start_ind = start_ind.item
221
+ end_ind = end_ind.item
222
+ end
223
+ end
224
+ word_ref = sentence[start_ind...end_ind]
225
+ if @tokenizer.instance_variable_get(:@tokenizer).respond_to?(:continuing_subword_prefix)
226
+ # This is a BPE, word aware tokenizer, there is a correct way
227
+ # to fuse tokens
228
+ is_subword = word.length != word_ref.length
229
+ else
230
+ is_subword = start_ind > 0 && !sentence[(start_ind - 1)...(start_ind + 1)].include?(" ")
231
+ end
232
+
233
+ if input_ids[idx].to_i == @tokenizer.unk_token_id
234
+ word = word_ref
235
+ is_subword = false
236
+ end
237
+ else
238
+ start_ind = nil
239
+ end_ind = nil
240
+ is_subword = nil
241
+ end
242
+
243
+ pre_entity = {
244
+ word: word,
245
+ scores: token_scores,
246
+ start: start_ind,
247
+ end: end_ind,
248
+ index: idx,
249
+ is_subword: is_subword
250
+ }
251
+ pre_entities << pre_entity
252
+ end
253
+ pre_entities
254
+ end
255
+
256
+ def aggregate(pre_entities, aggregation_strategy)
257
+ if [AggregationStrategy::NONE, AggregationStrategy::SIMPLE].include?(aggregation_strategy)
258
+ entities = []
259
+ pre_entities.each do |pre_entity|
260
+ entity_idx = pre_entity[:scores].argmax
261
+ score = pre_entity[:scores][entity_idx]
262
+ entity = {
263
+ entity: @model.config.id2label[entity_idx],
264
+ score: score,
265
+ index: pre_entity[:index],
266
+ word: pre_entity[:word],
267
+ start: pre_entity[:start],
268
+ end: pre_entity[:end]
269
+ }
270
+ entities << entity
271
+ end
272
+ else
273
+ entities = aggregate_words(pre_entities, aggregation_strategy)
274
+ end
275
+
276
+ if aggregation_strategy == AggregationStrategy::NONE
277
+ return entities
278
+ end
279
+ group_entities(entities)
280
+ end
281
+ end
282
+ end
@@ -0,0 +1,33 @@
1
+ module Transformers
2
+ module ClassAttribute
3
+ def class_attribute(name, default = nil)
4
+ singleton_class.attr_writer name
5
+ var = "@#{name}"
6
+ instance_variable_set(var, default)
7
+ singleton_class.define_method(name) do
8
+ # ancestors includes current module
9
+ ancestors.find { |c| c.instance_variable_defined?(var) }.instance_variable_get(var)
10
+ end
11
+ define_method(name) do
12
+ self.class.send(name)
13
+ end
14
+ end
15
+ end
16
+
17
+ module Copy
18
+ def self.deepcopy(value, memo = {})
19
+ key = value.object_id
20
+ if !memo.key?(key)
21
+ copy = value.dup
22
+ memo[key] = copy
23
+ if value.is_a?(Hash)
24
+ copy.transform_keys! { |k| deepcopy(k, memo) }
25
+ copy.transform_values! { |v| deepcopy(v, memo) }
26
+ elsif value.is_a?(Array)
27
+ copy.map! { |v| deepcopy(v, memo) }
28
+ end
29
+ end
30
+ memo[key]
31
+ end
32
+ end
33
+ end
@@ -0,0 +1,37 @@
1
+ module Transformers
2
+ class SentenceTransformer
3
+ def initialize(model_id)
4
+ @model_id = model_id
5
+ @tokenizer = Transformers::AutoTokenizer.from_pretrained(model_id)
6
+ @model = Transformers::AutoModel.from_pretrained(model_id)
7
+ end
8
+
9
+ def encode(sentences)
10
+ singular = sentences.is_a?(String)
11
+ sentences = [sentences] if singular
12
+
13
+ input = @tokenizer.(sentences, padding: true, truncation: true, return_tensors: "pt")
14
+ output = Torch.no_grad { @model.(**input) }[0]
15
+
16
+ # TODO check modules.json
17
+ if [
18
+ "sentence-transformers/all-MiniLM-L6-v2",
19
+ "sentence-transformers/multi-qa-MiniLM-L6-cos-v1"
20
+ ].include?(@model_id)
21
+ output = mean_pooling(output, input[:attention_mask])
22
+ output = Torch::NN::Functional.normalize(output, p: 2, dim: 1).to_a
23
+ else
24
+ output = output[0.., 0].to_a
25
+ end
26
+
27
+ singular ? output[0] : output
28
+ end
29
+
30
+ private
31
+
32
+ def mean_pooling(output, attention_mask)
33
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(output.size).float
34
+ Torch.sum(output * input_mask_expanded, 1) / Torch.clamp(input_mask_expanded.sum(1), min: 1e-9)
35
+ end
36
+ end
37
+ end
@@ -0,0 +1,152 @@
1
+ # Copyright 2020 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ class PreTrainedTokenizer < PreTrainedTokenizerBase
17
+ def initialize(**kwargs)
18
+
19
+ # 2. init `_added_tokens_decoder` if child class did not
20
+ if !instance_variable_defined?(:@added_tokens_decoder)
21
+ @added_tokens_decoder = {}
22
+ end
23
+
24
+ # 3. if a `added_tokens_decoder` is passed, we are loading from a saved tokenizer, we overwrite
25
+ @added_tokens_decoder.merge!(kwargs.delete(:added_tokens_decoder) { {} })
26
+ @added_tokens_encoder = @added_tokens_decoder.to_h { |k, v| [k.content, v] }
27
+
28
+ # 4 init the parent class
29
+ super(**kwargs)
30
+ end
31
+
32
+ def is_fast
33
+ false
34
+ end
35
+
36
+ def vocab_size
37
+ raise NotImplementedError
38
+ end
39
+
40
+ def tokenize(text, **kwargs)
41
+ raise Todo
42
+ end
43
+
44
+ def _encode_plus(
45
+ text:,
46
+ text_pair: nil,
47
+ add_special_tokens: true,
48
+ padding_strategy: PaddingStrategy::DO_NOT_PAD,
49
+ truncation_strategy: TruncationStrategy::DO_NOT_TRUNCATE,
50
+ max_length: nil,
51
+ stride: 0,
52
+ is_split_into_words: false,
53
+ pad_to_multiple_of: nil,
54
+ return_tensors: nil,
55
+ return_token_type_ids: nil,
56
+ return_attention_mask: nil,
57
+ return_overflowing_tokens: false,
58
+ return_special_tokens_mask: false,
59
+ return_offsets_mapping: false,
60
+ return_length: false,
61
+ verbose: true,
62
+ **kwargs
63
+ )
64
+ get_input_ids = lambda do |text|
65
+ if text.is_a?(String)
66
+ tokens = tokenize(text, **kwargs)
67
+ convert_tokens_to_ids(tokens)
68
+ elsif text.is_a?(Array) && text.length > 0 && text[0].is_a?(String)
69
+ if is_split_into_words
70
+ raise Todo
71
+ else
72
+ convert_tokens_to_ids(text)
73
+ end
74
+ elsif text.is_a?(Array) && text.length > 0 && text[0].is_a?(Integer)
75
+ text
76
+ else
77
+ if is_split_into_words
78
+ raise ArgumentError,
79
+ "Input #{text} is not valid. Should be a string or a list/tuple of strings when" +
80
+ " `is_split_into_words=True`."
81
+ else
82
+ raise ArgumentError,
83
+ "Input #{text} is not valid. Should be a string, a list/tuple of strings or a list/tuple of" +
84
+ " integers."
85
+ end
86
+ end
87
+ end
88
+
89
+ if return_offsets_mapping
90
+ raise RuntimeError,
91
+ "return_offset_mapping is not available when using Ruby tokenizers. " +
92
+ "To use this feature, change your tokenizer to one deriving from " +
93
+ "Transformers::PreTrainedTokenizerFast. " +
94
+ "More information on available tokenizers at " +
95
+ "https://github.com/huggingface/transformers/pull/2674"
96
+ end
97
+
98
+ first_ids = get_input_ids.(text)
99
+ second_ids = !text_pair.nil? ? get_input_ids.(text_pair) : nil
100
+
101
+ prepare_for_model(
102
+ first_ids,
103
+ pair_ids: second_ids,
104
+ add_special_tokens: add_special_tokens,
105
+ padding: padding_strategy,
106
+ truncation: truncation_strategy,
107
+ max_length: max_length,
108
+ stride: stride,
109
+ pad_to_multiple_of: pad_to_multiple_of,
110
+ return_tensors: return_tensors,
111
+ prepend_batch_axis: true,
112
+ return_attention_mask: return_attention_mask,
113
+ return_token_type_ids: return_token_type_ids,
114
+ return_overflowing_tokens: return_overflowing_tokens,
115
+ return_special_tokens_mask: return_special_tokens_mask,
116
+ return_length: return_length,
117
+ verbose: verbose
118
+ )
119
+ end
120
+
121
+ def convert_tokens_to_ids(tokens)
122
+ if tokens.nil?
123
+ return nil
124
+ end
125
+
126
+ if tokens.is_a?(String)
127
+ return _convert_token_to_id_with_added_voc(tokens)
128
+ end
129
+
130
+ ids = []
131
+ tokens.each do |token|
132
+ ids << _convert_token_to_id_with_added_voc(token)
133
+ end
134
+ ids
135
+ end
136
+
137
+ def _convert_token_to_id_with_added_voc(token)
138
+ if token.nil?
139
+ return nil
140
+ end
141
+
142
+ if @added_tokens_encoder.include?(token)
143
+ return @added_tokens_encoder[token]
144
+ end
145
+ _convert_token_to_id(token)
146
+ end
147
+
148
+ def _convert_token_to_id(token)
149
+ raise NotImplementedError
150
+ end
151
+ end
152
+ end