transformers-rb 0.1.1 → 0.1.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,792 @@
1
+ # Copyright 2018 The HuggingFace Inc. team, Microsoft Corporation.
2
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ module Transformers
17
+ module Mpnet
18
+ class MPNetPreTrainedModel < PreTrainedModel
19
+ self.config_class = MPNetConfig
20
+ self.base_model_prefix = "mpnet"
21
+
22
+ def _init_weights(module_)
23
+ if module_.is_a?(Torch::NN::Linear)
24
+ # Slightly different from the TF version which uses truncated_normal for initialization
25
+ # cf https://github.com/pytorch/pytorch/pull/5617
26
+ module_.weight.data.normal!(mean: 0.0, std: @config.initializer_range)
27
+ if !module_.bias.nil?
28
+ module_.bias.data.zero!
29
+ end
30
+ elsif module_.is_a?(Torch::NN::Embedding)
31
+ module_.weight.data.normal!(mean: 0.0, std: @config.initializer_range)
32
+ if !module_.padding_idx.nil?
33
+ module_.weight.data.fetch(module_.padding_idx).zero!
34
+ end
35
+ elsif module_.is_a?(Torch::NN::LayerNorm)
36
+ module_.bias.data.zero!
37
+ module_.weight.data.fill!(1.0)
38
+ end
39
+ end
40
+ end
41
+
42
+ class MPNetEmbeddings < Torch::NN::Module
43
+ def initialize(config)
44
+ super()
45
+ @padding_idx = 1
46
+ @word_embeddings = Torch::NN::Embedding.new(config.vocab_size, config.hidden_size, padding_idx: @padding_idx)
47
+ @position_embeddings = Torch::NN::Embedding.new(config.max_position_embeddings, config.hidden_size, padding_idx: @padding_idx)
48
+
49
+ @LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
50
+ @dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
51
+ register_buffer("position_ids", Torch.arange(config.max_position_embeddings).expand([1, -1]), persistent: false)
52
+ end
53
+
54
+ def forward(input_ids: nil, position_ids: nil, inputs_embeds: nil, **kwargs)
55
+ if position_ids.nil?
56
+ if !input_ids.nil?
57
+ position_ids = create_position_ids_from_input_ids(input_ids, @padding_idx)
58
+ else
59
+ position_ids = create_position_ids_from_inputs_embeds(inputs_embeds)
60
+ end
61
+ end
62
+
63
+ if !input_ids.nil?
64
+ input_shape = input_ids.size
65
+ else
66
+ input_shape = inputs_embeds.size[...-1]
67
+ end
68
+
69
+ seq_length = input_shape[1]
70
+
71
+ if position_ids.nil?
72
+ position_ids = @position_ids[0.., ...seq_length]
73
+ end
74
+
75
+ if inputs_embeds.nil?
76
+ inputs_embeds = @word_embeddings.(input_ids)
77
+ end
78
+ position_embeddings = @position_embeddings.(position_ids)
79
+
80
+ embeddings = inputs_embeds + position_embeddings
81
+ embeddings = @LayerNorm.(embeddings)
82
+ embeddings = @dropout.(embeddings)
83
+ embeddings
84
+ end
85
+
86
+ def create_position_ids_from_inputs_embeds(inputs_embeds)
87
+ input_shape = inputs_embeds.size[...-1]
88
+ sequence_length = input_shape[1]
89
+
90
+ position_ids = Torch.arange(@padding_idx + 1, sequence_length + @padding_idx + 1, dtype: Torch.long, device: inputs_embeds.device)
91
+ position_ids.unsqueeze(0).expand(input_shape)
92
+ end
93
+
94
+ def create_position_ids_from_input_ids(input_ids, padding_idx)
95
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
96
+ mask = input_ids.ne(padding_idx).int
97
+ incremental_indices = Torch.cumsum(mask, dim: 1).type_as(mask) * mask
98
+ incremental_indices.long + padding_idx
99
+ end
100
+ end
101
+
102
+ class MPNetSelfAttention < Torch::NN::Module
103
+ def initialize(config)
104
+ super()
105
+ if config.hidden_size % config.num_attention_heads != 0 && !config.instance_variable_defined?(:@embedding_size)
106
+ raise ArgumentError, "The hidden size (#{config.hidden_size}) is not a multiple of the number of attention heads (#{config.num_attention_heads})"
107
+ end
108
+
109
+ @num_attention_heads = config.num_attention_heads
110
+ @attention_head_size = (config.hidden_size / config.num_attention_heads).to_i
111
+ @all_head_size = @num_attention_heads * @attention_head_size
112
+
113
+ @q = Torch::NN::Linear.new(config.hidden_size, @all_head_size)
114
+ @k = Torch::NN::Linear.new(config.hidden_size, @all_head_size)
115
+ @v = Torch::NN::Linear.new(config.hidden_size, @all_head_size)
116
+ @o = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
117
+
118
+ @dropout = Torch::NN::Dropout.new(p: config.attention_probs_dropout_prob)
119
+ end
120
+
121
+ def transpose_for_scores(x)
122
+ new_x_shape = x.size[...-1] + [@num_attention_heads, @attention_head_size]
123
+ x = x.view(*new_x_shape)
124
+ x.permute(0, 2, 1, 3)
125
+ end
126
+
127
+ def forward(
128
+ hidden_states,
129
+ attention_mask: nil,
130
+ head_mask: nil,
131
+ position_bias: nil,
132
+ output_attentions: false,
133
+ **kwargs
134
+ )
135
+ q = @q.(hidden_states)
136
+ k = @k.(hidden_states)
137
+ v = @v.(hidden_states)
138
+
139
+ q = transpose_for_scores(q)
140
+ k = transpose_for_scores(k)
141
+ v = transpose_for_scores(v)
142
+
143
+ # Take the dot product between "query" and "key" to get the raw attention scores.
144
+ attention_scores = Torch.matmul(q, k.transpose(-1, -2))
145
+ attention_scores = attention_scores / Math.sqrt(@attention_head_size)
146
+
147
+ # Apply relative position embedding (precomputed in MPNetEncoder) if provided.
148
+ if !position_bias.nil?
149
+ attention_scores += position_bias
150
+ end
151
+
152
+ if !attention_mask.nil?
153
+ attention_scores = attention_scores + attention_mask
154
+ end
155
+
156
+ # Normalize the attention scores to probabilities.
157
+ attention_probs = Torch::NN::Functional.softmax(attention_scores, dim: -1)
158
+
159
+ attention_probs = @dropout.(attention_probs)
160
+
161
+ if !head_mask.nil?
162
+ attention_probs = attention_probs * head_mask
163
+ end
164
+
165
+ c = Torch.matmul(attention_probs, v)
166
+
167
+ c = c.permute(0, 2, 1, 3).contiguous
168
+ new_c_shape = c.size[...-2] + [@all_head_size]
169
+ c = c.view(*new_c_shape)
170
+
171
+ o = @o.(c)
172
+
173
+ outputs = output_attentions ? [o, attention_probs] : [o]
174
+ outputs
175
+ end
176
+ end
177
+
178
+ class MPNetAttention < Torch::NN::Module
179
+ def initialize(config)
180
+ super()
181
+ @attn = MPNetSelfAttention.new(config)
182
+ @LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
183
+ @dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
184
+
185
+ @pruned_heads = Set.new
186
+ end
187
+
188
+ def prune_heads(heads)
189
+ if heads.length == 0
190
+ return
191
+ end
192
+ heads, index = TorchUtils.find_pruneable_heads_and_indices(heads, @attn.num_attention_heads, @attn.attention_head_size, @pruned_heads)
193
+
194
+ @q = TorchUtils.prune_linear_layer(@attn.q, index)
195
+ @k = TorchUtils.prune_linear_layer(@attn.k, index)
196
+ @v = TorchUtils.prune_linear_layer(@attn.v, index)
197
+ @o = TorchUtils.prune_linear_layer(@attn.o, index, dim: 1)
198
+
199
+ @num_attention_heads = @attn.num_attention_heads - heads.length
200
+ @all_head_size = @attn.attention_head_size * @attn.num_attention_heads
201
+ @pruned_heads = @pruned_heads.union(heads)
202
+ end
203
+
204
+ def forward(
205
+ hidden_states,
206
+ attention_mask: nil,
207
+ head_mask: nil,
208
+ position_bias: nil,
209
+ output_attentions: false,
210
+ **kwargs
211
+ )
212
+ self_outputs = @attn.(hidden_states, attention_mask: attention_mask, head_mask: head_mask, position_bias: position_bias, output_attentions: output_attentions)
213
+ attention_output = @LayerNorm.(@dropout.(self_outputs[0]) + hidden_states)
214
+ outputs = [attention_output] + self_outputs[1..]
215
+ outputs
216
+ end
217
+ end
218
+
219
+ class MPNetIntermediate < Torch::NN::Module
220
+ def initialize(config)
221
+ super()
222
+ @dense = Torch::NN::Linear.new(config.hidden_size, config.intermediate_size)
223
+ if config.hidden_act.is_a?(String)
224
+ @intermediate_act_fn = ACT2FN[config.hidden_act]
225
+ else
226
+ @intermediate_act_fn = config.hidden_act
227
+ end
228
+ end
229
+
230
+ def forward(hidden_states)
231
+ hidden_states = @dense.(hidden_states)
232
+ hidden_states = @intermediate_act_fn.(hidden_states)
233
+ hidden_states
234
+ end
235
+ end
236
+
237
+ class MPNetOutput < Torch::NN::Module
238
+ def initialize(config)
239
+ super()
240
+ @dense = Torch::NN::Linear.new(config.intermediate_size, config.hidden_size)
241
+ @LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
242
+ @dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
243
+ end
244
+
245
+ def forward(hidden_states, input_tensor)
246
+ hidden_states = @dense.(hidden_states)
247
+ hidden_states = @dropout.(hidden_states)
248
+ hidden_states = @LayerNorm.(hidden_states + input_tensor)
249
+ hidden_states
250
+ end
251
+ end
252
+
253
+ class MPNetLayer < Torch::NN::Module
254
+ def initialize(config)
255
+ super()
256
+ @attention = MPNetAttention.new(config)
257
+ @intermediate = MPNetIntermediate.new(config)
258
+ @output = MPNetOutput.new(config)
259
+ end
260
+
261
+ def forward(
262
+ hidden_states,
263
+ attention_mask: nil,
264
+ head_mask: nil,
265
+ position_bias: nil,
266
+ output_attentions: false,
267
+ **kwargs
268
+ )
269
+ self_attention_outputs = @attention.(hidden_states, attention_mask: attention_mask, head_mask: head_mask, position_bias: position_bias, output_attentions: output_attentions)
270
+ attention_output = self_attention_outputs[0]
271
+ outputs = self_attention_outputs[1..]
272
+
273
+ intermediate_output = @intermediate.(attention_output)
274
+ layer_output = @output.(intermediate_output, attention_output)
275
+ outputs = [layer_output] + outputs
276
+ outputs
277
+ end
278
+ end
279
+
280
+ class MPNetEncoder < Torch::NN::Module
281
+ def initialize(config)
282
+ super()
283
+ @config = config
284
+ @n_heads = config.num_attention_heads
285
+ @layer = Torch::NN::ModuleList.new(config.num_hidden_layers.times.map { |_| MPNetLayer.new(config) })
286
+ @relative_attention_bias = Torch::NN::Embedding.new(config.relative_attention_num_buckets, @n_heads)
287
+ end
288
+
289
+ def forward(
290
+ hidden_states,
291
+ attention_mask: nil,
292
+ head_mask: nil,
293
+ output_attentions: false,
294
+ output_hidden_states: false,
295
+ return_dict: false,
296
+ **kwargs
297
+ )
298
+ position_bias = compute_position_bias(hidden_states)
299
+ all_hidden_states = output_hidden_states ? [] : nil
300
+ all_attentions = output_attentions ? [] : nil
301
+ @layer.each_with_index do |layer_module, i|
302
+ if output_hidden_states
303
+ all_hidden_states = all_hidden_states + [hidden_states]
304
+ end
305
+
306
+ layer_outputs = layer_module.(hidden_states, attention_mask: attention_mask, head_mask: head_mask[i], position_bias: position_bias, output_attentions: output_attentions, **kwargs)
307
+ hidden_states = layer_outputs[0]
308
+
309
+ if output_attentions
310
+ all_attentions = all_attentions + [layer_outputs[1]]
311
+ end
312
+ end
313
+
314
+ # Add last layer
315
+ if output_hidden_states
316
+ all_hidden_states = all_hidden_states + [hidden_states]
317
+ end
318
+
319
+ if !return_dict
320
+ return Array([hidden_states, all_hidden_states, all_attentions].select { |v| !v.nil? })
321
+ end
322
+ BaseModelOutput.new(last_hidden_state: hidden_states, hidden_states: all_hidden_states, attentions: all_attentions)
323
+ end
324
+
325
+ def compute_position_bias(x, position_ids: nil, num_buckets: 32)
326
+ bsz, qlen, klen = [x.size(0), x.size(1), x.size(1)]
327
+ if !position_ids.nil?
328
+ context_position = position_ids[0.., 0.., nil]
329
+ memory_position = position_ids[0.., nil, 0..]
330
+ else
331
+ context_position = Torch.arange(qlen, dtype: Torch.long)[0.., nil]
332
+ memory_position = Torch.arange(klen, dtype: Torch.long)[nil, 0..]
333
+ end
334
+
335
+ relative_position = memory_position - context_position
336
+
337
+ rp_bucket = self.class.relative_position_bucket(relative_position, num_buckets: num_buckets)
338
+ rp_bucket = rp_bucket.to(x.device)
339
+ values = @relative_attention_bias.(rp_bucket)
340
+ values = values.permute([2, 0, 1]).unsqueeze(0)
341
+ values = values.expand([bsz, -1, qlen, klen]).contiguous
342
+ values
343
+ end
344
+
345
+ def self.relative_position_bucket(relative_position, num_buckets: 32, max_distance: 128)
346
+ ret = 0
347
+ n = -relative_position
348
+
349
+ num_buckets /= 2
350
+ ret += n.lt(0).to(Torch.long) * num_buckets
351
+ n = Torch.abs(n)
352
+
353
+ max_exact = num_buckets / 2
354
+ is_small = n.lt(max_exact)
355
+
356
+ val_if_large = max_exact + (
357
+ Torch.log(n.float / max_exact) / Math.log(max_distance / max_exact) * (num_buckets - max_exact)
358
+ ).to(Torch.long)
359
+
360
+ val_if_large = Torch.min(val_if_large, Torch.full_like(val_if_large, num_buckets - 1))
361
+ ret += Torch.where(is_small, n, val_if_large)
362
+ ret
363
+ end
364
+ end
365
+
366
+ class MPNetPooler < Torch::NN::Module
367
+ def initialize(config)
368
+ super()
369
+ @dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
370
+ @activation = Torch::NN::Tanh.new
371
+ end
372
+
373
+ def forward(hidden_states)
374
+ # We "pool" the model by simply taking the hidden state corresponding
375
+ # to the first token.
376
+ first_token_tensor = hidden_states[0.., 0]
377
+ pooled_output = @dense.(first_token_tensor)
378
+ pooled_output = @activation.(pooled_output)
379
+ pooled_output
380
+ end
381
+ end
382
+
383
+ class MPNetModel < MPNetPreTrainedModel
384
+ def initialize(config, add_pooling_layer: true)
385
+ super(config)
386
+ @config = config
387
+
388
+ @embeddings = MPNetEmbeddings.new(config)
389
+ @encoder = MPNetEncoder.new(config)
390
+ @pooler = add_pooling_layer ? MPNetPooler.new(config) : nil
391
+
392
+ # Initialize weights and apply final processing
393
+ post_init
394
+ end
395
+
396
+ def get_input_embeddings
397
+ @embeddings.word_embeddings
398
+ end
399
+
400
+ def set_input_embeddings(value)
401
+ @word_embeddings = value
402
+ end
403
+
404
+ def _prune_heads(heads_to_prune)
405
+ heads_to_prune.each do |layer, heads|
406
+ @encoder.layer[layer].attention.prune_heads(heads)
407
+ end
408
+ end
409
+
410
+ def forward(
411
+ input_ids: nil,
412
+ attention_mask: nil,
413
+ position_ids: nil,
414
+ head_mask: nil,
415
+ inputs_embeds: nil,
416
+ output_attentions: nil,
417
+ output_hidden_states: nil,
418
+ return_dict: nil,
419
+ **kwargs
420
+ )
421
+ output_attentions = !output_attentions.nil? ? output_attentions : @config.output_attentions
422
+ output_hidden_states = !output_hidden_states.nil? ? output_hidden_states : @config.output_hidden_states
423
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
424
+
425
+ if !input_ids.nil? && !inputs_embeds.nil?
426
+ raise ArgumentError, "You cannot specify both input_ids and inputs_embeds at the same time"
427
+ elsif !input_ids.nil?
428
+ warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
429
+ input_shape = input_ids.size
430
+ elsif !inputs_embeds.nil?
431
+ input_shape = inputs_embeds.size[...-1]
432
+ else
433
+ raise ArgumentError, "You have to specify either input_ids or inputs_embeds"
434
+ end
435
+
436
+ device = !input_ids.nil? ? input_ids.device : inputs_embeds.device
437
+
438
+ if attention_mask.nil?
439
+ attention_mask = Torch.ones(input_shape, device: device)
440
+ end
441
+ extended_attention_mask = get_extended_attention_mask(attention_mask, input_shape)
442
+
443
+ head_mask = get_head_mask(head_mask, @config.num_hidden_layers)
444
+ embedding_output = @embeddings.(input_ids: input_ids, position_ids: position_ids, inputs_embeds: inputs_embeds)
445
+ encoder_outputs = @encoder.(embedding_output, attention_mask: extended_attention_mask, head_mask: head_mask, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
446
+ sequence_output = encoder_outputs[0]
447
+ pooled_output = !@pooler.nil? ? @pooler.(sequence_output) : nil
448
+
449
+ if !return_dict
450
+ return [sequence_output, pooled_output] + encoder_outputs[1..]
451
+ end
452
+
453
+ BaseModelOutputWithPooling.new(last_hidden_state: sequence_output, pooler_output: pooled_output, hidden_states: encoder_outputs.hidden_states, attentions: encoder_outputs.attentions)
454
+ end
455
+ end
456
+
457
+ class MPNetForMaskedLM < MPNetPreTrainedModel
458
+ self._tied_weights_keys = ["lm_head.decoder"]
459
+
460
+ def initialize(config)
461
+ super(config)
462
+
463
+ @mpnet = MPNetModel.new(config, add_pooling_layer: false)
464
+ @lm_head = MPNetLMHead.new(config)
465
+
466
+ # Initialize weights and apply final processing
467
+ post_init
468
+ end
469
+
470
+ def get_output_embeddings
471
+ @lm_head.decoder
472
+ end
473
+
474
+ def set_output_embeddings(new_embeddings)
475
+ @decoder = new_embeddings
476
+ @bias = new_embeddings.bias
477
+ end
478
+
479
+ def forward(
480
+ input_ids: nil,
481
+ attention_mask: nil,
482
+ position_ids: nil,
483
+ head_mask: nil,
484
+ inputs_embeds: nil,
485
+ labels: nil,
486
+ output_attentions: nil,
487
+ output_hidden_states: nil,
488
+ return_dict: nil
489
+ )
490
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
491
+
492
+ outputs = @mpnet.(input_ids, attention_mask: attention_mask, position_ids: position_ids, head_mask: head_mask, inputs_embeds: inputs_embeds, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
493
+
494
+ sequence_output = outputs[0]
495
+ prediction_scores = @lm_head.(sequence_output)
496
+
497
+ masked_lm_loss = nil
498
+ if !labels.nil?
499
+ loss_fct = Torch::NN::CrossEntropyLoss.new
500
+ masked_lm_loss = loss_fct.(prediction_scores.view(-1, @config.vocab_size), labels.view(-1))
501
+ end
502
+
503
+ if !return_dict
504
+ output = [prediction_scores] + outputs[2..]
505
+ return !masked_lm_loss.nil? ? [masked_lm_loss] + output : output
506
+ end
507
+
508
+ MaskedLMOutput.new(loss: masked_lm_loss, logits: prediction_scores, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
509
+ end
510
+ end
511
+
512
+ class MPNetLMHead < Torch::NN::Module
513
+ def initialize(config)
514
+ super()
515
+ @dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
516
+ @layer_norm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
517
+
518
+ @decoder = Torch::NN::Linear.new(config.hidden_size, config.vocab_size, bias: false)
519
+ @bias = Torch::NN::Parameter.new(Torch.zeros(config.vocab_size))
520
+
521
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
522
+ @bias = @bias
523
+ end
524
+
525
+ def _tie_weights
526
+ @bias = @bias
527
+ end
528
+
529
+ def forward(features, **kwargs)
530
+ x = @dense.(features)
531
+ x = Activations.gelu(x)
532
+ x = @layer_norm.(x)
533
+
534
+ # project back to size of vocabulary with bias
535
+ x = @decoder.(x)
536
+
537
+ x
538
+ end
539
+ end
540
+
541
+ class MPNetForSequenceClassification < MPNetPreTrainedModel
542
+ def initialize(config)
543
+ super(config)
544
+
545
+ @num_labels = config.num_labels
546
+ @mpnet = MPNetModel.new(config, add_pooling_layer: false)
547
+ @classifier = MPNetClassificationHead.new(config)
548
+
549
+ # Initialize weights and apply final processing
550
+ post_init
551
+ end
552
+
553
+ def forward(
554
+ input_ids: nil,
555
+ attention_mask: nil,
556
+ position_ids: nil,
557
+ head_mask: nil,
558
+ inputs_embeds: nil,
559
+ labels: nil,
560
+ output_attentions: nil,
561
+ output_hidden_states: nil,
562
+ return_dict: nil
563
+ )
564
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
565
+
566
+ outputs = @mpnet.(input_ids, attention_mask: attention_mask, position_ids: position_ids, head_mask: head_mask, inputs_embeds: inputs_embeds, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
567
+ sequence_output = outputs[0]
568
+ logits = @classifier.(sequence_output)
569
+
570
+ loss = nil
571
+ if !labels.nil?
572
+ if @config.problem_type.nil?
573
+ if @num_labels == 1
574
+ @problem_type = "regression"
575
+ elsif @num_labels > 1 && labels.dtype == Torch.long || labels.dtype == Torch.int
576
+ @problem_type = "single_label_classification"
577
+ else
578
+ @problem_type = "multi_label_classification"
579
+ end
580
+ end
581
+
582
+ if @config.problem_type == "regression"
583
+ loss_fct = Torch::NN::MSELoss.new
584
+ if @num_labels == 1
585
+ loss = loss_fct.(logits.squeeze, labels.squeeze)
586
+ else
587
+ loss = loss_fct.(logits, labels)
588
+ end
589
+ elsif @config.problem_type == "single_label_classification"
590
+ loss_fct = Torch::NN::CrossEntropyLoss.new
591
+ loss = loss_fct.(logits.view(-1, @num_labels), labels.view(-1))
592
+ elsif @config.problem_type == "multi_label_classification"
593
+ loss_fct = Torch::NN::BCEWithLogitsLoss.new
594
+ loss = loss_fct.(logits, labels)
595
+ end
596
+ end
597
+ if !return_dict
598
+ output = [logits] + outputs[2..]
599
+ return !loss.nil? ? [loss] + output : output
600
+ end
601
+
602
+ SequenceClassifierOutput.new(loss: loss, logits: logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
603
+ end
604
+ end
605
+
606
+ class MPNetForMultipleChoice < MPNetPreTrainedModel
607
+ def initialize(config)
608
+ super(config)
609
+
610
+ @mpnet = MPNetModel.new(config)
611
+ @dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
612
+ @classifier = Torch::NN::Linear.new(config.hidden_size, 1)
613
+
614
+ # Initialize weights and apply final processing
615
+ post_init
616
+ end
617
+
618
+ def forward(
619
+ input_ids: nil,
620
+ attention_mask: nil,
621
+ position_ids: nil,
622
+ head_mask: nil,
623
+ inputs_embeds: nil,
624
+ labels: nil,
625
+ output_attentions: nil,
626
+ output_hidden_states: nil,
627
+ return_dict: nil
628
+ )
629
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
630
+ num_choices = !input_ids.nil? ? input_ids.shape[1] : inputs_embeds.shape[1]
631
+
632
+ flat_input_ids = !input_ids.nil? ? input_ids.view(-1, input_ids.size(-1)) : nil
633
+ flat_position_ids = !position_ids.nil? ? position_ids.view(-1, position_ids.size(-1)) : nil
634
+ flat_attention_mask = !attention_mask.nil? ? attention_mask.view(-1, attention_mask.size(-1)) : nil
635
+ flat_inputs_embeds = !inputs_embeds.nil? ? inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) : nil
636
+
637
+ outputs = @mpnet.(flat_input_ids, position_ids: flat_position_ids, attention_mask: flat_attention_mask, head_mask: head_mask, inputs_embeds: flat_inputs_embeds, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
638
+ pooled_output = outputs[1]
639
+
640
+ pooled_output = @dropout.(pooled_output)
641
+ logits = @classifier.(pooled_output)
642
+ reshaped_logits = logits.view(-1, num_choices)
643
+
644
+ loss = nil
645
+ if !labels.nil?
646
+ loss_fct = Torch::NN::CrossEntropyLoss.new
647
+ loss = loss_fct.(reshaped_logits, labels)
648
+ end
649
+
650
+ if !return_dict
651
+ output = [reshaped_logits] + outputs[2..]
652
+ return !loss.nil? ? [loss] + output : output
653
+ end
654
+
655
+ MultipleChoiceModelOutput.new(loss: loss, logits: reshaped_logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
656
+ end
657
+ end
658
+
659
+ class MPNetForTokenClassification < MPNetPreTrainedModel
660
+ def initialize(config)
661
+ super(config)
662
+ @num_labels = config.num_labels
663
+
664
+ @mpnet = MPNetModel.new(config, add_pooling_layer: false)
665
+ @dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
666
+ @classifier = Torch::NN::Linear.new(config.hidden_size, config.num_labels)
667
+
668
+ # Initialize weights and apply final processing
669
+ post_init
670
+ end
671
+
672
+ def forward(
673
+ input_ids: nil,
674
+ attention_mask: nil,
675
+ position_ids: nil,
676
+ head_mask: nil,
677
+ inputs_embeds: nil,
678
+ labels: nil,
679
+ output_attentions: nil,
680
+ output_hidden_states: nil,
681
+ return_dict: nil
682
+ )
683
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
684
+
685
+ outputs = @mpnet.(input_ids, attention_mask: attention_mask, position_ids: position_ids, head_mask: head_mask, inputs_embeds: inputs_embeds, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
686
+
687
+ sequence_output = outputs[0]
688
+
689
+ sequence_output = @dropout.(sequence_output)
690
+ logits = @classifier.(sequence_output)
691
+
692
+ loss = nil
693
+ if !labels.nil?
694
+ loss_fct = Torch::NN::CrossEntropyLoss.new
695
+ loss = loss_fct.(logits.view(-1, @num_labels), labels.view(-1))
696
+ end
697
+
698
+ if !return_dict
699
+ output = [logits] + outputs[2..]
700
+ return !loss.nil? ? [loss] + output : output
701
+ end
702
+
703
+ TokenClassifierOutput.new(loss: loss, logits: logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
704
+ end
705
+ end
706
+
707
+ class MPNetClassificationHead < Torch::NN::Module
708
+ def initialize(config)
709
+ super()
710
+ @dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
711
+ @dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
712
+ @out_proj = Torch::NN::Linear.new(config.hidden_size, config.num_labels)
713
+ end
714
+
715
+ def forward(features, **kwargs)
716
+ x = features[0.., 0, 0..]
717
+ x = @dropout.(x)
718
+ x = @dense.(x)
719
+ x = Torch.tanh(x)
720
+ x = @dropout.(x)
721
+ x = @out_proj.(x)
722
+ x
723
+ end
724
+ end
725
+
726
+ class MPNetForQuestionAnswering < MPNetPreTrainedModel
727
+ def initialize(config)
728
+ super(config)
729
+
730
+ @num_labels = config.num_labels
731
+ @mpnet = MPNetModel.new(config, add_pooling_layer: false)
732
+ @qa_outputs = Torch::NN::Linear.new(config.hidden_size, config.num_labels)
733
+
734
+ # Initialize weights and apply final processing
735
+ post_init
736
+ end
737
+
738
+ def forward(
739
+ input_ids: nil,
740
+ attention_mask: nil,
741
+ position_ids: nil,
742
+ head_mask: nil,
743
+ inputs_embeds: nil,
744
+ start_positions: nil,
745
+ end_positions: nil,
746
+ output_attentions: nil,
747
+ output_hidden_states: nil,
748
+ return_dict: nil
749
+ )
750
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
751
+
752
+ outputs = @mpnet.(input_ids, attention_mask: attention_mask, position_ids: position_ids, head_mask: head_mask, inputs_embeds: inputs_embeds, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
753
+
754
+ sequence_output = outputs[0]
755
+
756
+ logits = @qa_outputs.(sequence_output)
757
+ start_logits, end_logits = logits.split(1, dim: -1)
758
+ start_logits = start_logits.squeeze(-1).contiguous
759
+ end_logits = end_logits.squeeze(-1).contiguous
760
+
761
+ total_loss = nil
762
+ if !start_positions.nil? && !end_positions.nil?
763
+ # If we are on multi-GPU, split add a dimension
764
+ if start_positions.size.length > 1
765
+ start_positions = start_positions.squeeze(-1)
766
+ end
767
+ if end_positions.size.length > 1
768
+ end_positions = end_positions.squeeze(-1)
769
+ end
770
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
771
+ ignored_index = start_logits.size(1)
772
+ start_positions = start_positions.clamp(0, ignored_index)
773
+ end_positions = end_positions.clamp(0, ignored_index)
774
+
775
+ loss_fct = Torch::NN::CrossEntropyLoss.new(ignore_index: ignored_index)
776
+ start_loss = loss_fct.(start_logits, start_positions)
777
+ end_loss = loss_fct.(end_logits, end_positions)
778
+ total_loss = (start_loss + end_loss) / 2
779
+ end
780
+
781
+ if !return_dict
782
+ output = [start_logits, end_logits] + outputs[2..]
783
+ return !total_loss.nil? ? [total_loss] + output : output
784
+ end
785
+
786
+ QuestionAnsweringModelOutput.new(loss: total_loss, start_logits: start_logits, end_logits: end_logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
787
+ end
788
+ end
789
+ end
790
+
791
+ MPNetForMaskedLM = Mpnet::MPNetForMaskedLM
792
+ end