transformers-rb 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +203 -0
- data/README.md +163 -0
- data/lib/transformers/activations.rb +57 -0
- data/lib/transformers/configuration_utils.rb +285 -0
- data/lib/transformers/convert_slow_tokenizer.rb +90 -0
- data/lib/transformers/data/processors/squad.rb +115 -0
- data/lib/transformers/dynamic_module_utils.rb +25 -0
- data/lib/transformers/feature_extraction_utils.rb +110 -0
- data/lib/transformers/hf_hub/constants.rb +71 -0
- data/lib/transformers/hf_hub/errors.rb +11 -0
- data/lib/transformers/hf_hub/file_download.rb +764 -0
- data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
- data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
- data/lib/transformers/image_processing_base.rb +169 -0
- data/lib/transformers/image_processing_utils.rb +63 -0
- data/lib/transformers/image_transforms.rb +208 -0
- data/lib/transformers/image_utils.rb +165 -0
- data/lib/transformers/modeling_outputs.rb +81 -0
- data/lib/transformers/modeling_utils.rb +888 -0
- data/lib/transformers/models/auto/auto_factory.rb +138 -0
- data/lib/transformers/models/auto/configuration_auto.rb +61 -0
- data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
- data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
- data/lib/transformers/models/auto/modeling_auto.rb +80 -0
- data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
- data/lib/transformers/models/bert/configuration_bert.rb +65 -0
- data/lib/transformers/models/bert/modeling_bert.rb +836 -0
- data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
- data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
- data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
- data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
- data/lib/transformers/models/vit/configuration_vit.rb +60 -0
- data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
- data/lib/transformers/models/vit/modeling_vit.rb +506 -0
- data/lib/transformers/pipelines/_init.rb +348 -0
- data/lib/transformers/pipelines/base.rb +301 -0
- data/lib/transformers/pipelines/feature_extraction.rb +47 -0
- data/lib/transformers/pipelines/image_classification.rb +110 -0
- data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
- data/lib/transformers/pipelines/pt_utils.rb +53 -0
- data/lib/transformers/pipelines/question_answering.rb +508 -0
- data/lib/transformers/pipelines/text_classification.rb +123 -0
- data/lib/transformers/pipelines/token_classification.rb +282 -0
- data/lib/transformers/ruby_utils.rb +33 -0
- data/lib/transformers/sentence_transformer.rb +37 -0
- data/lib/transformers/tokenization_utils.rb +152 -0
- data/lib/transformers/tokenization_utils_base.rb +937 -0
- data/lib/transformers/tokenization_utils_fast.rb +386 -0
- data/lib/transformers/torch_utils.rb +25 -0
- data/lib/transformers/utils/_init.rb +31 -0
- data/lib/transformers/utils/generic.rb +107 -0
- data/lib/transformers/utils/hub.rb +209 -0
- data/lib/transformers/utils/import_utils.rb +45 -0
- data/lib/transformers/utils/logging.rb +52 -0
- data/lib/transformers/version.rb +3 -0
- data/lib/transformers-rb.rb +1 -0
- data/lib/transformers.rb +100 -0
- data/licenses/LICENSE-huggingface-hub.txt +201 -0
- data/licenses/LICENSE-sentence-transformers.txt +201 -0
- data/licenses/NOTICE-sentence-transformers.txt +5 -0
- metadata +161 -0
@@ -0,0 +1,508 @@
|
|
1
|
+
module Transformers
|
2
|
+
class QuestionAnsweringArgumentHandler < ArgumentHandler
|
3
|
+
def normalize(item)
|
4
|
+
if item.is_a?(SquadExample)
|
5
|
+
return item
|
6
|
+
elsif item.is_a?(Hash)
|
7
|
+
[:question, :context].each do |k|
|
8
|
+
if !item.include?(k)
|
9
|
+
raise KeyError, "You need to provide a dictionary with keys {question:..., context:...}"
|
10
|
+
elsif item[k].nil?
|
11
|
+
raise ArgumentError, "`#{k}` cannot be nil"
|
12
|
+
elsif item[k].is_a?(String) && item[k].length == 0
|
13
|
+
raise ArgumentError, "`#{k}` cannot be empty"
|
14
|
+
end
|
15
|
+
end
|
16
|
+
|
17
|
+
return QuestionAnsweringPipeline.create_sample(**item)
|
18
|
+
end
|
19
|
+
raise ArgumentError, "#{item} argument needs to be of type (SquadExample, dict)"
|
20
|
+
end
|
21
|
+
|
22
|
+
def call(*args, **kwargs)
|
23
|
+
# Detect where the actual inputs are
|
24
|
+
if args.any?
|
25
|
+
if args.length == 1
|
26
|
+
inputs = args[0]
|
27
|
+
elsif args.length == 2 && args.all? { |el| el.is_a?(String) }
|
28
|
+
inputs = [{question: args[0], context: args[1]}]
|
29
|
+
else
|
30
|
+
inputs = args.to_a
|
31
|
+
end
|
32
|
+
elsif kwargs.include?(:question) && kwargs.include?(:context)
|
33
|
+
if kwargs[:question].is_a?(Array) && kwargs[:context].is_a?(String)
|
34
|
+
inputs = kwargs[:question].map { |q| {question: q, context: kwargs[:context]} }
|
35
|
+
elsif kwargs[:question].is_a?(Array) && kwargs[:context].is_a?(Array)
|
36
|
+
if kwargs[:question].length != kwargs[:context].length
|
37
|
+
raise ArgumentError, "Questions and contexts don't have the same lengths"
|
38
|
+
end
|
39
|
+
|
40
|
+
inputs = kwargs[:question].zip(kwargs[:context]).map { |q, c| {question: q, context: c} }
|
41
|
+
elsif kwargs[:question].is_a?(String) && kwargs[:context].is_a?(String)
|
42
|
+
inputs = [{question: kwargs[:question], context: kwargs[:context]}]
|
43
|
+
else
|
44
|
+
raise ArgumentError, "Arguments can't be understood"
|
45
|
+
end
|
46
|
+
else
|
47
|
+
raise ArgumentError, "Unknown arguments #{kwargs}"
|
48
|
+
end
|
49
|
+
|
50
|
+
# Normalize inputs
|
51
|
+
if inputs.is_a?(Hash)
|
52
|
+
inputs = [inputs]
|
53
|
+
elsif inputs.is_a?(Enumerable)
|
54
|
+
# Copy to avoid overriding arguments
|
55
|
+
inputs = inputs.to_a.dup
|
56
|
+
else
|
57
|
+
raise ArgumentError, "Invalid arguments #{kwargs}"
|
58
|
+
end
|
59
|
+
|
60
|
+
inputs.each_with_index do |item, i|
|
61
|
+
inputs[i] = normalize(item)
|
62
|
+
end
|
63
|
+
|
64
|
+
inputs
|
65
|
+
end
|
66
|
+
end
|
67
|
+
|
68
|
+
class QuestionAnsweringPipeline < ChunkPipeline
|
69
|
+
extend ClassAttribute
|
70
|
+
|
71
|
+
class_attribute :default_input_names, "question,context"
|
72
|
+
class_attribute :handle_impossible_answer, false
|
73
|
+
|
74
|
+
def initialize(
|
75
|
+
model,
|
76
|
+
tokenizer:,
|
77
|
+
modelcard: nil,
|
78
|
+
framework: nil,
|
79
|
+
task: "",
|
80
|
+
**kwargs
|
81
|
+
)
|
82
|
+
super(
|
83
|
+
model,
|
84
|
+
tokenizer: tokenizer,
|
85
|
+
modelcard: modelcard,
|
86
|
+
framework: framework,
|
87
|
+
task: task,
|
88
|
+
**kwargs
|
89
|
+
)
|
90
|
+
|
91
|
+
@args_parser = QuestionAnsweringArgumentHandler.new
|
92
|
+
check_model_type(MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES)
|
93
|
+
end
|
94
|
+
|
95
|
+
def self.create_sample(
|
96
|
+
question:, context:
|
97
|
+
)
|
98
|
+
if question.is_a?(Array)
|
99
|
+
question.zip(context).map { |q, c| SquadExample.new(nil, q, c, nil, nil, nil) }
|
100
|
+
else
|
101
|
+
SquadExample.new(nil, question, context, nil, nil, nil)
|
102
|
+
end
|
103
|
+
end
|
104
|
+
|
105
|
+
def _sanitize_parameters(
|
106
|
+
padding: nil,
|
107
|
+
topk: nil,
|
108
|
+
top_k: nil,
|
109
|
+
doc_stride: nil,
|
110
|
+
max_answer_len: nil,
|
111
|
+
max_seq_len: nil,
|
112
|
+
max_question_len: nil,
|
113
|
+
handle_impossible_answer: nil,
|
114
|
+
align_to_words: nil,
|
115
|
+
**kwargs
|
116
|
+
)
|
117
|
+
# Set defaults values
|
118
|
+
preprocess_params = {}
|
119
|
+
if !padding.nil?
|
120
|
+
preprocess_params[:padding] = padding
|
121
|
+
end
|
122
|
+
if !doc_stride.nil?
|
123
|
+
preprocess_params[:doc_stride] = doc_stride
|
124
|
+
end
|
125
|
+
if !max_question_len.nil?
|
126
|
+
preprocess_params[:max_question_len] = max_question_len
|
127
|
+
end
|
128
|
+
if !max_seq_len.nil?
|
129
|
+
preprocess_params[:max_seq_len] = max_seq_len
|
130
|
+
end
|
131
|
+
|
132
|
+
postprocess_params = {}
|
133
|
+
if !topk.nil? && top_k.nil?
|
134
|
+
warn("topk parameter is deprecated, use top_k instead")
|
135
|
+
top_k = topk
|
136
|
+
end
|
137
|
+
if !top_k.nil?
|
138
|
+
if top_k < 1
|
139
|
+
raise ArgumentError, "top_k parameter should be >= 1 (got #{top_k})"
|
140
|
+
end
|
141
|
+
postprocess_params[:top_k] = top_k
|
142
|
+
end
|
143
|
+
if !max_answer_len.nil?
|
144
|
+
if max_answer_len < 1
|
145
|
+
raise ArgumentError, "max_answer_len parameter should be >= 1 (got #{max_answer_len})"
|
146
|
+
end
|
147
|
+
end
|
148
|
+
if !max_answer_len.nil?
|
149
|
+
postprocess_params[:max_answer_len] = max_answer_len
|
150
|
+
end
|
151
|
+
if !handle_impossible_answer.nil?
|
152
|
+
postprocess_params[:handle_impossible_answer] = handle_impossible_answer
|
153
|
+
end
|
154
|
+
if !align_to_words.nil?
|
155
|
+
postprocess_params[:align_to_words] = align_to_words
|
156
|
+
end
|
157
|
+
[preprocess_params, {}, postprocess_params]
|
158
|
+
end
|
159
|
+
|
160
|
+
def call(*args, **kwargs)
|
161
|
+
examples = @args_parser.(*args, **kwargs)
|
162
|
+
if examples.is_a?(Array) && examples.length == 1
|
163
|
+
return super(examples[0], **kwargs)
|
164
|
+
end
|
165
|
+
super(examples, **kwargs)
|
166
|
+
end
|
167
|
+
|
168
|
+
def preprocess(example, padding: "do_not_pad", doc_stride: nil, max_question_len: 64, max_seq_len: nil)
|
169
|
+
# XXX: This is specal, args_parser will not handle anything generator or dataset like
|
170
|
+
# For those we expect user to send a simple valid example either directly as a SquadExample or simple dict.
|
171
|
+
# So we still need a little sanitation here.
|
172
|
+
if example.is_a?(Hash)
|
173
|
+
example = SquadExample.new(nil, example[:question], example[:context], nil, nil, nil)
|
174
|
+
end
|
175
|
+
|
176
|
+
if max_seq_len.nil?
|
177
|
+
max_seq_len = [@tokenizer.model_max_length, 384].min
|
178
|
+
end
|
179
|
+
if doc_stride.nil?
|
180
|
+
doc_stride = [max_seq_len.div(2), 128].min
|
181
|
+
end
|
182
|
+
|
183
|
+
if doc_stride > max_seq_len
|
184
|
+
raise ArgumentError, "`doc_stride` (#{doc_stride}) is larger than `max_seq_len` (#{max_seq_len})"
|
185
|
+
end
|
186
|
+
|
187
|
+
if !@tokenizer.is_fast
|
188
|
+
features = squad_convert_examples_to_features(
|
189
|
+
examples: [example],
|
190
|
+
tokenizer: @tokenizer,
|
191
|
+
max_seq_length: max_seq_len,
|
192
|
+
doc_stride: doc_stride,
|
193
|
+
max_query_length: max_question_len,
|
194
|
+
padding_strategy: PaddingStrategy::MAX_LENGTH,
|
195
|
+
is_training: false,
|
196
|
+
tqdm_enabled: false
|
197
|
+
)
|
198
|
+
else
|
199
|
+
# Define the side we want to truncate / pad and the text/pair sorting
|
200
|
+
question_first = @tokenizer.padding_side == "right"
|
201
|
+
|
202
|
+
encoded_inputs = @tokenizer.(
|
203
|
+
question_first ? example.question_text : example.context_text,
|
204
|
+
text_pair: question_first ? example.context_text : example.question_text,
|
205
|
+
padding: padding,
|
206
|
+
truncation: question_first ? "only_second" : "only_first",
|
207
|
+
max_length: max_seq_len,
|
208
|
+
stride: doc_stride,
|
209
|
+
return_token_type_ids: true,
|
210
|
+
return_overflowing_tokens: true,
|
211
|
+
return_offsets_mapping: true,
|
212
|
+
return_special_tokens_mask: true\
|
213
|
+
)
|
214
|
+
# When the input is too long, it's converted in a batch of inputs with overflowing tokens
|
215
|
+
# and a stride of overlap between the inputs. If a batch of inputs is given, a special output
|
216
|
+
# "overflow_to_sample_mapping" indicate which member of the encoded batch belong to which original batch sample.
|
217
|
+
# Here we tokenize examples one-by-one so we don't need to use "overflow_to_sample_mapping".
|
218
|
+
# "num_span" is the number of output samples generated from the overflowing tokens.
|
219
|
+
num_spans = encoded_inputs[:input_ids].length
|
220
|
+
|
221
|
+
# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
|
222
|
+
# We put 0 on the tokens from the context and 1 everywhere else (question and special tokens)
|
223
|
+
p_mask =
|
224
|
+
num_spans.times.map do |span_id|
|
225
|
+
encoded_inputs.sequence_ids(span_id).map { |tok| tok != (question_first ? 1 : 0) }
|
226
|
+
end
|
227
|
+
|
228
|
+
features = []
|
229
|
+
num_spans.times do |span_idx|
|
230
|
+
input_ids_span_idx = encoded_inputs[:input_ids][span_idx]
|
231
|
+
attention_mask_span_idx = (
|
232
|
+
encoded_inputs.include?(:attention_mask) ? encoded_inputs[:attention_mask][span_idx] : nil
|
233
|
+
)
|
234
|
+
token_type_ids_span_idx = (
|
235
|
+
encoded_inputs.include?(:token_type_ids) ? encoded_inputs[:token_type_ids][span_idx] : nil
|
236
|
+
)
|
237
|
+
# keep the cls_token unmasked (some models use it to indicate unanswerable questions)
|
238
|
+
if !@tokenizer.cls_token_id.nil?
|
239
|
+
cls_indices = (Numo::NArray.cast(input_ids_span_idx).eq(@tokenizer.cls_token_id)).where
|
240
|
+
cls_indices.each do |cls_index|
|
241
|
+
p_mask[span_idx][cls_index] = false
|
242
|
+
end
|
243
|
+
end
|
244
|
+
submask = p_mask[span_idx]
|
245
|
+
features <<
|
246
|
+
SquadFeatures.new(
|
247
|
+
input_ids: input_ids_span_idx,
|
248
|
+
attention_mask: attention_mask_span_idx,
|
249
|
+
token_type_ids: token_type_ids_span_idx,
|
250
|
+
p_mask: submask,
|
251
|
+
encoding: encoded_inputs[span_idx],
|
252
|
+
# We don't use the rest of the values - and actually
|
253
|
+
# for Fast tokenizer we could totally avoid using SquadFeatures and SquadExample
|
254
|
+
cls_index: nil,
|
255
|
+
token_to_orig_map: {},
|
256
|
+
example_index: 0,
|
257
|
+
unique_id: 0,
|
258
|
+
paragraph_len: 0,
|
259
|
+
token_is_max_context: 0,
|
260
|
+
tokens: [],
|
261
|
+
start_position: 0,
|
262
|
+
end_position: 0,
|
263
|
+
is_impossible: false,
|
264
|
+
qas_id: nil
|
265
|
+
)
|
266
|
+
end
|
267
|
+
end
|
268
|
+
|
269
|
+
features.each_with_index do |feature, i|
|
270
|
+
fw_args = {}
|
271
|
+
others = {}
|
272
|
+
model_input_names = @tokenizer.model_input_names + ["p_mask", "token_type_ids"]
|
273
|
+
|
274
|
+
feature.instance_variables.each do |k|
|
275
|
+
v = feature.instance_variable_get(k)
|
276
|
+
k = k[1..]
|
277
|
+
if model_input_names.include?(k)
|
278
|
+
if @framework == "tf"
|
279
|
+
raise Todo
|
280
|
+
elsif @framework == "pt"
|
281
|
+
tensor = Torch.tensor(v)
|
282
|
+
if tensor.dtype == Torch.int32
|
283
|
+
tensor = tensor.long
|
284
|
+
end
|
285
|
+
fw_args[k.to_sym] = tensor.unsqueeze(0)
|
286
|
+
end
|
287
|
+
else
|
288
|
+
others[k.to_sym] = v
|
289
|
+
end
|
290
|
+
end
|
291
|
+
|
292
|
+
is_last = i == features.length - 1
|
293
|
+
yield({example: example, is_last: is_last}.merge(fw_args).merge(others))
|
294
|
+
end
|
295
|
+
end
|
296
|
+
|
297
|
+
def _forward(inputs)
|
298
|
+
example = inputs[:example]
|
299
|
+
model_inputs = @tokenizer.model_input_names.to_h { |k| [k.to_sym, inputs[k.to_sym]] }
|
300
|
+
# `XXXForSequenceClassification` models should not use `use_cache=True` even if it's supported
|
301
|
+
# model_forward = @model.forward if self.framework == "pt" else self.model.call
|
302
|
+
# if "use_cache" in inspect.signature(model_forward).parameters.keys():
|
303
|
+
# model_inputs[:use_cache] = false
|
304
|
+
# end
|
305
|
+
output = @model.(**model_inputs)
|
306
|
+
if output.is_a?(Hash)
|
307
|
+
{start: output[:start_logits], end: output[:end_logits], example: example}.merge(inputs)
|
308
|
+
else
|
309
|
+
start, end_ = output[...2]
|
310
|
+
{start: start, end: end_, example: example}.merge(inputs)
|
311
|
+
end
|
312
|
+
end
|
313
|
+
|
314
|
+
def postprocess(
|
315
|
+
model_outputs,
|
316
|
+
top_k: 1,
|
317
|
+
handle_impossible_answer: false,
|
318
|
+
max_answer_len: 15,
|
319
|
+
align_to_words: true
|
320
|
+
)
|
321
|
+
min_null_score = 1000000 # large and positive
|
322
|
+
answers = []
|
323
|
+
model_outputs.each do |output|
|
324
|
+
start_ = output[:start]
|
325
|
+
end_ = output[:end]
|
326
|
+
example = output[:example]
|
327
|
+
p_mask = output[:p_mask]
|
328
|
+
attention_mask = (
|
329
|
+
!output[:attention_mask].nil? ? output[:attention_mask].numo : nil
|
330
|
+
)
|
331
|
+
|
332
|
+
starts, ends, scores, min_null_score = select_starts_ends(
|
333
|
+
start_, end_, p_mask, attention_mask, min_null_score, top_k, handle_impossible_answer, max_answer_len
|
334
|
+
)
|
335
|
+
|
336
|
+
if !@tokenizer.is_fast
|
337
|
+
raise Todo
|
338
|
+
else
|
339
|
+
# Convert the answer (tokens) back to the original text
|
340
|
+
# Score: score from the model
|
341
|
+
# Start: Index of the first character of the answer in the context string
|
342
|
+
# End: Index of the character following the last character of the answer in the context string
|
343
|
+
# Answer: Plain text of the answer
|
344
|
+
question_first = @tokenizer.padding_side == "right"
|
345
|
+
enc = output[:encoding]
|
346
|
+
|
347
|
+
# Encoding was *not* padded, input_ids *might*.
|
348
|
+
# It doesn't make a difference unless we're padding on
|
349
|
+
# the left hand side, since now we have different offsets
|
350
|
+
# everywhere.
|
351
|
+
if @tokenizer.padding_side == "left"
|
352
|
+
offset = output[:input_ids].eq(@tokenizer.pad_token_id).numo.sum
|
353
|
+
else
|
354
|
+
offset = 0
|
355
|
+
end
|
356
|
+
|
357
|
+
# Sometimes the max probability token is in the middle of a word so:
|
358
|
+
# - we start by finding the right word containing the token with `token_to_word`
|
359
|
+
# - then we convert this word in a character span with `word_to_chars`
|
360
|
+
sequence_index = question_first ? 1 : 0
|
361
|
+
starts.to_a.zip(ends.to_a, scores.to_a) do |s, e, score|
|
362
|
+
s = s - offset
|
363
|
+
e = e - offset
|
364
|
+
|
365
|
+
start_index, end_index = get_indices(enc, s, e, sequence_index, align_to_words)
|
366
|
+
|
367
|
+
answers <<
|
368
|
+
{
|
369
|
+
score: score[0],
|
370
|
+
start: start_index,
|
371
|
+
end: end_index,
|
372
|
+
answer: example.context_text[start_index...end_index]
|
373
|
+
}
|
374
|
+
end
|
375
|
+
end
|
376
|
+
end
|
377
|
+
|
378
|
+
if handle_impossible_answer
|
379
|
+
answers << {score: min_null_score, start: 0, end: 0, answer: ""}
|
380
|
+
end
|
381
|
+
answers = answers.sort_by { |x| -x[:score] }[...top_k]
|
382
|
+
if answers.length == 1
|
383
|
+
return answers[0]
|
384
|
+
end
|
385
|
+
answers
|
386
|
+
end
|
387
|
+
|
388
|
+
def get_indices(
|
389
|
+
enc, s, e, sequence_index, align_to_words
|
390
|
+
)
|
391
|
+
if align_to_words
|
392
|
+
begin
|
393
|
+
start_word = enc.token_to_word(s)
|
394
|
+
end_word = enc.token_to_word(e)
|
395
|
+
start_index = enc.word_to_chars(start_word, sequence_index)[0]
|
396
|
+
end_index = enc.word_to_chars(end_word, sequence_index)[1]
|
397
|
+
rescue
|
398
|
+
# TODO
|
399
|
+
raise
|
400
|
+
# Some tokenizers don't really handle words. Keep to offsets then.
|
401
|
+
start_index = enc.offsets[s][0]
|
402
|
+
end_index = enc.offsets[e][1]
|
403
|
+
end
|
404
|
+
else
|
405
|
+
start_index = enc.offsets[s][0]
|
406
|
+
end_index = enc.offsets[e][1]
|
407
|
+
end
|
408
|
+
[start_index, end_index]
|
409
|
+
end
|
410
|
+
|
411
|
+
def decode_spans(
|
412
|
+
start, end_, topk, max_answer_len, undesired_tokens
|
413
|
+
)
|
414
|
+
# Ensure we have batch axis
|
415
|
+
if start.ndim == 1
|
416
|
+
start = start[nil]
|
417
|
+
end
|
418
|
+
|
419
|
+
if end_.ndim == 1
|
420
|
+
end_ = end_[nil]
|
421
|
+
end
|
422
|
+
|
423
|
+
# Compute the score of each tuple(start, end) to be the real answer
|
424
|
+
outer = start.expand_dims(-1).dot(end_.expand_dims(1))
|
425
|
+
|
426
|
+
# Remove candidate with end < start and end - start > max_answer_len
|
427
|
+
candidates = outer.triu.tril(max_answer_len - 1)
|
428
|
+
|
429
|
+
# Inspired by Chen & al. (https://github.com/facebookresearch/DrQA)
|
430
|
+
scores_flat = candidates.flatten
|
431
|
+
if topk == 1
|
432
|
+
idx_sort = [scores_flat.argmax]
|
433
|
+
elsif scores_flat.length < topk
|
434
|
+
raise Todo
|
435
|
+
else
|
436
|
+
raise Todo
|
437
|
+
end
|
438
|
+
|
439
|
+
starts, ends = unravel_index(idx_sort, candidates.shape)[1..]
|
440
|
+
desired_spans = isin(starts, undesired_tokens.where) & isin(ends, undesired_tokens.where)
|
441
|
+
starts = starts[desired_spans]
|
442
|
+
ends = ends[desired_spans]
|
443
|
+
scores = candidates[0, starts, ends]
|
444
|
+
|
445
|
+
[starts, ends, scores]
|
446
|
+
end
|
447
|
+
|
448
|
+
def unravel_index(indices, shape)
|
449
|
+
indices = Numo::NArray.cast(indices)
|
450
|
+
result = []
|
451
|
+
factor = 1
|
452
|
+
shape.size.times do |i|
|
453
|
+
result.unshift(indices / factor % shape[-1 - i])
|
454
|
+
factor *= shape[-1 - i]
|
455
|
+
end
|
456
|
+
result
|
457
|
+
end
|
458
|
+
|
459
|
+
def isin(element, test_elements)
|
460
|
+
test_elements = test_elements.to_a
|
461
|
+
Numo::Bit.cast(element.to_a.map { |e| test_elements.include?(e) })
|
462
|
+
end
|
463
|
+
|
464
|
+
def select_starts_ends(
|
465
|
+
start,
|
466
|
+
end_,
|
467
|
+
p_mask,
|
468
|
+
attention_mask,
|
469
|
+
min_null_score = 1000000,
|
470
|
+
top_k = 1,
|
471
|
+
handle_impossible_answer = false,
|
472
|
+
max_answer_len = 15
|
473
|
+
)
|
474
|
+
# Ensure padded tokens & question tokens cannot belong to the set of candidate answers.
|
475
|
+
undesired_tokens = ~p_mask.numo
|
476
|
+
|
477
|
+
if !attention_mask.nil?
|
478
|
+
undesired_tokens = undesired_tokens & attention_mask
|
479
|
+
end
|
480
|
+
|
481
|
+
# Generate mask
|
482
|
+
undesired_tokens_mask = undesired_tokens.eq(0)
|
483
|
+
|
484
|
+
# Make sure non-context indexes in the tensor cannot contribute to the softmax
|
485
|
+
start = start.numo
|
486
|
+
end_ = end_.numo
|
487
|
+
start[undesired_tokens_mask] = -10000.0
|
488
|
+
end_[undesired_tokens_mask] = -10000.0
|
489
|
+
|
490
|
+
# Normalize logits and spans to retrieve the answer
|
491
|
+
start = Numo::NMath.exp(start - start.max(axis: -1, keepdims: true))
|
492
|
+
start = start / start.sum
|
493
|
+
|
494
|
+
end_ = Numo::NMath.exp(end_ - end_.max(axis: -1, keepdims: true))
|
495
|
+
end_ = end_ / end_.sum
|
496
|
+
|
497
|
+
if handle_impossible_answer
|
498
|
+
min_null_score = [min_null_score, (start[0, 0] * end_[0, 0]).item].min
|
499
|
+
end
|
500
|
+
|
501
|
+
# Mask CLS
|
502
|
+
start[0, 0] = end_[0, 0] = 0.0
|
503
|
+
|
504
|
+
starts, ends, scores = decode_spans(start, end_, top_k, max_answer_len, undesired_tokens)
|
505
|
+
[starts, ends, scores, min_null_score]
|
506
|
+
end
|
507
|
+
end
|
508
|
+
end
|
@@ -0,0 +1,123 @@
|
|
1
|
+
module Transformers
|
2
|
+
class TextClassificationPipeline < Pipeline
|
3
|
+
def initialize(*args, **kwargs)
|
4
|
+
super
|
5
|
+
|
6
|
+
check_model_type(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES)
|
7
|
+
end
|
8
|
+
|
9
|
+
private
|
10
|
+
|
11
|
+
def _sanitize_parameters(return_all_scores: nil, function_to_apply: nil, top_k: "", **tokenizer_kwargs)
|
12
|
+
# Using "" as default argument because we're going to use `top_k=None` in user code to declare
|
13
|
+
# "No top_k"
|
14
|
+
preprocess_params = tokenizer_kwargs
|
15
|
+
|
16
|
+
postprocess_params = {}
|
17
|
+
if @model.config.respond_to?(:return_all_scores) && return_all_scores.nil?
|
18
|
+
return_all_scores = @model.config.return_all_scores
|
19
|
+
end
|
20
|
+
|
21
|
+
if top_k.is_a?(Integer) || top_k.nil?
|
22
|
+
postprocess_params[:top_k] = top_k
|
23
|
+
postprocess_params[:_legacy] = false
|
24
|
+
elsif !return_all_scores.nil?
|
25
|
+
warn(
|
26
|
+
"`return_all_scores` is now deprecated, if want a similar functionality use `top_k: nil` instead of" +
|
27
|
+
" `return_all_scores: true` or `top_k: 1` instead of `return_all_scores: false`.",
|
28
|
+
)
|
29
|
+
if return_all_scores
|
30
|
+
postprocess_params[:top_k] = nil
|
31
|
+
else
|
32
|
+
postprocess_params[:top_k] = 1
|
33
|
+
end
|
34
|
+
end
|
35
|
+
|
36
|
+
if function_to_apply.is_a?(String)
|
37
|
+
function_to_apply = ClassificationFunction.new(function_to_apply.upcase).to_s
|
38
|
+
end
|
39
|
+
|
40
|
+
if !function_to_apply.nil?
|
41
|
+
postprocess_params[:function_to_apply] = function_to_apply
|
42
|
+
end
|
43
|
+
[preprocess_params, {}, postprocess_params]
|
44
|
+
end
|
45
|
+
|
46
|
+
def preprocess(inputs, **tokenizer_kwargs)
|
47
|
+
return_tensors = @framework
|
48
|
+
if inputs.is_a?(Hash)
|
49
|
+
return @tokenizer.(**inputs, return_tensors: return_tensors, **tokenizer_kwargs)
|
50
|
+
elsif inputs.is_a?(Array) && inputs.length == 1 && inputs[0].is_a?(Array) && inputs[0].length == 2
|
51
|
+
# It used to be valid to use a list of list of list for text pairs, keeping this path for BC
|
52
|
+
return @tokenizer.(
|
53
|
+
inputs[0][0], text_pair: inputs[0][1], return_tensors: return_tensors, **tokenizer_kwargs
|
54
|
+
)
|
55
|
+
elsif inputs.is_a?(Array)
|
56
|
+
# This is likely an invalid usage of the pipeline attempting to pass text pairs.
|
57
|
+
raise ArgumentError,
|
58
|
+
"The pipeline received invalid inputs, if you are trying to send text pairs, you can try to send a" +
|
59
|
+
' dictionary `{"text": "My text", "text_pair": "My pair"}` in order to send a text pair.'
|
60
|
+
end
|
61
|
+
@tokenizer.(inputs, return_tensors: return_tensors, **tokenizer_kwargs)
|
62
|
+
end
|
63
|
+
|
64
|
+
def _forward(model_inputs)
|
65
|
+
@model.(**model_inputs.to_h)
|
66
|
+
end
|
67
|
+
|
68
|
+
def postprocess(model_outputs, function_to_apply: nil, top_k: 1, _legacy: true)
|
69
|
+
if function_to_apply.nil?
|
70
|
+
if @model.config.problem_type == "multi_label_classification" || @model.config.num_labels == 1
|
71
|
+
function_to_apply = ClassificationFunction::SIGMOID
|
72
|
+
elsif @model.config.problem_type == "single_label_classification" || @model.config.num_labels > 1
|
73
|
+
function_to_apply = ClassificationFunction::SOFTMAX
|
74
|
+
elsif @model.config.instance_variable_defined?(:@function_to_apply) && function_to_apply.nil?
|
75
|
+
function_to_apply = @model.config.function_to_apply
|
76
|
+
else
|
77
|
+
function_to_apply = ClassificationFunction::NONE
|
78
|
+
end
|
79
|
+
end
|
80
|
+
|
81
|
+
outputs = model_outputs["logits"][0]
|
82
|
+
outputs = outputs.numo
|
83
|
+
|
84
|
+
if function_to_apply == ClassificationFunction::SIGMOID
|
85
|
+
scores = sigmoid(outputs)
|
86
|
+
elsif function_to_apply == ClassificationFunction::SOFTMAX
|
87
|
+
scores = softmax(outputs)
|
88
|
+
elsif function_to_apply == ClassificationFunction::NONE
|
89
|
+
scores = outputs
|
90
|
+
else
|
91
|
+
raise ArgumentError, "Unrecognized `function_to_apply` argument: #{function_to_apply}"
|
92
|
+
end
|
93
|
+
|
94
|
+
if top_k == 1 && _legacy
|
95
|
+
return {label: @model.config.id2label[scores.argmax], score: scores.max}
|
96
|
+
end
|
97
|
+
|
98
|
+
dict_scores =
|
99
|
+
scores.to_a.map.with_index do |score, i|
|
100
|
+
{label: @model.config.id2label[i], score: score}
|
101
|
+
end
|
102
|
+
if !_legacy
|
103
|
+
dict_scores.sort_by! { |x| -x[:score] }
|
104
|
+
if !top_k.nil?
|
105
|
+
dict_scores = dict_scores.first(top_k)
|
106
|
+
end
|
107
|
+
end
|
108
|
+
dict_scores
|
109
|
+
end
|
110
|
+
|
111
|
+
private
|
112
|
+
|
113
|
+
def sigmoid(_outputs)
|
114
|
+
1.0 / (1.0 + Numo::NMath.exp(-_outputs))
|
115
|
+
end
|
116
|
+
|
117
|
+
def softmax(_outputs)
|
118
|
+
maxes = _outputs.max(axis: -1, keepdims: true)
|
119
|
+
shifted_exp = Numo::NMath.exp(_outputs - maxes)
|
120
|
+
shifted_exp / shifted_exp.sum(axis: -1, keepdims: true)
|
121
|
+
end
|
122
|
+
end
|
123
|
+
end
|