transformers-rb 0.1.2 → 0.1.3

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,792 @@
1
+ # Copyright 2018 The HuggingFace Inc. team, Microsoft Corporation.
2
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ module Transformers
17
+ module Mpnet
18
+ class MPNetPreTrainedModel < PreTrainedModel
19
+ self.config_class = MPNetConfig
20
+ self.base_model_prefix = "mpnet"
21
+
22
+ def _init_weights(module_)
23
+ if module_.is_a?(Torch::NN::Linear)
24
+ # Slightly different from the TF version which uses truncated_normal for initialization
25
+ # cf https://github.com/pytorch/pytorch/pull/5617
26
+ module_.weight.data.normal!(mean: 0.0, std: @config.initializer_range)
27
+ if !module_.bias.nil?
28
+ module_.bias.data.zero!
29
+ end
30
+ elsif module_.is_a?(Torch::NN::Embedding)
31
+ module_.weight.data.normal!(mean: 0.0, std: @config.initializer_range)
32
+ if !module_.padding_idx.nil?
33
+ module_.weight.data.fetch(module_.padding_idx).zero!
34
+ end
35
+ elsif module_.is_a?(Torch::NN::LayerNorm)
36
+ module_.bias.data.zero!
37
+ module_.weight.data.fill!(1.0)
38
+ end
39
+ end
40
+ end
41
+
42
+ class MPNetEmbeddings < Torch::NN::Module
43
+ def initialize(config)
44
+ super()
45
+ @padding_idx = 1
46
+ @word_embeddings = Torch::NN::Embedding.new(config.vocab_size, config.hidden_size, padding_idx: @padding_idx)
47
+ @position_embeddings = Torch::NN::Embedding.new(config.max_position_embeddings, config.hidden_size, padding_idx: @padding_idx)
48
+
49
+ @LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
50
+ @dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
51
+ register_buffer("position_ids", Torch.arange(config.max_position_embeddings).expand([1, -1]), persistent: false)
52
+ end
53
+
54
+ def forward(input_ids: nil, position_ids: nil, inputs_embeds: nil, **kwargs)
55
+ if position_ids.nil?
56
+ if !input_ids.nil?
57
+ position_ids = create_position_ids_from_input_ids(input_ids, @padding_idx)
58
+ else
59
+ position_ids = create_position_ids_from_inputs_embeds(inputs_embeds)
60
+ end
61
+ end
62
+
63
+ if !input_ids.nil?
64
+ input_shape = input_ids.size
65
+ else
66
+ input_shape = inputs_embeds.size[...-1]
67
+ end
68
+
69
+ seq_length = input_shape[1]
70
+
71
+ if position_ids.nil?
72
+ position_ids = @position_ids[0.., ...seq_length]
73
+ end
74
+
75
+ if inputs_embeds.nil?
76
+ inputs_embeds = @word_embeddings.(input_ids)
77
+ end
78
+ position_embeddings = @position_embeddings.(position_ids)
79
+
80
+ embeddings = inputs_embeds + position_embeddings
81
+ embeddings = @LayerNorm.(embeddings)
82
+ embeddings = @dropout.(embeddings)
83
+ embeddings
84
+ end
85
+
86
+ def create_position_ids_from_inputs_embeds(inputs_embeds)
87
+ input_shape = inputs_embeds.size[...-1]
88
+ sequence_length = input_shape[1]
89
+
90
+ position_ids = Torch.arange(@padding_idx + 1, sequence_length + @padding_idx + 1, dtype: Torch.long, device: inputs_embeds.device)
91
+ position_ids.unsqueeze(0).expand(input_shape)
92
+ end
93
+
94
+ def create_position_ids_from_input_ids(input_ids, padding_idx)
95
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
96
+ mask = input_ids.ne(padding_idx).int
97
+ incremental_indices = Torch.cumsum(mask, dim: 1).type_as(mask) * mask
98
+ incremental_indices.long + padding_idx
99
+ end
100
+ end
101
+
102
+ class MPNetSelfAttention < Torch::NN::Module
103
+ def initialize(config)
104
+ super()
105
+ if config.hidden_size % config.num_attention_heads != 0 && !config.instance_variable_defined?(:@embedding_size)
106
+ raise ArgumentError, "The hidden size (#{config.hidden_size}) is not a multiple of the number of attention heads (#{config.num_attention_heads})"
107
+ end
108
+
109
+ @num_attention_heads = config.num_attention_heads
110
+ @attention_head_size = (config.hidden_size / config.num_attention_heads).to_i
111
+ @all_head_size = @num_attention_heads * @attention_head_size
112
+
113
+ @q = Torch::NN::Linear.new(config.hidden_size, @all_head_size)
114
+ @k = Torch::NN::Linear.new(config.hidden_size, @all_head_size)
115
+ @v = Torch::NN::Linear.new(config.hidden_size, @all_head_size)
116
+ @o = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
117
+
118
+ @dropout = Torch::NN::Dropout.new(p: config.attention_probs_dropout_prob)
119
+ end
120
+
121
+ def transpose_for_scores(x)
122
+ new_x_shape = x.size[...-1] + [@num_attention_heads, @attention_head_size]
123
+ x = x.view(*new_x_shape)
124
+ x.permute(0, 2, 1, 3)
125
+ end
126
+
127
+ def forward(
128
+ hidden_states,
129
+ attention_mask: nil,
130
+ head_mask: nil,
131
+ position_bias: nil,
132
+ output_attentions: false,
133
+ **kwargs
134
+ )
135
+ q = @q.(hidden_states)
136
+ k = @k.(hidden_states)
137
+ v = @v.(hidden_states)
138
+
139
+ q = transpose_for_scores(q)
140
+ k = transpose_for_scores(k)
141
+ v = transpose_for_scores(v)
142
+
143
+ # Take the dot product between "query" and "key" to get the raw attention scores.
144
+ attention_scores = Torch.matmul(q, k.transpose(-1, -2))
145
+ attention_scores = attention_scores / Math.sqrt(@attention_head_size)
146
+
147
+ # Apply relative position embedding (precomputed in MPNetEncoder) if provided.
148
+ if !position_bias.nil?
149
+ attention_scores += position_bias
150
+ end
151
+
152
+ if !attention_mask.nil?
153
+ attention_scores = attention_scores + attention_mask
154
+ end
155
+
156
+ # Normalize the attention scores to probabilities.
157
+ attention_probs = Torch::NN::Functional.softmax(attention_scores, dim: -1)
158
+
159
+ attention_probs = @dropout.(attention_probs)
160
+
161
+ if !head_mask.nil?
162
+ attention_probs = attention_probs * head_mask
163
+ end
164
+
165
+ c = Torch.matmul(attention_probs, v)
166
+
167
+ c = c.permute(0, 2, 1, 3).contiguous
168
+ new_c_shape = c.size[...-2] + [@all_head_size]
169
+ c = c.view(*new_c_shape)
170
+
171
+ o = @o.(c)
172
+
173
+ outputs = output_attentions ? [o, attention_probs] : [o]
174
+ outputs
175
+ end
176
+ end
177
+
178
+ class MPNetAttention < Torch::NN::Module
179
+ def initialize(config)
180
+ super()
181
+ @attn = MPNetSelfAttention.new(config)
182
+ @LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
183
+ @dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
184
+
185
+ @pruned_heads = Set.new
186
+ end
187
+
188
+ def prune_heads(heads)
189
+ if heads.length == 0
190
+ return
191
+ end
192
+ heads, index = TorchUtils.find_pruneable_heads_and_indices(heads, @attn.num_attention_heads, @attn.attention_head_size, @pruned_heads)
193
+
194
+ @q = TorchUtils.prune_linear_layer(@attn.q, index)
195
+ @k = TorchUtils.prune_linear_layer(@attn.k, index)
196
+ @v = TorchUtils.prune_linear_layer(@attn.v, index)
197
+ @o = TorchUtils.prune_linear_layer(@attn.o, index, dim: 1)
198
+
199
+ @num_attention_heads = @attn.num_attention_heads - heads.length
200
+ @all_head_size = @attn.attention_head_size * @attn.num_attention_heads
201
+ @pruned_heads = @pruned_heads.union(heads)
202
+ end
203
+
204
+ def forward(
205
+ hidden_states,
206
+ attention_mask: nil,
207
+ head_mask: nil,
208
+ position_bias: nil,
209
+ output_attentions: false,
210
+ **kwargs
211
+ )
212
+ self_outputs = @attn.(hidden_states, attention_mask: attention_mask, head_mask: head_mask, position_bias: position_bias, output_attentions: output_attentions)
213
+ attention_output = @LayerNorm.(@dropout.(self_outputs[0]) + hidden_states)
214
+ outputs = [attention_output] + self_outputs[1..]
215
+ outputs
216
+ end
217
+ end
218
+
219
+ class MPNetIntermediate < Torch::NN::Module
220
+ def initialize(config)
221
+ super()
222
+ @dense = Torch::NN::Linear.new(config.hidden_size, config.intermediate_size)
223
+ if config.hidden_act.is_a?(String)
224
+ @intermediate_act_fn = ACT2FN[config.hidden_act]
225
+ else
226
+ @intermediate_act_fn = config.hidden_act
227
+ end
228
+ end
229
+
230
+ def forward(hidden_states)
231
+ hidden_states = @dense.(hidden_states)
232
+ hidden_states = @intermediate_act_fn.(hidden_states)
233
+ hidden_states
234
+ end
235
+ end
236
+
237
+ class MPNetOutput < Torch::NN::Module
238
+ def initialize(config)
239
+ super()
240
+ @dense = Torch::NN::Linear.new(config.intermediate_size, config.hidden_size)
241
+ @LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
242
+ @dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
243
+ end
244
+
245
+ def forward(hidden_states, input_tensor)
246
+ hidden_states = @dense.(hidden_states)
247
+ hidden_states = @dropout.(hidden_states)
248
+ hidden_states = @LayerNorm.(hidden_states + input_tensor)
249
+ hidden_states
250
+ end
251
+ end
252
+
253
+ class MPNetLayer < Torch::NN::Module
254
+ def initialize(config)
255
+ super()
256
+ @attention = MPNetAttention.new(config)
257
+ @intermediate = MPNetIntermediate.new(config)
258
+ @output = MPNetOutput.new(config)
259
+ end
260
+
261
+ def forward(
262
+ hidden_states,
263
+ attention_mask: nil,
264
+ head_mask: nil,
265
+ position_bias: nil,
266
+ output_attentions: false,
267
+ **kwargs
268
+ )
269
+ self_attention_outputs = @attention.(hidden_states, attention_mask: attention_mask, head_mask: head_mask, position_bias: position_bias, output_attentions: output_attentions)
270
+ attention_output = self_attention_outputs[0]
271
+ outputs = self_attention_outputs[1..]
272
+
273
+ intermediate_output = @intermediate.(attention_output)
274
+ layer_output = @output.(intermediate_output, attention_output)
275
+ outputs = [layer_output] + outputs
276
+ outputs
277
+ end
278
+ end
279
+
280
+ class MPNetEncoder < Torch::NN::Module
281
+ def initialize(config)
282
+ super()
283
+ @config = config
284
+ @n_heads = config.num_attention_heads
285
+ @layer = Torch::NN::ModuleList.new(config.num_hidden_layers.times.map { |_| MPNetLayer.new(config) })
286
+ @relative_attention_bias = Torch::NN::Embedding.new(config.relative_attention_num_buckets, @n_heads)
287
+ end
288
+
289
+ def forward(
290
+ hidden_states,
291
+ attention_mask: nil,
292
+ head_mask: nil,
293
+ output_attentions: false,
294
+ output_hidden_states: false,
295
+ return_dict: false,
296
+ **kwargs
297
+ )
298
+ position_bias = compute_position_bias(hidden_states)
299
+ all_hidden_states = output_hidden_states ? [] : nil
300
+ all_attentions = output_attentions ? [] : nil
301
+ @layer.each_with_index do |layer_module, i|
302
+ if output_hidden_states
303
+ all_hidden_states = all_hidden_states + [hidden_states]
304
+ end
305
+
306
+ layer_outputs = layer_module.(hidden_states, attention_mask: attention_mask, head_mask: head_mask[i], position_bias: position_bias, output_attentions: output_attentions, **kwargs)
307
+ hidden_states = layer_outputs[0]
308
+
309
+ if output_attentions
310
+ all_attentions = all_attentions + [layer_outputs[1]]
311
+ end
312
+ end
313
+
314
+ # Add last layer
315
+ if output_hidden_states
316
+ all_hidden_states = all_hidden_states + [hidden_states]
317
+ end
318
+
319
+ if !return_dict
320
+ return Array([hidden_states, all_hidden_states, all_attentions].select { |v| !v.nil? })
321
+ end
322
+ BaseModelOutput.new(last_hidden_state: hidden_states, hidden_states: all_hidden_states, attentions: all_attentions)
323
+ end
324
+
325
+ def compute_position_bias(x, position_ids: nil, num_buckets: 32)
326
+ bsz, qlen, klen = [x.size(0), x.size(1), x.size(1)]
327
+ if !position_ids.nil?
328
+ context_position = position_ids[0.., 0.., nil]
329
+ memory_position = position_ids[0.., nil, 0..]
330
+ else
331
+ context_position = Torch.arange(qlen, dtype: Torch.long)[0.., nil]
332
+ memory_position = Torch.arange(klen, dtype: Torch.long)[nil, 0..]
333
+ end
334
+
335
+ relative_position = memory_position - context_position
336
+
337
+ rp_bucket = self.class.relative_position_bucket(relative_position, num_buckets: num_buckets)
338
+ rp_bucket = rp_bucket.to(x.device)
339
+ values = @relative_attention_bias.(rp_bucket)
340
+ values = values.permute([2, 0, 1]).unsqueeze(0)
341
+ values = values.expand([bsz, -1, qlen, klen]).contiguous
342
+ values
343
+ end
344
+
345
+ def self.relative_position_bucket(relative_position, num_buckets: 32, max_distance: 128)
346
+ ret = 0
347
+ n = -relative_position
348
+
349
+ num_buckets /= 2
350
+ ret += n.lt(0).to(Torch.long) * num_buckets
351
+ n = Torch.abs(n)
352
+
353
+ max_exact = num_buckets / 2
354
+ is_small = n.lt(max_exact)
355
+
356
+ val_if_large = max_exact + (
357
+ Torch.log(n.float / max_exact) / Math.log(max_distance / max_exact) * (num_buckets - max_exact)
358
+ ).to(Torch.long)
359
+
360
+ val_if_large = Torch.min(val_if_large, Torch.full_like(val_if_large, num_buckets - 1))
361
+ ret += Torch.where(is_small, n, val_if_large)
362
+ ret
363
+ end
364
+ end
365
+
366
+ class MPNetPooler < Torch::NN::Module
367
+ def initialize(config)
368
+ super()
369
+ @dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
370
+ @activation = Torch::NN::Tanh.new
371
+ end
372
+
373
+ def forward(hidden_states)
374
+ # We "pool" the model by simply taking the hidden state corresponding
375
+ # to the first token.
376
+ first_token_tensor = hidden_states[0.., 0]
377
+ pooled_output = @dense.(first_token_tensor)
378
+ pooled_output = @activation.(pooled_output)
379
+ pooled_output
380
+ end
381
+ end
382
+
383
+ class MPNetModel < MPNetPreTrainedModel
384
+ def initialize(config, add_pooling_layer: true)
385
+ super(config)
386
+ @config = config
387
+
388
+ @embeddings = MPNetEmbeddings.new(config)
389
+ @encoder = MPNetEncoder.new(config)
390
+ @pooler = add_pooling_layer ? MPNetPooler.new(config) : nil
391
+
392
+ # Initialize weights and apply final processing
393
+ post_init
394
+ end
395
+
396
+ def get_input_embeddings
397
+ @embeddings.word_embeddings
398
+ end
399
+
400
+ def set_input_embeddings(value)
401
+ @word_embeddings = value
402
+ end
403
+
404
+ def _prune_heads(heads_to_prune)
405
+ heads_to_prune.each do |layer, heads|
406
+ @encoder.layer[layer].attention.prune_heads(heads)
407
+ end
408
+ end
409
+
410
+ def forward(
411
+ input_ids: nil,
412
+ attention_mask: nil,
413
+ position_ids: nil,
414
+ head_mask: nil,
415
+ inputs_embeds: nil,
416
+ output_attentions: nil,
417
+ output_hidden_states: nil,
418
+ return_dict: nil,
419
+ **kwargs
420
+ )
421
+ output_attentions = !output_attentions.nil? ? output_attentions : @config.output_attentions
422
+ output_hidden_states = !output_hidden_states.nil? ? output_hidden_states : @config.output_hidden_states
423
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
424
+
425
+ if !input_ids.nil? && !inputs_embeds.nil?
426
+ raise ArgumentError, "You cannot specify both input_ids and inputs_embeds at the same time"
427
+ elsif !input_ids.nil?
428
+ warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
429
+ input_shape = input_ids.size
430
+ elsif !inputs_embeds.nil?
431
+ input_shape = inputs_embeds.size[...-1]
432
+ else
433
+ raise ArgumentError, "You have to specify either input_ids or inputs_embeds"
434
+ end
435
+
436
+ device = !input_ids.nil? ? input_ids.device : inputs_embeds.device
437
+
438
+ if attention_mask.nil?
439
+ attention_mask = Torch.ones(input_shape, device: device)
440
+ end
441
+ extended_attention_mask = get_extended_attention_mask(attention_mask, input_shape)
442
+
443
+ head_mask = get_head_mask(head_mask, @config.num_hidden_layers)
444
+ embedding_output = @embeddings.(input_ids: input_ids, position_ids: position_ids, inputs_embeds: inputs_embeds)
445
+ encoder_outputs = @encoder.(embedding_output, attention_mask: extended_attention_mask, head_mask: head_mask, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
446
+ sequence_output = encoder_outputs[0]
447
+ pooled_output = !@pooler.nil? ? @pooler.(sequence_output) : nil
448
+
449
+ if !return_dict
450
+ return [sequence_output, pooled_output] + encoder_outputs[1..]
451
+ end
452
+
453
+ BaseModelOutputWithPooling.new(last_hidden_state: sequence_output, pooler_output: pooled_output, hidden_states: encoder_outputs.hidden_states, attentions: encoder_outputs.attentions)
454
+ end
455
+ end
456
+
457
+ class MPNetForMaskedLM < MPNetPreTrainedModel
458
+ self._tied_weights_keys = ["lm_head.decoder"]
459
+
460
+ def initialize(config)
461
+ super(config)
462
+
463
+ @mpnet = MPNetModel.new(config, add_pooling_layer: false)
464
+ @lm_head = MPNetLMHead.new(config)
465
+
466
+ # Initialize weights and apply final processing
467
+ post_init
468
+ end
469
+
470
+ def get_output_embeddings
471
+ @lm_head.decoder
472
+ end
473
+
474
+ def set_output_embeddings(new_embeddings)
475
+ @decoder = new_embeddings
476
+ @bias = new_embeddings.bias
477
+ end
478
+
479
+ def forward(
480
+ input_ids: nil,
481
+ attention_mask: nil,
482
+ position_ids: nil,
483
+ head_mask: nil,
484
+ inputs_embeds: nil,
485
+ labels: nil,
486
+ output_attentions: nil,
487
+ output_hidden_states: nil,
488
+ return_dict: nil
489
+ )
490
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
491
+
492
+ outputs = @mpnet.(input_ids, attention_mask: attention_mask, position_ids: position_ids, head_mask: head_mask, inputs_embeds: inputs_embeds, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
493
+
494
+ sequence_output = outputs[0]
495
+ prediction_scores = @lm_head.(sequence_output)
496
+
497
+ masked_lm_loss = nil
498
+ if !labels.nil?
499
+ loss_fct = Torch::NN::CrossEntropyLoss.new
500
+ masked_lm_loss = loss_fct.(prediction_scores.view(-1, @config.vocab_size), labels.view(-1))
501
+ end
502
+
503
+ if !return_dict
504
+ output = [prediction_scores] + outputs[2..]
505
+ return !masked_lm_loss.nil? ? [masked_lm_loss] + output : output
506
+ end
507
+
508
+ MaskedLMOutput.new(loss: masked_lm_loss, logits: prediction_scores, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
509
+ end
510
+ end
511
+
512
+ class MPNetLMHead < Torch::NN::Module
513
+ def initialize(config)
514
+ super()
515
+ @dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
516
+ @layer_norm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
517
+
518
+ @decoder = Torch::NN::Linear.new(config.hidden_size, config.vocab_size, bias: false)
519
+ @bias = Torch::NN::Parameter.new(Torch.zeros(config.vocab_size))
520
+
521
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
522
+ @bias = @bias
523
+ end
524
+
525
+ def _tie_weights
526
+ @bias = @bias
527
+ end
528
+
529
+ def forward(features, **kwargs)
530
+ x = @dense.(features)
531
+ x = Activations.gelu(x)
532
+ x = @layer_norm.(x)
533
+
534
+ # project back to size of vocabulary with bias
535
+ x = @decoder.(x)
536
+
537
+ x
538
+ end
539
+ end
540
+
541
+ class MPNetForSequenceClassification < MPNetPreTrainedModel
542
+ def initialize(config)
543
+ super(config)
544
+
545
+ @num_labels = config.num_labels
546
+ @mpnet = MPNetModel.new(config, add_pooling_layer: false)
547
+ @classifier = MPNetClassificationHead.new(config)
548
+
549
+ # Initialize weights and apply final processing
550
+ post_init
551
+ end
552
+
553
+ def forward(
554
+ input_ids: nil,
555
+ attention_mask: nil,
556
+ position_ids: nil,
557
+ head_mask: nil,
558
+ inputs_embeds: nil,
559
+ labels: nil,
560
+ output_attentions: nil,
561
+ output_hidden_states: nil,
562
+ return_dict: nil
563
+ )
564
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
565
+
566
+ outputs = @mpnet.(input_ids, attention_mask: attention_mask, position_ids: position_ids, head_mask: head_mask, inputs_embeds: inputs_embeds, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
567
+ sequence_output = outputs[0]
568
+ logits = @classifier.(sequence_output)
569
+
570
+ loss = nil
571
+ if !labels.nil?
572
+ if @config.problem_type.nil?
573
+ if @num_labels == 1
574
+ @problem_type = "regression"
575
+ elsif @num_labels > 1 && labels.dtype == Torch.long || labels.dtype == Torch.int
576
+ @problem_type = "single_label_classification"
577
+ else
578
+ @problem_type = "multi_label_classification"
579
+ end
580
+ end
581
+
582
+ if @config.problem_type == "regression"
583
+ loss_fct = Torch::NN::MSELoss.new
584
+ if @num_labels == 1
585
+ loss = loss_fct.(logits.squeeze, labels.squeeze)
586
+ else
587
+ loss = loss_fct.(logits, labels)
588
+ end
589
+ elsif @config.problem_type == "single_label_classification"
590
+ loss_fct = Torch::NN::CrossEntropyLoss.new
591
+ loss = loss_fct.(logits.view(-1, @num_labels), labels.view(-1))
592
+ elsif @config.problem_type == "multi_label_classification"
593
+ loss_fct = Torch::NN::BCEWithLogitsLoss.new
594
+ loss = loss_fct.(logits, labels)
595
+ end
596
+ end
597
+ if !return_dict
598
+ output = [logits] + outputs[2..]
599
+ return !loss.nil? ? [loss] + output : output
600
+ end
601
+
602
+ SequenceClassifierOutput.new(loss: loss, logits: logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
603
+ end
604
+ end
605
+
606
+ class MPNetForMultipleChoice < MPNetPreTrainedModel
607
+ def initialize(config)
608
+ super(config)
609
+
610
+ @mpnet = MPNetModel.new(config)
611
+ @dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
612
+ @classifier = Torch::NN::Linear.new(config.hidden_size, 1)
613
+
614
+ # Initialize weights and apply final processing
615
+ post_init
616
+ end
617
+
618
+ def forward(
619
+ input_ids: nil,
620
+ attention_mask: nil,
621
+ position_ids: nil,
622
+ head_mask: nil,
623
+ inputs_embeds: nil,
624
+ labels: nil,
625
+ output_attentions: nil,
626
+ output_hidden_states: nil,
627
+ return_dict: nil
628
+ )
629
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
630
+ num_choices = !input_ids.nil? ? input_ids.shape[1] : inputs_embeds.shape[1]
631
+
632
+ flat_input_ids = !input_ids.nil? ? input_ids.view(-1, input_ids.size(-1)) : nil
633
+ flat_position_ids = !position_ids.nil? ? position_ids.view(-1, position_ids.size(-1)) : nil
634
+ flat_attention_mask = !attention_mask.nil? ? attention_mask.view(-1, attention_mask.size(-1)) : nil
635
+ flat_inputs_embeds = !inputs_embeds.nil? ? inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) : nil
636
+
637
+ outputs = @mpnet.(flat_input_ids, position_ids: flat_position_ids, attention_mask: flat_attention_mask, head_mask: head_mask, inputs_embeds: flat_inputs_embeds, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
638
+ pooled_output = outputs[1]
639
+
640
+ pooled_output = @dropout.(pooled_output)
641
+ logits = @classifier.(pooled_output)
642
+ reshaped_logits = logits.view(-1, num_choices)
643
+
644
+ loss = nil
645
+ if !labels.nil?
646
+ loss_fct = Torch::NN::CrossEntropyLoss.new
647
+ loss = loss_fct.(reshaped_logits, labels)
648
+ end
649
+
650
+ if !return_dict
651
+ output = [reshaped_logits] + outputs[2..]
652
+ return !loss.nil? ? [loss] + output : output
653
+ end
654
+
655
+ MultipleChoiceModelOutput.new(loss: loss, logits: reshaped_logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
656
+ end
657
+ end
658
+
659
+ class MPNetForTokenClassification < MPNetPreTrainedModel
660
+ def initialize(config)
661
+ super(config)
662
+ @num_labels = config.num_labels
663
+
664
+ @mpnet = MPNetModel.new(config, add_pooling_layer: false)
665
+ @dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
666
+ @classifier = Torch::NN::Linear.new(config.hidden_size, config.num_labels)
667
+
668
+ # Initialize weights and apply final processing
669
+ post_init
670
+ end
671
+
672
+ def forward(
673
+ input_ids: nil,
674
+ attention_mask: nil,
675
+ position_ids: nil,
676
+ head_mask: nil,
677
+ inputs_embeds: nil,
678
+ labels: nil,
679
+ output_attentions: nil,
680
+ output_hidden_states: nil,
681
+ return_dict: nil
682
+ )
683
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
684
+
685
+ outputs = @mpnet.(input_ids, attention_mask: attention_mask, position_ids: position_ids, head_mask: head_mask, inputs_embeds: inputs_embeds, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
686
+
687
+ sequence_output = outputs[0]
688
+
689
+ sequence_output = @dropout.(sequence_output)
690
+ logits = @classifier.(sequence_output)
691
+
692
+ loss = nil
693
+ if !labels.nil?
694
+ loss_fct = Torch::NN::CrossEntropyLoss.new
695
+ loss = loss_fct.(logits.view(-1, @num_labels), labels.view(-1))
696
+ end
697
+
698
+ if !return_dict
699
+ output = [logits] + outputs[2..]
700
+ return !loss.nil? ? [loss] + output : output
701
+ end
702
+
703
+ TokenClassifierOutput.new(loss: loss, logits: logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
704
+ end
705
+ end
706
+
707
+ class MPNetClassificationHead < Torch::NN::Module
708
+ def initialize(config)
709
+ super()
710
+ @dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
711
+ @dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
712
+ @out_proj = Torch::NN::Linear.new(config.hidden_size, config.num_labels)
713
+ end
714
+
715
+ def forward(features, **kwargs)
716
+ x = features[0.., 0, 0..]
717
+ x = @dropout.(x)
718
+ x = @dense.(x)
719
+ x = Torch.tanh(x)
720
+ x = @dropout.(x)
721
+ x = @out_proj.(x)
722
+ x
723
+ end
724
+ end
725
+
726
+ class MPNetForQuestionAnswering < MPNetPreTrainedModel
727
+ def initialize(config)
728
+ super(config)
729
+
730
+ @num_labels = config.num_labels
731
+ @mpnet = MPNetModel.new(config, add_pooling_layer: false)
732
+ @qa_outputs = Torch::NN::Linear.new(config.hidden_size, config.num_labels)
733
+
734
+ # Initialize weights and apply final processing
735
+ post_init
736
+ end
737
+
738
+ def forward(
739
+ input_ids: nil,
740
+ attention_mask: nil,
741
+ position_ids: nil,
742
+ head_mask: nil,
743
+ inputs_embeds: nil,
744
+ start_positions: nil,
745
+ end_positions: nil,
746
+ output_attentions: nil,
747
+ output_hidden_states: nil,
748
+ return_dict: nil
749
+ )
750
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
751
+
752
+ outputs = @mpnet.(input_ids, attention_mask: attention_mask, position_ids: position_ids, head_mask: head_mask, inputs_embeds: inputs_embeds, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
753
+
754
+ sequence_output = outputs[0]
755
+
756
+ logits = @qa_outputs.(sequence_output)
757
+ start_logits, end_logits = logits.split(1, dim: -1)
758
+ start_logits = start_logits.squeeze(-1).contiguous
759
+ end_logits = end_logits.squeeze(-1).contiguous
760
+
761
+ total_loss = nil
762
+ if !start_positions.nil? && !end_positions.nil?
763
+ # If we are on multi-GPU, split add a dimension
764
+ if start_positions.size.length > 1
765
+ start_positions = start_positions.squeeze(-1)
766
+ end
767
+ if end_positions.size.length > 1
768
+ end_positions = end_positions.squeeze(-1)
769
+ end
770
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
771
+ ignored_index = start_logits.size(1)
772
+ start_positions = start_positions.clamp(0, ignored_index)
773
+ end_positions = end_positions.clamp(0, ignored_index)
774
+
775
+ loss_fct = Torch::NN::CrossEntropyLoss.new(ignore_index: ignored_index)
776
+ start_loss = loss_fct.(start_logits, start_positions)
777
+ end_loss = loss_fct.(end_logits, end_positions)
778
+ total_loss = (start_loss + end_loss) / 2
779
+ end
780
+
781
+ if !return_dict
782
+ output = [start_logits, end_logits] + outputs[2..]
783
+ return !total_loss.nil? ? [total_loss] + output : output
784
+ end
785
+
786
+ QuestionAnsweringModelOutput.new(loss: total_loss, start_logits: start_logits, end_logits: end_logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
787
+ end
788
+ end
789
+ end
790
+
791
+ MPNetForMaskedLM = Mpnet::MPNetForMaskedLM
792
+ end