transformers-rb 0.1.1 → 0.1.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +11 -0
- data/README.md +64 -3
- data/lib/transformers/configuration_utils.rb +32 -4
- data/lib/transformers/modeling_utils.rb +10 -3
- data/lib/transformers/models/auto/auto_factory.rb +1 -1
- data/lib/transformers/models/auto/configuration_auto.rb +5 -2
- data/lib/transformers/models/auto/modeling_auto.rb +9 -3
- data/lib/transformers/models/auto/tokenization_auto.rb +5 -2
- data/lib/transformers/models/deberta_v2/configuration_deberta_v2.rb +80 -0
- data/lib/transformers/models/deberta_v2/modeling_deberta_v2.rb +1210 -0
- data/lib/transformers/models/deberta_v2/tokenization_deberta_v2_fast.rb +78 -0
- data/lib/transformers/models/mpnet/configuration_mpnet.rb +61 -0
- data/lib/transformers/models/mpnet/modeling_mpnet.rb +792 -0
- data/lib/transformers/models/mpnet/tokenization_mpnet_fast.rb +106 -0
- data/lib/transformers/models/xlm_roberta/configuration_xlm_roberta.rb +68 -0
- data/lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb +1216 -0
- data/lib/transformers/models/xlm_roberta/tokenization_xlm_roberta_fast.rb +68 -0
- data/lib/transformers/pipelines/_init.rb +16 -5
- data/lib/transformers/pipelines/reranking.rb +33 -0
- data/lib/transformers/version.rb +1 -1
- data/lib/transformers.rb +16 -0
- metadata +15 -5
@@ -0,0 +1,1216 @@
|
|
1
|
+
# Copyright 2019 Facebook AI Research and the HuggingFace Inc. team.
|
2
|
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
module Transformers
|
17
|
+
module XlmRoberta
|
18
|
+
class XLMRobertaEmbeddings < Torch::NN::Module
|
19
|
+
def initialize(config)
|
20
|
+
super()
|
21
|
+
@word_embeddings = Torch::NN::Embedding.new(config.vocab_size, config.hidden_size, padding_idx: config.pad_token_id)
|
22
|
+
@position_embeddings = Torch::NN::Embedding.new(config.max_position_embeddings, config.hidden_size)
|
23
|
+
@token_type_embeddings = Torch::NN::Embedding.new(config.type_vocab_size, config.hidden_size)
|
24
|
+
|
25
|
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
26
|
+
# any TensorFlow checkpoint file
|
27
|
+
@LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
|
28
|
+
@dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
|
29
|
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
30
|
+
@position_embedding_type = config.getattr("position_embedding_type", "absolute")
|
31
|
+
register_buffer("position_ids", Torch.arange(config.max_position_embeddings).expand([1, -1]), persistent: false)
|
32
|
+
register_buffer("token_type_ids", Torch.zeros(@position_ids.size, dtype: Torch.long), persistent: false)
|
33
|
+
|
34
|
+
@padding_idx = config.pad_token_id
|
35
|
+
@position_embeddings = Torch::NN::Embedding.new(config.max_position_embeddings, config.hidden_size, padding_idx: @padding_idx)
|
36
|
+
end
|
37
|
+
|
38
|
+
def forward(input_ids: nil, token_type_ids: nil, position_ids: nil, inputs_embeds: nil, past_key_values_length: 0)
|
39
|
+
if position_ids.nil?
|
40
|
+
if !input_ids.nil?
|
41
|
+
# Create the position ids from the input token ids. Any padded tokens remain padded.
|
42
|
+
position_ids = create_position_ids_from_input_ids(input_ids, @padding_idx, past_key_values_length:)
|
43
|
+
else
|
44
|
+
position_ids = create_position_ids_from_inputs_embeds(inputs_embeds)
|
45
|
+
end
|
46
|
+
end
|
47
|
+
|
48
|
+
if !input_ids.nil?
|
49
|
+
input_shape = input_ids.size
|
50
|
+
else
|
51
|
+
input_shape = inputs_embeds.size[...-1]
|
52
|
+
end
|
53
|
+
|
54
|
+
seq_length = input_shape[1]
|
55
|
+
|
56
|
+
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
57
|
+
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
58
|
+
# issue #5664
|
59
|
+
if token_type_ids.nil?
|
60
|
+
if respond_to?(:token_type_ids)
|
61
|
+
buffered_token_type_ids = token_type_ids[0.., ...seq_length]
|
62
|
+
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
|
63
|
+
token_type_ids = buffered_token_type_ids_expanded
|
64
|
+
else
|
65
|
+
token_type_ids = Torch.zeros(input_shape, dtype: Torch.long, device: @position_ids.device)
|
66
|
+
end
|
67
|
+
end
|
68
|
+
|
69
|
+
if inputs_embeds.nil?
|
70
|
+
inputs_embeds = @word_embeddings.(input_ids)
|
71
|
+
end
|
72
|
+
token_type_embeddings = @token_type_embeddings.(token_type_ids)
|
73
|
+
|
74
|
+
embeddings = inputs_embeds + token_type_embeddings
|
75
|
+
if @position_embedding_type == "absolute"
|
76
|
+
position_embeddings = @position_embeddings.(position_ids)
|
77
|
+
embeddings += position_embeddings
|
78
|
+
end
|
79
|
+
embeddings = @LayerNorm.(embeddings)
|
80
|
+
embeddings = @dropout.(embeddings)
|
81
|
+
embeddings
|
82
|
+
end
|
83
|
+
|
84
|
+
def create_position_ids_from_inputs_embeds(inputs_embeds)
|
85
|
+
input_shape = inputs_embeds.size[...-1]
|
86
|
+
sequence_length = input_shape[1]
|
87
|
+
|
88
|
+
position_ids = Torch.arange(@padding_idx + 1, sequence_length + @padding_idx + 1, dtype: Torch.long, device: inputs_embeds.device)
|
89
|
+
position_ids.unsqueeze(0).expand(input_shape)
|
90
|
+
end
|
91
|
+
|
92
|
+
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length: 0)
|
93
|
+
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
|
94
|
+
mask = input_ids.ne(padding_idx).int
|
95
|
+
incremental_indices = (Torch.cumsum(mask, dim: 1).type_as(mask) + past_key_values_length) * mask
|
96
|
+
incremental_indices.long + padding_idx
|
97
|
+
end
|
98
|
+
end
|
99
|
+
|
100
|
+
class XLMRobertaSelfAttention < Torch::NN::Module
|
101
|
+
def initialize(config, position_embedding_type: nil)
|
102
|
+
super()
|
103
|
+
if config.hidden_size % config.num_attention_heads != 0 && !config.hasattr("embedding_size")
|
104
|
+
raise ArgumentError, "The hidden size (#{config.hidden_size}) is not a multiple of the number of attention heads (#{config.num_attention_heads})"
|
105
|
+
end
|
106
|
+
|
107
|
+
@num_attention_heads = config.num_attention_heads
|
108
|
+
@attention_head_size = (config.hidden_size / config.num_attention_heads).to_i
|
109
|
+
@all_head_size = @num_attention_heads * @attention_head_size
|
110
|
+
|
111
|
+
@query = Torch::NN::Linear.new(config.hidden_size, @all_head_size)
|
112
|
+
@key = Torch::NN::Linear.new(config.hidden_size, @all_head_size)
|
113
|
+
@value = Torch::NN::Linear.new(config.hidden_size, @all_head_size)
|
114
|
+
|
115
|
+
@dropout = Torch::NN::Dropout.new(p: config.attention_probs_dropout_prob)
|
116
|
+
@position_embedding_type = position_embedding_type || config.getattr("position_embedding_type", "absolute")
|
117
|
+
if @position_embedding_type == "relative_key" || @position_embedding_type == "relative_key_query"
|
118
|
+
@max_position_embeddings = config.max_position_embeddings
|
119
|
+
@distance_embedding = Torch::NN::Embedding.new((2 * config.max_position_embeddings) - 1, @attention_head_size)
|
120
|
+
end
|
121
|
+
|
122
|
+
@is_decoder = config.is_decoder
|
123
|
+
end
|
124
|
+
|
125
|
+
def transpose_for_scores(x)
|
126
|
+
new_x_shape = x.size[...-1] + [@num_attention_heads, @attention_head_size]
|
127
|
+
x = x.view(new_x_shape)
|
128
|
+
x.permute(0, 2, 1, 3)
|
129
|
+
end
|
130
|
+
|
131
|
+
def forward(
|
132
|
+
hidden_states,
|
133
|
+
attention_mask: nil,
|
134
|
+
head_mask: nil,
|
135
|
+
encoder_hidden_states: nil,
|
136
|
+
encoder_attention_mask: nil,
|
137
|
+
past_key_value: nil,
|
138
|
+
output_attentions: false
|
139
|
+
)
|
140
|
+
mixed_query_layer = @query.(hidden_states)
|
141
|
+
|
142
|
+
# If this is instantiated as a cross-attention module, the keys
|
143
|
+
# and values come from an encoder; the attention mask needs to be
|
144
|
+
# such that the encoder's padding tokens are not attended to.
|
145
|
+
is_cross_attention = !encoder_hidden_states.nil?
|
146
|
+
|
147
|
+
if is_cross_attention && !past_key_value.nil?
|
148
|
+
# reuse k,v, cross_attentions
|
149
|
+
key_layer = past_key_value[0]
|
150
|
+
value_layer = past_key_value[1]
|
151
|
+
attention_mask = encoder_attention_mask
|
152
|
+
elsif is_cross_attention
|
153
|
+
key_layer = transpose_for_scores(@key.(encoder_hidden_states))
|
154
|
+
value_layer = transpose_for_scores(@value.(encoder_hidden_states))
|
155
|
+
attention_mask = encoder_attention_mask
|
156
|
+
elsif !past_key_value.nil?
|
157
|
+
key_layer = transpose_for_scores(@key.(hidden_states))
|
158
|
+
value_layer = transpose_for_scores(@value.(hidden_states))
|
159
|
+
key_layer = Torch.cat([past_key_value[0], key_layer], dim: 2)
|
160
|
+
value_layer = Torch.cat([past_key_value[1], value_layer], dim: 2)
|
161
|
+
else
|
162
|
+
key_layer = transpose_for_scores(@key.(hidden_states))
|
163
|
+
value_layer = transpose_for_scores(@value.(hidden_states))
|
164
|
+
end
|
165
|
+
|
166
|
+
query_layer = transpose_for_scores(mixed_query_layer)
|
167
|
+
|
168
|
+
use_cache = !past_key_value.nil?
|
169
|
+
if @is_decoder
|
170
|
+
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
171
|
+
# Further calls to cross_attention layer can then reuse all cross-attention
|
172
|
+
# key/value_states (first "if" case)
|
173
|
+
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
174
|
+
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
175
|
+
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
176
|
+
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
177
|
+
past_key_value = [key_layer, value_layer]
|
178
|
+
end
|
179
|
+
|
180
|
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
181
|
+
attention_scores = Torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
182
|
+
|
183
|
+
if @position_embedding_type == "relative_key" || @position_embedding_type == "relative_key_query"
|
184
|
+
query_length, key_length = [query_layer.shape[2], key_layer.shape[2]]
|
185
|
+
if use_cache
|
186
|
+
position_ids_l = Torch.tensor(key_length - 1, dtype: Torch.long, device: hidden_states.device).view(-1, 1)
|
187
|
+
else
|
188
|
+
position_ids_l = Torch.arange(query_length, dtype: Torch.long, device: hidden_states.device).view(-1, 1)
|
189
|
+
end
|
190
|
+
position_ids_r = Torch.arange(key_length, dtype: Torch.long, device: hidden_states.device).view(1, -1)
|
191
|
+
distance = position_ids_l - position_ids_r
|
192
|
+
|
193
|
+
positional_embedding = @distance_embedding.((distance + @max_position_embeddings) - 1)
|
194
|
+
positional_embedding = positional_embedding.to(dtype: query_layer.dtype)
|
195
|
+
|
196
|
+
if @position_embedding_type == "relative_key"
|
197
|
+
relative_position_scores = Torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
198
|
+
attention_scores = attention_scores + relative_position_scores
|
199
|
+
elsif @position_embedding_type == "relative_key_query"
|
200
|
+
relative_position_scores_query = Torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
201
|
+
relative_position_scores_key = Torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
202
|
+
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
203
|
+
end
|
204
|
+
end
|
205
|
+
|
206
|
+
attention_scores = attention_scores / Math.sqrt(@attention_head_size)
|
207
|
+
if !attention_mask.nil?
|
208
|
+
# Apply the attention mask is (precomputed for all layers in XLMRobertaModel forward() function)
|
209
|
+
attention_scores = attention_scores + attention_mask
|
210
|
+
end
|
211
|
+
|
212
|
+
# Normalize the attention scores to probabilities.
|
213
|
+
attention_probs = Torch::NN::Functional.softmax(attention_scores, dim: -1)
|
214
|
+
|
215
|
+
# This is actually dropping out entire tokens to attend to, which might
|
216
|
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
217
|
+
attention_probs = @dropout.(attention_probs)
|
218
|
+
|
219
|
+
# Mask heads if we want to
|
220
|
+
if !head_mask.nil?
|
221
|
+
attention_probs = attention_probs * head_mask
|
222
|
+
end
|
223
|
+
|
224
|
+
context_layer = Torch.matmul(attention_probs, value_layer)
|
225
|
+
|
226
|
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous
|
227
|
+
new_context_layer_shape = context_layer.size[...-2] + [@all_head_size]
|
228
|
+
context_layer = context_layer.view(new_context_layer_shape)
|
229
|
+
|
230
|
+
outputs = output_attentions ? [context_layer, attention_probs] : [context_layer]
|
231
|
+
|
232
|
+
if @is_decoder
|
233
|
+
outputs = outputs + [past_key_value]
|
234
|
+
end
|
235
|
+
outputs
|
236
|
+
end
|
237
|
+
end
|
238
|
+
|
239
|
+
class XLMRobertaSdpaSelfAttention < XLMRobertaSelfAttention
|
240
|
+
def initialize(config, position_embedding_type: nil)
|
241
|
+
super(config, position_embedding_type: position_embedding_type)
|
242
|
+
@dropout_prob = config.attention_probs_dropout_prob
|
243
|
+
@require_contiguous_qkv = Packaging::Version.parse(Utils.get_torch_version) < Packaging::Version.parse("2.2.0")
|
244
|
+
end
|
245
|
+
|
246
|
+
# Adapted from XLMRobertaSelfAttention
|
247
|
+
def forward(
|
248
|
+
hidden_states,
|
249
|
+
attention_mask: nil,
|
250
|
+
head_mask: nil,
|
251
|
+
encoder_hidden_states: nil,
|
252
|
+
encoder_attention_mask: nil,
|
253
|
+
past_key_value: nil,
|
254
|
+
output_attentions: false
|
255
|
+
)
|
256
|
+
if @position_embedding_type != "absolute" || output_attentions || !head_mask.nil?
|
257
|
+
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented.
|
258
|
+
Transformers.logger.warn("XLMRobertaSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support non-absolute `position_embedding_type` or `output_attentions: true` or `head_mask`. Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation: \"eager\"` when loading the model.")
|
259
|
+
return super(hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)
|
260
|
+
end
|
261
|
+
|
262
|
+
bsz, tgt_len, _ = hidden_states.size
|
263
|
+
|
264
|
+
query_layer = transpose_for_scores(@query.(hidden_states))
|
265
|
+
|
266
|
+
# If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention
|
267
|
+
# mask needs to be such that the encoder's padding tokens are not attended to.
|
268
|
+
is_cross_attention = !encoder_hidden_states.nil?
|
269
|
+
|
270
|
+
current_states = is_cross_attention ? encoder_hidden_states : hidden_states
|
271
|
+
attention_mask = is_cross_attention ? encoder_attention_mask : attention_mask
|
272
|
+
|
273
|
+
# Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning
|
274
|
+
if is_cross_attention && past_key_value && past_key_value[0].shape[2] == current_states.shape[1]
|
275
|
+
key_layer, value_layer = past_key_value
|
276
|
+
else
|
277
|
+
key_layer = transpose_for_scores(@key.(current_states))
|
278
|
+
value_layer = transpose_for_scores(@value.(current_states))
|
279
|
+
if !past_key_value.nil? && !is_cross_attention
|
280
|
+
key_layer = Torch.cat([past_key_value[0], key_layer], dim: 2)
|
281
|
+
value_layer = Torch.cat([past_key_value[1], value_layer], dim: 2)
|
282
|
+
end
|
283
|
+
end
|
284
|
+
|
285
|
+
if @is_decoder
|
286
|
+
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
287
|
+
# Further calls to cross_attention layer can then reuse all cross-attention
|
288
|
+
# key/value_states (first "if" case)
|
289
|
+
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
290
|
+
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
291
|
+
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
292
|
+
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
293
|
+
past_key_value = [key_layer, value_layer]
|
294
|
+
end
|
295
|
+
|
296
|
+
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
|
297
|
+
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
|
298
|
+
# Reference: https://github.com/pytorch/pytorch/issues/112577
|
299
|
+
if @require_contiguous_qkv && query_layer.device.type == "cuda" && !attention_mask.nil?
|
300
|
+
query_layer = query_layer.contiguous
|
301
|
+
key_layer = key_layer.contiguous
|
302
|
+
value_layer = value_layer.contiguous
|
303
|
+
end
|
304
|
+
|
305
|
+
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
306
|
+
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
307
|
+
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create
|
308
|
+
# a causal mask in case tgt_len == 1.
|
309
|
+
is_causal = @is_decoder && !is_cross_attention && attention_mask.nil? && tgt_len > 1 ? true : false
|
310
|
+
|
311
|
+
attn_output = Torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, attn_mask: attention_mask, dropout_p: @training ? @dropout_prob : 0.0, is_causal: is_causal)
|
312
|
+
|
313
|
+
attn_output = attn_output.transpose(1, 2)
|
314
|
+
attn_output = attn_output.reshape(bsz, tgt_len, @all_head_size)
|
315
|
+
|
316
|
+
outputs = [attn_output]
|
317
|
+
if @is_decoder
|
318
|
+
outputs = outputs + [past_key_value]
|
319
|
+
end
|
320
|
+
outputs
|
321
|
+
end
|
322
|
+
end
|
323
|
+
|
324
|
+
class XLMRobertaSelfOutput < Torch::NN::Module
|
325
|
+
def initialize(config)
|
326
|
+
super()
|
327
|
+
@dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
|
328
|
+
@LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
|
329
|
+
@dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
|
330
|
+
end
|
331
|
+
|
332
|
+
def forward(hidden_states, input_tensor)
|
333
|
+
hidden_states = @dense.(hidden_states)
|
334
|
+
hidden_states = @dropout.(hidden_states)
|
335
|
+
hidden_states = @LayerNorm.(hidden_states + input_tensor)
|
336
|
+
hidden_states
|
337
|
+
end
|
338
|
+
end
|
339
|
+
|
340
|
+
XLM_ROBERTA_SELF_ATTENTION_CLASSES = {"eager" => XLMRobertaSelfAttention, "sdpa" => XLMRobertaSdpaSelfAttention}
|
341
|
+
|
342
|
+
class XLMRobertaAttention < Torch::NN::Module
|
343
|
+
def initialize(config, position_embedding_type: nil)
|
344
|
+
super()
|
345
|
+
@self = XLM_ROBERTA_SELF_ATTENTION_CLASSES.fetch(config._attn_implementation).new(config, position_embedding_type: position_embedding_type)
|
346
|
+
@output = XLMRobertaSelfOutput.new(config)
|
347
|
+
@pruned_heads = Set.new
|
348
|
+
end
|
349
|
+
|
350
|
+
def prune_heads(heads)
|
351
|
+
if heads.length == 0
|
352
|
+
return
|
353
|
+
end
|
354
|
+
heads, index = TorchUtils.find_pruneable_heads_and_indices(heads, @self.num_attention_heads, @self.attention_head_size, @pruned_heads)
|
355
|
+
|
356
|
+
# Prune linear layers
|
357
|
+
@query = TorchUtils.prune_linear_layer(@self.query, index)
|
358
|
+
@key = TorchUtils.prune_linear_layer(@self.key, index)
|
359
|
+
@value = TorchUtils.prune_linear_layer(@self.value, index)
|
360
|
+
@dense = TorchUtils.prune_linear_layer(@output.dense, index, dim: 1)
|
361
|
+
|
362
|
+
# Update hyper params and store pruned heads
|
363
|
+
@num_attention_heads = @self.num_attention_heads - heads.length
|
364
|
+
@all_head_size = @self.attention_head_size * @self.num_attention_heads
|
365
|
+
@pruned_heads = @pruned_heads.union(heads)
|
366
|
+
end
|
367
|
+
|
368
|
+
def forward(
|
369
|
+
hidden_states,
|
370
|
+
attention_mask: nil,
|
371
|
+
head_mask: nil,
|
372
|
+
encoder_hidden_states: nil,
|
373
|
+
encoder_attention_mask: nil,
|
374
|
+
past_key_value: nil,
|
375
|
+
output_attentions: false
|
376
|
+
)
|
377
|
+
self_outputs = @self.(hidden_states, attention_mask:, head_mask:, encoder_hidden_states:, encoder_attention_mask:, past_key_value:, output_attentions:)
|
378
|
+
attention_output = @output.(self_outputs[0], hidden_states)
|
379
|
+
outputs = [attention_output] + self_outputs[1..]
|
380
|
+
outputs
|
381
|
+
end
|
382
|
+
end
|
383
|
+
|
384
|
+
class XLMRobertaIntermediate < Torch::NN::Module
|
385
|
+
def initialize(config)
|
386
|
+
super()
|
387
|
+
@dense = Torch::NN::Linear.new(config.hidden_size, config.intermediate_size)
|
388
|
+
if config.hidden_act.is_a?(String)
|
389
|
+
@intermediate_act_fn = ACT2FN[config.hidden_act]
|
390
|
+
else
|
391
|
+
@intermediate_act_fn = config.hidden_act
|
392
|
+
end
|
393
|
+
end
|
394
|
+
|
395
|
+
def forward(hidden_states)
|
396
|
+
hidden_states = @dense.(hidden_states)
|
397
|
+
hidden_states = @intermediate_act_fn.(hidden_states)
|
398
|
+
hidden_states
|
399
|
+
end
|
400
|
+
end
|
401
|
+
|
402
|
+
class XLMRobertaOutput < Torch::NN::Module
|
403
|
+
def initialize(config)
|
404
|
+
super()
|
405
|
+
@dense = Torch::NN::Linear.new(config.intermediate_size, config.hidden_size)
|
406
|
+
@LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
|
407
|
+
@dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
|
408
|
+
end
|
409
|
+
|
410
|
+
def forward(hidden_states, input_tensor)
|
411
|
+
hidden_states = @dense.(hidden_states)
|
412
|
+
hidden_states = @dropout.(hidden_states)
|
413
|
+
hidden_states = @LayerNorm.(hidden_states + input_tensor)
|
414
|
+
hidden_states
|
415
|
+
end
|
416
|
+
end
|
417
|
+
|
418
|
+
class XLMRobertaLayer < Torch::NN::Module
|
419
|
+
def initialize(config)
|
420
|
+
super()
|
421
|
+
@chunk_size_feed_forward = config.chunk_size_feed_forward
|
422
|
+
@seq_len_dim = 1
|
423
|
+
@attention = XLMRobertaAttention.new(config)
|
424
|
+
@is_decoder = config.is_decoder
|
425
|
+
@add_cross_attention = config.add_cross_attention
|
426
|
+
if @add_cross_attention
|
427
|
+
if !@is_decoder
|
428
|
+
raise ArgumentError, "#{self} should be used as a decoder model if cross attention is added"
|
429
|
+
end
|
430
|
+
@crossattention = XLMRobertaAttention.new(config, position_embedding_type: "absolute")
|
431
|
+
end
|
432
|
+
@intermediate = XLMRobertaIntermediate.new(config)
|
433
|
+
@output = XLMRobertaOutput.new(config)
|
434
|
+
end
|
435
|
+
|
436
|
+
def forward(
|
437
|
+
hidden_states,
|
438
|
+
attention_mask: nil,
|
439
|
+
head_mask: nil,
|
440
|
+
encoder_hidden_states: nil,
|
441
|
+
encoder_attention_mask: nil,
|
442
|
+
past_key_value: nil,
|
443
|
+
output_attentions: false
|
444
|
+
)
|
445
|
+
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
446
|
+
self_attn_past_key_value = !past_key_value.nil? ? past_key_value[...2] : nil
|
447
|
+
self_attention_outputs = @attention.(hidden_states, attention_mask:, head_mask:, output_attentions: output_attentions, past_key_value: self_attn_past_key_value)
|
448
|
+
attention_output = self_attention_outputs[0]
|
449
|
+
|
450
|
+
# if decoder, the last output is tuple of self-attn cache
|
451
|
+
if @is_decoder
|
452
|
+
outputs = self_attention_outputs[1...-1]
|
453
|
+
present_key_value = self_attention_outputs[-1]
|
454
|
+
else
|
455
|
+
outputs = self_attention_outputs[1..]
|
456
|
+
end
|
457
|
+
|
458
|
+
cross_attn_present_key_value = nil
|
459
|
+
if @is_decoder && !encoder_hidden_states.nil?
|
460
|
+
if instance_variable_defined?(:@crossattention)
|
461
|
+
raise ArgumentError, "If `encoder_hidden_states` are passed, #{self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
|
462
|
+
end
|
463
|
+
|
464
|
+
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
465
|
+
cross_attn_past_key_value = !past_key_value.nil? ? past_key_value[-2..] : nil
|
466
|
+
cross_attention_outputs = @crossattention.(attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, cross_attn_past_key_value, output_attentions)
|
467
|
+
attention_output = cross_attention_outputs[0]
|
468
|
+
outputs = outputs + cross_attention_outputs[1...-1]
|
469
|
+
|
470
|
+
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
471
|
+
cross_attn_present_key_value = cross_attention_outputs[-1]
|
472
|
+
present_key_value = present_key_value + cross_attn_present_key_value
|
473
|
+
end
|
474
|
+
|
475
|
+
layer_output = TorchUtils.apply_chunking_to_forward(method(:feed_forward_chunk), @chunk_size_feed_forward, @seq_len_dim, attention_output)
|
476
|
+
outputs = [layer_output] + outputs
|
477
|
+
|
478
|
+
# if decoder, return the attn key/values as the last output
|
479
|
+
if @is_decoder
|
480
|
+
outputs = outputs + [present_key_value]
|
481
|
+
end
|
482
|
+
|
483
|
+
outputs
|
484
|
+
end
|
485
|
+
|
486
|
+
def feed_forward_chunk(attention_output)
|
487
|
+
intermediate_output = @intermediate.(attention_output)
|
488
|
+
layer_output = @output.(intermediate_output, attention_output)
|
489
|
+
layer_output
|
490
|
+
end
|
491
|
+
end
|
492
|
+
|
493
|
+
class XLMRobertaEncoder < Torch::NN::Module
|
494
|
+
def initialize(config)
|
495
|
+
super()
|
496
|
+
@config = config
|
497
|
+
@layer = Torch::NN::ModuleList.new(config.num_hidden_layers.times.map { |_| XLMRobertaLayer.new(config) })
|
498
|
+
@gradient_checkpointing = false
|
499
|
+
end
|
500
|
+
|
501
|
+
def forward(
|
502
|
+
hidden_states,
|
503
|
+
attention_mask: nil,
|
504
|
+
head_mask: nil,
|
505
|
+
encoder_hidden_states: nil,
|
506
|
+
encoder_attention_mask: nil,
|
507
|
+
past_key_values: nil,
|
508
|
+
use_cache: nil,
|
509
|
+
output_attentions: false,
|
510
|
+
output_hidden_states: false,
|
511
|
+
return_dict: true
|
512
|
+
)
|
513
|
+
all_hidden_states = output_hidden_states ? [] : nil
|
514
|
+
all_self_attentions = output_attentions ? [] : nil
|
515
|
+
all_cross_attentions = output_attentions && @config.add_cross_attention ? [] : nil
|
516
|
+
|
517
|
+
if @gradient_checkpointing && @training
|
518
|
+
if use_cache
|
519
|
+
Transformers.logger.warn("`use_cache: true` is incompatible with gradient checkpointing. Setting `use_cache: false`...")
|
520
|
+
use_cache = false
|
521
|
+
end
|
522
|
+
end
|
523
|
+
|
524
|
+
next_decoder_cache = use_cache ? [] : nil
|
525
|
+
@layer.each_with_index do |layer_module, i|
|
526
|
+
if output_hidden_states
|
527
|
+
all_hidden_states = all_hidden_states + [hidden_states]
|
528
|
+
end
|
529
|
+
|
530
|
+
layer_head_mask = !head_mask.nil? ? head_mask[i] : nil
|
531
|
+
past_key_value = !past_key_values.nil? ? past_key_values[i] : nil
|
532
|
+
|
533
|
+
if @gradient_checkpointing && @training
|
534
|
+
layer_outputs = _gradient_checkpointing_func(layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)
|
535
|
+
else
|
536
|
+
layer_outputs = layer_module.(hidden_states, attention_mask:, head_mask: layer_head_mask, encoder_hidden_states:, encoder_attention_mask:, past_key_value:, output_attentions:)
|
537
|
+
end
|
538
|
+
|
539
|
+
hidden_states = layer_outputs[0]
|
540
|
+
if use_cache
|
541
|
+
next_decoder_cache += [layer_outputs[-1]]
|
542
|
+
end
|
543
|
+
if output_attentions
|
544
|
+
all_self_attentions = all_self_attentions + [layer_outputs[1]]
|
545
|
+
if @config.add_cross_attention
|
546
|
+
all_cross_attentions = all_cross_attentions + [layer_outputs[2]]
|
547
|
+
end
|
548
|
+
end
|
549
|
+
end
|
550
|
+
|
551
|
+
if output_hidden_states
|
552
|
+
all_hidden_states = all_hidden_states + [hidden_states]
|
553
|
+
end
|
554
|
+
|
555
|
+
if !return_dict
|
556
|
+
return Array([hidden_states, next_decoder_cache, all_hidden_states, all_self_attentions, all_cross_attentions].select { |v| !v.nil? })
|
557
|
+
end
|
558
|
+
BaseModelOutputWithPastAndCrossAttentions.new(last_hidden_state: hidden_states, past_key_values: next_decoder_cache, hidden_states: all_hidden_states, attentions: all_self_attentions, cross_attentions: all_cross_attentions)
|
559
|
+
end
|
560
|
+
end
|
561
|
+
|
562
|
+
class XLMRobertaPooler < Torch::NN::Module
|
563
|
+
def initialize(config)
|
564
|
+
super()
|
565
|
+
@dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
|
566
|
+
@activation = Torch::NN::Tanh.new
|
567
|
+
end
|
568
|
+
|
569
|
+
def forward(hidden_states)
|
570
|
+
# We "pool" the model by simply taking the hidden state corresponding
|
571
|
+
# to the first token.
|
572
|
+
first_token_tensor = hidden_states[0.., 0]
|
573
|
+
pooled_output = @dense.(first_token_tensor)
|
574
|
+
pooled_output = @activation.(pooled_output)
|
575
|
+
pooled_output
|
576
|
+
end
|
577
|
+
end
|
578
|
+
|
579
|
+
class XLMRobertaPreTrainedModel < PreTrainedModel
|
580
|
+
self.config_class = XLMRobertaConfig
|
581
|
+
self.base_model_prefix = "roberta"
|
582
|
+
# self.supports_gradient_checkpointing = true
|
583
|
+
# self._no_split_modules = ["XLMRobertaEmbeddings", "XLMRobertaSelfAttention", "XLMRobertaSdpaSelfAttention"]
|
584
|
+
# self._supports_sdpa = true
|
585
|
+
|
586
|
+
def _init_weights(module_)
|
587
|
+
if module_.is_a?(Torch::NN::Linear)
|
588
|
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
589
|
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
590
|
+
module_.weight.data.normal!(mean: 0.0, std: @config.initializer_range)
|
591
|
+
if !module_.bias.nil?
|
592
|
+
module_.bias.data.zero!
|
593
|
+
end
|
594
|
+
elsif module_.is_a?(Torch::NN::Embedding)
|
595
|
+
module_.weight.data.normal!(mean: 0.0, std: @config.initializer_range)
|
596
|
+
if !module_.padding_idx.nil?
|
597
|
+
module_.weight.data.fetch(module_.padding_idx).zero!
|
598
|
+
end
|
599
|
+
elsif module_.is_a?(Torch::NN::LayerNorm)
|
600
|
+
module_.bias.data.zero!
|
601
|
+
module_.weight.data.fill!(1.0)
|
602
|
+
end
|
603
|
+
end
|
604
|
+
end
|
605
|
+
|
606
|
+
class XLMRobertaModel < XLMRobertaPreTrainedModel
|
607
|
+
# self._no_split_modules = ["XLMRobertaEmbeddings", "XLMRobertaLayer"]
|
608
|
+
|
609
|
+
def initialize(config, add_pooling_layer: true)
|
610
|
+
super(config)
|
611
|
+
@config = config
|
612
|
+
|
613
|
+
@embeddings = XLMRobertaEmbeddings.new(config)
|
614
|
+
@encoder = XLMRobertaEncoder.new(config)
|
615
|
+
|
616
|
+
@pooler = add_pooling_layer ? XLMRobertaPooler.new(config) : nil
|
617
|
+
|
618
|
+
@attn_implementation = config._attn_implementation
|
619
|
+
@position_embedding_type = config.position_embedding_type
|
620
|
+
|
621
|
+
# Initialize weights and apply final processing
|
622
|
+
post_init
|
623
|
+
end
|
624
|
+
|
625
|
+
def get_input_embeddings
|
626
|
+
@embeddings.word_embeddings
|
627
|
+
end
|
628
|
+
|
629
|
+
def set_input_embeddings(value)
|
630
|
+
@word_embeddings = value
|
631
|
+
end
|
632
|
+
|
633
|
+
def _prune_heads(heads_to_prune)
|
634
|
+
heads_to_prune.each do |layer, heads|
|
635
|
+
@encoder.layer[layer].attention.prune_heads(heads)
|
636
|
+
end
|
637
|
+
end
|
638
|
+
|
639
|
+
def forward(
|
640
|
+
input_ids,
|
641
|
+
attention_mask: nil,
|
642
|
+
token_type_ids: nil,
|
643
|
+
position_ids: nil,
|
644
|
+
head_mask: nil,
|
645
|
+
inputs_embeds: nil,
|
646
|
+
encoder_hidden_states: nil,
|
647
|
+
encoder_attention_mask: nil,
|
648
|
+
past_key_values: nil,
|
649
|
+
use_cache: nil,
|
650
|
+
output_attentions: nil,
|
651
|
+
output_hidden_states: nil,
|
652
|
+
return_dict: nil
|
653
|
+
)
|
654
|
+
output_attentions = !output_attentions.nil? ? output_attentions : @config.output_attentions
|
655
|
+
output_hidden_states = !output_hidden_states.nil? ? output_hidden_states : @config.output_hidden_states
|
656
|
+
return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
|
657
|
+
|
658
|
+
if @config.is_decoder
|
659
|
+
use_cache = !use_cache.nil? ? use_cache : @config.use_cache
|
660
|
+
else
|
661
|
+
use_cache = false
|
662
|
+
end
|
663
|
+
|
664
|
+
if !input_ids.nil? && !inputs_embeds.nil?
|
665
|
+
raise ArgumentError, "You cannot specify both input_ids and inputs_embeds at the same time"
|
666
|
+
elsif !input_ids.nil?
|
667
|
+
warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
668
|
+
input_shape = input_ids.size
|
669
|
+
elsif !inputs_embeds.nil?
|
670
|
+
input_shape = inputs_embeds.size[...-1]
|
671
|
+
else
|
672
|
+
raise ArgumentError, "You have to specify either input_ids or inputs_embeds"
|
673
|
+
end
|
674
|
+
|
675
|
+
batch_size, seq_length = input_shape
|
676
|
+
device = !input_ids.nil? ? input_ids.device : inputs_embeds.device
|
677
|
+
|
678
|
+
# past_key_values_length
|
679
|
+
past_key_values_length = !past_key_values.nil? ? past_key_values[0][0].shape[2] : 0
|
680
|
+
|
681
|
+
if token_type_ids.nil?
|
682
|
+
if @embeddings.respond_to?(:token_type_ids)
|
683
|
+
buffered_token_type_ids = @embeddings.token_type_ids[0.., ...seq_length]
|
684
|
+
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
|
685
|
+
token_type_ids = buffered_token_type_ids_expanded
|
686
|
+
else
|
687
|
+
token_type_ids = Torch.zeros(input_shape, dtype: Torch.long, device: device)
|
688
|
+
end
|
689
|
+
end
|
690
|
+
|
691
|
+
embedding_output = @embeddings.(input_ids: input_ids, position_ids: position_ids, token_type_ids: token_type_ids, inputs_embeds: inputs_embeds, past_key_values_length: past_key_values_length)
|
692
|
+
|
693
|
+
if attention_mask.nil?
|
694
|
+
attention_mask = Torch.ones([batch_size, seq_length + past_key_values_length], device: device)
|
695
|
+
end
|
696
|
+
|
697
|
+
use_sdpa_attention_masks = @attn_implementation == "sdpa" && @position_embedding_type == "absolute" && head_mask.nil? && !output_attentions
|
698
|
+
|
699
|
+
# Expand the attention mask
|
700
|
+
if use_sdpa_attention_masks && attention_mask.dim == 2
|
701
|
+
# Expand the attention mask for SDPA.
|
702
|
+
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
|
703
|
+
if @config.is_decoder
|
704
|
+
extended_attention_mask = ModelingAttnMaskUtils._prepare_4d_causal_attention_mask_for_sdpa(attention_mask, input_shape, embedding_output, past_key_values_length)
|
705
|
+
else
|
706
|
+
extended_attention_mask = ModelingAttnMaskUtils._prepare_4d_attention_mask_for_sdpa(attention_mask, embedding_output.dtype, tgt_len: seq_length)
|
707
|
+
end
|
708
|
+
else
|
709
|
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
710
|
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
711
|
+
extended_attention_mask = get_extended_attention_mask(attention_mask, input_shape)
|
712
|
+
end
|
713
|
+
|
714
|
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
715
|
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
716
|
+
if @config.is_decoder && !encoder_hidden_states.nil?
|
717
|
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size
|
718
|
+
encoder_hidden_shape = [encoder_batch_size, encoder_sequence_length]
|
719
|
+
if encoder_attention_mask.nil?
|
720
|
+
encoder_attention_mask = Torch.ones(encoder_hidden_shape, device: device)
|
721
|
+
end
|
722
|
+
|
723
|
+
if use_sdpa_attention_masks && encoder_attention_mask.dim == 2
|
724
|
+
# Expand the attention mask for SDPA.
|
725
|
+
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
|
726
|
+
encoder_extended_attention_mask = ModelingAttnMaskUtils._prepare_4d_attention_mask_for_sdpa(encoder_attention_mask, embedding_output.dtype, tgt_len: seq_length)
|
727
|
+
else
|
728
|
+
encoder_extended_attention_mask = invert_attention_mask(encoder_attention_mask)
|
729
|
+
end
|
730
|
+
else
|
731
|
+
encoder_extended_attention_mask = nil
|
732
|
+
end
|
733
|
+
|
734
|
+
# Prepare head mask if needed
|
735
|
+
# 1.0 in head_mask indicate we keep the head
|
736
|
+
# attention_probs has shape bsz x n_heads x N x N
|
737
|
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
738
|
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
739
|
+
head_mask = get_head_mask(head_mask, @config.num_hidden_layers)
|
740
|
+
|
741
|
+
encoder_outputs = @encoder.(embedding_output, attention_mask: extended_attention_mask, head_mask: head_mask, encoder_hidden_states: encoder_hidden_states, encoder_attention_mask: encoder_extended_attention_mask, past_key_values: past_key_values, use_cache: use_cache, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
|
742
|
+
sequence_output = encoder_outputs[0]
|
743
|
+
pooled_output = !@pooler.nil? ? @pooler.(sequence_output) : nil
|
744
|
+
|
745
|
+
if !return_dict
|
746
|
+
return [sequence_output, pooled_output] + encoder_outputs[1..]
|
747
|
+
end
|
748
|
+
|
749
|
+
BaseModelOutputWithPoolingAndCrossAttentions.new(last_hidden_state: sequence_output, pooler_output: pooled_output, past_key_values: encoder_outputs.past_key_values, hidden_states: encoder_outputs.hidden_states, attentions: encoder_outputs.attentions, cross_attentions: encoder_outputs.cross_attentions)
|
750
|
+
end
|
751
|
+
end
|
752
|
+
|
753
|
+
class XLMRobertaForCausalLM < XLMRobertaPreTrainedModel
|
754
|
+
self._tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
|
755
|
+
|
756
|
+
def initialize(config)
|
757
|
+
super(config)
|
758
|
+
|
759
|
+
if !config.is_decoder
|
760
|
+
Transformers.logger.warn("If you want to use `XLMRobertaLMHeadModel` as a standalone, add `is_decoder=True.`")
|
761
|
+
end
|
762
|
+
|
763
|
+
@roberta = XLMRobertaModel.new(config, add_pooling_layer: false)
|
764
|
+
@lm_head = XLMRobertaLMHead.new(config)
|
765
|
+
|
766
|
+
# Initialize weights and apply final processing
|
767
|
+
post_init
|
768
|
+
end
|
769
|
+
|
770
|
+
def get_output_embeddings
|
771
|
+
@lm_head.decoder
|
772
|
+
end
|
773
|
+
|
774
|
+
def set_output_embeddings(new_embeddings)
|
775
|
+
@decoder = new_embeddings
|
776
|
+
end
|
777
|
+
|
778
|
+
def forward(
|
779
|
+
input_ids: nil,
|
780
|
+
attention_mask: nil,
|
781
|
+
token_type_ids: nil,
|
782
|
+
position_ids: nil,
|
783
|
+
head_mask: nil,
|
784
|
+
inputs_embeds: nil,
|
785
|
+
encoder_hidden_states: nil,
|
786
|
+
encoder_attention_mask: nil,
|
787
|
+
labels: nil,
|
788
|
+
past_key_values: nil,
|
789
|
+
use_cache: nil,
|
790
|
+
output_attentions: nil,
|
791
|
+
output_hidden_states: nil,
|
792
|
+
return_dict: nil
|
793
|
+
)
|
794
|
+
return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
|
795
|
+
if !labels.nil?
|
796
|
+
use_cache = false
|
797
|
+
end
|
798
|
+
|
799
|
+
outputs = @roberta.(input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids, position_ids: position_ids, head_mask: head_mask, inputs_embeds: inputs_embeds, encoder_hidden_states: encoder_hidden_states, encoder_attention_mask: encoder_attention_mask, past_key_values: past_key_values, use_cache: use_cache, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
|
800
|
+
|
801
|
+
sequence_output = outputs[0]
|
802
|
+
prediction_scores = @lm_head.(sequence_output)
|
803
|
+
|
804
|
+
lm_loss = nil
|
805
|
+
if !labels.nil?
|
806
|
+
# move labels to correct device to enable model parallelism
|
807
|
+
labels = labels.to(prediction_scores.device)
|
808
|
+
# we are doing next-token prediction; shift prediction scores and input ids by one
|
809
|
+
shifted_prediction_scores = prediction_scores[0.., ...-1, 0..].contiguous
|
810
|
+
labels = labels[0.., 1..].contiguous
|
811
|
+
loss_fct = Torch::NN::CrossEntropyLoss.new
|
812
|
+
lm_loss = loss_fct.(shifted_prediction_scores.view(-1, @config.vocab_size), labels.view(-1))
|
813
|
+
end
|
814
|
+
|
815
|
+
if !return_dict
|
816
|
+
output = [prediction_scores] + outputs[2..]
|
817
|
+
return !lm_loss.nil? ? [lm_loss] + output : output
|
818
|
+
end
|
819
|
+
|
820
|
+
CausalLMOutputWithCrossAttentions.new(loss: lm_loss, logits: prediction_scores, past_key_values: outputs.past_key_values, hidden_states: outputs.hidden_states, attentions: outputs.attentions, cross_attentions: outputs.cross_attentions)
|
821
|
+
end
|
822
|
+
|
823
|
+
def prepare_inputs_for_generation(input_ids, past_key_values: nil, attention_mask: nil, **model_kwargs)
|
824
|
+
input_shape = input_ids.shape
|
825
|
+
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
826
|
+
if attention_mask.nil?
|
827
|
+
attention_mask = input_ids.new_ones(input_shape)
|
828
|
+
end
|
829
|
+
|
830
|
+
# cut decoder_input_ids if past_key_values is used
|
831
|
+
if !past_key_values.nil?
|
832
|
+
past_length = past_key_values[0][0].shape[2]
|
833
|
+
|
834
|
+
# Some generation methods already pass only the last input ID
|
835
|
+
if input_ids.shape[1] > past_length
|
836
|
+
remove_prefix_length = past_length
|
837
|
+
else
|
838
|
+
# Default to old behavior: keep only final ID
|
839
|
+
remove_prefix_length = input_ids.shape[1] - 1
|
840
|
+
end
|
841
|
+
|
842
|
+
input_ids = input_ids[0.., remove_prefix_length..]
|
843
|
+
end
|
844
|
+
|
845
|
+
{"input_ids" => input_ids, "attention_mask" => attention_mask, "past_key_values" => past_key_values}
|
846
|
+
end
|
847
|
+
|
848
|
+
def _reorder_cache(past_key_values, beam_idx)
|
849
|
+
reordered_past = []
|
850
|
+
past_key_values.each do |layer_past|
|
851
|
+
reordered_past += [Array(layer_past.select { |past_state| past_state })]
|
852
|
+
end
|
853
|
+
reordered_past
|
854
|
+
end
|
855
|
+
end
|
856
|
+
|
857
|
+
class XLMRobertaForMaskedLM < XLMRobertaPreTrainedModel
|
858
|
+
self._tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
|
859
|
+
|
860
|
+
def initialize(config)
|
861
|
+
super(config)
|
862
|
+
|
863
|
+
if config.is_decoder
|
864
|
+
Transformers.logger.warn("If you want to use `XLMRobertaForMaskedLM` make sure `config.is_decoder: false` for bi-directional self-attention.")
|
865
|
+
end
|
866
|
+
|
867
|
+
@roberta = XLMRobertaModel.new(config, add_pooling_layer: false)
|
868
|
+
@lm_head = XLMRobertaLMHead.new(config)
|
869
|
+
|
870
|
+
# Initialize weights and apply final processing
|
871
|
+
post_init
|
872
|
+
end
|
873
|
+
|
874
|
+
def get_output_embeddings
|
875
|
+
@lm_head.decoder
|
876
|
+
end
|
877
|
+
|
878
|
+
def set_output_embeddings(new_embeddings)
|
879
|
+
@decoder = new_embeddings
|
880
|
+
end
|
881
|
+
|
882
|
+
def forward(
|
883
|
+
input_ids: nil,
|
884
|
+
attention_mask: nil,
|
885
|
+
token_type_ids: nil,
|
886
|
+
position_ids: nil,
|
887
|
+
head_mask: nil,
|
888
|
+
inputs_embeds: nil,
|
889
|
+
encoder_hidden_states: nil,
|
890
|
+
encoder_attention_mask: nil,
|
891
|
+
labels: nil,
|
892
|
+
output_attentions: nil,
|
893
|
+
output_hidden_states: nil,
|
894
|
+
return_dict: nil
|
895
|
+
)
|
896
|
+
return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
|
897
|
+
|
898
|
+
outputs = @roberta.(input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids, position_ids: position_ids, head_mask: head_mask, inputs_embeds: inputs_embeds, encoder_hidden_states: encoder_hidden_states, encoder_attention_mask: encoder_attention_mask, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
|
899
|
+
sequence_output = outputs[0]
|
900
|
+
prediction_scores = @lm_head.(sequence_output)
|
901
|
+
|
902
|
+
masked_lm_loss = nil
|
903
|
+
if !labels.nil?
|
904
|
+
# move labels to correct device to enable model parallelism
|
905
|
+
labels = labels.to(prediction_scores.device)
|
906
|
+
loss_fct = Torch::NN::CrossEntropyLoss.new
|
907
|
+
masked_lm_loss = loss_fct.(prediction_scores.view(-1, @config.vocab_size), labels.view(-1))
|
908
|
+
end
|
909
|
+
|
910
|
+
if !return_dict
|
911
|
+
output = [prediction_scores] + outputs[2..]
|
912
|
+
return !masked_lm_loss.nil? ? [masked_lm_loss] + output : output
|
913
|
+
end
|
914
|
+
|
915
|
+
MaskedLMOutput.new(loss: masked_lm_loss, logits: prediction_scores, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
|
916
|
+
end
|
917
|
+
end
|
918
|
+
|
919
|
+
class XLMRobertaLMHead < Torch::NN::Module
|
920
|
+
def initialize(config)
|
921
|
+
super()
|
922
|
+
@dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
|
923
|
+
@layer_norm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
|
924
|
+
|
925
|
+
@decoder = Torch::NN::Linear.new(config.hidden_size, config.vocab_size)
|
926
|
+
@bias = Torch::NN::Parameter.new(Torch.zeros(config.vocab_size))
|
927
|
+
@bias = @bias
|
928
|
+
end
|
929
|
+
|
930
|
+
def forward(features, **kwargs)
|
931
|
+
x = @dense.(features)
|
932
|
+
x = Activations.gelu(x)
|
933
|
+
x = @layer_norm.(x)
|
934
|
+
|
935
|
+
# project back to size of vocabulary with bias
|
936
|
+
x = @decoder.(x)
|
937
|
+
|
938
|
+
x
|
939
|
+
end
|
940
|
+
|
941
|
+
def _tie_weights
|
942
|
+
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
|
943
|
+
# For accelerate compatibility and to not break backward compatibility
|
944
|
+
if @decoder.bias.device.type == "meta"
|
945
|
+
@bias = @bias
|
946
|
+
else
|
947
|
+
@bias = @decoder.bias
|
948
|
+
end
|
949
|
+
end
|
950
|
+
end
|
951
|
+
|
952
|
+
class XLMRobertaForSequenceClassification < XLMRobertaPreTrainedModel
|
953
|
+
def initialize(config)
|
954
|
+
super(config)
|
955
|
+
@num_labels = config.num_labels
|
956
|
+
@config = config
|
957
|
+
|
958
|
+
@roberta = XLMRobertaModel.new(config, add_pooling_layer: false)
|
959
|
+
@classifier = XLMRobertaClassificationHead.new(config)
|
960
|
+
|
961
|
+
# Initialize weights and apply final processing
|
962
|
+
post_init
|
963
|
+
end
|
964
|
+
|
965
|
+
def forward(
|
966
|
+
input_ids: nil,
|
967
|
+
attention_mask: nil,
|
968
|
+
token_type_ids: nil,
|
969
|
+
position_ids: nil,
|
970
|
+
head_mask: nil,
|
971
|
+
inputs_embeds: nil,
|
972
|
+
labels: nil,
|
973
|
+
output_attentions: nil,
|
974
|
+
output_hidden_states: nil,
|
975
|
+
return_dict: nil
|
976
|
+
)
|
977
|
+
return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
|
978
|
+
|
979
|
+
outputs = @roberta.(input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids, position_ids: position_ids, head_mask: head_mask, inputs_embeds: inputs_embeds, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
|
980
|
+
sequence_output = outputs[0]
|
981
|
+
logits = @classifier.(sequence_output)
|
982
|
+
|
983
|
+
loss = nil
|
984
|
+
if !labels.nil?
|
985
|
+
# move labels to correct device to enable model parallelism
|
986
|
+
labels = labels.to(logits.device)
|
987
|
+
if @config.problem_type.nil?
|
988
|
+
if @num_labels == 1
|
989
|
+
@problem_type = "regression"
|
990
|
+
elsif @num_labels > 1 && labels.dtype == Torch.long || labels.dtype == Torch.int
|
991
|
+
@problem_type = "single_label_classification"
|
992
|
+
else
|
993
|
+
@problem_type = "multi_label_classification"
|
994
|
+
end
|
995
|
+
end
|
996
|
+
|
997
|
+
if @config.problem_type == "regression"
|
998
|
+
loss_fct = Torch::NN::MSELoss.new
|
999
|
+
if @num_labels == 1
|
1000
|
+
loss = loss_fct.(logits.squeeze, labels.squeeze)
|
1001
|
+
else
|
1002
|
+
loss = loss_fct.(logits, labels)
|
1003
|
+
end
|
1004
|
+
elsif @config.problem_type == "single_label_classification"
|
1005
|
+
loss_fct = Torch::NN::CrossEntropyLoss.new
|
1006
|
+
loss = loss_fct.(logits.view(-1, @num_labels), labels.view(-1))
|
1007
|
+
elsif @config.problem_type == "multi_label_classification"
|
1008
|
+
loss_fct = Torch::NN::BCEWithLogitsLoss.new
|
1009
|
+
loss = loss_fct.(logits, labels)
|
1010
|
+
end
|
1011
|
+
end
|
1012
|
+
|
1013
|
+
if !return_dict
|
1014
|
+
output = [logits] + outputs[2..]
|
1015
|
+
return !loss.nil? ? [loss] + output : output
|
1016
|
+
end
|
1017
|
+
|
1018
|
+
SequenceClassifierOutput.new(loss: loss, logits: logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
|
1019
|
+
end
|
1020
|
+
end
|
1021
|
+
|
1022
|
+
class XLMRobertaForMultipleChoice < XLMRobertaPreTrainedModel
|
1023
|
+
def initialize(config)
|
1024
|
+
super(config)
|
1025
|
+
|
1026
|
+
@roberta = XLMRobertaModel.new(config)
|
1027
|
+
@dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
|
1028
|
+
@classifier = Torch::NN::Linear.new(config.hidden_size, 1)
|
1029
|
+
|
1030
|
+
# Initialize weights and apply final processing
|
1031
|
+
post_init
|
1032
|
+
end
|
1033
|
+
|
1034
|
+
def forward(
|
1035
|
+
input_ids: nil,
|
1036
|
+
token_type_ids: nil,
|
1037
|
+
attention_mask: nil,
|
1038
|
+
labels: nil,
|
1039
|
+
position_ids: nil,
|
1040
|
+
head_mask: nil,
|
1041
|
+
inputs_embeds: nil,
|
1042
|
+
output_attentions: nil,
|
1043
|
+
output_hidden_states: nil,
|
1044
|
+
return_dict: nil
|
1045
|
+
)
|
1046
|
+
return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
|
1047
|
+
num_choices = !input_ids.nil? ? input_ids.shape[1] : inputs_embeds.shape[1]
|
1048
|
+
|
1049
|
+
flat_input_ids = !input_ids.nil? ? input_ids.view(-1, input_ids.size(-1)) : nil
|
1050
|
+
flat_position_ids = !position_ids.nil? ? position_ids.view(-1, position_ids.size(-1)) : nil
|
1051
|
+
flat_token_type_ids = !token_type_ids.nil? ? token_type_ids.view(-1, token_type_ids.size(-1)) : nil
|
1052
|
+
flat_attention_mask = !attention_mask.nil? ? attention_mask.view(-1, attention_mask.size(-1)) : nil
|
1053
|
+
flat_inputs_embeds = !inputs_embeds.nil? ? inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) : nil
|
1054
|
+
|
1055
|
+
outputs = @roberta.(flat_input_ids, position_ids: flat_position_ids, token_type_ids: flat_token_type_ids, attention_mask: flat_attention_mask, head_mask: head_mask, inputs_embeds: flat_inputs_embeds, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
|
1056
|
+
pooled_output = outputs[1]
|
1057
|
+
|
1058
|
+
pooled_output = @dropout.(pooled_output)
|
1059
|
+
logits = @classifier.(pooled_output)
|
1060
|
+
reshaped_logits = logits.view(-1, num_choices)
|
1061
|
+
|
1062
|
+
loss = nil
|
1063
|
+
if !labels.nil?
|
1064
|
+
# move labels to correct device to enable model parallelism
|
1065
|
+
labels = labels.to(reshaped_logits.device)
|
1066
|
+
loss_fct = Torch::NN::CrossEntropyLoss.new
|
1067
|
+
loss = loss_fct.(reshaped_logits, labels)
|
1068
|
+
end
|
1069
|
+
|
1070
|
+
if !return_dict
|
1071
|
+
output = [reshaped_logits] + outputs[2..]
|
1072
|
+
return !loss.nil? ? [loss] + output : output
|
1073
|
+
end
|
1074
|
+
|
1075
|
+
MultipleChoiceModelOutput.new(loss: loss, logits: reshaped_logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
|
1076
|
+
end
|
1077
|
+
end
|
1078
|
+
|
1079
|
+
class XLMRobertaForTokenClassification < XLMRobertaPreTrainedModel
|
1080
|
+
def initialize(config)
|
1081
|
+
super(config)
|
1082
|
+
@num_labels = config.num_labels
|
1083
|
+
|
1084
|
+
@roberta = XLMRobertaModel.new(config, add_pooling_layer: false)
|
1085
|
+
classifier_dropout = !config.classifier_dropout.nil? ? config.classifier_dropout : config.hidden_dropout_prob
|
1086
|
+
@dropout = Torch::NN::Dropout.new(p: classifier_dropout)
|
1087
|
+
@classifier = Torch::NN::Linear.new(config.hidden_size, config.num_labels)
|
1088
|
+
|
1089
|
+
# Initialize weights and apply final processing
|
1090
|
+
post_init
|
1091
|
+
end
|
1092
|
+
|
1093
|
+
def forward(
|
1094
|
+
input_ids: nil,
|
1095
|
+
attention_mask: nil,
|
1096
|
+
token_type_ids: nil,
|
1097
|
+
position_ids: nil,
|
1098
|
+
head_mask: nil,
|
1099
|
+
inputs_embeds: nil,
|
1100
|
+
labels: nil,
|
1101
|
+
output_attentions: nil,
|
1102
|
+
output_hidden_states: nil,
|
1103
|
+
return_dict: nil
|
1104
|
+
)
|
1105
|
+
return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
|
1106
|
+
|
1107
|
+
outputs = @roberta.(input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids, position_ids: position_ids, head_mask: head_mask, inputs_embeds: inputs_embeds, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
|
1108
|
+
|
1109
|
+
sequence_output = outputs[0]
|
1110
|
+
|
1111
|
+
sequence_output = @dropout.(sequence_output)
|
1112
|
+
logits = @classifier.(sequence_output)
|
1113
|
+
|
1114
|
+
loss = nil
|
1115
|
+
if !labels.nil?
|
1116
|
+
# move labels to correct device to enable model parallelism
|
1117
|
+
labels = labels.to(logits.device)
|
1118
|
+
loss_fct = Torch::NN::CrossEntropyLoss.new
|
1119
|
+
loss = loss_fct.(logits.view(-1, @num_labels), labels.view(-1))
|
1120
|
+
end
|
1121
|
+
|
1122
|
+
if !return_dict
|
1123
|
+
output = [logits] + outputs[2..]
|
1124
|
+
return !loss.nil? ? [loss] + output : output
|
1125
|
+
end
|
1126
|
+
|
1127
|
+
TokenClassifierOutput.new(loss: loss, logits: logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
|
1128
|
+
end
|
1129
|
+
end
|
1130
|
+
|
1131
|
+
class XLMRobertaClassificationHead < Torch::NN::Module
|
1132
|
+
def initialize(config)
|
1133
|
+
super()
|
1134
|
+
@dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
|
1135
|
+
classifier_dropout = !config.classifier_dropout.nil? ? config.classifier_dropout : config.hidden_dropout_prob
|
1136
|
+
@dropout = Torch::NN::Dropout.new(p: classifier_dropout)
|
1137
|
+
@out_proj = Torch::NN::Linear.new(config.hidden_size, config.num_labels)
|
1138
|
+
end
|
1139
|
+
|
1140
|
+
def forward(features, **kwargs)
|
1141
|
+
x = features[0.., 0, 0..]
|
1142
|
+
x = @dropout.(x)
|
1143
|
+
x = @dense.(x)
|
1144
|
+
x = Torch.tanh(x)
|
1145
|
+
x = @dropout.(x)
|
1146
|
+
x = @out_proj.(x)
|
1147
|
+
x
|
1148
|
+
end
|
1149
|
+
end
|
1150
|
+
|
1151
|
+
class XLMRobertaForQuestionAnswering < XLMRobertaPreTrainedModel
|
1152
|
+
def initialize(config)
|
1153
|
+
super(config)
|
1154
|
+
@num_labels = config.num_labels
|
1155
|
+
|
1156
|
+
@roberta = XLMRobertaModel.new(config, add_pooling_layer: false)
|
1157
|
+
@qa_outputs = Torch::NN::Linear.new(config.hidden_size, config.num_labels)
|
1158
|
+
|
1159
|
+
# Initialize weights and apply final processing
|
1160
|
+
post_init
|
1161
|
+
end
|
1162
|
+
|
1163
|
+
def forward(
|
1164
|
+
input_ids: nil,
|
1165
|
+
attention_mask: nil,
|
1166
|
+
token_type_ids: nil,
|
1167
|
+
position_ids: nil,
|
1168
|
+
head_mask: nil,
|
1169
|
+
inputs_embeds: nil,
|
1170
|
+
start_positions: nil,
|
1171
|
+
end_positions: nil,
|
1172
|
+
output_attentions: nil,
|
1173
|
+
output_hidden_states: nil,
|
1174
|
+
return_dict: nil
|
1175
|
+
)
|
1176
|
+
return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
|
1177
|
+
|
1178
|
+
outputs = @roberta.(input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids, position_ids: position_ids, head_mask: head_mask, inputs_embeds: inputs_embeds, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
|
1179
|
+
|
1180
|
+
sequence_output = outputs[0]
|
1181
|
+
|
1182
|
+
logits = @qa_outputs.(sequence_output)
|
1183
|
+
start_logits, end_logits = logits.split(1, dim: -1)
|
1184
|
+
start_logits = start_logits.squeeze(-1).contiguous
|
1185
|
+
end_logits = end_logits.squeeze(-1).contiguous
|
1186
|
+
|
1187
|
+
total_loss = nil
|
1188
|
+
if !start_positions.nil? && !end_positions.nil?
|
1189
|
+
# If we are on multi-GPU, split add a dimension
|
1190
|
+
if start_positions.size.length > 1
|
1191
|
+
start_positions = start_positions.squeeze(-1)
|
1192
|
+
end
|
1193
|
+
if end_positions.size.length > 1
|
1194
|
+
end_positions = end_positions.squeeze(-1)
|
1195
|
+
end
|
1196
|
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
1197
|
+
ignored_index = start_logits.size(1)
|
1198
|
+
start_positions = start_positions.clamp(0, ignored_index)
|
1199
|
+
end_positions = end_positions.clamp(0, ignored_index)
|
1200
|
+
|
1201
|
+
loss_fct = Torch::NN::CrossEntropyLoss.new(ignore_index: ignored_index)
|
1202
|
+
start_loss = loss_fct.(start_logits, start_positions)
|
1203
|
+
end_loss = loss_fct.(end_logits, end_positions)
|
1204
|
+
total_loss = (start_loss + end_loss) / 2
|
1205
|
+
end
|
1206
|
+
|
1207
|
+
if !return_dict
|
1208
|
+
output = [start_logits, end_logits] + outputs[2..]
|
1209
|
+
return !total_loss.nil? ? [total_loss] + output : output
|
1210
|
+
end
|
1211
|
+
|
1212
|
+
QuestionAnsweringModelOutput.new(loss: total_loss, start_logits: start_logits, end_logits: end_logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
|
1213
|
+
end
|
1214
|
+
end
|
1215
|
+
end
|
1216
|
+
end
|