transformers-rb 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +203 -0
- data/README.md +163 -0
- data/lib/transformers/activations.rb +57 -0
- data/lib/transformers/configuration_utils.rb +285 -0
- data/lib/transformers/convert_slow_tokenizer.rb +90 -0
- data/lib/transformers/data/processors/squad.rb +115 -0
- data/lib/transformers/dynamic_module_utils.rb +25 -0
- data/lib/transformers/feature_extraction_utils.rb +110 -0
- data/lib/transformers/hf_hub/constants.rb +71 -0
- data/lib/transformers/hf_hub/errors.rb +11 -0
- data/lib/transformers/hf_hub/file_download.rb +764 -0
- data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
- data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
- data/lib/transformers/image_processing_base.rb +169 -0
- data/lib/transformers/image_processing_utils.rb +63 -0
- data/lib/transformers/image_transforms.rb +208 -0
- data/lib/transformers/image_utils.rb +165 -0
- data/lib/transformers/modeling_outputs.rb +81 -0
- data/lib/transformers/modeling_utils.rb +888 -0
- data/lib/transformers/models/auto/auto_factory.rb +138 -0
- data/lib/transformers/models/auto/configuration_auto.rb +61 -0
- data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
- data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
- data/lib/transformers/models/auto/modeling_auto.rb +80 -0
- data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
- data/lib/transformers/models/bert/configuration_bert.rb +65 -0
- data/lib/transformers/models/bert/modeling_bert.rb +836 -0
- data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
- data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
- data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
- data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
- data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
- data/lib/transformers/models/vit/configuration_vit.rb +60 -0
- data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
- data/lib/transformers/models/vit/modeling_vit.rb +506 -0
- data/lib/transformers/pipelines/_init.rb +348 -0
- data/lib/transformers/pipelines/base.rb +301 -0
- data/lib/transformers/pipelines/feature_extraction.rb +47 -0
- data/lib/transformers/pipelines/image_classification.rb +110 -0
- data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
- data/lib/transformers/pipelines/pt_utils.rb +53 -0
- data/lib/transformers/pipelines/question_answering.rb +508 -0
- data/lib/transformers/pipelines/text_classification.rb +123 -0
- data/lib/transformers/pipelines/token_classification.rb +282 -0
- data/lib/transformers/ruby_utils.rb +33 -0
- data/lib/transformers/sentence_transformer.rb +37 -0
- data/lib/transformers/tokenization_utils.rb +152 -0
- data/lib/transformers/tokenization_utils_base.rb +937 -0
- data/lib/transformers/tokenization_utils_fast.rb +386 -0
- data/lib/transformers/torch_utils.rb +25 -0
- data/lib/transformers/utils/_init.rb +31 -0
- data/lib/transformers/utils/generic.rb +107 -0
- data/lib/transformers/utils/hub.rb +209 -0
- data/lib/transformers/utils/import_utils.rb +45 -0
- data/lib/transformers/utils/logging.rb +52 -0
- data/lib/transformers/version.rb +3 -0
- data/lib/transformers-rb.rb +1 -0
- data/lib/transformers.rb +100 -0
- data/licenses/LICENSE-huggingface-hub.txt +201 -0
- data/licenses/LICENSE-sentence-transformers.txt +201 -0
- data/licenses/NOTICE-sentence-transformers.txt +5 -0
- metadata +161 -0
@@ -0,0 +1,506 @@
|
|
1
|
+
# Copyright 2021 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
module Transformers
|
16
|
+
module Vit
|
17
|
+
class ViTEmbeddings < Torch::NN::Module
|
18
|
+
def initialize(config, use_mask_token: false)
|
19
|
+
super()
|
20
|
+
|
21
|
+
@cls_token = Torch::NN::Parameter.new(Torch.randn(1, 1, config.hidden_size))
|
22
|
+
@mask_token = use_mask_token ? Torch::NN::Parameter.new(Torch.zeros(1, 1, config.hidden_size)) : nil
|
23
|
+
@patch_embeddings = ViTPatchEmbeddings.new(config)
|
24
|
+
num_patches = @patch_embeddings.num_patches
|
25
|
+
@position_embeddings = Torch::NN::Parameter.new(Torch.randn(1, num_patches + 1, config.hidden_size))
|
26
|
+
@dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
|
27
|
+
@config = config
|
28
|
+
end
|
29
|
+
|
30
|
+
def forward(
|
31
|
+
pixel_values,
|
32
|
+
bool_masked_pos: nil,
|
33
|
+
interpolate_pos_encoding: false
|
34
|
+
)
|
35
|
+
batch_size, _num_channels, height, width = pixel_values.shape
|
36
|
+
embeddings = @patch_embeddings.(pixel_values, interpolate_pos_encoding: interpolate_pos_encoding)
|
37
|
+
|
38
|
+
if !bool_masked_pos.nil?
|
39
|
+
seq_length = embeddings.shape[1]
|
40
|
+
mask_tokens = @mask_token.expand(batch_size, seq_length, -1)
|
41
|
+
# replace the masked visual tokens by mask_tokens
|
42
|
+
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
|
43
|
+
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
|
44
|
+
end
|
45
|
+
|
46
|
+
# add the [CLS] token to the embedded patch tokens
|
47
|
+
cls_tokens = @cls_token.expand(batch_size, -1, -1)
|
48
|
+
embeddings = Torch.cat([cls_tokens, embeddings], dim: 1)
|
49
|
+
|
50
|
+
# add positional encoding to each token
|
51
|
+
if interpolate_pos_encoding
|
52
|
+
embeddings = embeddings + @interpolate_pos_encoding.(embeddings, height, width)
|
53
|
+
else
|
54
|
+
embeddings = embeddings + @position_embeddings
|
55
|
+
end
|
56
|
+
|
57
|
+
embeddings = @dropout.(embeddings)
|
58
|
+
|
59
|
+
embeddings
|
60
|
+
end
|
61
|
+
end
|
62
|
+
|
63
|
+
class ViTPatchEmbeddings < Torch::NN::Module
|
64
|
+
attr_reader :num_patches
|
65
|
+
|
66
|
+
def initialize(config)
|
67
|
+
super()
|
68
|
+
image_size, patch_size = config.image_size, config.patch_size
|
69
|
+
num_channels, hidden_size = config.num_channels, config.hidden_size
|
70
|
+
|
71
|
+
image_size = image_size.is_a?(Enumerable) ? image_size : [image_size, image_size]
|
72
|
+
patch_size = patch_size.is_a?(Enumerable) ? patch_size : [patch_size, patch_size]
|
73
|
+
num_patches = image_size[1].div(patch_size[1]) * image_size[0].div(patch_size[0])
|
74
|
+
@image_size = image_size
|
75
|
+
@patch_size = patch_size
|
76
|
+
@num_channels = num_channels
|
77
|
+
@num_patches = num_patches
|
78
|
+
|
79
|
+
@projection = Torch::NN::Conv2d.new(num_channels, hidden_size, patch_size, stride: patch_size)
|
80
|
+
end
|
81
|
+
|
82
|
+
def forward(pixel_values, interpolate_pos_encoding: false)
|
83
|
+
_batch_size, num_channels, height, width = pixel_values.shape
|
84
|
+
if num_channels != @num_channels
|
85
|
+
raise ArgumentError,
|
86
|
+
"Make sure that the channel dimension of the pixel values match with the one set in the configuration." +
|
87
|
+
" Expected #{@num_channels} but got #{num_channels}."
|
88
|
+
end
|
89
|
+
if !interpolate_pos_encoding
|
90
|
+
if height != @image_size[0] || width != @image_size[1]
|
91
|
+
raise ArgumentError,
|
92
|
+
"Input image size (#{height}*#{width}) doesn't match model" +
|
93
|
+
" (#{@image_size[0]}*#{@image_size[1]})."
|
94
|
+
end
|
95
|
+
end
|
96
|
+
embeddings = @projection.(pixel_values).flatten(2).transpose(1, 2)
|
97
|
+
embeddings
|
98
|
+
end
|
99
|
+
end
|
100
|
+
|
101
|
+
class ViTSelfAttention < Torch::NN::Module
|
102
|
+
def initialize(config)
|
103
|
+
super()
|
104
|
+
if config.hidden_size % config.num_attention_heads != 0 && !config.instance_variable_defined?(:@embedding_size)
|
105
|
+
raise ArgumentError,
|
106
|
+
"The hidden size #{config.hidden_size} is not a multiple of the number of attention " +
|
107
|
+
"heads #{config.num_attention_heads}."
|
108
|
+
end
|
109
|
+
|
110
|
+
@num_attention_heads = config.num_attention_heads
|
111
|
+
@attention_head_size = (config.hidden_size / config.num_attention_heads).to_i
|
112
|
+
@all_head_size = @num_attention_heads * @attention_head_size
|
113
|
+
|
114
|
+
@query = Torch::NN::Linear.new(config.hidden_size, @all_head_size, bias: config.qkv_bias)
|
115
|
+
@key = Torch::NN::Linear.new(config.hidden_size, @all_head_size, bias: config.qkv_bias)
|
116
|
+
@value = Torch::NN::Linear.new(config.hidden_size, @all_head_size, bias: config.qkv_bias)
|
117
|
+
|
118
|
+
@dropout = Torch::NN::Dropout.new(p: config.attention_probs_dropout_prob)
|
119
|
+
end
|
120
|
+
|
121
|
+
def transpose_for_scores(x)
|
122
|
+
new_x_shape = x.size[...-1] + [@num_attention_heads, @attention_head_size]
|
123
|
+
x = x.view(new_x_shape)
|
124
|
+
x.permute(0, 2, 1, 3)
|
125
|
+
end
|
126
|
+
|
127
|
+
def forward(
|
128
|
+
hidden_states, head_mask: nil, output_attentions: false
|
129
|
+
)
|
130
|
+
mixed_query_layer = @query.(hidden_states)
|
131
|
+
|
132
|
+
key_layer = transpose_for_scores(@key.(hidden_states))
|
133
|
+
value_layer = transpose_for_scores(@value.(hidden_states))
|
134
|
+
query_layer = transpose_for_scores(mixed_query_layer)
|
135
|
+
|
136
|
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
137
|
+
attention_scores = Torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
138
|
+
|
139
|
+
attention_scores = attention_scores / Math.sqrt(@attention_head_size)
|
140
|
+
|
141
|
+
# Normalize the attention scores to probabilities.
|
142
|
+
attention_probs = Torch::NN::Functional.softmax(attention_scores, dim: -1)
|
143
|
+
|
144
|
+
# This is actually dropping out entire tokens to attend to, which might
|
145
|
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
146
|
+
attention_probs = @dropout.(attention_probs)
|
147
|
+
|
148
|
+
# Mask heads if we want to
|
149
|
+
if !head_mask.nil?
|
150
|
+
attention_probs = attention_probs * head_mask
|
151
|
+
end
|
152
|
+
|
153
|
+
context_layer = Torch.matmul(attention_probs, value_layer)
|
154
|
+
|
155
|
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous
|
156
|
+
new_context_layer_shape = context_layer.size[...-2] + [@all_head_size]
|
157
|
+
context_layer = context_layer.view(new_context_layer_shape)
|
158
|
+
|
159
|
+
outputs = output_attentions ? [context_layer, attention_probs] : [context_layer]
|
160
|
+
|
161
|
+
outputs
|
162
|
+
end
|
163
|
+
end
|
164
|
+
|
165
|
+
class ViTSelfOutput < Torch::NN::Module
|
166
|
+
def initialize(config)
|
167
|
+
super()
|
168
|
+
@dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
|
169
|
+
@dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
|
170
|
+
end
|
171
|
+
|
172
|
+
def forward(hidden_states, input_tensor)
|
173
|
+
hidden_states = @dense.(hidden_states)
|
174
|
+
hidden_states = @dropout.(hidden_states)
|
175
|
+
|
176
|
+
hidden_states
|
177
|
+
end
|
178
|
+
end
|
179
|
+
|
180
|
+
class ViTAttention < Torch::NN::Module
|
181
|
+
def initialize(config)
|
182
|
+
super()
|
183
|
+
@attention = ViTSelfAttention.new(config)
|
184
|
+
@output = ViTSelfOutput.new(config)
|
185
|
+
@pruned_heads = Set.new
|
186
|
+
end
|
187
|
+
|
188
|
+
def prune_heads(heads)
|
189
|
+
raise Todo
|
190
|
+
end
|
191
|
+
|
192
|
+
def forward(
|
193
|
+
hidden_states,
|
194
|
+
head_mask: nil,
|
195
|
+
output_attentions: false
|
196
|
+
)
|
197
|
+
self_outputs = @attention.(hidden_states, head_mask: head_mask, output_attentions: output_attentions)
|
198
|
+
|
199
|
+
attention_output = @output.(self_outputs[0], hidden_states)
|
200
|
+
|
201
|
+
outputs = [attention_output] + self_outputs[1..] # add attentions if we output them
|
202
|
+
outputs
|
203
|
+
end
|
204
|
+
end
|
205
|
+
|
206
|
+
class ViTIntermediate < Torch::NN::Module
|
207
|
+
def initialize(config)
|
208
|
+
super()
|
209
|
+
@dense = Torch::NN::Linear.new(config.hidden_size, config.intermediate_size)
|
210
|
+
if config.hidden_act.is_a?(String)
|
211
|
+
@intermediate_act_fn = ACT2FN[config.hidden_act]
|
212
|
+
else
|
213
|
+
@intermediate_act_fn = config.hidden_act
|
214
|
+
end
|
215
|
+
end
|
216
|
+
|
217
|
+
def forward(hidden_states)
|
218
|
+
hidden_states = @dense.(hidden_states)
|
219
|
+
hidden_states = @intermediate_act_fn.(hidden_states)
|
220
|
+
|
221
|
+
hidden_states
|
222
|
+
end
|
223
|
+
end
|
224
|
+
|
225
|
+
class ViTOutput < Torch::NN::Module
|
226
|
+
def initialize(config)
|
227
|
+
super()
|
228
|
+
@dense = Torch::NN::Linear.new(config.intermediate_size, config.hidden_size)
|
229
|
+
@dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
|
230
|
+
end
|
231
|
+
|
232
|
+
def forward(hidden_states, input_tensor)
|
233
|
+
hidden_states = @dense.(hidden_states)
|
234
|
+
hidden_states = @dropout.(hidden_states)
|
235
|
+
|
236
|
+
hidden_states = hidden_states + input_tensor
|
237
|
+
|
238
|
+
hidden_states
|
239
|
+
end
|
240
|
+
end
|
241
|
+
|
242
|
+
VIT_ATTENTION_CLASSES = {
|
243
|
+
"eager" => ViTAttention
|
244
|
+
}
|
245
|
+
|
246
|
+
class ViTLayer < Torch::NN::Module
|
247
|
+
def initialize(config)
|
248
|
+
super()
|
249
|
+
@chunk_size_feed_forward = config.chunk_size_feed_forward
|
250
|
+
@seq_len_dim = 1
|
251
|
+
@attention = VIT_ATTENTION_CLASSES.fetch(config._attn_implementation).new(config)
|
252
|
+
@intermediate = ViTIntermediate.new(config)
|
253
|
+
@output = ViTOutput.new(config)
|
254
|
+
@layernorm_before = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
|
255
|
+
@layernorm_after = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
|
256
|
+
end
|
257
|
+
|
258
|
+
def forward(
|
259
|
+
hidden_states,
|
260
|
+
head_mask: nil,
|
261
|
+
output_attentions: false
|
262
|
+
)
|
263
|
+
self_attention_outputs = @attention.(
|
264
|
+
@layernorm_before.(hidden_states), # in ViT, layernorm is applied before self-attention
|
265
|
+
head_mask: head_mask,
|
266
|
+
output_attentions: output_attentions
|
267
|
+
)
|
268
|
+
attention_output = self_attention_outputs[0]
|
269
|
+
outputs = self_attention_outputs[1..] # add self attentions if we output attention weights
|
270
|
+
|
271
|
+
# first residual connection
|
272
|
+
hidden_states = attention_output + hidden_states
|
273
|
+
|
274
|
+
# in ViT, layernorm is also applied after self-attention
|
275
|
+
layer_output = @layernorm_after.(hidden_states)
|
276
|
+
layer_output = @intermediate.(layer_output)
|
277
|
+
|
278
|
+
# second residual connection is done here
|
279
|
+
layer_output = @output.(layer_output, hidden_states)
|
280
|
+
|
281
|
+
outputs = [layer_output] + outputs
|
282
|
+
|
283
|
+
outputs
|
284
|
+
end
|
285
|
+
end
|
286
|
+
|
287
|
+
class ViTEncoder < Torch::NN::Module
|
288
|
+
def initialize(config)
|
289
|
+
super()
|
290
|
+
@config = config
|
291
|
+
@layer = Torch::NN::ModuleList.new(config.num_hidden_layers.times.map { ViTLayer.new(config) })
|
292
|
+
@gradient_checkpointing = false
|
293
|
+
end
|
294
|
+
|
295
|
+
def forward(
|
296
|
+
hidden_states,
|
297
|
+
head_mask: nil,
|
298
|
+
output_attentions: false,
|
299
|
+
output_hidden_states: false,
|
300
|
+
return_dict: true
|
301
|
+
)
|
302
|
+
all_hidden_states = output_hidden_states ? [] : nil
|
303
|
+
all_self_attentions = output_attentions ? [] : nil
|
304
|
+
|
305
|
+
@layer.each_with_index do |layer_module, i|
|
306
|
+
if output_hidden_states
|
307
|
+
all_hidden_states = all_hidden_states + [hidden_states]
|
308
|
+
end
|
309
|
+
|
310
|
+
layer_head_mask = !head_mask.nil? ? head_mask[i] : nil
|
311
|
+
|
312
|
+
if @gradient_checkpointing && @training
|
313
|
+
raise Todo
|
314
|
+
else
|
315
|
+
layer_outputs = layer_module.(hidden_states, head_mask: layer_head_mask, output_attentions: output_attentions)
|
316
|
+
end
|
317
|
+
|
318
|
+
hidden_states = layer_outputs[0]
|
319
|
+
|
320
|
+
if output_attentions
|
321
|
+
all_self_attentions = all_self_attentions + [layer_outputs[1]]
|
322
|
+
end
|
323
|
+
end
|
324
|
+
|
325
|
+
if output_hidden_states
|
326
|
+
all_hidden_states = all_hidden_states + [hidden_states]
|
327
|
+
end
|
328
|
+
|
329
|
+
if !return_dict
|
330
|
+
raise Todo
|
331
|
+
end
|
332
|
+
BaseModelOutput.new(
|
333
|
+
last_hidden_state: hidden_states,
|
334
|
+
hidden_states: all_hidden_states,
|
335
|
+
attentions: all_self_attentions
|
336
|
+
)
|
337
|
+
end
|
338
|
+
end
|
339
|
+
|
340
|
+
class ViTPreTrainedModel < PreTrainedModel
|
341
|
+
self.config_class = ViTConfig
|
342
|
+
self.base_model_prefix = "vit"
|
343
|
+
self.main_input_name = "pixel_values"
|
344
|
+
|
345
|
+
def _init_weights(mod)
|
346
|
+
# TODO
|
347
|
+
end
|
348
|
+
end
|
349
|
+
|
350
|
+
class ViTModel < ViTPreTrainedModel
|
351
|
+
def initialize(config, add_pooling_layer: true, use_mask_token: false)
|
352
|
+
super(config)
|
353
|
+
@config = config
|
354
|
+
|
355
|
+
@embeddings = ViTEmbeddings.new(config, use_mask_token: use_mask_token)
|
356
|
+
@encoder = ViTEncoder.new(config)
|
357
|
+
|
358
|
+
@layernorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
|
359
|
+
@pooler = add_pooling_layer ? ViTPooler.new(config) : nil
|
360
|
+
|
361
|
+
# Initialize weights and apply final processing
|
362
|
+
post_init
|
363
|
+
end
|
364
|
+
|
365
|
+
def _prune_heads(heads_to_prune)
|
366
|
+
heads_to_prune.each do |layer, heads|
|
367
|
+
@encoder.layer[layer].attention.prune_heads(heads)
|
368
|
+
end
|
369
|
+
end
|
370
|
+
|
371
|
+
def forward(
|
372
|
+
pixel_values: nil,
|
373
|
+
bool_masked_pos: nil,
|
374
|
+
head_mask: nil,
|
375
|
+
output_attentions: nil,
|
376
|
+
output_hidden_states: nil,
|
377
|
+
interpolate_pos_encoding: nil,
|
378
|
+
return_dict: nil
|
379
|
+
)
|
380
|
+
output_attentions = !output_attentions.nil? ? output_attentions : @config.output_attentions
|
381
|
+
output_hidden_states = (
|
382
|
+
!output_hidden_states.nil? ? output_hidden_states : @config.output_hidden_states
|
383
|
+
)
|
384
|
+
return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
|
385
|
+
|
386
|
+
if pixel_values.nil?
|
387
|
+
raise ArgumentError, "You have to specify pixel_values"
|
388
|
+
end
|
389
|
+
|
390
|
+
# Prepare head mask if needed
|
391
|
+
# 1.0 in head_mask indicate we keep the head
|
392
|
+
# attention_probs has shape bsz x n_heads x N x N
|
393
|
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
394
|
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
395
|
+
head_mask = get_head_mask(head_mask, @config.num_hidden_layers)
|
396
|
+
|
397
|
+
# TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
|
398
|
+
expected_dtype = @embeddings.patch_embeddings.projection.weight.dtype
|
399
|
+
if pixel_values.dtype != expected_dtype
|
400
|
+
pixel_values = pixel_values.to(expected_dtype)
|
401
|
+
end
|
402
|
+
|
403
|
+
embedding_output = @embeddings.(
|
404
|
+
pixel_values, bool_masked_pos: bool_masked_pos, interpolate_pos_encoding: interpolate_pos_encoding
|
405
|
+
)
|
406
|
+
|
407
|
+
encoder_outputs = @encoder.(
|
408
|
+
embedding_output,
|
409
|
+
head_mask: head_mask,
|
410
|
+
output_attentions: output_attentions,
|
411
|
+
output_hidden_states: output_hidden_states,
|
412
|
+
return_dict: return_dict
|
413
|
+
)
|
414
|
+
sequence_output = encoder_outputs[0]
|
415
|
+
sequence_output = @layernorm.(sequence_output)
|
416
|
+
pooled_output = @pooler ? @pooler.(sequence_output) : nil
|
417
|
+
|
418
|
+
if !return_dict
|
419
|
+
raise Todo
|
420
|
+
end
|
421
|
+
|
422
|
+
BaseModelOutputWithPooling.new(
|
423
|
+
last_hidden_state: sequence_output,
|
424
|
+
pooler_output: pooled_output,
|
425
|
+
hidden_states: encoder_outputs.hidden_states,
|
426
|
+
attentions: encoder_outputs.attentions
|
427
|
+
)
|
428
|
+
end
|
429
|
+
end
|
430
|
+
|
431
|
+
class ViTPooler < Torch::NN::Module
|
432
|
+
def initialize(config)
|
433
|
+
super()
|
434
|
+
@dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
|
435
|
+
@activation = Torch::NN::Tanh.new
|
436
|
+
end
|
437
|
+
|
438
|
+
def forward(hidden_states)
|
439
|
+
# We "pool" the model by simply taking the hidden state corresponding
|
440
|
+
# to the first token.
|
441
|
+
first_token_tensor = hidden_states[0.., 0]
|
442
|
+
pooled_output = @dense.(first_token_tensor)
|
443
|
+
pooled_output = @activation.(pooled_output)
|
444
|
+
pooled_output
|
445
|
+
end
|
446
|
+
end
|
447
|
+
|
448
|
+
class ViTForImageClassification < ViTPreTrainedModel
|
449
|
+
def initialize(config)
|
450
|
+
super(config)
|
451
|
+
|
452
|
+
@num_labels = config.num_labels
|
453
|
+
@vit = ViTModel.new(config, add_pooling_layer: false)
|
454
|
+
|
455
|
+
# Classifier head
|
456
|
+
@classifier = config.num_labels > 0 ? Torch::NN::Linear.new(config.hidden_size, config.num_labels) : Torch::NN::Identity.new
|
457
|
+
|
458
|
+
# Initialize weights and apply final processing
|
459
|
+
post_init
|
460
|
+
end
|
461
|
+
|
462
|
+
def forward(
|
463
|
+
pixel_values: nil,
|
464
|
+
head_mask: nil,
|
465
|
+
labels: nil,
|
466
|
+
output_attentions: nil,
|
467
|
+
output_hidden_states: nil,
|
468
|
+
interpolate_pos_encoding: nil,
|
469
|
+
return_dict: nil
|
470
|
+
)
|
471
|
+
return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
|
472
|
+
|
473
|
+
outputs = @vit.(
|
474
|
+
pixel_values: pixel_values,
|
475
|
+
head_mask: head_mask,
|
476
|
+
output_attentions: output_attentions,
|
477
|
+
output_hidden_states: output_hidden_states,
|
478
|
+
interpolate_pos_encoding: interpolate_pos_encoding,
|
479
|
+
return_dict: return_dict
|
480
|
+
)
|
481
|
+
|
482
|
+
sequence_output = outputs[0]
|
483
|
+
|
484
|
+
logits = @classifier.(sequence_output[0.., 0, 0..])
|
485
|
+
|
486
|
+
loss = nil
|
487
|
+
if !labels.nil?
|
488
|
+
raise Todo
|
489
|
+
end
|
490
|
+
|
491
|
+
if !return_dict
|
492
|
+
raise Todo
|
493
|
+
end
|
494
|
+
|
495
|
+
ImageClassifierOutput.new(
|
496
|
+
loss: loss,
|
497
|
+
logits: logits,
|
498
|
+
hidden_states: outputs.hidden_states,
|
499
|
+
attentions: outputs.attentions
|
500
|
+
)
|
501
|
+
end
|
502
|
+
end
|
503
|
+
end
|
504
|
+
|
505
|
+
ViTForImageClassification = Vit::ViTForImageClassification
|
506
|
+
end
|