transformers-rb 0.1.2 → 0.1.3

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,1210 @@
1
+ # Copyright 2020 Microsoft and the Hugging Face Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ module DebertaV2
17
+ class ContextPooler < Torch::NN::Module
18
+ def initialize(config)
19
+ super()
20
+ @dense = Torch::NN::Linear.new(config.pooler_hidden_size, config.pooler_hidden_size)
21
+ @dropout = StableDropout.new(config.pooler_dropout)
22
+ @config = config
23
+ end
24
+
25
+ def forward(hidden_states)
26
+ # We "pool" the model by simply taking the hidden state corresponding
27
+ # to the first token.
28
+
29
+ context_token = hidden_states[0.., 0]
30
+ context_token = @dropout.(context_token)
31
+ pooled_output = @dense.(context_token)
32
+ pooled_output = ACT2FN[@config.pooler_hidden_act].(pooled_output)
33
+ pooled_output
34
+ end
35
+
36
+ def output_dim
37
+ @config.hidden_size
38
+ end
39
+ end
40
+
41
+ # TODO Torch::Autograd::Function
42
+ class XSoftmax
43
+ def self.apply(input, mask, dim)
44
+ @dim = dim
45
+ rmask = mask.to(Torch.bool).bitwise_not
46
+
47
+ # TODO use Torch.finfo
48
+ output = input.masked_fill(rmask, Torch.tensor(-3.40282e+38))
49
+ output = Torch.softmax(output, @dim)
50
+ output.masked_fill!(rmask, 0)
51
+ # ctx.save_for_backward(output)
52
+ output
53
+ end
54
+ end
55
+
56
+ class DropoutContext
57
+ def initialize
58
+ @dropout = 0
59
+ @mask = nil
60
+ @scale = 1
61
+ @reuse_mask = true
62
+ end
63
+ end
64
+
65
+ def get_mask(input, local_context)
66
+ if !local_context.is_a?(DropoutContext)
67
+ dropout = local_context
68
+ mask = nil
69
+ else
70
+ dropout = local_context.dropout
71
+ dropout *= local_context.scale
72
+ mask = local_context.reuse_mask ? local_context.mask : nil
73
+ end
74
+
75
+ if dropout > 0 && mask.nil?
76
+ mask = (1 - Torch.empty_like(input).bernoulli!(1 - dropout)).to(Torch.bool)
77
+ end
78
+
79
+ if local_context.is_a?(DropoutContext)
80
+ if local_context.mask.nil?
81
+ @mask = mask
82
+ end
83
+ end
84
+
85
+ [mask, dropout]
86
+ end
87
+
88
+ # TODO Torch::Autograd::Function
89
+ class XDropout
90
+ def self.apply(input, local_ctx)
91
+ mask, dropout = get_mask(input, local_ctx)
92
+ @scale = 1.0 / (1 - dropout)
93
+ if dropout > 0
94
+ # ctx.save_for_backward(mask)
95
+ input.masked_fill(mask, 0) * ctx.scale
96
+ else
97
+ input
98
+ end
99
+ end
100
+ end
101
+
102
+ class StableDropout < Torch::NN::Module
103
+ def initialize(drop_prob)
104
+ super()
105
+ @drop_prob = drop_prob
106
+ @count = 0
107
+ @context_stack = nil
108
+ end
109
+
110
+ def forward(x)
111
+ if @training && @drop_prob > 0
112
+ return XDropout.apply(x, get_context)
113
+ end
114
+ x
115
+ end
116
+
117
+ def clear_context
118
+ @count = 0
119
+ @context_stack = nil
120
+ end
121
+
122
+ def init_context(reuse_mask: true, scale: 1)
123
+ if @context_stack.nil?
124
+ @context_stack = []
125
+ end
126
+ @count = 0
127
+ @context_stack.each do |c|
128
+ @reuse_mask = reuse_mask
129
+ @scale = scale
130
+ end
131
+ end
132
+
133
+ def get_context
134
+ if !@context_stack.nil?
135
+ if @count >= @context_stack.length
136
+ @context_stack << DropoutContext.new
137
+ end
138
+ ctx = @context_stack.fetch(@count)
139
+ @dropout = @drop_prob
140
+ @count += 1
141
+ ctx
142
+ else
143
+ @drop_prob
144
+ end
145
+ end
146
+ end
147
+
148
+ class DebertaV2SelfOutput < Torch::NN::Module
149
+ def initialize(config)
150
+ super()
151
+ @dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
152
+ @LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
153
+ @dropout = StableDropout.new(config.hidden_dropout_prob)
154
+ end
155
+
156
+ def forward(hidden_states, input_tensor)
157
+ hidden_states = @dense.(hidden_states)
158
+ hidden_states = @dropout.(hidden_states)
159
+ hidden_states = @LayerNorm.(hidden_states + input_tensor)
160
+ hidden_states
161
+ end
162
+ end
163
+
164
+ class DebertaV2Attention < Torch::NN::Module
165
+ def initialize(config)
166
+ super()
167
+ @self = DisentangledSelfAttention.new(config)
168
+ @output = DebertaV2SelfOutput.new(config)
169
+ @config = config
170
+ end
171
+
172
+ def forward(
173
+ hidden_states,
174
+ attention_mask,
175
+ output_attentions: false,
176
+ query_states: nil,
177
+ relative_pos: nil,
178
+ rel_embeddings: nil
179
+ )
180
+ self_output = @self.(hidden_states, attention_mask, output_attentions:, query_states: query_states, relative_pos: relative_pos, rel_embeddings: rel_embeddings)
181
+ if output_attentions
182
+ self_output, att_matrix = self_output
183
+ end
184
+ if query_states.nil?
185
+ query_states = hidden_states
186
+ end
187
+ attention_output = @output.(self_output, query_states)
188
+
189
+ if output_attentions
190
+ [attention_output, att_matrix]
191
+ else
192
+ attention_output
193
+ end
194
+ end
195
+ end
196
+
197
+ class DebertaV2Intermediate < Torch::NN::Module
198
+ def initialize(config)
199
+ super()
200
+ @dense = Torch::NN::Linear.new(config.hidden_size, config.intermediate_size)
201
+ if config.hidden_act.is_a?(String)
202
+ @intermediate_act_fn = ACT2FN[config.hidden_act]
203
+ else
204
+ @intermediate_act_fn = config.hidden_act
205
+ end
206
+ end
207
+
208
+ def forward(hidden_states)
209
+ hidden_states = @dense.(hidden_states)
210
+ hidden_states = @intermediate_act_fn.(hidden_states)
211
+ hidden_states
212
+ end
213
+ end
214
+
215
+ class DebertaV2Output < Torch::NN::Module
216
+ def initialize(config)
217
+ super()
218
+ @dense = Torch::NN::Linear.new(config.intermediate_size, config.hidden_size)
219
+ @LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
220
+ @dropout = StableDropout.new(config.hidden_dropout_prob)
221
+ @config = config
222
+ end
223
+
224
+ def forward(hidden_states, input_tensor)
225
+ hidden_states = @dense.(hidden_states)
226
+ hidden_states = @dropout.(hidden_states)
227
+ hidden_states = @LayerNorm.(hidden_states + input_tensor)
228
+ hidden_states
229
+ end
230
+ end
231
+
232
+ class DebertaV2Layer < Torch::NN::Module
233
+ def initialize(config)
234
+ super()
235
+ @attention = DebertaV2Attention.new(config)
236
+ @intermediate = DebertaV2Intermediate.new(config)
237
+ @output = DebertaV2Output.new(config)
238
+ end
239
+
240
+ def forward(
241
+ hidden_states,
242
+ attention_mask,
243
+ query_states: nil,
244
+ relative_pos: nil,
245
+ rel_embeddings: nil,
246
+ output_attentions: false
247
+ )
248
+ attention_output = @attention.(hidden_states, attention_mask, output_attentions: output_attentions, query_states: query_states, relative_pos: relative_pos, rel_embeddings: rel_embeddings)
249
+ if output_attentions
250
+ attention_output, att_matrix = attention_output
251
+ end
252
+ intermediate_output = @intermediate.(attention_output)
253
+ layer_output = @output.(intermediate_output, attention_output)
254
+ if output_attentions
255
+ [layer_output, att_matrix]
256
+ else
257
+ layer_output
258
+ end
259
+ end
260
+ end
261
+
262
+ class ConvLayer < Torch::NN::Module
263
+ def initialize(config)
264
+ super()
265
+ kernel_size = config.getattr("conv_kernel_size", 3)
266
+ groups = config.getattr("conv_groups", 1)
267
+ @conv_act = config.getattr("conv_act", "tanh")
268
+ @conv = Torch::NN::Conv1d.new(config.hidden_size, config.hidden_size, kernel_size, padding: (kernel_size - 1) / 2, groups: groups)
269
+ @LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
270
+ @dropout = StableDropout.new(config.hidden_dropout_prob)
271
+ @config = config
272
+ end
273
+
274
+ def forward(hidden_states, residual_states, input_mask)
275
+ out = @conv.(hidden_states.permute(0, 2, 1).contiguous).permute(0, 2, 1).contiguous
276
+ rmask = (1 - input_mask).bool
277
+ out.masked_fill!(rmask.unsqueeze(-1).expand(out.size), 0)
278
+ out = ACT2FN[@conv_act].(@dropout.(out))
279
+
280
+ layer_norm_input = residual_states + out
281
+ output = @LayerNorm.(layer_norm_input).to(layer_norm_input)
282
+
283
+ if input_mask.nil?
284
+ output_states = output
285
+ elsif input_mask.dim != layer_norm_input.dim
286
+ if input_mask.dim == 4
287
+ input_mask = input_mask.squeeze(1).squeeze(1)
288
+ end
289
+ input_mask = input_mask.unsqueeze(2)
290
+ end
291
+
292
+ output_states
293
+ end
294
+ end
295
+
296
+ class DebertaV2Encoder < Torch::NN::Module
297
+ def initialize(config)
298
+ super()
299
+
300
+ @layer = Torch::NN::ModuleList.new(config.num_hidden_layers.times.map { |_| DebertaV2Layer.new(config) })
301
+ @relative_attention = config.getattr("relative_attention", false)
302
+
303
+ if @relative_attention
304
+ @max_relative_positions = config.getattr("max_relative_positions", -1)
305
+ if @max_relative_positions < 1
306
+ @max_relative_positions = config.max_position_embeddings
307
+ end
308
+
309
+ @position_buckets = config.getattr("position_buckets", -1)
310
+ pos_ebd_size = @max_relative_positions * 2
311
+
312
+ if @position_buckets > 0
313
+ pos_ebd_size = @position_buckets * 2
314
+ end
315
+
316
+ @rel_embeddings = Torch::NN::Embedding.new(pos_ebd_size, config.hidden_size)
317
+ end
318
+
319
+ @norm_rel_ebd = config.getattr("norm_rel_ebd", "none").downcase.split("|").map { |x| x.strip }
320
+
321
+ if @norm_rel_ebd.include?("layer_norm")
322
+ @LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps, elementwise_affine: true)
323
+ end
324
+
325
+ @conv = config.getattr("conv_kernel_size", 0) > 0 ? ConvLayer.new(config) : nil
326
+ @gradient_checkpointing = false
327
+ end
328
+
329
+ def get_rel_embedding
330
+ rel_embeddings = @relative_attention ? @rel_embeddings.weight : nil
331
+ if !rel_embeddings.nil? && @norm_rel_ebd.include?("layer_norm")
332
+ rel_embeddings = @LayerNorm.(rel_embeddings)
333
+ end
334
+ rel_embeddings
335
+ end
336
+
337
+ def get_attention_mask(attention_mask)
338
+ if attention_mask.dim <= 2
339
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
340
+ attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1)
341
+ elsif attention_mask.dim == 3
342
+ attention_mask = attention_mask.unsqueeze(1)
343
+ end
344
+
345
+ attention_mask
346
+ end
347
+
348
+ def get_rel_pos(hidden_states, query_states: nil, relative_pos: nil)
349
+ if @relative_attention && relative_pos.nil?
350
+ q = !query_states.nil? ? query_states.size(-2) : hidden_states.size(-2)
351
+ relative_pos = DebertaV2.build_relative_position(q, hidden_states.size(-2), bucket_size: @position_buckets, max_position: @max_relative_positions, device: hidden_states.device)
352
+ end
353
+ relative_pos
354
+ end
355
+
356
+ def forward(
357
+ hidden_states,
358
+ attention_mask,
359
+ output_hidden_states: true,
360
+ output_attentions: false,
361
+ query_states: nil,
362
+ relative_pos: nil,
363
+ return_dict: true
364
+ )
365
+ if attention_mask.dim <= 2
366
+ input_mask = attention_mask
367
+ else
368
+ input_mask = attention_mask.sum(-2) > 0
369
+ end
370
+ attention_mask = get_attention_mask(attention_mask)
371
+ relative_pos = get_rel_pos(hidden_states, query_states:, relative_pos:)
372
+
373
+ all_hidden_states = output_hidden_states ? [] : nil
374
+ all_attentions = output_attentions ? [] : nil
375
+
376
+ if hidden_states.is_a?(Array)
377
+ next_kv = hidden_states[0]
378
+ else
379
+ next_kv = hidden_states
380
+ end
381
+ rel_embeddings = get_rel_embedding
382
+ output_states = next_kv
383
+ @layer.each_with_index do |layer_module, i|
384
+ if output_hidden_states
385
+ all_hidden_states = all_hidden_states + [output_states]
386
+ end
387
+
388
+ if @gradient_checkpointing && @training
389
+ output_states = _gradient_checkpointing_func(layer_module.__call__, next_kv, attention_mask, query_states, relative_pos, rel_embeddings, output_attentions)
390
+ else
391
+ output_states = layer_module.(next_kv, attention_mask, query_states: query_states, relative_pos: relative_pos, rel_embeddings: rel_embeddings, output_attentions: output_attentions)
392
+ end
393
+
394
+ if output_attentions
395
+ output_states, att_m = output_states
396
+ end
397
+
398
+ if i == 0 && !@conv.nil?
399
+ output_states = @conv.(hidden_states, output_states, input_mask)
400
+ end
401
+
402
+ if !query_states.nil?
403
+ query_states = output_states
404
+ if hidden_states.is_a?(Array)
405
+ next_kv = i + 1 < @layer.length ? hidden_states[i + 1] : nil
406
+ end
407
+ else
408
+ next_kv = output_states
409
+ end
410
+
411
+ if output_attentions
412
+ all_attentions = all_attentions + [att_m]
413
+ end
414
+ end
415
+
416
+ if output_hidden_states
417
+ all_hidden_states = all_hidden_states + [output_states]
418
+ end
419
+
420
+ if !return_dict
421
+ return Array([output_states, all_hidden_states, all_attentions].select { |v| !v.nil? })
422
+ end
423
+ BaseModelOutput.new(last_hidden_state: output_states, hidden_states: all_hidden_states, attentions: all_attentions)
424
+ end
425
+ end
426
+
427
+ def self.make_log_bucket_position(relative_pos, bucket_size, max_position)
428
+ sign = Torch.sign(relative_pos)
429
+ mid = bucket_size / 2
430
+ abs_pos = Torch.where(relative_pos.lt(mid) & relative_pos.gt(-mid), Torch.tensor(mid - 1).type_as(relative_pos), Torch.abs(relative_pos))
431
+ log_pos = Torch.ceil((Torch.log(abs_pos / mid) / Torch.log(Torch.tensor((max_position - 1) / mid))) * (mid - 1)) + mid
432
+ bucket_pos = Torch.where(abs_pos.le(mid), relative_pos.type_as(log_pos), log_pos * sign)
433
+ bucket_pos
434
+ end
435
+
436
+ def self.build_relative_position(query_size, key_size, bucket_size: -1, max_position: -1, device: nil)
437
+ q_ids = Torch.arange(0, query_size, device: device)
438
+ k_ids = Torch.arange(0, key_size, device: device)
439
+ rel_pos_ids = q_ids[0.., nil] - k_ids[nil, 0..]
440
+ if bucket_size > 0 && max_position > 0
441
+ rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)
442
+ end
443
+ rel_pos_ids = rel_pos_ids.to(Torch.long)
444
+ rel_pos_ids = rel_pos_ids[...query_size, 0..]
445
+ rel_pos_ids = rel_pos_ids.unsqueeze(0)
446
+ rel_pos_ids
447
+ end
448
+
449
+ def self.c2p_dynamic_expand(c2p_pos, query_layer, relative_pos)
450
+ c2p_pos.expand([query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)])
451
+ end
452
+
453
+ def self.p2c_dynamic_expand(c2p_pos, query_layer, key_layer)
454
+ c2p_pos.expand([query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)])
455
+ end
456
+
457
+ def self.pos_dynamic_expand(pos_index, p2c_att, key_layer)
458
+ pos_index.expand(p2c_att.size[...2] + [pos_index.size(-2), key_layer.size(-2)])
459
+ end
460
+
461
+ class DisentangledSelfAttention < Torch::NN::Module
462
+ def initialize(config)
463
+ super()
464
+ if config.hidden_size % config.num_attention_heads != 0
465
+ raise ArgumentError, "The hidden size (#{config.hidden_size}) is not a multiple of the number of attention heads (#{config.num_attention_heads})"
466
+ end
467
+ @num_attention_heads = config.num_attention_heads
468
+ _attention_head_size = config.hidden_size / config.num_attention_heads
469
+ @attention_head_size = config.getattr("attention_head_size", _attention_head_size)
470
+ @all_head_size = @num_attention_heads * @attention_head_size
471
+ @query_proj = Torch::NN::Linear.new(config.hidden_size, @all_head_size, bias: true)
472
+ @key_proj = Torch::NN::Linear.new(config.hidden_size, @all_head_size, bias: true)
473
+ @value_proj = Torch::NN::Linear.new(config.hidden_size, @all_head_size, bias: true)
474
+
475
+ @share_att_key = config.getattr("share_att_key", false)
476
+ @pos_att_type = !config.pos_att_type.nil? ? config.pos_att_type : []
477
+ @relative_attention = config.getattr("relative_attention", false)
478
+
479
+ if @relative_attention
480
+ @position_buckets = config.getattr("position_buckets", -1)
481
+ @max_relative_positions = config.getattr("max_relative_positions", -1)
482
+ if @max_relative_positions < 1
483
+ @max_relative_positions = config.max_position_embeddings
484
+ end
485
+ @pos_ebd_size = @max_relative_positions
486
+ if @position_buckets > 0
487
+ @pos_ebd_size = @position_buckets
488
+ end
489
+
490
+ @pos_dropout = StableDropout.new(config.hidden_dropout_prob)
491
+
492
+ if !@share_att_key
493
+ if @pos_att_type.include?("c2p")
494
+ @pos_key_proj = Torch::NN::Linear.new(config.hidden_size, @all_head_size, bias: true)
495
+ end
496
+ if @pos_att_type.include?("p2c")
497
+ @pos_query_proj = Torch::NN::Linear.new(config.hidden_size, @all_head_size)
498
+ end
499
+ end
500
+ end
501
+
502
+ @dropout = StableDropout.new(config.attention_probs_dropout_prob)
503
+ end
504
+
505
+ def transpose_for_scores(x, attention_heads)
506
+ new_x_shape = x.size[...-1] + [attention_heads, -1]
507
+ x = x.view(new_x_shape)
508
+ x.permute(0, 2, 1, 3).contiguous.view(-1, x.size(1), x.size(-1))
509
+ end
510
+
511
+ def forward(
512
+ hidden_states,
513
+ attention_mask,
514
+ output_attentions: false,
515
+ query_states: nil,
516
+ relative_pos: nil,
517
+ rel_embeddings: nil
518
+ )
519
+ if query_states.nil?
520
+ query_states = hidden_states
521
+ end
522
+ query_layer = transpose_for_scores(@query_proj.(query_states), @num_attention_heads)
523
+ key_layer = transpose_for_scores(@key_proj.(hidden_states), @num_attention_heads)
524
+ value_layer = transpose_for_scores(@value_proj.(hidden_states), @num_attention_heads)
525
+
526
+ rel_att = nil
527
+ # Take the dot product between "query" and "key" to get the raw attention scores.
528
+ scale_factor = 1
529
+ if @pos_att_type.include?("c2p")
530
+ scale_factor += 1
531
+ end
532
+ if @pos_att_type.include?("p2c")
533
+ scale_factor += 1
534
+ end
535
+ scale = Torch.sqrt(Torch.tensor(query_layer.size(-1), dtype: Torch.float) * scale_factor)
536
+ attention_scores = Torch.bmm(query_layer, key_layer.transpose(-1, -2) / scale.to(dtype: query_layer.dtype))
537
+ if @relative_attention
538
+ rel_embeddings = @pos_dropout.(rel_embeddings)
539
+ rel_att = disentangled_attention_bias(query_layer, key_layer, relative_pos, rel_embeddings, scale_factor)
540
+ end
541
+
542
+ if !rel_att.nil?
543
+ attention_scores = attention_scores + rel_att
544
+ end
545
+ attention_scores = attention_scores
546
+ attention_scores = attention_scores.view(-1, @num_attention_heads, attention_scores.size(-2), attention_scores.size(-1))
547
+
548
+ # bsz x height x length x dimension
549
+ attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1)
550
+ attention_probs = @dropout.(attention_probs)
551
+ context_layer = Torch.bmm(attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer)
552
+ context_layer = context_layer.view(-1, @num_attention_heads, context_layer.size(-2), context_layer.size(-1)).permute(0, 2, 1, 3).contiguous
553
+ new_context_layer_shape = context_layer.size[...-2] + [-1]
554
+ context_layer = context_layer.view(new_context_layer_shape)
555
+ if output_attentions
556
+ [context_layer, attention_probs]
557
+ else
558
+ context_layer
559
+ end
560
+ end
561
+
562
+ def disentangled_attention_bias(query_layer, key_layer, relative_pos, rel_embeddings, scale_factor)
563
+ if relative_pos.nil?
564
+ q = query_layer.size(-2)
565
+ relative_pos = DebertaV2.build_relative_position(q, key_layer.size(-2), bucket_size: @position_buckets, max_position: @max_relative_positions, device: query_layer.device)
566
+ end
567
+ if relative_pos.dim == 2
568
+ relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
569
+ elsif relative_pos.dim == 3
570
+ relative_pos = relative_pos.unsqueeze(1)
571
+ elsif relative_pos.dim != 4
572
+ raise ArgumentError, "Relative position ids must be of dim 2 or 3 or 4. #{relative_pos.dim}"
573
+ end
574
+
575
+ att_span = @pos_ebd_size
576
+ relative_pos = relative_pos.long.to(query_layer.device)
577
+
578
+ rel_embeddings = rel_embeddings[0...att_span * 2, 0..].unsqueeze(0)
579
+ if @share_att_key
580
+ pos_query_layer = transpose_for_scores(@query_proj.(rel_embeddings), @num_attention_heads).repeat(query_layer.size(0) / @num_attention_heads, 1, 1)
581
+ pos_key_layer = transpose_for_scores(@key_proj.(rel_embeddings), @num_attention_heads).repeat(query_layer.size(0) / @num_attention_heads, 1, 1)
582
+ elsif @pos_att_type.include?("c2p")
583
+ pos_key_layer = transpose_for_scores(@pos_key_proj.(rel_embeddings), @num_attention_heads).repeat(query_layer.size(0) / @num_attention_heads, 1, 1)
584
+ end
585
+
586
+ score = 0
587
+ # content->position
588
+ if @pos_att_type.include?("c2p")
589
+ scale = Torch.sqrt(Torch.tensor(pos_key_layer.size(-1), dtype: Torch.float) * scale_factor)
590
+ c2p_att = Torch.bmm(query_layer, pos_key_layer.transpose(-1, -2))
591
+ c2p_pos = Torch.clamp(relative_pos + att_span, 0, (att_span * 2) - 1)
592
+ c2p_att = Torch.gather(c2p_att, dim: -1, index: c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]))
593
+ score += c2p_att / scale.to(dtype: c2p_att.dtype)
594
+ end
595
+
596
+ # position->content
597
+ if @pos_att_type.include?("p2c")
598
+ scale = Torch.sqrt(Torch.tensor(pos_query_layer.size(-1), dtype: Torch.float) * scale_factor)
599
+ if key_layer.size(-2) != query_layer.size(-2)
600
+ r_pos = DebertaV2.build_relative_position(key_layer.size(-2), key_layer.size(-2), bucket_size: @position_buckets, max_position: @max_relative_positions, device: query_layer.device)
601
+ r_pos = r_pos.unsqueeze(0)
602
+ else
603
+ r_pos = relative_pos
604
+ end
605
+
606
+ p2c_pos = Torch.clamp(-r_pos + att_span, 0, (att_span * 2) - 1)
607
+ p2c_att = Torch.bmm(key_layer, pos_query_layer.transpose(-1, -2))
608
+ p2c_att = Torch.gather(p2c_att, dim: -1, index: p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)])).transpose(-1, -2)
609
+ score += p2c_att / scale.to(dtype: p2c_att.dtype)
610
+ end
611
+
612
+ score
613
+ end
614
+ end
615
+
616
+ class DebertaV2Embeddings < Torch::NN::Module
617
+ def initialize(config)
618
+ super()
619
+ pad_token_id = config.getattr("pad_token_id", 0)
620
+ @embedding_size = config.getattr("embedding_size", config.hidden_size)
621
+ @word_embeddings = Torch::NN::Embedding.new(config.vocab_size, @embedding_size, padding_idx: pad_token_id)
622
+
623
+ @position_biased_input = config.getattr("position_biased_input", true)
624
+ if !@position_biased_input
625
+ @position_embeddings = nil
626
+ else
627
+ @position_embeddings = Torch::NN::Embedding.new(config.max_position_embeddings, @embedding_size)
628
+ end
629
+
630
+ if config.type_vocab_size > 0
631
+ @token_type_embeddings = Torch::NN::Embedding.new(config.type_vocab_size, @embedding_size)
632
+ end
633
+
634
+ if @embedding_size != config.hidden_size
635
+ @embed_proj = Torch::NN::Linear.new(@embedding_size, config.hidden_size, bias: false)
636
+ end
637
+ @LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
638
+ @dropout = StableDropout.new(config.hidden_dropout_prob)
639
+ @config = config
640
+
641
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
642
+ register_buffer("position_ids", Torch.arange(config.max_position_embeddings).expand([1, -1]), persistent: false)
643
+ end
644
+
645
+ def forward(input_ids: nil, token_type_ids: nil, position_ids: nil, mask: nil, inputs_embeds: nil)
646
+ if !input_ids.nil?
647
+ input_shape = input_ids.size
648
+ else
649
+ input_shape = inputs_embeds.size[...-1]
650
+ end
651
+
652
+ seq_length = input_shape[1]
653
+
654
+ if position_ids.nil?
655
+ position_ids = @position_ids[0.., ...seq_length]
656
+ end
657
+
658
+ if token_type_ids.nil?
659
+ token_type_ids = Torch.zeros(input_shape, dtype: Torch.long, device: @position_ids.device)
660
+ end
661
+
662
+ if inputs_embeds.nil?
663
+ inputs_embeds = @word_embeddings.(input_ids)
664
+ end
665
+
666
+ if !@position_embeddings.nil?
667
+ position_embeddings = @position_embeddings.(position_ids.long)
668
+ else
669
+ position_embeddings = Torch.zeros_like(inputs_embeds)
670
+ end
671
+
672
+ embeddings = inputs_embeds
673
+ if @position_biased_input
674
+ embeddings += position_embeddings
675
+ end
676
+ if @config.type_vocab_size > 0
677
+ token_type_embeddings = @token_type_embeddings.(token_type_ids)
678
+ embeddings += token_type_embeddings
679
+ end
680
+
681
+ if @embedding_size != @config.hidden_size
682
+ embeddings = @embed_proj.(embeddings)
683
+ end
684
+
685
+ embeddings = @LayerNorm.(embeddings)
686
+
687
+ if !mask.nil?
688
+ if mask.dim != embeddings.dim
689
+ if mask.dim == 4
690
+ mask = mask.squeeze(1).squeeze(1)
691
+ end
692
+ mask = mask.unsqueeze(2)
693
+ end
694
+ mask = mask.to(embeddings.dtype)
695
+
696
+ embeddings = embeddings * mask
697
+ end
698
+
699
+ embeddings = @dropout.(embeddings)
700
+ embeddings
701
+ end
702
+ end
703
+
704
+ class DebertaV2PreTrainedModel < PreTrainedModel
705
+ self.config_class = DebertaV2Config
706
+ self.base_model_prefix = "deberta"
707
+ # self._keys_to_ignore_on_load_unexpected = ["position_embeddings"]
708
+ # self.supports_gradient_checkpointing = true
709
+
710
+ def _init_weights(module_)
711
+ if module_.is_a?(Torch::NN::Linear)
712
+ # Slightly different from the TF version which uses truncated_normal for initialization
713
+ # cf https://github.com/pytorch/pytorch/pull/5617
714
+ module_.weight.data.normal!(mean: 0.0, std: @config.initializer_range)
715
+ if !module_.bias.nil?
716
+ module_.bias.data.zero!
717
+ end
718
+ elsif module_.is_a?(Torch::NN::Embedding)
719
+ module_.weight.data.normal!(mean: 0.0, std: @config.initializer_range)
720
+ if !module_.padding_idx.nil?
721
+ module_.weight.data.fetch(module_.padding_idx).zero!
722
+ end
723
+ end
724
+ end
725
+ end
726
+
727
+ class DebertaV2Model < DebertaV2PreTrainedModel
728
+ def initialize(config)
729
+ super(config)
730
+
731
+ @embeddings = DebertaV2Embeddings.new(config)
732
+ @encoder = DebertaV2Encoder.new(config)
733
+ @z_steps = 0
734
+ @config = config
735
+ # Initialize weights and apply final processing
736
+ post_init
737
+ end
738
+
739
+ def get_input_embeddings
740
+ @embeddings.word_embeddings
741
+ end
742
+
743
+ def set_input_embeddings(new_embeddings)
744
+ @word_embeddings = new_embeddings
745
+ end
746
+
747
+ def _prune_heads(heads_to_prune)
748
+ raise NotImplementedError, "The prune function is not implemented in DeBERTa model."
749
+ end
750
+
751
+ def forward(
752
+ input_ids,
753
+ attention_mask: nil,
754
+ token_type_ids: nil,
755
+ position_ids: nil,
756
+ inputs_embeds: nil,
757
+ output_attentions: nil,
758
+ output_hidden_states: nil,
759
+ return_dict: nil
760
+ )
761
+ output_attentions = !output_attentions.nil? ? output_attentions : @config.output_attentions
762
+ output_hidden_states = !output_hidden_states.nil? ? output_hidden_states : @config.output_hidden_states
763
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
764
+
765
+ if !input_ids.nil? && !inputs_embeds.nil?
766
+ raise ArgumentError, "You cannot specify both input_ids and inputs_embeds at the same time"
767
+ elsif !input_ids.nil?
768
+ warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
769
+ input_shape = input_ids.size
770
+ elsif !inputs_embeds.nil?
771
+ input_shape = inputs_embeds.size[...-1]
772
+ else
773
+ raise ArgumentError, "You have to specify either input_ids or inputs_embeds"
774
+ end
775
+
776
+ device = !input_ids.nil? ? input_ids.device : inputs_embeds.device
777
+
778
+ if attention_mask.nil?
779
+ attention_mask = Torch.ones(input_shape, device: device)
780
+ end
781
+ if token_type_ids.nil?
782
+ token_type_ids = Torch.zeros(input_shape, dtype: Torch.long, device: device)
783
+ end
784
+
785
+ embedding_output = @embeddings.(input_ids: input_ids, token_type_ids: token_type_ids, position_ids: position_ids, mask: attention_mask, inputs_embeds: inputs_embeds)
786
+
787
+ encoder_outputs = @encoder.(embedding_output, attention_mask, output_hidden_states: true, output_attentions: output_attentions, return_dict: return_dict)
788
+ encoded_layers = encoder_outputs[1]
789
+
790
+ if @z_steps > 1
791
+ hidden_states = encoded_layers[-2]
792
+ layers = @z_steps.times.map { |_| @encoder.layer[-1] }
793
+ query_states = encoded_layers[-1]
794
+ rel_embeddings = @encoder.get_rel_embedding
795
+ attention_mask = @encoder.get_attention_mask(attention_mask)
796
+ rel_pos = @encoder.get_rel_pos(embedding_output)
797
+ layers[1..].each do |layer|
798
+ query_states = layer(hidden_states, attention_mask, output_attentions: false, query_states: query_states, relative_pos: rel_pos, rel_embeddings: rel_embeddings)
799
+ encoded_layers << query_states
800
+ end
801
+ end
802
+
803
+ sequence_output = encoded_layers[-1]
804
+
805
+ if !return_dict
806
+ return [sequence_output] + encoder_outputs[output_hidden_states ? 1 : 2..]
807
+ end
808
+
809
+ BaseModelOutput.new(last_hidden_state: sequence_output, hidden_states: output_hidden_states ? encoder_outputs.hidden_states : nil, attentions: encoder_outputs.attentions)
810
+ end
811
+ end
812
+
813
+ class DebertaV2ForMaskedLM < DebertaV2PreTrainedModel
814
+ self._tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
815
+
816
+ def initialize(config)
817
+ super(config)
818
+
819
+ @deberta = DebertaV2Model.new(config)
820
+ @cls = DebertaV2OnlyMLMHead.new(config)
821
+
822
+ # Initialize weights and apply final processing
823
+ post_init
824
+ end
825
+
826
+ def get_output_embeddings
827
+ @cls.predictions.decoder
828
+ end
829
+
830
+ def set_output_embeddings(new_embeddings)
831
+ @decoder = new_embeddings
832
+ @bias = new_embeddings.bias
833
+ end
834
+
835
+ def forward(
836
+ input_ids: nil,
837
+ attention_mask: nil,
838
+ token_type_ids: nil,
839
+ position_ids: nil,
840
+ inputs_embeds: nil,
841
+ labels: nil,
842
+ output_attentions: nil,
843
+ output_hidden_states: nil,
844
+ return_dict: nil
845
+ )
846
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
847
+
848
+ outputs = @deberta.(input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids, position_ids: position_ids, inputs_embeds: inputs_embeds, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
849
+
850
+ sequence_output = outputs[0]
851
+ prediction_scores = @cls.(sequence_output)
852
+
853
+ masked_lm_loss = nil
854
+ if !labels.nil?
855
+ loss_fct = Torch::NN::CrossEntropyLoss.new
856
+ masked_lm_loss = loss_fct.(prediction_scores.view(-1, @config.vocab_size), labels.view(-1))
857
+ end
858
+
859
+ if !return_dict
860
+ output = [prediction_scores] + outputs[1..]
861
+ return !masked_lm_loss.nil? ? [masked_lm_loss] + output : output
862
+ end
863
+
864
+ MaskedLMOutput.new(loss: masked_lm_loss, logits: prediction_scores, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
865
+ end
866
+ end
867
+
868
+ class DebertaV2PredictionHeadTransform < Torch::NN::Module
869
+ def initialize(config)
870
+ super()
871
+ @embedding_size = config.getattr("embedding_size", config.hidden_size)
872
+
873
+ @dense = Torch::NN::Linear.new(config.hidden_size, @embedding_size)
874
+ if config.hidden_act.is_a?(String)
875
+ @transform_act_fn = ACT2FN[config.hidden_act]
876
+ else
877
+ @transform_act_fn = config.hidden_act
878
+ end
879
+ @LayerNorm = Torch::NN::LayerNorm.new(@embedding_size, eps: config.layer_norm_eps)
880
+ end
881
+
882
+ def forward(hidden_states)
883
+ hidden_states = @dense.(hidden_states)
884
+ hidden_states = @transform_act_fn.(hidden_states)
885
+ hidden_states = @LayerNorm.(hidden_states)
886
+ hidden_states
887
+ end
888
+ end
889
+
890
+ class DebertaV2LMPredictionHead < Torch::NN::Module
891
+ def initialize(config)
892
+ super()
893
+ @transform = DebertaV2PredictionHeadTransform.new(config)
894
+
895
+ @embedding_size = config.getattr("embedding_size", config.hidden_size)
896
+ # The output weights are the same as the input embeddings, but there is
897
+ # an output-only bias for each token.
898
+ @decoder = Torch::NN::Linear.new(@embedding_size, config.vocab_size, bias: false)
899
+
900
+ @bias = Torch::NN::Parameter.new(Torch.zeros(config.vocab_size))
901
+
902
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
903
+ @bias = @bias
904
+ end
905
+
906
+ def _tie_weights
907
+ @bias = @bias
908
+ end
909
+
910
+ def forward(hidden_states)
911
+ hidden_states = @transform.(hidden_states)
912
+ hidden_states = @decoder.(hidden_states)
913
+ hidden_states
914
+ end
915
+ end
916
+
917
+ class DebertaV2OnlyMLMHead < Torch::NN::Module
918
+ def initialize(config)
919
+ super()
920
+ @predictions = DebertaV2LMPredictionHead.new(config)
921
+ end
922
+
923
+ def forward(sequence_output)
924
+ prediction_scores = @predictions.(sequence_output)
925
+ prediction_scores
926
+ end
927
+ end
928
+
929
+ class DebertaV2ForSequenceClassification < DebertaV2PreTrainedModel
930
+ def initialize(config)
931
+ super(config)
932
+
933
+ num_labels = config.getattr("num_labels", 2)
934
+ @num_labels = num_labels
935
+
936
+ @deberta = DebertaV2Model.new(config)
937
+ @pooler = ContextPooler.new(config)
938
+ output_dim = @pooler.output_dim
939
+
940
+ @classifier = Torch::NN::Linear.new(output_dim, num_labels)
941
+ drop_out = config.getattr("cls_dropout", nil)
942
+ drop_out = drop_out.nil? ? @config.hidden_dropout_prob : drop_out
943
+ @dropout = StableDropout.new(drop_out)
944
+
945
+ # Initialize weights and apply final processing
946
+ post_init
947
+ end
948
+
949
+ def get_input_embeddings
950
+ @deberta.get_input_embeddings
951
+ end
952
+
953
+ def set_input_embeddings(new_embeddings)
954
+ @deberta.set_input_embeddings(new_embeddings)
955
+ end
956
+
957
+ def forward(
958
+ input_ids: nil,
959
+ attention_mask: nil,
960
+ token_type_ids: nil,
961
+ position_ids: nil,
962
+ inputs_embeds: nil,
963
+ labels: nil,
964
+ output_attentions: nil,
965
+ output_hidden_states: nil,
966
+ return_dict: nil
967
+ )
968
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
969
+
970
+ outputs = @deberta.(input_ids, token_type_ids: token_type_ids, attention_mask: attention_mask, position_ids: position_ids, inputs_embeds: inputs_embeds, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
971
+
972
+ encoder_layer = outputs[0]
973
+ pooled_output = @pooler.(encoder_layer)
974
+ pooled_output = @dropout.(pooled_output)
975
+ logits = @classifier.(pooled_output)
976
+
977
+ loss = nil
978
+ if !labels.nil?
979
+ if @config.problem_type.nil?
980
+ if @num_labels == 1
981
+ # regression task
982
+ loss_fn = Torch::NN::MSELoss.new
983
+ logits = logits.view(-1).to(labels.dtype)
984
+ loss = loss_fn.(logits, labels.view(-1))
985
+ elsif labels.dim == 1 || labels.size(-1) == 1
986
+ label_index = (labels >= 0).nonzero
987
+ labels = labels.long
988
+ if label_index.size(0) > 0
989
+ labeled_logits = Torch.gather(logits, 0, label_index.expand(label_index.size(0), logits.size(1)))
990
+ labels = Torch.gather(labels, 0, label_index.view(-1))
991
+ loss_fct = Torch::NN::CrossEntropyLoss.new
992
+ loss = loss_fct.(labeled_logits.view(-1, @num_labels).float, labels.view(-1))
993
+ else
994
+ loss = Torch.tensor(0).to(logits)
995
+ end
996
+ else
997
+ log_softmax = Torch::NN::LogSoftmax.new(-1)
998
+ loss = -(log_softmax.(logits) * labels).sum(-1).mean
999
+ end
1000
+ elsif @config.problem_type == "regression"
1001
+ loss_fct = Torch::NN::MSELoss.new
1002
+ if @num_labels == 1
1003
+ loss = loss_fct.(logits.squeeze, labels.squeeze)
1004
+ else
1005
+ loss = loss_fct.(logits, labels)
1006
+ end
1007
+ elsif @config.problem_type == "single_label_classification"
1008
+ loss_fct = Torch::NN::CrossEntropyLoss.new
1009
+ loss = loss_fct.(logits.view(-1, @num_labels), labels.view(-1))
1010
+ elsif @config.problem_type == "multi_label_classification"
1011
+ loss_fct = Torch::NN::BCEWithLogitsLoss.new
1012
+ loss = loss_fct.(logits, labels)
1013
+ end
1014
+ end
1015
+ if !return_dict
1016
+ output = [logits] + outputs[1..]
1017
+ return !loss.nil? ? [loss] + output : output
1018
+ end
1019
+
1020
+ SequenceClassifierOutput.new(loss: loss, logits: logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
1021
+ end
1022
+ end
1023
+
1024
+ class DebertaV2ForTokenClassification < DebertaV2PreTrainedModel
1025
+ def initialize(config)
1026
+ super(config)
1027
+ @num_labels = config.num_labels
1028
+
1029
+ @deberta = DebertaV2Model.new(config)
1030
+ @dropout = Torch::NN::Dropout.new(config.hidden_dropout_prob)
1031
+ @classifier = Torch::NN::Linear.new(config.hidden_size, config.num_labels)
1032
+
1033
+ # Initialize weights and apply final processing
1034
+ post_init
1035
+ end
1036
+
1037
+ def forward(
1038
+ input_ids: nil,
1039
+ attention_mask: nil,
1040
+ token_type_ids: nil,
1041
+ position_ids: nil,
1042
+ inputs_embeds: nil,
1043
+ labels: nil,
1044
+ output_attentions: nil,
1045
+ output_hidden_states: nil,
1046
+ return_dict: nil
1047
+ )
1048
+
1049
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
1050
+
1051
+ outputs = @deberta.(input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids, position_ids: position_ids, inputs_embeds: inputs_embeds, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
1052
+
1053
+ sequence_output = outputs[0]
1054
+
1055
+ sequence_output = @dropout.(sequence_output)
1056
+ logits = @classifier.(sequence_output)
1057
+
1058
+ loss = nil
1059
+ if !labels.nil?
1060
+ loss_fct = Torch::NN::CrossEntropyLoss.new
1061
+ loss = loss_fct.(logits.view(-1, @num_labels), labels.view(-1))
1062
+ end
1063
+
1064
+ if !return_dict
1065
+ output = [logits] + outputs[1..]
1066
+ return !loss.nil? ? [loss] + output : output
1067
+ end
1068
+
1069
+ TokenClassifierOutput.new(loss: loss, logits: logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
1070
+ end
1071
+ end
1072
+
1073
+ class DebertaV2ForQuestionAnswering < DebertaV2PreTrainedModel
1074
+ def initialize(config)
1075
+ super(config)
1076
+ @num_labels = config.num_labels
1077
+
1078
+ @deberta = DebertaV2Model.new(config)
1079
+ @qa_outputs = Torch::NN::Linear.new(config.hidden_size, config.num_labels)
1080
+
1081
+ # Initialize weights and apply final processing
1082
+ post_init
1083
+ end
1084
+
1085
+ def forward(
1086
+ input_ids: nil,
1087
+ attention_mask: nil,
1088
+ token_type_ids: nil,
1089
+ position_ids: nil,
1090
+ inputs_embeds: nil,
1091
+ start_positions: nil,
1092
+ end_positions: nil,
1093
+ output_attentions: nil,
1094
+ output_hidden_states: nil,
1095
+ return_dict: nil
1096
+ )
1097
+
1098
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
1099
+
1100
+ outputs = @deberta.(input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids, position_ids: position_ids, inputs_embeds: inputs_embeds, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
1101
+
1102
+ sequence_output = outputs[0]
1103
+
1104
+ logits = @qa_outputs.(sequence_output)
1105
+ start_logits, end_logits = logits.split(1, dim: -1)
1106
+ start_logits = start_logits.squeeze(-1).contiguous
1107
+ end_logits = end_logits.squeeze(-1).contiguous
1108
+
1109
+ total_loss = nil
1110
+ if !start_positions.nil? && !end_positions.nil?
1111
+ # If we are on multi-GPU, split add a dimension
1112
+ if start_positions.size.length > 1
1113
+ start_positions = start_positions.squeeze(-1)
1114
+ end
1115
+ if end_positions.size.length > 1
1116
+ end_positions = end_positions.squeeze(-1)
1117
+ end
1118
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1119
+ ignored_index = start_logits.size(1)
1120
+ start_positions = start_positions.clamp(0, ignored_index)
1121
+ end_positions = end_positions.clamp(0, ignored_index)
1122
+
1123
+ loss_fct = Torch::NN::CrossEntropyLoss.new(ignore_index: ignored_index)
1124
+ start_loss = loss_fct.(start_logits, start_positions)
1125
+ end_loss = loss_fct.(end_logits, end_positions)
1126
+ total_loss = (start_loss + end_loss) / 2
1127
+ end
1128
+
1129
+ if !return_dict
1130
+ output = [start_logits, end_logits] + outputs[1..]
1131
+ return !total_loss.nil? ? [total_loss] + output : output
1132
+ end
1133
+
1134
+ QuestionAnsweringModelOutput.new(loss: total_loss, start_logits: start_logits, end_logits: end_logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
1135
+ end
1136
+ end
1137
+
1138
+ class DebertaV2ForMultipleChoice < DebertaV2PreTrainedModel
1139
+ def initialize(config)
1140
+ super(config)
1141
+
1142
+ num_labels = config.getattr("num_labels", 2)
1143
+ @num_labels = num_labels
1144
+
1145
+ @deberta = DebertaV2Model.new(config)
1146
+ @pooler = ContextPooler.new(config)
1147
+ output_dim = @pooler.output_dim
1148
+
1149
+ @classifier = Torch::NN::Linear.new(output_dim, 1)
1150
+ drop_out = config.getattr("cls_dropout", nil)
1151
+ drop_out = drop_out.nil? ? @config.hidden_dropout_prob : drop_out
1152
+ @dropout = StableDropout.new(drop_out)
1153
+
1154
+ init_weights
1155
+ end
1156
+
1157
+ def get_input_embeddings
1158
+ @deberta.get_input_embeddings
1159
+ end
1160
+
1161
+ def set_input_embeddings(new_embeddings)
1162
+ @deberta.set_input_embeddings(new_embeddings)
1163
+ end
1164
+
1165
+ def forward(
1166
+ input_ids: nil,
1167
+ attention_mask: nil,
1168
+ token_type_ids: nil,
1169
+ position_ids: nil,
1170
+ inputs_embeds: nil,
1171
+ labels: nil,
1172
+ output_attentions: nil,
1173
+ output_hidden_states: nil,
1174
+ return_dict: nil
1175
+ )
1176
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
1177
+ num_choices = !input_ids.nil? ? input_ids.shape[1] : inputs_embeds.shape[1]
1178
+
1179
+ flat_input_ids = !input_ids.nil? ? input_ids.view(-1, input_ids.size(-1)) : nil
1180
+ flat_position_ids = !position_ids.nil? ? position_ids.view(-1, position_ids.size(-1)) : nil
1181
+ flat_token_type_ids = !token_type_ids.nil? ? token_type_ids.view(-1, token_type_ids.size(-1)) : nil
1182
+ flat_attention_mask = !attention_mask.nil? ? attention_mask.view(-1, attention_mask.size(-1)) : nil
1183
+ flat_inputs_embeds = !inputs_embeds.nil? ? inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) : nil
1184
+
1185
+ outputs = @deberta.(flat_input_ids, position_ids: flat_position_ids, token_type_ids: flat_token_type_ids, attention_mask: flat_attention_mask, inputs_embeds: flat_inputs_embeds, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
1186
+
1187
+ encoder_layer = outputs[0]
1188
+ pooled_output = @pooler.(encoder_layer)
1189
+ pooled_output = @dropout.(pooled_output)
1190
+ logits = @classifier.(pooled_output)
1191
+ reshaped_logits = logits.view(-1, num_choices)
1192
+
1193
+ loss = nil
1194
+ if !labels.nil?
1195
+ loss_fct = Torch::NN::CrossEntropyLoss.new
1196
+ loss = loss_fct.(reshaped_logits, labels)
1197
+ end
1198
+
1199
+ if !return_dict
1200
+ output = [reshaped_logits] + outputs[1..]
1201
+ return !loss.nil? ? [loss] + output : output
1202
+ end
1203
+
1204
+ MultipleChoiceModelOutput.new(loss: loss, logits: reshaped_logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
1205
+ end
1206
+ end
1207
+ end
1208
+
1209
+ DebertaV2ForSequenceClassification = DebertaV2::DebertaV2ForSequenceClassification
1210
+ end