transformers-rb 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +203 -0
- data/README.md +163 -0
- data/lib/transformers/activations.rb +57 -0
- data/lib/transformers/configuration_utils.rb +285 -0
- data/lib/transformers/convert_slow_tokenizer.rb +90 -0
- data/lib/transformers/data/processors/squad.rb +115 -0
- data/lib/transformers/dynamic_module_utils.rb +25 -0
- data/lib/transformers/feature_extraction_utils.rb +110 -0
- data/lib/transformers/hf_hub/constants.rb +71 -0
- data/lib/transformers/hf_hub/errors.rb +11 -0
- data/lib/transformers/hf_hub/file_download.rb +764 -0
- data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
- data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
- data/lib/transformers/image_processing_base.rb +169 -0
- data/lib/transformers/image_processing_utils.rb +63 -0
- data/lib/transformers/image_transforms.rb +208 -0
- data/lib/transformers/image_utils.rb +165 -0
- data/lib/transformers/modeling_outputs.rb +81 -0
- data/lib/transformers/modeling_utils.rb +888 -0
- data/lib/transformers/models/auto/auto_factory.rb +138 -0
- data/lib/transformers/models/auto/configuration_auto.rb +61 -0
- data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
- data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
- data/lib/transformers/models/auto/modeling_auto.rb +80 -0
- data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
- data/lib/transformers/models/bert/configuration_bert.rb +65 -0
- data/lib/transformers/models/bert/modeling_bert.rb +836 -0
- data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
- data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
- data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
- data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
- data/lib/transformers/models/vit/configuration_vit.rb +60 -0
- data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
- data/lib/transformers/models/vit/modeling_vit.rb +506 -0
- data/lib/transformers/pipelines/_init.rb +348 -0
- data/lib/transformers/pipelines/base.rb +301 -0
- data/lib/transformers/pipelines/feature_extraction.rb +47 -0
- data/lib/transformers/pipelines/image_classification.rb +110 -0
- data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
- data/lib/transformers/pipelines/pt_utils.rb +53 -0
- data/lib/transformers/pipelines/question_answering.rb +508 -0
- data/lib/transformers/pipelines/text_classification.rb +123 -0
- data/lib/transformers/pipelines/token_classification.rb +282 -0
- data/lib/transformers/ruby_utils.rb +33 -0
- data/lib/transformers/sentence_transformer.rb +37 -0
- data/lib/transformers/tokenization_utils.rb +152 -0
- data/lib/transformers/tokenization_utils_base.rb +937 -0
- data/lib/transformers/tokenization_utils_fast.rb +386 -0
- data/lib/transformers/torch_utils.rb +25 -0
- data/lib/transformers/utils/_init.rb +31 -0
- data/lib/transformers/utils/generic.rb +107 -0
- data/lib/transformers/utils/hub.rb +209 -0
- data/lib/transformers/utils/import_utils.rb +45 -0
- data/lib/transformers/utils/logging.rb +52 -0
- data/lib/transformers/version.rb +3 -0
- data/lib/transformers-rb.rb +1 -0
- data/lib/transformers.rb +100 -0
- data/licenses/LICENSE-huggingface-hub.txt +201 -0
- data/licenses/LICENSE-sentence-transformers.txt +201 -0
- data/licenses/NOTICE-sentence-transformers.txt +5 -0
- metadata +161 -0
@@ -0,0 +1,616 @@
|
|
1
|
+
# Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
module Distilbert
|
17
|
+
class Embeddings < Torch::NN::Module
|
18
|
+
def initialize(config)
|
19
|
+
super()
|
20
|
+
@word_embeddings = Torch::NN::Embedding.new(config.vocab_size, config.dim, padding_idx: config.pad_token_id)
|
21
|
+
@position_embeddings = Torch::NN::Embedding.new(config.max_position_embeddings, config.dim)
|
22
|
+
|
23
|
+
@LayerNorm = Torch::NN::LayerNorm.new(config.dim, eps: 1e-12)
|
24
|
+
@dropout = Torch::NN::Dropout.new(p: config.dropout)
|
25
|
+
register_buffer(
|
26
|
+
"position_ids", Torch.arange(config.max_position_embeddings).expand([1, -1]), persistent: false
|
27
|
+
)
|
28
|
+
end
|
29
|
+
|
30
|
+
def forward(input_ids, input_embeds)
|
31
|
+
if !input_ids.nil?
|
32
|
+
input_embeds = @word_embeddings.(input_ids) # (bs, max_seq_length, dim)
|
33
|
+
end
|
34
|
+
|
35
|
+
seq_length = input_embeds.size(1)
|
36
|
+
|
37
|
+
# Setting the position-ids to the registered buffer in constructor, it helps
|
38
|
+
# when tracing the model without passing position-ids, solves
|
39
|
+
# isues similar to issue #5664
|
40
|
+
if @position_ids
|
41
|
+
position_ids = @position_ids[0.., 0...seq_length]
|
42
|
+
else
|
43
|
+
position_ids = Torch.arange(seq_length, dtype: :long, device: input_ids.device) # (max_seq_length)
|
44
|
+
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length)
|
45
|
+
end
|
46
|
+
|
47
|
+
position_embeddings = @position_embeddings.(position_ids) # (bs, max_seq_length, dim)
|
48
|
+
|
49
|
+
embeddings = input_embeds + position_embeddings # (bs, max_seq_length, dim)
|
50
|
+
embeddings = @LayerNorm.(embeddings) # (bs, max_seq_length, dim)
|
51
|
+
embeddings = @dropout.(embeddings) # (bs, max_seq_length, dim)
|
52
|
+
embeddings
|
53
|
+
end
|
54
|
+
end
|
55
|
+
|
56
|
+
class MultiHeadSelfAttention < Torch::NN::Module
|
57
|
+
def initialize(config)
|
58
|
+
super()
|
59
|
+
@config = config
|
60
|
+
|
61
|
+
@n_heads = config.n_heads
|
62
|
+
@dim = config.dim
|
63
|
+
@dropout = Torch::NN::Dropout.new(p: config.attention_dropout)
|
64
|
+
@is_causal = false
|
65
|
+
|
66
|
+
# Have an even number of multi heads that divide the dimensions
|
67
|
+
if @dim % @n_heads != 0
|
68
|
+
# Raise value errors for even multi-head attention nodes
|
69
|
+
raise ArgumentError, "self.n_heads: #{@n_heads} must divide self.dim: #{@dim} evenly"
|
70
|
+
end
|
71
|
+
|
72
|
+
@q_lin = Torch::NN::Linear.new(config.dim, config.dim)
|
73
|
+
@k_lin = Torch::NN::Linear.new(config.dim, config.dim)
|
74
|
+
@v_lin = Torch::NN::Linear.new(config.dim, config.dim)
|
75
|
+
@out_lin = Torch::NN::Linear.new(config.dim, config.dim)
|
76
|
+
|
77
|
+
@pruned_heads = Set.new
|
78
|
+
@attention_head_size = @dim.div(@n_heads)
|
79
|
+
end
|
80
|
+
|
81
|
+
def prune_heads(heads)
|
82
|
+
if heads.length == 0
|
83
|
+
return
|
84
|
+
end
|
85
|
+
raise Todo
|
86
|
+
end
|
87
|
+
|
88
|
+
def forward(
|
89
|
+
query:,
|
90
|
+
key:,
|
91
|
+
value:,
|
92
|
+
mask:,
|
93
|
+
head_mask: nil,
|
94
|
+
output_attentions: false
|
95
|
+
)
|
96
|
+
bs, _q_length, dim = query.size
|
97
|
+
k_length = key.size(1)
|
98
|
+
if dim != @dim
|
99
|
+
raise "Dimensions do not match: #{dim} input vs #{@dim} configured"
|
100
|
+
end
|
101
|
+
if key.size != value.size
|
102
|
+
raise Todo
|
103
|
+
end
|
104
|
+
|
105
|
+
dim_per_head = @dim.div(@n_heads)
|
106
|
+
|
107
|
+
mask_reshp = [bs, 1, 1, k_length]
|
108
|
+
|
109
|
+
shape = lambda do |x|
|
110
|
+
x.view(bs, -1, @n_heads, dim_per_head).transpose(1, 2)
|
111
|
+
end
|
112
|
+
|
113
|
+
unshape = lambda do |x|
|
114
|
+
x.transpose(1, 2).contiguous.view(bs, -1, @n_heads * dim_per_head)
|
115
|
+
end
|
116
|
+
|
117
|
+
q = shape.(@q_lin.(query)) # (bs, n_heads, q_length, dim_per_head)
|
118
|
+
k = shape.(@k_lin.(key)) # (bs, n_heads, k_length, dim_per_head)
|
119
|
+
v = shape.(@v_lin.(value)) # (bs, n_heads, k_length, dim_per_head)
|
120
|
+
|
121
|
+
q = q / Math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head)
|
122
|
+
scores = Torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)
|
123
|
+
mask = (mask.eq(0)).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length)
|
124
|
+
scores =
|
125
|
+
scores.masked_fill(
|
126
|
+
# TODO use Torch.finfo
|
127
|
+
mask, Torch.tensor(0)
|
128
|
+
) # (bs, n_heads, q_length, k_length)
|
129
|
+
|
130
|
+
weights = Torch::NN::Functional.softmax(scores, dim: -1) # (bs, n_heads, q_length, k_length)
|
131
|
+
weights = @dropout.(weights) # (bs, n_heads, q_length, k_length)
|
132
|
+
|
133
|
+
# Mask heads if we want to
|
134
|
+
if !head_mask.nil?
|
135
|
+
weights = weights * head_mask
|
136
|
+
end
|
137
|
+
|
138
|
+
context = Torch.matmul(weights, v) # (bs, n_heads, q_length, dim_per_head)
|
139
|
+
context = unshape.(context) # (bs, q_length, dim)
|
140
|
+
context = @out_lin.(context) # (bs, q_length, dim)
|
141
|
+
|
142
|
+
if output_attentions
|
143
|
+
[context, weights]
|
144
|
+
else
|
145
|
+
[context]
|
146
|
+
end
|
147
|
+
end
|
148
|
+
end
|
149
|
+
|
150
|
+
class DistilBertFlashAttention2 < MultiHeadSelfAttention
|
151
|
+
end
|
152
|
+
|
153
|
+
class FFN < Torch::NN::Module
|
154
|
+
def initialize(config)
|
155
|
+
super()
|
156
|
+
@dropout = Torch::NN::Dropout.new(p: config.dropout)
|
157
|
+
@chunk_size_feed_forward = config.chunk_size_feed_forward
|
158
|
+
@seq_len_dim = 1
|
159
|
+
@lin1 = Torch::NN::Linear.new(config.dim, config.hidden_dim)
|
160
|
+
@lin2 = Torch::NN::Linear.new(config.hidden_dim, config.dim)
|
161
|
+
@activation = Activations.get_activation(config.activation)
|
162
|
+
end
|
163
|
+
|
164
|
+
def forward(input)
|
165
|
+
TorchUtils.apply_chunking_to_forward(method(:ff_chunk), @chunk_size_feed_forward, @seq_len_dim, input)
|
166
|
+
end
|
167
|
+
|
168
|
+
def ff_chunk(input)
|
169
|
+
x = @lin1.(input)
|
170
|
+
x = @activation.(x)
|
171
|
+
x = @lin2.(x)
|
172
|
+
x = @dropout.(x)
|
173
|
+
x
|
174
|
+
end
|
175
|
+
end
|
176
|
+
|
177
|
+
DISTILBERT_ATTENTION_CLASSES = {
|
178
|
+
"eager" => MultiHeadSelfAttention,
|
179
|
+
"flash_attention_2" => DistilBertFlashAttention2
|
180
|
+
}
|
181
|
+
|
182
|
+
class TransformerBlock < Torch::NN::Module
|
183
|
+
def initialize(config)
|
184
|
+
super()
|
185
|
+
|
186
|
+
# Have an even number of Configure multi-heads
|
187
|
+
if config.dim % config.n_heads != 0
|
188
|
+
raise ArgumentError, "config.n_heads #{config.n_heads} must divide config.dim #{config.dim} evenly"
|
189
|
+
end
|
190
|
+
|
191
|
+
@attention = DISTILBERT_ATTENTION_CLASSES[config._attn_implementation].new(config)
|
192
|
+
@sa_layer_norm = Torch::NN::LayerNorm.new(config.dim, eps: 1e-12)
|
193
|
+
|
194
|
+
@ffn = FFN.new(config)
|
195
|
+
@output_layer_norm = Torch::NN::LayerNorm.new(config.dim, eps: 1e-12)
|
196
|
+
end
|
197
|
+
|
198
|
+
def forward(
|
199
|
+
x:,
|
200
|
+
attn_mask: nil,
|
201
|
+
head_mask: nil,
|
202
|
+
output_attentions: false
|
203
|
+
)
|
204
|
+
# Self-Attention
|
205
|
+
sa_output =
|
206
|
+
@attention.(
|
207
|
+
query: x,
|
208
|
+
key: x,
|
209
|
+
value: x,
|
210
|
+
mask: attn_mask,
|
211
|
+
head_mask: head_mask,
|
212
|
+
output_attentions: output_attentions,
|
213
|
+
)
|
214
|
+
if output_attentions
|
215
|
+
sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
|
216
|
+
else # To handle these `output_attentions` or `output_hidden_states` cases returning tuples
|
217
|
+
if !sa_output.is_a?(Array)
|
218
|
+
raise TypeError, "sa_output must be an array but it is #{sa_output.class.name} type"
|
219
|
+
end
|
220
|
+
|
221
|
+
sa_output = sa_output[0]
|
222
|
+
end
|
223
|
+
sa_output = @sa_layer_norm.(sa_output + x) # (bs, seq_length, dim)
|
224
|
+
|
225
|
+
# Feed Forward Network
|
226
|
+
ffn_output = @ffn.(sa_output) # (bs, seq_length, dim)
|
227
|
+
ffn_output = @output_layer_norm.(ffn_output + sa_output) # (bs, seq_length, dim)
|
228
|
+
|
229
|
+
output = [ffn_output]
|
230
|
+
if output_attentions
|
231
|
+
output = [sa_weights] + output
|
232
|
+
end
|
233
|
+
output
|
234
|
+
end
|
235
|
+
end
|
236
|
+
|
237
|
+
class Transformer < Torch::NN::Module
|
238
|
+
def initialize(config)
|
239
|
+
super()
|
240
|
+
@n_layers = config.n_layers
|
241
|
+
@layer = Torch::NN::ModuleList.new(config.n_layers.times.map { TransformerBlock.new(config) })
|
242
|
+
@gradient_checkpointing = false
|
243
|
+
end
|
244
|
+
|
245
|
+
def forward(
|
246
|
+
x:,
|
247
|
+
attn_mask: nil,
|
248
|
+
head_mask: nil,
|
249
|
+
output_attentions: false,
|
250
|
+
output_hidden_states: false,
|
251
|
+
return_dict: nil
|
252
|
+
)
|
253
|
+
all_hidden_states = output_hidden_states ? [] : nil
|
254
|
+
all_attentions = output_attentions ? [] : nil
|
255
|
+
|
256
|
+
hidden_state = x
|
257
|
+
@layer.each_with_index do |layer_module, i|
|
258
|
+
if output_hidden_states
|
259
|
+
all_hidden_states = all_hidden_states + [hidden_state]
|
260
|
+
end
|
261
|
+
|
262
|
+
if @gradient_checkpointing && training
|
263
|
+
layer_outputs =
|
264
|
+
_gradient_checkpointing_func(
|
265
|
+
layer_module.__call__,
|
266
|
+
hidden_state,
|
267
|
+
attn_mask,
|
268
|
+
head_mask[i],
|
269
|
+
output_attentions,
|
270
|
+
)
|
271
|
+
else
|
272
|
+
layer_outputs =
|
273
|
+
layer_module.(
|
274
|
+
x: hidden_state,
|
275
|
+
attn_mask: attn_mask,
|
276
|
+
head_mask: head_mask[i],
|
277
|
+
output_attentions: output_attentions
|
278
|
+
)
|
279
|
+
end
|
280
|
+
|
281
|
+
hidden_state = layer_outputs[-1]
|
282
|
+
|
283
|
+
if output_attentions
|
284
|
+
if layer_outputs.length != 2
|
285
|
+
raise ArgumentError, "The length of the layer_outputs should be 2, but it is #{layer_outputs.length}"
|
286
|
+
end
|
287
|
+
|
288
|
+
attentions = layer_outputs[0]
|
289
|
+
all_attentions = all_attentions + [attentions]
|
290
|
+
else
|
291
|
+
if layer_outputs.length != 1
|
292
|
+
raise ArgumentError, "The length of the layer_outputs should be 1, but it is #{layer_outputs.length}"
|
293
|
+
end
|
294
|
+
end
|
295
|
+
end
|
296
|
+
|
297
|
+
# Add last layer
|
298
|
+
if output_hidden_states
|
299
|
+
all_hidden_states = all_hidden_states + [hidden_state]
|
300
|
+
end
|
301
|
+
|
302
|
+
if !return_dict
|
303
|
+
raise Todo
|
304
|
+
end
|
305
|
+
BaseModelOutput.new(
|
306
|
+
last_hidden_state: hidden_state, hidden_states: all_hidden_states, attentions: all_attentions
|
307
|
+
)
|
308
|
+
end
|
309
|
+
end
|
310
|
+
|
311
|
+
class DistilBertPreTrainedModel < PreTrainedModel
|
312
|
+
self.config_class = DistilBertConfig
|
313
|
+
self.base_model_prefix = "distilbert"
|
314
|
+
|
315
|
+
def _init_weights(mod)
|
316
|
+
if mod.is_a?(Torch::NN::Linear)
|
317
|
+
mod.weight.data.normal!(mean: 0.0, std: @config.initializer_range)
|
318
|
+
if !mod.bias.nil?
|
319
|
+
mod.bias.data.zero!
|
320
|
+
end
|
321
|
+
elsif mod.is_a?(Torch::NN::Embedding)
|
322
|
+
mod.weight.data.normal!(mean: 0.0, std: @config.initializer_range)
|
323
|
+
if !mod.instance_variable_get(:@padding_idx).nil?
|
324
|
+
mod.weight.data[mod.instance_variable_get(:@padding_idx)].zero!
|
325
|
+
end
|
326
|
+
elsif mod.is_a?(Torch::NN::LayerNorm)
|
327
|
+
mod.bias.data.zero!
|
328
|
+
mod.weight.data.fill!(1.0)
|
329
|
+
elsif mod.is_a?(Embeddings) && @config.sinusoidal_pos_embds
|
330
|
+
create_sinusoidal_embeddings(
|
331
|
+
@config.max_position_embeddings, @config.dim, mod.position_embeddings.weight
|
332
|
+
)
|
333
|
+
end
|
334
|
+
end
|
335
|
+
|
336
|
+
private
|
337
|
+
|
338
|
+
def create_sinusoidal_embeddings(n_pos, dim, out)
|
339
|
+
# TODO
|
340
|
+
end
|
341
|
+
end
|
342
|
+
|
343
|
+
class DistilBertModel < DistilBertPreTrainedModel
|
344
|
+
def initialize(config)
|
345
|
+
super(config)
|
346
|
+
|
347
|
+
@embeddings = Embeddings.new(config) # Embeddings
|
348
|
+
@transformer = Transformer.new(config) # Encoder
|
349
|
+
@use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
350
|
+
|
351
|
+
# Initialize weights and apply final processing
|
352
|
+
post_init
|
353
|
+
end
|
354
|
+
|
355
|
+
def get_position_embeddings
|
356
|
+
@embeddings.position_embeddings
|
357
|
+
end
|
358
|
+
|
359
|
+
def get_input_embeddings
|
360
|
+
@embeddings.word_embeddings
|
361
|
+
end
|
362
|
+
|
363
|
+
def _prune_heads(heads_to_prune)
|
364
|
+
heads_to_prune.each do |layer, heads|
|
365
|
+
@transformer.layer[layer].attention.prune_heads(heads)
|
366
|
+
end
|
367
|
+
end
|
368
|
+
|
369
|
+
def forward(
|
370
|
+
input_ids: nil,
|
371
|
+
attention_mask: nil,
|
372
|
+
head_mask: nil,
|
373
|
+
inputs_embeds: nil,
|
374
|
+
output_attentions: nil,
|
375
|
+
output_hidden_states: nil,
|
376
|
+
return_dict: nil
|
377
|
+
)
|
378
|
+
output_attentions = !output_attentions.nil? ? output_attentions : @config.output_attentions
|
379
|
+
output_hidden_states = (
|
380
|
+
!output_hidden_states.nil? ? output_hidden_states : @config.output_hidden_states
|
381
|
+
)
|
382
|
+
return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
|
383
|
+
|
384
|
+
if !input_ids.nil? && !inputs_embeds.nil?
|
385
|
+
raise ArgumentError, "You cannot specify both input_ids and inputs_embeds at the same time"
|
386
|
+
elsif !input_ids.nil?
|
387
|
+
warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
388
|
+
input_shape = input_ids.size
|
389
|
+
elsif !inputs_embeds.nil?
|
390
|
+
input_shape = inputs_embeds.size[...-1]
|
391
|
+
else
|
392
|
+
raise ArgumentError, "You have to specify either input_ids or inputs_embeds"
|
393
|
+
end
|
394
|
+
|
395
|
+
device = !input_ids.nil? ? input_ids.device : inputs_embeds.device
|
396
|
+
|
397
|
+
# Prepare head mask if needed
|
398
|
+
head_mask = get_head_mask(head_mask, @config.num_hidden_layers)
|
399
|
+
|
400
|
+
embeddings = @embeddings.(input_ids, inputs_embeds) # (bs, seq_length, dim)
|
401
|
+
|
402
|
+
if @use_flash_attention_2
|
403
|
+
raise Todo
|
404
|
+
else
|
405
|
+
if attention_mask.nil?
|
406
|
+
attention_mask = Torch.ones(input_shape, device: device) # (bs, seq_length)
|
407
|
+
end
|
408
|
+
end
|
409
|
+
|
410
|
+
@transformer.(
|
411
|
+
x: embeddings,
|
412
|
+
attn_mask: attention_mask,
|
413
|
+
head_mask: head_mask,
|
414
|
+
output_attentions: output_attentions,
|
415
|
+
output_hidden_states: output_hidden_states,
|
416
|
+
return_dict: return_dict
|
417
|
+
)
|
418
|
+
end
|
419
|
+
end
|
420
|
+
|
421
|
+
class DistilBertForMaskedLM < DistilBertPreTrainedModel
|
422
|
+
self._tied_weights_keys = ["vocab_projector.weight"]
|
423
|
+
|
424
|
+
def initialize(config)
|
425
|
+
super(config)
|
426
|
+
|
427
|
+
@activation = get_activation(config.activation)
|
428
|
+
|
429
|
+
@distilbert = DistilBertModel.new(config)
|
430
|
+
@vocab_transform = Torch::NN::Linear.new(config.dim, config.dim)
|
431
|
+
@vocab_layer_norm = Torch::NN::LayerNorm.new(config.dim, eps: 1e-12)
|
432
|
+
@vocab_projector = Torch::NN::Linear.new(config.dim, config.vocab_size)
|
433
|
+
|
434
|
+
# Initialize weights and apply final processing
|
435
|
+
post_init
|
436
|
+
|
437
|
+
@mlm_loss_fct = Torch::NN::CrossEntropyLoss.new
|
438
|
+
end
|
439
|
+
|
440
|
+
def forward(
|
441
|
+
input_ids: nil,
|
442
|
+
attention_mask: nil,
|
443
|
+
head_mask: nil,
|
444
|
+
inputs_embeds: nil,
|
445
|
+
labels: nil,
|
446
|
+
output_attentions: nil,
|
447
|
+
output_hidden_states: nil,
|
448
|
+
return_dict: nil
|
449
|
+
)
|
450
|
+
return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
|
451
|
+
|
452
|
+
dlbrt_output = @distilbert.(
|
453
|
+
input_ids: input_ids,
|
454
|
+
attention_mask: attention_mask,
|
455
|
+
head_mask: head_mask,
|
456
|
+
inputs_embeds: inputs_embeds,
|
457
|
+
output_attentions: output_attentions,
|
458
|
+
output_hidden_states: output_hidden_states,
|
459
|
+
return_dict: return_dict
|
460
|
+
)
|
461
|
+
hidden_states = dlbrt_output[0] # (bs, seq_length, dim)
|
462
|
+
prediction_logits = @vocab_transform.(hidden_states) # (bs, seq_length, dim)
|
463
|
+
prediction_logits = @activation.(prediction_logits) # (bs, seq_length, dim)
|
464
|
+
prediction_logits = @vocab_layer_norm.(prediction_logits) # (bs, seq_length, dim)
|
465
|
+
prediction_logits = @vocab_projector.(prediction_logits) # (bs, seq_length, vocab_size)
|
466
|
+
|
467
|
+
mlm_loss = nil
|
468
|
+
if !labels.nil?
|
469
|
+
mlm_loss = @mlm_loss_fct.(prediction_logits.view(-1, prediction_logits.size(-1)), labels.view(-1))
|
470
|
+
end
|
471
|
+
|
472
|
+
if !return_dict
|
473
|
+
raise Todo
|
474
|
+
end
|
475
|
+
|
476
|
+
MaskedLMOutput.new(
|
477
|
+
loss: mlm_loss,
|
478
|
+
logits: prediction_logits,
|
479
|
+
hidden_states: dlbrt_output.hidden_states,
|
480
|
+
attentions: dlbrt_output.attentions
|
481
|
+
)
|
482
|
+
end
|
483
|
+
end
|
484
|
+
|
485
|
+
class DistilBertForSequenceClassification < DistilBertPreTrainedModel
|
486
|
+
def initialize(config)
|
487
|
+
super(config)
|
488
|
+
@num_labels = config.num_labels
|
489
|
+
@config = config
|
490
|
+
|
491
|
+
@distilbert = DistilBertModel.new(config)
|
492
|
+
@pre_classifier = Torch::NN::Linear.new(config.dim, config.dim)
|
493
|
+
@classifier = Torch::NN::Linear.new(config.dim, config.num_labels)
|
494
|
+
@dropout = Torch::NN::Dropout.new(p: config.seq_classif_dropout)
|
495
|
+
|
496
|
+
# Initialize weights and apply final processing
|
497
|
+
post_init
|
498
|
+
end
|
499
|
+
|
500
|
+
def forward(
|
501
|
+
input_ids: nil,
|
502
|
+
attention_mask: nil,
|
503
|
+
head_mask: nil,
|
504
|
+
inputs_embeds: nil,
|
505
|
+
labels: nil,
|
506
|
+
output_attentions: nil,
|
507
|
+
output_hidden_states: nil,
|
508
|
+
return_dict: nil
|
509
|
+
)
|
510
|
+
return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
|
511
|
+
|
512
|
+
distilbert_output =
|
513
|
+
@distilbert.(
|
514
|
+
input_ids: input_ids,
|
515
|
+
attention_mask: attention_mask,
|
516
|
+
head_mask: head_mask,
|
517
|
+
inputs_embeds: inputs_embeds,
|
518
|
+
output_attentions: output_attentions,
|
519
|
+
output_hidden_states: output_hidden_states,
|
520
|
+
return_dict: return_dict
|
521
|
+
)
|
522
|
+
hidden_state = distilbert_output[0] # (bs, seq_len, dim)
|
523
|
+
pooled_output = hidden_state[0.., 0] # (bs, dim)
|
524
|
+
pooled_output = @pre_classifier.(pooled_output) # (bs, dim)
|
525
|
+
pooled_output = Torch::NN::ReLU.new.(pooled_output) # (bs, dim)
|
526
|
+
pooled_output = @dropout.(pooled_output) # (bs, dim)
|
527
|
+
logits = @classifier.(pooled_output) # (bs, num_labels)
|
528
|
+
|
529
|
+
loss = nil
|
530
|
+
if !labels.nil?
|
531
|
+
raise Todo
|
532
|
+
end
|
533
|
+
|
534
|
+
if !return_dict
|
535
|
+
raise Todo
|
536
|
+
end
|
537
|
+
|
538
|
+
SequenceClassifierOutput.new(
|
539
|
+
loss: loss,
|
540
|
+
logits: logits,
|
541
|
+
hidden_states: distilbert_output.hidden_states,
|
542
|
+
attentions: distilbert_output.attentions
|
543
|
+
)
|
544
|
+
end
|
545
|
+
end
|
546
|
+
|
547
|
+
class DistilBertForQuestionAnswering < DistilBertPreTrainedModel
|
548
|
+
def initialize(config)
|
549
|
+
super(config)
|
550
|
+
|
551
|
+
@distilbert = DistilBertModel.new(config)
|
552
|
+
@qa_outputs = Torch::NN::Linear.new(config.dim, config.num_labels)
|
553
|
+
if config.num_labels != 2
|
554
|
+
raise ArgumentError, "config.num_labels should be 2, but it is #{config.num_labels}"
|
555
|
+
end
|
556
|
+
|
557
|
+
@dropout = Torch::NN::Dropout.new(p: config.qa_dropout)
|
558
|
+
|
559
|
+
# Initialize weights and apply final processing
|
560
|
+
post_init
|
561
|
+
end
|
562
|
+
|
563
|
+
def forward(
|
564
|
+
input_ids: nil,
|
565
|
+
attention_mask: nil,
|
566
|
+
head_mask: nil,
|
567
|
+
inputs_embeds: nil,
|
568
|
+
start_positions: nil,
|
569
|
+
end_positions: nil,
|
570
|
+
output_attentions: nil,
|
571
|
+
output_hidden_states: nil,
|
572
|
+
return_dict: nil
|
573
|
+
)
|
574
|
+
return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
|
575
|
+
|
576
|
+
distilbert_output = @distilbert.(
|
577
|
+
input_ids: input_ids,
|
578
|
+
attention_mask: attention_mask,
|
579
|
+
head_mask: head_mask,
|
580
|
+
inputs_embeds: inputs_embeds,
|
581
|
+
output_attentions: output_attentions,
|
582
|
+
output_hidden_states: output_hidden_states,
|
583
|
+
return_dict: return_dict
|
584
|
+
)
|
585
|
+
hidden_states = distilbert_output[0] # (bs, max_query_len, dim)
|
586
|
+
|
587
|
+
hidden_states = @dropout.(hidden_states) # (bs, max_query_len, dim)
|
588
|
+
logits = @qa_outputs.(hidden_states) # (bs, max_query_len, 2)
|
589
|
+
start_logits, end_logits = logits.split(1, dim: -1)
|
590
|
+
start_logits = start_logits.squeeze(-1).contiguous # (bs, max_query_len)
|
591
|
+
end_logits = end_logits.squeeze(-1).contiguous # (bs, max_query_len)
|
592
|
+
|
593
|
+
total_loss = nil
|
594
|
+
if !start_positions.nil? && !end_positions.nil?
|
595
|
+
raise Todo
|
596
|
+
end
|
597
|
+
|
598
|
+
if !return_dict
|
599
|
+
raise Todo
|
600
|
+
end
|
601
|
+
|
602
|
+
QuestionAnsweringModelOutput.new(
|
603
|
+
loss: total_loss,
|
604
|
+
start_logits: start_logits,
|
605
|
+
end_logits: end_logits,
|
606
|
+
hidden_states: distilbert_output.hidden_states,
|
607
|
+
attentions: distilbert_output.attentions
|
608
|
+
)
|
609
|
+
end
|
610
|
+
end
|
611
|
+
end
|
612
|
+
|
613
|
+
DistilBertForMaskedLM = Distilbert::DistilBertForMaskedLM
|
614
|
+
DistilBertForSequenceClassification = Distilbert::DistilBertForSequenceClassification
|
615
|
+
DistilBertForQuestionAnswering = Distilbert::DistilBertForQuestionAnswering
|
616
|
+
end
|