transformers-rb 0.1.2 → 0.1.3

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,1216 @@
1
+ # Copyright 2019 Facebook AI Research and the HuggingFace Inc. team.
2
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ module Transformers
17
+ module XlmRoberta
18
+ class XLMRobertaEmbeddings < Torch::NN::Module
19
+ def initialize(config)
20
+ super()
21
+ @word_embeddings = Torch::NN::Embedding.new(config.vocab_size, config.hidden_size, padding_idx: config.pad_token_id)
22
+ @position_embeddings = Torch::NN::Embedding.new(config.max_position_embeddings, config.hidden_size)
23
+ @token_type_embeddings = Torch::NN::Embedding.new(config.type_vocab_size, config.hidden_size)
24
+
25
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
26
+ # any TensorFlow checkpoint file
27
+ @LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
28
+ @dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
29
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
30
+ @position_embedding_type = config.getattr("position_embedding_type", "absolute")
31
+ register_buffer("position_ids", Torch.arange(config.max_position_embeddings).expand([1, -1]), persistent: false)
32
+ register_buffer("token_type_ids", Torch.zeros(@position_ids.size, dtype: Torch.long), persistent: false)
33
+
34
+ @padding_idx = config.pad_token_id
35
+ @position_embeddings = Torch::NN::Embedding.new(config.max_position_embeddings, config.hidden_size, padding_idx: @padding_idx)
36
+ end
37
+
38
+ def forward(input_ids: nil, token_type_ids: nil, position_ids: nil, inputs_embeds: nil, past_key_values_length: 0)
39
+ if position_ids.nil?
40
+ if !input_ids.nil?
41
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
42
+ position_ids = create_position_ids_from_input_ids(input_ids, @padding_idx, past_key_values_length:)
43
+ else
44
+ position_ids = create_position_ids_from_inputs_embeds(inputs_embeds)
45
+ end
46
+ end
47
+
48
+ if !input_ids.nil?
49
+ input_shape = input_ids.size
50
+ else
51
+ input_shape = inputs_embeds.size[...-1]
52
+ end
53
+
54
+ seq_length = input_shape[1]
55
+
56
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
57
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
58
+ # issue #5664
59
+ if token_type_ids.nil?
60
+ if respond_to?(:token_type_ids)
61
+ buffered_token_type_ids = token_type_ids[0.., ...seq_length]
62
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
63
+ token_type_ids = buffered_token_type_ids_expanded
64
+ else
65
+ token_type_ids = Torch.zeros(input_shape, dtype: Torch.long, device: @position_ids.device)
66
+ end
67
+ end
68
+
69
+ if inputs_embeds.nil?
70
+ inputs_embeds = @word_embeddings.(input_ids)
71
+ end
72
+ token_type_embeddings = @token_type_embeddings.(token_type_ids)
73
+
74
+ embeddings = inputs_embeds + token_type_embeddings
75
+ if @position_embedding_type == "absolute"
76
+ position_embeddings = @position_embeddings.(position_ids)
77
+ embeddings += position_embeddings
78
+ end
79
+ embeddings = @LayerNorm.(embeddings)
80
+ embeddings = @dropout.(embeddings)
81
+ embeddings
82
+ end
83
+
84
+ def create_position_ids_from_inputs_embeds(inputs_embeds)
85
+ input_shape = inputs_embeds.size[...-1]
86
+ sequence_length = input_shape[1]
87
+
88
+ position_ids = Torch.arange(@padding_idx + 1, sequence_length + @padding_idx + 1, dtype: Torch.long, device: inputs_embeds.device)
89
+ position_ids.unsqueeze(0).expand(input_shape)
90
+ end
91
+
92
+ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length: 0)
93
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
94
+ mask = input_ids.ne(padding_idx).int
95
+ incremental_indices = (Torch.cumsum(mask, dim: 1).type_as(mask) + past_key_values_length) * mask
96
+ incremental_indices.long + padding_idx
97
+ end
98
+ end
99
+
100
+ class XLMRobertaSelfAttention < Torch::NN::Module
101
+ def initialize(config, position_embedding_type: nil)
102
+ super()
103
+ if config.hidden_size % config.num_attention_heads != 0 && !config.hasattr("embedding_size")
104
+ raise ArgumentError, "The hidden size (#{config.hidden_size}) is not a multiple of the number of attention heads (#{config.num_attention_heads})"
105
+ end
106
+
107
+ @num_attention_heads = config.num_attention_heads
108
+ @attention_head_size = (config.hidden_size / config.num_attention_heads).to_i
109
+ @all_head_size = @num_attention_heads * @attention_head_size
110
+
111
+ @query = Torch::NN::Linear.new(config.hidden_size, @all_head_size)
112
+ @key = Torch::NN::Linear.new(config.hidden_size, @all_head_size)
113
+ @value = Torch::NN::Linear.new(config.hidden_size, @all_head_size)
114
+
115
+ @dropout = Torch::NN::Dropout.new(p: config.attention_probs_dropout_prob)
116
+ @position_embedding_type = position_embedding_type || config.getattr("position_embedding_type", "absolute")
117
+ if @position_embedding_type == "relative_key" || @position_embedding_type == "relative_key_query"
118
+ @max_position_embeddings = config.max_position_embeddings
119
+ @distance_embedding = Torch::NN::Embedding.new((2 * config.max_position_embeddings) - 1, @attention_head_size)
120
+ end
121
+
122
+ @is_decoder = config.is_decoder
123
+ end
124
+
125
+ def transpose_for_scores(x)
126
+ new_x_shape = x.size[...-1] + [@num_attention_heads, @attention_head_size]
127
+ x = x.view(new_x_shape)
128
+ x.permute(0, 2, 1, 3)
129
+ end
130
+
131
+ def forward(
132
+ hidden_states,
133
+ attention_mask: nil,
134
+ head_mask: nil,
135
+ encoder_hidden_states: nil,
136
+ encoder_attention_mask: nil,
137
+ past_key_value: nil,
138
+ output_attentions: false
139
+ )
140
+ mixed_query_layer = @query.(hidden_states)
141
+
142
+ # If this is instantiated as a cross-attention module, the keys
143
+ # and values come from an encoder; the attention mask needs to be
144
+ # such that the encoder's padding tokens are not attended to.
145
+ is_cross_attention = !encoder_hidden_states.nil?
146
+
147
+ if is_cross_attention && !past_key_value.nil?
148
+ # reuse k,v, cross_attentions
149
+ key_layer = past_key_value[0]
150
+ value_layer = past_key_value[1]
151
+ attention_mask = encoder_attention_mask
152
+ elsif is_cross_attention
153
+ key_layer = transpose_for_scores(@key.(encoder_hidden_states))
154
+ value_layer = transpose_for_scores(@value.(encoder_hidden_states))
155
+ attention_mask = encoder_attention_mask
156
+ elsif !past_key_value.nil?
157
+ key_layer = transpose_for_scores(@key.(hidden_states))
158
+ value_layer = transpose_for_scores(@value.(hidden_states))
159
+ key_layer = Torch.cat([past_key_value[0], key_layer], dim: 2)
160
+ value_layer = Torch.cat([past_key_value[1], value_layer], dim: 2)
161
+ else
162
+ key_layer = transpose_for_scores(@key.(hidden_states))
163
+ value_layer = transpose_for_scores(@value.(hidden_states))
164
+ end
165
+
166
+ query_layer = transpose_for_scores(mixed_query_layer)
167
+
168
+ use_cache = !past_key_value.nil?
169
+ if @is_decoder
170
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
171
+ # Further calls to cross_attention layer can then reuse all cross-attention
172
+ # key/value_states (first "if" case)
173
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
174
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
175
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
176
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
177
+ past_key_value = [key_layer, value_layer]
178
+ end
179
+
180
+ # Take the dot product between "query" and "key" to get the raw attention scores.
181
+ attention_scores = Torch.matmul(query_layer, key_layer.transpose(-1, -2))
182
+
183
+ if @position_embedding_type == "relative_key" || @position_embedding_type == "relative_key_query"
184
+ query_length, key_length = [query_layer.shape[2], key_layer.shape[2]]
185
+ if use_cache
186
+ position_ids_l = Torch.tensor(key_length - 1, dtype: Torch.long, device: hidden_states.device).view(-1, 1)
187
+ else
188
+ position_ids_l = Torch.arange(query_length, dtype: Torch.long, device: hidden_states.device).view(-1, 1)
189
+ end
190
+ position_ids_r = Torch.arange(key_length, dtype: Torch.long, device: hidden_states.device).view(1, -1)
191
+ distance = position_ids_l - position_ids_r
192
+
193
+ positional_embedding = @distance_embedding.((distance + @max_position_embeddings) - 1)
194
+ positional_embedding = positional_embedding.to(dtype: query_layer.dtype)
195
+
196
+ if @position_embedding_type == "relative_key"
197
+ relative_position_scores = Torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
198
+ attention_scores = attention_scores + relative_position_scores
199
+ elsif @position_embedding_type == "relative_key_query"
200
+ relative_position_scores_query = Torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
201
+ relative_position_scores_key = Torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
202
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
203
+ end
204
+ end
205
+
206
+ attention_scores = attention_scores / Math.sqrt(@attention_head_size)
207
+ if !attention_mask.nil?
208
+ # Apply the attention mask is (precomputed for all layers in XLMRobertaModel forward() function)
209
+ attention_scores = attention_scores + attention_mask
210
+ end
211
+
212
+ # Normalize the attention scores to probabilities.
213
+ attention_probs = Torch::NN::Functional.softmax(attention_scores, dim: -1)
214
+
215
+ # This is actually dropping out entire tokens to attend to, which might
216
+ # seem a bit unusual, but is taken from the original Transformer paper.
217
+ attention_probs = @dropout.(attention_probs)
218
+
219
+ # Mask heads if we want to
220
+ if !head_mask.nil?
221
+ attention_probs = attention_probs * head_mask
222
+ end
223
+
224
+ context_layer = Torch.matmul(attention_probs, value_layer)
225
+
226
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous
227
+ new_context_layer_shape = context_layer.size[...-2] + [@all_head_size]
228
+ context_layer = context_layer.view(new_context_layer_shape)
229
+
230
+ outputs = output_attentions ? [context_layer, attention_probs] : [context_layer]
231
+
232
+ if @is_decoder
233
+ outputs = outputs + [past_key_value]
234
+ end
235
+ outputs
236
+ end
237
+ end
238
+
239
+ class XLMRobertaSdpaSelfAttention < XLMRobertaSelfAttention
240
+ def initialize(config, position_embedding_type: nil)
241
+ super(config, position_embedding_type: position_embedding_type)
242
+ @dropout_prob = config.attention_probs_dropout_prob
243
+ @require_contiguous_qkv = Packaging::Version.parse(Utils.get_torch_version) < Packaging::Version.parse("2.2.0")
244
+ end
245
+
246
+ # Adapted from XLMRobertaSelfAttention
247
+ def forward(
248
+ hidden_states,
249
+ attention_mask: nil,
250
+ head_mask: nil,
251
+ encoder_hidden_states: nil,
252
+ encoder_attention_mask: nil,
253
+ past_key_value: nil,
254
+ output_attentions: false
255
+ )
256
+ if @position_embedding_type != "absolute" || output_attentions || !head_mask.nil?
257
+ # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented.
258
+ Transformers.logger.warn("XLMRobertaSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support non-absolute `position_embedding_type` or `output_attentions: true` or `head_mask`. Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation: \"eager\"` when loading the model.")
259
+ return super(hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)
260
+ end
261
+
262
+ bsz, tgt_len, _ = hidden_states.size
263
+
264
+ query_layer = transpose_for_scores(@query.(hidden_states))
265
+
266
+ # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention
267
+ # mask needs to be such that the encoder's padding tokens are not attended to.
268
+ is_cross_attention = !encoder_hidden_states.nil?
269
+
270
+ current_states = is_cross_attention ? encoder_hidden_states : hidden_states
271
+ attention_mask = is_cross_attention ? encoder_attention_mask : attention_mask
272
+
273
+ # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning
274
+ if is_cross_attention && past_key_value && past_key_value[0].shape[2] == current_states.shape[1]
275
+ key_layer, value_layer = past_key_value
276
+ else
277
+ key_layer = transpose_for_scores(@key.(current_states))
278
+ value_layer = transpose_for_scores(@value.(current_states))
279
+ if !past_key_value.nil? && !is_cross_attention
280
+ key_layer = Torch.cat([past_key_value[0], key_layer], dim: 2)
281
+ value_layer = Torch.cat([past_key_value[1], value_layer], dim: 2)
282
+ end
283
+ end
284
+
285
+ if @is_decoder
286
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
287
+ # Further calls to cross_attention layer can then reuse all cross-attention
288
+ # key/value_states (first "if" case)
289
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
290
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
291
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
292
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
293
+ past_key_value = [key_layer, value_layer]
294
+ end
295
+
296
+ # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
297
+ # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
298
+ # Reference: https://github.com/pytorch/pytorch/issues/112577
299
+ if @require_contiguous_qkv && query_layer.device.type == "cuda" && !attention_mask.nil?
300
+ query_layer = query_layer.contiguous
301
+ key_layer = key_layer.contiguous
302
+ value_layer = value_layer.contiguous
303
+ end
304
+
305
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
306
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
307
+ # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create
308
+ # a causal mask in case tgt_len == 1.
309
+ is_causal = @is_decoder && !is_cross_attention && attention_mask.nil? && tgt_len > 1 ? true : false
310
+
311
+ attn_output = Torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, attn_mask: attention_mask, dropout_p: @training ? @dropout_prob : 0.0, is_causal: is_causal)
312
+
313
+ attn_output = attn_output.transpose(1, 2)
314
+ attn_output = attn_output.reshape(bsz, tgt_len, @all_head_size)
315
+
316
+ outputs = [attn_output]
317
+ if @is_decoder
318
+ outputs = outputs + [past_key_value]
319
+ end
320
+ outputs
321
+ end
322
+ end
323
+
324
+ class XLMRobertaSelfOutput < Torch::NN::Module
325
+ def initialize(config)
326
+ super()
327
+ @dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
328
+ @LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
329
+ @dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
330
+ end
331
+
332
+ def forward(hidden_states, input_tensor)
333
+ hidden_states = @dense.(hidden_states)
334
+ hidden_states = @dropout.(hidden_states)
335
+ hidden_states = @LayerNorm.(hidden_states + input_tensor)
336
+ hidden_states
337
+ end
338
+ end
339
+
340
+ XLM_ROBERTA_SELF_ATTENTION_CLASSES = {"eager" => XLMRobertaSelfAttention, "sdpa" => XLMRobertaSdpaSelfAttention}
341
+
342
+ class XLMRobertaAttention < Torch::NN::Module
343
+ def initialize(config, position_embedding_type: nil)
344
+ super()
345
+ @self = XLM_ROBERTA_SELF_ATTENTION_CLASSES.fetch(config._attn_implementation).new(config, position_embedding_type: position_embedding_type)
346
+ @output = XLMRobertaSelfOutput.new(config)
347
+ @pruned_heads = Set.new
348
+ end
349
+
350
+ def prune_heads(heads)
351
+ if heads.length == 0
352
+ return
353
+ end
354
+ heads, index = TorchUtils.find_pruneable_heads_and_indices(heads, @self.num_attention_heads, @self.attention_head_size, @pruned_heads)
355
+
356
+ # Prune linear layers
357
+ @query = TorchUtils.prune_linear_layer(@self.query, index)
358
+ @key = TorchUtils.prune_linear_layer(@self.key, index)
359
+ @value = TorchUtils.prune_linear_layer(@self.value, index)
360
+ @dense = TorchUtils.prune_linear_layer(@output.dense, index, dim: 1)
361
+
362
+ # Update hyper params and store pruned heads
363
+ @num_attention_heads = @self.num_attention_heads - heads.length
364
+ @all_head_size = @self.attention_head_size * @self.num_attention_heads
365
+ @pruned_heads = @pruned_heads.union(heads)
366
+ end
367
+
368
+ def forward(
369
+ hidden_states,
370
+ attention_mask: nil,
371
+ head_mask: nil,
372
+ encoder_hidden_states: nil,
373
+ encoder_attention_mask: nil,
374
+ past_key_value: nil,
375
+ output_attentions: false
376
+ )
377
+ self_outputs = @self.(hidden_states, attention_mask:, head_mask:, encoder_hidden_states:, encoder_attention_mask:, past_key_value:, output_attentions:)
378
+ attention_output = @output.(self_outputs[0], hidden_states)
379
+ outputs = [attention_output] + self_outputs[1..]
380
+ outputs
381
+ end
382
+ end
383
+
384
+ class XLMRobertaIntermediate < Torch::NN::Module
385
+ def initialize(config)
386
+ super()
387
+ @dense = Torch::NN::Linear.new(config.hidden_size, config.intermediate_size)
388
+ if config.hidden_act.is_a?(String)
389
+ @intermediate_act_fn = ACT2FN[config.hidden_act]
390
+ else
391
+ @intermediate_act_fn = config.hidden_act
392
+ end
393
+ end
394
+
395
+ def forward(hidden_states)
396
+ hidden_states = @dense.(hidden_states)
397
+ hidden_states = @intermediate_act_fn.(hidden_states)
398
+ hidden_states
399
+ end
400
+ end
401
+
402
+ class XLMRobertaOutput < Torch::NN::Module
403
+ def initialize(config)
404
+ super()
405
+ @dense = Torch::NN::Linear.new(config.intermediate_size, config.hidden_size)
406
+ @LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
407
+ @dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
408
+ end
409
+
410
+ def forward(hidden_states, input_tensor)
411
+ hidden_states = @dense.(hidden_states)
412
+ hidden_states = @dropout.(hidden_states)
413
+ hidden_states = @LayerNorm.(hidden_states + input_tensor)
414
+ hidden_states
415
+ end
416
+ end
417
+
418
+ class XLMRobertaLayer < Torch::NN::Module
419
+ def initialize(config)
420
+ super()
421
+ @chunk_size_feed_forward = config.chunk_size_feed_forward
422
+ @seq_len_dim = 1
423
+ @attention = XLMRobertaAttention.new(config)
424
+ @is_decoder = config.is_decoder
425
+ @add_cross_attention = config.add_cross_attention
426
+ if @add_cross_attention
427
+ if !@is_decoder
428
+ raise ArgumentError, "#{self} should be used as a decoder model if cross attention is added"
429
+ end
430
+ @crossattention = XLMRobertaAttention.new(config, position_embedding_type: "absolute")
431
+ end
432
+ @intermediate = XLMRobertaIntermediate.new(config)
433
+ @output = XLMRobertaOutput.new(config)
434
+ end
435
+
436
+ def forward(
437
+ hidden_states,
438
+ attention_mask: nil,
439
+ head_mask: nil,
440
+ encoder_hidden_states: nil,
441
+ encoder_attention_mask: nil,
442
+ past_key_value: nil,
443
+ output_attentions: false
444
+ )
445
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
446
+ self_attn_past_key_value = !past_key_value.nil? ? past_key_value[...2] : nil
447
+ self_attention_outputs = @attention.(hidden_states, attention_mask:, head_mask:, output_attentions: output_attentions, past_key_value: self_attn_past_key_value)
448
+ attention_output = self_attention_outputs[0]
449
+
450
+ # if decoder, the last output is tuple of self-attn cache
451
+ if @is_decoder
452
+ outputs = self_attention_outputs[1...-1]
453
+ present_key_value = self_attention_outputs[-1]
454
+ else
455
+ outputs = self_attention_outputs[1..]
456
+ end
457
+
458
+ cross_attn_present_key_value = nil
459
+ if @is_decoder && !encoder_hidden_states.nil?
460
+ if instance_variable_defined?(:@crossattention)
461
+ raise ArgumentError, "If `encoder_hidden_states` are passed, #{self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
462
+ end
463
+
464
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
465
+ cross_attn_past_key_value = !past_key_value.nil? ? past_key_value[-2..] : nil
466
+ cross_attention_outputs = @crossattention.(attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, cross_attn_past_key_value, output_attentions)
467
+ attention_output = cross_attention_outputs[0]
468
+ outputs = outputs + cross_attention_outputs[1...-1]
469
+
470
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
471
+ cross_attn_present_key_value = cross_attention_outputs[-1]
472
+ present_key_value = present_key_value + cross_attn_present_key_value
473
+ end
474
+
475
+ layer_output = TorchUtils.apply_chunking_to_forward(method(:feed_forward_chunk), @chunk_size_feed_forward, @seq_len_dim, attention_output)
476
+ outputs = [layer_output] + outputs
477
+
478
+ # if decoder, return the attn key/values as the last output
479
+ if @is_decoder
480
+ outputs = outputs + [present_key_value]
481
+ end
482
+
483
+ outputs
484
+ end
485
+
486
+ def feed_forward_chunk(attention_output)
487
+ intermediate_output = @intermediate.(attention_output)
488
+ layer_output = @output.(intermediate_output, attention_output)
489
+ layer_output
490
+ end
491
+ end
492
+
493
+ class XLMRobertaEncoder < Torch::NN::Module
494
+ def initialize(config)
495
+ super()
496
+ @config = config
497
+ @layer = Torch::NN::ModuleList.new(config.num_hidden_layers.times.map { |_| XLMRobertaLayer.new(config) })
498
+ @gradient_checkpointing = false
499
+ end
500
+
501
+ def forward(
502
+ hidden_states,
503
+ attention_mask: nil,
504
+ head_mask: nil,
505
+ encoder_hidden_states: nil,
506
+ encoder_attention_mask: nil,
507
+ past_key_values: nil,
508
+ use_cache: nil,
509
+ output_attentions: false,
510
+ output_hidden_states: false,
511
+ return_dict: true
512
+ )
513
+ all_hidden_states = output_hidden_states ? [] : nil
514
+ all_self_attentions = output_attentions ? [] : nil
515
+ all_cross_attentions = output_attentions && @config.add_cross_attention ? [] : nil
516
+
517
+ if @gradient_checkpointing && @training
518
+ if use_cache
519
+ Transformers.logger.warn("`use_cache: true` is incompatible with gradient checkpointing. Setting `use_cache: false`...")
520
+ use_cache = false
521
+ end
522
+ end
523
+
524
+ next_decoder_cache = use_cache ? [] : nil
525
+ @layer.each_with_index do |layer_module, i|
526
+ if output_hidden_states
527
+ all_hidden_states = all_hidden_states + [hidden_states]
528
+ end
529
+
530
+ layer_head_mask = !head_mask.nil? ? head_mask[i] : nil
531
+ past_key_value = !past_key_values.nil? ? past_key_values[i] : nil
532
+
533
+ if @gradient_checkpointing && @training
534
+ layer_outputs = _gradient_checkpointing_func(layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)
535
+ else
536
+ layer_outputs = layer_module.(hidden_states, attention_mask:, head_mask: layer_head_mask, encoder_hidden_states:, encoder_attention_mask:, past_key_value:, output_attentions:)
537
+ end
538
+
539
+ hidden_states = layer_outputs[0]
540
+ if use_cache
541
+ next_decoder_cache += [layer_outputs[-1]]
542
+ end
543
+ if output_attentions
544
+ all_self_attentions = all_self_attentions + [layer_outputs[1]]
545
+ if @config.add_cross_attention
546
+ all_cross_attentions = all_cross_attentions + [layer_outputs[2]]
547
+ end
548
+ end
549
+ end
550
+
551
+ if output_hidden_states
552
+ all_hidden_states = all_hidden_states + [hidden_states]
553
+ end
554
+
555
+ if !return_dict
556
+ return Array([hidden_states, next_decoder_cache, all_hidden_states, all_self_attentions, all_cross_attentions].select { |v| !v.nil? })
557
+ end
558
+ BaseModelOutputWithPastAndCrossAttentions.new(last_hidden_state: hidden_states, past_key_values: next_decoder_cache, hidden_states: all_hidden_states, attentions: all_self_attentions, cross_attentions: all_cross_attentions)
559
+ end
560
+ end
561
+
562
+ class XLMRobertaPooler < Torch::NN::Module
563
+ def initialize(config)
564
+ super()
565
+ @dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
566
+ @activation = Torch::NN::Tanh.new
567
+ end
568
+
569
+ def forward(hidden_states)
570
+ # We "pool" the model by simply taking the hidden state corresponding
571
+ # to the first token.
572
+ first_token_tensor = hidden_states[0.., 0]
573
+ pooled_output = @dense.(first_token_tensor)
574
+ pooled_output = @activation.(pooled_output)
575
+ pooled_output
576
+ end
577
+ end
578
+
579
+ class XLMRobertaPreTrainedModel < PreTrainedModel
580
+ self.config_class = XLMRobertaConfig
581
+ self.base_model_prefix = "roberta"
582
+ # self.supports_gradient_checkpointing = true
583
+ # self._no_split_modules = ["XLMRobertaEmbeddings", "XLMRobertaSelfAttention", "XLMRobertaSdpaSelfAttention"]
584
+ # self._supports_sdpa = true
585
+
586
+ def _init_weights(module_)
587
+ if module_.is_a?(Torch::NN::Linear)
588
+ # Slightly different from the TF version which uses truncated_normal for initialization
589
+ # cf https://github.com/pytorch/pytorch/pull/5617
590
+ module_.weight.data.normal!(mean: 0.0, std: @config.initializer_range)
591
+ if !module_.bias.nil?
592
+ module_.bias.data.zero!
593
+ end
594
+ elsif module_.is_a?(Torch::NN::Embedding)
595
+ module_.weight.data.normal!(mean: 0.0, std: @config.initializer_range)
596
+ if !module_.padding_idx.nil?
597
+ module_.weight.data.fetch(module_.padding_idx).zero!
598
+ end
599
+ elsif module_.is_a?(Torch::NN::LayerNorm)
600
+ module_.bias.data.zero!
601
+ module_.weight.data.fill!(1.0)
602
+ end
603
+ end
604
+ end
605
+
606
+ class XLMRobertaModel < XLMRobertaPreTrainedModel
607
+ # self._no_split_modules = ["XLMRobertaEmbeddings", "XLMRobertaLayer"]
608
+
609
+ def initialize(config, add_pooling_layer: true)
610
+ super(config)
611
+ @config = config
612
+
613
+ @embeddings = XLMRobertaEmbeddings.new(config)
614
+ @encoder = XLMRobertaEncoder.new(config)
615
+
616
+ @pooler = add_pooling_layer ? XLMRobertaPooler.new(config) : nil
617
+
618
+ @attn_implementation = config._attn_implementation
619
+ @position_embedding_type = config.position_embedding_type
620
+
621
+ # Initialize weights and apply final processing
622
+ post_init
623
+ end
624
+
625
+ def get_input_embeddings
626
+ @embeddings.word_embeddings
627
+ end
628
+
629
+ def set_input_embeddings(value)
630
+ @word_embeddings = value
631
+ end
632
+
633
+ def _prune_heads(heads_to_prune)
634
+ heads_to_prune.each do |layer, heads|
635
+ @encoder.layer[layer].attention.prune_heads(heads)
636
+ end
637
+ end
638
+
639
+ def forward(
640
+ input_ids,
641
+ attention_mask: nil,
642
+ token_type_ids: nil,
643
+ position_ids: nil,
644
+ head_mask: nil,
645
+ inputs_embeds: nil,
646
+ encoder_hidden_states: nil,
647
+ encoder_attention_mask: nil,
648
+ past_key_values: nil,
649
+ use_cache: nil,
650
+ output_attentions: nil,
651
+ output_hidden_states: nil,
652
+ return_dict: nil
653
+ )
654
+ output_attentions = !output_attentions.nil? ? output_attentions : @config.output_attentions
655
+ output_hidden_states = !output_hidden_states.nil? ? output_hidden_states : @config.output_hidden_states
656
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
657
+
658
+ if @config.is_decoder
659
+ use_cache = !use_cache.nil? ? use_cache : @config.use_cache
660
+ else
661
+ use_cache = false
662
+ end
663
+
664
+ if !input_ids.nil? && !inputs_embeds.nil?
665
+ raise ArgumentError, "You cannot specify both input_ids and inputs_embeds at the same time"
666
+ elsif !input_ids.nil?
667
+ warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
668
+ input_shape = input_ids.size
669
+ elsif !inputs_embeds.nil?
670
+ input_shape = inputs_embeds.size[...-1]
671
+ else
672
+ raise ArgumentError, "You have to specify either input_ids or inputs_embeds"
673
+ end
674
+
675
+ batch_size, seq_length = input_shape
676
+ device = !input_ids.nil? ? input_ids.device : inputs_embeds.device
677
+
678
+ # past_key_values_length
679
+ past_key_values_length = !past_key_values.nil? ? past_key_values[0][0].shape[2] : 0
680
+
681
+ if token_type_ids.nil?
682
+ if @embeddings.respond_to?(:token_type_ids)
683
+ buffered_token_type_ids = @embeddings.token_type_ids[0.., ...seq_length]
684
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
685
+ token_type_ids = buffered_token_type_ids_expanded
686
+ else
687
+ token_type_ids = Torch.zeros(input_shape, dtype: Torch.long, device: device)
688
+ end
689
+ end
690
+
691
+ embedding_output = @embeddings.(input_ids: input_ids, position_ids: position_ids, token_type_ids: token_type_ids, inputs_embeds: inputs_embeds, past_key_values_length: past_key_values_length)
692
+
693
+ if attention_mask.nil?
694
+ attention_mask = Torch.ones([batch_size, seq_length + past_key_values_length], device: device)
695
+ end
696
+
697
+ use_sdpa_attention_masks = @attn_implementation == "sdpa" && @position_embedding_type == "absolute" && head_mask.nil? && !output_attentions
698
+
699
+ # Expand the attention mask
700
+ if use_sdpa_attention_masks && attention_mask.dim == 2
701
+ # Expand the attention mask for SDPA.
702
+ # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
703
+ if @config.is_decoder
704
+ extended_attention_mask = ModelingAttnMaskUtils._prepare_4d_causal_attention_mask_for_sdpa(attention_mask, input_shape, embedding_output, past_key_values_length)
705
+ else
706
+ extended_attention_mask = ModelingAttnMaskUtils._prepare_4d_attention_mask_for_sdpa(attention_mask, embedding_output.dtype, tgt_len: seq_length)
707
+ end
708
+ else
709
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
710
+ # ourselves in which case we just need to make it broadcastable to all heads.
711
+ extended_attention_mask = get_extended_attention_mask(attention_mask, input_shape)
712
+ end
713
+
714
+ # If a 2D or 3D attention mask is provided for the cross-attention
715
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
716
+ if @config.is_decoder && !encoder_hidden_states.nil?
717
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size
718
+ encoder_hidden_shape = [encoder_batch_size, encoder_sequence_length]
719
+ if encoder_attention_mask.nil?
720
+ encoder_attention_mask = Torch.ones(encoder_hidden_shape, device: device)
721
+ end
722
+
723
+ if use_sdpa_attention_masks && encoder_attention_mask.dim == 2
724
+ # Expand the attention mask for SDPA.
725
+ # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
726
+ encoder_extended_attention_mask = ModelingAttnMaskUtils._prepare_4d_attention_mask_for_sdpa(encoder_attention_mask, embedding_output.dtype, tgt_len: seq_length)
727
+ else
728
+ encoder_extended_attention_mask = invert_attention_mask(encoder_attention_mask)
729
+ end
730
+ else
731
+ encoder_extended_attention_mask = nil
732
+ end
733
+
734
+ # Prepare head mask if needed
735
+ # 1.0 in head_mask indicate we keep the head
736
+ # attention_probs has shape bsz x n_heads x N x N
737
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
738
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
739
+ head_mask = get_head_mask(head_mask, @config.num_hidden_layers)
740
+
741
+ encoder_outputs = @encoder.(embedding_output, attention_mask: extended_attention_mask, head_mask: head_mask, encoder_hidden_states: encoder_hidden_states, encoder_attention_mask: encoder_extended_attention_mask, past_key_values: past_key_values, use_cache: use_cache, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
742
+ sequence_output = encoder_outputs[0]
743
+ pooled_output = !@pooler.nil? ? @pooler.(sequence_output) : nil
744
+
745
+ if !return_dict
746
+ return [sequence_output, pooled_output] + encoder_outputs[1..]
747
+ end
748
+
749
+ BaseModelOutputWithPoolingAndCrossAttentions.new(last_hidden_state: sequence_output, pooler_output: pooled_output, past_key_values: encoder_outputs.past_key_values, hidden_states: encoder_outputs.hidden_states, attentions: encoder_outputs.attentions, cross_attentions: encoder_outputs.cross_attentions)
750
+ end
751
+ end
752
+
753
+ class XLMRobertaForCausalLM < XLMRobertaPreTrainedModel
754
+ self._tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
755
+
756
+ def initialize(config)
757
+ super(config)
758
+
759
+ if !config.is_decoder
760
+ Transformers.logger.warn("If you want to use `XLMRobertaLMHeadModel` as a standalone, add `is_decoder=True.`")
761
+ end
762
+
763
+ @roberta = XLMRobertaModel.new(config, add_pooling_layer: false)
764
+ @lm_head = XLMRobertaLMHead.new(config)
765
+
766
+ # Initialize weights and apply final processing
767
+ post_init
768
+ end
769
+
770
+ def get_output_embeddings
771
+ @lm_head.decoder
772
+ end
773
+
774
+ def set_output_embeddings(new_embeddings)
775
+ @decoder = new_embeddings
776
+ end
777
+
778
+ def forward(
779
+ input_ids: nil,
780
+ attention_mask: nil,
781
+ token_type_ids: nil,
782
+ position_ids: nil,
783
+ head_mask: nil,
784
+ inputs_embeds: nil,
785
+ encoder_hidden_states: nil,
786
+ encoder_attention_mask: nil,
787
+ labels: nil,
788
+ past_key_values: nil,
789
+ use_cache: nil,
790
+ output_attentions: nil,
791
+ output_hidden_states: nil,
792
+ return_dict: nil
793
+ )
794
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
795
+ if !labels.nil?
796
+ use_cache = false
797
+ end
798
+
799
+ outputs = @roberta.(input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids, position_ids: position_ids, head_mask: head_mask, inputs_embeds: inputs_embeds, encoder_hidden_states: encoder_hidden_states, encoder_attention_mask: encoder_attention_mask, past_key_values: past_key_values, use_cache: use_cache, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
800
+
801
+ sequence_output = outputs[0]
802
+ prediction_scores = @lm_head.(sequence_output)
803
+
804
+ lm_loss = nil
805
+ if !labels.nil?
806
+ # move labels to correct device to enable model parallelism
807
+ labels = labels.to(prediction_scores.device)
808
+ # we are doing next-token prediction; shift prediction scores and input ids by one
809
+ shifted_prediction_scores = prediction_scores[0.., ...-1, 0..].contiguous
810
+ labels = labels[0.., 1..].contiguous
811
+ loss_fct = Torch::NN::CrossEntropyLoss.new
812
+ lm_loss = loss_fct.(shifted_prediction_scores.view(-1, @config.vocab_size), labels.view(-1))
813
+ end
814
+
815
+ if !return_dict
816
+ output = [prediction_scores] + outputs[2..]
817
+ return !lm_loss.nil? ? [lm_loss] + output : output
818
+ end
819
+
820
+ CausalLMOutputWithCrossAttentions.new(loss: lm_loss, logits: prediction_scores, past_key_values: outputs.past_key_values, hidden_states: outputs.hidden_states, attentions: outputs.attentions, cross_attentions: outputs.cross_attentions)
821
+ end
822
+
823
+ def prepare_inputs_for_generation(input_ids, past_key_values: nil, attention_mask: nil, **model_kwargs)
824
+ input_shape = input_ids.shape
825
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
826
+ if attention_mask.nil?
827
+ attention_mask = input_ids.new_ones(input_shape)
828
+ end
829
+
830
+ # cut decoder_input_ids if past_key_values is used
831
+ if !past_key_values.nil?
832
+ past_length = past_key_values[0][0].shape[2]
833
+
834
+ # Some generation methods already pass only the last input ID
835
+ if input_ids.shape[1] > past_length
836
+ remove_prefix_length = past_length
837
+ else
838
+ # Default to old behavior: keep only final ID
839
+ remove_prefix_length = input_ids.shape[1] - 1
840
+ end
841
+
842
+ input_ids = input_ids[0.., remove_prefix_length..]
843
+ end
844
+
845
+ {"input_ids" => input_ids, "attention_mask" => attention_mask, "past_key_values" => past_key_values}
846
+ end
847
+
848
+ def _reorder_cache(past_key_values, beam_idx)
849
+ reordered_past = []
850
+ past_key_values.each do |layer_past|
851
+ reordered_past += [Array(layer_past.select { |past_state| past_state })]
852
+ end
853
+ reordered_past
854
+ end
855
+ end
856
+
857
+ class XLMRobertaForMaskedLM < XLMRobertaPreTrainedModel
858
+ self._tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
859
+
860
+ def initialize(config)
861
+ super(config)
862
+
863
+ if config.is_decoder
864
+ Transformers.logger.warn("If you want to use `XLMRobertaForMaskedLM` make sure `config.is_decoder: false` for bi-directional self-attention.")
865
+ end
866
+
867
+ @roberta = XLMRobertaModel.new(config, add_pooling_layer: false)
868
+ @lm_head = XLMRobertaLMHead.new(config)
869
+
870
+ # Initialize weights and apply final processing
871
+ post_init
872
+ end
873
+
874
+ def get_output_embeddings
875
+ @lm_head.decoder
876
+ end
877
+
878
+ def set_output_embeddings(new_embeddings)
879
+ @decoder = new_embeddings
880
+ end
881
+
882
+ def forward(
883
+ input_ids: nil,
884
+ attention_mask: nil,
885
+ token_type_ids: nil,
886
+ position_ids: nil,
887
+ head_mask: nil,
888
+ inputs_embeds: nil,
889
+ encoder_hidden_states: nil,
890
+ encoder_attention_mask: nil,
891
+ labels: nil,
892
+ output_attentions: nil,
893
+ output_hidden_states: nil,
894
+ return_dict: nil
895
+ )
896
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
897
+
898
+ outputs = @roberta.(input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids, position_ids: position_ids, head_mask: head_mask, inputs_embeds: inputs_embeds, encoder_hidden_states: encoder_hidden_states, encoder_attention_mask: encoder_attention_mask, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
899
+ sequence_output = outputs[0]
900
+ prediction_scores = @lm_head.(sequence_output)
901
+
902
+ masked_lm_loss = nil
903
+ if !labels.nil?
904
+ # move labels to correct device to enable model parallelism
905
+ labels = labels.to(prediction_scores.device)
906
+ loss_fct = Torch::NN::CrossEntropyLoss.new
907
+ masked_lm_loss = loss_fct.(prediction_scores.view(-1, @config.vocab_size), labels.view(-1))
908
+ end
909
+
910
+ if !return_dict
911
+ output = [prediction_scores] + outputs[2..]
912
+ return !masked_lm_loss.nil? ? [masked_lm_loss] + output : output
913
+ end
914
+
915
+ MaskedLMOutput.new(loss: masked_lm_loss, logits: prediction_scores, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
916
+ end
917
+ end
918
+
919
+ class XLMRobertaLMHead < Torch::NN::Module
920
+ def initialize(config)
921
+ super()
922
+ @dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
923
+ @layer_norm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
924
+
925
+ @decoder = Torch::NN::Linear.new(config.hidden_size, config.vocab_size)
926
+ @bias = Torch::NN::Parameter.new(Torch.zeros(config.vocab_size))
927
+ @bias = @bias
928
+ end
929
+
930
+ def forward(features, **kwargs)
931
+ x = @dense.(features)
932
+ x = Activations.gelu(x)
933
+ x = @layer_norm.(x)
934
+
935
+ # project back to size of vocabulary with bias
936
+ x = @decoder.(x)
937
+
938
+ x
939
+ end
940
+
941
+ def _tie_weights
942
+ # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
943
+ # For accelerate compatibility and to not break backward compatibility
944
+ if @decoder.bias.device.type == "meta"
945
+ @bias = @bias
946
+ else
947
+ @bias = @decoder.bias
948
+ end
949
+ end
950
+ end
951
+
952
+ class XLMRobertaForSequenceClassification < XLMRobertaPreTrainedModel
953
+ def initialize(config)
954
+ super(config)
955
+ @num_labels = config.num_labels
956
+ @config = config
957
+
958
+ @roberta = XLMRobertaModel.new(config, add_pooling_layer: false)
959
+ @classifier = XLMRobertaClassificationHead.new(config)
960
+
961
+ # Initialize weights and apply final processing
962
+ post_init
963
+ end
964
+
965
+ def forward(
966
+ input_ids: nil,
967
+ attention_mask: nil,
968
+ token_type_ids: nil,
969
+ position_ids: nil,
970
+ head_mask: nil,
971
+ inputs_embeds: nil,
972
+ labels: nil,
973
+ output_attentions: nil,
974
+ output_hidden_states: nil,
975
+ return_dict: nil
976
+ )
977
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
978
+
979
+ outputs = @roberta.(input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids, position_ids: position_ids, head_mask: head_mask, inputs_embeds: inputs_embeds, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
980
+ sequence_output = outputs[0]
981
+ logits = @classifier.(sequence_output)
982
+
983
+ loss = nil
984
+ if !labels.nil?
985
+ # move labels to correct device to enable model parallelism
986
+ labels = labels.to(logits.device)
987
+ if @config.problem_type.nil?
988
+ if @num_labels == 1
989
+ @problem_type = "regression"
990
+ elsif @num_labels > 1 && labels.dtype == Torch.long || labels.dtype == Torch.int
991
+ @problem_type = "single_label_classification"
992
+ else
993
+ @problem_type = "multi_label_classification"
994
+ end
995
+ end
996
+
997
+ if @config.problem_type == "regression"
998
+ loss_fct = Torch::NN::MSELoss.new
999
+ if @num_labels == 1
1000
+ loss = loss_fct.(logits.squeeze, labels.squeeze)
1001
+ else
1002
+ loss = loss_fct.(logits, labels)
1003
+ end
1004
+ elsif @config.problem_type == "single_label_classification"
1005
+ loss_fct = Torch::NN::CrossEntropyLoss.new
1006
+ loss = loss_fct.(logits.view(-1, @num_labels), labels.view(-1))
1007
+ elsif @config.problem_type == "multi_label_classification"
1008
+ loss_fct = Torch::NN::BCEWithLogitsLoss.new
1009
+ loss = loss_fct.(logits, labels)
1010
+ end
1011
+ end
1012
+
1013
+ if !return_dict
1014
+ output = [logits] + outputs[2..]
1015
+ return !loss.nil? ? [loss] + output : output
1016
+ end
1017
+
1018
+ SequenceClassifierOutput.new(loss: loss, logits: logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
1019
+ end
1020
+ end
1021
+
1022
+ class XLMRobertaForMultipleChoice < XLMRobertaPreTrainedModel
1023
+ def initialize(config)
1024
+ super(config)
1025
+
1026
+ @roberta = XLMRobertaModel.new(config)
1027
+ @dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
1028
+ @classifier = Torch::NN::Linear.new(config.hidden_size, 1)
1029
+
1030
+ # Initialize weights and apply final processing
1031
+ post_init
1032
+ end
1033
+
1034
+ def forward(
1035
+ input_ids: nil,
1036
+ token_type_ids: nil,
1037
+ attention_mask: nil,
1038
+ labels: nil,
1039
+ position_ids: nil,
1040
+ head_mask: nil,
1041
+ inputs_embeds: nil,
1042
+ output_attentions: nil,
1043
+ output_hidden_states: nil,
1044
+ return_dict: nil
1045
+ )
1046
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
1047
+ num_choices = !input_ids.nil? ? input_ids.shape[1] : inputs_embeds.shape[1]
1048
+
1049
+ flat_input_ids = !input_ids.nil? ? input_ids.view(-1, input_ids.size(-1)) : nil
1050
+ flat_position_ids = !position_ids.nil? ? position_ids.view(-1, position_ids.size(-1)) : nil
1051
+ flat_token_type_ids = !token_type_ids.nil? ? token_type_ids.view(-1, token_type_ids.size(-1)) : nil
1052
+ flat_attention_mask = !attention_mask.nil? ? attention_mask.view(-1, attention_mask.size(-1)) : nil
1053
+ flat_inputs_embeds = !inputs_embeds.nil? ? inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) : nil
1054
+
1055
+ outputs = @roberta.(flat_input_ids, position_ids: flat_position_ids, token_type_ids: flat_token_type_ids, attention_mask: flat_attention_mask, head_mask: head_mask, inputs_embeds: flat_inputs_embeds, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
1056
+ pooled_output = outputs[1]
1057
+
1058
+ pooled_output = @dropout.(pooled_output)
1059
+ logits = @classifier.(pooled_output)
1060
+ reshaped_logits = logits.view(-1, num_choices)
1061
+
1062
+ loss = nil
1063
+ if !labels.nil?
1064
+ # move labels to correct device to enable model parallelism
1065
+ labels = labels.to(reshaped_logits.device)
1066
+ loss_fct = Torch::NN::CrossEntropyLoss.new
1067
+ loss = loss_fct.(reshaped_logits, labels)
1068
+ end
1069
+
1070
+ if !return_dict
1071
+ output = [reshaped_logits] + outputs[2..]
1072
+ return !loss.nil? ? [loss] + output : output
1073
+ end
1074
+
1075
+ MultipleChoiceModelOutput.new(loss: loss, logits: reshaped_logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
1076
+ end
1077
+ end
1078
+
1079
+ class XLMRobertaForTokenClassification < XLMRobertaPreTrainedModel
1080
+ def initialize(config)
1081
+ super(config)
1082
+ @num_labels = config.num_labels
1083
+
1084
+ @roberta = XLMRobertaModel.new(config, add_pooling_layer: false)
1085
+ classifier_dropout = !config.classifier_dropout.nil? ? config.classifier_dropout : config.hidden_dropout_prob
1086
+ @dropout = Torch::NN::Dropout.new(p: classifier_dropout)
1087
+ @classifier = Torch::NN::Linear.new(config.hidden_size, config.num_labels)
1088
+
1089
+ # Initialize weights and apply final processing
1090
+ post_init
1091
+ end
1092
+
1093
+ def forward(
1094
+ input_ids: nil,
1095
+ attention_mask: nil,
1096
+ token_type_ids: nil,
1097
+ position_ids: nil,
1098
+ head_mask: nil,
1099
+ inputs_embeds: nil,
1100
+ labels: nil,
1101
+ output_attentions: nil,
1102
+ output_hidden_states: nil,
1103
+ return_dict: nil
1104
+ )
1105
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
1106
+
1107
+ outputs = @roberta.(input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids, position_ids: position_ids, head_mask: head_mask, inputs_embeds: inputs_embeds, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
1108
+
1109
+ sequence_output = outputs[0]
1110
+
1111
+ sequence_output = @dropout.(sequence_output)
1112
+ logits = @classifier.(sequence_output)
1113
+
1114
+ loss = nil
1115
+ if !labels.nil?
1116
+ # move labels to correct device to enable model parallelism
1117
+ labels = labels.to(logits.device)
1118
+ loss_fct = Torch::NN::CrossEntropyLoss.new
1119
+ loss = loss_fct.(logits.view(-1, @num_labels), labels.view(-1))
1120
+ end
1121
+
1122
+ if !return_dict
1123
+ output = [logits] + outputs[2..]
1124
+ return !loss.nil? ? [loss] + output : output
1125
+ end
1126
+
1127
+ TokenClassifierOutput.new(loss: loss, logits: logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
1128
+ end
1129
+ end
1130
+
1131
+ class XLMRobertaClassificationHead < Torch::NN::Module
1132
+ def initialize(config)
1133
+ super()
1134
+ @dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
1135
+ classifier_dropout = !config.classifier_dropout.nil? ? config.classifier_dropout : config.hidden_dropout_prob
1136
+ @dropout = Torch::NN::Dropout.new(p: classifier_dropout)
1137
+ @out_proj = Torch::NN::Linear.new(config.hidden_size, config.num_labels)
1138
+ end
1139
+
1140
+ def forward(features, **kwargs)
1141
+ x = features[0.., 0, 0..]
1142
+ x = @dropout.(x)
1143
+ x = @dense.(x)
1144
+ x = Torch.tanh(x)
1145
+ x = @dropout.(x)
1146
+ x = @out_proj.(x)
1147
+ x
1148
+ end
1149
+ end
1150
+
1151
+ class XLMRobertaForQuestionAnswering < XLMRobertaPreTrainedModel
1152
+ def initialize(config)
1153
+ super(config)
1154
+ @num_labels = config.num_labels
1155
+
1156
+ @roberta = XLMRobertaModel.new(config, add_pooling_layer: false)
1157
+ @qa_outputs = Torch::NN::Linear.new(config.hidden_size, config.num_labels)
1158
+
1159
+ # Initialize weights and apply final processing
1160
+ post_init
1161
+ end
1162
+
1163
+ def forward(
1164
+ input_ids: nil,
1165
+ attention_mask: nil,
1166
+ token_type_ids: nil,
1167
+ position_ids: nil,
1168
+ head_mask: nil,
1169
+ inputs_embeds: nil,
1170
+ start_positions: nil,
1171
+ end_positions: nil,
1172
+ output_attentions: nil,
1173
+ output_hidden_states: nil,
1174
+ return_dict: nil
1175
+ )
1176
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
1177
+
1178
+ outputs = @roberta.(input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids, position_ids: position_ids, head_mask: head_mask, inputs_embeds: inputs_embeds, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
1179
+
1180
+ sequence_output = outputs[0]
1181
+
1182
+ logits = @qa_outputs.(sequence_output)
1183
+ start_logits, end_logits = logits.split(1, dim: -1)
1184
+ start_logits = start_logits.squeeze(-1).contiguous
1185
+ end_logits = end_logits.squeeze(-1).contiguous
1186
+
1187
+ total_loss = nil
1188
+ if !start_positions.nil? && !end_positions.nil?
1189
+ # If we are on multi-GPU, split add a dimension
1190
+ if start_positions.size.length > 1
1191
+ start_positions = start_positions.squeeze(-1)
1192
+ end
1193
+ if end_positions.size.length > 1
1194
+ end_positions = end_positions.squeeze(-1)
1195
+ end
1196
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1197
+ ignored_index = start_logits.size(1)
1198
+ start_positions = start_positions.clamp(0, ignored_index)
1199
+ end_positions = end_positions.clamp(0, ignored_index)
1200
+
1201
+ loss_fct = Torch::NN::CrossEntropyLoss.new(ignore_index: ignored_index)
1202
+ start_loss = loss_fct.(start_logits, start_positions)
1203
+ end_loss = loss_fct.(end_logits, end_positions)
1204
+ total_loss = (start_loss + end_loss) / 2
1205
+ end
1206
+
1207
+ if !return_dict
1208
+ output = [start_logits, end_logits] + outputs[2..]
1209
+ return !total_loss.nil? ? [total_loss] + output : output
1210
+ end
1211
+
1212
+ QuestionAnsweringModelOutput.new(loss: total_loss, start_logits: start_logits, end_logits: end_logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
1213
+ end
1214
+ end
1215
+ end
1216
+ end