transformers-rb 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +203 -0
- data/README.md +163 -0
- data/lib/transformers/activations.rb +57 -0
- data/lib/transformers/configuration_utils.rb +285 -0
- data/lib/transformers/convert_slow_tokenizer.rb +90 -0
- data/lib/transformers/data/processors/squad.rb +115 -0
- data/lib/transformers/dynamic_module_utils.rb +25 -0
- data/lib/transformers/feature_extraction_utils.rb +110 -0
- data/lib/transformers/hf_hub/constants.rb +71 -0
- data/lib/transformers/hf_hub/errors.rb +11 -0
- data/lib/transformers/hf_hub/file_download.rb +764 -0
- data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
- data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
- data/lib/transformers/image_processing_base.rb +169 -0
- data/lib/transformers/image_processing_utils.rb +63 -0
- data/lib/transformers/image_transforms.rb +208 -0
- data/lib/transformers/image_utils.rb +165 -0
- data/lib/transformers/modeling_outputs.rb +81 -0
- data/lib/transformers/modeling_utils.rb +888 -0
- data/lib/transformers/models/auto/auto_factory.rb +138 -0
- data/lib/transformers/models/auto/configuration_auto.rb +61 -0
- data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
- data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
- data/lib/transformers/models/auto/modeling_auto.rb +80 -0
- data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
- data/lib/transformers/models/bert/configuration_bert.rb +65 -0
- data/lib/transformers/models/bert/modeling_bert.rb +836 -0
- data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
- data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
- data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
- data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
- data/lib/transformers/models/vit/configuration_vit.rb +60 -0
- data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
- data/lib/transformers/models/vit/modeling_vit.rb +506 -0
- data/lib/transformers/pipelines/_init.rb +348 -0
- data/lib/transformers/pipelines/base.rb +301 -0
- data/lib/transformers/pipelines/feature_extraction.rb +47 -0
- data/lib/transformers/pipelines/image_classification.rb +110 -0
- data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
- data/lib/transformers/pipelines/pt_utils.rb +53 -0
- data/lib/transformers/pipelines/question_answering.rb +508 -0
- data/lib/transformers/pipelines/text_classification.rb +123 -0
- data/lib/transformers/pipelines/token_classification.rb +282 -0
- data/lib/transformers/ruby_utils.rb +33 -0
- data/lib/transformers/sentence_transformer.rb +37 -0
- data/lib/transformers/tokenization_utils.rb +152 -0
- data/lib/transformers/tokenization_utils_base.rb +937 -0
- data/lib/transformers/tokenization_utils_fast.rb +386 -0
- data/lib/transformers/torch_utils.rb +25 -0
- data/lib/transformers/utils/_init.rb +31 -0
- data/lib/transformers/utils/generic.rb +107 -0
- data/lib/transformers/utils/hub.rb +209 -0
- data/lib/transformers/utils/import_utils.rb +45 -0
- data/lib/transformers/utils/logging.rb +52 -0
- data/lib/transformers/version.rb +3 -0
- data/lib/transformers-rb.rb +1 -0
- data/lib/transformers.rb +100 -0
- data/licenses/LICENSE-huggingface-hub.txt +201 -0
- data/licenses/LICENSE-sentence-transformers.txt +201 -0
- data/licenses/NOTICE-sentence-transformers.txt +5 -0
- metadata +161 -0
@@ -0,0 +1,508 @@
|
|
1
|
+
module Transformers
|
2
|
+
class QuestionAnsweringArgumentHandler < ArgumentHandler
|
3
|
+
def normalize(item)
|
4
|
+
if item.is_a?(SquadExample)
|
5
|
+
return item
|
6
|
+
elsif item.is_a?(Hash)
|
7
|
+
[:question, :context].each do |k|
|
8
|
+
if !item.include?(k)
|
9
|
+
raise KeyError, "You need to provide a dictionary with keys {question:..., context:...}"
|
10
|
+
elsif item[k].nil?
|
11
|
+
raise ArgumentError, "`#{k}` cannot be nil"
|
12
|
+
elsif item[k].is_a?(String) && item[k].length == 0
|
13
|
+
raise ArgumentError, "`#{k}` cannot be empty"
|
14
|
+
end
|
15
|
+
end
|
16
|
+
|
17
|
+
return QuestionAnsweringPipeline.create_sample(**item)
|
18
|
+
end
|
19
|
+
raise ArgumentError, "#{item} argument needs to be of type (SquadExample, dict)"
|
20
|
+
end
|
21
|
+
|
22
|
+
def call(*args, **kwargs)
|
23
|
+
# Detect where the actual inputs are
|
24
|
+
if args.any?
|
25
|
+
if args.length == 1
|
26
|
+
inputs = args[0]
|
27
|
+
elsif args.length == 2 && args.all? { |el| el.is_a?(String) }
|
28
|
+
inputs = [{question: args[0], context: args[1]}]
|
29
|
+
else
|
30
|
+
inputs = args.to_a
|
31
|
+
end
|
32
|
+
elsif kwargs.include?(:question) && kwargs.include?(:context)
|
33
|
+
if kwargs[:question].is_a?(Array) && kwargs[:context].is_a?(String)
|
34
|
+
inputs = kwargs[:question].map { |q| {question: q, context: kwargs[:context]} }
|
35
|
+
elsif kwargs[:question].is_a?(Array) && kwargs[:context].is_a?(Array)
|
36
|
+
if kwargs[:question].length != kwargs[:context].length
|
37
|
+
raise ArgumentError, "Questions and contexts don't have the same lengths"
|
38
|
+
end
|
39
|
+
|
40
|
+
inputs = kwargs[:question].zip(kwargs[:context]).map { |q, c| {question: q, context: c} }
|
41
|
+
elsif kwargs[:question].is_a?(String) && kwargs[:context].is_a?(String)
|
42
|
+
inputs = [{question: kwargs[:question], context: kwargs[:context]}]
|
43
|
+
else
|
44
|
+
raise ArgumentError, "Arguments can't be understood"
|
45
|
+
end
|
46
|
+
else
|
47
|
+
raise ArgumentError, "Unknown arguments #{kwargs}"
|
48
|
+
end
|
49
|
+
|
50
|
+
# Normalize inputs
|
51
|
+
if inputs.is_a?(Hash)
|
52
|
+
inputs = [inputs]
|
53
|
+
elsif inputs.is_a?(Enumerable)
|
54
|
+
# Copy to avoid overriding arguments
|
55
|
+
inputs = inputs.to_a.dup
|
56
|
+
else
|
57
|
+
raise ArgumentError, "Invalid arguments #{kwargs}"
|
58
|
+
end
|
59
|
+
|
60
|
+
inputs.each_with_index do |item, i|
|
61
|
+
inputs[i] = normalize(item)
|
62
|
+
end
|
63
|
+
|
64
|
+
inputs
|
65
|
+
end
|
66
|
+
end
|
67
|
+
|
68
|
+
class QuestionAnsweringPipeline < ChunkPipeline
|
69
|
+
extend ClassAttribute
|
70
|
+
|
71
|
+
class_attribute :default_input_names, "question,context"
|
72
|
+
class_attribute :handle_impossible_answer, false
|
73
|
+
|
74
|
+
def initialize(
|
75
|
+
model,
|
76
|
+
tokenizer:,
|
77
|
+
modelcard: nil,
|
78
|
+
framework: nil,
|
79
|
+
task: "",
|
80
|
+
**kwargs
|
81
|
+
)
|
82
|
+
super(
|
83
|
+
model,
|
84
|
+
tokenizer: tokenizer,
|
85
|
+
modelcard: modelcard,
|
86
|
+
framework: framework,
|
87
|
+
task: task,
|
88
|
+
**kwargs
|
89
|
+
)
|
90
|
+
|
91
|
+
@args_parser = QuestionAnsweringArgumentHandler.new
|
92
|
+
check_model_type(MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES)
|
93
|
+
end
|
94
|
+
|
95
|
+
def self.create_sample(
|
96
|
+
question:, context:
|
97
|
+
)
|
98
|
+
if question.is_a?(Array)
|
99
|
+
question.zip(context).map { |q, c| SquadExample.new(nil, q, c, nil, nil, nil) }
|
100
|
+
else
|
101
|
+
SquadExample.new(nil, question, context, nil, nil, nil)
|
102
|
+
end
|
103
|
+
end
|
104
|
+
|
105
|
+
def _sanitize_parameters(
|
106
|
+
padding: nil,
|
107
|
+
topk: nil,
|
108
|
+
top_k: nil,
|
109
|
+
doc_stride: nil,
|
110
|
+
max_answer_len: nil,
|
111
|
+
max_seq_len: nil,
|
112
|
+
max_question_len: nil,
|
113
|
+
handle_impossible_answer: nil,
|
114
|
+
align_to_words: nil,
|
115
|
+
**kwargs
|
116
|
+
)
|
117
|
+
# Set defaults values
|
118
|
+
preprocess_params = {}
|
119
|
+
if !padding.nil?
|
120
|
+
preprocess_params[:padding] = padding
|
121
|
+
end
|
122
|
+
if !doc_stride.nil?
|
123
|
+
preprocess_params[:doc_stride] = doc_stride
|
124
|
+
end
|
125
|
+
if !max_question_len.nil?
|
126
|
+
preprocess_params[:max_question_len] = max_question_len
|
127
|
+
end
|
128
|
+
if !max_seq_len.nil?
|
129
|
+
preprocess_params[:max_seq_len] = max_seq_len
|
130
|
+
end
|
131
|
+
|
132
|
+
postprocess_params = {}
|
133
|
+
if !topk.nil? && top_k.nil?
|
134
|
+
warn("topk parameter is deprecated, use top_k instead")
|
135
|
+
top_k = topk
|
136
|
+
end
|
137
|
+
if !top_k.nil?
|
138
|
+
if top_k < 1
|
139
|
+
raise ArgumentError, "top_k parameter should be >= 1 (got #{top_k})"
|
140
|
+
end
|
141
|
+
postprocess_params[:top_k] = top_k
|
142
|
+
end
|
143
|
+
if !max_answer_len.nil?
|
144
|
+
if max_answer_len < 1
|
145
|
+
raise ArgumentError, "max_answer_len parameter should be >= 1 (got #{max_answer_len})"
|
146
|
+
end
|
147
|
+
end
|
148
|
+
if !max_answer_len.nil?
|
149
|
+
postprocess_params[:max_answer_len] = max_answer_len
|
150
|
+
end
|
151
|
+
if !handle_impossible_answer.nil?
|
152
|
+
postprocess_params[:handle_impossible_answer] = handle_impossible_answer
|
153
|
+
end
|
154
|
+
if !align_to_words.nil?
|
155
|
+
postprocess_params[:align_to_words] = align_to_words
|
156
|
+
end
|
157
|
+
[preprocess_params, {}, postprocess_params]
|
158
|
+
end
|
159
|
+
|
160
|
+
def call(*args, **kwargs)
|
161
|
+
examples = @args_parser.(*args, **kwargs)
|
162
|
+
if examples.is_a?(Array) && examples.length == 1
|
163
|
+
return super(examples[0], **kwargs)
|
164
|
+
end
|
165
|
+
super(examples, **kwargs)
|
166
|
+
end
|
167
|
+
|
168
|
+
def preprocess(example, padding: "do_not_pad", doc_stride: nil, max_question_len: 64, max_seq_len: nil)
|
169
|
+
# XXX: This is specal, args_parser will not handle anything generator or dataset like
|
170
|
+
# For those we expect user to send a simple valid example either directly as a SquadExample or simple dict.
|
171
|
+
# So we still need a little sanitation here.
|
172
|
+
if example.is_a?(Hash)
|
173
|
+
example = SquadExample.new(nil, example[:question], example[:context], nil, nil, nil)
|
174
|
+
end
|
175
|
+
|
176
|
+
if max_seq_len.nil?
|
177
|
+
max_seq_len = [@tokenizer.model_max_length, 384].min
|
178
|
+
end
|
179
|
+
if doc_stride.nil?
|
180
|
+
doc_stride = [max_seq_len.div(2), 128].min
|
181
|
+
end
|
182
|
+
|
183
|
+
if doc_stride > max_seq_len
|
184
|
+
raise ArgumentError, "`doc_stride` (#{doc_stride}) is larger than `max_seq_len` (#{max_seq_len})"
|
185
|
+
end
|
186
|
+
|
187
|
+
if !@tokenizer.is_fast
|
188
|
+
features = squad_convert_examples_to_features(
|
189
|
+
examples: [example],
|
190
|
+
tokenizer: @tokenizer,
|
191
|
+
max_seq_length: max_seq_len,
|
192
|
+
doc_stride: doc_stride,
|
193
|
+
max_query_length: max_question_len,
|
194
|
+
padding_strategy: PaddingStrategy::MAX_LENGTH,
|
195
|
+
is_training: false,
|
196
|
+
tqdm_enabled: false
|
197
|
+
)
|
198
|
+
else
|
199
|
+
# Define the side we want to truncate / pad and the text/pair sorting
|
200
|
+
question_first = @tokenizer.padding_side == "right"
|
201
|
+
|
202
|
+
encoded_inputs = @tokenizer.(
|
203
|
+
question_first ? example.question_text : example.context_text,
|
204
|
+
text_pair: question_first ? example.context_text : example.question_text,
|
205
|
+
padding: padding,
|
206
|
+
truncation: question_first ? "only_second" : "only_first",
|
207
|
+
max_length: max_seq_len,
|
208
|
+
stride: doc_stride,
|
209
|
+
return_token_type_ids: true,
|
210
|
+
return_overflowing_tokens: true,
|
211
|
+
return_offsets_mapping: true,
|
212
|
+
return_special_tokens_mask: true\
|
213
|
+
)
|
214
|
+
# When the input is too long, it's converted in a batch of inputs with overflowing tokens
|
215
|
+
# and a stride of overlap between the inputs. If a batch of inputs is given, a special output
|
216
|
+
# "overflow_to_sample_mapping" indicate which member of the encoded batch belong to which original batch sample.
|
217
|
+
# Here we tokenize examples one-by-one so we don't need to use "overflow_to_sample_mapping".
|
218
|
+
# "num_span" is the number of output samples generated from the overflowing tokens.
|
219
|
+
num_spans = encoded_inputs[:input_ids].length
|
220
|
+
|
221
|
+
# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
|
222
|
+
# We put 0 on the tokens from the context and 1 everywhere else (question and special tokens)
|
223
|
+
p_mask =
|
224
|
+
num_spans.times.map do |span_id|
|
225
|
+
encoded_inputs.sequence_ids(span_id).map { |tok| tok != (question_first ? 1 : 0) }
|
226
|
+
end
|
227
|
+
|
228
|
+
features = []
|
229
|
+
num_spans.times do |span_idx|
|
230
|
+
input_ids_span_idx = encoded_inputs[:input_ids][span_idx]
|
231
|
+
attention_mask_span_idx = (
|
232
|
+
encoded_inputs.include?(:attention_mask) ? encoded_inputs[:attention_mask][span_idx] : nil
|
233
|
+
)
|
234
|
+
token_type_ids_span_idx = (
|
235
|
+
encoded_inputs.include?(:token_type_ids) ? encoded_inputs[:token_type_ids][span_idx] : nil
|
236
|
+
)
|
237
|
+
# keep the cls_token unmasked (some models use it to indicate unanswerable questions)
|
238
|
+
if !@tokenizer.cls_token_id.nil?
|
239
|
+
cls_indices = (Numo::NArray.cast(input_ids_span_idx).eq(@tokenizer.cls_token_id)).where
|
240
|
+
cls_indices.each do |cls_index|
|
241
|
+
p_mask[span_idx][cls_index] = false
|
242
|
+
end
|
243
|
+
end
|
244
|
+
submask = p_mask[span_idx]
|
245
|
+
features <<
|
246
|
+
SquadFeatures.new(
|
247
|
+
input_ids: input_ids_span_idx,
|
248
|
+
attention_mask: attention_mask_span_idx,
|
249
|
+
token_type_ids: token_type_ids_span_idx,
|
250
|
+
p_mask: submask,
|
251
|
+
encoding: encoded_inputs[span_idx],
|
252
|
+
# We don't use the rest of the values - and actually
|
253
|
+
# for Fast tokenizer we could totally avoid using SquadFeatures and SquadExample
|
254
|
+
cls_index: nil,
|
255
|
+
token_to_orig_map: {},
|
256
|
+
example_index: 0,
|
257
|
+
unique_id: 0,
|
258
|
+
paragraph_len: 0,
|
259
|
+
token_is_max_context: 0,
|
260
|
+
tokens: [],
|
261
|
+
start_position: 0,
|
262
|
+
end_position: 0,
|
263
|
+
is_impossible: false,
|
264
|
+
qas_id: nil
|
265
|
+
)
|
266
|
+
end
|
267
|
+
end
|
268
|
+
|
269
|
+
features.each_with_index do |feature, i|
|
270
|
+
fw_args = {}
|
271
|
+
others = {}
|
272
|
+
model_input_names = @tokenizer.model_input_names + ["p_mask", "token_type_ids"]
|
273
|
+
|
274
|
+
feature.instance_variables.each do |k|
|
275
|
+
v = feature.instance_variable_get(k)
|
276
|
+
k = k[1..]
|
277
|
+
if model_input_names.include?(k)
|
278
|
+
if @framework == "tf"
|
279
|
+
raise Todo
|
280
|
+
elsif @framework == "pt"
|
281
|
+
tensor = Torch.tensor(v)
|
282
|
+
if tensor.dtype == Torch.int32
|
283
|
+
tensor = tensor.long
|
284
|
+
end
|
285
|
+
fw_args[k.to_sym] = tensor.unsqueeze(0)
|
286
|
+
end
|
287
|
+
else
|
288
|
+
others[k.to_sym] = v
|
289
|
+
end
|
290
|
+
end
|
291
|
+
|
292
|
+
is_last = i == features.length - 1
|
293
|
+
yield({example: example, is_last: is_last}.merge(fw_args).merge(others))
|
294
|
+
end
|
295
|
+
end
|
296
|
+
|
297
|
+
def _forward(inputs)
|
298
|
+
example = inputs[:example]
|
299
|
+
model_inputs = @tokenizer.model_input_names.to_h { |k| [k.to_sym, inputs[k.to_sym]] }
|
300
|
+
# `XXXForSequenceClassification` models should not use `use_cache=True` even if it's supported
|
301
|
+
# model_forward = @model.forward if self.framework == "pt" else self.model.call
|
302
|
+
# if "use_cache" in inspect.signature(model_forward).parameters.keys():
|
303
|
+
# model_inputs[:use_cache] = false
|
304
|
+
# end
|
305
|
+
output = @model.(**model_inputs)
|
306
|
+
if output.is_a?(Hash)
|
307
|
+
{start: output[:start_logits], end: output[:end_logits], example: example}.merge(inputs)
|
308
|
+
else
|
309
|
+
start, end_ = output[...2]
|
310
|
+
{start: start, end: end_, example: example}.merge(inputs)
|
311
|
+
end
|
312
|
+
end
|
313
|
+
|
314
|
+
def postprocess(
|
315
|
+
model_outputs,
|
316
|
+
top_k: 1,
|
317
|
+
handle_impossible_answer: false,
|
318
|
+
max_answer_len: 15,
|
319
|
+
align_to_words: true
|
320
|
+
)
|
321
|
+
min_null_score = 1000000 # large and positive
|
322
|
+
answers = []
|
323
|
+
model_outputs.each do |output|
|
324
|
+
start_ = output[:start]
|
325
|
+
end_ = output[:end]
|
326
|
+
example = output[:example]
|
327
|
+
p_mask = output[:p_mask]
|
328
|
+
attention_mask = (
|
329
|
+
!output[:attention_mask].nil? ? output[:attention_mask].numo : nil
|
330
|
+
)
|
331
|
+
|
332
|
+
starts, ends, scores, min_null_score = select_starts_ends(
|
333
|
+
start_, end_, p_mask, attention_mask, min_null_score, top_k, handle_impossible_answer, max_answer_len
|
334
|
+
)
|
335
|
+
|
336
|
+
if !@tokenizer.is_fast
|
337
|
+
raise Todo
|
338
|
+
else
|
339
|
+
# Convert the answer (tokens) back to the original text
|
340
|
+
# Score: score from the model
|
341
|
+
# Start: Index of the first character of the answer in the context string
|
342
|
+
# End: Index of the character following the last character of the answer in the context string
|
343
|
+
# Answer: Plain text of the answer
|
344
|
+
question_first = @tokenizer.padding_side == "right"
|
345
|
+
enc = output[:encoding]
|
346
|
+
|
347
|
+
# Encoding was *not* padded, input_ids *might*.
|
348
|
+
# It doesn't make a difference unless we're padding on
|
349
|
+
# the left hand side, since now we have different offsets
|
350
|
+
# everywhere.
|
351
|
+
if @tokenizer.padding_side == "left"
|
352
|
+
offset = output[:input_ids].eq(@tokenizer.pad_token_id).numo.sum
|
353
|
+
else
|
354
|
+
offset = 0
|
355
|
+
end
|
356
|
+
|
357
|
+
# Sometimes the max probability token is in the middle of a word so:
|
358
|
+
# - we start by finding the right word containing the token with `token_to_word`
|
359
|
+
# - then we convert this word in a character span with `word_to_chars`
|
360
|
+
sequence_index = question_first ? 1 : 0
|
361
|
+
starts.to_a.zip(ends.to_a, scores.to_a) do |s, e, score|
|
362
|
+
s = s - offset
|
363
|
+
e = e - offset
|
364
|
+
|
365
|
+
start_index, end_index = get_indices(enc, s, e, sequence_index, align_to_words)
|
366
|
+
|
367
|
+
answers <<
|
368
|
+
{
|
369
|
+
score: score[0],
|
370
|
+
start: start_index,
|
371
|
+
end: end_index,
|
372
|
+
answer: example.context_text[start_index...end_index]
|
373
|
+
}
|
374
|
+
end
|
375
|
+
end
|
376
|
+
end
|
377
|
+
|
378
|
+
if handle_impossible_answer
|
379
|
+
answers << {score: min_null_score, start: 0, end: 0, answer: ""}
|
380
|
+
end
|
381
|
+
answers = answers.sort_by { |x| -x[:score] }[...top_k]
|
382
|
+
if answers.length == 1
|
383
|
+
return answers[0]
|
384
|
+
end
|
385
|
+
answers
|
386
|
+
end
|
387
|
+
|
388
|
+
def get_indices(
|
389
|
+
enc, s, e, sequence_index, align_to_words
|
390
|
+
)
|
391
|
+
if align_to_words
|
392
|
+
begin
|
393
|
+
start_word = enc.token_to_word(s)
|
394
|
+
end_word = enc.token_to_word(e)
|
395
|
+
start_index = enc.word_to_chars(start_word, sequence_index)[0]
|
396
|
+
end_index = enc.word_to_chars(end_word, sequence_index)[1]
|
397
|
+
rescue
|
398
|
+
# TODO
|
399
|
+
raise
|
400
|
+
# Some tokenizers don't really handle words. Keep to offsets then.
|
401
|
+
start_index = enc.offsets[s][0]
|
402
|
+
end_index = enc.offsets[e][1]
|
403
|
+
end
|
404
|
+
else
|
405
|
+
start_index = enc.offsets[s][0]
|
406
|
+
end_index = enc.offsets[e][1]
|
407
|
+
end
|
408
|
+
[start_index, end_index]
|
409
|
+
end
|
410
|
+
|
411
|
+
def decode_spans(
|
412
|
+
start, end_, topk, max_answer_len, undesired_tokens
|
413
|
+
)
|
414
|
+
# Ensure we have batch axis
|
415
|
+
if start.ndim == 1
|
416
|
+
start = start[nil]
|
417
|
+
end
|
418
|
+
|
419
|
+
if end_.ndim == 1
|
420
|
+
end_ = end_[nil]
|
421
|
+
end
|
422
|
+
|
423
|
+
# Compute the score of each tuple(start, end) to be the real answer
|
424
|
+
outer = start.expand_dims(-1).dot(end_.expand_dims(1))
|
425
|
+
|
426
|
+
# Remove candidate with end < start and end - start > max_answer_len
|
427
|
+
candidates = outer.triu.tril(max_answer_len - 1)
|
428
|
+
|
429
|
+
# Inspired by Chen & al. (https://github.com/facebookresearch/DrQA)
|
430
|
+
scores_flat = candidates.flatten
|
431
|
+
if topk == 1
|
432
|
+
idx_sort = [scores_flat.argmax]
|
433
|
+
elsif scores_flat.length < topk
|
434
|
+
raise Todo
|
435
|
+
else
|
436
|
+
raise Todo
|
437
|
+
end
|
438
|
+
|
439
|
+
starts, ends = unravel_index(idx_sort, candidates.shape)[1..]
|
440
|
+
desired_spans = isin(starts, undesired_tokens.where) & isin(ends, undesired_tokens.where)
|
441
|
+
starts = starts[desired_spans]
|
442
|
+
ends = ends[desired_spans]
|
443
|
+
scores = candidates[0, starts, ends]
|
444
|
+
|
445
|
+
[starts, ends, scores]
|
446
|
+
end
|
447
|
+
|
448
|
+
def unravel_index(indices, shape)
|
449
|
+
indices = Numo::NArray.cast(indices)
|
450
|
+
result = []
|
451
|
+
factor = 1
|
452
|
+
shape.size.times do |i|
|
453
|
+
result.unshift(indices / factor % shape[-1 - i])
|
454
|
+
factor *= shape[-1 - i]
|
455
|
+
end
|
456
|
+
result
|
457
|
+
end
|
458
|
+
|
459
|
+
def isin(element, test_elements)
|
460
|
+
test_elements = test_elements.to_a
|
461
|
+
Numo::Bit.cast(element.to_a.map { |e| test_elements.include?(e) })
|
462
|
+
end
|
463
|
+
|
464
|
+
def select_starts_ends(
|
465
|
+
start,
|
466
|
+
end_,
|
467
|
+
p_mask,
|
468
|
+
attention_mask,
|
469
|
+
min_null_score = 1000000,
|
470
|
+
top_k = 1,
|
471
|
+
handle_impossible_answer = false,
|
472
|
+
max_answer_len = 15
|
473
|
+
)
|
474
|
+
# Ensure padded tokens & question tokens cannot belong to the set of candidate answers.
|
475
|
+
undesired_tokens = ~p_mask.numo
|
476
|
+
|
477
|
+
if !attention_mask.nil?
|
478
|
+
undesired_tokens = undesired_tokens & attention_mask
|
479
|
+
end
|
480
|
+
|
481
|
+
# Generate mask
|
482
|
+
undesired_tokens_mask = undesired_tokens.eq(0)
|
483
|
+
|
484
|
+
# Make sure non-context indexes in the tensor cannot contribute to the softmax
|
485
|
+
start = start.numo
|
486
|
+
end_ = end_.numo
|
487
|
+
start[undesired_tokens_mask] = -10000.0
|
488
|
+
end_[undesired_tokens_mask] = -10000.0
|
489
|
+
|
490
|
+
# Normalize logits and spans to retrieve the answer
|
491
|
+
start = Numo::NMath.exp(start - start.max(axis: -1, keepdims: true))
|
492
|
+
start = start / start.sum
|
493
|
+
|
494
|
+
end_ = Numo::NMath.exp(end_ - end_.max(axis: -1, keepdims: true))
|
495
|
+
end_ = end_ / end_.sum
|
496
|
+
|
497
|
+
if handle_impossible_answer
|
498
|
+
min_null_score = [min_null_score, (start[0, 0] * end_[0, 0]).item].min
|
499
|
+
end
|
500
|
+
|
501
|
+
# Mask CLS
|
502
|
+
start[0, 0] = end_[0, 0] = 0.0
|
503
|
+
|
504
|
+
starts, ends, scores = decode_spans(start, end_, top_k, max_answer_len, undesired_tokens)
|
505
|
+
[starts, ends, scores, min_null_score]
|
506
|
+
end
|
507
|
+
end
|
508
|
+
end
|
@@ -0,0 +1,123 @@
|
|
1
|
+
module Transformers
|
2
|
+
class TextClassificationPipeline < Pipeline
|
3
|
+
def initialize(*args, **kwargs)
|
4
|
+
super
|
5
|
+
|
6
|
+
check_model_type(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES)
|
7
|
+
end
|
8
|
+
|
9
|
+
private
|
10
|
+
|
11
|
+
def _sanitize_parameters(return_all_scores: nil, function_to_apply: nil, top_k: "", **tokenizer_kwargs)
|
12
|
+
# Using "" as default argument because we're going to use `top_k=None` in user code to declare
|
13
|
+
# "No top_k"
|
14
|
+
preprocess_params = tokenizer_kwargs
|
15
|
+
|
16
|
+
postprocess_params = {}
|
17
|
+
if @model.config.respond_to?(:return_all_scores) && return_all_scores.nil?
|
18
|
+
return_all_scores = @model.config.return_all_scores
|
19
|
+
end
|
20
|
+
|
21
|
+
if top_k.is_a?(Integer) || top_k.nil?
|
22
|
+
postprocess_params[:top_k] = top_k
|
23
|
+
postprocess_params[:_legacy] = false
|
24
|
+
elsif !return_all_scores.nil?
|
25
|
+
warn(
|
26
|
+
"`return_all_scores` is now deprecated, if want a similar functionality use `top_k: nil` instead of" +
|
27
|
+
" `return_all_scores: true` or `top_k: 1` instead of `return_all_scores: false`.",
|
28
|
+
)
|
29
|
+
if return_all_scores
|
30
|
+
postprocess_params[:top_k] = nil
|
31
|
+
else
|
32
|
+
postprocess_params[:top_k] = 1
|
33
|
+
end
|
34
|
+
end
|
35
|
+
|
36
|
+
if function_to_apply.is_a?(String)
|
37
|
+
function_to_apply = ClassificationFunction.new(function_to_apply.upcase).to_s
|
38
|
+
end
|
39
|
+
|
40
|
+
if !function_to_apply.nil?
|
41
|
+
postprocess_params[:function_to_apply] = function_to_apply
|
42
|
+
end
|
43
|
+
[preprocess_params, {}, postprocess_params]
|
44
|
+
end
|
45
|
+
|
46
|
+
def preprocess(inputs, **tokenizer_kwargs)
|
47
|
+
return_tensors = @framework
|
48
|
+
if inputs.is_a?(Hash)
|
49
|
+
return @tokenizer.(**inputs, return_tensors: return_tensors, **tokenizer_kwargs)
|
50
|
+
elsif inputs.is_a?(Array) && inputs.length == 1 && inputs[0].is_a?(Array) && inputs[0].length == 2
|
51
|
+
# It used to be valid to use a list of list of list for text pairs, keeping this path for BC
|
52
|
+
return @tokenizer.(
|
53
|
+
inputs[0][0], text_pair: inputs[0][1], return_tensors: return_tensors, **tokenizer_kwargs
|
54
|
+
)
|
55
|
+
elsif inputs.is_a?(Array)
|
56
|
+
# This is likely an invalid usage of the pipeline attempting to pass text pairs.
|
57
|
+
raise ArgumentError,
|
58
|
+
"The pipeline received invalid inputs, if you are trying to send text pairs, you can try to send a" +
|
59
|
+
' dictionary `{"text": "My text", "text_pair": "My pair"}` in order to send a text pair.'
|
60
|
+
end
|
61
|
+
@tokenizer.(inputs, return_tensors: return_tensors, **tokenizer_kwargs)
|
62
|
+
end
|
63
|
+
|
64
|
+
def _forward(model_inputs)
|
65
|
+
@model.(**model_inputs.to_h)
|
66
|
+
end
|
67
|
+
|
68
|
+
def postprocess(model_outputs, function_to_apply: nil, top_k: 1, _legacy: true)
|
69
|
+
if function_to_apply.nil?
|
70
|
+
if @model.config.problem_type == "multi_label_classification" || @model.config.num_labels == 1
|
71
|
+
function_to_apply = ClassificationFunction::SIGMOID
|
72
|
+
elsif @model.config.problem_type == "single_label_classification" || @model.config.num_labels > 1
|
73
|
+
function_to_apply = ClassificationFunction::SOFTMAX
|
74
|
+
elsif @model.config.instance_variable_defined?(:@function_to_apply) && function_to_apply.nil?
|
75
|
+
function_to_apply = @model.config.function_to_apply
|
76
|
+
else
|
77
|
+
function_to_apply = ClassificationFunction::NONE
|
78
|
+
end
|
79
|
+
end
|
80
|
+
|
81
|
+
outputs = model_outputs["logits"][0]
|
82
|
+
outputs = outputs.numo
|
83
|
+
|
84
|
+
if function_to_apply == ClassificationFunction::SIGMOID
|
85
|
+
scores = sigmoid(outputs)
|
86
|
+
elsif function_to_apply == ClassificationFunction::SOFTMAX
|
87
|
+
scores = softmax(outputs)
|
88
|
+
elsif function_to_apply == ClassificationFunction::NONE
|
89
|
+
scores = outputs
|
90
|
+
else
|
91
|
+
raise ArgumentError, "Unrecognized `function_to_apply` argument: #{function_to_apply}"
|
92
|
+
end
|
93
|
+
|
94
|
+
if top_k == 1 && _legacy
|
95
|
+
return {label: @model.config.id2label[scores.argmax], score: scores.max}
|
96
|
+
end
|
97
|
+
|
98
|
+
dict_scores =
|
99
|
+
scores.to_a.map.with_index do |score, i|
|
100
|
+
{label: @model.config.id2label[i], score: score}
|
101
|
+
end
|
102
|
+
if !_legacy
|
103
|
+
dict_scores.sort_by! { |x| -x[:score] }
|
104
|
+
if !top_k.nil?
|
105
|
+
dict_scores = dict_scores.first(top_k)
|
106
|
+
end
|
107
|
+
end
|
108
|
+
dict_scores
|
109
|
+
end
|
110
|
+
|
111
|
+
private
|
112
|
+
|
113
|
+
def sigmoid(_outputs)
|
114
|
+
1.0 / (1.0 + Numo::NMath.exp(-_outputs))
|
115
|
+
end
|
116
|
+
|
117
|
+
def softmax(_outputs)
|
118
|
+
maxes = _outputs.max(axis: -1, keepdims: true)
|
119
|
+
shifted_exp = Numo::NMath.exp(_outputs - maxes)
|
120
|
+
shifted_exp / shifted_exp.sum(axis: -1, keepdims: true)
|
121
|
+
end
|
122
|
+
end
|
123
|
+
end
|