transformers-rb 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (65) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +3 -0
  3. data/LICENSE.txt +203 -0
  4. data/README.md +163 -0
  5. data/lib/transformers/activations.rb +57 -0
  6. data/lib/transformers/configuration_utils.rb +285 -0
  7. data/lib/transformers/convert_slow_tokenizer.rb +90 -0
  8. data/lib/transformers/data/processors/squad.rb +115 -0
  9. data/lib/transformers/dynamic_module_utils.rb +25 -0
  10. data/lib/transformers/feature_extraction_utils.rb +110 -0
  11. data/lib/transformers/hf_hub/constants.rb +71 -0
  12. data/lib/transformers/hf_hub/errors.rb +11 -0
  13. data/lib/transformers/hf_hub/file_download.rb +764 -0
  14. data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
  15. data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
  16. data/lib/transformers/image_processing_base.rb +169 -0
  17. data/lib/transformers/image_processing_utils.rb +63 -0
  18. data/lib/transformers/image_transforms.rb +208 -0
  19. data/lib/transformers/image_utils.rb +165 -0
  20. data/lib/transformers/modeling_outputs.rb +81 -0
  21. data/lib/transformers/modeling_utils.rb +888 -0
  22. data/lib/transformers/models/auto/auto_factory.rb +138 -0
  23. data/lib/transformers/models/auto/configuration_auto.rb +61 -0
  24. data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
  25. data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
  26. data/lib/transformers/models/auto/modeling_auto.rb +80 -0
  27. data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
  28. data/lib/transformers/models/bert/configuration_bert.rb +65 -0
  29. data/lib/transformers/models/bert/modeling_bert.rb +836 -0
  30. data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
  31. data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
  32. data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
  33. data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
  34. data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
  35. data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
  36. data/lib/transformers/models/vit/configuration_vit.rb +60 -0
  37. data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
  38. data/lib/transformers/models/vit/modeling_vit.rb +506 -0
  39. data/lib/transformers/pipelines/_init.rb +348 -0
  40. data/lib/transformers/pipelines/base.rb +301 -0
  41. data/lib/transformers/pipelines/feature_extraction.rb +47 -0
  42. data/lib/transformers/pipelines/image_classification.rb +110 -0
  43. data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
  44. data/lib/transformers/pipelines/pt_utils.rb +53 -0
  45. data/lib/transformers/pipelines/question_answering.rb +508 -0
  46. data/lib/transformers/pipelines/text_classification.rb +123 -0
  47. data/lib/transformers/pipelines/token_classification.rb +282 -0
  48. data/lib/transformers/ruby_utils.rb +33 -0
  49. data/lib/transformers/sentence_transformer.rb +37 -0
  50. data/lib/transformers/tokenization_utils.rb +152 -0
  51. data/lib/transformers/tokenization_utils_base.rb +937 -0
  52. data/lib/transformers/tokenization_utils_fast.rb +386 -0
  53. data/lib/transformers/torch_utils.rb +25 -0
  54. data/lib/transformers/utils/_init.rb +31 -0
  55. data/lib/transformers/utils/generic.rb +107 -0
  56. data/lib/transformers/utils/hub.rb +209 -0
  57. data/lib/transformers/utils/import_utils.rb +45 -0
  58. data/lib/transformers/utils/logging.rb +52 -0
  59. data/lib/transformers/version.rb +3 -0
  60. data/lib/transformers-rb.rb +1 -0
  61. data/lib/transformers.rb +100 -0
  62. data/licenses/LICENSE-huggingface-hub.txt +201 -0
  63. data/licenses/LICENSE-sentence-transformers.txt +201 -0
  64. data/licenses/NOTICE-sentence-transformers.txt +5 -0
  65. metadata +161 -0
@@ -0,0 +1,836 @@
1
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
2
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ module Transformers
17
+ module Bert
18
+ class BertEmbeddings < Torch::NN::Module
19
+ def initialize(config)
20
+ super()
21
+ @word_embeddings = Torch::NN::Embedding.new(config.vocab_size, config.hidden_size, padding_idx: config.pad_token_id)
22
+ @position_embeddings = Torch::NN::Embedding.new(config.max_position_embeddings, config.hidden_size)
23
+ @token_type_embeddings = Torch::NN::Embedding.new(config.type_vocab_size, config.hidden_size)
24
+
25
+ @LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
26
+ @dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
27
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
28
+ @position_embedding_type = config.position_embedding_type || "absolute"
29
+ register_buffer(
30
+ "position_ids", Torch.arange(config.max_position_embeddings).expand([1, -1]), persistent: false
31
+ )
32
+ register_buffer(
33
+ "token_type_ids", Torch.zeros(position_ids.size, dtype: Torch.long), persistent: false
34
+ )
35
+ end
36
+
37
+ def forward(
38
+ input_ids: nil,
39
+ token_type_ids: nil,
40
+ position_ids: nil,
41
+ inputs_embeds: nil,
42
+ past_key_values_length: 0
43
+ )
44
+ if !input_ids.nil?
45
+ input_shape = input_ids.size
46
+ else
47
+ input_shape = inputs_embeds.size[...-1]
48
+ end
49
+
50
+ seq_length = input_shape[1]
51
+
52
+ if position_ids.nil?
53
+ position_ids = @position_ids[0.., past_key_values_length...(seq_length + past_key_values_length)]
54
+ end
55
+
56
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
57
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
58
+ # issue #5664
59
+ if token_type_ids.nil?
60
+ raise Todo
61
+ end
62
+
63
+ if inputs_embeds.nil?
64
+ inputs_embeds = @word_embeddings.(input_ids)
65
+ end
66
+ token_type_embeddings = @token_type_embeddings.(token_type_ids)
67
+
68
+ embeddings = inputs_embeds + token_type_embeddings
69
+ if @position_embedding_type == "absolute"
70
+ position_embeddings = @position_embeddings.(position_ids)
71
+ embeddings += position_embeddings
72
+ end
73
+ embeddings = @LayerNorm.(embeddings)
74
+ embeddings = @dropout.(embeddings)
75
+ embeddings
76
+ end
77
+ end
78
+
79
+ class BertSelfAttention < Torch::NN::Module
80
+ def initialize(config, position_embedding_type: nil)
81
+ super()
82
+ if config.hidden_size % config.num_attention_heads != 0 && !config.embedding_size
83
+ raise ArgumentError,
84
+ "The hidden size (#{config.hidden_size}) is not a multiple of the number of attention " +
85
+ "heads (#{config.num_attention_heads})"
86
+ end
87
+
88
+ @num_attention_heads = config.num_attention_heads
89
+ @attention_head_size = (config.hidden_size / config.num_attention_heads).to_i
90
+ @all_head_size = @num_attention_heads * @attention_head_size
91
+
92
+ @query = Torch::NN::Linear.new(config.hidden_size, @all_head_size)
93
+ @key = Torch::NN::Linear.new(config.hidden_size, @all_head_size)
94
+ @value = Torch::NN::Linear.new(config.hidden_size, @all_head_size)
95
+
96
+ @dropout = Torch::NN::Dropout.new(p: config.attention_probs_dropout_prob)
97
+ @position_embedding_type = position_embedding_type || config.position_embedding_type || "absolute"
98
+ if @position_embedding_type == "relative_key" || @position_embedding_type == "relative_key_query"
99
+ @max_position_embeddings = config.max_position_embeddings
100
+ @distance_embedding = Torch:NN::Embedding.new(2 * config.max_position_embeddings - 1, @attention_head_size)
101
+ end
102
+
103
+ @is_decoder = config.is_decoder
104
+ end
105
+
106
+ def transpose_for_scores(x)
107
+ new_x_shape = x.size[...-1] + [@num_attention_heads, @attention_head_size]
108
+ x = x.view(new_x_shape)
109
+ x.permute(0, 2, 1, 3)
110
+ end
111
+
112
+ def forward(
113
+ hidden_states,
114
+ attention_mask: nil,
115
+ head_mask: nil,
116
+ encoder_hidden_states: nil,
117
+ encoder_attention_mask: nil,
118
+ past_key_value: nil,
119
+ output_attentions: false
120
+ )
121
+ mixed_query_layer = @query.(hidden_states)
122
+
123
+ # If this is instantiated as a cross-attention module, the keys
124
+ # and values come from an encoder; the attention mask needs to be
125
+ # such that the encoder's padding tokens are not attended to.
126
+ is_cross_attention = !encoder_hidden_states.nil?
127
+
128
+ if is_cross_attention && !past_key_value.nil?
129
+ # reuse k,v, cross_attentions
130
+ key_layer = past_key_value[0]
131
+ value_layer = past_key_value[1]
132
+ attention_mask = encoder_attention_mask
133
+ elsif is_cross_attention
134
+ key_layer = transpose_for_scores(@key.(encoder_hidden_states))
135
+ value_layer = transpose_for_scores(@value.(encoder_hidden_states))
136
+ attention_mask = encoder_attention_mask
137
+ elsif !past_key_value.nil?
138
+ key_layer = transpose_for_scores(@key.(hidden_states))
139
+ value_layer = transpose_for_scores(@value.(hidden_states))
140
+ key_layer = Torch.cat([past_key_value[0], key_layer], dim: 2)
141
+ value_layer = Torch.cat([past_key_value[1], value_layer], dim: 2)
142
+ else
143
+ key_layer = transpose_for_scores(@key.(hidden_states))
144
+ value_layer = transpose_for_scores(@value.(hidden_states))
145
+ end
146
+
147
+ query_layer = transpose_for_scores(mixed_query_layer)
148
+
149
+ _use_cache = !past_key_value.nil?
150
+ if @is_decoder
151
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
152
+ # Further calls to cross_attention layer can then reuse all cross-attention
153
+ # key/value_states (first "if" case)
154
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
155
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
156
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
157
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
158
+ past_key_value = [key_layer, value_layer]
159
+ end
160
+
161
+ # Take the dot product between "query" and "key" to get the raw attention scores.
162
+ attention_scores = Torch.matmul(query_layer, key_layer.transpose(-1, -2))
163
+
164
+ if @position_embedding_type == "relative_key" || @position_embedding_type == "relative_key_query"
165
+ raise Todo
166
+ end
167
+
168
+ attention_scores = attention_scores / Math.sqrt(@attention_head_size)
169
+ if !attention_mask.nil?
170
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
171
+ attention_scores = attention_scores + attention_mask
172
+ end
173
+
174
+ # Normalize the attention scores to probabilities.
175
+ attention_probs = Torch::NN::Functional.softmax(attention_scores, dim: -1)
176
+
177
+ # This is actually dropping out entire tokens to attend to, which might
178
+ # seem a bit unusual, but is taken from the original Transformer paper.
179
+ attention_probs = @dropout.(attention_probs)
180
+
181
+ # Mask heads if we want to
182
+ if !head_mask.nil?
183
+ attention_probs = attention_probs * head_mask
184
+ end
185
+
186
+ context_layer = Torch.matmul(attention_probs, value_layer)
187
+
188
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous
189
+ new_context_layer_shape = context_layer.size[...-2] + [@all_head_size]
190
+ context_layer = context_layer.view(new_context_layer_shape)
191
+
192
+ outputs = output_attentions ? [context_layer, attention_probs] : [context_layer]
193
+
194
+ if @is_decoder
195
+ outputs = outputs + [past_key_value]
196
+ end
197
+ outputs
198
+ end
199
+ end
200
+
201
+ class BertSelfOutput < Torch::NN::Module
202
+ def initialize(config)
203
+ super()
204
+ @dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
205
+ @LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
206
+ @dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
207
+ end
208
+
209
+ def forward(hidden_states, input_tensor)
210
+ hidden_states = @dense.(hidden_states)
211
+ hidden_states = @dropout.(hidden_states)
212
+ hidden_states = @LayerNorm.(hidden_states + input_tensor)
213
+ hidden_states
214
+ end
215
+ end
216
+
217
+ BERT_SELF_ATTENTION_CLASSES = {
218
+ "eager" => BertSelfAttention
219
+ }
220
+
221
+ class BertAttention < Torch::NN::Module
222
+ def initialize(config, position_embedding_type: nil)
223
+ super()
224
+ @self = BERT_SELF_ATTENTION_CLASSES.fetch(config._attn_implementation).new(
225
+ config, position_embedding_type: position_embedding_type
226
+ )
227
+ @output = BertSelfOutput.new(config)
228
+ @pruned_heads = Set.new
229
+ end
230
+
231
+ def forward(
232
+ hidden_states,
233
+ attention_mask: nil,
234
+ head_mask: nil,
235
+ encoder_hidden_states: nil,
236
+ encoder_attention_mask: nil,
237
+ past_key_value: nil,
238
+ output_attentions: false
239
+ )
240
+ self_outputs = @self.(
241
+ hidden_states,
242
+ attention_mask: attention_mask,
243
+ head_mask: head_mask,
244
+ encoder_hidden_states: encoder_hidden_states,
245
+ encoder_attention_mask: encoder_attention_mask,
246
+ past_key_value: past_key_value,
247
+ output_attentions: output_attentions
248
+ )
249
+ attention_output = @output.(self_outputs[0], hidden_states)
250
+ outputs = [attention_output] + self_outputs[1..] # add attentions if we output them
251
+ outputs
252
+ end
253
+ end
254
+
255
+ class BertIntermediate < Torch::NN::Module
256
+ def initialize(config)
257
+ super()
258
+ @dense = Torch::NN::Linear.new(config.hidden_size, config.intermediate_size)
259
+ if config.hidden_act.is_a?(String)
260
+ @intermediate_act_fn = ACT2FN[config.hidden_act]
261
+ else
262
+ @intermediate_act_fn = config.hidden_act
263
+ end
264
+ end
265
+
266
+ def forward(hidden_states)
267
+ hidden_states = @dense.(hidden_states)
268
+ hidden_states = @intermediate_act_fn.(hidden_states)
269
+ hidden_states
270
+ end
271
+ end
272
+
273
+ class BertOutput < Torch::NN::Module
274
+ def initialize(config)
275
+ super()
276
+ @dense = Torch::NN::Linear.new(config.intermediate_size, config.hidden_size)
277
+ @LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
278
+ @dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
279
+ end
280
+
281
+ def forward(hidden_states, input_tensor)
282
+ hidden_states = @dense.(hidden_states)
283
+ hidden_states = @dropout.(hidden_states)
284
+ hidden_states = @LayerNorm.(hidden_states + input_tensor)
285
+ hidden_states
286
+ end
287
+ end
288
+
289
+ class BertLayer < Torch::NN::Module
290
+ def initialize(config)
291
+ super()
292
+ @chunk_size_feed_forward = config.chunk_size_feed_forward
293
+ @seq_len_dim = 1
294
+ @attention = BertAttention.new(config)
295
+ @is_decoder = config.is_decoder
296
+ @add_cross_attention = config.add_cross_attention
297
+ if @add_cross_attention
298
+ if !@is_decoder
299
+ raise ArgumentError, "#{self} should be used as a decoder model if cross attention is added"
300
+ end
301
+ @crossattention = BertAttention.new(config, position_embedding_type: "absolute")
302
+ end
303
+ @intermediate = BertIntermediate.new(config)
304
+ @output = BertOutput.new(config)
305
+ end
306
+
307
+ def forward(
308
+ hidden_states,
309
+ attention_mask: nil,
310
+ head_mask: nil,
311
+ encoder_hidden_states: nil,
312
+ encoder_attention_mask: nil,
313
+ past_key_value: nil,
314
+ output_attentions: false
315
+ )
316
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
317
+ self_attn_past_key_value = !past_key_value.nil? ? past_key_value[...2] : nil
318
+ self_attention_outputs = @attention.(
319
+ hidden_states,
320
+ attention_mask: attention_mask,
321
+ head_mask: head_mask,
322
+ output_attentions: output_attentions,
323
+ past_key_value: self_attn_past_key_value
324
+ )
325
+ attention_output = self_attention_outputs[0]
326
+
327
+ # if decoder, the last output is tuple of self-attn cache
328
+ if @is_decoder
329
+ outputs = self_attention_outputs[1...-1]
330
+ present_key_value = self_attention_outputs[-1]
331
+ else
332
+ outputs = self_attention_outputs[1..] # add self attentions if we output attention weights
333
+ end
334
+
335
+ _cross_attn_present_key_value = nil
336
+ if @is_decoder && !encoder_hidden_states.nil?
337
+ raise Todo
338
+ end
339
+
340
+ layer_output = TorchUtils.apply_chunking_to_forward(
341
+ method(:feed_forward_chunk), @chunk_size_feed_forward, @seq_len_dim, attention_output
342
+ )
343
+ outputs = [layer_output] + outputs
344
+
345
+ # if decoder, return the attn key/values as the last output
346
+ if @is_decoder
347
+ outputs = outputs + [present_key_value]
348
+ end
349
+
350
+ outputs
351
+ end
352
+
353
+ def feed_forward_chunk(attention_output)
354
+ intermediate_output = @intermediate.(attention_output)
355
+ layer_output = @output.(intermediate_output, attention_output)
356
+ return layer_output
357
+ end
358
+ end
359
+
360
+ class BertEncoder < Torch::NN::Module
361
+ def initialize(config)
362
+ super()
363
+ @config = config
364
+ @layer = Torch::NN::ModuleList.new(config.num_hidden_layers.times.map { BertLayer.new(config) })
365
+ @gradient_checkpointing = false
366
+ end
367
+
368
+ def forward(
369
+ hidden_states,
370
+ attention_mask: nil,
371
+ head_mask: nil,
372
+ encoder_hidden_states: nil,
373
+ encoder_attention_mask:nil,
374
+ past_key_values: nil,
375
+ use_cache: nil,
376
+ output_attentions: false,
377
+ output_hidden_states: false,
378
+ return_dict: true
379
+ )
380
+ all_hidden_states = output_hidden_states ? [] : nil
381
+ all_self_attentions = output_attentions ? [] : nil
382
+ all_cross_attentions = output_attentions && @config.add_cross_attention ? [] : nil
383
+
384
+ if @gradient_checkpointing && @raining
385
+ raise Todo
386
+ end
387
+
388
+ next_decoder_cache = use_cache ? [] : nil
389
+ @layer.each_with_index do |layer_module, i|
390
+ if output_hidden_states
391
+ all_hidden_states = all_hidden_states + [hidden_states]
392
+ end
393
+
394
+ layer_head_mask = !head_mask.nil? ? head_mask[i] : nil
395
+ past_key_value = !past_key_values.nil? ? past_key_values[i] : nil
396
+
397
+ if @gradient_checkpointing && @training
398
+ raise Todo
399
+ else
400
+ layer_outputs = layer_module.(
401
+ hidden_states,
402
+ attention_mask: attention_mask,
403
+ head_mask: layer_head_mask,
404
+ encoder_hidden_states: encoder_hidden_states,
405
+ encoder_attention_mask: encoder_attention_mask,
406
+ past_key_value: past_key_value,
407
+ output_attentions: output_attentions
408
+ )
409
+ end
410
+
411
+ hidden_states = layer_outputs[0]
412
+ if use_cache
413
+ next_decoder_cache += [layer_outputs[-1]]
414
+ end
415
+ if output_attentions
416
+ all_self_attentions = all_self_attentions + [layer_outputs[1]]
417
+ if @config.add_cross_attention
418
+ all_cross_attentions = all_cross_attentions + [layer_outputs[2]]
419
+ end
420
+ end
421
+ end
422
+
423
+ if output_hidden_states
424
+ all_hidden_states = all_hidden_states + [hidden_states]
425
+ end
426
+
427
+ if !return_dict
428
+ raise Todo
429
+ end
430
+ BaseModelOutputWithPastAndCrossAttentions.new(
431
+ last_hidden_state: hidden_states,
432
+ past_key_values: next_decoder_cache,
433
+ hidden_states: all_hidden_states,
434
+ attentions: all_self_attentions,
435
+ cross_attentions: all_cross_attentions
436
+ )
437
+ end
438
+ end
439
+
440
+ class BertPooler < Torch::NN::Module
441
+ def initialize(config)
442
+ super()
443
+ @dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
444
+ @activation = Torch::NN::Tanh.new
445
+ end
446
+
447
+ def forward(hidden_states)
448
+ # We "pool" the model by simply taking the hidden state corresponding
449
+ # to the first token.
450
+ first_token_tensor = hidden_states[0.., 0]
451
+ pooled_output = @dense.(first_token_tensor)
452
+ pooled_output = @activation.(pooled_output)
453
+ pooled_output
454
+ end
455
+ end
456
+
457
+ class BertPredictionHeadTransform < Torch::NN::Module
458
+ def initialize(config)
459
+ super()
460
+ @dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
461
+ if config.hidden_act.is_a?(String)
462
+ @transform_act_fn = ACT2FN[config.hidden_act]
463
+ else
464
+ @transform_act_fn = config.hidden_act
465
+ end
466
+ @LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
467
+ end
468
+
469
+ def forward(hidden_states)
470
+ hidden_states = @dense.(hidden_states)
471
+ hidden_states = @transform_act_fn.(hidden_states)
472
+ hidden_states = @LayerNorm.(hidden_states)
473
+ hidden_states
474
+ end
475
+ end
476
+
477
+ class BertLMPredictionHead < Torch::NN::Module
478
+ def initialize(config)
479
+ super()
480
+ @transform = BertPredictionHeadTransform.new(config)
481
+
482
+ # The output weights are the same as the input embeddings, but there is
483
+ # an output-only bias for each token.
484
+ @decoder = Torch::NN::Linear.new(config.hidden_size, config.vocab_size, bias: false)
485
+
486
+ @bias = Torch::NN::Parameter.new(Torch.zeros(config.vocab_size))
487
+
488
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
489
+ @decoder.instance_variable_set(:@bias, @bias)
490
+ end
491
+
492
+ def _tie_weights
493
+ @decoder.instance_variable_set(:@bias, @bias)
494
+ end
495
+
496
+ def forward(hidden_states)
497
+ hidden_states = @transform.(hidden_states)
498
+ hidden_states = @decoder.(hidden_states)
499
+ hidden_states
500
+ end
501
+ end
502
+
503
+ class BertOnlyMLMHead < Torch::NN::Module
504
+ def initialize(config)
505
+ super()
506
+ @predictions = BertLMPredictionHead.new(config)
507
+ end
508
+
509
+ def forward(sequence_output)
510
+ prediction_scores = @predictions.(sequence_output)
511
+ prediction_scores
512
+ end
513
+ end
514
+
515
+ class BertPreTrainedModel < PreTrainedModel
516
+ self.config_class = BertConfig
517
+ self.base_model_prefix = "bert"
518
+
519
+ def _init_weights(mod)
520
+ if mod.is_a?(Torch::NN::Linear)
521
+ mod.weight.data.normal!(mean: 0.0, std: @config.initializer_range)
522
+ if !mod.bias.nil?
523
+ mod.bias.data.zero!
524
+ end
525
+ elsif mod.is_a?(Torch::NN::Embedding)
526
+ mod.weight.data.normal!(mean: 0.0, std: @config.initializer_range)
527
+ if !mod.instance_variable_get(:@padding_idx).nil?
528
+ mod.weight.data[mod.instance_variable_get(:@padding_idx)].zero!
529
+ end
530
+ elsif mod.is_a?(Torch::NN::LayerNorm)
531
+ mod.bias.data.zero!
532
+ mod.weight.data.fill!(1.0)
533
+ end
534
+ end
535
+ end
536
+
537
+ class BertModel < BertPreTrainedModel
538
+ def initialize(config, add_pooling_layer: true)
539
+ super(config)
540
+ @config = config
541
+
542
+ @embeddings = BertEmbeddings.new(config)
543
+ @encoder = BertEncoder.new(config)
544
+
545
+ @pooler = add_pooling_layer ? BertPooler.new(config) : nil
546
+
547
+ @attn_implementation = config._attn_implementation
548
+ @position_embedding_type = config.position_embedding_type
549
+
550
+ # Initialize weights and apply final processing
551
+ post_init
552
+ end
553
+
554
+ def _prune_heads(heads_to_prune)
555
+ heads_to_prune.each do |layer, heads|
556
+ @encoder.layer[layer].attention.prune_heads(heads)
557
+ end
558
+ end
559
+
560
+ def forward(
561
+ input_ids: nil,
562
+ attention_mask: nil,
563
+ token_type_ids: nil,
564
+ position_ids: nil,
565
+ head_mask: nil,
566
+ inputs_embeds: nil,
567
+ encoder_hidden_states: nil,
568
+ encoder_attention_mask: nil,
569
+ past_key_values: nil,
570
+ use_cache: nil,
571
+ output_attentions: nil,
572
+ output_hidden_states: nil,
573
+ return_dict: nil
574
+ )
575
+ output_attentions = !output_attentions.nil? ? output_attentions : @config.output_attentions
576
+ output_hidden_states = (
577
+ !output_hidden_states.nil? ? output_hidden_states : @config.output_hidden_states
578
+ )
579
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
580
+
581
+ if @config.is_decoder
582
+ use_cache = !use_cache.nil? ? use_cache : @config.use_cache
583
+ else
584
+ use_cache = false
585
+ end
586
+
587
+ if !input_ids.nil? && !inputs_embeds.nil?
588
+ raise ArgumentError, "You cannot specify both input_ids and inputs_embeds at the same time"
589
+ elsif !input_ids.nil?
590
+ # self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
591
+ input_shape = input_ids.size
592
+ elsif !inputs_embeds.nil?
593
+ input_shape = inputs_embeds.size[...-1]
594
+ else
595
+ raise ArgumentError, "You have to specify either input_ids or inputs_embeds"
596
+ end
597
+
598
+ batch_size, seq_length = input_shape
599
+ device = !input_ids.nil? ? input_ids.device : inputs_embeds.device
600
+
601
+ # past_key_values_length
602
+ past_key_values_length = !past_key_values.nil? ? past_key_values[0][0].shape[2] : 0
603
+
604
+ if token_type_ids.nil?
605
+ if @embeddings.token_type_ids
606
+ buffered_token_type_ids = @embeddings.token_type_ids[0.., 0...seq_length]
607
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
608
+ token_type_ids = buffered_token_type_ids_expanded
609
+ else
610
+ token_type_ids = Torch.zeros(input_shape, dtype: Torch.long, device: device)
611
+ end
612
+ end
613
+
614
+ embedding_output = @embeddings.(
615
+ input_ids: input_ids,
616
+ position_ids: position_ids,
617
+ token_type_ids: token_type_ids,
618
+ inputs_embeds: inputs_embeds,
619
+ past_key_values_length: past_key_values_length
620
+ )
621
+
622
+ if attention_mask.nil?
623
+ attention_mask = Torch.ones([batch_size, seq_length + past_key_values_length], device: device)
624
+ end
625
+
626
+ use_sdpa_attention_masks = (
627
+ @attn_implementation == "sdpa" &&
628
+ @position_embedding_type == "absolute" &&
629
+ head_mask.nil? &&
630
+ !output_attentions
631
+ )
632
+
633
+ # Expand the attention mask
634
+ if use_sdpa_attention_masks
635
+ raise Todo
636
+ else
637
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
638
+ # ourselves in which case we just need to make it broadcastable to all heads.
639
+ extended_attention_mask = get_extended_attention_mask(attention_mask, input_shape)
640
+ end
641
+
642
+ # # If a 2D or 3D attention mask is provided for the cross-attention
643
+ # # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
644
+ if @config.is_decoder && !encoder_hidden_states.nil?
645
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size
646
+ encoder_hidden_shape = [encoder_batch_size, encoder_sequence_length]
647
+ if encoder_attention_mask.nil?
648
+ encoder_attention_mask = Torch.ones(encoder_hidden_shape, device: device)
649
+ end
650
+
651
+ if use_sdpa_attention_masks
652
+ # Expand the attention mask for SDPA.
653
+ # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
654
+ encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
655
+ encoder_attention_mask, embedding_output.dtype, tgt_len: seq_length
656
+ )
657
+ else
658
+ encoder_extended_attention_mask = invert_attention_mask(encoder_attention_mask)
659
+ end
660
+ else
661
+ encoder_extended_attention_mask = nil
662
+ end
663
+
664
+ # Prepare head mask if needed
665
+ # 1.0 in head_mask indicate we keep the head
666
+ # attention_probs has shape bsz x n_heads x N x N
667
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
668
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
669
+ head_mask = get_head_mask(head_mask, @config.num_hidden_layers)
670
+
671
+ encoder_outputs = @encoder.(
672
+ embedding_output,
673
+ attention_mask: extended_attention_mask,
674
+ head_mask: head_mask,
675
+ encoder_hidden_states: encoder_hidden_states,
676
+ encoder_attention_mask: encoder_extended_attention_mask,
677
+ past_key_values: past_key_values,
678
+ use_cache: use_cache,
679
+ output_attentions: output_attentions,
680
+ output_hidden_states: output_hidden_states,
681
+ return_dict: return_dict
682
+ )
683
+ sequence_output = encoder_outputs[0]
684
+ pooled_output = !@pooler.nil? ? @pooler.(sequence_output) : nil
685
+
686
+ if !return_dict
687
+ raise Todo
688
+ end
689
+
690
+ BaseModelOutputWithPoolingAndCrossAttentions.new(
691
+ last_hidden_state: sequence_output,
692
+ pooler_output: pooled_output,
693
+ past_key_values: encoder_outputs.past_key_values,
694
+ hidden_states: encoder_outputs.hidden_states,
695
+ attentions: encoder_outputs.attentions,
696
+ cross_attentions: encoder_outputs.cross_attentions
697
+ )
698
+ end
699
+ end
700
+
701
+ class BertForMaskedLM < BertPreTrainedModel
702
+ def initialize(config)
703
+ super(config)
704
+
705
+ if config.is_decoder
706
+ Transformers.logger.warn(
707
+ "If you want to use `BertForMaskedLM` make sure `config.is_decoder: false` for " +
708
+ "bi-directional self-attention."
709
+ )
710
+ end
711
+
712
+ @bert = BertModel.new(config, add_pooling_layer: false)
713
+ @cls = BertOnlyMLMHead.new(config)
714
+ end
715
+
716
+ def forward(
717
+ input_ids: nil,
718
+ attention_mask: nil,
719
+ token_type_ids: nil,
720
+ position_ids: nil,
721
+ head_mask: nil,
722
+ inputs_embeds: nil,
723
+ encoder_hidden_states: nil,
724
+ encoder_attention_mask: nil,
725
+ labels: nil,
726
+ output_attentions: nil,
727
+ output_hidden_states: nil,
728
+ return_dict: nil
729
+ )
730
+ return_dict = !return_dict.nil? ? return_dict : config.use_return_dict
731
+
732
+ outputs = @bert.(
733
+ input_ids: input_ids,
734
+ attention_mask: attention_mask,
735
+ token_type_ids: token_type_ids,
736
+ position_ids: position_ids,
737
+ head_mask: head_mask,
738
+ inputs_embeds: inputs_embeds,
739
+ encoder_hidden_states: encoder_hidden_states,
740
+ encoder_attention_mask: encoder_attention_mask,
741
+ output_attentions: output_attentions,
742
+ output_hidden_states: output_hidden_states,
743
+ return_dict: return_dict
744
+ )
745
+
746
+ sequence_output = outputs[0]
747
+ prediction_scores = @cls.(sequence_output)
748
+
749
+ masked_lm_loss = nil
750
+ if !labels.nil?
751
+ raise Todo
752
+ end
753
+
754
+ if !return_dict
755
+ raise Todo
756
+ end
757
+
758
+ MaskedLMOutput.new(
759
+ loss: masked_lm_loss,
760
+ logits: prediction_scores,
761
+ hidden_states: outputs.hidden_states,
762
+ attentions: outputs.attentions
763
+ )
764
+ end
765
+ end
766
+
767
+ class BertForTokenClassification < BertPreTrainedModel
768
+ def initialize(config)
769
+ super(config)
770
+ @num_labels = config.num_labels
771
+
772
+ @bert = BertModel.new(config, add_pooling_layer: false)
773
+ classifier_dropout = (
774
+ !config.classifier_dropout.nil? ? config.classifier_dropout : config.hidden_dropout_prob
775
+ )
776
+ @dropout = Torch::NN::Dropout.new(p: classifier_dropout)
777
+ @classifier = Torch::NN::Linear.new(config.hidden_size, config.num_labels)
778
+
779
+ # Initialize weights and apply final processing
780
+ post_init
781
+ end
782
+
783
+ def forward(
784
+ input_ids: nil,
785
+ attention_mask: nil,
786
+ token_type_ids: nil,
787
+ position_ids: nil,
788
+ head_mask: nil,
789
+ inputs_embeds: nil,
790
+ labels: nil,
791
+ output_attentions: nil,
792
+ output_hidden_states: nil,
793
+ return_dict: nil
794
+ )
795
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
796
+
797
+ outputs = @bert.(
798
+ input_ids: input_ids,
799
+ attention_mask: attention_mask,
800
+ token_type_ids: token_type_ids,
801
+ position_ids: position_ids,
802
+ head_mask: head_mask,
803
+ inputs_embeds: inputs_embeds,
804
+ output_attentions: output_attentions,
805
+ output_hidden_states: output_hidden_states,
806
+ return_dict: return_dict
807
+ )
808
+
809
+ sequence_output = outputs[0]
810
+
811
+ sequence_output = @dropout.(sequence_output)
812
+ logits = @classifier.(sequence_output)
813
+
814
+ loss = nil
815
+ if !labels.nil?
816
+ loss_fct = CrossEntropyLoss.new
817
+ loss = loss_fct.(logits.view(-1,@num_labels), labels.view(-1))
818
+ end
819
+
820
+ if !return_dict
821
+ raise Todo
822
+ end
823
+
824
+ TokenClassifierOutput.new(
825
+ loss: loss,
826
+ logits: logits,
827
+ hidden_states: outputs.hidden_states,
828
+ attentions: outputs.attentions
829
+ )
830
+ end
831
+ end
832
+ end
833
+
834
+ BertModel = Bert::BertModel
835
+ BertForTokenClassification = Bert::BertForTokenClassification
836
+ end