transformers-rb 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +3 -0
  3. data/LICENSE.txt +203 -0
  4. data/README.md +163 -0
  5. data/lib/transformers/activations.rb +57 -0
  6. data/lib/transformers/configuration_utils.rb +285 -0
  7. data/lib/transformers/convert_slow_tokenizer.rb +90 -0
  8. data/lib/transformers/data/processors/squad.rb +115 -0
  9. data/lib/transformers/dynamic_module_utils.rb +25 -0
  10. data/lib/transformers/feature_extraction_utils.rb +110 -0
  11. data/lib/transformers/hf_hub/constants.rb +71 -0
  12. data/lib/transformers/hf_hub/errors.rb +11 -0
  13. data/lib/transformers/hf_hub/file_download.rb +764 -0
  14. data/lib/transformers/hf_hub/utils/_errors.rb +94 -0
  15. data/lib/transformers/hf_hub/utils/_headers.rb +109 -0
  16. data/lib/transformers/image_processing_base.rb +169 -0
  17. data/lib/transformers/image_processing_utils.rb +63 -0
  18. data/lib/transformers/image_transforms.rb +208 -0
  19. data/lib/transformers/image_utils.rb +165 -0
  20. data/lib/transformers/modeling_outputs.rb +81 -0
  21. data/lib/transformers/modeling_utils.rb +888 -0
  22. data/lib/transformers/models/auto/auto_factory.rb +138 -0
  23. data/lib/transformers/models/auto/configuration_auto.rb +61 -0
  24. data/lib/transformers/models/auto/feature_extraction_auto.rb +20 -0
  25. data/lib/transformers/models/auto/image_processing_auto.rb +104 -0
  26. data/lib/transformers/models/auto/modeling_auto.rb +80 -0
  27. data/lib/transformers/models/auto/tokenization_auto.rb +160 -0
  28. data/lib/transformers/models/bert/configuration_bert.rb +65 -0
  29. data/lib/transformers/models/bert/modeling_bert.rb +836 -0
  30. data/lib/transformers/models/bert/tokenization_bert.rb +115 -0
  31. data/lib/transformers/models/bert/tokenization_bert_fast.rb +52 -0
  32. data/lib/transformers/models/distilbert/configuration_distilbert.rb +63 -0
  33. data/lib/transformers/models/distilbert/modeling_distilbert.rb +616 -0
  34. data/lib/transformers/models/distilbert/tokenization_distilbert.rb +114 -0
  35. data/lib/transformers/models/distilbert/tokenization_distilbert_fast.rb +71 -0
  36. data/lib/transformers/models/vit/configuration_vit.rb +60 -0
  37. data/lib/transformers/models/vit/image_processing_vit.rb +170 -0
  38. data/lib/transformers/models/vit/modeling_vit.rb +506 -0
  39. data/lib/transformers/pipelines/_init.rb +348 -0
  40. data/lib/transformers/pipelines/base.rb +301 -0
  41. data/lib/transformers/pipelines/feature_extraction.rb +47 -0
  42. data/lib/transformers/pipelines/image_classification.rb +110 -0
  43. data/lib/transformers/pipelines/image_feature_extraction.rb +56 -0
  44. data/lib/transformers/pipelines/pt_utils.rb +53 -0
  45. data/lib/transformers/pipelines/question_answering.rb +508 -0
  46. data/lib/transformers/pipelines/text_classification.rb +123 -0
  47. data/lib/transformers/pipelines/token_classification.rb +282 -0
  48. data/lib/transformers/ruby_utils.rb +33 -0
  49. data/lib/transformers/sentence_transformer.rb +37 -0
  50. data/lib/transformers/tokenization_utils.rb +152 -0
  51. data/lib/transformers/tokenization_utils_base.rb +937 -0
  52. data/lib/transformers/tokenization_utils_fast.rb +386 -0
  53. data/lib/transformers/torch_utils.rb +25 -0
  54. data/lib/transformers/utils/_init.rb +31 -0
  55. data/lib/transformers/utils/generic.rb +107 -0
  56. data/lib/transformers/utils/hub.rb +209 -0
  57. data/lib/transformers/utils/import_utils.rb +45 -0
  58. data/lib/transformers/utils/logging.rb +52 -0
  59. data/lib/transformers/version.rb +3 -0
  60. data/lib/transformers-rb.rb +1 -0
  61. data/lib/transformers.rb +100 -0
  62. data/licenses/LICENSE-huggingface-hub.txt +201 -0
  63. data/licenses/LICENSE-sentence-transformers.txt +201 -0
  64. data/licenses/NOTICE-sentence-transformers.txt +5 -0
  65. metadata +161 -0
@@ -0,0 +1,506 @@
1
+ # Copyright 2021 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ module Transformers
16
+ module Vit
17
+ class ViTEmbeddings < Torch::NN::Module
18
+ def initialize(config, use_mask_token: false)
19
+ super()
20
+
21
+ @cls_token = Torch::NN::Parameter.new(Torch.randn(1, 1, config.hidden_size))
22
+ @mask_token = use_mask_token ? Torch::NN::Parameter.new(Torch.zeros(1, 1, config.hidden_size)) : nil
23
+ @patch_embeddings = ViTPatchEmbeddings.new(config)
24
+ num_patches = @patch_embeddings.num_patches
25
+ @position_embeddings = Torch::NN::Parameter.new(Torch.randn(1, num_patches + 1, config.hidden_size))
26
+ @dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
27
+ @config = config
28
+ end
29
+
30
+ def forward(
31
+ pixel_values,
32
+ bool_masked_pos: nil,
33
+ interpolate_pos_encoding: false
34
+ )
35
+ batch_size, _num_channels, height, width = pixel_values.shape
36
+ embeddings = @patch_embeddings.(pixel_values, interpolate_pos_encoding: interpolate_pos_encoding)
37
+
38
+ if !bool_masked_pos.nil?
39
+ seq_length = embeddings.shape[1]
40
+ mask_tokens = @mask_token.expand(batch_size, seq_length, -1)
41
+ # replace the masked visual tokens by mask_tokens
42
+ mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
43
+ embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
44
+ end
45
+
46
+ # add the [CLS] token to the embedded patch tokens
47
+ cls_tokens = @cls_token.expand(batch_size, -1, -1)
48
+ embeddings = Torch.cat([cls_tokens, embeddings], dim: 1)
49
+
50
+ # add positional encoding to each token
51
+ if interpolate_pos_encoding
52
+ embeddings = embeddings + @interpolate_pos_encoding.(embeddings, height, width)
53
+ else
54
+ embeddings = embeddings + @position_embeddings
55
+ end
56
+
57
+ embeddings = @dropout.(embeddings)
58
+
59
+ embeddings
60
+ end
61
+ end
62
+
63
+ class ViTPatchEmbeddings < Torch::NN::Module
64
+ attr_reader :num_patches
65
+
66
+ def initialize(config)
67
+ super()
68
+ image_size, patch_size = config.image_size, config.patch_size
69
+ num_channels, hidden_size = config.num_channels, config.hidden_size
70
+
71
+ image_size = image_size.is_a?(Enumerable) ? image_size : [image_size, image_size]
72
+ patch_size = patch_size.is_a?(Enumerable) ? patch_size : [patch_size, patch_size]
73
+ num_patches = image_size[1].div(patch_size[1]) * image_size[0].div(patch_size[0])
74
+ @image_size = image_size
75
+ @patch_size = patch_size
76
+ @num_channels = num_channels
77
+ @num_patches = num_patches
78
+
79
+ @projection = Torch::NN::Conv2d.new(num_channels, hidden_size, patch_size, stride: patch_size)
80
+ end
81
+
82
+ def forward(pixel_values, interpolate_pos_encoding: false)
83
+ _batch_size, num_channels, height, width = pixel_values.shape
84
+ if num_channels != @num_channels
85
+ raise ArgumentError,
86
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration." +
87
+ " Expected #{@num_channels} but got #{num_channels}."
88
+ end
89
+ if !interpolate_pos_encoding
90
+ if height != @image_size[0] || width != @image_size[1]
91
+ raise ArgumentError,
92
+ "Input image size (#{height}*#{width}) doesn't match model" +
93
+ " (#{@image_size[0]}*#{@image_size[1]})."
94
+ end
95
+ end
96
+ embeddings = @projection.(pixel_values).flatten(2).transpose(1, 2)
97
+ embeddings
98
+ end
99
+ end
100
+
101
+ class ViTSelfAttention < Torch::NN::Module
102
+ def initialize(config)
103
+ super()
104
+ if config.hidden_size % config.num_attention_heads != 0 && !config.instance_variable_defined?(:@embedding_size)
105
+ raise ArgumentError,
106
+ "The hidden size #{config.hidden_size} is not a multiple of the number of attention " +
107
+ "heads #{config.num_attention_heads}."
108
+ end
109
+
110
+ @num_attention_heads = config.num_attention_heads
111
+ @attention_head_size = (config.hidden_size / config.num_attention_heads).to_i
112
+ @all_head_size = @num_attention_heads * @attention_head_size
113
+
114
+ @query = Torch::NN::Linear.new(config.hidden_size, @all_head_size, bias: config.qkv_bias)
115
+ @key = Torch::NN::Linear.new(config.hidden_size, @all_head_size, bias: config.qkv_bias)
116
+ @value = Torch::NN::Linear.new(config.hidden_size, @all_head_size, bias: config.qkv_bias)
117
+
118
+ @dropout = Torch::NN::Dropout.new(p: config.attention_probs_dropout_prob)
119
+ end
120
+
121
+ def transpose_for_scores(x)
122
+ new_x_shape = x.size[...-1] + [@num_attention_heads, @attention_head_size]
123
+ x = x.view(new_x_shape)
124
+ x.permute(0, 2, 1, 3)
125
+ end
126
+
127
+ def forward(
128
+ hidden_states, head_mask: nil, output_attentions: false
129
+ )
130
+ mixed_query_layer = @query.(hidden_states)
131
+
132
+ key_layer = transpose_for_scores(@key.(hidden_states))
133
+ value_layer = transpose_for_scores(@value.(hidden_states))
134
+ query_layer = transpose_for_scores(mixed_query_layer)
135
+
136
+ # Take the dot product between "query" and "key" to get the raw attention scores.
137
+ attention_scores = Torch.matmul(query_layer, key_layer.transpose(-1, -2))
138
+
139
+ attention_scores = attention_scores / Math.sqrt(@attention_head_size)
140
+
141
+ # Normalize the attention scores to probabilities.
142
+ attention_probs = Torch::NN::Functional.softmax(attention_scores, dim: -1)
143
+
144
+ # This is actually dropping out entire tokens to attend to, which might
145
+ # seem a bit unusual, but is taken from the original Transformer paper.
146
+ attention_probs = @dropout.(attention_probs)
147
+
148
+ # Mask heads if we want to
149
+ if !head_mask.nil?
150
+ attention_probs = attention_probs * head_mask
151
+ end
152
+
153
+ context_layer = Torch.matmul(attention_probs, value_layer)
154
+
155
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous
156
+ new_context_layer_shape = context_layer.size[...-2] + [@all_head_size]
157
+ context_layer = context_layer.view(new_context_layer_shape)
158
+
159
+ outputs = output_attentions ? [context_layer, attention_probs] : [context_layer]
160
+
161
+ outputs
162
+ end
163
+ end
164
+
165
+ class ViTSelfOutput < Torch::NN::Module
166
+ def initialize(config)
167
+ super()
168
+ @dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
169
+ @dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
170
+ end
171
+
172
+ def forward(hidden_states, input_tensor)
173
+ hidden_states = @dense.(hidden_states)
174
+ hidden_states = @dropout.(hidden_states)
175
+
176
+ hidden_states
177
+ end
178
+ end
179
+
180
+ class ViTAttention < Torch::NN::Module
181
+ def initialize(config)
182
+ super()
183
+ @attention = ViTSelfAttention.new(config)
184
+ @output = ViTSelfOutput.new(config)
185
+ @pruned_heads = Set.new
186
+ end
187
+
188
+ def prune_heads(heads)
189
+ raise Todo
190
+ end
191
+
192
+ def forward(
193
+ hidden_states,
194
+ head_mask: nil,
195
+ output_attentions: false
196
+ )
197
+ self_outputs = @attention.(hidden_states, head_mask: head_mask, output_attentions: output_attentions)
198
+
199
+ attention_output = @output.(self_outputs[0], hidden_states)
200
+
201
+ outputs = [attention_output] + self_outputs[1..] # add attentions if we output them
202
+ outputs
203
+ end
204
+ end
205
+
206
+ class ViTIntermediate < Torch::NN::Module
207
+ def initialize(config)
208
+ super()
209
+ @dense = Torch::NN::Linear.new(config.hidden_size, config.intermediate_size)
210
+ if config.hidden_act.is_a?(String)
211
+ @intermediate_act_fn = ACT2FN[config.hidden_act]
212
+ else
213
+ @intermediate_act_fn = config.hidden_act
214
+ end
215
+ end
216
+
217
+ def forward(hidden_states)
218
+ hidden_states = @dense.(hidden_states)
219
+ hidden_states = @intermediate_act_fn.(hidden_states)
220
+
221
+ hidden_states
222
+ end
223
+ end
224
+
225
+ class ViTOutput < Torch::NN::Module
226
+ def initialize(config)
227
+ super()
228
+ @dense = Torch::NN::Linear.new(config.intermediate_size, config.hidden_size)
229
+ @dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
230
+ end
231
+
232
+ def forward(hidden_states, input_tensor)
233
+ hidden_states = @dense.(hidden_states)
234
+ hidden_states = @dropout.(hidden_states)
235
+
236
+ hidden_states = hidden_states + input_tensor
237
+
238
+ hidden_states
239
+ end
240
+ end
241
+
242
+ VIT_ATTENTION_CLASSES = {
243
+ "eager" => ViTAttention
244
+ }
245
+
246
+ class ViTLayer < Torch::NN::Module
247
+ def initialize(config)
248
+ super()
249
+ @chunk_size_feed_forward = config.chunk_size_feed_forward
250
+ @seq_len_dim = 1
251
+ @attention = VIT_ATTENTION_CLASSES.fetch(config._attn_implementation).new(config)
252
+ @intermediate = ViTIntermediate.new(config)
253
+ @output = ViTOutput.new(config)
254
+ @layernorm_before = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
255
+ @layernorm_after = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
256
+ end
257
+
258
+ def forward(
259
+ hidden_states,
260
+ head_mask: nil,
261
+ output_attentions: false
262
+ )
263
+ self_attention_outputs = @attention.(
264
+ @layernorm_before.(hidden_states), # in ViT, layernorm is applied before self-attention
265
+ head_mask: head_mask,
266
+ output_attentions: output_attentions
267
+ )
268
+ attention_output = self_attention_outputs[0]
269
+ outputs = self_attention_outputs[1..] # add self attentions if we output attention weights
270
+
271
+ # first residual connection
272
+ hidden_states = attention_output + hidden_states
273
+
274
+ # in ViT, layernorm is also applied after self-attention
275
+ layer_output = @layernorm_after.(hidden_states)
276
+ layer_output = @intermediate.(layer_output)
277
+
278
+ # second residual connection is done here
279
+ layer_output = @output.(layer_output, hidden_states)
280
+
281
+ outputs = [layer_output] + outputs
282
+
283
+ outputs
284
+ end
285
+ end
286
+
287
+ class ViTEncoder < Torch::NN::Module
288
+ def initialize(config)
289
+ super()
290
+ @config = config
291
+ @layer = Torch::NN::ModuleList.new(config.num_hidden_layers.times.map { ViTLayer.new(config) })
292
+ @gradient_checkpointing = false
293
+ end
294
+
295
+ def forward(
296
+ hidden_states,
297
+ head_mask: nil,
298
+ output_attentions: false,
299
+ output_hidden_states: false,
300
+ return_dict: true
301
+ )
302
+ all_hidden_states = output_hidden_states ? [] : nil
303
+ all_self_attentions = output_attentions ? [] : nil
304
+
305
+ @layer.each_with_index do |layer_module, i|
306
+ if output_hidden_states
307
+ all_hidden_states = all_hidden_states + [hidden_states]
308
+ end
309
+
310
+ layer_head_mask = !head_mask.nil? ? head_mask[i] : nil
311
+
312
+ if @gradient_checkpointing && @training
313
+ raise Todo
314
+ else
315
+ layer_outputs = layer_module.(hidden_states, head_mask: layer_head_mask, output_attentions: output_attentions)
316
+ end
317
+
318
+ hidden_states = layer_outputs[0]
319
+
320
+ if output_attentions
321
+ all_self_attentions = all_self_attentions + [layer_outputs[1]]
322
+ end
323
+ end
324
+
325
+ if output_hidden_states
326
+ all_hidden_states = all_hidden_states + [hidden_states]
327
+ end
328
+
329
+ if !return_dict
330
+ raise Todo
331
+ end
332
+ BaseModelOutput.new(
333
+ last_hidden_state: hidden_states,
334
+ hidden_states: all_hidden_states,
335
+ attentions: all_self_attentions
336
+ )
337
+ end
338
+ end
339
+
340
+ class ViTPreTrainedModel < PreTrainedModel
341
+ self.config_class = ViTConfig
342
+ self.base_model_prefix = "vit"
343
+ self.main_input_name = "pixel_values"
344
+
345
+ def _init_weights(mod)
346
+ # TODO
347
+ end
348
+ end
349
+
350
+ class ViTModel < ViTPreTrainedModel
351
+ def initialize(config, add_pooling_layer: true, use_mask_token: false)
352
+ super(config)
353
+ @config = config
354
+
355
+ @embeddings = ViTEmbeddings.new(config, use_mask_token: use_mask_token)
356
+ @encoder = ViTEncoder.new(config)
357
+
358
+ @layernorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
359
+ @pooler = add_pooling_layer ? ViTPooler.new(config) : nil
360
+
361
+ # Initialize weights and apply final processing
362
+ post_init
363
+ end
364
+
365
+ def _prune_heads(heads_to_prune)
366
+ heads_to_prune.each do |layer, heads|
367
+ @encoder.layer[layer].attention.prune_heads(heads)
368
+ end
369
+ end
370
+
371
+ def forward(
372
+ pixel_values: nil,
373
+ bool_masked_pos: nil,
374
+ head_mask: nil,
375
+ output_attentions: nil,
376
+ output_hidden_states: nil,
377
+ interpolate_pos_encoding: nil,
378
+ return_dict: nil
379
+ )
380
+ output_attentions = !output_attentions.nil? ? output_attentions : @config.output_attentions
381
+ output_hidden_states = (
382
+ !output_hidden_states.nil? ? output_hidden_states : @config.output_hidden_states
383
+ )
384
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
385
+
386
+ if pixel_values.nil?
387
+ raise ArgumentError, "You have to specify pixel_values"
388
+ end
389
+
390
+ # Prepare head mask if needed
391
+ # 1.0 in head_mask indicate we keep the head
392
+ # attention_probs has shape bsz x n_heads x N x N
393
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
394
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
395
+ head_mask = get_head_mask(head_mask, @config.num_hidden_layers)
396
+
397
+ # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
398
+ expected_dtype = @embeddings.patch_embeddings.projection.weight.dtype
399
+ if pixel_values.dtype != expected_dtype
400
+ pixel_values = pixel_values.to(expected_dtype)
401
+ end
402
+
403
+ embedding_output = @embeddings.(
404
+ pixel_values, bool_masked_pos: bool_masked_pos, interpolate_pos_encoding: interpolate_pos_encoding
405
+ )
406
+
407
+ encoder_outputs = @encoder.(
408
+ embedding_output,
409
+ head_mask: head_mask,
410
+ output_attentions: output_attentions,
411
+ output_hidden_states: output_hidden_states,
412
+ return_dict: return_dict
413
+ )
414
+ sequence_output = encoder_outputs[0]
415
+ sequence_output = @layernorm.(sequence_output)
416
+ pooled_output = @pooler ? @pooler.(sequence_output) : nil
417
+
418
+ if !return_dict
419
+ raise Todo
420
+ end
421
+
422
+ BaseModelOutputWithPooling.new(
423
+ last_hidden_state: sequence_output,
424
+ pooler_output: pooled_output,
425
+ hidden_states: encoder_outputs.hidden_states,
426
+ attentions: encoder_outputs.attentions
427
+ )
428
+ end
429
+ end
430
+
431
+ class ViTPooler < Torch::NN::Module
432
+ def initialize(config)
433
+ super()
434
+ @dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
435
+ @activation = Torch::NN::Tanh.new
436
+ end
437
+
438
+ def forward(hidden_states)
439
+ # We "pool" the model by simply taking the hidden state corresponding
440
+ # to the first token.
441
+ first_token_tensor = hidden_states[0.., 0]
442
+ pooled_output = @dense.(first_token_tensor)
443
+ pooled_output = @activation.(pooled_output)
444
+ pooled_output
445
+ end
446
+ end
447
+
448
+ class ViTForImageClassification < ViTPreTrainedModel
449
+ def initialize(config)
450
+ super(config)
451
+
452
+ @num_labels = config.num_labels
453
+ @vit = ViTModel.new(config, add_pooling_layer: false)
454
+
455
+ # Classifier head
456
+ @classifier = config.num_labels > 0 ? Torch::NN::Linear.new(config.hidden_size, config.num_labels) : Torch::NN::Identity.new
457
+
458
+ # Initialize weights and apply final processing
459
+ post_init
460
+ end
461
+
462
+ def forward(
463
+ pixel_values: nil,
464
+ head_mask: nil,
465
+ labels: nil,
466
+ output_attentions: nil,
467
+ output_hidden_states: nil,
468
+ interpolate_pos_encoding: nil,
469
+ return_dict: nil
470
+ )
471
+ return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
472
+
473
+ outputs = @vit.(
474
+ pixel_values: pixel_values,
475
+ head_mask: head_mask,
476
+ output_attentions: output_attentions,
477
+ output_hidden_states: output_hidden_states,
478
+ interpolate_pos_encoding: interpolate_pos_encoding,
479
+ return_dict: return_dict
480
+ )
481
+
482
+ sequence_output = outputs[0]
483
+
484
+ logits = @classifier.(sequence_output[0.., 0, 0..])
485
+
486
+ loss = nil
487
+ if !labels.nil?
488
+ raise Todo
489
+ end
490
+
491
+ if !return_dict
492
+ raise Todo
493
+ end
494
+
495
+ ImageClassifierOutput.new(
496
+ loss: loss,
497
+ logits: logits,
498
+ hidden_states: outputs.hidden_states,
499
+ attentions: outputs.attentions
500
+ )
501
+ end
502
+ end
503
+ end
504
+
505
+ ViTForImageClassification = Vit::ViTForImageClassification
506
+ end