transformers-rb 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +3 -0
  3. data/LICENSE.txt +203 -0
  4. data/README.md +163 -0
  5. data/lib/transformers/activations.rb +57 -0
  6. data/lib/transformers/configuration_utils.rb +285 -0
  7. data/lib/transformers/convert_slow_tokenizer.rb +90 -0
  8. data/lib/transformers/data/processors/squad.rb +115 -0
  9. data/lib/transformers/dynamic_module_utils.rb +25 -0
  10. data/lib/transformers/feature_extraction_utils.rb +110 -0
  11. data/lib/transformers/hf_hub/constants.rb +71 -0
  12. data/lib/transformers/hf_hub/errors.rb +11 -0
  13. data/lib/transformers/hf_hub/file_download.rb +764 -0
  14. data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
  15. data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
  16. data/lib/transformers/image_processing_base.rb +169 -0
  17. data/lib/transformers/image_processing_utils.rb +63 -0
  18. data/lib/transformers/image_transforms.rb +208 -0
  19. data/lib/transformers/image_utils.rb +165 -0
  20. data/lib/transformers/modeling_outputs.rb +81 -0
  21. data/lib/transformers/modeling_utils.rb +888 -0
  22. data/lib/transformers/models/auto/auto_factory.rb +138 -0
  23. data/lib/transformers/models/auto/configuration_auto.rb +61 -0
  24. data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
  25. data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
  26. data/lib/transformers/models/auto/modeling_auto.rb +80 -0
  27. data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
  28. data/lib/transformers/models/bert/configuration_bert.rb +65 -0
  29. data/lib/transformers/models/bert/modeling_bert.rb +836 -0
  30. data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
  31. data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
  32. data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
  33. data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
  34. data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
  35. data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
  36. data/lib/transformers/models/vit/configuration_vit.rb +60 -0
  37. data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
  38. data/lib/transformers/models/vit/modeling_vit.rb +506 -0
  39. data/lib/transformers/pipelines/_init.rb +348 -0
  40. data/lib/transformers/pipelines/base.rb +301 -0
  41. data/lib/transformers/pipelines/feature_extraction.rb +47 -0
  42. data/lib/transformers/pipelines/image_classification.rb +110 -0
  43. data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
  44. data/lib/transformers/pipelines/pt_utils.rb +53 -0
  45. data/lib/transformers/pipelines/question_answering.rb +508 -0
  46. data/lib/transformers/pipelines/text_classification.rb +123 -0
  47. data/lib/transformers/pipelines/token_classification.rb +282 -0
  48. data/lib/transformers/ruby_utils.rb +33 -0
  49. data/lib/transformers/sentence_transformer.rb +37 -0
  50. data/lib/transformers/tokenization_utils.rb +152 -0
  51. data/lib/transformers/tokenization_utils_base.rb +937 -0
  52. data/lib/transformers/tokenization_utils_fast.rb +386 -0
  53. data/lib/transformers/torch_utils.rb +25 -0
  54. data/lib/transformers/utils/_init.rb +31 -0
  55. data/lib/transformers/utils/generic.rb +107 -0
  56. data/lib/transformers/utils/hub.rb +209 -0
  57. data/lib/transformers/utils/import_utils.rb +45 -0
  58. data/lib/transformers/utils/logging.rb +52 -0
  59. data/lib/transformers/version.rb +3 -0
  60. data/lib/transformers-rb.rb +1 -0
  61. data/lib/transformers.rb +100 -0
  62. data/licenses/LICENSE-huggingface-hub.txt +201 -0
  63. data/licenses/LICENSE-sentence-transformers.txt +201 -0
  64. data/licenses/NOTICE-sentence-transformers.txt +5 -0
  65. metadata +161 -0
@@ -0,0 +1,508 @@
1
+ module Transformers
2
+ class QuestionAnsweringArgumentHandler < ArgumentHandler
3
+ def normalize(item)
4
+ if item.is_a?(SquadExample)
5
+ return item
6
+ elsif item.is_a?(Hash)
7
+ [:question, :context].each do |k|
8
+ if !item.include?(k)
9
+ raise KeyError, "You need to provide a dictionary with keys {question:..., context:...}"
10
+ elsif item[k].nil?
11
+ raise ArgumentError, "`#{k}` cannot be nil"
12
+ elsif item[k].is_a?(String) && item[k].length == 0
13
+ raise ArgumentError, "`#{k}` cannot be empty"
14
+ end
15
+ end
16
+
17
+ return QuestionAnsweringPipeline.create_sample(**item)
18
+ end
19
+ raise ArgumentError, "#{item} argument needs to be of type (SquadExample, dict)"
20
+ end
21
+
22
+ def call(*args, **kwargs)
23
+ # Detect where the actual inputs are
24
+ if args.any?
25
+ if args.length == 1
26
+ inputs = args[0]
27
+ elsif args.length == 2 && args.all? { |el| el.is_a?(String) }
28
+ inputs = [{question: args[0], context: args[1]}]
29
+ else
30
+ inputs = args.to_a
31
+ end
32
+ elsif kwargs.include?(:question) && kwargs.include?(:context)
33
+ if kwargs[:question].is_a?(Array) && kwargs[:context].is_a?(String)
34
+ inputs = kwargs[:question].map { |q| {question: q, context: kwargs[:context]} }
35
+ elsif kwargs[:question].is_a?(Array) && kwargs[:context].is_a?(Array)
36
+ if kwargs[:question].length != kwargs[:context].length
37
+ raise ArgumentError, "Questions and contexts don't have the same lengths"
38
+ end
39
+
40
+ inputs = kwargs[:question].zip(kwargs[:context]).map { |q, c| {question: q, context: c} }
41
+ elsif kwargs[:question].is_a?(String) && kwargs[:context].is_a?(String)
42
+ inputs = [{question: kwargs[:question], context: kwargs[:context]}]
43
+ else
44
+ raise ArgumentError, "Arguments can't be understood"
45
+ end
46
+ else
47
+ raise ArgumentError, "Unknown arguments #{kwargs}"
48
+ end
49
+
50
+ # Normalize inputs
51
+ if inputs.is_a?(Hash)
52
+ inputs = [inputs]
53
+ elsif inputs.is_a?(Enumerable)
54
+ # Copy to avoid overriding arguments
55
+ inputs = inputs.to_a.dup
56
+ else
57
+ raise ArgumentError, "Invalid arguments #{kwargs}"
58
+ end
59
+
60
+ inputs.each_with_index do |item, i|
61
+ inputs[i] = normalize(item)
62
+ end
63
+
64
+ inputs
65
+ end
66
+ end
67
+
68
+ class QuestionAnsweringPipeline < ChunkPipeline
69
+ extend ClassAttribute
70
+
71
+ class_attribute :default_input_names, "question,context"
72
+ class_attribute :handle_impossible_answer, false
73
+
74
+ def initialize(
75
+ model,
76
+ tokenizer:,
77
+ modelcard: nil,
78
+ framework: nil,
79
+ task: "",
80
+ **kwargs
81
+ )
82
+ super(
83
+ model,
84
+ tokenizer: tokenizer,
85
+ modelcard: modelcard,
86
+ framework: framework,
87
+ task: task,
88
+ **kwargs
89
+ )
90
+
91
+ @args_parser = QuestionAnsweringArgumentHandler.new
92
+ check_model_type(MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES)
93
+ end
94
+
95
+ def self.create_sample(
96
+ question:, context:
97
+ )
98
+ if question.is_a?(Array)
99
+ question.zip(context).map { |q, c| SquadExample.new(nil, q, c, nil, nil, nil) }
100
+ else
101
+ SquadExample.new(nil, question, context, nil, nil, nil)
102
+ end
103
+ end
104
+
105
+ def _sanitize_parameters(
106
+ padding: nil,
107
+ topk: nil,
108
+ top_k: nil,
109
+ doc_stride: nil,
110
+ max_answer_len: nil,
111
+ max_seq_len: nil,
112
+ max_question_len: nil,
113
+ handle_impossible_answer: nil,
114
+ align_to_words: nil,
115
+ **kwargs
116
+ )
117
+ # Set defaults values
118
+ preprocess_params = {}
119
+ if !padding.nil?
120
+ preprocess_params[:padding] = padding
121
+ end
122
+ if !doc_stride.nil?
123
+ preprocess_params[:doc_stride] = doc_stride
124
+ end
125
+ if !max_question_len.nil?
126
+ preprocess_params[:max_question_len] = max_question_len
127
+ end
128
+ if !max_seq_len.nil?
129
+ preprocess_params[:max_seq_len] = max_seq_len
130
+ end
131
+
132
+ postprocess_params = {}
133
+ if !topk.nil? && top_k.nil?
134
+ warn("topk parameter is deprecated, use top_k instead")
135
+ top_k = topk
136
+ end
137
+ if !top_k.nil?
138
+ if top_k < 1
139
+ raise ArgumentError, "top_k parameter should be >= 1 (got #{top_k})"
140
+ end
141
+ postprocess_params[:top_k] = top_k
142
+ end
143
+ if !max_answer_len.nil?
144
+ if max_answer_len < 1
145
+ raise ArgumentError, "max_answer_len parameter should be >= 1 (got #{max_answer_len})"
146
+ end
147
+ end
148
+ if !max_answer_len.nil?
149
+ postprocess_params[:max_answer_len] = max_answer_len
150
+ end
151
+ if !handle_impossible_answer.nil?
152
+ postprocess_params[:handle_impossible_answer] = handle_impossible_answer
153
+ end
154
+ if !align_to_words.nil?
155
+ postprocess_params[:align_to_words] = align_to_words
156
+ end
157
+ [preprocess_params, {}, postprocess_params]
158
+ end
159
+
160
+ def call(*args, **kwargs)
161
+ examples = @args_parser.(*args, **kwargs)
162
+ if examples.is_a?(Array) && examples.length == 1
163
+ return super(examples[0], **kwargs)
164
+ end
165
+ super(examples, **kwargs)
166
+ end
167
+
168
+ def preprocess(example, padding: "do_not_pad", doc_stride: nil, max_question_len: 64, max_seq_len: nil)
169
+ # XXX: This is specal, args_parser will not handle anything generator or dataset like
170
+ # For those we expect user to send a simple valid example either directly as a SquadExample or simple dict.
171
+ # So we still need a little sanitation here.
172
+ if example.is_a?(Hash)
173
+ example = SquadExample.new(nil, example[:question], example[:context], nil, nil, nil)
174
+ end
175
+
176
+ if max_seq_len.nil?
177
+ max_seq_len = [@tokenizer.model_max_length, 384].min
178
+ end
179
+ if doc_stride.nil?
180
+ doc_stride = [max_seq_len.div(2), 128].min
181
+ end
182
+
183
+ if doc_stride > max_seq_len
184
+ raise ArgumentError, "`doc_stride` (#{doc_stride}) is larger than `max_seq_len` (#{max_seq_len})"
185
+ end
186
+
187
+ if !@tokenizer.is_fast
188
+ features = squad_convert_examples_to_features(
189
+ examples: [example],
190
+ tokenizer: @tokenizer,
191
+ max_seq_length: max_seq_len,
192
+ doc_stride: doc_stride,
193
+ max_query_length: max_question_len,
194
+ padding_strategy: PaddingStrategy::MAX_LENGTH,
195
+ is_training: false,
196
+ tqdm_enabled: false
197
+ )
198
+ else
199
+ # Define the side we want to truncate / pad and the text/pair sorting
200
+ question_first = @tokenizer.padding_side == "right"
201
+
202
+ encoded_inputs = @tokenizer.(
203
+ question_first ? example.question_text : example.context_text,
204
+ text_pair: question_first ? example.context_text : example.question_text,
205
+ padding: padding,
206
+ truncation: question_first ? "only_second" : "only_first",
207
+ max_length: max_seq_len,
208
+ stride: doc_stride,
209
+ return_token_type_ids: true,
210
+ return_overflowing_tokens: true,
211
+ return_offsets_mapping: true,
212
+ return_special_tokens_mask: true\
213
+ )
214
+ # When the input is too long, it's converted in a batch of inputs with overflowing tokens
215
+ # and a stride of overlap between the inputs. If a batch of inputs is given, a special output
216
+ # "overflow_to_sample_mapping" indicate which member of the encoded batch belong to which original batch sample.
217
+ # Here we tokenize examples one-by-one so we don't need to use "overflow_to_sample_mapping".
218
+ # "num_span" is the number of output samples generated from the overflowing tokens.
219
+ num_spans = encoded_inputs[:input_ids].length
220
+
221
+ # p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
222
+ # We put 0 on the tokens from the context and 1 everywhere else (question and special tokens)
223
+ p_mask =
224
+ num_spans.times.map do |span_id|
225
+ encoded_inputs.sequence_ids(span_id).map { |tok| tok != (question_first ? 1 : 0) }
226
+ end
227
+
228
+ features = []
229
+ num_spans.times do |span_idx|
230
+ input_ids_span_idx = encoded_inputs[:input_ids][span_idx]
231
+ attention_mask_span_idx = (
232
+ encoded_inputs.include?(:attention_mask) ? encoded_inputs[:attention_mask][span_idx] : nil
233
+ )
234
+ token_type_ids_span_idx = (
235
+ encoded_inputs.include?(:token_type_ids) ? encoded_inputs[:token_type_ids][span_idx] : nil
236
+ )
237
+ # keep the cls_token unmasked (some models use it to indicate unanswerable questions)
238
+ if !@tokenizer.cls_token_id.nil?
239
+ cls_indices = (Numo::NArray.cast(input_ids_span_idx).eq(@tokenizer.cls_token_id)).where
240
+ cls_indices.each do |cls_index|
241
+ p_mask[span_idx][cls_index] = false
242
+ end
243
+ end
244
+ submask = p_mask[span_idx]
245
+ features <<
246
+ SquadFeatures.new(
247
+ input_ids: input_ids_span_idx,
248
+ attention_mask: attention_mask_span_idx,
249
+ token_type_ids: token_type_ids_span_idx,
250
+ p_mask: submask,
251
+ encoding: encoded_inputs[span_idx],
252
+ # We don't use the rest of the values - and actually
253
+ # for Fast tokenizer we could totally avoid using SquadFeatures and SquadExample
254
+ cls_index: nil,
255
+ token_to_orig_map: {},
256
+ example_index: 0,
257
+ unique_id: 0,
258
+ paragraph_len: 0,
259
+ token_is_max_context: 0,
260
+ tokens: [],
261
+ start_position: 0,
262
+ end_position: 0,
263
+ is_impossible: false,
264
+ qas_id: nil
265
+ )
266
+ end
267
+ end
268
+
269
+ features.each_with_index do |feature, i|
270
+ fw_args = {}
271
+ others = {}
272
+ model_input_names = @tokenizer.model_input_names + ["p_mask", "token_type_ids"]
273
+
274
+ feature.instance_variables.each do |k|
275
+ v = feature.instance_variable_get(k)
276
+ k = k[1..]
277
+ if model_input_names.include?(k)
278
+ if @framework == "tf"
279
+ raise Todo
280
+ elsif @framework == "pt"
281
+ tensor = Torch.tensor(v)
282
+ if tensor.dtype == Torch.int32
283
+ tensor = tensor.long
284
+ end
285
+ fw_args[k.to_sym] = tensor.unsqueeze(0)
286
+ end
287
+ else
288
+ others[k.to_sym] = v
289
+ end
290
+ end
291
+
292
+ is_last = i == features.length - 1
293
+ yield({example: example, is_last: is_last}.merge(fw_args).merge(others))
294
+ end
295
+ end
296
+
297
+ def _forward(inputs)
298
+ example = inputs[:example]
299
+ model_inputs = @tokenizer.model_input_names.to_h { |k| [k.to_sym, inputs[k.to_sym]] }
300
+ # `XXXForSequenceClassification` models should not use `use_cache=True` even if it's supported
301
+ # model_forward = @model.forward if self.framework == "pt" else self.model.call
302
+ # if "use_cache" in inspect.signature(model_forward).parameters.keys():
303
+ # model_inputs[:use_cache] = false
304
+ # end
305
+ output = @model.(**model_inputs)
306
+ if output.is_a?(Hash)
307
+ {start: output[:start_logits], end: output[:end_logits], example: example}.merge(inputs)
308
+ else
309
+ start, end_ = output[...2]
310
+ {start: start, end: end_, example: example}.merge(inputs)
311
+ end
312
+ end
313
+
314
+ def postprocess(
315
+ model_outputs,
316
+ top_k: 1,
317
+ handle_impossible_answer: false,
318
+ max_answer_len: 15,
319
+ align_to_words: true
320
+ )
321
+ min_null_score = 1000000 # large and positive
322
+ answers = []
323
+ model_outputs.each do |output|
324
+ start_ = output[:start]
325
+ end_ = output[:end]
326
+ example = output[:example]
327
+ p_mask = output[:p_mask]
328
+ attention_mask = (
329
+ !output[:attention_mask].nil? ? output[:attention_mask].numo : nil
330
+ )
331
+
332
+ starts, ends, scores, min_null_score = select_starts_ends(
333
+ start_, end_, p_mask, attention_mask, min_null_score, top_k, handle_impossible_answer, max_answer_len
334
+ )
335
+
336
+ if !@tokenizer.is_fast
337
+ raise Todo
338
+ else
339
+ # Convert the answer (tokens) back to the original text
340
+ # Score: score from the model
341
+ # Start: Index of the first character of the answer in the context string
342
+ # End: Index of the character following the last character of the answer in the context string
343
+ # Answer: Plain text of the answer
344
+ question_first = @tokenizer.padding_side == "right"
345
+ enc = output[:encoding]
346
+
347
+ # Encoding was *not* padded, input_ids *might*.
348
+ # It doesn't make a difference unless we're padding on
349
+ # the left hand side, since now we have different offsets
350
+ # everywhere.
351
+ if @tokenizer.padding_side == "left"
352
+ offset = output[:input_ids].eq(@tokenizer.pad_token_id).numo.sum
353
+ else
354
+ offset = 0
355
+ end
356
+
357
+ # Sometimes the max probability token is in the middle of a word so:
358
+ # - we start by finding the right word containing the token with `token_to_word`
359
+ # - then we convert this word in a character span with `word_to_chars`
360
+ sequence_index = question_first ? 1 : 0
361
+ starts.to_a.zip(ends.to_a, scores.to_a) do |s, e, score|
362
+ s = s - offset
363
+ e = e - offset
364
+
365
+ start_index, end_index = get_indices(enc, s, e, sequence_index, align_to_words)
366
+
367
+ answers <<
368
+ {
369
+ score: score[0],
370
+ start: start_index,
371
+ end: end_index,
372
+ answer: example.context_text[start_index...end_index]
373
+ }
374
+ end
375
+ end
376
+ end
377
+
378
+ if handle_impossible_answer
379
+ answers << {score: min_null_score, start: 0, end: 0, answer: ""}
380
+ end
381
+ answers = answers.sort_by { |x| -x[:score] }[...top_k]
382
+ if answers.length == 1
383
+ return answers[0]
384
+ end
385
+ answers
386
+ end
387
+
388
+ def get_indices(
389
+ enc, s, e, sequence_index, align_to_words
390
+ )
391
+ if align_to_words
392
+ begin
393
+ start_word = enc.token_to_word(s)
394
+ end_word = enc.token_to_word(e)
395
+ start_index = enc.word_to_chars(start_word, sequence_index)[0]
396
+ end_index = enc.word_to_chars(end_word, sequence_index)[1]
397
+ rescue
398
+ # TODO
399
+ raise
400
+ # Some tokenizers don't really handle words. Keep to offsets then.
401
+ start_index = enc.offsets[s][0]
402
+ end_index = enc.offsets[e][1]
403
+ end
404
+ else
405
+ start_index = enc.offsets[s][0]
406
+ end_index = enc.offsets[e][1]
407
+ end
408
+ [start_index, end_index]
409
+ end
410
+
411
+ def decode_spans(
412
+ start, end_, topk, max_answer_len, undesired_tokens
413
+ )
414
+ # Ensure we have batch axis
415
+ if start.ndim == 1
416
+ start = start[nil]
417
+ end
418
+
419
+ if end_.ndim == 1
420
+ end_ = end_[nil]
421
+ end
422
+
423
+ # Compute the score of each tuple(start, end) to be the real answer
424
+ outer = start.expand_dims(-1).dot(end_.expand_dims(1))
425
+
426
+ # Remove candidate with end < start and end - start > max_answer_len
427
+ candidates = outer.triu.tril(max_answer_len - 1)
428
+
429
+ # Inspired by Chen & al. (https://github.com/facebookresearch/DrQA)
430
+ scores_flat = candidates.flatten
431
+ if topk == 1
432
+ idx_sort = [scores_flat.argmax]
433
+ elsif scores_flat.length < topk
434
+ raise Todo
435
+ else
436
+ raise Todo
437
+ end
438
+
439
+ starts, ends = unravel_index(idx_sort, candidates.shape)[1..]
440
+ desired_spans = isin(starts, undesired_tokens.where) & isin(ends, undesired_tokens.where)
441
+ starts = starts[desired_spans]
442
+ ends = ends[desired_spans]
443
+ scores = candidates[0, starts, ends]
444
+
445
+ [starts, ends, scores]
446
+ end
447
+
448
+ def unravel_index(indices, shape)
449
+ indices = Numo::NArray.cast(indices)
450
+ result = []
451
+ factor = 1
452
+ shape.size.times do |i|
453
+ result.unshift(indices / factor % shape[-1 - i])
454
+ factor *= shape[-1 - i]
455
+ end
456
+ result
457
+ end
458
+
459
+ def isin(element, test_elements)
460
+ test_elements = test_elements.to_a
461
+ Numo::Bit.cast(element.to_a.map { |e| test_elements.include?(e) })
462
+ end
463
+
464
+ def select_starts_ends(
465
+ start,
466
+ end_,
467
+ p_mask,
468
+ attention_mask,
469
+ min_null_score = 1000000,
470
+ top_k = 1,
471
+ handle_impossible_answer = false,
472
+ max_answer_len = 15
473
+ )
474
+ # Ensure padded tokens & question tokens cannot belong to the set of candidate answers.
475
+ undesired_tokens = ~p_mask.numo
476
+
477
+ if !attention_mask.nil?
478
+ undesired_tokens = undesired_tokens & attention_mask
479
+ end
480
+
481
+ # Generate mask
482
+ undesired_tokens_mask = undesired_tokens.eq(0)
483
+
484
+ # Make sure non-context indexes in the tensor cannot contribute to the softmax
485
+ start = start.numo
486
+ end_ = end_.numo
487
+ start[undesired_tokens_mask] = -10000.0
488
+ end_[undesired_tokens_mask] = -10000.0
489
+
490
+ # Normalize logits and spans to retrieve the answer
491
+ start = Numo::NMath.exp(start - start.max(axis: -1, keepdims: true))
492
+ start = start / start.sum
493
+
494
+ end_ = Numo::NMath.exp(end_ - end_.max(axis: -1, keepdims: true))
495
+ end_ = end_ / end_.sum
496
+
497
+ if handle_impossible_answer
498
+ min_null_score = [min_null_score, (start[0, 0] * end_[0, 0]).item].min
499
+ end
500
+
501
+ # Mask CLS
502
+ start[0, 0] = end_[0, 0] = 0.0
503
+
504
+ starts, ends, scores = decode_spans(start, end_, top_k, max_answer_len, undesired_tokens)
505
+ [starts, ends, scores, min_null_score]
506
+ end
507
+ end
508
+ end
@@ -0,0 +1,123 @@
1
+ module Transformers
2
+ class TextClassificationPipeline < Pipeline
3
+ def initialize(*args, **kwargs)
4
+ super
5
+
6
+ check_model_type(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES)
7
+ end
8
+
9
+ private
10
+
11
+ def _sanitize_parameters(return_all_scores: nil, function_to_apply: nil, top_k: "", **tokenizer_kwargs)
12
+ # Using "" as default argument because we're going to use `top_k=None` in user code to declare
13
+ # "No top_k"
14
+ preprocess_params = tokenizer_kwargs
15
+
16
+ postprocess_params = {}
17
+ if @model.config.respond_to?(:return_all_scores) && return_all_scores.nil?
18
+ return_all_scores = @model.config.return_all_scores
19
+ end
20
+
21
+ if top_k.is_a?(Integer) || top_k.nil?
22
+ postprocess_params[:top_k] = top_k
23
+ postprocess_params[:_legacy] = false
24
+ elsif !return_all_scores.nil?
25
+ warn(
26
+ "`return_all_scores` is now deprecated, if want a similar functionality use `top_k: nil` instead of" +
27
+ " `return_all_scores: true` or `top_k: 1` instead of `return_all_scores: false`.",
28
+ )
29
+ if return_all_scores
30
+ postprocess_params[:top_k] = nil
31
+ else
32
+ postprocess_params[:top_k] = 1
33
+ end
34
+ end
35
+
36
+ if function_to_apply.is_a?(String)
37
+ function_to_apply = ClassificationFunction.new(function_to_apply.upcase).to_s
38
+ end
39
+
40
+ if !function_to_apply.nil?
41
+ postprocess_params[:function_to_apply] = function_to_apply
42
+ end
43
+ [preprocess_params, {}, postprocess_params]
44
+ end
45
+
46
+ def preprocess(inputs, **tokenizer_kwargs)
47
+ return_tensors = @framework
48
+ if inputs.is_a?(Hash)
49
+ return @tokenizer.(**inputs, return_tensors: return_tensors, **tokenizer_kwargs)
50
+ elsif inputs.is_a?(Array) && inputs.length == 1 && inputs[0].is_a?(Array) && inputs[0].length == 2
51
+ # It used to be valid to use a list of list of list for text pairs, keeping this path for BC
52
+ return @tokenizer.(
53
+ inputs[0][0], text_pair: inputs[0][1], return_tensors: return_tensors, **tokenizer_kwargs
54
+ )
55
+ elsif inputs.is_a?(Array)
56
+ # This is likely an invalid usage of the pipeline attempting to pass text pairs.
57
+ raise ArgumentError,
58
+ "The pipeline received invalid inputs, if you are trying to send text pairs, you can try to send a" +
59
+ ' dictionary `{"text": "My text", "text_pair": "My pair"}` in order to send a text pair.'
60
+ end
61
+ @tokenizer.(inputs, return_tensors: return_tensors, **tokenizer_kwargs)
62
+ end
63
+
64
+ def _forward(model_inputs)
65
+ @model.(**model_inputs.to_h)
66
+ end
67
+
68
+ def postprocess(model_outputs, function_to_apply: nil, top_k: 1, _legacy: true)
69
+ if function_to_apply.nil?
70
+ if @model.config.problem_type == "multi_label_classification" || @model.config.num_labels == 1
71
+ function_to_apply = ClassificationFunction::SIGMOID
72
+ elsif @model.config.problem_type == "single_label_classification" || @model.config.num_labels > 1
73
+ function_to_apply = ClassificationFunction::SOFTMAX
74
+ elsif @model.config.instance_variable_defined?(:@function_to_apply) && function_to_apply.nil?
75
+ function_to_apply = @model.config.function_to_apply
76
+ else
77
+ function_to_apply = ClassificationFunction::NONE
78
+ end
79
+ end
80
+
81
+ outputs = model_outputs["logits"][0]
82
+ outputs = outputs.numo
83
+
84
+ if function_to_apply == ClassificationFunction::SIGMOID
85
+ scores = sigmoid(outputs)
86
+ elsif function_to_apply == ClassificationFunction::SOFTMAX
87
+ scores = softmax(outputs)
88
+ elsif function_to_apply == ClassificationFunction::NONE
89
+ scores = outputs
90
+ else
91
+ raise ArgumentError, "Unrecognized `function_to_apply` argument: #{function_to_apply}"
92
+ end
93
+
94
+ if top_k == 1 && _legacy
95
+ return {label: @model.config.id2label[scores.argmax], score: scores.max}
96
+ end
97
+
98
+ dict_scores =
99
+ scores.to_a.map.with_index do |score, i|
100
+ {label: @model.config.id2label[i], score: score}
101
+ end
102
+ if !_legacy
103
+ dict_scores.sort_by! { |x| -x[:score] }
104
+ if !top_k.nil?
105
+ dict_scores = dict_scores.first(top_k)
106
+ end
107
+ end
108
+ dict_scores
109
+ end
110
+
111
+ private
112
+
113
+ def sigmoid(_outputs)
114
+ 1.0 / (1.0 + Numo::NMath.exp(-_outputs))
115
+ end
116
+
117
+ def softmax(_outputs)
118
+ maxes = _outputs.max(axis: -1, keepdims: true)
119
+ shifted_exp = Numo::NMath.exp(_outputs - maxes)
120
+ shifted_exp / shifted_exp.sum(axis: -1, keepdims: true)
121
+ end
122
+ end
123
+ end