transformers-rb 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +203 -0
- data/README.md +163 -0
- data/lib/transformers/activations.rb +57 -0
- data/lib/transformers/configuration_utils.rb +285 -0
- data/lib/transformers/convert_slow_tokenizer.rb +90 -0
- data/lib/transformers/data/processors/squad.rb +115 -0
- data/lib/transformers/dynamic_module_utils.rb +25 -0
- data/lib/transformers/feature_extraction_utils.rb +110 -0
- data/lib/transformers/hf_hub/constants.rb +71 -0
- data/lib/transformers/hf_hub/errors.rb +11 -0
- data/lib/transformers/hf_hub/file_download.rb +764 -0
- data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
- data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
- data/lib/transformers/image_processing_base.rb +169 -0
- data/lib/transformers/image_processing_utils.rb +63 -0
- data/lib/transformers/image_transforms.rb +208 -0
- data/lib/transformers/image_utils.rb +165 -0
- data/lib/transformers/modeling_outputs.rb +81 -0
- data/lib/transformers/modeling_utils.rb +888 -0
- data/lib/transformers/models/auto/auto_factory.rb +138 -0
- data/lib/transformers/models/auto/configuration_auto.rb +61 -0
- data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
- data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
- data/lib/transformers/models/auto/modeling_auto.rb +80 -0
- data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
- data/lib/transformers/models/bert/configuration_bert.rb +65 -0
- data/lib/transformers/models/bert/modeling_bert.rb +836 -0
- data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
- data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
- data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
- data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
- data/lib/transformers/models/vit/configuration_vit.rb +60 -0
- data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
- data/lib/transformers/models/vit/modeling_vit.rb +506 -0
- data/lib/transformers/pipelines/_init.rb +348 -0
- data/lib/transformers/pipelines/base.rb +301 -0
- data/lib/transformers/pipelines/feature_extraction.rb +47 -0
- data/lib/transformers/pipelines/image_classification.rb +110 -0
- data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
- data/lib/transformers/pipelines/pt_utils.rb +53 -0
- data/lib/transformers/pipelines/question_answering.rb +508 -0
- data/lib/transformers/pipelines/text_classification.rb +123 -0
- data/lib/transformers/pipelines/token_classification.rb +282 -0
- data/lib/transformers/ruby_utils.rb +33 -0
- data/lib/transformers/sentence_transformer.rb +37 -0
- data/lib/transformers/tokenization_utils.rb +152 -0
- data/lib/transformers/tokenization_utils_base.rb +937 -0
- data/lib/transformers/tokenization_utils_fast.rb +386 -0
- data/lib/transformers/torch_utils.rb +25 -0
- data/lib/transformers/utils/_init.rb +31 -0
- data/lib/transformers/utils/generic.rb +107 -0
- data/lib/transformers/utils/hub.rb +209 -0
- data/lib/transformers/utils/import_utils.rb +45 -0
- data/lib/transformers/utils/logging.rb +52 -0
- data/lib/transformers/version.rb +3 -0
- data/lib/transformers-rb.rb +1 -0
- data/lib/transformers.rb +100 -0
- data/licenses/LICENSE-huggingface-hub.txt +201 -0
- data/licenses/LICENSE-sentence-transformers.txt +201 -0
- data/licenses/NOTICE-sentence-transformers.txt +5 -0
- metadata +161 -0
@@ -0,0 +1,836 @@
|
|
1
|
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
2
|
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
module Transformers
|
17
|
+
module Bert
|
18
|
+
class BertEmbeddings < Torch::NN::Module
|
19
|
+
def initialize(config)
|
20
|
+
super()
|
21
|
+
@word_embeddings = Torch::NN::Embedding.new(config.vocab_size, config.hidden_size, padding_idx: config.pad_token_id)
|
22
|
+
@position_embeddings = Torch::NN::Embedding.new(config.max_position_embeddings, config.hidden_size)
|
23
|
+
@token_type_embeddings = Torch::NN::Embedding.new(config.type_vocab_size, config.hidden_size)
|
24
|
+
|
25
|
+
@LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
|
26
|
+
@dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
|
27
|
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
28
|
+
@position_embedding_type = config.position_embedding_type || "absolute"
|
29
|
+
register_buffer(
|
30
|
+
"position_ids", Torch.arange(config.max_position_embeddings).expand([1, -1]), persistent: false
|
31
|
+
)
|
32
|
+
register_buffer(
|
33
|
+
"token_type_ids", Torch.zeros(position_ids.size, dtype: Torch.long), persistent: false
|
34
|
+
)
|
35
|
+
end
|
36
|
+
|
37
|
+
def forward(
|
38
|
+
input_ids: nil,
|
39
|
+
token_type_ids: nil,
|
40
|
+
position_ids: nil,
|
41
|
+
inputs_embeds: nil,
|
42
|
+
past_key_values_length: 0
|
43
|
+
)
|
44
|
+
if !input_ids.nil?
|
45
|
+
input_shape = input_ids.size
|
46
|
+
else
|
47
|
+
input_shape = inputs_embeds.size[...-1]
|
48
|
+
end
|
49
|
+
|
50
|
+
seq_length = input_shape[1]
|
51
|
+
|
52
|
+
if position_ids.nil?
|
53
|
+
position_ids = @position_ids[0.., past_key_values_length...(seq_length + past_key_values_length)]
|
54
|
+
end
|
55
|
+
|
56
|
+
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
57
|
+
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
58
|
+
# issue #5664
|
59
|
+
if token_type_ids.nil?
|
60
|
+
raise Todo
|
61
|
+
end
|
62
|
+
|
63
|
+
if inputs_embeds.nil?
|
64
|
+
inputs_embeds = @word_embeddings.(input_ids)
|
65
|
+
end
|
66
|
+
token_type_embeddings = @token_type_embeddings.(token_type_ids)
|
67
|
+
|
68
|
+
embeddings = inputs_embeds + token_type_embeddings
|
69
|
+
if @position_embedding_type == "absolute"
|
70
|
+
position_embeddings = @position_embeddings.(position_ids)
|
71
|
+
embeddings += position_embeddings
|
72
|
+
end
|
73
|
+
embeddings = @LayerNorm.(embeddings)
|
74
|
+
embeddings = @dropout.(embeddings)
|
75
|
+
embeddings
|
76
|
+
end
|
77
|
+
end
|
78
|
+
|
79
|
+
class BertSelfAttention < Torch::NN::Module
|
80
|
+
def initialize(config, position_embedding_type: nil)
|
81
|
+
super()
|
82
|
+
if config.hidden_size % config.num_attention_heads != 0 && !config.embedding_size
|
83
|
+
raise ArgumentError,
|
84
|
+
"The hidden size (#{config.hidden_size}) is not a multiple of the number of attention " +
|
85
|
+
"heads (#{config.num_attention_heads})"
|
86
|
+
end
|
87
|
+
|
88
|
+
@num_attention_heads = config.num_attention_heads
|
89
|
+
@attention_head_size = (config.hidden_size / config.num_attention_heads).to_i
|
90
|
+
@all_head_size = @num_attention_heads * @attention_head_size
|
91
|
+
|
92
|
+
@query = Torch::NN::Linear.new(config.hidden_size, @all_head_size)
|
93
|
+
@key = Torch::NN::Linear.new(config.hidden_size, @all_head_size)
|
94
|
+
@value = Torch::NN::Linear.new(config.hidden_size, @all_head_size)
|
95
|
+
|
96
|
+
@dropout = Torch::NN::Dropout.new(p: config.attention_probs_dropout_prob)
|
97
|
+
@position_embedding_type = position_embedding_type || config.position_embedding_type || "absolute"
|
98
|
+
if @position_embedding_type == "relative_key" || @position_embedding_type == "relative_key_query"
|
99
|
+
@max_position_embeddings = config.max_position_embeddings
|
100
|
+
@distance_embedding = Torch:NN::Embedding.new(2 * config.max_position_embeddings - 1, @attention_head_size)
|
101
|
+
end
|
102
|
+
|
103
|
+
@is_decoder = config.is_decoder
|
104
|
+
end
|
105
|
+
|
106
|
+
def transpose_for_scores(x)
|
107
|
+
new_x_shape = x.size[...-1] + [@num_attention_heads, @attention_head_size]
|
108
|
+
x = x.view(new_x_shape)
|
109
|
+
x.permute(0, 2, 1, 3)
|
110
|
+
end
|
111
|
+
|
112
|
+
def forward(
|
113
|
+
hidden_states,
|
114
|
+
attention_mask: nil,
|
115
|
+
head_mask: nil,
|
116
|
+
encoder_hidden_states: nil,
|
117
|
+
encoder_attention_mask: nil,
|
118
|
+
past_key_value: nil,
|
119
|
+
output_attentions: false
|
120
|
+
)
|
121
|
+
mixed_query_layer = @query.(hidden_states)
|
122
|
+
|
123
|
+
# If this is instantiated as a cross-attention module, the keys
|
124
|
+
# and values come from an encoder; the attention mask needs to be
|
125
|
+
# such that the encoder's padding tokens are not attended to.
|
126
|
+
is_cross_attention = !encoder_hidden_states.nil?
|
127
|
+
|
128
|
+
if is_cross_attention && !past_key_value.nil?
|
129
|
+
# reuse k,v, cross_attentions
|
130
|
+
key_layer = past_key_value[0]
|
131
|
+
value_layer = past_key_value[1]
|
132
|
+
attention_mask = encoder_attention_mask
|
133
|
+
elsif is_cross_attention
|
134
|
+
key_layer = transpose_for_scores(@key.(encoder_hidden_states))
|
135
|
+
value_layer = transpose_for_scores(@value.(encoder_hidden_states))
|
136
|
+
attention_mask = encoder_attention_mask
|
137
|
+
elsif !past_key_value.nil?
|
138
|
+
key_layer = transpose_for_scores(@key.(hidden_states))
|
139
|
+
value_layer = transpose_for_scores(@value.(hidden_states))
|
140
|
+
key_layer = Torch.cat([past_key_value[0], key_layer], dim: 2)
|
141
|
+
value_layer = Torch.cat([past_key_value[1], value_layer], dim: 2)
|
142
|
+
else
|
143
|
+
key_layer = transpose_for_scores(@key.(hidden_states))
|
144
|
+
value_layer = transpose_for_scores(@value.(hidden_states))
|
145
|
+
end
|
146
|
+
|
147
|
+
query_layer = transpose_for_scores(mixed_query_layer)
|
148
|
+
|
149
|
+
_use_cache = !past_key_value.nil?
|
150
|
+
if @is_decoder
|
151
|
+
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
152
|
+
# Further calls to cross_attention layer can then reuse all cross-attention
|
153
|
+
# key/value_states (first "if" case)
|
154
|
+
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
155
|
+
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
156
|
+
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
157
|
+
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
158
|
+
past_key_value = [key_layer, value_layer]
|
159
|
+
end
|
160
|
+
|
161
|
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
162
|
+
attention_scores = Torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
163
|
+
|
164
|
+
if @position_embedding_type == "relative_key" || @position_embedding_type == "relative_key_query"
|
165
|
+
raise Todo
|
166
|
+
end
|
167
|
+
|
168
|
+
attention_scores = attention_scores / Math.sqrt(@attention_head_size)
|
169
|
+
if !attention_mask.nil?
|
170
|
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
171
|
+
attention_scores = attention_scores + attention_mask
|
172
|
+
end
|
173
|
+
|
174
|
+
# Normalize the attention scores to probabilities.
|
175
|
+
attention_probs = Torch::NN::Functional.softmax(attention_scores, dim: -1)
|
176
|
+
|
177
|
+
# This is actually dropping out entire tokens to attend to, which might
|
178
|
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
179
|
+
attention_probs = @dropout.(attention_probs)
|
180
|
+
|
181
|
+
# Mask heads if we want to
|
182
|
+
if !head_mask.nil?
|
183
|
+
attention_probs = attention_probs * head_mask
|
184
|
+
end
|
185
|
+
|
186
|
+
context_layer = Torch.matmul(attention_probs, value_layer)
|
187
|
+
|
188
|
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous
|
189
|
+
new_context_layer_shape = context_layer.size[...-2] + [@all_head_size]
|
190
|
+
context_layer = context_layer.view(new_context_layer_shape)
|
191
|
+
|
192
|
+
outputs = output_attentions ? [context_layer, attention_probs] : [context_layer]
|
193
|
+
|
194
|
+
if @is_decoder
|
195
|
+
outputs = outputs + [past_key_value]
|
196
|
+
end
|
197
|
+
outputs
|
198
|
+
end
|
199
|
+
end
|
200
|
+
|
201
|
+
class BertSelfOutput < Torch::NN::Module
|
202
|
+
def initialize(config)
|
203
|
+
super()
|
204
|
+
@dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
|
205
|
+
@LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
|
206
|
+
@dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
|
207
|
+
end
|
208
|
+
|
209
|
+
def forward(hidden_states, input_tensor)
|
210
|
+
hidden_states = @dense.(hidden_states)
|
211
|
+
hidden_states = @dropout.(hidden_states)
|
212
|
+
hidden_states = @LayerNorm.(hidden_states + input_tensor)
|
213
|
+
hidden_states
|
214
|
+
end
|
215
|
+
end
|
216
|
+
|
217
|
+
BERT_SELF_ATTENTION_CLASSES = {
|
218
|
+
"eager" => BertSelfAttention
|
219
|
+
}
|
220
|
+
|
221
|
+
class BertAttention < Torch::NN::Module
|
222
|
+
def initialize(config, position_embedding_type: nil)
|
223
|
+
super()
|
224
|
+
@self = BERT_SELF_ATTENTION_CLASSES.fetch(config._attn_implementation).new(
|
225
|
+
config, position_embedding_type: position_embedding_type
|
226
|
+
)
|
227
|
+
@output = BertSelfOutput.new(config)
|
228
|
+
@pruned_heads = Set.new
|
229
|
+
end
|
230
|
+
|
231
|
+
def forward(
|
232
|
+
hidden_states,
|
233
|
+
attention_mask: nil,
|
234
|
+
head_mask: nil,
|
235
|
+
encoder_hidden_states: nil,
|
236
|
+
encoder_attention_mask: nil,
|
237
|
+
past_key_value: nil,
|
238
|
+
output_attentions: false
|
239
|
+
)
|
240
|
+
self_outputs = @self.(
|
241
|
+
hidden_states,
|
242
|
+
attention_mask: attention_mask,
|
243
|
+
head_mask: head_mask,
|
244
|
+
encoder_hidden_states: encoder_hidden_states,
|
245
|
+
encoder_attention_mask: encoder_attention_mask,
|
246
|
+
past_key_value: past_key_value,
|
247
|
+
output_attentions: output_attentions
|
248
|
+
)
|
249
|
+
attention_output = @output.(self_outputs[0], hidden_states)
|
250
|
+
outputs = [attention_output] + self_outputs[1..] # add attentions if we output them
|
251
|
+
outputs
|
252
|
+
end
|
253
|
+
end
|
254
|
+
|
255
|
+
class BertIntermediate < Torch::NN::Module
|
256
|
+
def initialize(config)
|
257
|
+
super()
|
258
|
+
@dense = Torch::NN::Linear.new(config.hidden_size, config.intermediate_size)
|
259
|
+
if config.hidden_act.is_a?(String)
|
260
|
+
@intermediate_act_fn = ACT2FN[config.hidden_act]
|
261
|
+
else
|
262
|
+
@intermediate_act_fn = config.hidden_act
|
263
|
+
end
|
264
|
+
end
|
265
|
+
|
266
|
+
def forward(hidden_states)
|
267
|
+
hidden_states = @dense.(hidden_states)
|
268
|
+
hidden_states = @intermediate_act_fn.(hidden_states)
|
269
|
+
hidden_states
|
270
|
+
end
|
271
|
+
end
|
272
|
+
|
273
|
+
class BertOutput < Torch::NN::Module
|
274
|
+
def initialize(config)
|
275
|
+
super()
|
276
|
+
@dense = Torch::NN::Linear.new(config.intermediate_size, config.hidden_size)
|
277
|
+
@LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
|
278
|
+
@dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
|
279
|
+
end
|
280
|
+
|
281
|
+
def forward(hidden_states, input_tensor)
|
282
|
+
hidden_states = @dense.(hidden_states)
|
283
|
+
hidden_states = @dropout.(hidden_states)
|
284
|
+
hidden_states = @LayerNorm.(hidden_states + input_tensor)
|
285
|
+
hidden_states
|
286
|
+
end
|
287
|
+
end
|
288
|
+
|
289
|
+
class BertLayer < Torch::NN::Module
|
290
|
+
def initialize(config)
|
291
|
+
super()
|
292
|
+
@chunk_size_feed_forward = config.chunk_size_feed_forward
|
293
|
+
@seq_len_dim = 1
|
294
|
+
@attention = BertAttention.new(config)
|
295
|
+
@is_decoder = config.is_decoder
|
296
|
+
@add_cross_attention = config.add_cross_attention
|
297
|
+
if @add_cross_attention
|
298
|
+
if !@is_decoder
|
299
|
+
raise ArgumentError, "#{self} should be used as a decoder model if cross attention is added"
|
300
|
+
end
|
301
|
+
@crossattention = BertAttention.new(config, position_embedding_type: "absolute")
|
302
|
+
end
|
303
|
+
@intermediate = BertIntermediate.new(config)
|
304
|
+
@output = BertOutput.new(config)
|
305
|
+
end
|
306
|
+
|
307
|
+
def forward(
|
308
|
+
hidden_states,
|
309
|
+
attention_mask: nil,
|
310
|
+
head_mask: nil,
|
311
|
+
encoder_hidden_states: nil,
|
312
|
+
encoder_attention_mask: nil,
|
313
|
+
past_key_value: nil,
|
314
|
+
output_attentions: false
|
315
|
+
)
|
316
|
+
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
317
|
+
self_attn_past_key_value = !past_key_value.nil? ? past_key_value[...2] : nil
|
318
|
+
self_attention_outputs = @attention.(
|
319
|
+
hidden_states,
|
320
|
+
attention_mask: attention_mask,
|
321
|
+
head_mask: head_mask,
|
322
|
+
output_attentions: output_attentions,
|
323
|
+
past_key_value: self_attn_past_key_value
|
324
|
+
)
|
325
|
+
attention_output = self_attention_outputs[0]
|
326
|
+
|
327
|
+
# if decoder, the last output is tuple of self-attn cache
|
328
|
+
if @is_decoder
|
329
|
+
outputs = self_attention_outputs[1...-1]
|
330
|
+
present_key_value = self_attention_outputs[-1]
|
331
|
+
else
|
332
|
+
outputs = self_attention_outputs[1..] # add self attentions if we output attention weights
|
333
|
+
end
|
334
|
+
|
335
|
+
_cross_attn_present_key_value = nil
|
336
|
+
if @is_decoder && !encoder_hidden_states.nil?
|
337
|
+
raise Todo
|
338
|
+
end
|
339
|
+
|
340
|
+
layer_output = TorchUtils.apply_chunking_to_forward(
|
341
|
+
method(:feed_forward_chunk), @chunk_size_feed_forward, @seq_len_dim, attention_output
|
342
|
+
)
|
343
|
+
outputs = [layer_output] + outputs
|
344
|
+
|
345
|
+
# if decoder, return the attn key/values as the last output
|
346
|
+
if @is_decoder
|
347
|
+
outputs = outputs + [present_key_value]
|
348
|
+
end
|
349
|
+
|
350
|
+
outputs
|
351
|
+
end
|
352
|
+
|
353
|
+
def feed_forward_chunk(attention_output)
|
354
|
+
intermediate_output = @intermediate.(attention_output)
|
355
|
+
layer_output = @output.(intermediate_output, attention_output)
|
356
|
+
return layer_output
|
357
|
+
end
|
358
|
+
end
|
359
|
+
|
360
|
+
class BertEncoder < Torch::NN::Module
|
361
|
+
def initialize(config)
|
362
|
+
super()
|
363
|
+
@config = config
|
364
|
+
@layer = Torch::NN::ModuleList.new(config.num_hidden_layers.times.map { BertLayer.new(config) })
|
365
|
+
@gradient_checkpointing = false
|
366
|
+
end
|
367
|
+
|
368
|
+
def forward(
|
369
|
+
hidden_states,
|
370
|
+
attention_mask: nil,
|
371
|
+
head_mask: nil,
|
372
|
+
encoder_hidden_states: nil,
|
373
|
+
encoder_attention_mask:nil,
|
374
|
+
past_key_values: nil,
|
375
|
+
use_cache: nil,
|
376
|
+
output_attentions: false,
|
377
|
+
output_hidden_states: false,
|
378
|
+
return_dict: true
|
379
|
+
)
|
380
|
+
all_hidden_states = output_hidden_states ? [] : nil
|
381
|
+
all_self_attentions = output_attentions ? [] : nil
|
382
|
+
all_cross_attentions = output_attentions && @config.add_cross_attention ? [] : nil
|
383
|
+
|
384
|
+
if @gradient_checkpointing && @raining
|
385
|
+
raise Todo
|
386
|
+
end
|
387
|
+
|
388
|
+
next_decoder_cache = use_cache ? [] : nil
|
389
|
+
@layer.each_with_index do |layer_module, i|
|
390
|
+
if output_hidden_states
|
391
|
+
all_hidden_states = all_hidden_states + [hidden_states]
|
392
|
+
end
|
393
|
+
|
394
|
+
layer_head_mask = !head_mask.nil? ? head_mask[i] : nil
|
395
|
+
past_key_value = !past_key_values.nil? ? past_key_values[i] : nil
|
396
|
+
|
397
|
+
if @gradient_checkpointing && @training
|
398
|
+
raise Todo
|
399
|
+
else
|
400
|
+
layer_outputs = layer_module.(
|
401
|
+
hidden_states,
|
402
|
+
attention_mask: attention_mask,
|
403
|
+
head_mask: layer_head_mask,
|
404
|
+
encoder_hidden_states: encoder_hidden_states,
|
405
|
+
encoder_attention_mask: encoder_attention_mask,
|
406
|
+
past_key_value: past_key_value,
|
407
|
+
output_attentions: output_attentions
|
408
|
+
)
|
409
|
+
end
|
410
|
+
|
411
|
+
hidden_states = layer_outputs[0]
|
412
|
+
if use_cache
|
413
|
+
next_decoder_cache += [layer_outputs[-1]]
|
414
|
+
end
|
415
|
+
if output_attentions
|
416
|
+
all_self_attentions = all_self_attentions + [layer_outputs[1]]
|
417
|
+
if @config.add_cross_attention
|
418
|
+
all_cross_attentions = all_cross_attentions + [layer_outputs[2]]
|
419
|
+
end
|
420
|
+
end
|
421
|
+
end
|
422
|
+
|
423
|
+
if output_hidden_states
|
424
|
+
all_hidden_states = all_hidden_states + [hidden_states]
|
425
|
+
end
|
426
|
+
|
427
|
+
if !return_dict
|
428
|
+
raise Todo
|
429
|
+
end
|
430
|
+
BaseModelOutputWithPastAndCrossAttentions.new(
|
431
|
+
last_hidden_state: hidden_states,
|
432
|
+
past_key_values: next_decoder_cache,
|
433
|
+
hidden_states: all_hidden_states,
|
434
|
+
attentions: all_self_attentions,
|
435
|
+
cross_attentions: all_cross_attentions
|
436
|
+
)
|
437
|
+
end
|
438
|
+
end
|
439
|
+
|
440
|
+
class BertPooler < Torch::NN::Module
|
441
|
+
def initialize(config)
|
442
|
+
super()
|
443
|
+
@dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
|
444
|
+
@activation = Torch::NN::Tanh.new
|
445
|
+
end
|
446
|
+
|
447
|
+
def forward(hidden_states)
|
448
|
+
# We "pool" the model by simply taking the hidden state corresponding
|
449
|
+
# to the first token.
|
450
|
+
first_token_tensor = hidden_states[0.., 0]
|
451
|
+
pooled_output = @dense.(first_token_tensor)
|
452
|
+
pooled_output = @activation.(pooled_output)
|
453
|
+
pooled_output
|
454
|
+
end
|
455
|
+
end
|
456
|
+
|
457
|
+
class BertPredictionHeadTransform < Torch::NN::Module
|
458
|
+
def initialize(config)
|
459
|
+
super()
|
460
|
+
@dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
|
461
|
+
if config.hidden_act.is_a?(String)
|
462
|
+
@transform_act_fn = ACT2FN[config.hidden_act]
|
463
|
+
else
|
464
|
+
@transform_act_fn = config.hidden_act
|
465
|
+
end
|
466
|
+
@LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
|
467
|
+
end
|
468
|
+
|
469
|
+
def forward(hidden_states)
|
470
|
+
hidden_states = @dense.(hidden_states)
|
471
|
+
hidden_states = @transform_act_fn.(hidden_states)
|
472
|
+
hidden_states = @LayerNorm.(hidden_states)
|
473
|
+
hidden_states
|
474
|
+
end
|
475
|
+
end
|
476
|
+
|
477
|
+
class BertLMPredictionHead < Torch::NN::Module
|
478
|
+
def initialize(config)
|
479
|
+
super()
|
480
|
+
@transform = BertPredictionHeadTransform.new(config)
|
481
|
+
|
482
|
+
# The output weights are the same as the input embeddings, but there is
|
483
|
+
# an output-only bias for each token.
|
484
|
+
@decoder = Torch::NN::Linear.new(config.hidden_size, config.vocab_size, bias: false)
|
485
|
+
|
486
|
+
@bias = Torch::NN::Parameter.new(Torch.zeros(config.vocab_size))
|
487
|
+
|
488
|
+
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
489
|
+
@decoder.instance_variable_set(:@bias, @bias)
|
490
|
+
end
|
491
|
+
|
492
|
+
def _tie_weights
|
493
|
+
@decoder.instance_variable_set(:@bias, @bias)
|
494
|
+
end
|
495
|
+
|
496
|
+
def forward(hidden_states)
|
497
|
+
hidden_states = @transform.(hidden_states)
|
498
|
+
hidden_states = @decoder.(hidden_states)
|
499
|
+
hidden_states
|
500
|
+
end
|
501
|
+
end
|
502
|
+
|
503
|
+
class BertOnlyMLMHead < Torch::NN::Module
|
504
|
+
def initialize(config)
|
505
|
+
super()
|
506
|
+
@predictions = BertLMPredictionHead.new(config)
|
507
|
+
end
|
508
|
+
|
509
|
+
def forward(sequence_output)
|
510
|
+
prediction_scores = @predictions.(sequence_output)
|
511
|
+
prediction_scores
|
512
|
+
end
|
513
|
+
end
|
514
|
+
|
515
|
+
class BertPreTrainedModel < PreTrainedModel
|
516
|
+
self.config_class = BertConfig
|
517
|
+
self.base_model_prefix = "bert"
|
518
|
+
|
519
|
+
def _init_weights(mod)
|
520
|
+
if mod.is_a?(Torch::NN::Linear)
|
521
|
+
mod.weight.data.normal!(mean: 0.0, std: @config.initializer_range)
|
522
|
+
if !mod.bias.nil?
|
523
|
+
mod.bias.data.zero!
|
524
|
+
end
|
525
|
+
elsif mod.is_a?(Torch::NN::Embedding)
|
526
|
+
mod.weight.data.normal!(mean: 0.0, std: @config.initializer_range)
|
527
|
+
if !mod.instance_variable_get(:@padding_idx).nil?
|
528
|
+
mod.weight.data[mod.instance_variable_get(:@padding_idx)].zero!
|
529
|
+
end
|
530
|
+
elsif mod.is_a?(Torch::NN::LayerNorm)
|
531
|
+
mod.bias.data.zero!
|
532
|
+
mod.weight.data.fill!(1.0)
|
533
|
+
end
|
534
|
+
end
|
535
|
+
end
|
536
|
+
|
537
|
+
class BertModel < BertPreTrainedModel
|
538
|
+
def initialize(config, add_pooling_layer: true)
|
539
|
+
super(config)
|
540
|
+
@config = config
|
541
|
+
|
542
|
+
@embeddings = BertEmbeddings.new(config)
|
543
|
+
@encoder = BertEncoder.new(config)
|
544
|
+
|
545
|
+
@pooler = add_pooling_layer ? BertPooler.new(config) : nil
|
546
|
+
|
547
|
+
@attn_implementation = config._attn_implementation
|
548
|
+
@position_embedding_type = config.position_embedding_type
|
549
|
+
|
550
|
+
# Initialize weights and apply final processing
|
551
|
+
post_init
|
552
|
+
end
|
553
|
+
|
554
|
+
def _prune_heads(heads_to_prune)
|
555
|
+
heads_to_prune.each do |layer, heads|
|
556
|
+
@encoder.layer[layer].attention.prune_heads(heads)
|
557
|
+
end
|
558
|
+
end
|
559
|
+
|
560
|
+
def forward(
|
561
|
+
input_ids: nil,
|
562
|
+
attention_mask: nil,
|
563
|
+
token_type_ids: nil,
|
564
|
+
position_ids: nil,
|
565
|
+
head_mask: nil,
|
566
|
+
inputs_embeds: nil,
|
567
|
+
encoder_hidden_states: nil,
|
568
|
+
encoder_attention_mask: nil,
|
569
|
+
past_key_values: nil,
|
570
|
+
use_cache: nil,
|
571
|
+
output_attentions: nil,
|
572
|
+
output_hidden_states: nil,
|
573
|
+
return_dict: nil
|
574
|
+
)
|
575
|
+
output_attentions = !output_attentions.nil? ? output_attentions : @config.output_attentions
|
576
|
+
output_hidden_states = (
|
577
|
+
!output_hidden_states.nil? ? output_hidden_states : @config.output_hidden_states
|
578
|
+
)
|
579
|
+
return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
|
580
|
+
|
581
|
+
if @config.is_decoder
|
582
|
+
use_cache = !use_cache.nil? ? use_cache : @config.use_cache
|
583
|
+
else
|
584
|
+
use_cache = false
|
585
|
+
end
|
586
|
+
|
587
|
+
if !input_ids.nil? && !inputs_embeds.nil?
|
588
|
+
raise ArgumentError, "You cannot specify both input_ids and inputs_embeds at the same time"
|
589
|
+
elsif !input_ids.nil?
|
590
|
+
# self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
591
|
+
input_shape = input_ids.size
|
592
|
+
elsif !inputs_embeds.nil?
|
593
|
+
input_shape = inputs_embeds.size[...-1]
|
594
|
+
else
|
595
|
+
raise ArgumentError, "You have to specify either input_ids or inputs_embeds"
|
596
|
+
end
|
597
|
+
|
598
|
+
batch_size, seq_length = input_shape
|
599
|
+
device = !input_ids.nil? ? input_ids.device : inputs_embeds.device
|
600
|
+
|
601
|
+
# past_key_values_length
|
602
|
+
past_key_values_length = !past_key_values.nil? ? past_key_values[0][0].shape[2] : 0
|
603
|
+
|
604
|
+
if token_type_ids.nil?
|
605
|
+
if @embeddings.token_type_ids
|
606
|
+
buffered_token_type_ids = @embeddings.token_type_ids[0.., 0...seq_length]
|
607
|
+
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
|
608
|
+
token_type_ids = buffered_token_type_ids_expanded
|
609
|
+
else
|
610
|
+
token_type_ids = Torch.zeros(input_shape, dtype: Torch.long, device: device)
|
611
|
+
end
|
612
|
+
end
|
613
|
+
|
614
|
+
embedding_output = @embeddings.(
|
615
|
+
input_ids: input_ids,
|
616
|
+
position_ids: position_ids,
|
617
|
+
token_type_ids: token_type_ids,
|
618
|
+
inputs_embeds: inputs_embeds,
|
619
|
+
past_key_values_length: past_key_values_length
|
620
|
+
)
|
621
|
+
|
622
|
+
if attention_mask.nil?
|
623
|
+
attention_mask = Torch.ones([batch_size, seq_length + past_key_values_length], device: device)
|
624
|
+
end
|
625
|
+
|
626
|
+
use_sdpa_attention_masks = (
|
627
|
+
@attn_implementation == "sdpa" &&
|
628
|
+
@position_embedding_type == "absolute" &&
|
629
|
+
head_mask.nil? &&
|
630
|
+
!output_attentions
|
631
|
+
)
|
632
|
+
|
633
|
+
# Expand the attention mask
|
634
|
+
if use_sdpa_attention_masks
|
635
|
+
raise Todo
|
636
|
+
else
|
637
|
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
638
|
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
639
|
+
extended_attention_mask = get_extended_attention_mask(attention_mask, input_shape)
|
640
|
+
end
|
641
|
+
|
642
|
+
# # If a 2D or 3D attention mask is provided for the cross-attention
|
643
|
+
# # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
644
|
+
if @config.is_decoder && !encoder_hidden_states.nil?
|
645
|
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size
|
646
|
+
encoder_hidden_shape = [encoder_batch_size, encoder_sequence_length]
|
647
|
+
if encoder_attention_mask.nil?
|
648
|
+
encoder_attention_mask = Torch.ones(encoder_hidden_shape, device: device)
|
649
|
+
end
|
650
|
+
|
651
|
+
if use_sdpa_attention_masks
|
652
|
+
# Expand the attention mask for SDPA.
|
653
|
+
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
|
654
|
+
encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
655
|
+
encoder_attention_mask, embedding_output.dtype, tgt_len: seq_length
|
656
|
+
)
|
657
|
+
else
|
658
|
+
encoder_extended_attention_mask = invert_attention_mask(encoder_attention_mask)
|
659
|
+
end
|
660
|
+
else
|
661
|
+
encoder_extended_attention_mask = nil
|
662
|
+
end
|
663
|
+
|
664
|
+
# Prepare head mask if needed
|
665
|
+
# 1.0 in head_mask indicate we keep the head
|
666
|
+
# attention_probs has shape bsz x n_heads x N x N
|
667
|
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
668
|
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
669
|
+
head_mask = get_head_mask(head_mask, @config.num_hidden_layers)
|
670
|
+
|
671
|
+
encoder_outputs = @encoder.(
|
672
|
+
embedding_output,
|
673
|
+
attention_mask: extended_attention_mask,
|
674
|
+
head_mask: head_mask,
|
675
|
+
encoder_hidden_states: encoder_hidden_states,
|
676
|
+
encoder_attention_mask: encoder_extended_attention_mask,
|
677
|
+
past_key_values: past_key_values,
|
678
|
+
use_cache: use_cache,
|
679
|
+
output_attentions: output_attentions,
|
680
|
+
output_hidden_states: output_hidden_states,
|
681
|
+
return_dict: return_dict
|
682
|
+
)
|
683
|
+
sequence_output = encoder_outputs[0]
|
684
|
+
pooled_output = !@pooler.nil? ? @pooler.(sequence_output) : nil
|
685
|
+
|
686
|
+
if !return_dict
|
687
|
+
raise Todo
|
688
|
+
end
|
689
|
+
|
690
|
+
BaseModelOutputWithPoolingAndCrossAttentions.new(
|
691
|
+
last_hidden_state: sequence_output,
|
692
|
+
pooler_output: pooled_output,
|
693
|
+
past_key_values: encoder_outputs.past_key_values,
|
694
|
+
hidden_states: encoder_outputs.hidden_states,
|
695
|
+
attentions: encoder_outputs.attentions,
|
696
|
+
cross_attentions: encoder_outputs.cross_attentions
|
697
|
+
)
|
698
|
+
end
|
699
|
+
end
|
700
|
+
|
701
|
+
class BertForMaskedLM < BertPreTrainedModel
|
702
|
+
def initialize(config)
|
703
|
+
super(config)
|
704
|
+
|
705
|
+
if config.is_decoder
|
706
|
+
Transformers.logger.warn(
|
707
|
+
"If you want to use `BertForMaskedLM` make sure `config.is_decoder: false` for " +
|
708
|
+
"bi-directional self-attention."
|
709
|
+
)
|
710
|
+
end
|
711
|
+
|
712
|
+
@bert = BertModel.new(config, add_pooling_layer: false)
|
713
|
+
@cls = BertOnlyMLMHead.new(config)
|
714
|
+
end
|
715
|
+
|
716
|
+
def forward(
|
717
|
+
input_ids: nil,
|
718
|
+
attention_mask: nil,
|
719
|
+
token_type_ids: nil,
|
720
|
+
position_ids: nil,
|
721
|
+
head_mask: nil,
|
722
|
+
inputs_embeds: nil,
|
723
|
+
encoder_hidden_states: nil,
|
724
|
+
encoder_attention_mask: nil,
|
725
|
+
labels: nil,
|
726
|
+
output_attentions: nil,
|
727
|
+
output_hidden_states: nil,
|
728
|
+
return_dict: nil
|
729
|
+
)
|
730
|
+
return_dict = !return_dict.nil? ? return_dict : config.use_return_dict
|
731
|
+
|
732
|
+
outputs = @bert.(
|
733
|
+
input_ids: input_ids,
|
734
|
+
attention_mask: attention_mask,
|
735
|
+
token_type_ids: token_type_ids,
|
736
|
+
position_ids: position_ids,
|
737
|
+
head_mask: head_mask,
|
738
|
+
inputs_embeds: inputs_embeds,
|
739
|
+
encoder_hidden_states: encoder_hidden_states,
|
740
|
+
encoder_attention_mask: encoder_attention_mask,
|
741
|
+
output_attentions: output_attentions,
|
742
|
+
output_hidden_states: output_hidden_states,
|
743
|
+
return_dict: return_dict
|
744
|
+
)
|
745
|
+
|
746
|
+
sequence_output = outputs[0]
|
747
|
+
prediction_scores = @cls.(sequence_output)
|
748
|
+
|
749
|
+
masked_lm_loss = nil
|
750
|
+
if !labels.nil?
|
751
|
+
raise Todo
|
752
|
+
end
|
753
|
+
|
754
|
+
if !return_dict
|
755
|
+
raise Todo
|
756
|
+
end
|
757
|
+
|
758
|
+
MaskedLMOutput.new(
|
759
|
+
loss: masked_lm_loss,
|
760
|
+
logits: prediction_scores,
|
761
|
+
hidden_states: outputs.hidden_states,
|
762
|
+
attentions: outputs.attentions
|
763
|
+
)
|
764
|
+
end
|
765
|
+
end
|
766
|
+
|
767
|
+
class BertForTokenClassification < BertPreTrainedModel
|
768
|
+
def initialize(config)
|
769
|
+
super(config)
|
770
|
+
@num_labels = config.num_labels
|
771
|
+
|
772
|
+
@bert = BertModel.new(config, add_pooling_layer: false)
|
773
|
+
classifier_dropout = (
|
774
|
+
!config.classifier_dropout.nil? ? config.classifier_dropout : config.hidden_dropout_prob
|
775
|
+
)
|
776
|
+
@dropout = Torch::NN::Dropout.new(p: classifier_dropout)
|
777
|
+
@classifier = Torch::NN::Linear.new(config.hidden_size, config.num_labels)
|
778
|
+
|
779
|
+
# Initialize weights and apply final processing
|
780
|
+
post_init
|
781
|
+
end
|
782
|
+
|
783
|
+
def forward(
|
784
|
+
input_ids: nil,
|
785
|
+
attention_mask: nil,
|
786
|
+
token_type_ids: nil,
|
787
|
+
position_ids: nil,
|
788
|
+
head_mask: nil,
|
789
|
+
inputs_embeds: nil,
|
790
|
+
labels: nil,
|
791
|
+
output_attentions: nil,
|
792
|
+
output_hidden_states: nil,
|
793
|
+
return_dict: nil
|
794
|
+
)
|
795
|
+
return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
|
796
|
+
|
797
|
+
outputs = @bert.(
|
798
|
+
input_ids: input_ids,
|
799
|
+
attention_mask: attention_mask,
|
800
|
+
token_type_ids: token_type_ids,
|
801
|
+
position_ids: position_ids,
|
802
|
+
head_mask: head_mask,
|
803
|
+
inputs_embeds: inputs_embeds,
|
804
|
+
output_attentions: output_attentions,
|
805
|
+
output_hidden_states: output_hidden_states,
|
806
|
+
return_dict: return_dict
|
807
|
+
)
|
808
|
+
|
809
|
+
sequence_output = outputs[0]
|
810
|
+
|
811
|
+
sequence_output = @dropout.(sequence_output)
|
812
|
+
logits = @classifier.(sequence_output)
|
813
|
+
|
814
|
+
loss = nil
|
815
|
+
if !labels.nil?
|
816
|
+
loss_fct = CrossEntropyLoss.new
|
817
|
+
loss = loss_fct.(logits.view(-1,@num_labels), labels.view(-1))
|
818
|
+
end
|
819
|
+
|
820
|
+
if !return_dict
|
821
|
+
raise Todo
|
822
|
+
end
|
823
|
+
|
824
|
+
TokenClassifierOutput.new(
|
825
|
+
loss: loss,
|
826
|
+
logits: logits,
|
827
|
+
hidden_states: outputs.hidden_states,
|
828
|
+
attentions: outputs.attentions
|
829
|
+
)
|
830
|
+
end
|
831
|
+
end
|
832
|
+
end
|
833
|
+
|
834
|
+
BertModel = Bert::BertModel
|
835
|
+
BertForTokenClassification = Bert::BertForTokenClassification
|
836
|
+
end
|