transformers-rb 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +3 -0
  3. data/LICENSE.txt +203 -0
  4. data/README.md +163 -0
  5. data/lib/transformers/activations.rb +57 -0
  6. data/lib/transformers/configuration_utils.rb +285 -0
  7. data/lib/transformers/convert_slow_tokenizer.rb +90 -0
  8. data/lib/transformers/data/processors/squad.rb +115 -0
  9. data/lib/transformers/dynamic_module_utils.rb +25 -0
  10. data/lib/transformers/feature_extraction_utils.rb +110 -0
  11. data/lib/transformers/hf_hub/constants.rb +71 -0
  12. data/lib/transformers/hf_hub/errors.rb +11 -0
  13. data/lib/transformers/hf_hub/file_download.rb +764 -0
  14. data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
  15. data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
  16. data/lib/transformers/image_processing_base.rb +169 -0
  17. data/lib/transformers/image_processing_utils.rb +63 -0
  18. data/lib/transformers/image_transforms.rb +208 -0
  19. data/lib/transformers/image_utils.rb +165 -0
  20. data/lib/transformers/modeling_outputs.rb +81 -0
  21. data/lib/transformers/modeling_utils.rb +888 -0
  22. data/lib/transformers/models/auto/auto_factory.rb +138 -0
  23. data/lib/transformers/models/auto/configuration_auto.rb +61 -0
  24. data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
  25. data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
  26. data/lib/transformers/models/auto/modeling_auto.rb +80 -0
  27. data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
  28. data/lib/transformers/models/bert/configuration_bert.rb +65 -0
  29. data/lib/transformers/models/bert/modeling_bert.rb +836 -0
  30. data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
  31. data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
  32. data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
  33. data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
  34. data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
  35. data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
  36. data/lib/transformers/models/vit/configuration_vit.rb +60 -0
  37. data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
  38. data/lib/transformers/models/vit/modeling_vit.rb +506 -0
  39. data/lib/transformers/pipelines/_init.rb +348 -0
  40. data/lib/transformers/pipelines/base.rb +301 -0
  41. data/lib/transformers/pipelines/feature_extraction.rb +47 -0
  42. data/lib/transformers/pipelines/image_classification.rb +110 -0
  43. data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
  44. data/lib/transformers/pipelines/pt_utils.rb +53 -0
  45. data/lib/transformers/pipelines/question_answering.rb +508 -0
  46. data/lib/transformers/pipelines/text_classification.rb +123 -0
  47. data/lib/transformers/pipelines/token_classification.rb +282 -0
  48. data/lib/transformers/ruby_utils.rb +33 -0
  49. data/lib/transformers/sentence_transformer.rb +37 -0
  50. data/lib/transformers/tokenization_utils.rb +152 -0
  51. data/lib/transformers/tokenization_utils_base.rb +937 -0
  52. data/lib/transformers/tokenization_utils_fast.rb +386 -0
  53. data/lib/transformers/torch_utils.rb +25 -0
  54. data/lib/transformers/utils/_init.rb +31 -0
  55. data/lib/transformers/utils/generic.rb +107 -0
  56. data/lib/transformers/utils/hub.rb +209 -0
  57. data/lib/transformers/utils/import_utils.rb +45 -0
  58. data/lib/transformers/utils/logging.rb +52 -0
  59. data/lib/transformers/version.rb +3 -0
  60. data/lib/transformers-rb.rb +1 -0
  61. data/lib/transformers.rb +100 -0
  62. data/licenses/LICENSE-huggingface-hub.txt +201 -0
  63. data/licenses/LICENSE-sentence-transformers.txt +201 -0
  64. data/licenses/NOTICE-sentence-transformers.txt +5 -0
  65. metadata +161 -0
@@ -0,0 +1,836 @@
1
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
2
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ module Transformers
17
+ module Bert
18
+ class BertEmbeddings < Torch::NN::Module
19
+ def initialize(config)
20
+ super()
21
+ @word_embeddings = Torch::NN::Embedding.new(config.vocab_size, config.hidden_size, padding_idx: config.pad_token_id)
22
+ @position_embeddings = Torch::NN::Embedding.new(config.max_position_embeddings, config.hidden_size)
23
+ @token_type_embeddings = Torch::NN::Embedding.new(config.type_vocab_size, config.hidden_size)
24
+
25
+ @LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
26
+ @dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
27
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
28
+ @position_embedding_type = config.position_embedding_type || "absolute"
29
+ register_buffer(
30
+ "position_ids", Torch.arange(config.max_position_embeddings).expand([1, -1]), persistent: false
31
+ )
32
+ register_buffer(
33
+ "token_type_ids", Torch.zeros(position_ids.size, dtype: Torch.long), persistent: false
34
+ )
35
+ end
36
+
37
+ def forward(
38
+ input_ids: nil,
39
+ token_type_ids: nil,
40
+ position_ids: nil,
41
+ inputs_embeds: nil,
42
+ past_key_values_length: 0
43
+ )
44
+ if !input_ids.nil?
45
+ input_shape = input_ids.size
46
+ else
47
+ input_shape = inputs_embeds.size[...-1]
48
+ end
49
+
50
+ seq_length = input_shape[1]
51
+
52
+ if position_ids.nil?
53
+ position_ids = @position_ids[0.., past_key_values_length...(seq_length + past_key_values_length)]
54
+ end
55
+
56
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
57
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
58
+ # issue #5664
59
+ if token_type_ids.nil?
60
+ raise Todo
61
+ end
62
+
63
+ if inputs_embeds.nil?
64
+ inputs_embeds = @word_embeddings.(input_ids)
65
+ end
66
+ token_type_embeddings = @token_type_embeddings.(token_type_ids)
67
+
68
+ embeddings = inputs_embeds + token_type_embeddings
69
+ if @position_embedding_type == "absolute"
70
+ position_embeddings = @position_embeddings.(position_ids)
71
+ embeddings += position_embeddings
72
+ end
73
+ embeddings = @LayerNorm.(embeddings)
74
+ embeddings = @dropout.(embeddings)
75
+ embeddings
76
+ end
77
+ end
78
+
79
+ class BertSelfAttention < Torch::NN::Module
80
+ def initialize(config, position_embedding_type: nil)
81
+ super()
82
+ if config.hidden_size % config.num_attention_heads != 0 && !config.embedding_size
83
+ raise ArgumentError,
84
+ "The hidden size (#{config.hidden_size}) is not a multiple of the number of attention " +
85
+ "heads (#{config.num_attention_heads})"
86
+ end
87
+
88
+ @num_attention_heads = config.num_attention_heads
89
+ @attention_head_size = (config.hidden_size / config.num_attention_heads).to_i
90
+ @all_head_size = @num_attention_heads * @attention_head_size
91
+
92
+ @query = Torch::NN::Linear.new(config.hidden_size, @all_head_size)
93
+ @key = Torch::NN::Linear.new(config.hidden_size, @all_head_size)
94
+ @value = Torch::NN::Linear.new(config.hidden_size, @all_head_size)
95
+
96
+ @dropout = Torch::NN::Dropout.new(p: config.attention_probs_dropout_prob)
97
+ @position_embedding_type = position_embedding_type || config.position_embedding_type || "absolute"
98
+ if @position_embedding_type == "relative_key" || @position_embedding_type == "relative_key_query"
99
+ @max_position_embeddings = config.max_position_embeddings
100
+ @distance_embedding = Torch:NN::Embedding.new(2 * config.max_position_embeddings - 1, @attention_head_size)
101
+ end
102
+
103
+ @is_decoder = config.is_decoder
104
+ end
105
+
106
+ def transpose_for_scores(x)
107
+ new_x_shape = x.size[...-1] + [@num_attention_heads, @attention_head_size]
108
+ x = x.view(new_x_shape)
109
+ x.permute(0, 2, 1, 3)
110
+ end
111
+
112
+ def forward(
113
+ hidden_states,
114
+ attention_mask: nil,
115
+ head_mask: nil,
116
+ encoder_hidden_states: nil,
117
+ encoder_attention_mask: nil,
118
+ past_key_value: nil,
119
+ output_attentions: false
120
+ )
121
+ mixed_query_layer = @query.(hidden_states)
122
+
123
+ # If this is instantiated as a cross-attention module, the keys
124
+ # and values come from an encoder; the attention mask needs to be
125
+ # such that the encoder's padding tokens are not attended to.
126
+ is_cross_attention = !encoder_hidden_states.nil?
127
+
128
+ if is_cross_attention && !past_key_value.nil?
129
+ # reuse k,v, cross_attentions
130
+ key_layer = past_key_value[0]
131
+ value_layer = past_key_value[1]
132
+ attention_mask = encoder_attention_mask
133
+ elsif is_cross_attention
134
+ key_layer = transpose_for_scores(@key.(encoder_hidden_states))
135
+ value_layer = transpose_for_scores(@value.(encoder_hidden_states))
136
+ attention_mask = encoder_attention_mask
137
+ elsif !past_key_value.nil?
138
+ key_layer = transpose_for_scores(@key.(hidden_states))
139
+ value_layer = transpose_for_scores(@value.(hidden_states))
140
+ key_layer = Torch.cat([past_key_value[0], key_layer], dim: 2)
141
+ value_layer = Torch.cat([past_key_value[1], value_layer], dim: 2)
142
+ else
143
+ key_layer = transpose_for_scores(@key.(hidden_states))
144
+ value_layer = transpose_for_scores(@value.(hidden_states))
145
+ end
146
+
147
+ query_layer = transpose_for_scores(mixed_query_layer)
148
+
149
+ _use_cache = !past_key_value.nil?
150
+ if @is_decoder
151
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
152
+ # Further calls to cross_attention layer can then reuse all cross-attention
153
+ # key/value_states (first "if" case)
154
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
155
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
156
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
157
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
158
+ past_key_value = [key_layer, value_layer]
159
+ end
160
+
161
+ # Take the dot product between "query" and "key" to get the raw attention scores.
162
+ attention_scores = Torch.matmul(query_layer, key_layer.transpose(-1, -2))
163
+
164
+ if @position_embedding_type == "relative_key" || @position_embedding_type == "relative_key_query"
165
+ raise Todo
166
+ end
167
+
168
+ attention_scores = attention_scores / Math.sqrt(@attention_head_size)
169
+ if !attention_mask.nil?
170
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
171
+ attention_scores = attention_scores + attention_mask
172
+ end
173
+
174
+ # Normalize the attention scores to probabilities.
175
+ attention_probs = Torch::NN::Functional.softmax(attention_scores, dim: -1)
176
+
177
+ # This is actually dropping out entire tokens to attend to, which might
178
+ # seem a bit unusual, but is taken from the original Transformer paper.
179
+ attention_probs = @dropout.(attention_probs)
180
+
181
+ # Mask heads if we want to
182
+ if !head_mask.nil?
183
+ attention_probs = attention_probs * head_mask
184
+ end
185
+
186
+ context_layer = Torch.matmul(attention_probs, value_layer)
187
+
188
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous
189
+ new_context_layer_shape = context_layer.size[...-2] + [@all_head_size]
190
+ context_layer = context_layer.view(new_context_layer_shape)
191
+
192
+ outputs = output_attentions ? [context_layer, attention_probs] : [context_layer]
193
+
194
+ if @is_decoder
195
+ outputs = outputs + [past_key_value]
196
+ end
197
+ outputs
198
+ end
199
+ end
200
+
201
+ class BertSelfOutput < Torch::NN::Module
202
+ def initialize(config)
203
+ super()
204
+ @dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
205
+ @LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
206
+ @dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
207
+ end
208
+
209
+ def forward(hidden_states, input_tensor)
210
+ hidden_states = @dense.(hidden_states)
211
+ hidden_states = @dropout.(hidden_states)
212
+ hidden_states = @LayerNorm.(hidden_states + input_tensor)
213
+ hidden_states
214
+ end
215
+ end
216
+
217
+ BERT_SELF_ATTENTION_CLASSES = {
218
+ "eager" => BertSelfAttention
219
+ }
220
+
221
+ class BertAttention < Torch::NN::Module
222
+ def initialize(config, position_embedding_type: nil)
223
+ super()
224
+ @self = BERT_SELF_ATTENTION_CLASSES.fetch(config._attn_implementation).new(
225
+ config, position_embedding_type: position_embedding_type
226
+ )
227
+ @output = BertSelfOutput.new(config)
228
+ @pruned_heads = Set.new
229
+ end
230
+
231
+ def forward(
232
+ hidden_states,
233
+ attention_mask: nil,
234
+ head_mask: nil,
235
+ encoder_hidden_states: nil,
236
+ encoder_attention_mask: nil,
237
+ past_key_value: nil,
238
+ output_attentions: false
239
+ )
240
+ self_outputs = @self.(
241
+ hidden_states,
242
+ attention_mask: attention_mask,
243
+ head_mask: head_mask,
244
+ encoder_hidden_states: encoder_hidden_states,
245
+ encoder_attention_mask: encoder_attention_mask,
246
+ past_key_value: past_key_value,
247
+ output_attentions: output_attentions
248
+ )
249
+ attention_output = @output.(self_outputs[0], hidden_states)
250
+ outputs = [attention_output] + self_outputs[1..] # add attentions if we output them
251
+ outputs
252
+ end
253
+ end
254
+
255
+ class BertIntermediate < Torch::NN::Module
256
+ def initialize(config)
257
+ super()
258
+ @dense = Torch::NN::Linear.new(config.hidden_size, config.intermediate_size)
259
+ if config.hidden_act.is_a?(String)
260
+ @intermediate_act_fn = ACT2FN[config.hidden_act]
261
+ else
262
+ @intermediate_act_fn = config.hidden_act
263
+ end
264
+ end
265
+
266
+ def forward(hidden_states)
267
+ hidden_states = @dense.(hidden_states)
268
+ hidden_states = @intermediate_act_fn.(hidden_states)
269
+ hidden_states
270
+ end
271
+ end
272
+
273
+ class BertOutput < Torch::NN::Module
274
+ def initialize(config)
275
+ super()
276
+ @dense = Torch::NN::Linear.new(config.intermediate_size, config.hidden_size)
277
+ @LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
278
+ @dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
279
+ end
280
+
281
+ def forward(hidden_states, input_tensor)
282
+ hidden_states = @dense.(hidden_states)
283
+ hidden_states = @dropout.(hidden_states)
284
+ hidden_states = @LayerNorm.(hidden_states + input_tensor)
285
+ hidden_states
286
+ end
287
+ end
288
+
289
+ class BertLayer < Torch::NN::Module
290
+ def initialize(config)
291
+ super()
292
+ @chunk_size_feed_forward = config.chunk_size_feed_forward
293
+ @seq_len_dim = 1
294
+ @attention = BertAttention.new(config)
295
+ @is_decoder = config.is_decoder
296
+ @add_cross_attention = config.add_cross_attention
297
+ if @add_cross_attention
298
+ if !@is_decoder
299
+ raise ArgumentError, "#{self} should be used as a decoder model if cross attention is added"
300
+ end
301
+ @crossattention = BertAttention.new(config, position_embedding_type: "absolute")
302
+ end
303
+ @intermediate = BertIntermediate.new(config)
304
+ @output = BertOutput.new(config)
305
+ end
306
+
307
+ def forward(
308
+ hidden_states,
309
+ attention_mask: nil,
310
+ head_mask: nil,
311
+ encoder_hidden_states: nil,
312
+ encoder_attention_mask: nil,
313
+ past_key_value: nil,
314
+ output_attentions: false
315
+ )
316
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
317
+ self_attn_past_key_value = !past_key_value.nil? ? past_key_value[...2] : nil
318
+ self_attention_outputs = @attention.(
319
+ hidden_states,
320
+ attention_mask: attention_mask,
321
+ head_mask: head_mask,
322
+ output_attentions: output_attentions,
323
+ past_key_value: self_attn_past_key_value
324
+ )
325
+ attention_output = self_attention_outputs[0]
326
+
327
+ # if decoder, the last output is tuple of self-attn cache
328
+ if @is_decoder
329
+ outputs = self_attention_outputs[1...-1]
330
+ present_key_value = self_attention_outputs[-1]
331
+ else
332
+ outputs = self_attention_outputs[1..] # add self attentions if we output attention weights
333
+ end
334
+
335
+ _cross_attn_present_key_value = nil
336
+ if @is_decoder && !encoder_hidden_states.nil?
337
+ raise Todo
338
+ end
339
+
340
+ layer_output = TorchUtils.apply_chunking_to_forward(
341
+ method(:feed_forward_chunk), @chunk_size_feed_forward, @seq_len_dim, attention_output
342
+ )
343
+ outputs = [layer_output] + outputs
344
+
345
+ # if decoder, return the attn key/values as the last output
346
+ if @is_decoder
347
+ outputs = outputs + [present_key_value]
348
+ end
349
+
350
+ outputs
351
+ end
352
+
353
+ def feed_forward_chunk(attention_output)
354
+ intermediate_output = @intermediate.(attention_output)
355
+ layer_output = @output.(intermediate_output, attention_output)
356
+ return layer_output
357
+ end
358
+ end
359
+
360
+ class BertEncoder < Torch::NN::Module
361
+ def initialize(config)
362
+ super()
363
+ @config = config
364
+ @layer = Torch::NN::ModuleList.new(config.num_hidden_layers.times.map { BertLayer.new(config) })
365
+ @gradient_checkpointing = false
366
+ end
367
+
368
+ def forward(
369
+ hidden_states,
370
+ attention_mask: nil,
371
+ head_mask: nil,
372
+ encoder_hidden_states: nil,
373
+ encoder_attention_mask:nil,
374
+ past_key_values: nil,
375
+ use_cache: nil,
376
+ output_attentions: false,
377
+ output_hidden_states: false,
378
+ return_dict: true
379
+ )
380
+ all_hidden_states = output_hidden_states ? [] : nil
381
+ all_self_attentions = output_attentions ? [] : nil
382
+ all_cross_attentions = output_attentions && @config.add_cross_attention ? [] : nil
383
+
384
+ if @gradient_checkpointing && @raining
385
+ raise Todo
386
+ end
387
+
388
+ next_decoder_cache = use_cache ? [] : nil
389
+ @layer.each_with_index do |layer_module, i|
390
+ if output_hidden_states
391
+ all_hidden_states = all_hidden_states + [hidden_states]
392
+ end
393
+
394
+ layer_head_mask = !head_mask.nil? ? head_mask[i] : nil
395
+ past_key_value = !past_key_values.nil? ? past_key_values[i] : nil
396
+
397
+ if @gradient_checkpointing && @training
398
+ raise Todo
399
+ else
400
+ layer_outputs = layer_module.(
401
+ hidden_states,
402
+ attention_mask: attention_mask,
403
+ head_mask: layer_head_mask,
404
+ encoder_hidden_states: encoder_hidden_states,
405
+ encoder_attention_mask: encoder_attention_mask,
406
+ past_key_value: past_key_value,
407
+ output_attentions: output_attentions
408
+ )
409
+ end
410
+
411
+ hidden_states = layer_outputs[0]
412
+ if use_cache
413
+ next_decoder_cache += [layer_outputs[-1]]
414
+ end
415
+ if output_attentions
416
+ all_self_attentions = all_self_attentions + [layer_outputs[1]]
417
+ if @config.add_cross_attention
418
+ all_cross_attentions = all_cross_attentions + [layer_outputs[2]]
419
+ end
420
+ end
421
+ end
422
+
423
+ if output_hidden_states
424
+ all_hidden_states = all_hidden_states + [hidden_states]
425
+ end
426
+
427
+ if !return_dict
428
+ raise Todo
429
+ end
430
+ BaseModelOutputWithPastAndCrossAttentions.new(
431
+ last_hidden_state: hidden_states,
432
+ past_key_values: next_decoder_cache,
433
+ hidden_states: all_hidden_states,
434
+ attentions: all_self_attentions,
435
+ cross_attentions: all_cross_attentions
436
+ )
437
+ end
438
+ end
439
+
440
+ class BertPooler < Torch::NN::Module
441
+ def initialize(config)
442
+ super()
443
+ @dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
444
+ @activation = Torch::NN::Tanh.new
445
+ end
446
+
447
+ def forward(hidden_states)
448
+ # We "pool" the model by simply taking the hidden state corresponding
449
+ # to the first token.
450
+ first_token_tensor = hidden_states[0.., 0]
451
+ pooled_output = @dense.(first_token_tensor)
452
+ pooled_output = @activation.(pooled_output)
453
+ pooled_output
454
+ end
455
+ end
456
+
457
+ class BertPredictionHeadTransform < Torch::NN::Module
458
+ def initialize(config)
459
+ super()
460
+ @dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
461
+ if config.hidden_act.is_a?(String)
462
+ @transform_act_fn = ACT2FN[config.hidden_act]
463
+ else
464
+ @transform_act_fn = config.hidden_act
465
+ end
466
+ @LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
467
+ end
468
+
469
+ def forward(hidden_states)
470
+ hidden_states = @dense.(hidden_states)
471
+ hidden_states = @transform_act_fn.(hidden_states)
472
+ hidden_states = @LayerNorm.(hidden_states)
473
+ hidden_states
474
+ end
475
+ end
476
+
477
+ class BertLMPredictionHead < Torch::NN::Module
478
+ def initialize(config)
479
+ super()
480
+ @transform = BertPredictionHeadTransform.new(config)
481
+
482
+ # The output weights are the same as the input embeddings, but there is
483
+ # an output-only bias for each token.
484
+ @decoder = Torch::NN::Linear.new(config.hidden_size, config.vocab_size, bias: false)
485
+
486
+ @bias = Torch::NN::Parameter.new(Torch.zeros(config.vocab_size))
487
+
488
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
489
+ @decoder.instance_variable_set(:@bias, @bias)
490
+ end
491
+
492
+ def _tie_weights
493
+ @decoder.instance_variable_set(:@bias, @bias)
494
+ end
495
+
496
+ def forward(hidden_states)
497
+ hidden_states = @transform.(hidden_states)
498
+ hidden_states = @decoder.(hidden_states)
499
+ hidden_states
500
+ end
501
+ end
502
+
503
+ class BertOnlyMLMHead < Torch::NN::Module
504
+ def initialize(config)
505
+ super()
506
+ @predictions = BertLMPredictionHead.new(config)
507
+ end
508
+
509
+ def forward(sequence_output)
510
+ prediction_scores = @predictions.(sequence_output)
511
+ prediction_scores
512
+ end
513
+ end
514
+
515
+ class BertPreTrainedModel < PreTrainedModel
516
+ self.config_class = BertConfig
517
+ self.base_model_prefix = "bert"
518
+
519
+ def _init_weights(mod)
520
+ if mod.is_a?(Torch::NN::Linear)
521
+ mod.weight.data.normal!(mean: 0.0, std: @config.initializer_range)
522
+ if !mod.bias.nil?
523
+ mod.bias.data.zero!
524
+ end
525
+ elsif mod.is_a?(Torch::NN::Embedding)
526
+ mod.weight.data.normal!(mean: 0.0, std: @config.initializer_range)
527
+ if !mod.instance_variable_get(:@padding_idx).nil?
528
+ mod.weight.data[mod.instance_variable_get(:@padding_idx)].zero!
529
+ end
530
+ elsif mod.is_a?(Torch::NN::LayerNorm)
531
+ mod.bias.data.zero!
532
+ mod.weight.data.fill!(1.0)
533
+ end
534
+ end
535
+ end
536
+
537
+ class BertModel < BertPreTrainedModel
538
+ def initialize(config, add_pooling_layer: true)
539
+ super(config)
540
+ @config = config
541
+
542
+ @embeddings = BertEmbeddings.new(config)
543
+ @encoder = BertEncoder.new(config)
544
+
545
+ @pooler = add_pooling_layer ? BertPooler.new(config) : nil
546
+
547
+ @attn_implementation = config._attn_implementation
548
+ @position_embedding_type = config.position_embedding_type
549
+
550
+ # Initialize weights and apply final processing
551
+ post_init
552
+ end
553
+
554
+ def _prune_heads(heads_to_prune)
555
+ heads_to_prune.each do |layer, heads|
556
+ @encoder.layer[layer].attention.prune_heads(heads)
557
+ end
558
+ end
559
+
560
+ def forward(
561
+ input_ids: nil,
562
+ attention_mask: nil,
563
+ token_type_ids: nil,
564
+ position_ids: nil,
565
+ head_mask: nil,
566
+ inputs_embeds: nil,
567
+ encoder_hidden_states: nil,
568
+ encoder_attention_mask: nil,
569
+ past_key_values: nil,
570
+ use_cache: nil,
571
+ output_attentions: nil,
572
+ output_hidden_states: nil,
573
+ return_dict: nil
574
+ )
575
+ output_attentions = !output_attentions.nil? ? output_attentions : @config.output_attentions
576
+ output_hidden_states = (
577
+ !output_hidden_states.nil? ? output_hidden_states : @config.output_hidden_states
578
+ )
579
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
580
+
581
+ if @config.is_decoder
582
+ use_cache = !use_cache.nil? ? use_cache : @config.use_cache
583
+ else
584
+ use_cache = false
585
+ end
586
+
587
+ if !input_ids.nil? && !inputs_embeds.nil?
588
+ raise ArgumentError, "You cannot specify both input_ids and inputs_embeds at the same time"
589
+ elsif !input_ids.nil?
590
+ # self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
591
+ input_shape = input_ids.size
592
+ elsif !inputs_embeds.nil?
593
+ input_shape = inputs_embeds.size[...-1]
594
+ else
595
+ raise ArgumentError, "You have to specify either input_ids or inputs_embeds"
596
+ end
597
+
598
+ batch_size, seq_length = input_shape
599
+ device = !input_ids.nil? ? input_ids.device : inputs_embeds.device
600
+
601
+ # past_key_values_length
602
+ past_key_values_length = !past_key_values.nil? ? past_key_values[0][0].shape[2] : 0
603
+
604
+ if token_type_ids.nil?
605
+ if @embeddings.token_type_ids
606
+ buffered_token_type_ids = @embeddings.token_type_ids[0.., 0...seq_length]
607
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
608
+ token_type_ids = buffered_token_type_ids_expanded
609
+ else
610
+ token_type_ids = Torch.zeros(input_shape, dtype: Torch.long, device: device)
611
+ end
612
+ end
613
+
614
+ embedding_output = @embeddings.(
615
+ input_ids: input_ids,
616
+ position_ids: position_ids,
617
+ token_type_ids: token_type_ids,
618
+ inputs_embeds: inputs_embeds,
619
+ past_key_values_length: past_key_values_length
620
+ )
621
+
622
+ if attention_mask.nil?
623
+ attention_mask = Torch.ones([batch_size, seq_length + past_key_values_length], device: device)
624
+ end
625
+
626
+ use_sdpa_attention_masks = (
627
+ @attn_implementation == "sdpa" &&
628
+ @position_embedding_type == "absolute" &&
629
+ head_mask.nil? &&
630
+ !output_attentions
631
+ )
632
+
633
+ # Expand the attention mask
634
+ if use_sdpa_attention_masks
635
+ raise Todo
636
+ else
637
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
638
+ # ourselves in which case we just need to make it broadcastable to all heads.
639
+ extended_attention_mask = get_extended_attention_mask(attention_mask, input_shape)
640
+ end
641
+
642
+ # # If a 2D or 3D attention mask is provided for the cross-attention
643
+ # # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
644
+ if @config.is_decoder && !encoder_hidden_states.nil?
645
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size
646
+ encoder_hidden_shape = [encoder_batch_size, encoder_sequence_length]
647
+ if encoder_attention_mask.nil?
648
+ encoder_attention_mask = Torch.ones(encoder_hidden_shape, device: device)
649
+ end
650
+
651
+ if use_sdpa_attention_masks
652
+ # Expand the attention mask for SDPA.
653
+ # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
654
+ encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
655
+ encoder_attention_mask, embedding_output.dtype, tgt_len: seq_length
656
+ )
657
+ else
658
+ encoder_extended_attention_mask = invert_attention_mask(encoder_attention_mask)
659
+ end
660
+ else
661
+ encoder_extended_attention_mask = nil
662
+ end
663
+
664
+ # Prepare head mask if needed
665
+ # 1.0 in head_mask indicate we keep the head
666
+ # attention_probs has shape bsz x n_heads x N x N
667
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
668
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
669
+ head_mask = get_head_mask(head_mask, @config.num_hidden_layers)
670
+
671
+ encoder_outputs = @encoder.(
672
+ embedding_output,
673
+ attention_mask: extended_attention_mask,
674
+ head_mask: head_mask,
675
+ encoder_hidden_states: encoder_hidden_states,
676
+ encoder_attention_mask: encoder_extended_attention_mask,
677
+ past_key_values: past_key_values,
678
+ use_cache: use_cache,
679
+ output_attentions: output_attentions,
680
+ output_hidden_states: output_hidden_states,
681
+ return_dict: return_dict
682
+ )
683
+ sequence_output = encoder_outputs[0]
684
+ pooled_output = !@pooler.nil? ? @pooler.(sequence_output) : nil
685
+
686
+ if !return_dict
687
+ raise Todo
688
+ end
689
+
690
+ BaseModelOutputWithPoolingAndCrossAttentions.new(
691
+ last_hidden_state: sequence_output,
692
+ pooler_output: pooled_output,
693
+ past_key_values: encoder_outputs.past_key_values,
694
+ hidden_states: encoder_outputs.hidden_states,
695
+ attentions: encoder_outputs.attentions,
696
+ cross_attentions: encoder_outputs.cross_attentions
697
+ )
698
+ end
699
+ end
700
+
701
+ class BertForMaskedLM < BertPreTrainedModel
702
+ def initialize(config)
703
+ super(config)
704
+
705
+ if config.is_decoder
706
+ Transformers.logger.warn(
707
+ "If you want to use `BertForMaskedLM` make sure `config.is_decoder: false` for " +
708
+ "bi-directional self-attention."
709
+ )
710
+ end
711
+
712
+ @bert = BertModel.new(config, add_pooling_layer: false)
713
+ @cls = BertOnlyMLMHead.new(config)
714
+ end
715
+
716
+ def forward(
717
+ input_ids: nil,
718
+ attention_mask: nil,
719
+ token_type_ids: nil,
720
+ position_ids: nil,
721
+ head_mask: nil,
722
+ inputs_embeds: nil,
723
+ encoder_hidden_states: nil,
724
+ encoder_attention_mask: nil,
725
+ labels: nil,
726
+ output_attentions: nil,
727
+ output_hidden_states: nil,
728
+ return_dict: nil
729
+ )
730
+ return_dict = !return_dict.nil? ? return_dict : config.use_return_dict
731
+
732
+ outputs = @bert.(
733
+ input_ids: input_ids,
734
+ attention_mask: attention_mask,
735
+ token_type_ids: token_type_ids,
736
+ position_ids: position_ids,
737
+ head_mask: head_mask,
738
+ inputs_embeds: inputs_embeds,
739
+ encoder_hidden_states: encoder_hidden_states,
740
+ encoder_attention_mask: encoder_attention_mask,
741
+ output_attentions: output_attentions,
742
+ output_hidden_states: output_hidden_states,
743
+ return_dict: return_dict
744
+ )
745
+
746
+ sequence_output = outputs[0]
747
+ prediction_scores = @cls.(sequence_output)
748
+
749
+ masked_lm_loss = nil
750
+ if !labels.nil?
751
+ raise Todo
752
+ end
753
+
754
+ if !return_dict
755
+ raise Todo
756
+ end
757
+
758
+ MaskedLMOutput.new(
759
+ loss: masked_lm_loss,
760
+ logits: prediction_scores,
761
+ hidden_states: outputs.hidden_states,
762
+ attentions: outputs.attentions
763
+ )
764
+ end
765
+ end
766
+
767
+ class BertForTokenClassification < BertPreTrainedModel
768
+ def initialize(config)
769
+ super(config)
770
+ @num_labels = config.num_labels
771
+
772
+ @bert = BertModel.new(config, add_pooling_layer: false)
773
+ classifier_dropout = (
774
+ !config.classifier_dropout.nil? ? config.classifier_dropout : config.hidden_dropout_prob
775
+ )
776
+ @dropout = Torch::NN::Dropout.new(p: classifier_dropout)
777
+ @classifier = Torch::NN::Linear.new(config.hidden_size, config.num_labels)
778
+
779
+ # Initialize weights and apply final processing
780
+ post_init
781
+ end
782
+
783
+ def forward(
784
+ input_ids: nil,
785
+ attention_mask: nil,
786
+ token_type_ids: nil,
787
+ position_ids: nil,
788
+ head_mask: nil,
789
+ inputs_embeds: nil,
790
+ labels: nil,
791
+ output_attentions: nil,
792
+ output_hidden_states: nil,
793
+ return_dict: nil
794
+ )
795
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
796
+
797
+ outputs = @bert.(
798
+ input_ids: input_ids,
799
+ attention_mask: attention_mask,
800
+ token_type_ids: token_type_ids,
801
+ position_ids: position_ids,
802
+ head_mask: head_mask,
803
+ inputs_embeds: inputs_embeds,
804
+ output_attentions: output_attentions,
805
+ output_hidden_states: output_hidden_states,
806
+ return_dict: return_dict
807
+ )
808
+
809
+ sequence_output = outputs[0]
810
+
811
+ sequence_output = @dropout.(sequence_output)
812
+ logits = @classifier.(sequence_output)
813
+
814
+ loss = nil
815
+ if !labels.nil?
816
+ loss_fct = CrossEntropyLoss.new
817
+ loss = loss_fct.(logits.view(-1,@num_labels), labels.view(-1))
818
+ end
819
+
820
+ if !return_dict
821
+ raise Todo
822
+ end
823
+
824
+ TokenClassifierOutput.new(
825
+ loss: loss,
826
+ logits: logits,
827
+ hidden_states: outputs.hidden_states,
828
+ attentions: outputs.attentions
829
+ )
830
+ end
831
+ end
832
+ end
833
+
834
+ BertModel = Bert::BertModel
835
+ BertForTokenClassification = Bert::BertForTokenClassification
836
+ end