transformers-rb 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +203 -0
- data/README.md +163 -0
- data/lib/transformers/activations.rb +57 -0
- data/lib/transformers/configuration_utils.rb +285 -0
- data/lib/transformers/convert_slow_tokenizer.rb +90 -0
- data/lib/transformers/data/processors/squad.rb +115 -0
- data/lib/transformers/dynamic_module_utils.rb +25 -0
- data/lib/transformers/feature_extraction_utils.rb +110 -0
- data/lib/transformers/hf_hub/constants.rb +71 -0
- data/lib/transformers/hf_hub/errors.rb +11 -0
- data/lib/transformers/hf_hub/file_download.rb +764 -0
- data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
- data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
- data/lib/transformers/image_processing_base.rb +169 -0
- data/lib/transformers/image_processing_utils.rb +63 -0
- data/lib/transformers/image_transforms.rb +208 -0
- data/lib/transformers/image_utils.rb +165 -0
- data/lib/transformers/modeling_outputs.rb +81 -0
- data/lib/transformers/modeling_utils.rb +888 -0
- data/lib/transformers/models/auto/auto_factory.rb +138 -0
- data/lib/transformers/models/auto/configuration_auto.rb +61 -0
- data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
- data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
- data/lib/transformers/models/auto/modeling_auto.rb +80 -0
- data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
- data/lib/transformers/models/bert/configuration_bert.rb +65 -0
- data/lib/transformers/models/bert/modeling_bert.rb +836 -0
- data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
- data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
- data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
- data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
- data/lib/transformers/models/vit/configuration_vit.rb +60 -0
- data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
- data/lib/transformers/models/vit/modeling_vit.rb +506 -0
- data/lib/transformers/pipelines/_init.rb +348 -0
- data/lib/transformers/pipelines/base.rb +301 -0
- data/lib/transformers/pipelines/feature_extraction.rb +47 -0
- data/lib/transformers/pipelines/image_classification.rb +110 -0
- data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
- data/lib/transformers/pipelines/pt_utils.rb +53 -0
- data/lib/transformers/pipelines/question_answering.rb +508 -0
- data/lib/transformers/pipelines/text_classification.rb +123 -0
- data/lib/transformers/pipelines/token_classification.rb +282 -0
- data/lib/transformers/ruby_utils.rb +33 -0
- data/lib/transformers/sentence_transformer.rb +37 -0
- data/lib/transformers/tokenization_utils.rb +152 -0
- data/lib/transformers/tokenization_utils_base.rb +937 -0
- data/lib/transformers/tokenization_utils_fast.rb +386 -0
- data/lib/transformers/torch_utils.rb +25 -0
- data/lib/transformers/utils/_init.rb +31 -0
- data/lib/transformers/utils/generic.rb +107 -0
- data/lib/transformers/utils/hub.rb +209 -0
- data/lib/transformers/utils/import_utils.rb +45 -0
- data/lib/transformers/utils/logging.rb +52 -0
- data/lib/transformers/version.rb +3 -0
- data/lib/transformers-rb.rb +1 -0
- data/lib/transformers.rb +100 -0
- data/licenses/LICENSE-huggingface-hub.txt +201 -0
- data/licenses/LICENSE-sentence-transformers.txt +201 -0
- data/licenses/NOTICE-sentence-transformers.txt +5 -0
- metadata +161 -0
@@ -0,0 +1,506 @@
|
|
1
|
+
# Copyright 2021 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
module Vit
|
17
|
+
class ViTEmbeddings < Torch::NN::Module
|
18
|
+
def initialize(config, use_mask_token: false)
|
19
|
+
super()
|
20
|
+
|
21
|
+
@cls_token = Torch::NN::Parameter.new(Torch.randn(1, 1, config.hidden_size))
|
22
|
+
@mask_token = use_mask_token ? Torch::NN::Parameter.new(Torch.zeros(1, 1, config.hidden_size)) : nil
|
23
|
+
@patch_embeddings = ViTPatchEmbeddings.new(config)
|
24
|
+
num_patches = @patch_embeddings.num_patches
|
25
|
+
@position_embeddings = Torch::NN::Parameter.new(Torch.randn(1, num_patches + 1, config.hidden_size))
|
26
|
+
@dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
|
27
|
+
@config = config
|
28
|
+
end
|
29
|
+
|
30
|
+
def forward(
|
31
|
+
pixel_values,
|
32
|
+
bool_masked_pos: nil,
|
33
|
+
interpolate_pos_encoding: false
|
34
|
+
)
|
35
|
+
batch_size, _num_channels, height, width = pixel_values.shape
|
36
|
+
embeddings = @patch_embeddings.(pixel_values, interpolate_pos_encoding: interpolate_pos_encoding)
|
37
|
+
|
38
|
+
if !bool_masked_pos.nil?
|
39
|
+
seq_length = embeddings.shape[1]
|
40
|
+
mask_tokens = @mask_token.expand(batch_size, seq_length, -1)
|
41
|
+
# replace the masked visual tokens by mask_tokens
|
42
|
+
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
|
43
|
+
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
|
44
|
+
end
|
45
|
+
|
46
|
+
# add the [CLS] token to the embedded patch tokens
|
47
|
+
cls_tokens = @cls_token.expand(batch_size, -1, -1)
|
48
|
+
embeddings = Torch.cat([cls_tokens, embeddings], dim: 1)
|
49
|
+
|
50
|
+
# add positional encoding to each token
|
51
|
+
if interpolate_pos_encoding
|
52
|
+
embeddings = embeddings + @interpolate_pos_encoding.(embeddings, height, width)
|
53
|
+
else
|
54
|
+
embeddings = embeddings + @position_embeddings
|
55
|
+
end
|
56
|
+
|
57
|
+
embeddings = @dropout.(embeddings)
|
58
|
+
|
59
|
+
embeddings
|
60
|
+
end
|
61
|
+
end
|
62
|
+
|
63
|
+
class ViTPatchEmbeddings < Torch::NN::Module
|
64
|
+
attr_reader :num_patches
|
65
|
+
|
66
|
+
def initialize(config)
|
67
|
+
super()
|
68
|
+
image_size, patch_size = config.image_size, config.patch_size
|
69
|
+
num_channels, hidden_size = config.num_channels, config.hidden_size
|
70
|
+
|
71
|
+
image_size = image_size.is_a?(Enumerable) ? image_size : [image_size, image_size]
|
72
|
+
patch_size = patch_size.is_a?(Enumerable) ? patch_size : [patch_size, patch_size]
|
73
|
+
num_patches = image_size[1].div(patch_size[1]) * image_size[0].div(patch_size[0])
|
74
|
+
@image_size = image_size
|
75
|
+
@patch_size = patch_size
|
76
|
+
@num_channels = num_channels
|
77
|
+
@num_patches = num_patches
|
78
|
+
|
79
|
+
@projection = Torch::NN::Conv2d.new(num_channels, hidden_size, patch_size, stride: patch_size)
|
80
|
+
end
|
81
|
+
|
82
|
+
def forward(pixel_values, interpolate_pos_encoding: false)
|
83
|
+
_batch_size, num_channels, height, width = pixel_values.shape
|
84
|
+
if num_channels != @num_channels
|
85
|
+
raise ArgumentError,
|
86
|
+
"Make sure that the channel dimension of the pixel values match with the one set in the configuration." +
|
87
|
+
" Expected #{@num_channels} but got #{num_channels}."
|
88
|
+
end
|
89
|
+
if !interpolate_pos_encoding
|
90
|
+
if height != @image_size[0] || width != @image_size[1]
|
91
|
+
raise ArgumentError,
|
92
|
+
"Input image size (#{height}*#{width}) doesn't match model" +
|
93
|
+
" (#{@image_size[0]}*#{@image_size[1]})."
|
94
|
+
end
|
95
|
+
end
|
96
|
+
embeddings = @projection.(pixel_values).flatten(2).transpose(1, 2)
|
97
|
+
embeddings
|
98
|
+
end
|
99
|
+
end
|
100
|
+
|
101
|
+
class ViTSelfAttention < Torch::NN::Module
|
102
|
+
def initialize(config)
|
103
|
+
super()
|
104
|
+
if config.hidden_size % config.num_attention_heads != 0 && !config.instance_variable_defined?(:@embedding_size)
|
105
|
+
raise ArgumentError,
|
106
|
+
"The hidden size #{config.hidden_size} is not a multiple of the number of attention " +
|
107
|
+
"heads #{config.num_attention_heads}."
|
108
|
+
end
|
109
|
+
|
110
|
+
@num_attention_heads = config.num_attention_heads
|
111
|
+
@attention_head_size = (config.hidden_size / config.num_attention_heads).to_i
|
112
|
+
@all_head_size = @num_attention_heads * @attention_head_size
|
113
|
+
|
114
|
+
@query = Torch::NN::Linear.new(config.hidden_size, @all_head_size, bias: config.qkv_bias)
|
115
|
+
@key = Torch::NN::Linear.new(config.hidden_size, @all_head_size, bias: config.qkv_bias)
|
116
|
+
@value = Torch::NN::Linear.new(config.hidden_size, @all_head_size, bias: config.qkv_bias)
|
117
|
+
|
118
|
+
@dropout = Torch::NN::Dropout.new(p: config.attention_probs_dropout_prob)
|
119
|
+
end
|
120
|
+
|
121
|
+
def transpose_for_scores(x)
|
122
|
+
new_x_shape = x.size[...-1] + [@num_attention_heads, @attention_head_size]
|
123
|
+
x = x.view(new_x_shape)
|
124
|
+
x.permute(0, 2, 1, 3)
|
125
|
+
end
|
126
|
+
|
127
|
+
def forward(
|
128
|
+
hidden_states, head_mask: nil, output_attentions: false
|
129
|
+
)
|
130
|
+
mixed_query_layer = @query.(hidden_states)
|
131
|
+
|
132
|
+
key_layer = transpose_for_scores(@key.(hidden_states))
|
133
|
+
value_layer = transpose_for_scores(@value.(hidden_states))
|
134
|
+
query_layer = transpose_for_scores(mixed_query_layer)
|
135
|
+
|
136
|
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
137
|
+
attention_scores = Torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
138
|
+
|
139
|
+
attention_scores = attention_scores / Math.sqrt(@attention_head_size)
|
140
|
+
|
141
|
+
# Normalize the attention scores to probabilities.
|
142
|
+
attention_probs = Torch::NN::Functional.softmax(attention_scores, dim: -1)
|
143
|
+
|
144
|
+
# This is actually dropping out entire tokens to attend to, which might
|
145
|
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
146
|
+
attention_probs = @dropout.(attention_probs)
|
147
|
+
|
148
|
+
# Mask heads if we want to
|
149
|
+
if !head_mask.nil?
|
150
|
+
attention_probs = attention_probs * head_mask
|
151
|
+
end
|
152
|
+
|
153
|
+
context_layer = Torch.matmul(attention_probs, value_layer)
|
154
|
+
|
155
|
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous
|
156
|
+
new_context_layer_shape = context_layer.size[...-2] + [@all_head_size]
|
157
|
+
context_layer = context_layer.view(new_context_layer_shape)
|
158
|
+
|
159
|
+
outputs = output_attentions ? [context_layer, attention_probs] : [context_layer]
|
160
|
+
|
161
|
+
outputs
|
162
|
+
end
|
163
|
+
end
|
164
|
+
|
165
|
+
class ViTSelfOutput < Torch::NN::Module
|
166
|
+
def initialize(config)
|
167
|
+
super()
|
168
|
+
@dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
|
169
|
+
@dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
|
170
|
+
end
|
171
|
+
|
172
|
+
def forward(hidden_states, input_tensor)
|
173
|
+
hidden_states = @dense.(hidden_states)
|
174
|
+
hidden_states = @dropout.(hidden_states)
|
175
|
+
|
176
|
+
hidden_states
|
177
|
+
end
|
178
|
+
end
|
179
|
+
|
180
|
+
class ViTAttention < Torch::NN::Module
|
181
|
+
def initialize(config)
|
182
|
+
super()
|
183
|
+
@attention = ViTSelfAttention.new(config)
|
184
|
+
@output = ViTSelfOutput.new(config)
|
185
|
+
@pruned_heads = Set.new
|
186
|
+
end
|
187
|
+
|
188
|
+
def prune_heads(heads)
|
189
|
+
raise Todo
|
190
|
+
end
|
191
|
+
|
192
|
+
def forward(
|
193
|
+
hidden_states,
|
194
|
+
head_mask: nil,
|
195
|
+
output_attentions: false
|
196
|
+
)
|
197
|
+
self_outputs = @attention.(hidden_states, head_mask: head_mask, output_attentions: output_attentions)
|
198
|
+
|
199
|
+
attention_output = @output.(self_outputs[0], hidden_states)
|
200
|
+
|
201
|
+
outputs = [attention_output] + self_outputs[1..] # add attentions if we output them
|
202
|
+
outputs
|
203
|
+
end
|
204
|
+
end
|
205
|
+
|
206
|
+
class ViTIntermediate < Torch::NN::Module
|
207
|
+
def initialize(config)
|
208
|
+
super()
|
209
|
+
@dense = Torch::NN::Linear.new(config.hidden_size, config.intermediate_size)
|
210
|
+
if config.hidden_act.is_a?(String)
|
211
|
+
@intermediate_act_fn = ACT2FN[config.hidden_act]
|
212
|
+
else
|
213
|
+
@intermediate_act_fn = config.hidden_act
|
214
|
+
end
|
215
|
+
end
|
216
|
+
|
217
|
+
def forward(hidden_states)
|
218
|
+
hidden_states = @dense.(hidden_states)
|
219
|
+
hidden_states = @intermediate_act_fn.(hidden_states)
|
220
|
+
|
221
|
+
hidden_states
|
222
|
+
end
|
223
|
+
end
|
224
|
+
|
225
|
+
class ViTOutput < Torch::NN::Module
|
226
|
+
def initialize(config)
|
227
|
+
super()
|
228
|
+
@dense = Torch::NN::Linear.new(config.intermediate_size, config.hidden_size)
|
229
|
+
@dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
|
230
|
+
end
|
231
|
+
|
232
|
+
def forward(hidden_states, input_tensor)
|
233
|
+
hidden_states = @dense.(hidden_states)
|
234
|
+
hidden_states = @dropout.(hidden_states)
|
235
|
+
|
236
|
+
hidden_states = hidden_states + input_tensor
|
237
|
+
|
238
|
+
hidden_states
|
239
|
+
end
|
240
|
+
end
|
241
|
+
|
242
|
+
VIT_ATTENTION_CLASSES = {
|
243
|
+
"eager" => ViTAttention
|
244
|
+
}
|
245
|
+
|
246
|
+
class ViTLayer < Torch::NN::Module
|
247
|
+
def initialize(config)
|
248
|
+
super()
|
249
|
+
@chunk_size_feed_forward = config.chunk_size_feed_forward
|
250
|
+
@seq_len_dim = 1
|
251
|
+
@attention = VIT_ATTENTION_CLASSES.fetch(config._attn_implementation).new(config)
|
252
|
+
@intermediate = ViTIntermediate.new(config)
|
253
|
+
@output = ViTOutput.new(config)
|
254
|
+
@layernorm_before = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
|
255
|
+
@layernorm_after = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
|
256
|
+
end
|
257
|
+
|
258
|
+
def forward(
|
259
|
+
hidden_states,
|
260
|
+
head_mask: nil,
|
261
|
+
output_attentions: false
|
262
|
+
)
|
263
|
+
self_attention_outputs = @attention.(
|
264
|
+
@layernorm_before.(hidden_states), # in ViT, layernorm is applied before self-attention
|
265
|
+
head_mask: head_mask,
|
266
|
+
output_attentions: output_attentions
|
267
|
+
)
|
268
|
+
attention_output = self_attention_outputs[0]
|
269
|
+
outputs = self_attention_outputs[1..] # add self attentions if we output attention weights
|
270
|
+
|
271
|
+
# first residual connection
|
272
|
+
hidden_states = attention_output + hidden_states
|
273
|
+
|
274
|
+
# in ViT, layernorm is also applied after self-attention
|
275
|
+
layer_output = @layernorm_after.(hidden_states)
|
276
|
+
layer_output = @intermediate.(layer_output)
|
277
|
+
|
278
|
+
# second residual connection is done here
|
279
|
+
layer_output = @output.(layer_output, hidden_states)
|
280
|
+
|
281
|
+
outputs = [layer_output] + outputs
|
282
|
+
|
283
|
+
outputs
|
284
|
+
end
|
285
|
+
end
|
286
|
+
|
287
|
+
class ViTEncoder < Torch::NN::Module
|
288
|
+
def initialize(config)
|
289
|
+
super()
|
290
|
+
@config = config
|
291
|
+
@layer = Torch::NN::ModuleList.new(config.num_hidden_layers.times.map { ViTLayer.new(config) })
|
292
|
+
@gradient_checkpointing = false
|
293
|
+
end
|
294
|
+
|
295
|
+
def forward(
|
296
|
+
hidden_states,
|
297
|
+
head_mask: nil,
|
298
|
+
output_attentions: false,
|
299
|
+
output_hidden_states: false,
|
300
|
+
return_dict: true
|
301
|
+
)
|
302
|
+
all_hidden_states = output_hidden_states ? [] : nil
|
303
|
+
all_self_attentions = output_attentions ? [] : nil
|
304
|
+
|
305
|
+
@layer.each_with_index do |layer_module, i|
|
306
|
+
if output_hidden_states
|
307
|
+
all_hidden_states = all_hidden_states + [hidden_states]
|
308
|
+
end
|
309
|
+
|
310
|
+
layer_head_mask = !head_mask.nil? ? head_mask[i] : nil
|
311
|
+
|
312
|
+
if @gradient_checkpointing && @training
|
313
|
+
raise Todo
|
314
|
+
else
|
315
|
+
layer_outputs = layer_module.(hidden_states, head_mask: layer_head_mask, output_attentions: output_attentions)
|
316
|
+
end
|
317
|
+
|
318
|
+
hidden_states = layer_outputs[0]
|
319
|
+
|
320
|
+
if output_attentions
|
321
|
+
all_self_attentions = all_self_attentions + [layer_outputs[1]]
|
322
|
+
end
|
323
|
+
end
|
324
|
+
|
325
|
+
if output_hidden_states
|
326
|
+
all_hidden_states = all_hidden_states + [hidden_states]
|
327
|
+
end
|
328
|
+
|
329
|
+
if !return_dict
|
330
|
+
raise Todo
|
331
|
+
end
|
332
|
+
BaseModelOutput.new(
|
333
|
+
last_hidden_state: hidden_states,
|
334
|
+
hidden_states: all_hidden_states,
|
335
|
+
attentions: all_self_attentions
|
336
|
+
)
|
337
|
+
end
|
338
|
+
end
|
339
|
+
|
340
|
+
class ViTPreTrainedModel < PreTrainedModel
|
341
|
+
self.config_class = ViTConfig
|
342
|
+
self.base_model_prefix = "vit"
|
343
|
+
self.main_input_name = "pixel_values"
|
344
|
+
|
345
|
+
def _init_weights(mod)
|
346
|
+
# TODO
|
347
|
+
end
|
348
|
+
end
|
349
|
+
|
350
|
+
class ViTModel < ViTPreTrainedModel
|
351
|
+
def initialize(config, add_pooling_layer: true, use_mask_token: false)
|
352
|
+
super(config)
|
353
|
+
@config = config
|
354
|
+
|
355
|
+
@embeddings = ViTEmbeddings.new(config, use_mask_token: use_mask_token)
|
356
|
+
@encoder = ViTEncoder.new(config)
|
357
|
+
|
358
|
+
@layernorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
|
359
|
+
@pooler = add_pooling_layer ? ViTPooler.new(config) : nil
|
360
|
+
|
361
|
+
# Initialize weights and apply final processing
|
362
|
+
post_init
|
363
|
+
end
|
364
|
+
|
365
|
+
def _prune_heads(heads_to_prune)
|
366
|
+
heads_to_prune.each do |layer, heads|
|
367
|
+
@encoder.layer[layer].attention.prune_heads(heads)
|
368
|
+
end
|
369
|
+
end
|
370
|
+
|
371
|
+
def forward(
|
372
|
+
pixel_values: nil,
|
373
|
+
bool_masked_pos: nil,
|
374
|
+
head_mask: nil,
|
375
|
+
output_attentions: nil,
|
376
|
+
output_hidden_states: nil,
|
377
|
+
interpolate_pos_encoding: nil,
|
378
|
+
return_dict: nil
|
379
|
+
)
|
380
|
+
output_attentions = !output_attentions.nil? ? output_attentions : @config.output_attentions
|
381
|
+
output_hidden_states = (
|
382
|
+
!output_hidden_states.nil? ? output_hidden_states : @config.output_hidden_states
|
383
|
+
)
|
384
|
+
return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
|
385
|
+
|
386
|
+
if pixel_values.nil?
|
387
|
+
raise ArgumentError, "You have to specify pixel_values"
|
388
|
+
end
|
389
|
+
|
390
|
+
# Prepare head mask if needed
|
391
|
+
# 1.0 in head_mask indicate we keep the head
|
392
|
+
# attention_probs has shape bsz x n_heads x N x N
|
393
|
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
394
|
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
395
|
+
head_mask = get_head_mask(head_mask, @config.num_hidden_layers)
|
396
|
+
|
397
|
+
# TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
|
398
|
+
expected_dtype = @embeddings.patch_embeddings.projection.weight.dtype
|
399
|
+
if pixel_values.dtype != expected_dtype
|
400
|
+
pixel_values = pixel_values.to(expected_dtype)
|
401
|
+
end
|
402
|
+
|
403
|
+
embedding_output = @embeddings.(
|
404
|
+
pixel_values, bool_masked_pos: bool_masked_pos, interpolate_pos_encoding: interpolate_pos_encoding
|
405
|
+
)
|
406
|
+
|
407
|
+
encoder_outputs = @encoder.(
|
408
|
+
embedding_output,
|
409
|
+
head_mask: head_mask,
|
410
|
+
output_attentions: output_attentions,
|
411
|
+
output_hidden_states: output_hidden_states,
|
412
|
+
return_dict: return_dict
|
413
|
+
)
|
414
|
+
sequence_output = encoder_outputs[0]
|
415
|
+
sequence_output = @layernorm.(sequence_output)
|
416
|
+
pooled_output = @pooler ? @pooler.(sequence_output) : nil
|
417
|
+
|
418
|
+
if !return_dict
|
419
|
+
raise Todo
|
420
|
+
end
|
421
|
+
|
422
|
+
BaseModelOutputWithPooling.new(
|
423
|
+
last_hidden_state: sequence_output,
|
424
|
+
pooler_output: pooled_output,
|
425
|
+
hidden_states: encoder_outputs.hidden_states,
|
426
|
+
attentions: encoder_outputs.attentions
|
427
|
+
)
|
428
|
+
end
|
429
|
+
end
|
430
|
+
|
431
|
+
class ViTPooler < Torch::NN::Module
|
432
|
+
def initialize(config)
|
433
|
+
super()
|
434
|
+
@dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
|
435
|
+
@activation = Torch::NN::Tanh.new
|
436
|
+
end
|
437
|
+
|
438
|
+
def forward(hidden_states)
|
439
|
+
# We "pool" the model by simply taking the hidden state corresponding
|
440
|
+
# to the first token.
|
441
|
+
first_token_tensor = hidden_states[0.., 0]
|
442
|
+
pooled_output = @dense.(first_token_tensor)
|
443
|
+
pooled_output = @activation.(pooled_output)
|
444
|
+
pooled_output
|
445
|
+
end
|
446
|
+
end
|
447
|
+
|
448
|
+
class ViTForImageClassification < ViTPreTrainedModel
|
449
|
+
def initialize(config)
|
450
|
+
super(config)
|
451
|
+
|
452
|
+
@num_labels = config.num_labels
|
453
|
+
@vit = ViTModel.new(config, add_pooling_layer: false)
|
454
|
+
|
455
|
+
# Classifier head
|
456
|
+
@classifier = config.num_labels > 0 ? Torch::NN::Linear.new(config.hidden_size, config.num_labels) : Torch::NN::Identity.new
|
457
|
+
|
458
|
+
# Initialize weights and apply final processing
|
459
|
+
post_init
|
460
|
+
end
|
461
|
+
|
462
|
+
def forward(
|
463
|
+
pixel_values: nil,
|
464
|
+
head_mask: nil,
|
465
|
+
labels: nil,
|
466
|
+
output_attentions: nil,
|
467
|
+
output_hidden_states: nil,
|
468
|
+
interpolate_pos_encoding: nil,
|
469
|
+
return_dict: nil
|
470
|
+
)
|
471
|
+
return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
|
472
|
+
|
473
|
+
outputs = @vit.(
|
474
|
+
pixel_values: pixel_values,
|
475
|
+
head_mask: head_mask,
|
476
|
+
output_attentions: output_attentions,
|
477
|
+
output_hidden_states: output_hidden_states,
|
478
|
+
interpolate_pos_encoding: interpolate_pos_encoding,
|
479
|
+
return_dict: return_dict
|
480
|
+
)
|
481
|
+
|
482
|
+
sequence_output = outputs[0]
|
483
|
+
|
484
|
+
logits = @classifier.(sequence_output[0.., 0, 0..])
|
485
|
+
|
486
|
+
loss = nil
|
487
|
+
if !labels.nil?
|
488
|
+
raise Todo
|
489
|
+
end
|
490
|
+
|
491
|
+
if !return_dict
|
492
|
+
raise Todo
|
493
|
+
end
|
494
|
+
|
495
|
+
ImageClassifierOutput.new(
|
496
|
+
loss: loss,
|
497
|
+
logits: logits,
|
498
|
+
hidden_states: outputs.hidden_states,
|
499
|
+
attentions: outputs.attentions
|
500
|
+
)
|
501
|
+
end
|
502
|
+
end
|
503
|
+
end
|
504
|
+
|
505
|
+
ViTForImageClassification = Vit::ViTForImageClassification
|
506
|
+
end
|