transformers-rb 0.1.2 → 0.1.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1210 @@
1
+ # Copyright 2020 Microsoft and the Hugging Face Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ module DebertaV2
17
+ class ContextPooler < Torch::NN::Module
18
+ def initialize(config)
19
+ super()
20
+ @dense = Torch::NN::Linear.new(config.pooler_hidden_size, config.pooler_hidden_size)
21
+ @dropout = StableDropout.new(config.pooler_dropout)
22
+ @config = config
23
+ end
24
+
25
+ def forward(hidden_states)
26
+ # We "pool" the model by simply taking the hidden state corresponding
27
+ # to the first token.
28
+
29
+ context_token = hidden_states[0.., 0]
30
+ context_token = @dropout.(context_token)
31
+ pooled_output = @dense.(context_token)
32
+ pooled_output = ACT2FN[@config.pooler_hidden_act].(pooled_output)
33
+ pooled_output
34
+ end
35
+
36
+ def output_dim
37
+ @config.hidden_size
38
+ end
39
+ end
40
+
41
+ # TODO Torch::Autograd::Function
42
+ class XSoftmax
43
+ def self.apply(input, mask, dim)
44
+ @dim = dim
45
+ rmask = mask.to(Torch.bool).bitwise_not
46
+
47
+ # TODO use Torch.finfo
48
+ output = input.masked_fill(rmask, Torch.tensor(-3.40282e+38))
49
+ output = Torch.softmax(output, @dim)
50
+ output.masked_fill!(rmask, 0)
51
+ # ctx.save_for_backward(output)
52
+ output
53
+ end
54
+ end
55
+
56
+ class DropoutContext
57
+ def initialize
58
+ @dropout = 0
59
+ @mask = nil
60
+ @scale = 1
61
+ @reuse_mask = true
62
+ end
63
+ end
64
+
65
+ def get_mask(input, local_context)
66
+ if !local_context.is_a?(DropoutContext)
67
+ dropout = local_context
68
+ mask = nil
69
+ else
70
+ dropout = local_context.dropout
71
+ dropout *= local_context.scale
72
+ mask = local_context.reuse_mask ? local_context.mask : nil
73
+ end
74
+
75
+ if dropout > 0 && mask.nil?
76
+ mask = (1 - Torch.empty_like(input).bernoulli!(1 - dropout)).to(Torch.bool)
77
+ end
78
+
79
+ if local_context.is_a?(DropoutContext)
80
+ if local_context.mask.nil?
81
+ @mask = mask
82
+ end
83
+ end
84
+
85
+ [mask, dropout]
86
+ end
87
+
88
+ # TODO Torch::Autograd::Function
89
+ class XDropout
90
+ def self.apply(input, local_ctx)
91
+ mask, dropout = get_mask(input, local_ctx)
92
+ @scale = 1.0 / (1 - dropout)
93
+ if dropout > 0
94
+ # ctx.save_for_backward(mask)
95
+ input.masked_fill(mask, 0) * ctx.scale
96
+ else
97
+ input
98
+ end
99
+ end
100
+ end
101
+
102
+ class StableDropout < Torch::NN::Module
103
+ def initialize(drop_prob)
104
+ super()
105
+ @drop_prob = drop_prob
106
+ @count = 0
107
+ @context_stack = nil
108
+ end
109
+
110
+ def forward(x)
111
+ if @training && @drop_prob > 0
112
+ return XDropout.apply(x, get_context)
113
+ end
114
+ x
115
+ end
116
+
117
+ def clear_context
118
+ @count = 0
119
+ @context_stack = nil
120
+ end
121
+
122
+ def init_context(reuse_mask: true, scale: 1)
123
+ if @context_stack.nil?
124
+ @context_stack = []
125
+ end
126
+ @count = 0
127
+ @context_stack.each do |c|
128
+ @reuse_mask = reuse_mask
129
+ @scale = scale
130
+ end
131
+ end
132
+
133
+ def get_context
134
+ if !@context_stack.nil?
135
+ if @count >= @context_stack.length
136
+ @context_stack << DropoutContext.new
137
+ end
138
+ ctx = @context_stack.fetch(@count)
139
+ @dropout = @drop_prob
140
+ @count += 1
141
+ ctx
142
+ else
143
+ @drop_prob
144
+ end
145
+ end
146
+ end
147
+
148
+ class DebertaV2SelfOutput < Torch::NN::Module
149
+ def initialize(config)
150
+ super()
151
+ @dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
152
+ @LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
153
+ @dropout = StableDropout.new(config.hidden_dropout_prob)
154
+ end
155
+
156
+ def forward(hidden_states, input_tensor)
157
+ hidden_states = @dense.(hidden_states)
158
+ hidden_states = @dropout.(hidden_states)
159
+ hidden_states = @LayerNorm.(hidden_states + input_tensor)
160
+ hidden_states
161
+ end
162
+ end
163
+
164
+ class DebertaV2Attention < Torch::NN::Module
165
+ def initialize(config)
166
+ super()
167
+ @self = DisentangledSelfAttention.new(config)
168
+ @output = DebertaV2SelfOutput.new(config)
169
+ @config = config
170
+ end
171
+
172
+ def forward(
173
+ hidden_states,
174
+ attention_mask,
175
+ output_attentions: false,
176
+ query_states: nil,
177
+ relative_pos: nil,
178
+ rel_embeddings: nil
179
+ )
180
+ self_output = @self.(hidden_states, attention_mask, output_attentions:, query_states: query_states, relative_pos: relative_pos, rel_embeddings: rel_embeddings)
181
+ if output_attentions
182
+ self_output, att_matrix = self_output
183
+ end
184
+ if query_states.nil?
185
+ query_states = hidden_states
186
+ end
187
+ attention_output = @output.(self_output, query_states)
188
+
189
+ if output_attentions
190
+ [attention_output, att_matrix]
191
+ else
192
+ attention_output
193
+ end
194
+ end
195
+ end
196
+
197
+ class DebertaV2Intermediate < Torch::NN::Module
198
+ def initialize(config)
199
+ super()
200
+ @dense = Torch::NN::Linear.new(config.hidden_size, config.intermediate_size)
201
+ if config.hidden_act.is_a?(String)
202
+ @intermediate_act_fn = ACT2FN[config.hidden_act]
203
+ else
204
+ @intermediate_act_fn = config.hidden_act
205
+ end
206
+ end
207
+
208
+ def forward(hidden_states)
209
+ hidden_states = @dense.(hidden_states)
210
+ hidden_states = @intermediate_act_fn.(hidden_states)
211
+ hidden_states
212
+ end
213
+ end
214
+
215
+ class DebertaV2Output < Torch::NN::Module
216
+ def initialize(config)
217
+ super()
218
+ @dense = Torch::NN::Linear.new(config.intermediate_size, config.hidden_size)
219
+ @LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
220
+ @dropout = StableDropout.new(config.hidden_dropout_prob)
221
+ @config = config
222
+ end
223
+
224
+ def forward(hidden_states, input_tensor)
225
+ hidden_states = @dense.(hidden_states)
226
+ hidden_states = @dropout.(hidden_states)
227
+ hidden_states = @LayerNorm.(hidden_states + input_tensor)
228
+ hidden_states
229
+ end
230
+ end
231
+
232
+ class DebertaV2Layer < Torch::NN::Module
233
+ def initialize(config)
234
+ super()
235
+ @attention = DebertaV2Attention.new(config)
236
+ @intermediate = DebertaV2Intermediate.new(config)
237
+ @output = DebertaV2Output.new(config)
238
+ end
239
+
240
+ def forward(
241
+ hidden_states,
242
+ attention_mask,
243
+ query_states: nil,
244
+ relative_pos: nil,
245
+ rel_embeddings: nil,
246
+ output_attentions: false
247
+ )
248
+ attention_output = @attention.(hidden_states, attention_mask, output_attentions: output_attentions, query_states: query_states, relative_pos: relative_pos, rel_embeddings: rel_embeddings)
249
+ if output_attentions
250
+ attention_output, att_matrix = attention_output
251
+ end
252
+ intermediate_output = @intermediate.(attention_output)
253
+ layer_output = @output.(intermediate_output, attention_output)
254
+ if output_attentions
255
+ [layer_output, att_matrix]
256
+ else
257
+ layer_output
258
+ end
259
+ end
260
+ end
261
+
262
+ class ConvLayer < Torch::NN::Module
263
+ def initialize(config)
264
+ super()
265
+ kernel_size = config.getattr("conv_kernel_size", 3)
266
+ groups = config.getattr("conv_groups", 1)
267
+ @conv_act = config.getattr("conv_act", "tanh")
268
+ @conv = Torch::NN::Conv1d.new(config.hidden_size, config.hidden_size, kernel_size, padding: (kernel_size - 1) / 2, groups: groups)
269
+ @LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
270
+ @dropout = StableDropout.new(config.hidden_dropout_prob)
271
+ @config = config
272
+ end
273
+
274
+ def forward(hidden_states, residual_states, input_mask)
275
+ out = @conv.(hidden_states.permute(0, 2, 1).contiguous).permute(0, 2, 1).contiguous
276
+ rmask = (1 - input_mask).bool
277
+ out.masked_fill!(rmask.unsqueeze(-1).expand(out.size), 0)
278
+ out = ACT2FN[@conv_act].(@dropout.(out))
279
+
280
+ layer_norm_input = residual_states + out
281
+ output = @LayerNorm.(layer_norm_input).to(layer_norm_input)
282
+
283
+ if input_mask.nil?
284
+ output_states = output
285
+ elsif input_mask.dim != layer_norm_input.dim
286
+ if input_mask.dim == 4
287
+ input_mask = input_mask.squeeze(1).squeeze(1)
288
+ end
289
+ input_mask = input_mask.unsqueeze(2)
290
+ end
291
+
292
+ output_states
293
+ end
294
+ end
295
+
296
+ class DebertaV2Encoder < Torch::NN::Module
297
+ def initialize(config)
298
+ super()
299
+
300
+ @layer = Torch::NN::ModuleList.new(config.num_hidden_layers.times.map { |_| DebertaV2Layer.new(config) })
301
+ @relative_attention = config.getattr("relative_attention", false)
302
+
303
+ if @relative_attention
304
+ @max_relative_positions = config.getattr("max_relative_positions", -1)
305
+ if @max_relative_positions < 1
306
+ @max_relative_positions = config.max_position_embeddings
307
+ end
308
+
309
+ @position_buckets = config.getattr("position_buckets", -1)
310
+ pos_ebd_size = @max_relative_positions * 2
311
+
312
+ if @position_buckets > 0
313
+ pos_ebd_size = @position_buckets * 2
314
+ end
315
+
316
+ @rel_embeddings = Torch::NN::Embedding.new(pos_ebd_size, config.hidden_size)
317
+ end
318
+
319
+ @norm_rel_ebd = config.getattr("norm_rel_ebd", "none").downcase.split("|").map { |x| x.strip }
320
+
321
+ if @norm_rel_ebd.include?("layer_norm")
322
+ @LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps, elementwise_affine: true)
323
+ end
324
+
325
+ @conv = config.getattr("conv_kernel_size", 0) > 0 ? ConvLayer.new(config) : nil
326
+ @gradient_checkpointing = false
327
+ end
328
+
329
+ def get_rel_embedding
330
+ rel_embeddings = @relative_attention ? @rel_embeddings.weight : nil
331
+ if !rel_embeddings.nil? && @norm_rel_ebd.include?("layer_norm")
332
+ rel_embeddings = @LayerNorm.(rel_embeddings)
333
+ end
334
+ rel_embeddings
335
+ end
336
+
337
+ def get_attention_mask(attention_mask)
338
+ if attention_mask.dim <= 2
339
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
340
+ attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1)
341
+ elsif attention_mask.dim == 3
342
+ attention_mask = attention_mask.unsqueeze(1)
343
+ end
344
+
345
+ attention_mask
346
+ end
347
+
348
+ def get_rel_pos(hidden_states, query_states: nil, relative_pos: nil)
349
+ if @relative_attention && relative_pos.nil?
350
+ q = !query_states.nil? ? query_states.size(-2) : hidden_states.size(-2)
351
+ relative_pos = DebertaV2.build_relative_position(q, hidden_states.size(-2), bucket_size: @position_buckets, max_position: @max_relative_positions, device: hidden_states.device)
352
+ end
353
+ relative_pos
354
+ end
355
+
356
+ def forward(
357
+ hidden_states,
358
+ attention_mask,
359
+ output_hidden_states: true,
360
+ output_attentions: false,
361
+ query_states: nil,
362
+ relative_pos: nil,
363
+ return_dict: true
364
+ )
365
+ if attention_mask.dim <= 2
366
+ input_mask = attention_mask
367
+ else
368
+ input_mask = attention_mask.sum(-2) > 0
369
+ end
370
+ attention_mask = get_attention_mask(attention_mask)
371
+ relative_pos = get_rel_pos(hidden_states, query_states:, relative_pos:)
372
+
373
+ all_hidden_states = output_hidden_states ? [] : nil
374
+ all_attentions = output_attentions ? [] : nil
375
+
376
+ if hidden_states.is_a?(Array)
377
+ next_kv = hidden_states[0]
378
+ else
379
+ next_kv = hidden_states
380
+ end
381
+ rel_embeddings = get_rel_embedding
382
+ output_states = next_kv
383
+ @layer.each_with_index do |layer_module, i|
384
+ if output_hidden_states
385
+ all_hidden_states = all_hidden_states + [output_states]
386
+ end
387
+
388
+ if @gradient_checkpointing && @training
389
+ output_states = _gradient_checkpointing_func(layer_module.__call__, next_kv, attention_mask, query_states, relative_pos, rel_embeddings, output_attentions)
390
+ else
391
+ output_states = layer_module.(next_kv, attention_mask, query_states: query_states, relative_pos: relative_pos, rel_embeddings: rel_embeddings, output_attentions: output_attentions)
392
+ end
393
+
394
+ if output_attentions
395
+ output_states, att_m = output_states
396
+ end
397
+
398
+ if i == 0 && !@conv.nil?
399
+ output_states = @conv.(hidden_states, output_states, input_mask)
400
+ end
401
+
402
+ if !query_states.nil?
403
+ query_states = output_states
404
+ if hidden_states.is_a?(Array)
405
+ next_kv = i + 1 < @layer.length ? hidden_states[i + 1] : nil
406
+ end
407
+ else
408
+ next_kv = output_states
409
+ end
410
+
411
+ if output_attentions
412
+ all_attentions = all_attentions + [att_m]
413
+ end
414
+ end
415
+
416
+ if output_hidden_states
417
+ all_hidden_states = all_hidden_states + [output_states]
418
+ end
419
+
420
+ if !return_dict
421
+ return Array([output_states, all_hidden_states, all_attentions].select { |v| !v.nil? })
422
+ end
423
+ BaseModelOutput.new(last_hidden_state: output_states, hidden_states: all_hidden_states, attentions: all_attentions)
424
+ end
425
+ end
426
+
427
+ def self.make_log_bucket_position(relative_pos, bucket_size, max_position)
428
+ sign = Torch.sign(relative_pos)
429
+ mid = bucket_size / 2
430
+ abs_pos = Torch.where(relative_pos.lt(mid) & relative_pos.gt(-mid), Torch.tensor(mid - 1).type_as(relative_pos), Torch.abs(relative_pos))
431
+ log_pos = Torch.ceil((Torch.log(abs_pos / mid) / Torch.log(Torch.tensor((max_position - 1) / mid))) * (mid - 1)) + mid
432
+ bucket_pos = Torch.where(abs_pos.le(mid), relative_pos.type_as(log_pos), log_pos * sign)
433
+ bucket_pos
434
+ end
435
+
436
+ def self.build_relative_position(query_size, key_size, bucket_size: -1, max_position: -1, device: nil)
437
+ q_ids = Torch.arange(0, query_size, device: device)
438
+ k_ids = Torch.arange(0, key_size, device: device)
439
+ rel_pos_ids = q_ids[0.., nil] - k_ids[nil, 0..]
440
+ if bucket_size > 0 && max_position > 0
441
+ rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)
442
+ end
443
+ rel_pos_ids = rel_pos_ids.to(Torch.long)
444
+ rel_pos_ids = rel_pos_ids[...query_size, 0..]
445
+ rel_pos_ids = rel_pos_ids.unsqueeze(0)
446
+ rel_pos_ids
447
+ end
448
+
449
+ def self.c2p_dynamic_expand(c2p_pos, query_layer, relative_pos)
450
+ c2p_pos.expand([query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)])
451
+ end
452
+
453
+ def self.p2c_dynamic_expand(c2p_pos, query_layer, key_layer)
454
+ c2p_pos.expand([query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)])
455
+ end
456
+
457
+ def self.pos_dynamic_expand(pos_index, p2c_att, key_layer)
458
+ pos_index.expand(p2c_att.size[...2] + [pos_index.size(-2), key_layer.size(-2)])
459
+ end
460
+
461
+ class DisentangledSelfAttention < Torch::NN::Module
462
+ def initialize(config)
463
+ super()
464
+ if config.hidden_size % config.num_attention_heads != 0
465
+ raise ArgumentError, "The hidden size (#{config.hidden_size}) is not a multiple of the number of attention heads (#{config.num_attention_heads})"
466
+ end
467
+ @num_attention_heads = config.num_attention_heads
468
+ _attention_head_size = config.hidden_size / config.num_attention_heads
469
+ @attention_head_size = config.getattr("attention_head_size", _attention_head_size)
470
+ @all_head_size = @num_attention_heads * @attention_head_size
471
+ @query_proj = Torch::NN::Linear.new(config.hidden_size, @all_head_size, bias: true)
472
+ @key_proj = Torch::NN::Linear.new(config.hidden_size, @all_head_size, bias: true)
473
+ @value_proj = Torch::NN::Linear.new(config.hidden_size, @all_head_size, bias: true)
474
+
475
+ @share_att_key = config.getattr("share_att_key", false)
476
+ @pos_att_type = !config.pos_att_type.nil? ? config.pos_att_type : []
477
+ @relative_attention = config.getattr("relative_attention", false)
478
+
479
+ if @relative_attention
480
+ @position_buckets = config.getattr("position_buckets", -1)
481
+ @max_relative_positions = config.getattr("max_relative_positions", -1)
482
+ if @max_relative_positions < 1
483
+ @max_relative_positions = config.max_position_embeddings
484
+ end
485
+ @pos_ebd_size = @max_relative_positions
486
+ if @position_buckets > 0
487
+ @pos_ebd_size = @position_buckets
488
+ end
489
+
490
+ @pos_dropout = StableDropout.new(config.hidden_dropout_prob)
491
+
492
+ if !@share_att_key
493
+ if @pos_att_type.include?("c2p")
494
+ @pos_key_proj = Torch::NN::Linear.new(config.hidden_size, @all_head_size, bias: true)
495
+ end
496
+ if @pos_att_type.include?("p2c")
497
+ @pos_query_proj = Torch::NN::Linear.new(config.hidden_size, @all_head_size)
498
+ end
499
+ end
500
+ end
501
+
502
+ @dropout = StableDropout.new(config.attention_probs_dropout_prob)
503
+ end
504
+
505
+ def transpose_for_scores(x, attention_heads)
506
+ new_x_shape = x.size[...-1] + [attention_heads, -1]
507
+ x = x.view(new_x_shape)
508
+ x.permute(0, 2, 1, 3).contiguous.view(-1, x.size(1), x.size(-1))
509
+ end
510
+
511
+ def forward(
512
+ hidden_states,
513
+ attention_mask,
514
+ output_attentions: false,
515
+ query_states: nil,
516
+ relative_pos: nil,
517
+ rel_embeddings: nil
518
+ )
519
+ if query_states.nil?
520
+ query_states = hidden_states
521
+ end
522
+ query_layer = transpose_for_scores(@query_proj.(query_states), @num_attention_heads)
523
+ key_layer = transpose_for_scores(@key_proj.(hidden_states), @num_attention_heads)
524
+ value_layer = transpose_for_scores(@value_proj.(hidden_states), @num_attention_heads)
525
+
526
+ rel_att = nil
527
+ # Take the dot product between "query" and "key" to get the raw attention scores.
528
+ scale_factor = 1
529
+ if @pos_att_type.include?("c2p")
530
+ scale_factor += 1
531
+ end
532
+ if @pos_att_type.include?("p2c")
533
+ scale_factor += 1
534
+ end
535
+ scale = Torch.sqrt(Torch.tensor(query_layer.size(-1), dtype: Torch.float) * scale_factor)
536
+ attention_scores = Torch.bmm(query_layer, key_layer.transpose(-1, -2) / scale.to(dtype: query_layer.dtype))
537
+ if @relative_attention
538
+ rel_embeddings = @pos_dropout.(rel_embeddings)
539
+ rel_att = disentangled_attention_bias(query_layer, key_layer, relative_pos, rel_embeddings, scale_factor)
540
+ end
541
+
542
+ if !rel_att.nil?
543
+ attention_scores = attention_scores + rel_att
544
+ end
545
+ attention_scores = attention_scores
546
+ attention_scores = attention_scores.view(-1, @num_attention_heads, attention_scores.size(-2), attention_scores.size(-1))
547
+
548
+ # bsz x height x length x dimension
549
+ attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1)
550
+ attention_probs = @dropout.(attention_probs)
551
+ context_layer = Torch.bmm(attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer)
552
+ context_layer = context_layer.view(-1, @num_attention_heads, context_layer.size(-2), context_layer.size(-1)).permute(0, 2, 1, 3).contiguous
553
+ new_context_layer_shape = context_layer.size[...-2] + [-1]
554
+ context_layer = context_layer.view(new_context_layer_shape)
555
+ if output_attentions
556
+ [context_layer, attention_probs]
557
+ else
558
+ context_layer
559
+ end
560
+ end
561
+
562
+ def disentangled_attention_bias(query_layer, key_layer, relative_pos, rel_embeddings, scale_factor)
563
+ if relative_pos.nil?
564
+ q = query_layer.size(-2)
565
+ relative_pos = DebertaV2.build_relative_position(q, key_layer.size(-2), bucket_size: @position_buckets, max_position: @max_relative_positions, device: query_layer.device)
566
+ end
567
+ if relative_pos.dim == 2
568
+ relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
569
+ elsif relative_pos.dim == 3
570
+ relative_pos = relative_pos.unsqueeze(1)
571
+ elsif relative_pos.dim != 4
572
+ raise ArgumentError, "Relative position ids must be of dim 2 or 3 or 4. #{relative_pos.dim}"
573
+ end
574
+
575
+ att_span = @pos_ebd_size
576
+ relative_pos = relative_pos.long.to(query_layer.device)
577
+
578
+ rel_embeddings = rel_embeddings[0...att_span * 2, 0..].unsqueeze(0)
579
+ if @share_att_key
580
+ pos_query_layer = transpose_for_scores(@query_proj.(rel_embeddings), @num_attention_heads).repeat(query_layer.size(0) / @num_attention_heads, 1, 1)
581
+ pos_key_layer = transpose_for_scores(@key_proj.(rel_embeddings), @num_attention_heads).repeat(query_layer.size(0) / @num_attention_heads, 1, 1)
582
+ elsif @pos_att_type.include?("c2p")
583
+ pos_key_layer = transpose_for_scores(@pos_key_proj.(rel_embeddings), @num_attention_heads).repeat(query_layer.size(0) / @num_attention_heads, 1, 1)
584
+ end
585
+
586
+ score = 0
587
+ # content->position
588
+ if @pos_att_type.include?("c2p")
589
+ scale = Torch.sqrt(Torch.tensor(pos_key_layer.size(-1), dtype: Torch.float) * scale_factor)
590
+ c2p_att = Torch.bmm(query_layer, pos_key_layer.transpose(-1, -2))
591
+ c2p_pos = Torch.clamp(relative_pos + att_span, 0, (att_span * 2) - 1)
592
+ c2p_att = Torch.gather(c2p_att, dim: -1, index: c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]))
593
+ score += c2p_att / scale.to(dtype: c2p_att.dtype)
594
+ end
595
+
596
+ # position->content
597
+ if @pos_att_type.include?("p2c")
598
+ scale = Torch.sqrt(Torch.tensor(pos_query_layer.size(-1), dtype: Torch.float) * scale_factor)
599
+ if key_layer.size(-2) != query_layer.size(-2)
600
+ r_pos = DebertaV2.build_relative_position(key_layer.size(-2), key_layer.size(-2), bucket_size: @position_buckets, max_position: @max_relative_positions, device: query_layer.device)
601
+ r_pos = r_pos.unsqueeze(0)
602
+ else
603
+ r_pos = relative_pos
604
+ end
605
+
606
+ p2c_pos = Torch.clamp(-r_pos + att_span, 0, (att_span * 2) - 1)
607
+ p2c_att = Torch.bmm(key_layer, pos_query_layer.transpose(-1, -2))
608
+ p2c_att = Torch.gather(p2c_att, dim: -1, index: p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)])).transpose(-1, -2)
609
+ score += p2c_att / scale.to(dtype: p2c_att.dtype)
610
+ end
611
+
612
+ score
613
+ end
614
+ end
615
+
616
+ class DebertaV2Embeddings < Torch::NN::Module
617
+ def initialize(config)
618
+ super()
619
+ pad_token_id = config.getattr("pad_token_id", 0)
620
+ @embedding_size = config.getattr("embedding_size", config.hidden_size)
621
+ @word_embeddings = Torch::NN::Embedding.new(config.vocab_size, @embedding_size, padding_idx: pad_token_id)
622
+
623
+ @position_biased_input = config.getattr("position_biased_input", true)
624
+ if !@position_biased_input
625
+ @position_embeddings = nil
626
+ else
627
+ @position_embeddings = Torch::NN::Embedding.new(config.max_position_embeddings, @embedding_size)
628
+ end
629
+
630
+ if config.type_vocab_size > 0
631
+ @token_type_embeddings = Torch::NN::Embedding.new(config.type_vocab_size, @embedding_size)
632
+ end
633
+
634
+ if @embedding_size != config.hidden_size
635
+ @embed_proj = Torch::NN::Linear.new(@embedding_size, config.hidden_size, bias: false)
636
+ end
637
+ @LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
638
+ @dropout = StableDropout.new(config.hidden_dropout_prob)
639
+ @config = config
640
+
641
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
642
+ register_buffer("position_ids", Torch.arange(config.max_position_embeddings).expand([1, -1]), persistent: false)
643
+ end
644
+
645
+ def forward(input_ids: nil, token_type_ids: nil, position_ids: nil, mask: nil, inputs_embeds: nil)
646
+ if !input_ids.nil?
647
+ input_shape = input_ids.size
648
+ else
649
+ input_shape = inputs_embeds.size[...-1]
650
+ end
651
+
652
+ seq_length = input_shape[1]
653
+
654
+ if position_ids.nil?
655
+ position_ids = @position_ids[0.., ...seq_length]
656
+ end
657
+
658
+ if token_type_ids.nil?
659
+ token_type_ids = Torch.zeros(input_shape, dtype: Torch.long, device: @position_ids.device)
660
+ end
661
+
662
+ if inputs_embeds.nil?
663
+ inputs_embeds = @word_embeddings.(input_ids)
664
+ end
665
+
666
+ if !@position_embeddings.nil?
667
+ position_embeddings = @position_embeddings.(position_ids.long)
668
+ else
669
+ position_embeddings = Torch.zeros_like(inputs_embeds)
670
+ end
671
+
672
+ embeddings = inputs_embeds
673
+ if @position_biased_input
674
+ embeddings += position_embeddings
675
+ end
676
+ if @config.type_vocab_size > 0
677
+ token_type_embeddings = @token_type_embeddings.(token_type_ids)
678
+ embeddings += token_type_embeddings
679
+ end
680
+
681
+ if @embedding_size != @config.hidden_size
682
+ embeddings = @embed_proj.(embeddings)
683
+ end
684
+
685
+ embeddings = @LayerNorm.(embeddings)
686
+
687
+ if !mask.nil?
688
+ if mask.dim != embeddings.dim
689
+ if mask.dim == 4
690
+ mask = mask.squeeze(1).squeeze(1)
691
+ end
692
+ mask = mask.unsqueeze(2)
693
+ end
694
+ mask = mask.to(embeddings.dtype)
695
+
696
+ embeddings = embeddings * mask
697
+ end
698
+
699
+ embeddings = @dropout.(embeddings)
700
+ embeddings
701
+ end
702
+ end
703
+
704
+ class DebertaV2PreTrainedModel < PreTrainedModel
705
+ self.config_class = DebertaV2Config
706
+ self.base_model_prefix = "deberta"
707
+ # self._keys_to_ignore_on_load_unexpected = ["position_embeddings"]
708
+ # self.supports_gradient_checkpointing = true
709
+
710
+ def _init_weights(module_)
711
+ if module_.is_a?(Torch::NN::Linear)
712
+ # Slightly different from the TF version which uses truncated_normal for initialization
713
+ # cf https://github.com/pytorch/pytorch/pull/5617
714
+ module_.weight.data.normal!(mean: 0.0, std: @config.initializer_range)
715
+ if !module_.bias.nil?
716
+ module_.bias.data.zero!
717
+ end
718
+ elsif module_.is_a?(Torch::NN::Embedding)
719
+ module_.weight.data.normal!(mean: 0.0, std: @config.initializer_range)
720
+ if !module_.padding_idx.nil?
721
+ module_.weight.data.fetch(module_.padding_idx).zero!
722
+ end
723
+ end
724
+ end
725
+ end
726
+
727
+ class DebertaV2Model < DebertaV2PreTrainedModel
728
+ def initialize(config)
729
+ super(config)
730
+
731
+ @embeddings = DebertaV2Embeddings.new(config)
732
+ @encoder = DebertaV2Encoder.new(config)
733
+ @z_steps = 0
734
+ @config = config
735
+ # Initialize weights and apply final processing
736
+ post_init
737
+ end
738
+
739
+ def get_input_embeddings
740
+ @embeddings.word_embeddings
741
+ end
742
+
743
+ def set_input_embeddings(new_embeddings)
744
+ @word_embeddings = new_embeddings
745
+ end
746
+
747
+ def _prune_heads(heads_to_prune)
748
+ raise NotImplementedError, "The prune function is not implemented in DeBERTa model."
749
+ end
750
+
751
+ def forward(
752
+ input_ids,
753
+ attention_mask: nil,
754
+ token_type_ids: nil,
755
+ position_ids: nil,
756
+ inputs_embeds: nil,
757
+ output_attentions: nil,
758
+ output_hidden_states: nil,
759
+ return_dict: nil
760
+ )
761
+ output_attentions = !output_attentions.nil? ? output_attentions : @config.output_attentions
762
+ output_hidden_states = !output_hidden_states.nil? ? output_hidden_states : @config.output_hidden_states
763
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
764
+
765
+ if !input_ids.nil? && !inputs_embeds.nil?
766
+ raise ArgumentError, "You cannot specify both input_ids and inputs_embeds at the same time"
767
+ elsif !input_ids.nil?
768
+ warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
769
+ input_shape = input_ids.size
770
+ elsif !inputs_embeds.nil?
771
+ input_shape = inputs_embeds.size[...-1]
772
+ else
773
+ raise ArgumentError, "You have to specify either input_ids or inputs_embeds"
774
+ end
775
+
776
+ device = !input_ids.nil? ? input_ids.device : inputs_embeds.device
777
+
778
+ if attention_mask.nil?
779
+ attention_mask = Torch.ones(input_shape, device: device)
780
+ end
781
+ if token_type_ids.nil?
782
+ token_type_ids = Torch.zeros(input_shape, dtype: Torch.long, device: device)
783
+ end
784
+
785
+ embedding_output = @embeddings.(input_ids: input_ids, token_type_ids: token_type_ids, position_ids: position_ids, mask: attention_mask, inputs_embeds: inputs_embeds)
786
+
787
+ encoder_outputs = @encoder.(embedding_output, attention_mask, output_hidden_states: true, output_attentions: output_attentions, return_dict: return_dict)
788
+ encoded_layers = encoder_outputs[1]
789
+
790
+ if @z_steps > 1
791
+ hidden_states = encoded_layers[-2]
792
+ layers = @z_steps.times.map { |_| @encoder.layer[-1] }
793
+ query_states = encoded_layers[-1]
794
+ rel_embeddings = @encoder.get_rel_embedding
795
+ attention_mask = @encoder.get_attention_mask(attention_mask)
796
+ rel_pos = @encoder.get_rel_pos(embedding_output)
797
+ layers[1..].each do |layer|
798
+ query_states = layer(hidden_states, attention_mask, output_attentions: false, query_states: query_states, relative_pos: rel_pos, rel_embeddings: rel_embeddings)
799
+ encoded_layers << query_states
800
+ end
801
+ end
802
+
803
+ sequence_output = encoded_layers[-1]
804
+
805
+ if !return_dict
806
+ return [sequence_output] + encoder_outputs[output_hidden_states ? 1 : 2..]
807
+ end
808
+
809
+ BaseModelOutput.new(last_hidden_state: sequence_output, hidden_states: output_hidden_states ? encoder_outputs.hidden_states : nil, attentions: encoder_outputs.attentions)
810
+ end
811
+ end
812
+
813
+ class DebertaV2ForMaskedLM < DebertaV2PreTrainedModel
814
+ self._tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
815
+
816
+ def initialize(config)
817
+ super(config)
818
+
819
+ @deberta = DebertaV2Model.new(config)
820
+ @cls = DebertaV2OnlyMLMHead.new(config)
821
+
822
+ # Initialize weights and apply final processing
823
+ post_init
824
+ end
825
+
826
+ def get_output_embeddings
827
+ @cls.predictions.decoder
828
+ end
829
+
830
+ def set_output_embeddings(new_embeddings)
831
+ @decoder = new_embeddings
832
+ @bias = new_embeddings.bias
833
+ end
834
+
835
+ def forward(
836
+ input_ids: nil,
837
+ attention_mask: nil,
838
+ token_type_ids: nil,
839
+ position_ids: nil,
840
+ inputs_embeds: nil,
841
+ labels: nil,
842
+ output_attentions: nil,
843
+ output_hidden_states: nil,
844
+ return_dict: nil
845
+ )
846
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
847
+
848
+ outputs = @deberta.(input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids, position_ids: position_ids, inputs_embeds: inputs_embeds, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
849
+
850
+ sequence_output = outputs[0]
851
+ prediction_scores = @cls.(sequence_output)
852
+
853
+ masked_lm_loss = nil
854
+ if !labels.nil?
855
+ loss_fct = Torch::NN::CrossEntropyLoss.new
856
+ masked_lm_loss = loss_fct.(prediction_scores.view(-1, @config.vocab_size), labels.view(-1))
857
+ end
858
+
859
+ if !return_dict
860
+ output = [prediction_scores] + outputs[1..]
861
+ return !masked_lm_loss.nil? ? [masked_lm_loss] + output : output
862
+ end
863
+
864
+ MaskedLMOutput.new(loss: masked_lm_loss, logits: prediction_scores, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
865
+ end
866
+ end
867
+
868
+ class DebertaV2PredictionHeadTransform < Torch::NN::Module
869
+ def initialize(config)
870
+ super()
871
+ @embedding_size = config.getattr("embedding_size", config.hidden_size)
872
+
873
+ @dense = Torch::NN::Linear.new(config.hidden_size, @embedding_size)
874
+ if config.hidden_act.is_a?(String)
875
+ @transform_act_fn = ACT2FN[config.hidden_act]
876
+ else
877
+ @transform_act_fn = config.hidden_act
878
+ end
879
+ @LayerNorm = Torch::NN::LayerNorm.new(@embedding_size, eps: config.layer_norm_eps)
880
+ end
881
+
882
+ def forward(hidden_states)
883
+ hidden_states = @dense.(hidden_states)
884
+ hidden_states = @transform_act_fn.(hidden_states)
885
+ hidden_states = @LayerNorm.(hidden_states)
886
+ hidden_states
887
+ end
888
+ end
889
+
890
+ class DebertaV2LMPredictionHead < Torch::NN::Module
891
+ def initialize(config)
892
+ super()
893
+ @transform = DebertaV2PredictionHeadTransform.new(config)
894
+
895
+ @embedding_size = config.getattr("embedding_size", config.hidden_size)
896
+ # The output weights are the same as the input embeddings, but there is
897
+ # an output-only bias for each token.
898
+ @decoder = Torch::NN::Linear.new(@embedding_size, config.vocab_size, bias: false)
899
+
900
+ @bias = Torch::NN::Parameter.new(Torch.zeros(config.vocab_size))
901
+
902
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
903
+ @bias = @bias
904
+ end
905
+
906
+ def _tie_weights
907
+ @bias = @bias
908
+ end
909
+
910
+ def forward(hidden_states)
911
+ hidden_states = @transform.(hidden_states)
912
+ hidden_states = @decoder.(hidden_states)
913
+ hidden_states
914
+ end
915
+ end
916
+
917
+ class DebertaV2OnlyMLMHead < Torch::NN::Module
918
+ def initialize(config)
919
+ super()
920
+ @predictions = DebertaV2LMPredictionHead.new(config)
921
+ end
922
+
923
+ def forward(sequence_output)
924
+ prediction_scores = @predictions.(sequence_output)
925
+ prediction_scores
926
+ end
927
+ end
928
+
929
+ class DebertaV2ForSequenceClassification < DebertaV2PreTrainedModel
930
+ def initialize(config)
931
+ super(config)
932
+
933
+ num_labels = config.getattr("num_labels", 2)
934
+ @num_labels = num_labels
935
+
936
+ @deberta = DebertaV2Model.new(config)
937
+ @pooler = ContextPooler.new(config)
938
+ output_dim = @pooler.output_dim
939
+
940
+ @classifier = Torch::NN::Linear.new(output_dim, num_labels)
941
+ drop_out = config.getattr("cls_dropout", nil)
942
+ drop_out = drop_out.nil? ? @config.hidden_dropout_prob : drop_out
943
+ @dropout = StableDropout.new(drop_out)
944
+
945
+ # Initialize weights and apply final processing
946
+ post_init
947
+ end
948
+
949
+ def get_input_embeddings
950
+ @deberta.get_input_embeddings
951
+ end
952
+
953
+ def set_input_embeddings(new_embeddings)
954
+ @deberta.set_input_embeddings(new_embeddings)
955
+ end
956
+
957
+ def forward(
958
+ input_ids: nil,
959
+ attention_mask: nil,
960
+ token_type_ids: nil,
961
+ position_ids: nil,
962
+ inputs_embeds: nil,
963
+ labels: nil,
964
+ output_attentions: nil,
965
+ output_hidden_states: nil,
966
+ return_dict: nil
967
+ )
968
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
969
+
970
+ outputs = @deberta.(input_ids, token_type_ids: token_type_ids, attention_mask: attention_mask, position_ids: position_ids, inputs_embeds: inputs_embeds, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
971
+
972
+ encoder_layer = outputs[0]
973
+ pooled_output = @pooler.(encoder_layer)
974
+ pooled_output = @dropout.(pooled_output)
975
+ logits = @classifier.(pooled_output)
976
+
977
+ loss = nil
978
+ if !labels.nil?
979
+ if @config.problem_type.nil?
980
+ if @num_labels == 1
981
+ # regression task
982
+ loss_fn = Torch::NN::MSELoss.new
983
+ logits = logits.view(-1).to(labels.dtype)
984
+ loss = loss_fn.(logits, labels.view(-1))
985
+ elsif labels.dim == 1 || labels.size(-1) == 1
986
+ label_index = (labels >= 0).nonzero
987
+ labels = labels.long
988
+ if label_index.size(0) > 0
989
+ labeled_logits = Torch.gather(logits, 0, label_index.expand(label_index.size(0), logits.size(1)))
990
+ labels = Torch.gather(labels, 0, label_index.view(-1))
991
+ loss_fct = Torch::NN::CrossEntropyLoss.new
992
+ loss = loss_fct.(labeled_logits.view(-1, @num_labels).float, labels.view(-1))
993
+ else
994
+ loss = Torch.tensor(0).to(logits)
995
+ end
996
+ else
997
+ log_softmax = Torch::NN::LogSoftmax.new(-1)
998
+ loss = -(log_softmax.(logits) * labels).sum(-1).mean
999
+ end
1000
+ elsif @config.problem_type == "regression"
1001
+ loss_fct = Torch::NN::MSELoss.new
1002
+ if @num_labels == 1
1003
+ loss = loss_fct.(logits.squeeze, labels.squeeze)
1004
+ else
1005
+ loss = loss_fct.(logits, labels)
1006
+ end
1007
+ elsif @config.problem_type == "single_label_classification"
1008
+ loss_fct = Torch::NN::CrossEntropyLoss.new
1009
+ loss = loss_fct.(logits.view(-1, @num_labels), labels.view(-1))
1010
+ elsif @config.problem_type == "multi_label_classification"
1011
+ loss_fct = Torch::NN::BCEWithLogitsLoss.new
1012
+ loss = loss_fct.(logits, labels)
1013
+ end
1014
+ end
1015
+ if !return_dict
1016
+ output = [logits] + outputs[1..]
1017
+ return !loss.nil? ? [loss] + output : output
1018
+ end
1019
+
1020
+ SequenceClassifierOutput.new(loss: loss, logits: logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
1021
+ end
1022
+ end
1023
+
1024
+ class DebertaV2ForTokenClassification < DebertaV2PreTrainedModel
1025
+ def initialize(config)
1026
+ super(config)
1027
+ @num_labels = config.num_labels
1028
+
1029
+ @deberta = DebertaV2Model.new(config)
1030
+ @dropout = Torch::NN::Dropout.new(config.hidden_dropout_prob)
1031
+ @classifier = Torch::NN::Linear.new(config.hidden_size, config.num_labels)
1032
+
1033
+ # Initialize weights and apply final processing
1034
+ post_init
1035
+ end
1036
+
1037
+ def forward(
1038
+ input_ids: nil,
1039
+ attention_mask: nil,
1040
+ token_type_ids: nil,
1041
+ position_ids: nil,
1042
+ inputs_embeds: nil,
1043
+ labels: nil,
1044
+ output_attentions: nil,
1045
+ output_hidden_states: nil,
1046
+ return_dict: nil
1047
+ )
1048
+
1049
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
1050
+
1051
+ outputs = @deberta.(input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids, position_ids: position_ids, inputs_embeds: inputs_embeds, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
1052
+
1053
+ sequence_output = outputs[0]
1054
+
1055
+ sequence_output = @dropout.(sequence_output)
1056
+ logits = @classifier.(sequence_output)
1057
+
1058
+ loss = nil
1059
+ if !labels.nil?
1060
+ loss_fct = Torch::NN::CrossEntropyLoss.new
1061
+ loss = loss_fct.(logits.view(-1, @num_labels), labels.view(-1))
1062
+ end
1063
+
1064
+ if !return_dict
1065
+ output = [logits] + outputs[1..]
1066
+ return !loss.nil? ? [loss] + output : output
1067
+ end
1068
+
1069
+ TokenClassifierOutput.new(loss: loss, logits: logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
1070
+ end
1071
+ end
1072
+
1073
+ class DebertaV2ForQuestionAnswering < DebertaV2PreTrainedModel
1074
+ def initialize(config)
1075
+ super(config)
1076
+ @num_labels = config.num_labels
1077
+
1078
+ @deberta = DebertaV2Model.new(config)
1079
+ @qa_outputs = Torch::NN::Linear.new(config.hidden_size, config.num_labels)
1080
+
1081
+ # Initialize weights and apply final processing
1082
+ post_init
1083
+ end
1084
+
1085
+ def forward(
1086
+ input_ids: nil,
1087
+ attention_mask: nil,
1088
+ token_type_ids: nil,
1089
+ position_ids: nil,
1090
+ inputs_embeds: nil,
1091
+ start_positions: nil,
1092
+ end_positions: nil,
1093
+ output_attentions: nil,
1094
+ output_hidden_states: nil,
1095
+ return_dict: nil
1096
+ )
1097
+
1098
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
1099
+
1100
+ outputs = @deberta.(input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids, position_ids: position_ids, inputs_embeds: inputs_embeds, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
1101
+
1102
+ sequence_output = outputs[0]
1103
+
1104
+ logits = @qa_outputs.(sequence_output)
1105
+ start_logits, end_logits = logits.split(1, dim: -1)
1106
+ start_logits = start_logits.squeeze(-1).contiguous
1107
+ end_logits = end_logits.squeeze(-1).contiguous
1108
+
1109
+ total_loss = nil
1110
+ if !start_positions.nil? && !end_positions.nil?
1111
+ # If we are on multi-GPU, split add a dimension
1112
+ if start_positions.size.length > 1
1113
+ start_positions = start_positions.squeeze(-1)
1114
+ end
1115
+ if end_positions.size.length > 1
1116
+ end_positions = end_positions.squeeze(-1)
1117
+ end
1118
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1119
+ ignored_index = start_logits.size(1)
1120
+ start_positions = start_positions.clamp(0, ignored_index)
1121
+ end_positions = end_positions.clamp(0, ignored_index)
1122
+
1123
+ loss_fct = Torch::NN::CrossEntropyLoss.new(ignore_index: ignored_index)
1124
+ start_loss = loss_fct.(start_logits, start_positions)
1125
+ end_loss = loss_fct.(end_logits, end_positions)
1126
+ total_loss = (start_loss + end_loss) / 2
1127
+ end
1128
+
1129
+ if !return_dict
1130
+ output = [start_logits, end_logits] + outputs[1..]
1131
+ return !total_loss.nil? ? [total_loss] + output : output
1132
+ end
1133
+
1134
+ QuestionAnsweringModelOutput.new(loss: total_loss, start_logits: start_logits, end_logits: end_logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
1135
+ end
1136
+ end
1137
+
1138
+ class DebertaV2ForMultipleChoice < DebertaV2PreTrainedModel
1139
+ def initialize(config)
1140
+ super(config)
1141
+
1142
+ num_labels = config.getattr("num_labels", 2)
1143
+ @num_labels = num_labels
1144
+
1145
+ @deberta = DebertaV2Model.new(config)
1146
+ @pooler = ContextPooler.new(config)
1147
+ output_dim = @pooler.output_dim
1148
+
1149
+ @classifier = Torch::NN::Linear.new(output_dim, 1)
1150
+ drop_out = config.getattr("cls_dropout", nil)
1151
+ drop_out = drop_out.nil? ? @config.hidden_dropout_prob : drop_out
1152
+ @dropout = StableDropout.new(drop_out)
1153
+
1154
+ init_weights
1155
+ end
1156
+
1157
+ def get_input_embeddings
1158
+ @deberta.get_input_embeddings
1159
+ end
1160
+
1161
+ def set_input_embeddings(new_embeddings)
1162
+ @deberta.set_input_embeddings(new_embeddings)
1163
+ end
1164
+
1165
+ def forward(
1166
+ input_ids: nil,
1167
+ attention_mask: nil,
1168
+ token_type_ids: nil,
1169
+ position_ids: nil,
1170
+ inputs_embeds: nil,
1171
+ labels: nil,
1172
+ output_attentions: nil,
1173
+ output_hidden_states: nil,
1174
+ return_dict: nil
1175
+ )
1176
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
1177
+ num_choices = !input_ids.nil? ? input_ids.shape[1] : inputs_embeds.shape[1]
1178
+
1179
+ flat_input_ids = !input_ids.nil? ? input_ids.view(-1, input_ids.size(-1)) : nil
1180
+ flat_position_ids = !position_ids.nil? ? position_ids.view(-1, position_ids.size(-1)) : nil
1181
+ flat_token_type_ids = !token_type_ids.nil? ? token_type_ids.view(-1, token_type_ids.size(-1)) : nil
1182
+ flat_attention_mask = !attention_mask.nil? ? attention_mask.view(-1, attention_mask.size(-1)) : nil
1183
+ flat_inputs_embeds = !inputs_embeds.nil? ? inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) : nil
1184
+
1185
+ outputs = @deberta.(flat_input_ids, position_ids: flat_position_ids, token_type_ids: flat_token_type_ids, attention_mask: flat_attention_mask, inputs_embeds: flat_inputs_embeds, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
1186
+
1187
+ encoder_layer = outputs[0]
1188
+ pooled_output = @pooler.(encoder_layer)
1189
+ pooled_output = @dropout.(pooled_output)
1190
+ logits = @classifier.(pooled_output)
1191
+ reshaped_logits = logits.view(-1, num_choices)
1192
+
1193
+ loss = nil
1194
+ if !labels.nil?
1195
+ loss_fct = Torch::NN::CrossEntropyLoss.new
1196
+ loss = loss_fct.(reshaped_logits, labels)
1197
+ end
1198
+
1199
+ if !return_dict
1200
+ output = [reshaped_logits] + outputs[1..]
1201
+ return !loss.nil? ? [loss] + output : output
1202
+ end
1203
+
1204
+ MultipleChoiceModelOutput.new(loss: loss, logits: reshaped_logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
1205
+ end
1206
+ end
1207
+ end
1208
+
1209
+ DebertaV2ForSequenceClassification = DebertaV2::DebertaV2ForSequenceClassification
1210
+ end