transformers-rb 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (65) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +3 -0
  3. data/LICENSE.txt +203 -0
  4. data/README.md +163 -0
  5. data/lib/transformers/activations.rb +57 -0
  6. data/lib/transformers/configuration_utils.rb +285 -0
  7. data/lib/transformers/convert_slow_tokenizer.rb +90 -0
  8. data/lib/transformers/data/processors/squad.rb +115 -0
  9. data/lib/transformers/dynamic_module_utils.rb +25 -0
  10. data/lib/transformers/feature_extraction_utils.rb +110 -0
  11. data/lib/transformers/hf_hub/constants.rb +71 -0
  12. data/lib/transformers/hf_hub/errors.rb +11 -0
  13. data/lib/transformers/hf_hub/file_download.rb +764 -0
  14. data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
  15. data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
  16. data/lib/transformers/image_processing_base.rb +169 -0
  17. data/lib/transformers/image_processing_utils.rb +63 -0
  18. data/lib/transformers/image_transforms.rb +208 -0
  19. data/lib/transformers/image_utils.rb +165 -0
  20. data/lib/transformers/modeling_outputs.rb +81 -0
  21. data/lib/transformers/modeling_utils.rb +888 -0
  22. data/lib/transformers/models/auto/auto_factory.rb +138 -0
  23. data/lib/transformers/models/auto/configuration_auto.rb +61 -0
  24. data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
  25. data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
  26. data/lib/transformers/models/auto/modeling_auto.rb +80 -0
  27. data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
  28. data/lib/transformers/models/bert/configuration_bert.rb +65 -0
  29. data/lib/transformers/models/bert/modeling_bert.rb +836 -0
  30. data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
  31. data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
  32. data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
  33. data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
  34. data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
  35. data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
  36. data/lib/transformers/models/vit/configuration_vit.rb +60 -0
  37. data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
  38. data/lib/transformers/models/vit/modeling_vit.rb +506 -0
  39. data/lib/transformers/pipelines/_init.rb +348 -0
  40. data/lib/transformers/pipelines/base.rb +301 -0
  41. data/lib/transformers/pipelines/feature_extraction.rb +47 -0
  42. data/lib/transformers/pipelines/image_classification.rb +110 -0
  43. data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
  44. data/lib/transformers/pipelines/pt_utils.rb +53 -0
  45. data/lib/transformers/pipelines/question_answering.rb +508 -0
  46. data/lib/transformers/pipelines/text_classification.rb +123 -0
  47. data/lib/transformers/pipelines/token_classification.rb +282 -0
  48. data/lib/transformers/ruby_utils.rb +33 -0
  49. data/lib/transformers/sentence_transformer.rb +37 -0
  50. data/lib/transformers/tokenization_utils.rb +152 -0
  51. data/lib/transformers/tokenization_utils_base.rb +937 -0
  52. data/lib/transformers/tokenization_utils_fast.rb +386 -0
  53. data/lib/transformers/torch_utils.rb +25 -0
  54. data/lib/transformers/utils/_init.rb +31 -0
  55. data/lib/transformers/utils/generic.rb +107 -0
  56. data/lib/transformers/utils/hub.rb +209 -0
  57. data/lib/transformers/utils/import_utils.rb +45 -0
  58. data/lib/transformers/utils/logging.rb +52 -0
  59. data/lib/transformers/version.rb +3 -0
  60. data/lib/transformers-rb.rb +1 -0
  61. data/lib/transformers.rb +100 -0
  62. data/licenses/LICENSE-huggingface-hub.txt +201 -0
  63. data/licenses/LICENSE-sentence-transformers.txt +201 -0
  64. data/licenses/NOTICE-sentence-transformers.txt +5 -0
  65. metadata +161 -0
@@ -0,0 +1,616 @@
1
+ # Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ module Distilbert
17
+ class Embeddings < Torch::NN::Module
18
+ def initialize(config)
19
+ super()
20
+ @word_embeddings = Torch::NN::Embedding.new(config.vocab_size, config.dim, padding_idx: config.pad_token_id)
21
+ @position_embeddings = Torch::NN::Embedding.new(config.max_position_embeddings, config.dim)
22
+
23
+ @LayerNorm = Torch::NN::LayerNorm.new(config.dim, eps: 1e-12)
24
+ @dropout = Torch::NN::Dropout.new(p: config.dropout)
25
+ register_buffer(
26
+ "position_ids", Torch.arange(config.max_position_embeddings).expand([1, -1]), persistent: false
27
+ )
28
+ end
29
+
30
+ def forward(input_ids, input_embeds)
31
+ if !input_ids.nil?
32
+ input_embeds = @word_embeddings.(input_ids) # (bs, max_seq_length, dim)
33
+ end
34
+
35
+ seq_length = input_embeds.size(1)
36
+
37
+ # Setting the position-ids to the registered buffer in constructor, it helps
38
+ # when tracing the model without passing position-ids, solves
39
+ # isues similar to issue #5664
40
+ if @position_ids
41
+ position_ids = @position_ids[0.., 0...seq_length]
42
+ else
43
+ position_ids = Torch.arange(seq_length, dtype: :long, device: input_ids.device) # (max_seq_length)
44
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length)
45
+ end
46
+
47
+ position_embeddings = @position_embeddings.(position_ids) # (bs, max_seq_length, dim)
48
+
49
+ embeddings = input_embeds + position_embeddings # (bs, max_seq_length, dim)
50
+ embeddings = @LayerNorm.(embeddings) # (bs, max_seq_length, dim)
51
+ embeddings = @dropout.(embeddings) # (bs, max_seq_length, dim)
52
+ embeddings
53
+ end
54
+ end
55
+
56
+ class MultiHeadSelfAttention < Torch::NN::Module
57
+ def initialize(config)
58
+ super()
59
+ @config = config
60
+
61
+ @n_heads = config.n_heads
62
+ @dim = config.dim
63
+ @dropout = Torch::NN::Dropout.new(p: config.attention_dropout)
64
+ @is_causal = false
65
+
66
+ # Have an even number of multi heads that divide the dimensions
67
+ if @dim % @n_heads != 0
68
+ # Raise value errors for even multi-head attention nodes
69
+ raise ArgumentError, "self.n_heads: #{@n_heads} must divide self.dim: #{@dim} evenly"
70
+ end
71
+
72
+ @q_lin = Torch::NN::Linear.new(config.dim, config.dim)
73
+ @k_lin = Torch::NN::Linear.new(config.dim, config.dim)
74
+ @v_lin = Torch::NN::Linear.new(config.dim, config.dim)
75
+ @out_lin = Torch::NN::Linear.new(config.dim, config.dim)
76
+
77
+ @pruned_heads = Set.new
78
+ @attention_head_size = @dim.div(@n_heads)
79
+ end
80
+
81
+ def prune_heads(heads)
82
+ if heads.length == 0
83
+ return
84
+ end
85
+ raise Todo
86
+ end
87
+
88
+ def forward(
89
+ query:,
90
+ key:,
91
+ value:,
92
+ mask:,
93
+ head_mask: nil,
94
+ output_attentions: false
95
+ )
96
+ bs, _q_length, dim = query.size
97
+ k_length = key.size(1)
98
+ if dim != @dim
99
+ raise "Dimensions do not match: #{dim} input vs #{@dim} configured"
100
+ end
101
+ if key.size != value.size
102
+ raise Todo
103
+ end
104
+
105
+ dim_per_head = @dim.div(@n_heads)
106
+
107
+ mask_reshp = [bs, 1, 1, k_length]
108
+
109
+ shape = lambda do |x|
110
+ x.view(bs, -1, @n_heads, dim_per_head).transpose(1, 2)
111
+ end
112
+
113
+ unshape = lambda do |x|
114
+ x.transpose(1, 2).contiguous.view(bs, -1, @n_heads * dim_per_head)
115
+ end
116
+
117
+ q = shape.(@q_lin.(query)) # (bs, n_heads, q_length, dim_per_head)
118
+ k = shape.(@k_lin.(key)) # (bs, n_heads, k_length, dim_per_head)
119
+ v = shape.(@v_lin.(value)) # (bs, n_heads, k_length, dim_per_head)
120
+
121
+ q = q / Math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head)
122
+ scores = Torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)
123
+ mask = (mask.eq(0)).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length)
124
+ scores =
125
+ scores.masked_fill(
126
+ # TODO use Torch.finfo
127
+ mask, Torch.tensor(0)
128
+ ) # (bs, n_heads, q_length, k_length)
129
+
130
+ weights = Torch::NN::Functional.softmax(scores, dim: -1) # (bs, n_heads, q_length, k_length)
131
+ weights = @dropout.(weights) # (bs, n_heads, q_length, k_length)
132
+
133
+ # Mask heads if we want to
134
+ if !head_mask.nil?
135
+ weights = weights * head_mask
136
+ end
137
+
138
+ context = Torch.matmul(weights, v) # (bs, n_heads, q_length, dim_per_head)
139
+ context = unshape.(context) # (bs, q_length, dim)
140
+ context = @out_lin.(context) # (bs, q_length, dim)
141
+
142
+ if output_attentions
143
+ [context, weights]
144
+ else
145
+ [context]
146
+ end
147
+ end
148
+ end
149
+
150
+ class DistilBertFlashAttention2 < MultiHeadSelfAttention
151
+ end
152
+
153
+ class FFN < Torch::NN::Module
154
+ def initialize(config)
155
+ super()
156
+ @dropout = Torch::NN::Dropout.new(p: config.dropout)
157
+ @chunk_size_feed_forward = config.chunk_size_feed_forward
158
+ @seq_len_dim = 1
159
+ @lin1 = Torch::NN::Linear.new(config.dim, config.hidden_dim)
160
+ @lin2 = Torch::NN::Linear.new(config.hidden_dim, config.dim)
161
+ @activation = Activations.get_activation(config.activation)
162
+ end
163
+
164
+ def forward(input)
165
+ TorchUtils.apply_chunking_to_forward(method(:ff_chunk), @chunk_size_feed_forward, @seq_len_dim, input)
166
+ end
167
+
168
+ def ff_chunk(input)
169
+ x = @lin1.(input)
170
+ x = @activation.(x)
171
+ x = @lin2.(x)
172
+ x = @dropout.(x)
173
+ x
174
+ end
175
+ end
176
+
177
+ DISTILBERT_ATTENTION_CLASSES = {
178
+ "eager" => MultiHeadSelfAttention,
179
+ "flash_attention_2" => DistilBertFlashAttention2
180
+ }
181
+
182
+ class TransformerBlock < Torch::NN::Module
183
+ def initialize(config)
184
+ super()
185
+
186
+ # Have an even number of Configure multi-heads
187
+ if config.dim % config.n_heads != 0
188
+ raise ArgumentError, "config.n_heads #{config.n_heads} must divide config.dim #{config.dim} evenly"
189
+ end
190
+
191
+ @attention = DISTILBERT_ATTENTION_CLASSES[config._attn_implementation].new(config)
192
+ @sa_layer_norm = Torch::NN::LayerNorm.new(config.dim, eps: 1e-12)
193
+
194
+ @ffn = FFN.new(config)
195
+ @output_layer_norm = Torch::NN::LayerNorm.new(config.dim, eps: 1e-12)
196
+ end
197
+
198
+ def forward(
199
+ x:,
200
+ attn_mask: nil,
201
+ head_mask: nil,
202
+ output_attentions: false
203
+ )
204
+ # Self-Attention
205
+ sa_output =
206
+ @attention.(
207
+ query: x,
208
+ key: x,
209
+ value: x,
210
+ mask: attn_mask,
211
+ head_mask: head_mask,
212
+ output_attentions: output_attentions,
213
+ )
214
+ if output_attentions
215
+ sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
216
+ else # To handle these `output_attentions` or `output_hidden_states` cases returning tuples
217
+ if !sa_output.is_a?(Array)
218
+ raise TypeError, "sa_output must be an array but it is #{sa_output.class.name} type"
219
+ end
220
+
221
+ sa_output = sa_output[0]
222
+ end
223
+ sa_output = @sa_layer_norm.(sa_output + x) # (bs, seq_length, dim)
224
+
225
+ # Feed Forward Network
226
+ ffn_output = @ffn.(sa_output) # (bs, seq_length, dim)
227
+ ffn_output = @output_layer_norm.(ffn_output + sa_output) # (bs, seq_length, dim)
228
+
229
+ output = [ffn_output]
230
+ if output_attentions
231
+ output = [sa_weights] + output
232
+ end
233
+ output
234
+ end
235
+ end
236
+
237
+ class Transformer < Torch::NN::Module
238
+ def initialize(config)
239
+ super()
240
+ @n_layers = config.n_layers
241
+ @layer = Torch::NN::ModuleList.new(config.n_layers.times.map { TransformerBlock.new(config) })
242
+ @gradient_checkpointing = false
243
+ end
244
+
245
+ def forward(
246
+ x:,
247
+ attn_mask: nil,
248
+ head_mask: nil,
249
+ output_attentions: false,
250
+ output_hidden_states: false,
251
+ return_dict: nil
252
+ )
253
+ all_hidden_states = output_hidden_states ? [] : nil
254
+ all_attentions = output_attentions ? [] : nil
255
+
256
+ hidden_state = x
257
+ @layer.each_with_index do |layer_module, i|
258
+ if output_hidden_states
259
+ all_hidden_states = all_hidden_states + [hidden_state]
260
+ end
261
+
262
+ if @gradient_checkpointing && training
263
+ layer_outputs =
264
+ _gradient_checkpointing_func(
265
+ layer_module.__call__,
266
+ hidden_state,
267
+ attn_mask,
268
+ head_mask[i],
269
+ output_attentions,
270
+ )
271
+ else
272
+ layer_outputs =
273
+ layer_module.(
274
+ x: hidden_state,
275
+ attn_mask: attn_mask,
276
+ head_mask: head_mask[i],
277
+ output_attentions: output_attentions
278
+ )
279
+ end
280
+
281
+ hidden_state = layer_outputs[-1]
282
+
283
+ if output_attentions
284
+ if layer_outputs.length != 2
285
+ raise ArgumentError, "The length of the layer_outputs should be 2, but it is #{layer_outputs.length}"
286
+ end
287
+
288
+ attentions = layer_outputs[0]
289
+ all_attentions = all_attentions + [attentions]
290
+ else
291
+ if layer_outputs.length != 1
292
+ raise ArgumentError, "The length of the layer_outputs should be 1, but it is #{layer_outputs.length}"
293
+ end
294
+ end
295
+ end
296
+
297
+ # Add last layer
298
+ if output_hidden_states
299
+ all_hidden_states = all_hidden_states + [hidden_state]
300
+ end
301
+
302
+ if !return_dict
303
+ raise Todo
304
+ end
305
+ BaseModelOutput.new(
306
+ last_hidden_state: hidden_state, hidden_states: all_hidden_states, attentions: all_attentions
307
+ )
308
+ end
309
+ end
310
+
311
+ class DistilBertPreTrainedModel < PreTrainedModel
312
+ self.config_class = DistilBertConfig
313
+ self.base_model_prefix = "distilbert"
314
+
315
+ def _init_weights(mod)
316
+ if mod.is_a?(Torch::NN::Linear)
317
+ mod.weight.data.normal!(mean: 0.0, std: @config.initializer_range)
318
+ if !mod.bias.nil?
319
+ mod.bias.data.zero!
320
+ end
321
+ elsif mod.is_a?(Torch::NN::Embedding)
322
+ mod.weight.data.normal!(mean: 0.0, std: @config.initializer_range)
323
+ if !mod.instance_variable_get(:@padding_idx).nil?
324
+ mod.weight.data[mod.instance_variable_get(:@padding_idx)].zero!
325
+ end
326
+ elsif mod.is_a?(Torch::NN::LayerNorm)
327
+ mod.bias.data.zero!
328
+ mod.weight.data.fill!(1.0)
329
+ elsif mod.is_a?(Embeddings) && @config.sinusoidal_pos_embds
330
+ create_sinusoidal_embeddings(
331
+ @config.max_position_embeddings, @config.dim, mod.position_embeddings.weight
332
+ )
333
+ end
334
+ end
335
+
336
+ private
337
+
338
+ def create_sinusoidal_embeddings(n_pos, dim, out)
339
+ # TODO
340
+ end
341
+ end
342
+
343
+ class DistilBertModel < DistilBertPreTrainedModel
344
+ def initialize(config)
345
+ super(config)
346
+
347
+ @embeddings = Embeddings.new(config) # Embeddings
348
+ @transformer = Transformer.new(config) # Encoder
349
+ @use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
350
+
351
+ # Initialize weights and apply final processing
352
+ post_init
353
+ end
354
+
355
+ def get_position_embeddings
356
+ @embeddings.position_embeddings
357
+ end
358
+
359
+ def get_input_embeddings
360
+ @embeddings.word_embeddings
361
+ end
362
+
363
+ def _prune_heads(heads_to_prune)
364
+ heads_to_prune.each do |layer, heads|
365
+ @transformer.layer[layer].attention.prune_heads(heads)
366
+ end
367
+ end
368
+
369
+ def forward(
370
+ input_ids: nil,
371
+ attention_mask: nil,
372
+ head_mask: nil,
373
+ inputs_embeds: nil,
374
+ output_attentions: nil,
375
+ output_hidden_states: nil,
376
+ return_dict: nil
377
+ )
378
+ output_attentions = !output_attentions.nil? ? output_attentions : @config.output_attentions
379
+ output_hidden_states = (
380
+ !output_hidden_states.nil? ? output_hidden_states : @config.output_hidden_states
381
+ )
382
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
383
+
384
+ if !input_ids.nil? && !inputs_embeds.nil?
385
+ raise ArgumentError, "You cannot specify both input_ids and inputs_embeds at the same time"
386
+ elsif !input_ids.nil?
387
+ warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
388
+ input_shape = input_ids.size
389
+ elsif !inputs_embeds.nil?
390
+ input_shape = inputs_embeds.size[...-1]
391
+ else
392
+ raise ArgumentError, "You have to specify either input_ids or inputs_embeds"
393
+ end
394
+
395
+ device = !input_ids.nil? ? input_ids.device : inputs_embeds.device
396
+
397
+ # Prepare head mask if needed
398
+ head_mask = get_head_mask(head_mask, @config.num_hidden_layers)
399
+
400
+ embeddings = @embeddings.(input_ids, inputs_embeds) # (bs, seq_length, dim)
401
+
402
+ if @use_flash_attention_2
403
+ raise Todo
404
+ else
405
+ if attention_mask.nil?
406
+ attention_mask = Torch.ones(input_shape, device: device) # (bs, seq_length)
407
+ end
408
+ end
409
+
410
+ @transformer.(
411
+ x: embeddings,
412
+ attn_mask: attention_mask,
413
+ head_mask: head_mask,
414
+ output_attentions: output_attentions,
415
+ output_hidden_states: output_hidden_states,
416
+ return_dict: return_dict
417
+ )
418
+ end
419
+ end
420
+
421
+ class DistilBertForMaskedLM < DistilBertPreTrainedModel
422
+ self._tied_weights_keys = ["vocab_projector.weight"]
423
+
424
+ def initialize(config)
425
+ super(config)
426
+
427
+ @activation = get_activation(config.activation)
428
+
429
+ @distilbert = DistilBertModel.new(config)
430
+ @vocab_transform = Torch::NN::Linear.new(config.dim, config.dim)
431
+ @vocab_layer_norm = Torch::NN::LayerNorm.new(config.dim, eps: 1e-12)
432
+ @vocab_projector = Torch::NN::Linear.new(config.dim, config.vocab_size)
433
+
434
+ # Initialize weights and apply final processing
435
+ post_init
436
+
437
+ @mlm_loss_fct = Torch::NN::CrossEntropyLoss.new
438
+ end
439
+
440
+ def forward(
441
+ input_ids: nil,
442
+ attention_mask: nil,
443
+ head_mask: nil,
444
+ inputs_embeds: nil,
445
+ labels: nil,
446
+ output_attentions: nil,
447
+ output_hidden_states: nil,
448
+ return_dict: nil
449
+ )
450
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
451
+
452
+ dlbrt_output = @distilbert.(
453
+ input_ids: input_ids,
454
+ attention_mask: attention_mask,
455
+ head_mask: head_mask,
456
+ inputs_embeds: inputs_embeds,
457
+ output_attentions: output_attentions,
458
+ output_hidden_states: output_hidden_states,
459
+ return_dict: return_dict
460
+ )
461
+ hidden_states = dlbrt_output[0] # (bs, seq_length, dim)
462
+ prediction_logits = @vocab_transform.(hidden_states) # (bs, seq_length, dim)
463
+ prediction_logits = @activation.(prediction_logits) # (bs, seq_length, dim)
464
+ prediction_logits = @vocab_layer_norm.(prediction_logits) # (bs, seq_length, dim)
465
+ prediction_logits = @vocab_projector.(prediction_logits) # (bs, seq_length, vocab_size)
466
+
467
+ mlm_loss = nil
468
+ if !labels.nil?
469
+ mlm_loss = @mlm_loss_fct.(prediction_logits.view(-1, prediction_logits.size(-1)), labels.view(-1))
470
+ end
471
+
472
+ if !return_dict
473
+ raise Todo
474
+ end
475
+
476
+ MaskedLMOutput.new(
477
+ loss: mlm_loss,
478
+ logits: prediction_logits,
479
+ hidden_states: dlbrt_output.hidden_states,
480
+ attentions: dlbrt_output.attentions
481
+ )
482
+ end
483
+ end
484
+
485
+ class DistilBertForSequenceClassification < DistilBertPreTrainedModel
486
+ def initialize(config)
487
+ super(config)
488
+ @num_labels = config.num_labels
489
+ @config = config
490
+
491
+ @distilbert = DistilBertModel.new(config)
492
+ @pre_classifier = Torch::NN::Linear.new(config.dim, config.dim)
493
+ @classifier = Torch::NN::Linear.new(config.dim, config.num_labels)
494
+ @dropout = Torch::NN::Dropout.new(p: config.seq_classif_dropout)
495
+
496
+ # Initialize weights and apply final processing
497
+ post_init
498
+ end
499
+
500
+ def forward(
501
+ input_ids: nil,
502
+ attention_mask: nil,
503
+ head_mask: nil,
504
+ inputs_embeds: nil,
505
+ labels: nil,
506
+ output_attentions: nil,
507
+ output_hidden_states: nil,
508
+ return_dict: nil
509
+ )
510
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
511
+
512
+ distilbert_output =
513
+ @distilbert.(
514
+ input_ids: input_ids,
515
+ attention_mask: attention_mask,
516
+ head_mask: head_mask,
517
+ inputs_embeds: inputs_embeds,
518
+ output_attentions: output_attentions,
519
+ output_hidden_states: output_hidden_states,
520
+ return_dict: return_dict
521
+ )
522
+ hidden_state = distilbert_output[0] # (bs, seq_len, dim)
523
+ pooled_output = hidden_state[0.., 0] # (bs, dim)
524
+ pooled_output = @pre_classifier.(pooled_output) # (bs, dim)
525
+ pooled_output = Torch::NN::ReLU.new.(pooled_output) # (bs, dim)
526
+ pooled_output = @dropout.(pooled_output) # (bs, dim)
527
+ logits = @classifier.(pooled_output) # (bs, num_labels)
528
+
529
+ loss = nil
530
+ if !labels.nil?
531
+ raise Todo
532
+ end
533
+
534
+ if !return_dict
535
+ raise Todo
536
+ end
537
+
538
+ SequenceClassifierOutput.new(
539
+ loss: loss,
540
+ logits: logits,
541
+ hidden_states: distilbert_output.hidden_states,
542
+ attentions: distilbert_output.attentions
543
+ )
544
+ end
545
+ end
546
+
547
+ class DistilBertForQuestionAnswering < DistilBertPreTrainedModel
548
+ def initialize(config)
549
+ super(config)
550
+
551
+ @distilbert = DistilBertModel.new(config)
552
+ @qa_outputs = Torch::NN::Linear.new(config.dim, config.num_labels)
553
+ if config.num_labels != 2
554
+ raise ArgumentError, "config.num_labels should be 2, but it is #{config.num_labels}"
555
+ end
556
+
557
+ @dropout = Torch::NN::Dropout.new(p: config.qa_dropout)
558
+
559
+ # Initialize weights and apply final processing
560
+ post_init
561
+ end
562
+
563
+ def forward(
564
+ input_ids: nil,
565
+ attention_mask: nil,
566
+ head_mask: nil,
567
+ inputs_embeds: nil,
568
+ start_positions: nil,
569
+ end_positions: nil,
570
+ output_attentions: nil,
571
+ output_hidden_states: nil,
572
+ return_dict: nil
573
+ )
574
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
575
+
576
+ distilbert_output = @distilbert.(
577
+ input_ids: input_ids,
578
+ attention_mask: attention_mask,
579
+ head_mask: head_mask,
580
+ inputs_embeds: inputs_embeds,
581
+ output_attentions: output_attentions,
582
+ output_hidden_states: output_hidden_states,
583
+ return_dict: return_dict
584
+ )
585
+ hidden_states = distilbert_output[0] # (bs, max_query_len, dim)
586
+
587
+ hidden_states = @dropout.(hidden_states) # (bs, max_query_len, dim)
588
+ logits = @qa_outputs.(hidden_states) # (bs, max_query_len, 2)
589
+ start_logits, end_logits = logits.split(1, dim: -1)
590
+ start_logits = start_logits.squeeze(-1).contiguous # (bs, max_query_len)
591
+ end_logits = end_logits.squeeze(-1).contiguous # (bs, max_query_len)
592
+
593
+ total_loss = nil
594
+ if !start_positions.nil? && !end_positions.nil?
595
+ raise Todo
596
+ end
597
+
598
+ if !return_dict
599
+ raise Todo
600
+ end
601
+
602
+ QuestionAnsweringModelOutput.new(
603
+ loss: total_loss,
604
+ start_logits: start_logits,
605
+ end_logits: end_logits,
606
+ hidden_states: distilbert_output.hidden_states,
607
+ attentions: distilbert_output.attentions
608
+ )
609
+ end
610
+ end
611
+ end
612
+
613
+ DistilBertForMaskedLM = Distilbert::DistilBertForMaskedLM
614
+ DistilBertForSequenceClassification = Distilbert::DistilBertForSequenceClassification
615
+ DistilBertForQuestionAnswering = Distilbert::DistilBertForQuestionAnswering
616
+ end