transformers-rb 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (65) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +3 -0
  3. data/LICENSE.txt +203 -0
  4. data/README.md +163 -0
  5. data/lib/transformers/activations.rb +57 -0
  6. data/lib/transformers/configuration_utils.rb +285 -0
  7. data/lib/transformers/convert_slow_tokenizer.rb +90 -0
  8. data/lib/transformers/data/processors/squad.rb +115 -0
  9. data/lib/transformers/dynamic_module_utils.rb +25 -0
  10. data/lib/transformers/feature_extraction_utils.rb +110 -0
  11. data/lib/transformers/hf_hub/constants.rb +71 -0
  12. data/lib/transformers/hf_hub/errors.rb +11 -0
  13. data/lib/transformers/hf_hub/file_download.rb +764 -0
  14. data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
  15. data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
  16. data/lib/transformers/image_processing_base.rb +169 -0
  17. data/lib/transformers/image_processing_utils.rb +63 -0
  18. data/lib/transformers/image_transforms.rb +208 -0
  19. data/lib/transformers/image_utils.rb +165 -0
  20. data/lib/transformers/modeling_outputs.rb +81 -0
  21. data/lib/transformers/modeling_utils.rb +888 -0
  22. data/lib/transformers/models/auto/auto_factory.rb +138 -0
  23. data/lib/transformers/models/auto/configuration_auto.rb +61 -0
  24. data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
  25. data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
  26. data/lib/transformers/models/auto/modeling_auto.rb +80 -0
  27. data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
  28. data/lib/transformers/models/bert/configuration_bert.rb +65 -0
  29. data/lib/transformers/models/bert/modeling_bert.rb +836 -0
  30. data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
  31. data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
  32. data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
  33. data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
  34. data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
  35. data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
  36. data/lib/transformers/models/vit/configuration_vit.rb +60 -0
  37. data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
  38. data/lib/transformers/models/vit/modeling_vit.rb +506 -0
  39. data/lib/transformers/pipelines/_init.rb +348 -0
  40. data/lib/transformers/pipelines/base.rb +301 -0
  41. data/lib/transformers/pipelines/feature_extraction.rb +47 -0
  42. data/lib/transformers/pipelines/image_classification.rb +110 -0
  43. data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
  44. data/lib/transformers/pipelines/pt_utils.rb +53 -0
  45. data/lib/transformers/pipelines/question_answering.rb +508 -0
  46. data/lib/transformers/pipelines/text_classification.rb +123 -0
  47. data/lib/transformers/pipelines/token_classification.rb +282 -0
  48. data/lib/transformers/ruby_utils.rb +33 -0
  49. data/lib/transformers/sentence_transformer.rb +37 -0
  50. data/lib/transformers/tokenization_utils.rb +152 -0
  51. data/lib/transformers/tokenization_utils_base.rb +937 -0
  52. data/lib/transformers/tokenization_utils_fast.rb +386 -0
  53. data/lib/transformers/torch_utils.rb +25 -0
  54. data/lib/transformers/utils/_init.rb +31 -0
  55. data/lib/transformers/utils/generic.rb +107 -0
  56. data/lib/transformers/utils/hub.rb +209 -0
  57. data/lib/transformers/utils/import_utils.rb +45 -0
  58. data/lib/transformers/utils/logging.rb +52 -0
  59. data/lib/transformers/version.rb +3 -0
  60. data/lib/transformers-rb.rb +1 -0
  61. data/lib/transformers.rb +100 -0
  62. data/licenses/LICENSE-huggingface-hub.txt +201 -0
  63. data/licenses/LICENSE-sentence-transformers.txt +201 -0
  64. data/licenses/NOTICE-sentence-transformers.txt +5 -0
  65. metadata +161 -0
@@ -0,0 +1,508 @@
1
+ module Transformers
2
+ class QuestionAnsweringArgumentHandler < ArgumentHandler
3
+ def normalize(item)
4
+ if item.is_a?(SquadExample)
5
+ return item
6
+ elsif item.is_a?(Hash)
7
+ [:question, :context].each do |k|
8
+ if !item.include?(k)
9
+ raise KeyError, "You need to provide a dictionary with keys {question:..., context:...}"
10
+ elsif item[k].nil?
11
+ raise ArgumentError, "`#{k}` cannot be nil"
12
+ elsif item[k].is_a?(String) && item[k].length == 0
13
+ raise ArgumentError, "`#{k}` cannot be empty"
14
+ end
15
+ end
16
+
17
+ return QuestionAnsweringPipeline.create_sample(**item)
18
+ end
19
+ raise ArgumentError, "#{item} argument needs to be of type (SquadExample, dict)"
20
+ end
21
+
22
+ def call(*args, **kwargs)
23
+ # Detect where the actual inputs are
24
+ if args.any?
25
+ if args.length == 1
26
+ inputs = args[0]
27
+ elsif args.length == 2 && args.all? { |el| el.is_a?(String) }
28
+ inputs = [{question: args[0], context: args[1]}]
29
+ else
30
+ inputs = args.to_a
31
+ end
32
+ elsif kwargs.include?(:question) && kwargs.include?(:context)
33
+ if kwargs[:question].is_a?(Array) && kwargs[:context].is_a?(String)
34
+ inputs = kwargs[:question].map { |q| {question: q, context: kwargs[:context]} }
35
+ elsif kwargs[:question].is_a?(Array) && kwargs[:context].is_a?(Array)
36
+ if kwargs[:question].length != kwargs[:context].length
37
+ raise ArgumentError, "Questions and contexts don't have the same lengths"
38
+ end
39
+
40
+ inputs = kwargs[:question].zip(kwargs[:context]).map { |q, c| {question: q, context: c} }
41
+ elsif kwargs[:question].is_a?(String) && kwargs[:context].is_a?(String)
42
+ inputs = [{question: kwargs[:question], context: kwargs[:context]}]
43
+ else
44
+ raise ArgumentError, "Arguments can't be understood"
45
+ end
46
+ else
47
+ raise ArgumentError, "Unknown arguments #{kwargs}"
48
+ end
49
+
50
+ # Normalize inputs
51
+ if inputs.is_a?(Hash)
52
+ inputs = [inputs]
53
+ elsif inputs.is_a?(Enumerable)
54
+ # Copy to avoid overriding arguments
55
+ inputs = inputs.to_a.dup
56
+ else
57
+ raise ArgumentError, "Invalid arguments #{kwargs}"
58
+ end
59
+
60
+ inputs.each_with_index do |item, i|
61
+ inputs[i] = normalize(item)
62
+ end
63
+
64
+ inputs
65
+ end
66
+ end
67
+
68
+ class QuestionAnsweringPipeline < ChunkPipeline
69
+ extend ClassAttribute
70
+
71
+ class_attribute :default_input_names, "question,context"
72
+ class_attribute :handle_impossible_answer, false
73
+
74
+ def initialize(
75
+ model,
76
+ tokenizer:,
77
+ modelcard: nil,
78
+ framework: nil,
79
+ task: "",
80
+ **kwargs
81
+ )
82
+ super(
83
+ model,
84
+ tokenizer: tokenizer,
85
+ modelcard: modelcard,
86
+ framework: framework,
87
+ task: task,
88
+ **kwargs
89
+ )
90
+
91
+ @args_parser = QuestionAnsweringArgumentHandler.new
92
+ check_model_type(MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES)
93
+ end
94
+
95
+ def self.create_sample(
96
+ question:, context:
97
+ )
98
+ if question.is_a?(Array)
99
+ question.zip(context).map { |q, c| SquadExample.new(nil, q, c, nil, nil, nil) }
100
+ else
101
+ SquadExample.new(nil, question, context, nil, nil, nil)
102
+ end
103
+ end
104
+
105
+ def _sanitize_parameters(
106
+ padding: nil,
107
+ topk: nil,
108
+ top_k: nil,
109
+ doc_stride: nil,
110
+ max_answer_len: nil,
111
+ max_seq_len: nil,
112
+ max_question_len: nil,
113
+ handle_impossible_answer: nil,
114
+ align_to_words: nil,
115
+ **kwargs
116
+ )
117
+ # Set defaults values
118
+ preprocess_params = {}
119
+ if !padding.nil?
120
+ preprocess_params[:padding] = padding
121
+ end
122
+ if !doc_stride.nil?
123
+ preprocess_params[:doc_stride] = doc_stride
124
+ end
125
+ if !max_question_len.nil?
126
+ preprocess_params[:max_question_len] = max_question_len
127
+ end
128
+ if !max_seq_len.nil?
129
+ preprocess_params[:max_seq_len] = max_seq_len
130
+ end
131
+
132
+ postprocess_params = {}
133
+ if !topk.nil? && top_k.nil?
134
+ warn("topk parameter is deprecated, use top_k instead")
135
+ top_k = topk
136
+ end
137
+ if !top_k.nil?
138
+ if top_k < 1
139
+ raise ArgumentError, "top_k parameter should be >= 1 (got #{top_k})"
140
+ end
141
+ postprocess_params[:top_k] = top_k
142
+ end
143
+ if !max_answer_len.nil?
144
+ if max_answer_len < 1
145
+ raise ArgumentError, "max_answer_len parameter should be >= 1 (got #{max_answer_len})"
146
+ end
147
+ end
148
+ if !max_answer_len.nil?
149
+ postprocess_params[:max_answer_len] = max_answer_len
150
+ end
151
+ if !handle_impossible_answer.nil?
152
+ postprocess_params[:handle_impossible_answer] = handle_impossible_answer
153
+ end
154
+ if !align_to_words.nil?
155
+ postprocess_params[:align_to_words] = align_to_words
156
+ end
157
+ [preprocess_params, {}, postprocess_params]
158
+ end
159
+
160
+ def call(*args, **kwargs)
161
+ examples = @args_parser.(*args, **kwargs)
162
+ if examples.is_a?(Array) && examples.length == 1
163
+ return super(examples[0], **kwargs)
164
+ end
165
+ super(examples, **kwargs)
166
+ end
167
+
168
+ def preprocess(example, padding: "do_not_pad", doc_stride: nil, max_question_len: 64, max_seq_len: nil)
169
+ # XXX: This is specal, args_parser will not handle anything generator or dataset like
170
+ # For those we expect user to send a simple valid example either directly as a SquadExample or simple dict.
171
+ # So we still need a little sanitation here.
172
+ if example.is_a?(Hash)
173
+ example = SquadExample.new(nil, example[:question], example[:context], nil, nil, nil)
174
+ end
175
+
176
+ if max_seq_len.nil?
177
+ max_seq_len = [@tokenizer.model_max_length, 384].min
178
+ end
179
+ if doc_stride.nil?
180
+ doc_stride = [max_seq_len.div(2), 128].min
181
+ end
182
+
183
+ if doc_stride > max_seq_len
184
+ raise ArgumentError, "`doc_stride` (#{doc_stride}) is larger than `max_seq_len` (#{max_seq_len})"
185
+ end
186
+
187
+ if !@tokenizer.is_fast
188
+ features = squad_convert_examples_to_features(
189
+ examples: [example],
190
+ tokenizer: @tokenizer,
191
+ max_seq_length: max_seq_len,
192
+ doc_stride: doc_stride,
193
+ max_query_length: max_question_len,
194
+ padding_strategy: PaddingStrategy::MAX_LENGTH,
195
+ is_training: false,
196
+ tqdm_enabled: false
197
+ )
198
+ else
199
+ # Define the side we want to truncate / pad and the text/pair sorting
200
+ question_first = @tokenizer.padding_side == "right"
201
+
202
+ encoded_inputs = @tokenizer.(
203
+ question_first ? example.question_text : example.context_text,
204
+ text_pair: question_first ? example.context_text : example.question_text,
205
+ padding: padding,
206
+ truncation: question_first ? "only_second" : "only_first",
207
+ max_length: max_seq_len,
208
+ stride: doc_stride,
209
+ return_token_type_ids: true,
210
+ return_overflowing_tokens: true,
211
+ return_offsets_mapping: true,
212
+ return_special_tokens_mask: true\
213
+ )
214
+ # When the input is too long, it's converted in a batch of inputs with overflowing tokens
215
+ # and a stride of overlap between the inputs. If a batch of inputs is given, a special output
216
+ # "overflow_to_sample_mapping" indicate which member of the encoded batch belong to which original batch sample.
217
+ # Here we tokenize examples one-by-one so we don't need to use "overflow_to_sample_mapping".
218
+ # "num_span" is the number of output samples generated from the overflowing tokens.
219
+ num_spans = encoded_inputs[:input_ids].length
220
+
221
+ # p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
222
+ # We put 0 on the tokens from the context and 1 everywhere else (question and special tokens)
223
+ p_mask =
224
+ num_spans.times.map do |span_id|
225
+ encoded_inputs.sequence_ids(span_id).map { |tok| tok != (question_first ? 1 : 0) }
226
+ end
227
+
228
+ features = []
229
+ num_spans.times do |span_idx|
230
+ input_ids_span_idx = encoded_inputs[:input_ids][span_idx]
231
+ attention_mask_span_idx = (
232
+ encoded_inputs.include?(:attention_mask) ? encoded_inputs[:attention_mask][span_idx] : nil
233
+ )
234
+ token_type_ids_span_idx = (
235
+ encoded_inputs.include?(:token_type_ids) ? encoded_inputs[:token_type_ids][span_idx] : nil
236
+ )
237
+ # keep the cls_token unmasked (some models use it to indicate unanswerable questions)
238
+ if !@tokenizer.cls_token_id.nil?
239
+ cls_indices = (Numo::NArray.cast(input_ids_span_idx).eq(@tokenizer.cls_token_id)).where
240
+ cls_indices.each do |cls_index|
241
+ p_mask[span_idx][cls_index] = false
242
+ end
243
+ end
244
+ submask = p_mask[span_idx]
245
+ features <<
246
+ SquadFeatures.new(
247
+ input_ids: input_ids_span_idx,
248
+ attention_mask: attention_mask_span_idx,
249
+ token_type_ids: token_type_ids_span_idx,
250
+ p_mask: submask,
251
+ encoding: encoded_inputs[span_idx],
252
+ # We don't use the rest of the values - and actually
253
+ # for Fast tokenizer we could totally avoid using SquadFeatures and SquadExample
254
+ cls_index: nil,
255
+ token_to_orig_map: {},
256
+ example_index: 0,
257
+ unique_id: 0,
258
+ paragraph_len: 0,
259
+ token_is_max_context: 0,
260
+ tokens: [],
261
+ start_position: 0,
262
+ end_position: 0,
263
+ is_impossible: false,
264
+ qas_id: nil
265
+ )
266
+ end
267
+ end
268
+
269
+ features.each_with_index do |feature, i|
270
+ fw_args = {}
271
+ others = {}
272
+ model_input_names = @tokenizer.model_input_names + ["p_mask", "token_type_ids"]
273
+
274
+ feature.instance_variables.each do |k|
275
+ v = feature.instance_variable_get(k)
276
+ k = k[1..]
277
+ if model_input_names.include?(k)
278
+ if @framework == "tf"
279
+ raise Todo
280
+ elsif @framework == "pt"
281
+ tensor = Torch.tensor(v)
282
+ if tensor.dtype == Torch.int32
283
+ tensor = tensor.long
284
+ end
285
+ fw_args[k.to_sym] = tensor.unsqueeze(0)
286
+ end
287
+ else
288
+ others[k.to_sym] = v
289
+ end
290
+ end
291
+
292
+ is_last = i == features.length - 1
293
+ yield({example: example, is_last: is_last}.merge(fw_args).merge(others))
294
+ end
295
+ end
296
+
297
+ def _forward(inputs)
298
+ example = inputs[:example]
299
+ model_inputs = @tokenizer.model_input_names.to_h { |k| [k.to_sym, inputs[k.to_sym]] }
300
+ # `XXXForSequenceClassification` models should not use `use_cache=True` even if it's supported
301
+ # model_forward = @model.forward if self.framework == "pt" else self.model.call
302
+ # if "use_cache" in inspect.signature(model_forward).parameters.keys():
303
+ # model_inputs[:use_cache] = false
304
+ # end
305
+ output = @model.(**model_inputs)
306
+ if output.is_a?(Hash)
307
+ {start: output[:start_logits], end: output[:end_logits], example: example}.merge(inputs)
308
+ else
309
+ start, end_ = output[...2]
310
+ {start: start, end: end_, example: example}.merge(inputs)
311
+ end
312
+ end
313
+
314
+ def postprocess(
315
+ model_outputs,
316
+ top_k: 1,
317
+ handle_impossible_answer: false,
318
+ max_answer_len: 15,
319
+ align_to_words: true
320
+ )
321
+ min_null_score = 1000000 # large and positive
322
+ answers = []
323
+ model_outputs.each do |output|
324
+ start_ = output[:start]
325
+ end_ = output[:end]
326
+ example = output[:example]
327
+ p_mask = output[:p_mask]
328
+ attention_mask = (
329
+ !output[:attention_mask].nil? ? output[:attention_mask].numo : nil
330
+ )
331
+
332
+ starts, ends, scores, min_null_score = select_starts_ends(
333
+ start_, end_, p_mask, attention_mask, min_null_score, top_k, handle_impossible_answer, max_answer_len
334
+ )
335
+
336
+ if !@tokenizer.is_fast
337
+ raise Todo
338
+ else
339
+ # Convert the answer (tokens) back to the original text
340
+ # Score: score from the model
341
+ # Start: Index of the first character of the answer in the context string
342
+ # End: Index of the character following the last character of the answer in the context string
343
+ # Answer: Plain text of the answer
344
+ question_first = @tokenizer.padding_side == "right"
345
+ enc = output[:encoding]
346
+
347
+ # Encoding was *not* padded, input_ids *might*.
348
+ # It doesn't make a difference unless we're padding on
349
+ # the left hand side, since now we have different offsets
350
+ # everywhere.
351
+ if @tokenizer.padding_side == "left"
352
+ offset = output[:input_ids].eq(@tokenizer.pad_token_id).numo.sum
353
+ else
354
+ offset = 0
355
+ end
356
+
357
+ # Sometimes the max probability token is in the middle of a word so:
358
+ # - we start by finding the right word containing the token with `token_to_word`
359
+ # - then we convert this word in a character span with `word_to_chars`
360
+ sequence_index = question_first ? 1 : 0
361
+ starts.to_a.zip(ends.to_a, scores.to_a) do |s, e, score|
362
+ s = s - offset
363
+ e = e - offset
364
+
365
+ start_index, end_index = get_indices(enc, s, e, sequence_index, align_to_words)
366
+
367
+ answers <<
368
+ {
369
+ score: score[0],
370
+ start: start_index,
371
+ end: end_index,
372
+ answer: example.context_text[start_index...end_index]
373
+ }
374
+ end
375
+ end
376
+ end
377
+
378
+ if handle_impossible_answer
379
+ answers << {score: min_null_score, start: 0, end: 0, answer: ""}
380
+ end
381
+ answers = answers.sort_by { |x| -x[:score] }[...top_k]
382
+ if answers.length == 1
383
+ return answers[0]
384
+ end
385
+ answers
386
+ end
387
+
388
+ def get_indices(
389
+ enc, s, e, sequence_index, align_to_words
390
+ )
391
+ if align_to_words
392
+ begin
393
+ start_word = enc.token_to_word(s)
394
+ end_word = enc.token_to_word(e)
395
+ start_index = enc.word_to_chars(start_word, sequence_index)[0]
396
+ end_index = enc.word_to_chars(end_word, sequence_index)[1]
397
+ rescue
398
+ # TODO
399
+ raise
400
+ # Some tokenizers don't really handle words. Keep to offsets then.
401
+ start_index = enc.offsets[s][0]
402
+ end_index = enc.offsets[e][1]
403
+ end
404
+ else
405
+ start_index = enc.offsets[s][0]
406
+ end_index = enc.offsets[e][1]
407
+ end
408
+ [start_index, end_index]
409
+ end
410
+
411
+ def decode_spans(
412
+ start, end_, topk, max_answer_len, undesired_tokens
413
+ )
414
+ # Ensure we have batch axis
415
+ if start.ndim == 1
416
+ start = start[nil]
417
+ end
418
+
419
+ if end_.ndim == 1
420
+ end_ = end_[nil]
421
+ end
422
+
423
+ # Compute the score of each tuple(start, end) to be the real answer
424
+ outer = start.expand_dims(-1).dot(end_.expand_dims(1))
425
+
426
+ # Remove candidate with end < start and end - start > max_answer_len
427
+ candidates = outer.triu.tril(max_answer_len - 1)
428
+
429
+ # Inspired by Chen & al. (https://github.com/facebookresearch/DrQA)
430
+ scores_flat = candidates.flatten
431
+ if topk == 1
432
+ idx_sort = [scores_flat.argmax]
433
+ elsif scores_flat.length < topk
434
+ raise Todo
435
+ else
436
+ raise Todo
437
+ end
438
+
439
+ starts, ends = unravel_index(idx_sort, candidates.shape)[1..]
440
+ desired_spans = isin(starts, undesired_tokens.where) & isin(ends, undesired_tokens.where)
441
+ starts = starts[desired_spans]
442
+ ends = ends[desired_spans]
443
+ scores = candidates[0, starts, ends]
444
+
445
+ [starts, ends, scores]
446
+ end
447
+
448
+ def unravel_index(indices, shape)
449
+ indices = Numo::NArray.cast(indices)
450
+ result = []
451
+ factor = 1
452
+ shape.size.times do |i|
453
+ result.unshift(indices / factor % shape[-1 - i])
454
+ factor *= shape[-1 - i]
455
+ end
456
+ result
457
+ end
458
+
459
+ def isin(element, test_elements)
460
+ test_elements = test_elements.to_a
461
+ Numo::Bit.cast(element.to_a.map { |e| test_elements.include?(e) })
462
+ end
463
+
464
+ def select_starts_ends(
465
+ start,
466
+ end_,
467
+ p_mask,
468
+ attention_mask,
469
+ min_null_score = 1000000,
470
+ top_k = 1,
471
+ handle_impossible_answer = false,
472
+ max_answer_len = 15
473
+ )
474
+ # Ensure padded tokens & question tokens cannot belong to the set of candidate answers.
475
+ undesired_tokens = ~p_mask.numo
476
+
477
+ if !attention_mask.nil?
478
+ undesired_tokens = undesired_tokens & attention_mask
479
+ end
480
+
481
+ # Generate mask
482
+ undesired_tokens_mask = undesired_tokens.eq(0)
483
+
484
+ # Make sure non-context indexes in the tensor cannot contribute to the softmax
485
+ start = start.numo
486
+ end_ = end_.numo
487
+ start[undesired_tokens_mask] = -10000.0
488
+ end_[undesired_tokens_mask] = -10000.0
489
+
490
+ # Normalize logits and spans to retrieve the answer
491
+ start = Numo::NMath.exp(start - start.max(axis: -1, keepdims: true))
492
+ start = start / start.sum
493
+
494
+ end_ = Numo::NMath.exp(end_ - end_.max(axis: -1, keepdims: true))
495
+ end_ = end_ / end_.sum
496
+
497
+ if handle_impossible_answer
498
+ min_null_score = [min_null_score, (start[0, 0] * end_[0, 0]).item].min
499
+ end
500
+
501
+ # Mask CLS
502
+ start[0, 0] = end_[0, 0] = 0.0
503
+
504
+ starts, ends, scores = decode_spans(start, end_, top_k, max_answer_len, undesired_tokens)
505
+ [starts, ends, scores, min_null_score]
506
+ end
507
+ end
508
+ end
@@ -0,0 +1,123 @@
1
+ module Transformers
2
+ class TextClassificationPipeline < Pipeline
3
+ def initialize(*args, **kwargs)
4
+ super
5
+
6
+ check_model_type(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES)
7
+ end
8
+
9
+ private
10
+
11
+ def _sanitize_parameters(return_all_scores: nil, function_to_apply: nil, top_k: "", **tokenizer_kwargs)
12
+ # Using "" as default argument because we're going to use `top_k=None` in user code to declare
13
+ # "No top_k"
14
+ preprocess_params = tokenizer_kwargs
15
+
16
+ postprocess_params = {}
17
+ if @model.config.respond_to?(:return_all_scores) && return_all_scores.nil?
18
+ return_all_scores = @model.config.return_all_scores
19
+ end
20
+
21
+ if top_k.is_a?(Integer) || top_k.nil?
22
+ postprocess_params[:top_k] = top_k
23
+ postprocess_params[:_legacy] = false
24
+ elsif !return_all_scores.nil?
25
+ warn(
26
+ "`return_all_scores` is now deprecated, if want a similar functionality use `top_k: nil` instead of" +
27
+ " `return_all_scores: true` or `top_k: 1` instead of `return_all_scores: false`.",
28
+ )
29
+ if return_all_scores
30
+ postprocess_params[:top_k] = nil
31
+ else
32
+ postprocess_params[:top_k] = 1
33
+ end
34
+ end
35
+
36
+ if function_to_apply.is_a?(String)
37
+ function_to_apply = ClassificationFunction.new(function_to_apply.upcase).to_s
38
+ end
39
+
40
+ if !function_to_apply.nil?
41
+ postprocess_params[:function_to_apply] = function_to_apply
42
+ end
43
+ [preprocess_params, {}, postprocess_params]
44
+ end
45
+
46
+ def preprocess(inputs, **tokenizer_kwargs)
47
+ return_tensors = @framework
48
+ if inputs.is_a?(Hash)
49
+ return @tokenizer.(**inputs, return_tensors: return_tensors, **tokenizer_kwargs)
50
+ elsif inputs.is_a?(Array) && inputs.length == 1 && inputs[0].is_a?(Array) && inputs[0].length == 2
51
+ # It used to be valid to use a list of list of list for text pairs, keeping this path for BC
52
+ return @tokenizer.(
53
+ inputs[0][0], text_pair: inputs[0][1], return_tensors: return_tensors, **tokenizer_kwargs
54
+ )
55
+ elsif inputs.is_a?(Array)
56
+ # This is likely an invalid usage of the pipeline attempting to pass text pairs.
57
+ raise ArgumentError,
58
+ "The pipeline received invalid inputs, if you are trying to send text pairs, you can try to send a" +
59
+ ' dictionary `{"text": "My text", "text_pair": "My pair"}` in order to send a text pair.'
60
+ end
61
+ @tokenizer.(inputs, return_tensors: return_tensors, **tokenizer_kwargs)
62
+ end
63
+
64
+ def _forward(model_inputs)
65
+ @model.(**model_inputs.to_h)
66
+ end
67
+
68
+ def postprocess(model_outputs, function_to_apply: nil, top_k: 1, _legacy: true)
69
+ if function_to_apply.nil?
70
+ if @model.config.problem_type == "multi_label_classification" || @model.config.num_labels == 1
71
+ function_to_apply = ClassificationFunction::SIGMOID
72
+ elsif @model.config.problem_type == "single_label_classification" || @model.config.num_labels > 1
73
+ function_to_apply = ClassificationFunction::SOFTMAX
74
+ elsif @model.config.instance_variable_defined?(:@function_to_apply) && function_to_apply.nil?
75
+ function_to_apply = @model.config.function_to_apply
76
+ else
77
+ function_to_apply = ClassificationFunction::NONE
78
+ end
79
+ end
80
+
81
+ outputs = model_outputs["logits"][0]
82
+ outputs = outputs.numo
83
+
84
+ if function_to_apply == ClassificationFunction::SIGMOID
85
+ scores = sigmoid(outputs)
86
+ elsif function_to_apply == ClassificationFunction::SOFTMAX
87
+ scores = softmax(outputs)
88
+ elsif function_to_apply == ClassificationFunction::NONE
89
+ scores = outputs
90
+ else
91
+ raise ArgumentError, "Unrecognized `function_to_apply` argument: #{function_to_apply}"
92
+ end
93
+
94
+ if top_k == 1 && _legacy
95
+ return {label: @model.config.id2label[scores.argmax], score: scores.max}
96
+ end
97
+
98
+ dict_scores =
99
+ scores.to_a.map.with_index do |score, i|
100
+ {label: @model.config.id2label[i], score: score}
101
+ end
102
+ if !_legacy
103
+ dict_scores.sort_by! { |x| -x[:score] }
104
+ if !top_k.nil?
105
+ dict_scores = dict_scores.first(top_k)
106
+ end
107
+ end
108
+ dict_scores
109
+ end
110
+
111
+ private
112
+
113
+ def sigmoid(_outputs)
114
+ 1.0 / (1.0 + Numo::NMath.exp(-_outputs))
115
+ end
116
+
117
+ def softmax(_outputs)
118
+ maxes = _outputs.max(axis: -1, keepdims: true)
119
+ shifted_exp = Numo::NMath.exp(_outputs - maxes)
120
+ shifted_exp / shifted_exp.sum(axis: -1, keepdims: true)
121
+ end
122
+ end
123
+ end