transformers-rb 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +3 -0
  3. data/LICENSE.txt +203 -0
  4. data/README.md +163 -0
  5. data/lib/transformers/activations.rb +57 -0
  6. data/lib/transformers/configuration_utils.rb +285 -0
  7. data/lib/transformers/convert_slow_tokenizer.rb +90 -0
  8. data/lib/transformers/data/processors/squad.rb +115 -0
  9. data/lib/transformers/dynamic_module_utils.rb +25 -0
  10. data/lib/transformers/feature_extraction_utils.rb +110 -0
  11. data/lib/transformers/hf_hub/constants.rb +71 -0
  12. data/lib/transformers/hf_hub/errors.rb +11 -0
  13. data/lib/transformers/hf_hub/file_download.rb +764 -0
  14. data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
  15. data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
  16. data/lib/transformers/image_processing_base.rb +169 -0
  17. data/lib/transformers/image_processing_utils.rb +63 -0
  18. data/lib/transformers/image_transforms.rb +208 -0
  19. data/lib/transformers/image_utils.rb +165 -0
  20. data/lib/transformers/modeling_outputs.rb +81 -0
  21. data/lib/transformers/modeling_utils.rb +888 -0
  22. data/lib/transformers/models/auto/auto_factory.rb +138 -0
  23. data/lib/transformers/models/auto/configuration_auto.rb +61 -0
  24. data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
  25. data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
  26. data/lib/transformers/models/auto/modeling_auto.rb +80 -0
  27. data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
  28. data/lib/transformers/models/bert/configuration_bert.rb +65 -0
  29. data/lib/transformers/models/bert/modeling_bert.rb +836 -0
  30. data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
  31. data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
  32. data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
  33. data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
  34. data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
  35. data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
  36. data/lib/transformers/models/vit/configuration_vit.rb +60 -0
  37. data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
  38. data/lib/transformers/models/vit/modeling_vit.rb +506 -0
  39. data/lib/transformers/pipelines/_init.rb +348 -0
  40. data/lib/transformers/pipelines/base.rb +301 -0
  41. data/lib/transformers/pipelines/feature_extraction.rb +47 -0
  42. data/lib/transformers/pipelines/image_classification.rb +110 -0
  43. data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
  44. data/lib/transformers/pipelines/pt_utils.rb +53 -0
  45. data/lib/transformers/pipelines/question_answering.rb +508 -0
  46. data/lib/transformers/pipelines/text_classification.rb +123 -0
  47. data/lib/transformers/pipelines/token_classification.rb +282 -0
  48. data/lib/transformers/ruby_utils.rb +33 -0
  49. data/lib/transformers/sentence_transformer.rb +37 -0
  50. data/lib/transformers/tokenization_utils.rb +152 -0
  51. data/lib/transformers/tokenization_utils_base.rb +937 -0
  52. data/lib/transformers/tokenization_utils_fast.rb +386 -0
  53. data/lib/transformers/torch_utils.rb +25 -0
  54. data/lib/transformers/utils/_init.rb +31 -0
  55. data/lib/transformers/utils/generic.rb +107 -0
  56. data/lib/transformers/utils/hub.rb +209 -0
  57. data/lib/transformers/utils/import_utils.rb +45 -0
  58. data/lib/transformers/utils/logging.rb +52 -0
  59. data/lib/transformers/version.rb +3 -0
  60. data/lib/transformers-rb.rb +1 -0
  61. data/lib/transformers.rb +100 -0
  62. data/licenses/LICENSE-huggingface-hub.txt +201 -0
  63. data/licenses/LICENSE-sentence-transformers.txt +201 -0
  64. data/licenses/NOTICE-sentence-transformers.txt +5 -0
  65. metadata +161 -0
@@ -0,0 +1,616 @@
1
+ # Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ module Distilbert
17
+ class Embeddings < Torch::NN::Module
18
+ def initialize(config)
19
+ super()
20
+ @word_embeddings = Torch::NN::Embedding.new(config.vocab_size, config.dim, padding_idx: config.pad_token_id)
21
+ @position_embeddings = Torch::NN::Embedding.new(config.max_position_embeddings, config.dim)
22
+
23
+ @LayerNorm = Torch::NN::LayerNorm.new(config.dim, eps: 1e-12)
24
+ @dropout = Torch::NN::Dropout.new(p: config.dropout)
25
+ register_buffer(
26
+ "position_ids", Torch.arange(config.max_position_embeddings).expand([1, -1]), persistent: false
27
+ )
28
+ end
29
+
30
+ def forward(input_ids, input_embeds)
31
+ if !input_ids.nil?
32
+ input_embeds = @word_embeddings.(input_ids) # (bs, max_seq_length, dim)
33
+ end
34
+
35
+ seq_length = input_embeds.size(1)
36
+
37
+ # Setting the position-ids to the registered buffer in constructor, it helps
38
+ # when tracing the model without passing position-ids, solves
39
+ # isues similar to issue #5664
40
+ if @position_ids
41
+ position_ids = @position_ids[0.., 0...seq_length]
42
+ else
43
+ position_ids = Torch.arange(seq_length, dtype: :long, device: input_ids.device) # (max_seq_length)
44
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length)
45
+ end
46
+
47
+ position_embeddings = @position_embeddings.(position_ids) # (bs, max_seq_length, dim)
48
+
49
+ embeddings = input_embeds + position_embeddings # (bs, max_seq_length, dim)
50
+ embeddings = @LayerNorm.(embeddings) # (bs, max_seq_length, dim)
51
+ embeddings = @dropout.(embeddings) # (bs, max_seq_length, dim)
52
+ embeddings
53
+ end
54
+ end
55
+
56
+ class MultiHeadSelfAttention < Torch::NN::Module
57
+ def initialize(config)
58
+ super()
59
+ @config = config
60
+
61
+ @n_heads = config.n_heads
62
+ @dim = config.dim
63
+ @dropout = Torch::NN::Dropout.new(p: config.attention_dropout)
64
+ @is_causal = false
65
+
66
+ # Have an even number of multi heads that divide the dimensions
67
+ if @dim % @n_heads != 0
68
+ # Raise value errors for even multi-head attention nodes
69
+ raise ArgumentError, "self.n_heads: #{@n_heads} must divide self.dim: #{@dim} evenly"
70
+ end
71
+
72
+ @q_lin = Torch::NN::Linear.new(config.dim, config.dim)
73
+ @k_lin = Torch::NN::Linear.new(config.dim, config.dim)
74
+ @v_lin = Torch::NN::Linear.new(config.dim, config.dim)
75
+ @out_lin = Torch::NN::Linear.new(config.dim, config.dim)
76
+
77
+ @pruned_heads = Set.new
78
+ @attention_head_size = @dim.div(@n_heads)
79
+ end
80
+
81
+ def prune_heads(heads)
82
+ if heads.length == 0
83
+ return
84
+ end
85
+ raise Todo
86
+ end
87
+
88
+ def forward(
89
+ query:,
90
+ key:,
91
+ value:,
92
+ mask:,
93
+ head_mask: nil,
94
+ output_attentions: false
95
+ )
96
+ bs, _q_length, dim = query.size
97
+ k_length = key.size(1)
98
+ if dim != @dim
99
+ raise "Dimensions do not match: #{dim} input vs #{@dim} configured"
100
+ end
101
+ if key.size != value.size
102
+ raise Todo
103
+ end
104
+
105
+ dim_per_head = @dim.div(@n_heads)
106
+
107
+ mask_reshp = [bs, 1, 1, k_length]
108
+
109
+ shape = lambda do |x|
110
+ x.view(bs, -1, @n_heads, dim_per_head).transpose(1, 2)
111
+ end
112
+
113
+ unshape = lambda do |x|
114
+ x.transpose(1, 2).contiguous.view(bs, -1, @n_heads * dim_per_head)
115
+ end
116
+
117
+ q = shape.(@q_lin.(query)) # (bs, n_heads, q_length, dim_per_head)
118
+ k = shape.(@k_lin.(key)) # (bs, n_heads, k_length, dim_per_head)
119
+ v = shape.(@v_lin.(value)) # (bs, n_heads, k_length, dim_per_head)
120
+
121
+ q = q / Math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head)
122
+ scores = Torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)
123
+ mask = (mask.eq(0)).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length)
124
+ scores =
125
+ scores.masked_fill(
126
+ # TODO use Torch.finfo
127
+ mask, Torch.tensor(0)
128
+ ) # (bs, n_heads, q_length, k_length)
129
+
130
+ weights = Torch::NN::Functional.softmax(scores, dim: -1) # (bs, n_heads, q_length, k_length)
131
+ weights = @dropout.(weights) # (bs, n_heads, q_length, k_length)
132
+
133
+ # Mask heads if we want to
134
+ if !head_mask.nil?
135
+ weights = weights * head_mask
136
+ end
137
+
138
+ context = Torch.matmul(weights, v) # (bs, n_heads, q_length, dim_per_head)
139
+ context = unshape.(context) # (bs, q_length, dim)
140
+ context = @out_lin.(context) # (bs, q_length, dim)
141
+
142
+ if output_attentions
143
+ [context, weights]
144
+ else
145
+ [context]
146
+ end
147
+ end
148
+ end
149
+
150
+ class DistilBertFlashAttention2 < MultiHeadSelfAttention
151
+ end
152
+
153
+ class FFN < Torch::NN::Module
154
+ def initialize(config)
155
+ super()
156
+ @dropout = Torch::NN::Dropout.new(p: config.dropout)
157
+ @chunk_size_feed_forward = config.chunk_size_feed_forward
158
+ @seq_len_dim = 1
159
+ @lin1 = Torch::NN::Linear.new(config.dim, config.hidden_dim)
160
+ @lin2 = Torch::NN::Linear.new(config.hidden_dim, config.dim)
161
+ @activation = Activations.get_activation(config.activation)
162
+ end
163
+
164
+ def forward(input)
165
+ TorchUtils.apply_chunking_to_forward(method(:ff_chunk), @chunk_size_feed_forward, @seq_len_dim, input)
166
+ end
167
+
168
+ def ff_chunk(input)
169
+ x = @lin1.(input)
170
+ x = @activation.(x)
171
+ x = @lin2.(x)
172
+ x = @dropout.(x)
173
+ x
174
+ end
175
+ end
176
+
177
+ DISTILBERT_ATTENTION_CLASSES = {
178
+ "eager" => MultiHeadSelfAttention,
179
+ "flash_attention_2" => DistilBertFlashAttention2
180
+ }
181
+
182
+ class TransformerBlock < Torch::NN::Module
183
+ def initialize(config)
184
+ super()
185
+
186
+ # Have an even number of Configure multi-heads
187
+ if config.dim % config.n_heads != 0
188
+ raise ArgumentError, "config.n_heads #{config.n_heads} must divide config.dim #{config.dim} evenly"
189
+ end
190
+
191
+ @attention = DISTILBERT_ATTENTION_CLASSES[config._attn_implementation].new(config)
192
+ @sa_layer_norm = Torch::NN::LayerNorm.new(config.dim, eps: 1e-12)
193
+
194
+ @ffn = FFN.new(config)
195
+ @output_layer_norm = Torch::NN::LayerNorm.new(config.dim, eps: 1e-12)
196
+ end
197
+
198
+ def forward(
199
+ x:,
200
+ attn_mask: nil,
201
+ head_mask: nil,
202
+ output_attentions: false
203
+ )
204
+ # Self-Attention
205
+ sa_output =
206
+ @attention.(
207
+ query: x,
208
+ key: x,
209
+ value: x,
210
+ mask: attn_mask,
211
+ head_mask: head_mask,
212
+ output_attentions: output_attentions,
213
+ )
214
+ if output_attentions
215
+ sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
216
+ else # To handle these `output_attentions` or `output_hidden_states` cases returning tuples
217
+ if !sa_output.is_a?(Array)
218
+ raise TypeError, "sa_output must be an array but it is #{sa_output.class.name} type"
219
+ end
220
+
221
+ sa_output = sa_output[0]
222
+ end
223
+ sa_output = @sa_layer_norm.(sa_output + x) # (bs, seq_length, dim)
224
+
225
+ # Feed Forward Network
226
+ ffn_output = @ffn.(sa_output) # (bs, seq_length, dim)
227
+ ffn_output = @output_layer_norm.(ffn_output + sa_output) # (bs, seq_length, dim)
228
+
229
+ output = [ffn_output]
230
+ if output_attentions
231
+ output = [sa_weights] + output
232
+ end
233
+ output
234
+ end
235
+ end
236
+
237
+ class Transformer < Torch::NN::Module
238
+ def initialize(config)
239
+ super()
240
+ @n_layers = config.n_layers
241
+ @layer = Torch::NN::ModuleList.new(config.n_layers.times.map { TransformerBlock.new(config) })
242
+ @gradient_checkpointing = false
243
+ end
244
+
245
+ def forward(
246
+ x:,
247
+ attn_mask: nil,
248
+ head_mask: nil,
249
+ output_attentions: false,
250
+ output_hidden_states: false,
251
+ return_dict: nil
252
+ )
253
+ all_hidden_states = output_hidden_states ? [] : nil
254
+ all_attentions = output_attentions ? [] : nil
255
+
256
+ hidden_state = x
257
+ @layer.each_with_index do |layer_module, i|
258
+ if output_hidden_states
259
+ all_hidden_states = all_hidden_states + [hidden_state]
260
+ end
261
+
262
+ if @gradient_checkpointing && training
263
+ layer_outputs =
264
+ _gradient_checkpointing_func(
265
+ layer_module.__call__,
266
+ hidden_state,
267
+ attn_mask,
268
+ head_mask[i],
269
+ output_attentions,
270
+ )
271
+ else
272
+ layer_outputs =
273
+ layer_module.(
274
+ x: hidden_state,
275
+ attn_mask: attn_mask,
276
+ head_mask: head_mask[i],
277
+ output_attentions: output_attentions
278
+ )
279
+ end
280
+
281
+ hidden_state = layer_outputs[-1]
282
+
283
+ if output_attentions
284
+ if layer_outputs.length != 2
285
+ raise ArgumentError, "The length of the layer_outputs should be 2, but it is #{layer_outputs.length}"
286
+ end
287
+
288
+ attentions = layer_outputs[0]
289
+ all_attentions = all_attentions + [attentions]
290
+ else
291
+ if layer_outputs.length != 1
292
+ raise ArgumentError, "The length of the layer_outputs should be 1, but it is #{layer_outputs.length}"
293
+ end
294
+ end
295
+ end
296
+
297
+ # Add last layer
298
+ if output_hidden_states
299
+ all_hidden_states = all_hidden_states + [hidden_state]
300
+ end
301
+
302
+ if !return_dict
303
+ raise Todo
304
+ end
305
+ BaseModelOutput.new(
306
+ last_hidden_state: hidden_state, hidden_states: all_hidden_states, attentions: all_attentions
307
+ )
308
+ end
309
+ end
310
+
311
+ class DistilBertPreTrainedModel < PreTrainedModel
312
+ self.config_class = DistilBertConfig
313
+ self.base_model_prefix = "distilbert"
314
+
315
+ def _init_weights(mod)
316
+ if mod.is_a?(Torch::NN::Linear)
317
+ mod.weight.data.normal!(mean: 0.0, std: @config.initializer_range)
318
+ if !mod.bias.nil?
319
+ mod.bias.data.zero!
320
+ end
321
+ elsif mod.is_a?(Torch::NN::Embedding)
322
+ mod.weight.data.normal!(mean: 0.0, std: @config.initializer_range)
323
+ if !mod.instance_variable_get(:@padding_idx).nil?
324
+ mod.weight.data[mod.instance_variable_get(:@padding_idx)].zero!
325
+ end
326
+ elsif mod.is_a?(Torch::NN::LayerNorm)
327
+ mod.bias.data.zero!
328
+ mod.weight.data.fill!(1.0)
329
+ elsif mod.is_a?(Embeddings) && @config.sinusoidal_pos_embds
330
+ create_sinusoidal_embeddings(
331
+ @config.max_position_embeddings, @config.dim, mod.position_embeddings.weight
332
+ )
333
+ end
334
+ end
335
+
336
+ private
337
+
338
+ def create_sinusoidal_embeddings(n_pos, dim, out)
339
+ # TODO
340
+ end
341
+ end
342
+
343
+ class DistilBertModel < DistilBertPreTrainedModel
344
+ def initialize(config)
345
+ super(config)
346
+
347
+ @embeddings = Embeddings.new(config) # Embeddings
348
+ @transformer = Transformer.new(config) # Encoder
349
+ @use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
350
+
351
+ # Initialize weights and apply final processing
352
+ post_init
353
+ end
354
+
355
+ def get_position_embeddings
356
+ @embeddings.position_embeddings
357
+ end
358
+
359
+ def get_input_embeddings
360
+ @embeddings.word_embeddings
361
+ end
362
+
363
+ def _prune_heads(heads_to_prune)
364
+ heads_to_prune.each do |layer, heads|
365
+ @transformer.layer[layer].attention.prune_heads(heads)
366
+ end
367
+ end
368
+
369
+ def forward(
370
+ input_ids: nil,
371
+ attention_mask: nil,
372
+ head_mask: nil,
373
+ inputs_embeds: nil,
374
+ output_attentions: nil,
375
+ output_hidden_states: nil,
376
+ return_dict: nil
377
+ )
378
+ output_attentions = !output_attentions.nil? ? output_attentions : @config.output_attentions
379
+ output_hidden_states = (
380
+ !output_hidden_states.nil? ? output_hidden_states : @config.output_hidden_states
381
+ )
382
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
383
+
384
+ if !input_ids.nil? && !inputs_embeds.nil?
385
+ raise ArgumentError, "You cannot specify both input_ids and inputs_embeds at the same time"
386
+ elsif !input_ids.nil?
387
+ warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
388
+ input_shape = input_ids.size
389
+ elsif !inputs_embeds.nil?
390
+ input_shape = inputs_embeds.size[...-1]
391
+ else
392
+ raise ArgumentError, "You have to specify either input_ids or inputs_embeds"
393
+ end
394
+
395
+ device = !input_ids.nil? ? input_ids.device : inputs_embeds.device
396
+
397
+ # Prepare head mask if needed
398
+ head_mask = get_head_mask(head_mask, @config.num_hidden_layers)
399
+
400
+ embeddings = @embeddings.(input_ids, inputs_embeds) # (bs, seq_length, dim)
401
+
402
+ if @use_flash_attention_2
403
+ raise Todo
404
+ else
405
+ if attention_mask.nil?
406
+ attention_mask = Torch.ones(input_shape, device: device) # (bs, seq_length)
407
+ end
408
+ end
409
+
410
+ @transformer.(
411
+ x: embeddings,
412
+ attn_mask: attention_mask,
413
+ head_mask: head_mask,
414
+ output_attentions: output_attentions,
415
+ output_hidden_states: output_hidden_states,
416
+ return_dict: return_dict
417
+ )
418
+ end
419
+ end
420
+
421
+ class DistilBertForMaskedLM < DistilBertPreTrainedModel
422
+ self._tied_weights_keys = ["vocab_projector.weight"]
423
+
424
+ def initialize(config)
425
+ super(config)
426
+
427
+ @activation = get_activation(config.activation)
428
+
429
+ @distilbert = DistilBertModel.new(config)
430
+ @vocab_transform = Torch::NN::Linear.new(config.dim, config.dim)
431
+ @vocab_layer_norm = Torch::NN::LayerNorm.new(config.dim, eps: 1e-12)
432
+ @vocab_projector = Torch::NN::Linear.new(config.dim, config.vocab_size)
433
+
434
+ # Initialize weights and apply final processing
435
+ post_init
436
+
437
+ @mlm_loss_fct = Torch::NN::CrossEntropyLoss.new
438
+ end
439
+
440
+ def forward(
441
+ input_ids: nil,
442
+ attention_mask: nil,
443
+ head_mask: nil,
444
+ inputs_embeds: nil,
445
+ labels: nil,
446
+ output_attentions: nil,
447
+ output_hidden_states: nil,
448
+ return_dict: nil
449
+ )
450
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
451
+
452
+ dlbrt_output = @distilbert.(
453
+ input_ids: input_ids,
454
+ attention_mask: attention_mask,
455
+ head_mask: head_mask,
456
+ inputs_embeds: inputs_embeds,
457
+ output_attentions: output_attentions,
458
+ output_hidden_states: output_hidden_states,
459
+ return_dict: return_dict
460
+ )
461
+ hidden_states = dlbrt_output[0] # (bs, seq_length, dim)
462
+ prediction_logits = @vocab_transform.(hidden_states) # (bs, seq_length, dim)
463
+ prediction_logits = @activation.(prediction_logits) # (bs, seq_length, dim)
464
+ prediction_logits = @vocab_layer_norm.(prediction_logits) # (bs, seq_length, dim)
465
+ prediction_logits = @vocab_projector.(prediction_logits) # (bs, seq_length, vocab_size)
466
+
467
+ mlm_loss = nil
468
+ if !labels.nil?
469
+ mlm_loss = @mlm_loss_fct.(prediction_logits.view(-1, prediction_logits.size(-1)), labels.view(-1))
470
+ end
471
+
472
+ if !return_dict
473
+ raise Todo
474
+ end
475
+
476
+ MaskedLMOutput.new(
477
+ loss: mlm_loss,
478
+ logits: prediction_logits,
479
+ hidden_states: dlbrt_output.hidden_states,
480
+ attentions: dlbrt_output.attentions
481
+ )
482
+ end
483
+ end
484
+
485
+ class DistilBertForSequenceClassification < DistilBertPreTrainedModel
486
+ def initialize(config)
487
+ super(config)
488
+ @num_labels = config.num_labels
489
+ @config = config
490
+
491
+ @distilbert = DistilBertModel.new(config)
492
+ @pre_classifier = Torch::NN::Linear.new(config.dim, config.dim)
493
+ @classifier = Torch::NN::Linear.new(config.dim, config.num_labels)
494
+ @dropout = Torch::NN::Dropout.new(p: config.seq_classif_dropout)
495
+
496
+ # Initialize weights and apply final processing
497
+ post_init
498
+ end
499
+
500
+ def forward(
501
+ input_ids: nil,
502
+ attention_mask: nil,
503
+ head_mask: nil,
504
+ inputs_embeds: nil,
505
+ labels: nil,
506
+ output_attentions: nil,
507
+ output_hidden_states: nil,
508
+ return_dict: nil
509
+ )
510
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
511
+
512
+ distilbert_output =
513
+ @distilbert.(
514
+ input_ids: input_ids,
515
+ attention_mask: attention_mask,
516
+ head_mask: head_mask,
517
+ inputs_embeds: inputs_embeds,
518
+ output_attentions: output_attentions,
519
+ output_hidden_states: output_hidden_states,
520
+ return_dict: return_dict
521
+ )
522
+ hidden_state = distilbert_output[0] # (bs, seq_len, dim)
523
+ pooled_output = hidden_state[0.., 0] # (bs, dim)
524
+ pooled_output = @pre_classifier.(pooled_output) # (bs, dim)
525
+ pooled_output = Torch::NN::ReLU.new.(pooled_output) # (bs, dim)
526
+ pooled_output = @dropout.(pooled_output) # (bs, dim)
527
+ logits = @classifier.(pooled_output) # (bs, num_labels)
528
+
529
+ loss = nil
530
+ if !labels.nil?
531
+ raise Todo
532
+ end
533
+
534
+ if !return_dict
535
+ raise Todo
536
+ end
537
+
538
+ SequenceClassifierOutput.new(
539
+ loss: loss,
540
+ logits: logits,
541
+ hidden_states: distilbert_output.hidden_states,
542
+ attentions: distilbert_output.attentions
543
+ )
544
+ end
545
+ end
546
+
547
+ class DistilBertForQuestionAnswering < DistilBertPreTrainedModel
548
+ def initialize(config)
549
+ super(config)
550
+
551
+ @distilbert = DistilBertModel.new(config)
552
+ @qa_outputs = Torch::NN::Linear.new(config.dim, config.num_labels)
553
+ if config.num_labels != 2
554
+ raise ArgumentError, "config.num_labels should be 2, but it is #{config.num_labels}"
555
+ end
556
+
557
+ @dropout = Torch::NN::Dropout.new(p: config.qa_dropout)
558
+
559
+ # Initialize weights and apply final processing
560
+ post_init
561
+ end
562
+
563
+ def forward(
564
+ input_ids: nil,
565
+ attention_mask: nil,
566
+ head_mask: nil,
567
+ inputs_embeds: nil,
568
+ start_positions: nil,
569
+ end_positions: nil,
570
+ output_attentions: nil,
571
+ output_hidden_states: nil,
572
+ return_dict: nil
573
+ )
574
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
575
+
576
+ distilbert_output = @distilbert.(
577
+ input_ids: input_ids,
578
+ attention_mask: attention_mask,
579
+ head_mask: head_mask,
580
+ inputs_embeds: inputs_embeds,
581
+ output_attentions: output_attentions,
582
+ output_hidden_states: output_hidden_states,
583
+ return_dict: return_dict
584
+ )
585
+ hidden_states = distilbert_output[0] # (bs, max_query_len, dim)
586
+
587
+ hidden_states = @dropout.(hidden_states) # (bs, max_query_len, dim)
588
+ logits = @qa_outputs.(hidden_states) # (bs, max_query_len, 2)
589
+ start_logits, end_logits = logits.split(1, dim: -1)
590
+ start_logits = start_logits.squeeze(-1).contiguous # (bs, max_query_len)
591
+ end_logits = end_logits.squeeze(-1).contiguous # (bs, max_query_len)
592
+
593
+ total_loss = nil
594
+ if !start_positions.nil? && !end_positions.nil?
595
+ raise Todo
596
+ end
597
+
598
+ if !return_dict
599
+ raise Todo
600
+ end
601
+
602
+ QuestionAnsweringModelOutput.new(
603
+ loss: total_loss,
604
+ start_logits: start_logits,
605
+ end_logits: end_logits,
606
+ hidden_states: distilbert_output.hidden_states,
607
+ attentions: distilbert_output.attentions
608
+ )
609
+ end
610
+ end
611
+ end
612
+
613
+ DistilBertForMaskedLM = Distilbert::DistilBertForMaskedLM
614
+ DistilBertForSequenceClassification = Distilbert::DistilBertForSequenceClassification
615
+ DistilBertForQuestionAnswering = Distilbert::DistilBertForQuestionAnswering
616
+ end