keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +21 -3
- keras_hub/api/models/__init__.py +71 -12
- keras_hub/api/tokenizers/__init__.py +1 -1
- keras_hub/src/bounding_box/__init__.py +2 -0
- keras_hub/src/bounding_box/converters.py +102 -12
- keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
- keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
- keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
- keras_hub/src/layers/modeling/rms_normalization.py +36 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
- keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
- keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
- keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
- keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
- keras_hub/src/layers/preprocessing/image_converter.py +170 -34
- keras_hub/src/metrics/bleu.py +4 -3
- keras_hub/src/models/albert/albert_presets.py +4 -12
- keras_hub/src/models/albert/albert_text_classifier.py +7 -7
- keras_hub/src/models/backbone.py +3 -14
- keras_hub/src/models/bart/bart_backbone.py +4 -4
- keras_hub/src/models/bart/bart_presets.py +3 -9
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
- keras_hub/src/models/basnet/__init__.py +5 -0
- keras_hub/src/models/basnet/basnet.py +122 -0
- keras_hub/src/models/basnet/basnet_backbone.py +366 -0
- keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
- keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
- keras_hub/src/models/basnet/basnet_presets.py +17 -0
- keras_hub/src/models/bert/bert_presets.py +14 -32
- keras_hub/src/models/bert/bert_text_classifier.py +3 -3
- keras_hub/src/models/bloom/bloom_presets.py +8 -24
- keras_hub/src/models/causal_lm.py +56 -12
- keras_hub/src/models/clip/__init__.py +5 -0
- keras_hub/src/models/clip/clip_backbone.py +286 -0
- keras_hub/src/models/clip/clip_encoder_block.py +19 -4
- keras_hub/src/models/clip/clip_image_converter.py +8 -0
- keras_hub/src/models/clip/clip_presets.py +93 -0
- keras_hub/src/models/clip/clip_text_encoder.py +4 -1
- keras_hub/src/models/clip/clip_tokenizer.py +18 -3
- keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
- keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
- keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
- keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
- keras_hub/src/models/deeplab_v3/__init__.py +7 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
- keras_hub/src/models/densenet/densenet_backbone.py +6 -4
- keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
- keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
- keras_hub/src/models/densenet/densenet_presets.py +9 -15
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
- keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
- keras_hub/src/models/efficientnet/__init__.py +9 -0
- keras_hub/src/models/efficientnet/cba.py +141 -0
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
- keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
- keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
- keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
- keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
- keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
- keras_hub/src/models/efficientnet/mbconv.py +53 -22
- keras_hub/src/models/electra/electra_backbone.py +2 -2
- keras_hub/src/models/electra/electra_presets.py +6 -18
- keras_hub/src/models/f_net/f_net_presets.py +2 -6
- keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
- keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
- keras_hub/src/models/falcon/falcon_backbone.py +5 -3
- keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
- keras_hub/src/models/falcon/falcon_presets.py +1 -3
- keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
- keras_hub/src/models/feature_pyramid_backbone.py +1 -1
- keras_hub/src/models/flux/__init__.py +5 -0
- keras_hub/src/models/flux/flux_layers.py +496 -0
- keras_hub/src/models/flux/flux_maths.py +225 -0
- keras_hub/src/models/flux/flux_model.py +236 -0
- keras_hub/src/models/flux/flux_presets.py +3 -0
- keras_hub/src/models/flux/flux_text_to_image.py +146 -0
- keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
- keras_hub/src/models/gemma/gemma_backbone.py +35 -20
- keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
- keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
- keras_hub/src/models/gemma/gemma_presets.py +29 -63
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
- keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
- keras_hub/src/models/image_classifier.py +147 -2
- keras_hub/src/models/image_classifier_preprocessor.py +6 -3
- keras_hub/src/models/image_object_detector.py +87 -0
- keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
- keras_hub/src/models/image_segmenter.py +0 -5
- keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
- keras_hub/src/models/image_to_image.py +417 -0
- keras_hub/src/models/inpaint.py +520 -0
- keras_hub/src/models/llama/llama_backbone.py +138 -12
- keras_hub/src/models/llama/llama_causal_lm.py +3 -1
- keras_hub/src/models/llama/llama_presets.py +10 -20
- keras_hub/src/models/llama3/llama3_backbone.py +12 -11
- keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
- keras_hub/src/models/llama3/llama3_presets.py +4 -12
- keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
- keras_hub/src/models/mistral/mistral_backbone.py +16 -15
- keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
- keras_hub/src/models/mistral/mistral_presets.py +3 -9
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
- keras_hub/src/models/mit/__init__.py +6 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
- keras_hub/src/models/mit/mit_image_classifier.py +12 -0
- keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/mit/mit_image_converter.py +8 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
- keras_hub/src/models/mit/mit_presets.py +139 -0
- keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
- keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
- keras_hub/src/models/opt/opt_causal_lm.py +2 -2
- keras_hub/src/models/opt/opt_presets.py +4 -12
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
- keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
- keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
- keras_hub/src/models/phi3/phi3_decoder.py +0 -1
- keras_hub/src/models/phi3/phi3_presets.py +2 -6
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
- keras_hub/src/models/preprocessor.py +25 -11
- keras_hub/src/models/resnet/resnet_backbone.py +3 -14
- keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
- keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
- keras_hub/src/models/resnet/resnet_presets.py +127 -18
- keras_hub/src/models/retinanet/__init__.py +5 -0
- keras_hub/src/models/retinanet/anchor_generator.py +52 -53
- keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
- keras_hub/src/models/retinanet/non_max_supression.py +1 -0
- keras_hub/src/models/retinanet/prediction_head.py +192 -0
- keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
- keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
- keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
- keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
- keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
- keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
- keras_hub/src/models/roberta/roberta_backbone.py +2 -2
- keras_hub/src/models/roberta/roberta_presets.py +6 -8
- keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
- keras_hub/src/models/sam/__init__.py +5 -0
- keras_hub/src/models/sam/sam_backbone.py +2 -3
- keras_hub/src/models/sam/sam_image_converter.py +2 -4
- keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
- keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
- keras_hub/src/models/sam/sam_layers.py +5 -3
- keras_hub/src/models/sam/sam_presets.py +3 -9
- keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
- keras_hub/src/models/sam/sam_transformer.py +5 -4
- keras_hub/src/models/segformer/__init__.py +8 -0
- keras_hub/src/models/segformer/segformer_backbone.py +167 -0
- keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
- keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
- keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
- keras_hub/src/models/segformer/segformer_presets.py +136 -0
- keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
- keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
- keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
- keras_hub/src/models/t5/t5_backbone.py +5 -4
- keras_hub/src/models/t5/t5_presets.py +47 -19
- keras_hub/src/models/task.py +47 -39
- keras_hub/src/models/text_classifier.py +2 -2
- keras_hub/src/models/text_to_image.py +106 -41
- keras_hub/src/models/vae/__init__.py +1 -0
- keras_hub/src/models/vae/vae_backbone.py +184 -0
- keras_hub/src/models/vae/vae_layers.py +739 -0
- keras_hub/src/models/vgg/__init__.py +5 -0
- keras_hub/src/models/vgg/vgg_backbone.py +4 -24
- keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
- keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
- keras_hub/src/models/vgg/vgg_presets.py +48 -0
- keras_hub/src/models/vit/__init__.py +5 -0
- keras_hub/src/models/vit/vit_backbone.py +152 -0
- keras_hub/src/models/vit/vit_image_classifier.py +187 -0
- keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vit/vit_image_converter.py +73 -0
- keras_hub/src/models/vit/vit_layers.py +391 -0
- keras_hub/src/models/vit/vit_presets.py +126 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
- keras_hub/src/models/vit_det/vit_layers.py +3 -3
- keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
- keras_hub/src/models/whisper/whisper_backbone.py +6 -5
- keras_hub/src/models/whisper/whisper_decoder.py +3 -5
- keras_hub/src/models/whisper/whisper_presets.py +10 -30
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
- keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
- keras_hub/src/models/xlnet/relative_attention.py +20 -19
- keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
- keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
- keras_hub/src/samplers/contrastive_sampler.py +2 -3
- keras_hub/src/samplers/sampler.py +2 -1
- keras_hub/src/tests/test_case.py +41 -6
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
- keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
- keras_hub/src/tokenizers/tokenizer.py +10 -13
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
- keras_hub/src/utils/keras_utils.py +2 -13
- keras_hub/src/utils/pipeline_model.py +3 -3
- keras_hub/src/utils/preset_utils.py +196 -144
- keras_hub/src/utils/tensor_utils.py +4 -4
- keras_hub/src/utils/timm/convert_densenet.py +6 -4
- keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
- keras_hub/src/utils/timm/convert_resnet.py +1 -1
- keras_hub/src/utils/timm/convert_vgg.py +85 -0
- keras_hub/src/utils/timm/preset_loader.py +14 -9
- keras_hub/src/utils/transformers/convert_llama3.py +21 -5
- keras_hub/src/utils/transformers/convert_vit.py +150 -0
- keras_hub/src/utils/transformers/preset_loader.py +23 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
- keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
- keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
- keras_hub/src/models/mix_transformer/__init__.py +0 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
- keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
- keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,193 @@
|
|
1
|
+
"""EfficientNet preset configurations."""
|
2
|
+
|
3
|
+
backbone_presets = {
|
4
|
+
"efficientnet_b0_ra_imagenet": {
|
5
|
+
"metadata": {
|
6
|
+
"description": (
|
7
|
+
"EfficientNet B0 model pre-trained on the ImageNet 1k dataset "
|
8
|
+
"with RandAugment recipe."
|
9
|
+
),
|
10
|
+
"params": 5288548,
|
11
|
+
"path": "efficientnet",
|
12
|
+
},
|
13
|
+
"kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b0_ra_imagenet/2",
|
14
|
+
},
|
15
|
+
"efficientnet_b0_ra4_e3600_r224_imagenet": {
|
16
|
+
"metadata": {
|
17
|
+
"description": (
|
18
|
+
"EfficientNet B0 model pre-trained on the ImageNet 1k dataset "
|
19
|
+
"by Ross Wightman. Trained with timm scripts using "
|
20
|
+
"hyper-parameters inspired by the MobileNet-V4 small, mixed "
|
21
|
+
"with go-to hparams from timm and 'ResNet Strikes Back'."
|
22
|
+
),
|
23
|
+
"params": 5288548,
|
24
|
+
"path": "efficientnet",
|
25
|
+
},
|
26
|
+
"kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b0_ra4_e3600_r224_imagenet/2",
|
27
|
+
},
|
28
|
+
"efficientnet_b1_ft_imagenet": {
|
29
|
+
"metadata": {
|
30
|
+
"description": (
|
31
|
+
"EfficientNet B1 model fine-tuned on the ImageNet 1k dataset."
|
32
|
+
),
|
33
|
+
"params": 7794184,
|
34
|
+
"path": "efficientnet",
|
35
|
+
},
|
36
|
+
"kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b1_ft_imagenet/5",
|
37
|
+
},
|
38
|
+
"efficientnet_b1_ra4_e3600_r240_imagenet": {
|
39
|
+
"metadata": {
|
40
|
+
"description": (
|
41
|
+
"EfficientNet B1 model pre-trained on the ImageNet 1k dataset "
|
42
|
+
"by Ross Wightman. Trained with timm scripts using "
|
43
|
+
"hyper-parameters inspired by the MobileNet-V4 small, mixed "
|
44
|
+
"with go-to hparams from timm and 'ResNet Strikes Back'."
|
45
|
+
),
|
46
|
+
"params": 7794184,
|
47
|
+
"path": "efficientnet",
|
48
|
+
},
|
49
|
+
"kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b1_ra4_e3600_r240_imagenet/2",
|
50
|
+
},
|
51
|
+
"efficientnet_b2_ra_imagenet": {
|
52
|
+
"metadata": {
|
53
|
+
"description": (
|
54
|
+
"EfficientNet B2 model pre-trained on the ImageNet 1k dataset "
|
55
|
+
"with RandAugment recipe."
|
56
|
+
),
|
57
|
+
"params": 9109994,
|
58
|
+
"path": "efficientnet",
|
59
|
+
},
|
60
|
+
"kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b2_ra_imagenet/2",
|
61
|
+
},
|
62
|
+
"efficientnet_b3_ra2_imagenet": {
|
63
|
+
"metadata": {
|
64
|
+
"description": (
|
65
|
+
"EfficientNet B3 model pre-trained on the ImageNet 1k dataset "
|
66
|
+
"with RandAugment2 recipe."
|
67
|
+
),
|
68
|
+
"params": 12233232,
|
69
|
+
"path": "efficientnet",
|
70
|
+
},
|
71
|
+
"kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b3_ra2_imagenet/2",
|
72
|
+
},
|
73
|
+
"efficientnet_b4_ra2_imagenet": {
|
74
|
+
"metadata": {
|
75
|
+
"description": (
|
76
|
+
"EfficientNet B4 model pre-trained on the ImageNet 1k dataset "
|
77
|
+
"with RandAugment2 recipe."
|
78
|
+
),
|
79
|
+
"params": 19341616,
|
80
|
+
"path": "efficientnet",
|
81
|
+
},
|
82
|
+
"kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b4_ra2_imagenet/2",
|
83
|
+
},
|
84
|
+
"efficientnet_b5_sw_imagenet": {
|
85
|
+
"metadata": {
|
86
|
+
"description": (
|
87
|
+
"EfficientNet B5 model pre-trained on the ImageNet 12k dataset "
|
88
|
+
"by Ross Wightman. Based on Swin Transformer train / pretrain "
|
89
|
+
"recipe with modifications (related to both DeiT and ConvNeXt "
|
90
|
+
"recipes)."
|
91
|
+
),
|
92
|
+
"params": 30389784,
|
93
|
+
"path": "efficientnet",
|
94
|
+
},
|
95
|
+
"kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b5_sw_imagenet/2",
|
96
|
+
},
|
97
|
+
"efficientnet_b5_sw_ft_imagenet": {
|
98
|
+
"metadata": {
|
99
|
+
"description": (
|
100
|
+
"EfficientNet B5 model pre-trained on the ImageNet 12k dataset "
|
101
|
+
"and fine-tuned on ImageNet-1k by Ross Wightman. Based on Swin "
|
102
|
+
"Transformer train / pretrain recipe with modifications "
|
103
|
+
"(related to both DeiT and ConvNeXt recipes)."
|
104
|
+
),
|
105
|
+
"params": 30389784,
|
106
|
+
"path": "efficientnet",
|
107
|
+
},
|
108
|
+
"kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b5_sw_ft_imagenet/2",
|
109
|
+
},
|
110
|
+
"efficientnet_el_ra_imagenet": {
|
111
|
+
"metadata": {
|
112
|
+
"description": (
|
113
|
+
"EfficientNet-EdgeTPU Large model trained on the ImageNet 1k "
|
114
|
+
"dataset with RandAugment recipe."
|
115
|
+
),
|
116
|
+
"params": 10589712,
|
117
|
+
"path": "efficientnet",
|
118
|
+
},
|
119
|
+
"kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b1_ft_imagenet/5",
|
120
|
+
},
|
121
|
+
"efficientnet_em_ra2_imagenet": {
|
122
|
+
"metadata": {
|
123
|
+
"description": (
|
124
|
+
"EfficientNet-EdgeTPU Medium model trained on the ImageNet 1k "
|
125
|
+
"dataset with RandAugment2 recipe."
|
126
|
+
),
|
127
|
+
"params": 6899496,
|
128
|
+
"path": "efficientnet",
|
129
|
+
},
|
130
|
+
"kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b1_ft_imagenet/5",
|
131
|
+
},
|
132
|
+
"efficientnet_es_ra_imagenet": {
|
133
|
+
"metadata": {
|
134
|
+
"description": (
|
135
|
+
"EfficientNet-EdgeTPU Small model trained on the ImageNet 1k "
|
136
|
+
"dataset with RandAugment recipe."
|
137
|
+
),
|
138
|
+
"params": 5438392,
|
139
|
+
"path": "efficientnet",
|
140
|
+
},
|
141
|
+
"kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b1_ft_imagenet/5",
|
142
|
+
},
|
143
|
+
"efficientnet2_rw_m_agc_imagenet": {
|
144
|
+
"metadata": {
|
145
|
+
"description": (
|
146
|
+
"EfficientNet-v2 Medium model trained on the ImageNet 1k "
|
147
|
+
"dataset with adaptive gradient clipping."
|
148
|
+
),
|
149
|
+
"params": 53236442,
|
150
|
+
"official_name": "EfficientNet",
|
151
|
+
"path": "efficientnet",
|
152
|
+
"model_card": "https://arxiv.org/abs/2104.00298",
|
153
|
+
},
|
154
|
+
"kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet2_rw_m_agc_imagenet/2",
|
155
|
+
},
|
156
|
+
"efficientnet2_rw_s_ra2_imagenet": {
|
157
|
+
"metadata": {
|
158
|
+
"description": (
|
159
|
+
"EfficientNet-v2 Small model trained on the ImageNet 1k "
|
160
|
+
"dataset with RandAugment2 recipe."
|
161
|
+
),
|
162
|
+
"params": 23941296,
|
163
|
+
"official_name": "EfficientNet",
|
164
|
+
"path": "efficientnet",
|
165
|
+
"model_card": "https://arxiv.org/abs/2104.00298",
|
166
|
+
},
|
167
|
+
"kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet2_rw_s_ra2_imagenet/2",
|
168
|
+
},
|
169
|
+
"efficientnet2_rw_t_ra2_imagenet": {
|
170
|
+
"metadata": {
|
171
|
+
"description": (
|
172
|
+
"EfficientNet-v2 Tiny model trained on the ImageNet 1k "
|
173
|
+
"dataset with RandAugment2 recipe."
|
174
|
+
),
|
175
|
+
"params": 13649388,
|
176
|
+
"official_name": "EfficientNet",
|
177
|
+
"path": "efficientnet",
|
178
|
+
"model_card": "https://arxiv.org/abs/2104.00298",
|
179
|
+
},
|
180
|
+
"kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet2_rw_t_ra2_imagenet/2",
|
181
|
+
},
|
182
|
+
"efficientnet_lite0_ra_imagenet": {
|
183
|
+
"metadata": {
|
184
|
+
"description": (
|
185
|
+
"EfficientNet-Lite model fine-trained on the ImageNet 1k "
|
186
|
+
"dataset with RandAugment recipe."
|
187
|
+
),
|
188
|
+
"params": 4652008,
|
189
|
+
"path": "efficientnet",
|
190
|
+
},
|
191
|
+
"kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_lite0_ra_imagenet/2",
|
192
|
+
},
|
193
|
+
}
|
@@ -2,24 +2,13 @@ import keras
|
|
2
2
|
|
3
3
|
BN_AXIS = 3
|
4
4
|
|
5
|
-
CONV_KERNEL_INITIALIZER = {
|
6
|
-
"class_name": "VarianceScaling",
|
7
|
-
"config": {
|
8
|
-
"scale": 2.0,
|
9
|
-
"mode": "fan_out",
|
10
|
-
"distribution": "truncated_normal",
|
11
|
-
},
|
12
|
-
}
|
13
|
-
|
14
5
|
|
15
6
|
class FusedMBConvBlock(keras.layers.Layer):
|
16
7
|
"""Implementation of the FusedMBConv block
|
17
8
|
|
18
9
|
Also known as a Fused Mobile Inverted Residual Bottleneck block from:
|
19
|
-
|
20
|
-
|
21
|
-
[EfficientNetV2: Smaller Models and Faster Training]
|
22
|
-
(https://arxiv.org/abs/2104.00298v3).
|
10
|
+
[EfficientNet-EdgeTPU](https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html)
|
11
|
+
[EfficientNetV2: Smaller Models and Faster Training](https://arxiv.org/abs/2104.00298v3).
|
23
12
|
|
24
13
|
FusedMBConv blocks are based on MBConv blocks, and replace the depthwise and
|
25
14
|
1x1 output convolution blocks with a single 3x3 convolution block, fusing
|
@@ -44,13 +33,24 @@ class FusedMBConvBlock(keras.layers.Layer):
|
|
44
33
|
convolutions
|
45
34
|
strides: default 1, the strides to apply to the expansion phase
|
46
35
|
convolutions
|
36
|
+
data_format: str, channels_last (default) or channels_first, expects
|
37
|
+
tensors to be of shape (N, H, W, C) or (N, C, H, W) respectively
|
47
38
|
se_ratio: default 0.0, The filters used in the Squeeze-Excitation phase,
|
48
39
|
and are chosen as the maximum between 1 and input_filters*se_ratio
|
49
40
|
batch_norm_momentum: default 0.9, the BatchNormalization momentum
|
41
|
+
batch_norm_epsilon: default 1e-3, float, epsilon for batch norm
|
42
|
+
calcualtions. Used in denominator for calculations to prevent divide
|
43
|
+
by 0 errors.
|
50
44
|
activation: default "swish", the activation function used between
|
51
45
|
convolution operations
|
46
|
+
projection_activation: default None, the activation function to use
|
47
|
+
after the output projection convoultion
|
52
48
|
dropout: float, the optional dropout rate to apply before the output
|
53
49
|
convolution, defaults to 0.2
|
50
|
+
nores: bool, default False, forces no residual connection if True,
|
51
|
+
otherwise allows it if False.
|
52
|
+
projection_kernel_size: default 1, the kernel_size to apply to the
|
53
|
+
output projection phase convolution
|
54
54
|
|
55
55
|
Returns:
|
56
56
|
A tensor representing a feature map, passed through the FusedMBConv
|
@@ -67,11 +67,16 @@ class FusedMBConvBlock(keras.layers.Layer):
|
|
67
67
|
expand_ratio=1,
|
68
68
|
kernel_size=3,
|
69
69
|
strides=1,
|
70
|
+
data_format="channels_last",
|
70
71
|
se_ratio=0.0,
|
71
72
|
batch_norm_momentum=0.9,
|
73
|
+
batch_norm_epsilon=1e-3,
|
72
74
|
activation="swish",
|
75
|
+
projection_activation=None,
|
73
76
|
dropout=0.2,
|
74
|
-
|
77
|
+
nores=False,
|
78
|
+
projection_kernel_size=1,
|
79
|
+
**kwargs,
|
75
80
|
):
|
76
81
|
super().__init__(**kwargs)
|
77
82
|
self.input_filters = input_filters
|
@@ -79,44 +84,50 @@ class FusedMBConvBlock(keras.layers.Layer):
|
|
79
84
|
self.expand_ratio = expand_ratio
|
80
85
|
self.kernel_size = kernel_size
|
81
86
|
self.strides = strides
|
87
|
+
self.data_format = data_format
|
82
88
|
self.se_ratio = se_ratio
|
83
89
|
self.batch_norm_momentum = batch_norm_momentum
|
90
|
+
self.batch_norm_epsilon = batch_norm_epsilon
|
84
91
|
self.activation = activation
|
92
|
+
self.projection_activation = projection_activation
|
85
93
|
self.dropout = dropout
|
94
|
+
self.nores = nores
|
95
|
+
self.projection_kernel_size = projection_kernel_size
|
86
96
|
self.filters = self.input_filters * self.expand_ratio
|
87
97
|
self.filters_se = max(1, int(input_filters * se_ratio))
|
88
98
|
|
99
|
+
padding_pixels = kernel_size // 2
|
100
|
+
self.conv1_pad = keras.layers.ZeroPadding2D(
|
101
|
+
padding=(padding_pixels, padding_pixels),
|
102
|
+
name=self.name + "expand_conv_pad",
|
103
|
+
)
|
89
104
|
self.conv1 = keras.layers.Conv2D(
|
90
105
|
filters=self.filters,
|
91
106
|
kernel_size=kernel_size,
|
92
107
|
strides=strides,
|
93
|
-
kernel_initializer=
|
94
|
-
padding="
|
95
|
-
data_format=
|
108
|
+
kernel_initializer=self._conv_kernel_initializer(),
|
109
|
+
padding="valid",
|
110
|
+
data_format=data_format,
|
96
111
|
use_bias=False,
|
97
112
|
name=self.name + "expand_conv",
|
98
113
|
)
|
99
114
|
self.bn1 = keras.layers.BatchNormalization(
|
100
115
|
axis=BN_AXIS,
|
101
116
|
momentum=self.batch_norm_momentum,
|
117
|
+
epsilon=self.batch_norm_epsilon,
|
102
118
|
name=self.name + "expand_bn",
|
103
119
|
)
|
104
120
|
self.act = keras.layers.Activation(
|
105
121
|
self.activation, name=self.name + "expand_activation"
|
106
122
|
)
|
107
123
|
|
108
|
-
self.bn2 = keras.layers.BatchNormalization(
|
109
|
-
axis=BN_AXIS,
|
110
|
-
momentum=self.batch_norm_momentum,
|
111
|
-
name=self.name + "bn",
|
112
|
-
)
|
113
|
-
|
114
124
|
self.se_conv1 = keras.layers.Conv2D(
|
115
125
|
self.filters_se,
|
116
126
|
1,
|
117
127
|
padding="same",
|
128
|
+
data_format=data_format,
|
118
129
|
activation=self.activation,
|
119
|
-
kernel_initializer=
|
130
|
+
kernel_initializer=self._conv_kernel_initializer(),
|
120
131
|
name=self.name + "se_reduce",
|
121
132
|
)
|
122
133
|
|
@@ -124,28 +135,40 @@ class FusedMBConvBlock(keras.layers.Layer):
|
|
124
135
|
self.filters,
|
125
136
|
1,
|
126
137
|
padding="same",
|
138
|
+
data_format=data_format,
|
127
139
|
activation="sigmoid",
|
128
|
-
kernel_initializer=
|
140
|
+
kernel_initializer=self._conv_kernel_initializer(),
|
129
141
|
name=self.name + "se_expand",
|
130
142
|
)
|
131
143
|
|
144
|
+
padding_pixels = projection_kernel_size // 2
|
145
|
+
self.output_conv_pad = keras.layers.ZeroPadding2D(
|
146
|
+
padding=(padding_pixels, padding_pixels),
|
147
|
+
name=self.name + "project_conv_pad",
|
148
|
+
)
|
132
149
|
self.output_conv = keras.layers.Conv2D(
|
133
150
|
filters=self.output_filters,
|
134
|
-
kernel_size=
|
151
|
+
kernel_size=projection_kernel_size,
|
135
152
|
strides=1,
|
136
|
-
kernel_initializer=
|
137
|
-
padding="
|
138
|
-
data_format=
|
153
|
+
kernel_initializer=self._conv_kernel_initializer(),
|
154
|
+
padding="valid",
|
155
|
+
data_format=data_format,
|
139
156
|
use_bias=False,
|
140
157
|
name=self.name + "project_conv",
|
141
158
|
)
|
142
159
|
|
143
|
-
self.
|
160
|
+
self.bn2 = keras.layers.BatchNormalization(
|
144
161
|
axis=BN_AXIS,
|
145
162
|
momentum=self.batch_norm_momentum,
|
163
|
+
epsilon=self.batch_norm_epsilon,
|
146
164
|
name=self.name + "project_bn",
|
147
165
|
)
|
148
166
|
|
167
|
+
if self.projection_activation:
|
168
|
+
self.projection_act = keras.layers.Activation(
|
169
|
+
self.projection_activation, name=self.name + "projection_act"
|
170
|
+
)
|
171
|
+
|
149
172
|
if self.dropout:
|
150
173
|
self.dropout_layer = keras.layers.Dropout(
|
151
174
|
self.dropout,
|
@@ -153,23 +176,33 @@ class FusedMBConvBlock(keras.layers.Layer):
|
|
153
176
|
name=self.name + "drop",
|
154
177
|
)
|
155
178
|
|
179
|
+
def _conv_kernel_initializer(
|
180
|
+
self,
|
181
|
+
scale=2.0,
|
182
|
+
mode="fan_out",
|
183
|
+
distribution="truncated_normal",
|
184
|
+
seed=None,
|
185
|
+
):
|
186
|
+
return keras.initializers.VarianceScaling(
|
187
|
+
scale=scale, mode=mode, distribution=distribution, seed=seed
|
188
|
+
)
|
189
|
+
|
156
190
|
def build(self, input_shape):
|
157
191
|
if self.name is None:
|
158
192
|
self.name = keras.backend.get_uid("block0")
|
159
193
|
|
160
194
|
def call(self, inputs):
|
161
195
|
# Expansion phase
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
else:
|
167
|
-
x = inputs
|
196
|
+
x = self.conv1_pad(inputs)
|
197
|
+
x = self.conv1(x)
|
198
|
+
x = self.bn1(x)
|
199
|
+
x = self.act(x)
|
168
200
|
|
169
201
|
# Squeeze and excite
|
170
202
|
if 0 < self.se_ratio <= 1:
|
171
203
|
se = keras.layers.GlobalAveragePooling2D(
|
172
|
-
name=self.name + "se_squeeze"
|
204
|
+
name=self.name + "se_squeeze",
|
205
|
+
data_format=self.data_format,
|
173
206
|
)(x)
|
174
207
|
if BN_AXIS == 1:
|
175
208
|
se_shape = (self.filters, 1, 1)
|
@@ -186,13 +219,18 @@ class FusedMBConvBlock(keras.layers.Layer):
|
|
186
219
|
x = keras.layers.multiply([x, se], name=self.name + "se_excite")
|
187
220
|
|
188
221
|
# Output phase:
|
222
|
+
x = self.output_conv_pad(x)
|
189
223
|
x = self.output_conv(x)
|
190
|
-
x = self.
|
191
|
-
if self.expand_ratio == 1:
|
192
|
-
x = self.
|
224
|
+
x = self.bn2(x)
|
225
|
+
if self.expand_ratio == 1 and self.projection_activation:
|
226
|
+
x = self.projection_act(x)
|
193
227
|
|
194
228
|
# Residual:
|
195
|
-
if
|
229
|
+
if (
|
230
|
+
self.strides == 1
|
231
|
+
and self.input_filters == self.output_filters
|
232
|
+
and not self.nores
|
233
|
+
):
|
196
234
|
if self.dropout:
|
197
235
|
x = self.dropout_layer(x)
|
198
236
|
x = keras.layers.Add(name=self.name + "add")([x, inputs])
|
@@ -205,10 +243,15 @@ class FusedMBConvBlock(keras.layers.Layer):
|
|
205
243
|
"expand_ratio": self.expand_ratio,
|
206
244
|
"kernel_size": self.kernel_size,
|
207
245
|
"strides": self.strides,
|
246
|
+
"data_format": self.data_format,
|
208
247
|
"se_ratio": self.se_ratio,
|
209
248
|
"batch_norm_momentum": self.batch_norm_momentum,
|
249
|
+
"batch_norm_epsilon": self.batch_norm_epsilon,
|
210
250
|
"activation": self.activation,
|
251
|
+
"projection_activation": self.projection_activation,
|
211
252
|
"dropout": self.dropout,
|
253
|
+
"nores": self.nores,
|
254
|
+
"projection_kernel_size": self.projection_kernel_size,
|
212
255
|
}
|
213
256
|
|
214
257
|
base_config = super().get_config()
|
@@ -2,15 +2,6 @@ import keras
|
|
2
2
|
|
3
3
|
BN_AXIS = 3
|
4
4
|
|
5
|
-
CONV_KERNEL_INITIALIZER = {
|
6
|
-
"class_name": "VarianceScaling",
|
7
|
-
"config": {
|
8
|
-
"scale": 2.0,
|
9
|
-
"mode": "fan_out",
|
10
|
-
"distribution": "truncated_normal",
|
11
|
-
},
|
12
|
-
}
|
13
|
-
|
14
5
|
|
15
6
|
class MBConvBlock(keras.layers.Layer):
|
16
7
|
def __init__(
|
@@ -20,11 +11,14 @@ class MBConvBlock(keras.layers.Layer):
|
|
20
11
|
expand_ratio=1,
|
21
12
|
kernel_size=3,
|
22
13
|
strides=1,
|
14
|
+
data_format="channels_last",
|
23
15
|
se_ratio=0.0,
|
24
16
|
batch_norm_momentum=0.9,
|
17
|
+
batch_norm_epsilon=1e-3,
|
25
18
|
activation="swish",
|
26
19
|
dropout=0.2,
|
27
|
-
|
20
|
+
nores=False,
|
21
|
+
**kwargs,
|
28
22
|
):
|
29
23
|
"""Implementation of the MBConv block
|
30
24
|
|
@@ -59,6 +53,9 @@ class MBConvBlock(keras.layers.Layer):
|
|
59
53
|
is above 0. The filters used in this phase are chosen as the
|
60
54
|
maximum between 1 and input_filters*se_ratio
|
61
55
|
batch_norm_momentum: default 0.9, the BatchNormalization momentum
|
56
|
+
batch_norm_epsilon: default 1e-3, float, epsilon for batch norm
|
57
|
+
calcualtions. Used in denominator for calculations to prevent
|
58
|
+
divide by 0 errors.
|
62
59
|
activation: default "swish", the activation function used between
|
63
60
|
convolution operations
|
64
61
|
dropout: float, the optional dropout rate to apply before the output
|
@@ -79,10 +76,13 @@ class MBConvBlock(keras.layers.Layer):
|
|
79
76
|
self.expand_ratio = expand_ratio
|
80
77
|
self.kernel_size = kernel_size
|
81
78
|
self.strides = strides
|
79
|
+
self.data_format = data_format
|
82
80
|
self.se_ratio = se_ratio
|
83
81
|
self.batch_norm_momentum = batch_norm_momentum
|
82
|
+
self.batch_norm_epsilon = batch_norm_epsilon
|
84
83
|
self.activation = activation
|
85
84
|
self.dropout = dropout
|
85
|
+
self.nores = nores
|
86
86
|
self.filters = self.input_filters * self.expand_ratio
|
87
87
|
self.filters_se = max(1, int(input_filters * se_ratio))
|
88
88
|
|
@@ -90,15 +90,16 @@ class MBConvBlock(keras.layers.Layer):
|
|
90
90
|
filters=self.filters,
|
91
91
|
kernel_size=1,
|
92
92
|
strides=1,
|
93
|
-
kernel_initializer=
|
93
|
+
kernel_initializer=self._conv_kernel_initializer(),
|
94
94
|
padding="same",
|
95
|
-
data_format=
|
95
|
+
data_format=data_format,
|
96
96
|
use_bias=False,
|
97
97
|
name=self.name + "expand_conv",
|
98
98
|
)
|
99
99
|
self.bn1 = keras.layers.BatchNormalization(
|
100
100
|
axis=BN_AXIS,
|
101
101
|
momentum=self.batch_norm_momentum,
|
102
|
+
epsilon=self.batch_norm_epsilon,
|
102
103
|
name=self.name + "expand_bn",
|
103
104
|
)
|
104
105
|
self.act = keras.layers.Activation(
|
@@ -107,9 +108,9 @@ class MBConvBlock(keras.layers.Layer):
|
|
107
108
|
self.depthwise = keras.layers.DepthwiseConv2D(
|
108
109
|
kernel_size=self.kernel_size,
|
109
110
|
strides=self.strides,
|
110
|
-
depthwise_initializer=
|
111
|
+
depthwise_initializer=self._conv_kernel_initializer(),
|
111
112
|
padding="same",
|
112
|
-
data_format=
|
113
|
+
data_format=data_format,
|
113
114
|
use_bias=False,
|
114
115
|
name=self.name + "dwconv2",
|
115
116
|
)
|
@@ -117,6 +118,7 @@ class MBConvBlock(keras.layers.Layer):
|
|
117
118
|
self.bn2 = keras.layers.BatchNormalization(
|
118
119
|
axis=BN_AXIS,
|
119
120
|
momentum=self.batch_norm_momentum,
|
121
|
+
epsilon=self.batch_norm_epsilon,
|
120
122
|
name=self.name + "bn",
|
121
123
|
)
|
122
124
|
|
@@ -124,8 +126,9 @@ class MBConvBlock(keras.layers.Layer):
|
|
124
126
|
self.filters_se,
|
125
127
|
1,
|
126
128
|
padding="same",
|
129
|
+
data_format=data_format,
|
127
130
|
activation=self.activation,
|
128
|
-
kernel_initializer=
|
131
|
+
kernel_initializer=self._conv_kernel_initializer(),
|
129
132
|
name=self.name + "se_reduce",
|
130
133
|
)
|
131
134
|
|
@@ -133,18 +136,25 @@ class MBConvBlock(keras.layers.Layer):
|
|
133
136
|
self.filters,
|
134
137
|
1,
|
135
138
|
padding="same",
|
139
|
+
data_format=data_format,
|
136
140
|
activation="sigmoid",
|
137
|
-
kernel_initializer=
|
141
|
+
kernel_initializer=self._conv_kernel_initializer(),
|
138
142
|
name=self.name + "se_expand",
|
139
143
|
)
|
140
144
|
|
145
|
+
projection_kernel_size = 1 if expand_ratio != 1 else kernel_size
|
146
|
+
padding_pixels = projection_kernel_size // 2
|
147
|
+
self.output_conv_pad = keras.layers.ZeroPadding2D(
|
148
|
+
padding=(padding_pixels, padding_pixels),
|
149
|
+
name=self.name + "project_conv_pad",
|
150
|
+
)
|
141
151
|
self.output_conv = keras.layers.Conv2D(
|
142
152
|
filters=self.output_filters,
|
143
|
-
kernel_size=
|
153
|
+
kernel_size=projection_kernel_size,
|
144
154
|
strides=1,
|
145
|
-
kernel_initializer=
|
146
|
-
padding="
|
147
|
-
data_format=
|
155
|
+
kernel_initializer=self._conv_kernel_initializer(),
|
156
|
+
padding="valid",
|
157
|
+
data_format=data_format,
|
148
158
|
use_bias=False,
|
149
159
|
name=self.name + "project_conv",
|
150
160
|
)
|
@@ -152,6 +162,7 @@ class MBConvBlock(keras.layers.Layer):
|
|
152
162
|
self.bn3 = keras.layers.BatchNormalization(
|
153
163
|
axis=BN_AXIS,
|
154
164
|
momentum=self.batch_norm_momentum,
|
165
|
+
epsilon=self.batch_norm_epsilon,
|
155
166
|
name=self.name + "project_bn",
|
156
167
|
)
|
157
168
|
|
@@ -162,6 +173,17 @@ class MBConvBlock(keras.layers.Layer):
|
|
162
173
|
name=self.name + "drop",
|
163
174
|
)
|
164
175
|
|
176
|
+
def _conv_kernel_initializer(
|
177
|
+
self,
|
178
|
+
scale=2.0,
|
179
|
+
mode="fan_out",
|
180
|
+
distribution="truncated_normal",
|
181
|
+
seed=None,
|
182
|
+
):
|
183
|
+
return keras.initializers.VarianceScaling(
|
184
|
+
scale=scale, mode=mode, distribution=distribution, seed=seed
|
185
|
+
)
|
186
|
+
|
165
187
|
def build(self, input_shape):
|
166
188
|
if self.name is None:
|
167
189
|
self.name = keras.backend.get_uid("block0")
|
@@ -183,7 +205,8 @@ class MBConvBlock(keras.layers.Layer):
|
|
183
205
|
# Squeeze and excite
|
184
206
|
if 0 < self.se_ratio <= 1:
|
185
207
|
se = keras.layers.GlobalAveragePooling2D(
|
186
|
-
name=self.name + "se_squeeze"
|
208
|
+
name=self.name + "se_squeeze",
|
209
|
+
data_format=self.data_format,
|
187
210
|
)(x)
|
188
211
|
if BN_AXIS == 1:
|
189
212
|
se_shape = (self.filters, 1, 1)
|
@@ -199,10 +222,15 @@ class MBConvBlock(keras.layers.Layer):
|
|
199
222
|
x = keras.layers.multiply([x, se], name=self.name + "se_excite")
|
200
223
|
|
201
224
|
# Output phase
|
225
|
+
x = self.output_conv_pad(x)
|
202
226
|
x = self.output_conv(x)
|
203
227
|
x = self.bn3(x)
|
204
228
|
|
205
|
-
if
|
229
|
+
if (
|
230
|
+
self.strides == 1
|
231
|
+
and self.input_filters == self.output_filters
|
232
|
+
and not self.nores
|
233
|
+
):
|
206
234
|
if self.dropout:
|
207
235
|
x = self.dropout_layer(x)
|
208
236
|
x = keras.layers.Add(name=self.name + "add")([x, inputs])
|
@@ -215,10 +243,13 @@ class MBConvBlock(keras.layers.Layer):
|
|
215
243
|
"expand_ratio": self.expand_ratio,
|
216
244
|
"kernel_size": self.kernel_size,
|
217
245
|
"strides": self.strides,
|
246
|
+
"data_format": self.data_format,
|
218
247
|
"se_ratio": self.se_ratio,
|
219
248
|
"batch_norm_momentum": self.batch_norm_momentum,
|
249
|
+
"batch_norm_epsilon": self.batch_norm_epsilon,
|
220
250
|
"activation": self.activation,
|
221
251
|
"dropout": self.dropout,
|
252
|
+
"nores": self.nores,
|
222
253
|
}
|
223
254
|
base_config = super().get_config()
|
224
255
|
return dict(list(base_config.items()) + list(config.items()))
|
@@ -186,8 +186,8 @@ class ElectraBackbone(Backbone):
|
|
186
186
|
# Index of classification token in the vocabulary
|
187
187
|
cls_token_index = 0
|
188
188
|
sequence_output = x
|
189
|
-
# Construct the two ELECTRA outputs. The pooled output is a dense layer
|
190
|
-
# top of the [CLS] token.
|
189
|
+
# Construct the two ELECTRA outputs. The pooled output is a dense layer
|
190
|
+
# on top of the [CLS] token.
|
191
191
|
pooled_output = self.pooled_dense(x[:, cls_token_index, :])
|
192
192
|
super().__init__(
|
193
193
|
inputs={
|