transformers-rb 0.1.1 → 0.1.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1216 @@
1
+ # Copyright 2019 Facebook AI Research and the HuggingFace Inc. team.
2
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ module Transformers
17
+ module XlmRoberta
18
+ class XLMRobertaEmbeddings < Torch::NN::Module
19
+ def initialize(config)
20
+ super()
21
+ @word_embeddings = Torch::NN::Embedding.new(config.vocab_size, config.hidden_size, padding_idx: config.pad_token_id)
22
+ @position_embeddings = Torch::NN::Embedding.new(config.max_position_embeddings, config.hidden_size)
23
+ @token_type_embeddings = Torch::NN::Embedding.new(config.type_vocab_size, config.hidden_size)
24
+
25
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
26
+ # any TensorFlow checkpoint file
27
+ @LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
28
+ @dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
29
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
30
+ @position_embedding_type = config.getattr("position_embedding_type", "absolute")
31
+ register_buffer("position_ids", Torch.arange(config.max_position_embeddings).expand([1, -1]), persistent: false)
32
+ register_buffer("token_type_ids", Torch.zeros(@position_ids.size, dtype: Torch.long), persistent: false)
33
+
34
+ @padding_idx = config.pad_token_id
35
+ @position_embeddings = Torch::NN::Embedding.new(config.max_position_embeddings, config.hidden_size, padding_idx: @padding_idx)
36
+ end
37
+
38
+ def forward(input_ids: nil, token_type_ids: nil, position_ids: nil, inputs_embeds: nil, past_key_values_length: 0)
39
+ if position_ids.nil?
40
+ if !input_ids.nil?
41
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
42
+ position_ids = create_position_ids_from_input_ids(input_ids, @padding_idx, past_key_values_length:)
43
+ else
44
+ position_ids = create_position_ids_from_inputs_embeds(inputs_embeds)
45
+ end
46
+ end
47
+
48
+ if !input_ids.nil?
49
+ input_shape = input_ids.size
50
+ else
51
+ input_shape = inputs_embeds.size[...-1]
52
+ end
53
+
54
+ seq_length = input_shape[1]
55
+
56
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
57
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
58
+ # issue #5664
59
+ if token_type_ids.nil?
60
+ if respond_to?(:token_type_ids)
61
+ buffered_token_type_ids = token_type_ids[0.., ...seq_length]
62
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
63
+ token_type_ids = buffered_token_type_ids_expanded
64
+ else
65
+ token_type_ids = Torch.zeros(input_shape, dtype: Torch.long, device: @position_ids.device)
66
+ end
67
+ end
68
+
69
+ if inputs_embeds.nil?
70
+ inputs_embeds = @word_embeddings.(input_ids)
71
+ end
72
+ token_type_embeddings = @token_type_embeddings.(token_type_ids)
73
+
74
+ embeddings = inputs_embeds + token_type_embeddings
75
+ if @position_embedding_type == "absolute"
76
+ position_embeddings = @position_embeddings.(position_ids)
77
+ embeddings += position_embeddings
78
+ end
79
+ embeddings = @LayerNorm.(embeddings)
80
+ embeddings = @dropout.(embeddings)
81
+ embeddings
82
+ end
83
+
84
+ def create_position_ids_from_inputs_embeds(inputs_embeds)
85
+ input_shape = inputs_embeds.size[...-1]
86
+ sequence_length = input_shape[1]
87
+
88
+ position_ids = Torch.arange(@padding_idx + 1, sequence_length + @padding_idx + 1, dtype: Torch.long, device: inputs_embeds.device)
89
+ position_ids.unsqueeze(0).expand(input_shape)
90
+ end
91
+
92
+ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length: 0)
93
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
94
+ mask = input_ids.ne(padding_idx).int
95
+ incremental_indices = (Torch.cumsum(mask, dim: 1).type_as(mask) + past_key_values_length) * mask
96
+ incremental_indices.long + padding_idx
97
+ end
98
+ end
99
+
100
+ class XLMRobertaSelfAttention < Torch::NN::Module
101
+ def initialize(config, position_embedding_type: nil)
102
+ super()
103
+ if config.hidden_size % config.num_attention_heads != 0 && !config.hasattr("embedding_size")
104
+ raise ArgumentError, "The hidden size (#{config.hidden_size}) is not a multiple of the number of attention heads (#{config.num_attention_heads})"
105
+ end
106
+
107
+ @num_attention_heads = config.num_attention_heads
108
+ @attention_head_size = (config.hidden_size / config.num_attention_heads).to_i
109
+ @all_head_size = @num_attention_heads * @attention_head_size
110
+
111
+ @query = Torch::NN::Linear.new(config.hidden_size, @all_head_size)
112
+ @key = Torch::NN::Linear.new(config.hidden_size, @all_head_size)
113
+ @value = Torch::NN::Linear.new(config.hidden_size, @all_head_size)
114
+
115
+ @dropout = Torch::NN::Dropout.new(p: config.attention_probs_dropout_prob)
116
+ @position_embedding_type = position_embedding_type || config.getattr("position_embedding_type", "absolute")
117
+ if @position_embedding_type == "relative_key" || @position_embedding_type == "relative_key_query"
118
+ @max_position_embeddings = config.max_position_embeddings
119
+ @distance_embedding = Torch::NN::Embedding.new((2 * config.max_position_embeddings) - 1, @attention_head_size)
120
+ end
121
+
122
+ @is_decoder = config.is_decoder
123
+ end
124
+
125
+ def transpose_for_scores(x)
126
+ new_x_shape = x.size[...-1] + [@num_attention_heads, @attention_head_size]
127
+ x = x.view(new_x_shape)
128
+ x.permute(0, 2, 1, 3)
129
+ end
130
+
131
+ def forward(
132
+ hidden_states,
133
+ attention_mask: nil,
134
+ head_mask: nil,
135
+ encoder_hidden_states: nil,
136
+ encoder_attention_mask: nil,
137
+ past_key_value: nil,
138
+ output_attentions: false
139
+ )
140
+ mixed_query_layer = @query.(hidden_states)
141
+
142
+ # If this is instantiated as a cross-attention module, the keys
143
+ # and values come from an encoder; the attention mask needs to be
144
+ # such that the encoder's padding tokens are not attended to.
145
+ is_cross_attention = !encoder_hidden_states.nil?
146
+
147
+ if is_cross_attention && !past_key_value.nil?
148
+ # reuse k,v, cross_attentions
149
+ key_layer = past_key_value[0]
150
+ value_layer = past_key_value[1]
151
+ attention_mask = encoder_attention_mask
152
+ elsif is_cross_attention
153
+ key_layer = transpose_for_scores(@key.(encoder_hidden_states))
154
+ value_layer = transpose_for_scores(@value.(encoder_hidden_states))
155
+ attention_mask = encoder_attention_mask
156
+ elsif !past_key_value.nil?
157
+ key_layer = transpose_for_scores(@key.(hidden_states))
158
+ value_layer = transpose_for_scores(@value.(hidden_states))
159
+ key_layer = Torch.cat([past_key_value[0], key_layer], dim: 2)
160
+ value_layer = Torch.cat([past_key_value[1], value_layer], dim: 2)
161
+ else
162
+ key_layer = transpose_for_scores(@key.(hidden_states))
163
+ value_layer = transpose_for_scores(@value.(hidden_states))
164
+ end
165
+
166
+ query_layer = transpose_for_scores(mixed_query_layer)
167
+
168
+ use_cache = !past_key_value.nil?
169
+ if @is_decoder
170
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
171
+ # Further calls to cross_attention layer can then reuse all cross-attention
172
+ # key/value_states (first "if" case)
173
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
174
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
175
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
176
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
177
+ past_key_value = [key_layer, value_layer]
178
+ end
179
+
180
+ # Take the dot product between "query" and "key" to get the raw attention scores.
181
+ attention_scores = Torch.matmul(query_layer, key_layer.transpose(-1, -2))
182
+
183
+ if @position_embedding_type == "relative_key" || @position_embedding_type == "relative_key_query"
184
+ query_length, key_length = [query_layer.shape[2], key_layer.shape[2]]
185
+ if use_cache
186
+ position_ids_l = Torch.tensor(key_length - 1, dtype: Torch.long, device: hidden_states.device).view(-1, 1)
187
+ else
188
+ position_ids_l = Torch.arange(query_length, dtype: Torch.long, device: hidden_states.device).view(-1, 1)
189
+ end
190
+ position_ids_r = Torch.arange(key_length, dtype: Torch.long, device: hidden_states.device).view(1, -1)
191
+ distance = position_ids_l - position_ids_r
192
+
193
+ positional_embedding = @distance_embedding.((distance + @max_position_embeddings) - 1)
194
+ positional_embedding = positional_embedding.to(dtype: query_layer.dtype)
195
+
196
+ if @position_embedding_type == "relative_key"
197
+ relative_position_scores = Torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
198
+ attention_scores = attention_scores + relative_position_scores
199
+ elsif @position_embedding_type == "relative_key_query"
200
+ relative_position_scores_query = Torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
201
+ relative_position_scores_key = Torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
202
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
203
+ end
204
+ end
205
+
206
+ attention_scores = attention_scores / Math.sqrt(@attention_head_size)
207
+ if !attention_mask.nil?
208
+ # Apply the attention mask is (precomputed for all layers in XLMRobertaModel forward() function)
209
+ attention_scores = attention_scores + attention_mask
210
+ end
211
+
212
+ # Normalize the attention scores to probabilities.
213
+ attention_probs = Torch::NN::Functional.softmax(attention_scores, dim: -1)
214
+
215
+ # This is actually dropping out entire tokens to attend to, which might
216
+ # seem a bit unusual, but is taken from the original Transformer paper.
217
+ attention_probs = @dropout.(attention_probs)
218
+
219
+ # Mask heads if we want to
220
+ if !head_mask.nil?
221
+ attention_probs = attention_probs * head_mask
222
+ end
223
+
224
+ context_layer = Torch.matmul(attention_probs, value_layer)
225
+
226
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous
227
+ new_context_layer_shape = context_layer.size[...-2] + [@all_head_size]
228
+ context_layer = context_layer.view(new_context_layer_shape)
229
+
230
+ outputs = output_attentions ? [context_layer, attention_probs] : [context_layer]
231
+
232
+ if @is_decoder
233
+ outputs = outputs + [past_key_value]
234
+ end
235
+ outputs
236
+ end
237
+ end
238
+
239
+ class XLMRobertaSdpaSelfAttention < XLMRobertaSelfAttention
240
+ def initialize(config, position_embedding_type: nil)
241
+ super(config, position_embedding_type: position_embedding_type)
242
+ @dropout_prob = config.attention_probs_dropout_prob
243
+ @require_contiguous_qkv = Packaging::Version.parse(Utils.get_torch_version) < Packaging::Version.parse("2.2.0")
244
+ end
245
+
246
+ # Adapted from XLMRobertaSelfAttention
247
+ def forward(
248
+ hidden_states,
249
+ attention_mask: nil,
250
+ head_mask: nil,
251
+ encoder_hidden_states: nil,
252
+ encoder_attention_mask: nil,
253
+ past_key_value: nil,
254
+ output_attentions: false
255
+ )
256
+ if @position_embedding_type != "absolute" || output_attentions || !head_mask.nil?
257
+ # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented.
258
+ Transformers.logger.warn("XLMRobertaSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support non-absolute `position_embedding_type` or `output_attentions: true` or `head_mask`. Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation: \"eager\"` when loading the model.")
259
+ return super(hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)
260
+ end
261
+
262
+ bsz, tgt_len, _ = hidden_states.size
263
+
264
+ query_layer = transpose_for_scores(@query.(hidden_states))
265
+
266
+ # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention
267
+ # mask needs to be such that the encoder's padding tokens are not attended to.
268
+ is_cross_attention = !encoder_hidden_states.nil?
269
+
270
+ current_states = is_cross_attention ? encoder_hidden_states : hidden_states
271
+ attention_mask = is_cross_attention ? encoder_attention_mask : attention_mask
272
+
273
+ # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning
274
+ if is_cross_attention && past_key_value && past_key_value[0].shape[2] == current_states.shape[1]
275
+ key_layer, value_layer = past_key_value
276
+ else
277
+ key_layer = transpose_for_scores(@key.(current_states))
278
+ value_layer = transpose_for_scores(@value.(current_states))
279
+ if !past_key_value.nil? && !is_cross_attention
280
+ key_layer = Torch.cat([past_key_value[0], key_layer], dim: 2)
281
+ value_layer = Torch.cat([past_key_value[1], value_layer], dim: 2)
282
+ end
283
+ end
284
+
285
+ if @is_decoder
286
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
287
+ # Further calls to cross_attention layer can then reuse all cross-attention
288
+ # key/value_states (first "if" case)
289
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
290
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
291
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
292
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
293
+ past_key_value = [key_layer, value_layer]
294
+ end
295
+
296
+ # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
297
+ # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
298
+ # Reference: https://github.com/pytorch/pytorch/issues/112577
299
+ if @require_contiguous_qkv && query_layer.device.type == "cuda" && !attention_mask.nil?
300
+ query_layer = query_layer.contiguous
301
+ key_layer = key_layer.contiguous
302
+ value_layer = value_layer.contiguous
303
+ end
304
+
305
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
306
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
307
+ # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create
308
+ # a causal mask in case tgt_len == 1.
309
+ is_causal = @is_decoder && !is_cross_attention && attention_mask.nil? && tgt_len > 1 ? true : false
310
+
311
+ attn_output = Torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, attn_mask: attention_mask, dropout_p: @training ? @dropout_prob : 0.0, is_causal: is_causal)
312
+
313
+ attn_output = attn_output.transpose(1, 2)
314
+ attn_output = attn_output.reshape(bsz, tgt_len, @all_head_size)
315
+
316
+ outputs = [attn_output]
317
+ if @is_decoder
318
+ outputs = outputs + [past_key_value]
319
+ end
320
+ outputs
321
+ end
322
+ end
323
+
324
+ class XLMRobertaSelfOutput < Torch::NN::Module
325
+ def initialize(config)
326
+ super()
327
+ @dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
328
+ @LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
329
+ @dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
330
+ end
331
+
332
+ def forward(hidden_states, input_tensor)
333
+ hidden_states = @dense.(hidden_states)
334
+ hidden_states = @dropout.(hidden_states)
335
+ hidden_states = @LayerNorm.(hidden_states + input_tensor)
336
+ hidden_states
337
+ end
338
+ end
339
+
340
+ XLM_ROBERTA_SELF_ATTENTION_CLASSES = {"eager" => XLMRobertaSelfAttention, "sdpa" => XLMRobertaSdpaSelfAttention}
341
+
342
+ class XLMRobertaAttention < Torch::NN::Module
343
+ def initialize(config, position_embedding_type: nil)
344
+ super()
345
+ @self = XLM_ROBERTA_SELF_ATTENTION_CLASSES.fetch(config._attn_implementation).new(config, position_embedding_type: position_embedding_type)
346
+ @output = XLMRobertaSelfOutput.new(config)
347
+ @pruned_heads = Set.new
348
+ end
349
+
350
+ def prune_heads(heads)
351
+ if heads.length == 0
352
+ return
353
+ end
354
+ heads, index = TorchUtils.find_pruneable_heads_and_indices(heads, @self.num_attention_heads, @self.attention_head_size, @pruned_heads)
355
+
356
+ # Prune linear layers
357
+ @query = TorchUtils.prune_linear_layer(@self.query, index)
358
+ @key = TorchUtils.prune_linear_layer(@self.key, index)
359
+ @value = TorchUtils.prune_linear_layer(@self.value, index)
360
+ @dense = TorchUtils.prune_linear_layer(@output.dense, index, dim: 1)
361
+
362
+ # Update hyper params and store pruned heads
363
+ @num_attention_heads = @self.num_attention_heads - heads.length
364
+ @all_head_size = @self.attention_head_size * @self.num_attention_heads
365
+ @pruned_heads = @pruned_heads.union(heads)
366
+ end
367
+
368
+ def forward(
369
+ hidden_states,
370
+ attention_mask: nil,
371
+ head_mask: nil,
372
+ encoder_hidden_states: nil,
373
+ encoder_attention_mask: nil,
374
+ past_key_value: nil,
375
+ output_attentions: false
376
+ )
377
+ self_outputs = @self.(hidden_states, attention_mask:, head_mask:, encoder_hidden_states:, encoder_attention_mask:, past_key_value:, output_attentions:)
378
+ attention_output = @output.(self_outputs[0], hidden_states)
379
+ outputs = [attention_output] + self_outputs[1..]
380
+ outputs
381
+ end
382
+ end
383
+
384
+ class XLMRobertaIntermediate < Torch::NN::Module
385
+ def initialize(config)
386
+ super()
387
+ @dense = Torch::NN::Linear.new(config.hidden_size, config.intermediate_size)
388
+ if config.hidden_act.is_a?(String)
389
+ @intermediate_act_fn = ACT2FN[config.hidden_act]
390
+ else
391
+ @intermediate_act_fn = config.hidden_act
392
+ end
393
+ end
394
+
395
+ def forward(hidden_states)
396
+ hidden_states = @dense.(hidden_states)
397
+ hidden_states = @intermediate_act_fn.(hidden_states)
398
+ hidden_states
399
+ end
400
+ end
401
+
402
+ class XLMRobertaOutput < Torch::NN::Module
403
+ def initialize(config)
404
+ super()
405
+ @dense = Torch::NN::Linear.new(config.intermediate_size, config.hidden_size)
406
+ @LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
407
+ @dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
408
+ end
409
+
410
+ def forward(hidden_states, input_tensor)
411
+ hidden_states = @dense.(hidden_states)
412
+ hidden_states = @dropout.(hidden_states)
413
+ hidden_states = @LayerNorm.(hidden_states + input_tensor)
414
+ hidden_states
415
+ end
416
+ end
417
+
418
+ class XLMRobertaLayer < Torch::NN::Module
419
+ def initialize(config)
420
+ super()
421
+ @chunk_size_feed_forward = config.chunk_size_feed_forward
422
+ @seq_len_dim = 1
423
+ @attention = XLMRobertaAttention.new(config)
424
+ @is_decoder = config.is_decoder
425
+ @add_cross_attention = config.add_cross_attention
426
+ if @add_cross_attention
427
+ if !@is_decoder
428
+ raise ArgumentError, "#{self} should be used as a decoder model if cross attention is added"
429
+ end
430
+ @crossattention = XLMRobertaAttention.new(config, position_embedding_type: "absolute")
431
+ end
432
+ @intermediate = XLMRobertaIntermediate.new(config)
433
+ @output = XLMRobertaOutput.new(config)
434
+ end
435
+
436
+ def forward(
437
+ hidden_states,
438
+ attention_mask: nil,
439
+ head_mask: nil,
440
+ encoder_hidden_states: nil,
441
+ encoder_attention_mask: nil,
442
+ past_key_value: nil,
443
+ output_attentions: false
444
+ )
445
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
446
+ self_attn_past_key_value = !past_key_value.nil? ? past_key_value[...2] : nil
447
+ self_attention_outputs = @attention.(hidden_states, attention_mask:, head_mask:, output_attentions: output_attentions, past_key_value: self_attn_past_key_value)
448
+ attention_output = self_attention_outputs[0]
449
+
450
+ # if decoder, the last output is tuple of self-attn cache
451
+ if @is_decoder
452
+ outputs = self_attention_outputs[1...-1]
453
+ present_key_value = self_attention_outputs[-1]
454
+ else
455
+ outputs = self_attention_outputs[1..]
456
+ end
457
+
458
+ cross_attn_present_key_value = nil
459
+ if @is_decoder && !encoder_hidden_states.nil?
460
+ if instance_variable_defined?(:@crossattention)
461
+ raise ArgumentError, "If `encoder_hidden_states` are passed, #{self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
462
+ end
463
+
464
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
465
+ cross_attn_past_key_value = !past_key_value.nil? ? past_key_value[-2..] : nil
466
+ cross_attention_outputs = @crossattention.(attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, cross_attn_past_key_value, output_attentions)
467
+ attention_output = cross_attention_outputs[0]
468
+ outputs = outputs + cross_attention_outputs[1...-1]
469
+
470
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
471
+ cross_attn_present_key_value = cross_attention_outputs[-1]
472
+ present_key_value = present_key_value + cross_attn_present_key_value
473
+ end
474
+
475
+ layer_output = TorchUtils.apply_chunking_to_forward(method(:feed_forward_chunk), @chunk_size_feed_forward, @seq_len_dim, attention_output)
476
+ outputs = [layer_output] + outputs
477
+
478
+ # if decoder, return the attn key/values as the last output
479
+ if @is_decoder
480
+ outputs = outputs + [present_key_value]
481
+ end
482
+
483
+ outputs
484
+ end
485
+
486
+ def feed_forward_chunk(attention_output)
487
+ intermediate_output = @intermediate.(attention_output)
488
+ layer_output = @output.(intermediate_output, attention_output)
489
+ layer_output
490
+ end
491
+ end
492
+
493
+ class XLMRobertaEncoder < Torch::NN::Module
494
+ def initialize(config)
495
+ super()
496
+ @config = config
497
+ @layer = Torch::NN::ModuleList.new(config.num_hidden_layers.times.map { |_| XLMRobertaLayer.new(config) })
498
+ @gradient_checkpointing = false
499
+ end
500
+
501
+ def forward(
502
+ hidden_states,
503
+ attention_mask: nil,
504
+ head_mask: nil,
505
+ encoder_hidden_states: nil,
506
+ encoder_attention_mask: nil,
507
+ past_key_values: nil,
508
+ use_cache: nil,
509
+ output_attentions: false,
510
+ output_hidden_states: false,
511
+ return_dict: true
512
+ )
513
+ all_hidden_states = output_hidden_states ? [] : nil
514
+ all_self_attentions = output_attentions ? [] : nil
515
+ all_cross_attentions = output_attentions && @config.add_cross_attention ? [] : nil
516
+
517
+ if @gradient_checkpointing && @training
518
+ if use_cache
519
+ Transformers.logger.warn("`use_cache: true` is incompatible with gradient checkpointing. Setting `use_cache: false`...")
520
+ use_cache = false
521
+ end
522
+ end
523
+
524
+ next_decoder_cache = use_cache ? [] : nil
525
+ @layer.each_with_index do |layer_module, i|
526
+ if output_hidden_states
527
+ all_hidden_states = all_hidden_states + [hidden_states]
528
+ end
529
+
530
+ layer_head_mask = !head_mask.nil? ? head_mask[i] : nil
531
+ past_key_value = !past_key_values.nil? ? past_key_values[i] : nil
532
+
533
+ if @gradient_checkpointing && @training
534
+ layer_outputs = _gradient_checkpointing_func(layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)
535
+ else
536
+ layer_outputs = layer_module.(hidden_states, attention_mask:, head_mask: layer_head_mask, encoder_hidden_states:, encoder_attention_mask:, past_key_value:, output_attentions:)
537
+ end
538
+
539
+ hidden_states = layer_outputs[0]
540
+ if use_cache
541
+ next_decoder_cache += [layer_outputs[-1]]
542
+ end
543
+ if output_attentions
544
+ all_self_attentions = all_self_attentions + [layer_outputs[1]]
545
+ if @config.add_cross_attention
546
+ all_cross_attentions = all_cross_attentions + [layer_outputs[2]]
547
+ end
548
+ end
549
+ end
550
+
551
+ if output_hidden_states
552
+ all_hidden_states = all_hidden_states + [hidden_states]
553
+ end
554
+
555
+ if !return_dict
556
+ return Array([hidden_states, next_decoder_cache, all_hidden_states, all_self_attentions, all_cross_attentions].select { |v| !v.nil? })
557
+ end
558
+ BaseModelOutputWithPastAndCrossAttentions.new(last_hidden_state: hidden_states, past_key_values: next_decoder_cache, hidden_states: all_hidden_states, attentions: all_self_attentions, cross_attentions: all_cross_attentions)
559
+ end
560
+ end
561
+
562
+ class XLMRobertaPooler < Torch::NN::Module
563
+ def initialize(config)
564
+ super()
565
+ @dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
566
+ @activation = Torch::NN::Tanh.new
567
+ end
568
+
569
+ def forward(hidden_states)
570
+ # We "pool" the model by simply taking the hidden state corresponding
571
+ # to the first token.
572
+ first_token_tensor = hidden_states[0.., 0]
573
+ pooled_output = @dense.(first_token_tensor)
574
+ pooled_output = @activation.(pooled_output)
575
+ pooled_output
576
+ end
577
+ end
578
+
579
+ class XLMRobertaPreTrainedModel < PreTrainedModel
580
+ self.config_class = XLMRobertaConfig
581
+ self.base_model_prefix = "roberta"
582
+ # self.supports_gradient_checkpointing = true
583
+ # self._no_split_modules = ["XLMRobertaEmbeddings", "XLMRobertaSelfAttention", "XLMRobertaSdpaSelfAttention"]
584
+ # self._supports_sdpa = true
585
+
586
+ def _init_weights(module_)
587
+ if module_.is_a?(Torch::NN::Linear)
588
+ # Slightly different from the TF version which uses truncated_normal for initialization
589
+ # cf https://github.com/pytorch/pytorch/pull/5617
590
+ module_.weight.data.normal!(mean: 0.0, std: @config.initializer_range)
591
+ if !module_.bias.nil?
592
+ module_.bias.data.zero!
593
+ end
594
+ elsif module_.is_a?(Torch::NN::Embedding)
595
+ module_.weight.data.normal!(mean: 0.0, std: @config.initializer_range)
596
+ if !module_.padding_idx.nil?
597
+ module_.weight.data.fetch(module_.padding_idx).zero!
598
+ end
599
+ elsif module_.is_a?(Torch::NN::LayerNorm)
600
+ module_.bias.data.zero!
601
+ module_.weight.data.fill!(1.0)
602
+ end
603
+ end
604
+ end
605
+
606
+ class XLMRobertaModel < XLMRobertaPreTrainedModel
607
+ # self._no_split_modules = ["XLMRobertaEmbeddings", "XLMRobertaLayer"]
608
+
609
+ def initialize(config, add_pooling_layer: true)
610
+ super(config)
611
+ @config = config
612
+
613
+ @embeddings = XLMRobertaEmbeddings.new(config)
614
+ @encoder = XLMRobertaEncoder.new(config)
615
+
616
+ @pooler = add_pooling_layer ? XLMRobertaPooler.new(config) : nil
617
+
618
+ @attn_implementation = config._attn_implementation
619
+ @position_embedding_type = config.position_embedding_type
620
+
621
+ # Initialize weights and apply final processing
622
+ post_init
623
+ end
624
+
625
+ def get_input_embeddings
626
+ @embeddings.word_embeddings
627
+ end
628
+
629
+ def set_input_embeddings(value)
630
+ @word_embeddings = value
631
+ end
632
+
633
+ def _prune_heads(heads_to_prune)
634
+ heads_to_prune.each do |layer, heads|
635
+ @encoder.layer[layer].attention.prune_heads(heads)
636
+ end
637
+ end
638
+
639
+ def forward(
640
+ input_ids,
641
+ attention_mask: nil,
642
+ token_type_ids: nil,
643
+ position_ids: nil,
644
+ head_mask: nil,
645
+ inputs_embeds: nil,
646
+ encoder_hidden_states: nil,
647
+ encoder_attention_mask: nil,
648
+ past_key_values: nil,
649
+ use_cache: nil,
650
+ output_attentions: nil,
651
+ output_hidden_states: nil,
652
+ return_dict: nil
653
+ )
654
+ output_attentions = !output_attentions.nil? ? output_attentions : @config.output_attentions
655
+ output_hidden_states = !output_hidden_states.nil? ? output_hidden_states : @config.output_hidden_states
656
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
657
+
658
+ if @config.is_decoder
659
+ use_cache = !use_cache.nil? ? use_cache : @config.use_cache
660
+ else
661
+ use_cache = false
662
+ end
663
+
664
+ if !input_ids.nil? && !inputs_embeds.nil?
665
+ raise ArgumentError, "You cannot specify both input_ids and inputs_embeds at the same time"
666
+ elsif !input_ids.nil?
667
+ warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
668
+ input_shape = input_ids.size
669
+ elsif !inputs_embeds.nil?
670
+ input_shape = inputs_embeds.size[...-1]
671
+ else
672
+ raise ArgumentError, "You have to specify either input_ids or inputs_embeds"
673
+ end
674
+
675
+ batch_size, seq_length = input_shape
676
+ device = !input_ids.nil? ? input_ids.device : inputs_embeds.device
677
+
678
+ # past_key_values_length
679
+ past_key_values_length = !past_key_values.nil? ? past_key_values[0][0].shape[2] : 0
680
+
681
+ if token_type_ids.nil?
682
+ if @embeddings.respond_to?(:token_type_ids)
683
+ buffered_token_type_ids = @embeddings.token_type_ids[0.., ...seq_length]
684
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
685
+ token_type_ids = buffered_token_type_ids_expanded
686
+ else
687
+ token_type_ids = Torch.zeros(input_shape, dtype: Torch.long, device: device)
688
+ end
689
+ end
690
+
691
+ embedding_output = @embeddings.(input_ids: input_ids, position_ids: position_ids, token_type_ids: token_type_ids, inputs_embeds: inputs_embeds, past_key_values_length: past_key_values_length)
692
+
693
+ if attention_mask.nil?
694
+ attention_mask = Torch.ones([batch_size, seq_length + past_key_values_length], device: device)
695
+ end
696
+
697
+ use_sdpa_attention_masks = @attn_implementation == "sdpa" && @position_embedding_type == "absolute" && head_mask.nil? && !output_attentions
698
+
699
+ # Expand the attention mask
700
+ if use_sdpa_attention_masks && attention_mask.dim == 2
701
+ # Expand the attention mask for SDPA.
702
+ # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
703
+ if @config.is_decoder
704
+ extended_attention_mask = ModelingAttnMaskUtils._prepare_4d_causal_attention_mask_for_sdpa(attention_mask, input_shape, embedding_output, past_key_values_length)
705
+ else
706
+ extended_attention_mask = ModelingAttnMaskUtils._prepare_4d_attention_mask_for_sdpa(attention_mask, embedding_output.dtype, tgt_len: seq_length)
707
+ end
708
+ else
709
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
710
+ # ourselves in which case we just need to make it broadcastable to all heads.
711
+ extended_attention_mask = get_extended_attention_mask(attention_mask, input_shape)
712
+ end
713
+
714
+ # If a 2D or 3D attention mask is provided for the cross-attention
715
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
716
+ if @config.is_decoder && !encoder_hidden_states.nil?
717
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size
718
+ encoder_hidden_shape = [encoder_batch_size, encoder_sequence_length]
719
+ if encoder_attention_mask.nil?
720
+ encoder_attention_mask = Torch.ones(encoder_hidden_shape, device: device)
721
+ end
722
+
723
+ if use_sdpa_attention_masks && encoder_attention_mask.dim == 2
724
+ # Expand the attention mask for SDPA.
725
+ # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
726
+ encoder_extended_attention_mask = ModelingAttnMaskUtils._prepare_4d_attention_mask_for_sdpa(encoder_attention_mask, embedding_output.dtype, tgt_len: seq_length)
727
+ else
728
+ encoder_extended_attention_mask = invert_attention_mask(encoder_attention_mask)
729
+ end
730
+ else
731
+ encoder_extended_attention_mask = nil
732
+ end
733
+
734
+ # Prepare head mask if needed
735
+ # 1.0 in head_mask indicate we keep the head
736
+ # attention_probs has shape bsz x n_heads x N x N
737
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
738
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
739
+ head_mask = get_head_mask(head_mask, @config.num_hidden_layers)
740
+
741
+ encoder_outputs = @encoder.(embedding_output, attention_mask: extended_attention_mask, head_mask: head_mask, encoder_hidden_states: encoder_hidden_states, encoder_attention_mask: encoder_extended_attention_mask, past_key_values: past_key_values, use_cache: use_cache, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
742
+ sequence_output = encoder_outputs[0]
743
+ pooled_output = !@pooler.nil? ? @pooler.(sequence_output) : nil
744
+
745
+ if !return_dict
746
+ return [sequence_output, pooled_output] + encoder_outputs[1..]
747
+ end
748
+
749
+ BaseModelOutputWithPoolingAndCrossAttentions.new(last_hidden_state: sequence_output, pooler_output: pooled_output, past_key_values: encoder_outputs.past_key_values, hidden_states: encoder_outputs.hidden_states, attentions: encoder_outputs.attentions, cross_attentions: encoder_outputs.cross_attentions)
750
+ end
751
+ end
752
+
753
+ class XLMRobertaForCausalLM < XLMRobertaPreTrainedModel
754
+ self._tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
755
+
756
+ def initialize(config)
757
+ super(config)
758
+
759
+ if !config.is_decoder
760
+ Transformers.logger.warn("If you want to use `XLMRobertaLMHeadModel` as a standalone, add `is_decoder=True.`")
761
+ end
762
+
763
+ @roberta = XLMRobertaModel.new(config, add_pooling_layer: false)
764
+ @lm_head = XLMRobertaLMHead.new(config)
765
+
766
+ # Initialize weights and apply final processing
767
+ post_init
768
+ end
769
+
770
+ def get_output_embeddings
771
+ @lm_head.decoder
772
+ end
773
+
774
+ def set_output_embeddings(new_embeddings)
775
+ @decoder = new_embeddings
776
+ end
777
+
778
+ def forward(
779
+ input_ids: nil,
780
+ attention_mask: nil,
781
+ token_type_ids: nil,
782
+ position_ids: nil,
783
+ head_mask: nil,
784
+ inputs_embeds: nil,
785
+ encoder_hidden_states: nil,
786
+ encoder_attention_mask: nil,
787
+ labels: nil,
788
+ past_key_values: nil,
789
+ use_cache: nil,
790
+ output_attentions: nil,
791
+ output_hidden_states: nil,
792
+ return_dict: nil
793
+ )
794
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
795
+ if !labels.nil?
796
+ use_cache = false
797
+ end
798
+
799
+ outputs = @roberta.(input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids, position_ids: position_ids, head_mask: head_mask, inputs_embeds: inputs_embeds, encoder_hidden_states: encoder_hidden_states, encoder_attention_mask: encoder_attention_mask, past_key_values: past_key_values, use_cache: use_cache, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
800
+
801
+ sequence_output = outputs[0]
802
+ prediction_scores = @lm_head.(sequence_output)
803
+
804
+ lm_loss = nil
805
+ if !labels.nil?
806
+ # move labels to correct device to enable model parallelism
807
+ labels = labels.to(prediction_scores.device)
808
+ # we are doing next-token prediction; shift prediction scores and input ids by one
809
+ shifted_prediction_scores = prediction_scores[0.., ...-1, 0..].contiguous
810
+ labels = labels[0.., 1..].contiguous
811
+ loss_fct = Torch::NN::CrossEntropyLoss.new
812
+ lm_loss = loss_fct.(shifted_prediction_scores.view(-1, @config.vocab_size), labels.view(-1))
813
+ end
814
+
815
+ if !return_dict
816
+ output = [prediction_scores] + outputs[2..]
817
+ return !lm_loss.nil? ? [lm_loss] + output : output
818
+ end
819
+
820
+ CausalLMOutputWithCrossAttentions.new(loss: lm_loss, logits: prediction_scores, past_key_values: outputs.past_key_values, hidden_states: outputs.hidden_states, attentions: outputs.attentions, cross_attentions: outputs.cross_attentions)
821
+ end
822
+
823
+ def prepare_inputs_for_generation(input_ids, past_key_values: nil, attention_mask: nil, **model_kwargs)
824
+ input_shape = input_ids.shape
825
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
826
+ if attention_mask.nil?
827
+ attention_mask = input_ids.new_ones(input_shape)
828
+ end
829
+
830
+ # cut decoder_input_ids if past_key_values is used
831
+ if !past_key_values.nil?
832
+ past_length = past_key_values[0][0].shape[2]
833
+
834
+ # Some generation methods already pass only the last input ID
835
+ if input_ids.shape[1] > past_length
836
+ remove_prefix_length = past_length
837
+ else
838
+ # Default to old behavior: keep only final ID
839
+ remove_prefix_length = input_ids.shape[1] - 1
840
+ end
841
+
842
+ input_ids = input_ids[0.., remove_prefix_length..]
843
+ end
844
+
845
+ {"input_ids" => input_ids, "attention_mask" => attention_mask, "past_key_values" => past_key_values}
846
+ end
847
+
848
+ def _reorder_cache(past_key_values, beam_idx)
849
+ reordered_past = []
850
+ past_key_values.each do |layer_past|
851
+ reordered_past += [Array(layer_past.select { |past_state| past_state })]
852
+ end
853
+ reordered_past
854
+ end
855
+ end
856
+
857
+ class XLMRobertaForMaskedLM < XLMRobertaPreTrainedModel
858
+ self._tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
859
+
860
+ def initialize(config)
861
+ super(config)
862
+
863
+ if config.is_decoder
864
+ Transformers.logger.warn("If you want to use `XLMRobertaForMaskedLM` make sure `config.is_decoder: false` for bi-directional self-attention.")
865
+ end
866
+
867
+ @roberta = XLMRobertaModel.new(config, add_pooling_layer: false)
868
+ @lm_head = XLMRobertaLMHead.new(config)
869
+
870
+ # Initialize weights and apply final processing
871
+ post_init
872
+ end
873
+
874
+ def get_output_embeddings
875
+ @lm_head.decoder
876
+ end
877
+
878
+ def set_output_embeddings(new_embeddings)
879
+ @decoder = new_embeddings
880
+ end
881
+
882
+ def forward(
883
+ input_ids: nil,
884
+ attention_mask: nil,
885
+ token_type_ids: nil,
886
+ position_ids: nil,
887
+ head_mask: nil,
888
+ inputs_embeds: nil,
889
+ encoder_hidden_states: nil,
890
+ encoder_attention_mask: nil,
891
+ labels: nil,
892
+ output_attentions: nil,
893
+ output_hidden_states: nil,
894
+ return_dict: nil
895
+ )
896
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
897
+
898
+ outputs = @roberta.(input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids, position_ids: position_ids, head_mask: head_mask, inputs_embeds: inputs_embeds, encoder_hidden_states: encoder_hidden_states, encoder_attention_mask: encoder_attention_mask, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
899
+ sequence_output = outputs[0]
900
+ prediction_scores = @lm_head.(sequence_output)
901
+
902
+ masked_lm_loss = nil
903
+ if !labels.nil?
904
+ # move labels to correct device to enable model parallelism
905
+ labels = labels.to(prediction_scores.device)
906
+ loss_fct = Torch::NN::CrossEntropyLoss.new
907
+ masked_lm_loss = loss_fct.(prediction_scores.view(-1, @config.vocab_size), labels.view(-1))
908
+ end
909
+
910
+ if !return_dict
911
+ output = [prediction_scores] + outputs[2..]
912
+ return !masked_lm_loss.nil? ? [masked_lm_loss] + output : output
913
+ end
914
+
915
+ MaskedLMOutput.new(loss: masked_lm_loss, logits: prediction_scores, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
916
+ end
917
+ end
918
+
919
+ class XLMRobertaLMHead < Torch::NN::Module
920
+ def initialize(config)
921
+ super()
922
+ @dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
923
+ @layer_norm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
924
+
925
+ @decoder = Torch::NN::Linear.new(config.hidden_size, config.vocab_size)
926
+ @bias = Torch::NN::Parameter.new(Torch.zeros(config.vocab_size))
927
+ @bias = @bias
928
+ end
929
+
930
+ def forward(features, **kwargs)
931
+ x = @dense.(features)
932
+ x = Activations.gelu(x)
933
+ x = @layer_norm.(x)
934
+
935
+ # project back to size of vocabulary with bias
936
+ x = @decoder.(x)
937
+
938
+ x
939
+ end
940
+
941
+ def _tie_weights
942
+ # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
943
+ # For accelerate compatibility and to not break backward compatibility
944
+ if @decoder.bias.device.type == "meta"
945
+ @bias = @bias
946
+ else
947
+ @bias = @decoder.bias
948
+ end
949
+ end
950
+ end
951
+
952
+ class XLMRobertaForSequenceClassification < XLMRobertaPreTrainedModel
953
+ def initialize(config)
954
+ super(config)
955
+ @num_labels = config.num_labels
956
+ @config = config
957
+
958
+ @roberta = XLMRobertaModel.new(config, add_pooling_layer: false)
959
+ @classifier = XLMRobertaClassificationHead.new(config)
960
+
961
+ # Initialize weights and apply final processing
962
+ post_init
963
+ end
964
+
965
+ def forward(
966
+ input_ids: nil,
967
+ attention_mask: nil,
968
+ token_type_ids: nil,
969
+ position_ids: nil,
970
+ head_mask: nil,
971
+ inputs_embeds: nil,
972
+ labels: nil,
973
+ output_attentions: nil,
974
+ output_hidden_states: nil,
975
+ return_dict: nil
976
+ )
977
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
978
+
979
+ outputs = @roberta.(input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids, position_ids: position_ids, head_mask: head_mask, inputs_embeds: inputs_embeds, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
980
+ sequence_output = outputs[0]
981
+ logits = @classifier.(sequence_output)
982
+
983
+ loss = nil
984
+ if !labels.nil?
985
+ # move labels to correct device to enable model parallelism
986
+ labels = labels.to(logits.device)
987
+ if @config.problem_type.nil?
988
+ if @num_labels == 1
989
+ @problem_type = "regression"
990
+ elsif @num_labels > 1 && labels.dtype == Torch.long || labels.dtype == Torch.int
991
+ @problem_type = "single_label_classification"
992
+ else
993
+ @problem_type = "multi_label_classification"
994
+ end
995
+ end
996
+
997
+ if @config.problem_type == "regression"
998
+ loss_fct = Torch::NN::MSELoss.new
999
+ if @num_labels == 1
1000
+ loss = loss_fct.(logits.squeeze, labels.squeeze)
1001
+ else
1002
+ loss = loss_fct.(logits, labels)
1003
+ end
1004
+ elsif @config.problem_type == "single_label_classification"
1005
+ loss_fct = Torch::NN::CrossEntropyLoss.new
1006
+ loss = loss_fct.(logits.view(-1, @num_labels), labels.view(-1))
1007
+ elsif @config.problem_type == "multi_label_classification"
1008
+ loss_fct = Torch::NN::BCEWithLogitsLoss.new
1009
+ loss = loss_fct.(logits, labels)
1010
+ end
1011
+ end
1012
+
1013
+ if !return_dict
1014
+ output = [logits] + outputs[2..]
1015
+ return !loss.nil? ? [loss] + output : output
1016
+ end
1017
+
1018
+ SequenceClassifierOutput.new(loss: loss, logits: logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
1019
+ end
1020
+ end
1021
+
1022
+ class XLMRobertaForMultipleChoice < XLMRobertaPreTrainedModel
1023
+ def initialize(config)
1024
+ super(config)
1025
+
1026
+ @roberta = XLMRobertaModel.new(config)
1027
+ @dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
1028
+ @classifier = Torch::NN::Linear.new(config.hidden_size, 1)
1029
+
1030
+ # Initialize weights and apply final processing
1031
+ post_init
1032
+ end
1033
+
1034
+ def forward(
1035
+ input_ids: nil,
1036
+ token_type_ids: nil,
1037
+ attention_mask: nil,
1038
+ labels: nil,
1039
+ position_ids: nil,
1040
+ head_mask: nil,
1041
+ inputs_embeds: nil,
1042
+ output_attentions: nil,
1043
+ output_hidden_states: nil,
1044
+ return_dict: nil
1045
+ )
1046
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
1047
+ num_choices = !input_ids.nil? ? input_ids.shape[1] : inputs_embeds.shape[1]
1048
+
1049
+ flat_input_ids = !input_ids.nil? ? input_ids.view(-1, input_ids.size(-1)) : nil
1050
+ flat_position_ids = !position_ids.nil? ? position_ids.view(-1, position_ids.size(-1)) : nil
1051
+ flat_token_type_ids = !token_type_ids.nil? ? token_type_ids.view(-1, token_type_ids.size(-1)) : nil
1052
+ flat_attention_mask = !attention_mask.nil? ? attention_mask.view(-1, attention_mask.size(-1)) : nil
1053
+ flat_inputs_embeds = !inputs_embeds.nil? ? inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) : nil
1054
+
1055
+ outputs = @roberta.(flat_input_ids, position_ids: flat_position_ids, token_type_ids: flat_token_type_ids, attention_mask: flat_attention_mask, head_mask: head_mask, inputs_embeds: flat_inputs_embeds, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
1056
+ pooled_output = outputs[1]
1057
+
1058
+ pooled_output = @dropout.(pooled_output)
1059
+ logits = @classifier.(pooled_output)
1060
+ reshaped_logits = logits.view(-1, num_choices)
1061
+
1062
+ loss = nil
1063
+ if !labels.nil?
1064
+ # move labels to correct device to enable model parallelism
1065
+ labels = labels.to(reshaped_logits.device)
1066
+ loss_fct = Torch::NN::CrossEntropyLoss.new
1067
+ loss = loss_fct.(reshaped_logits, labels)
1068
+ end
1069
+
1070
+ if !return_dict
1071
+ output = [reshaped_logits] + outputs[2..]
1072
+ return !loss.nil? ? [loss] + output : output
1073
+ end
1074
+
1075
+ MultipleChoiceModelOutput.new(loss: loss, logits: reshaped_logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
1076
+ end
1077
+ end
1078
+
1079
+ class XLMRobertaForTokenClassification < XLMRobertaPreTrainedModel
1080
+ def initialize(config)
1081
+ super(config)
1082
+ @num_labels = config.num_labels
1083
+
1084
+ @roberta = XLMRobertaModel.new(config, add_pooling_layer: false)
1085
+ classifier_dropout = !config.classifier_dropout.nil? ? config.classifier_dropout : config.hidden_dropout_prob
1086
+ @dropout = Torch::NN::Dropout.new(p: classifier_dropout)
1087
+ @classifier = Torch::NN::Linear.new(config.hidden_size, config.num_labels)
1088
+
1089
+ # Initialize weights and apply final processing
1090
+ post_init
1091
+ end
1092
+
1093
+ def forward(
1094
+ input_ids: nil,
1095
+ attention_mask: nil,
1096
+ token_type_ids: nil,
1097
+ position_ids: nil,
1098
+ head_mask: nil,
1099
+ inputs_embeds: nil,
1100
+ labels: nil,
1101
+ output_attentions: nil,
1102
+ output_hidden_states: nil,
1103
+ return_dict: nil
1104
+ )
1105
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
1106
+
1107
+ outputs = @roberta.(input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids, position_ids: position_ids, head_mask: head_mask, inputs_embeds: inputs_embeds, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
1108
+
1109
+ sequence_output = outputs[0]
1110
+
1111
+ sequence_output = @dropout.(sequence_output)
1112
+ logits = @classifier.(sequence_output)
1113
+
1114
+ loss = nil
1115
+ if !labels.nil?
1116
+ # move labels to correct device to enable model parallelism
1117
+ labels = labels.to(logits.device)
1118
+ loss_fct = Torch::NN::CrossEntropyLoss.new
1119
+ loss = loss_fct.(logits.view(-1, @num_labels), labels.view(-1))
1120
+ end
1121
+
1122
+ if !return_dict
1123
+ output = [logits] + outputs[2..]
1124
+ return !loss.nil? ? [loss] + output : output
1125
+ end
1126
+
1127
+ TokenClassifierOutput.new(loss: loss, logits: logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
1128
+ end
1129
+ end
1130
+
1131
+ class XLMRobertaClassificationHead < Torch::NN::Module
1132
+ def initialize(config)
1133
+ super()
1134
+ @dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
1135
+ classifier_dropout = !config.classifier_dropout.nil? ? config.classifier_dropout : config.hidden_dropout_prob
1136
+ @dropout = Torch::NN::Dropout.new(p: classifier_dropout)
1137
+ @out_proj = Torch::NN::Linear.new(config.hidden_size, config.num_labels)
1138
+ end
1139
+
1140
+ def forward(features, **kwargs)
1141
+ x = features[0.., 0, 0..]
1142
+ x = @dropout.(x)
1143
+ x = @dense.(x)
1144
+ x = Torch.tanh(x)
1145
+ x = @dropout.(x)
1146
+ x = @out_proj.(x)
1147
+ x
1148
+ end
1149
+ end
1150
+
1151
+ class XLMRobertaForQuestionAnswering < XLMRobertaPreTrainedModel
1152
+ def initialize(config)
1153
+ super(config)
1154
+ @num_labels = config.num_labels
1155
+
1156
+ @roberta = XLMRobertaModel.new(config, add_pooling_layer: false)
1157
+ @qa_outputs = Torch::NN::Linear.new(config.hidden_size, config.num_labels)
1158
+
1159
+ # Initialize weights and apply final processing
1160
+ post_init
1161
+ end
1162
+
1163
+ def forward(
1164
+ input_ids: nil,
1165
+ attention_mask: nil,
1166
+ token_type_ids: nil,
1167
+ position_ids: nil,
1168
+ head_mask: nil,
1169
+ inputs_embeds: nil,
1170
+ start_positions: nil,
1171
+ end_positions: nil,
1172
+ output_attentions: nil,
1173
+ output_hidden_states: nil,
1174
+ return_dict: nil
1175
+ )
1176
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
1177
+
1178
+ outputs = @roberta.(input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids, position_ids: position_ids, head_mask: head_mask, inputs_embeds: inputs_embeds, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
1179
+
1180
+ sequence_output = outputs[0]
1181
+
1182
+ logits = @qa_outputs.(sequence_output)
1183
+ start_logits, end_logits = logits.split(1, dim: -1)
1184
+ start_logits = start_logits.squeeze(-1).contiguous
1185
+ end_logits = end_logits.squeeze(-1).contiguous
1186
+
1187
+ total_loss = nil
1188
+ if !start_positions.nil? && !end_positions.nil?
1189
+ # If we are on multi-GPU, split add a dimension
1190
+ if start_positions.size.length > 1
1191
+ start_positions = start_positions.squeeze(-1)
1192
+ end
1193
+ if end_positions.size.length > 1
1194
+ end_positions = end_positions.squeeze(-1)
1195
+ end
1196
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1197
+ ignored_index = start_logits.size(1)
1198
+ start_positions = start_positions.clamp(0, ignored_index)
1199
+ end_positions = end_positions.clamp(0, ignored_index)
1200
+
1201
+ loss_fct = Torch::NN::CrossEntropyLoss.new(ignore_index: ignored_index)
1202
+ start_loss = loss_fct.(start_logits, start_positions)
1203
+ end_loss = loss_fct.(end_logits, end_positions)
1204
+ total_loss = (start_loss + end_loss) / 2
1205
+ end
1206
+
1207
+ if !return_dict
1208
+ output = [start_logits, end_logits] + outputs[2..]
1209
+ return !total_loss.nil? ? [total_loss] + output : output
1210
+ end
1211
+
1212
+ QuestionAnsweringModelOutput.new(loss: total_loss, start_logits: start_logits, end_logits: end_logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
1213
+ end
1214
+ end
1215
+ end
1216
+ end