keras-hub-nightly 0.16.1.dev202410200345__py3-none-any.whl → 0.19.0.dev202412070351__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. keras_hub/api/layers/__init__.py +12 -0
  2. keras_hub/api/models/__init__.py +32 -0
  3. keras_hub/src/bounding_box/__init__.py +2 -0
  4. keras_hub/src/bounding_box/converters.py +102 -12
  5. keras_hub/src/layers/modeling/rms_normalization.py +34 -0
  6. keras_hub/src/layers/modeling/transformer_encoder.py +27 -7
  7. keras_hub/src/layers/preprocessing/image_converter.py +5 -0
  8. keras_hub/src/models/albert/albert_presets.py +0 -8
  9. keras_hub/src/models/bart/bart_presets.py +0 -6
  10. keras_hub/src/models/bert/bert_presets.py +0 -20
  11. keras_hub/src/models/bloom/bloom_presets.py +0 -16
  12. keras_hub/src/models/clip/__init__.py +5 -0
  13. keras_hub/src/models/clip/clip_backbone.py +286 -0
  14. keras_hub/src/models/clip/clip_encoder_block.py +19 -4
  15. keras_hub/src/models/clip/clip_image_converter.py +8 -0
  16. keras_hub/src/models/clip/clip_presets.py +93 -0
  17. keras_hub/src/models/clip/clip_text_encoder.py +4 -1
  18. keras_hub/src/models/clip/clip_tokenizer.py +18 -3
  19. keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
  20. keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
  21. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +0 -10
  22. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +0 -2
  23. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +5 -3
  24. keras_hub/src/models/densenet/densenet_backbone.py +1 -1
  25. keras_hub/src/models/densenet/densenet_presets.py +0 -6
  26. keras_hub/src/models/distil_bert/distil_bert_presets.py +0 -6
  27. keras_hub/src/models/efficientnet/__init__.py +9 -0
  28. keras_hub/src/models/efficientnet/cba.py +141 -0
  29. keras_hub/src/models/efficientnet/efficientnet_backbone.py +139 -56
  30. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
  31. keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
  32. keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
  33. keras_hub/src/models/efficientnet/efficientnet_presets.py +192 -0
  34. keras_hub/src/models/efficientnet/fusedmbconv.py +81 -36
  35. keras_hub/src/models/efficientnet/mbconv.py +52 -21
  36. keras_hub/src/models/electra/electra_presets.py +0 -12
  37. keras_hub/src/models/f_net/f_net_presets.py +0 -4
  38. keras_hub/src/models/falcon/falcon_presets.py +0 -2
  39. keras_hub/src/models/flux/__init__.py +5 -0
  40. keras_hub/src/models/flux/flux_layers.py +494 -0
  41. keras_hub/src/models/flux/flux_maths.py +218 -0
  42. keras_hub/src/models/flux/flux_model.py +231 -0
  43. keras_hub/src/models/flux/flux_presets.py +14 -0
  44. keras_hub/src/models/flux/flux_text_to_image.py +142 -0
  45. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
  46. keras_hub/src/models/gemma/gemma_presets.py +0 -40
  47. keras_hub/src/models/gpt2/gpt2_presets.py +0 -9
  48. keras_hub/src/models/image_object_detector.py +87 -0
  49. keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
  50. keras_hub/src/models/image_to_image.py +16 -10
  51. keras_hub/src/models/inpaint.py +20 -13
  52. keras_hub/src/models/llama/llama_backbone.py +1 -1
  53. keras_hub/src/models/llama/llama_presets.py +5 -15
  54. keras_hub/src/models/llama3/llama3_presets.py +0 -8
  55. keras_hub/src/models/mistral/mistral_presets.py +0 -6
  56. keras_hub/src/models/mit/mit_backbone.py +41 -27
  57. keras_hub/src/models/mit/mit_layers.py +9 -7
  58. keras_hub/src/models/mit/mit_presets.py +12 -24
  59. keras_hub/src/models/opt/opt_presets.py +0 -8
  60. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +61 -11
  61. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
  62. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +166 -10
  63. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +12 -11
  64. keras_hub/src/models/phi3/phi3_presets.py +0 -4
  65. keras_hub/src/models/resnet/resnet_presets.py +10 -42
  66. keras_hub/src/models/retinanet/__init__.py +5 -0
  67. keras_hub/src/models/retinanet/anchor_generator.py +52 -53
  68. keras_hub/src/models/retinanet/feature_pyramid.py +99 -36
  69. keras_hub/src/models/retinanet/non_max_supression.py +1 -0
  70. keras_hub/src/models/retinanet/prediction_head.py +192 -0
  71. keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
  72. keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
  73. keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
  74. keras_hub/src/models/retinanet/retinanet_object_detector.py +382 -0
  75. keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
  76. keras_hub/src/models/retinanet/retinanet_presets.py +15 -0
  77. keras_hub/src/models/roberta/roberta_presets.py +0 -4
  78. keras_hub/src/models/sam/sam_backbone.py +0 -1
  79. keras_hub/src/models/sam/sam_image_segmenter.py +9 -10
  80. keras_hub/src/models/sam/sam_presets.py +0 -6
  81. keras_hub/src/models/segformer/__init__.py +8 -0
  82. keras_hub/src/models/segformer/segformer_backbone.py +163 -0
  83. keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
  84. keras_hub/src/models/segformer/segformer_image_segmenter.py +171 -0
  85. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
  86. keras_hub/src/models/segformer/segformer_presets.py +124 -0
  87. keras_hub/src/models/stable_diffusion_3/mmdit.py +41 -0
  88. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +38 -21
  89. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +3 -3
  90. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +3 -3
  91. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +28 -4
  92. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +1 -1
  93. keras_hub/src/models/t5/t5_backbone.py +5 -4
  94. keras_hub/src/models/t5/t5_presets.py +41 -13
  95. keras_hub/src/models/text_to_image.py +13 -5
  96. keras_hub/src/models/vgg/vgg_backbone.py +1 -1
  97. keras_hub/src/models/vgg/vgg_presets.py +0 -8
  98. keras_hub/src/models/whisper/whisper_audio_converter.py +1 -1
  99. keras_hub/src/models/whisper/whisper_presets.py +0 -20
  100. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +0 -4
  101. keras_hub/src/tests/test_case.py +25 -0
  102. keras_hub/src/utils/preset_utils.py +17 -4
  103. keras_hub/src/utils/timm/convert_efficientnet.py +449 -0
  104. keras_hub/src/utils/timm/preset_loader.py +3 -0
  105. keras_hub/src/version_utils.py +1 -1
  106. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/METADATA +15 -26
  107. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/RECORD +109 -76
  108. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/WHEEL +1 -1
  109. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,192 @@
1
+ """EfficientNet preset configurations."""
2
+
3
+ backbone_presets = {
4
+ "efficientnet_b0_ra_imagenet": {
5
+ "metadata": {
6
+ "description": (
7
+ "EfficientNet B0 model pre-trained on the ImageNet 1k dataset "
8
+ "with RandAugment recipe."
9
+ ),
10
+ "params": 5288548,
11
+ "path": "efficientnet",
12
+ },
13
+ "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b0_ra_imagenet/1",
14
+ },
15
+ "efficientnet_b0_ra4_e3600_r224_imagenet": {
16
+ "metadata": {
17
+ "description": (
18
+ "EfficientNet B0 model pre-trained on the ImageNet 1k dataset by"
19
+ " Ross Wightman. Trained with timm scripts using hyper-parameters"
20
+ " inspired by the MobileNet-V4 small, mixed with go-to hparams "
21
+ 'from timm and "ResNet Strikes Back".'
22
+ ),
23
+ "params": 5288548,
24
+ "path": "efficientnet",
25
+ },
26
+ "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b0_ra4_e3600_r224_imagenet/1",
27
+ },
28
+ "efficientnet_b1_ft_imagenet": {
29
+ "metadata": {
30
+ "description": (
31
+ "EfficientNet B1 model fine-tuned on the ImageNet 1k dataset."
32
+ ),
33
+ "params": 7794184,
34
+ "path": "efficientnet",
35
+ },
36
+ "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b1_ft_imagenet/1",
37
+ },
38
+ "efficientnet_b1_ra4_e3600_r240_imagenet": {
39
+ "metadata": {
40
+ "description": (
41
+ "EfficientNet B1 model pre-trained on the ImageNet 1k dataset by"
42
+ " Ross Wightman. Trained with timm scripts using hyper-parameters"
43
+ " inspired by the MobileNet-V4 small, mixed with go-to hparams "
44
+ 'from timm and "ResNet Strikes Back".'
45
+ ),
46
+ "params": 7794184,
47
+ "path": "efficientnet",
48
+ },
49
+ "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b1_ra4_e3600_r240_imagenet/1",
50
+ },
51
+ "efficientnet_b2_ra_imagenet": {
52
+ "metadata": {
53
+ "description": (
54
+ "EfficientNet B2 model pre-trained on the ImageNet 1k dataset "
55
+ "with RandAugment recipe."
56
+ ),
57
+ "params": 9109994,
58
+ "path": "efficientnet",
59
+ },
60
+ "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b2_ra_imagenet/1",
61
+ },
62
+ "efficientnet_b3_ra2_imagenet": {
63
+ "metadata": {
64
+ "description": (
65
+ "EfficientNet B3 model pre-trained on the ImageNet 1k dataset "
66
+ "with RandAugment2 recipe."
67
+ ),
68
+ "params": 12233232,
69
+ "path": "efficientnet",
70
+ },
71
+ "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b3_ra2_imagenet/1",
72
+ },
73
+ "efficientnet_b4_ra2_imagenet": {
74
+ "metadata": {
75
+ "description": (
76
+ "EfficientNet B4 model pre-trained on the ImageNet 1k dataset "
77
+ "with RandAugment2 recipe."
78
+ ),
79
+ "params": 19341616,
80
+ "path": "efficientnet",
81
+ },
82
+ "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b4_ra2_imagenet/1",
83
+ },
84
+ "efficientnet_b5_sw_imagenet": {
85
+ "metadata": {
86
+ "description": (
87
+ "EfficientNet B5 model pre-trained on the ImageNet 12k dataset "
88
+ "by Ross Wightman. Based on Swin Transformer train / pretrain "
89
+ "recipe with modifications (related to both DeiT and ConvNeXt recipes)."
90
+ ),
91
+ "params": 30389784,
92
+ "path": "efficientnet",
93
+ },
94
+ "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b5_sw_imagenet/1",
95
+ },
96
+ "efficientnet_b5_sw_ft_imagenet": {
97
+ "metadata": {
98
+ "description": (
99
+ "EfficientNet B5 model pre-trained on the ImageNet 12k dataset "
100
+ "and fine-tuned on ImageNet-1k by Ross Wightman. Based on Swin "
101
+ "Transformer train / pretrain recipe with modifications "
102
+ "(related to both DeiT and ConvNeXt recipes)."
103
+ ),
104
+ "params": 30389784,
105
+ "path": "efficientnet",
106
+ },
107
+ "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b5_sw_ft_imagenet/1",
108
+ },
109
+ "efficientnet_el_ra_imagenet": {
110
+ "metadata": {
111
+ "description": (
112
+ "EfficientNet-EdgeTPU Large model trained on the ImageNet 1k "
113
+ "dataset with RandAugment recipe."
114
+ ),
115
+ "params": 10589712,
116
+ "path": "efficientnet",
117
+ },
118
+ "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b1_ft_imagenet",
119
+ },
120
+ "efficientnet_em_ra2_imagenet": {
121
+ "metadata": {
122
+ "description": (
123
+ "EfficientNet-EdgeTPU Medium model trained on the ImageNet 1k "
124
+ "dataset with RandAugment2 recipe."
125
+ ),
126
+ "params": 6899496,
127
+ "path": "efficientnet",
128
+ },
129
+ "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b1_ft_imagenet",
130
+ },
131
+ "efficientnet_es_ra_imagenet": {
132
+ "metadata": {
133
+ "description": (
134
+ "EfficientNet-EdgeTPU Small model trained on the ImageNet 1k "
135
+ "dataset with RandAugment recipe."
136
+ ),
137
+ "params": 5438392,
138
+ "path": "efficientnet",
139
+ },
140
+ "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b1_ft_imagenet",
141
+ },
142
+ "efficientnet2_rw_m_agc_imagenet": {
143
+ "metadata": {
144
+ "description": (
145
+ "EfficientNet-v2 Medium model trained on the ImageNet 1k "
146
+ "dataset with adaptive gradient clipping."
147
+ ),
148
+ "params": 53236442,
149
+ "official_name": "EfficientNet",
150
+ "path": "efficientnet",
151
+ "model_card": "https://arxiv.org/abs/2104.00298",
152
+ },
153
+ "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet2_rw_m_agc_imagenet",
154
+ },
155
+ "efficientnet2_rw_s_ra2_imagenet": {
156
+ "metadata": {
157
+ "description": (
158
+ "EfficientNet-v2 Small model trained on the ImageNet 1k "
159
+ "dataset with RandAugment2 recipe."
160
+ ),
161
+ "params": 23941296,
162
+ "official_name": "EfficientNet",
163
+ "path": "efficientnet",
164
+ "model_card": "https://arxiv.org/abs/2104.00298",
165
+ },
166
+ "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet2_rw_s_ra2_imagenet",
167
+ },
168
+ "efficientnet2_rw_t_ra2_imagenet": {
169
+ "metadata": {
170
+ "description": (
171
+ "EfficientNet-v2 Tiny model trained on the ImageNet 1k "
172
+ "dataset with RandAugment2 recipe."
173
+ ),
174
+ "params": 13649388,
175
+ "official_name": "EfficientNet",
176
+ "path": "efficientnet",
177
+ "model_card": "https://arxiv.org/abs/2104.00298",
178
+ },
179
+ "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet2_rw_t_ra2_imagenet",
180
+ },
181
+ "efficientnet_lite0_ra_imagenet": {
182
+ "metadata": {
183
+ "description": (
184
+ "EfficientNet-Lite model fine-trained on the ImageNet 1k dataset "
185
+ "with RandAugment recipe."
186
+ ),
187
+ "params": 4652008,
188
+ "path": "efficientnet",
189
+ },
190
+ "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_lite0_ra_imagenet",
191
+ },
192
+ }
@@ -2,15 +2,6 @@ import keras
2
2
 
3
3
  BN_AXIS = 3
4
4
 
5
- CONV_KERNEL_INITIALIZER = {
6
- "class_name": "VarianceScaling",
7
- "config": {
8
- "scale": 2.0,
9
- "mode": "fan_out",
10
- "distribution": "truncated_normal",
11
- },
12
- }
13
-
14
5
 
15
6
  class FusedMBConvBlock(keras.layers.Layer):
16
7
  """Implementation of the FusedMBConv block
@@ -44,13 +35,24 @@ class FusedMBConvBlock(keras.layers.Layer):
44
35
  convolutions
45
36
  strides: default 1, the strides to apply to the expansion phase
46
37
  convolutions
38
+ data_format: str, channels_last (default) or channels_first, expects
39
+ tensors to be of shape (N, H, W, C) or (N, C, H, W) respectively
47
40
  se_ratio: default 0.0, The filters used in the Squeeze-Excitation phase,
48
41
  and are chosen as the maximum between 1 and input_filters*se_ratio
49
42
  batch_norm_momentum: default 0.9, the BatchNormalization momentum
43
+ batch_norm_epsilon: default 1e-3, float, epsilon for batch norm
44
+ calcualtions. Used in denominator for calculations to prevent divide
45
+ by 0 errors.
50
46
  activation: default "swish", the activation function used between
51
47
  convolution operations
48
+ projection_activation: default None, the activation function to use
49
+ after the output projection convoultion
52
50
  dropout: float, the optional dropout rate to apply before the output
53
51
  convolution, defaults to 0.2
52
+ nores: bool, default False, forces no residual connection if True,
53
+ otherwise allows it if False.
54
+ projection_kernel_size: default 1, the kernel_size to apply to the
55
+ output projection phase convolution
54
56
 
55
57
  Returns:
56
58
  A tensor representing a feature map, passed through the FusedMBConv
@@ -67,10 +69,15 @@ class FusedMBConvBlock(keras.layers.Layer):
67
69
  expand_ratio=1,
68
70
  kernel_size=3,
69
71
  strides=1,
72
+ data_format="channels_last",
70
73
  se_ratio=0.0,
71
74
  batch_norm_momentum=0.9,
75
+ batch_norm_epsilon=1e-3,
72
76
  activation="swish",
77
+ projection_activation=None,
73
78
  dropout=0.2,
79
+ nores=False,
80
+ projection_kernel_size=1,
74
81
  **kwargs
75
82
  ):
76
83
  super().__init__(**kwargs)
@@ -79,44 +86,50 @@ class FusedMBConvBlock(keras.layers.Layer):
79
86
  self.expand_ratio = expand_ratio
80
87
  self.kernel_size = kernel_size
81
88
  self.strides = strides
89
+ self.data_format = data_format
82
90
  self.se_ratio = se_ratio
83
91
  self.batch_norm_momentum = batch_norm_momentum
92
+ self.batch_norm_epsilon = batch_norm_epsilon
84
93
  self.activation = activation
94
+ self.projection_activation = projection_activation
85
95
  self.dropout = dropout
96
+ self.nores = nores
97
+ self.projection_kernel_size = projection_kernel_size
86
98
  self.filters = self.input_filters * self.expand_ratio
87
99
  self.filters_se = max(1, int(input_filters * se_ratio))
88
100
 
101
+ padding_pixels = kernel_size // 2
102
+ self.conv1_pad = keras.layers.ZeroPadding2D(
103
+ padding=(padding_pixels, padding_pixels),
104
+ name=self.name + "expand_conv_pad",
105
+ )
89
106
  self.conv1 = keras.layers.Conv2D(
90
107
  filters=self.filters,
91
108
  kernel_size=kernel_size,
92
109
  strides=strides,
93
- kernel_initializer=CONV_KERNEL_INITIALIZER,
94
- padding="same",
95
- data_format="channels_last",
110
+ kernel_initializer=self._conv_kernel_initializer(),
111
+ padding="valid",
112
+ data_format=data_format,
96
113
  use_bias=False,
97
114
  name=self.name + "expand_conv",
98
115
  )
99
116
  self.bn1 = keras.layers.BatchNormalization(
100
117
  axis=BN_AXIS,
101
118
  momentum=self.batch_norm_momentum,
119
+ epsilon=self.batch_norm_epsilon,
102
120
  name=self.name + "expand_bn",
103
121
  )
104
122
  self.act = keras.layers.Activation(
105
123
  self.activation, name=self.name + "expand_activation"
106
124
  )
107
125
 
108
- self.bn2 = keras.layers.BatchNormalization(
109
- axis=BN_AXIS,
110
- momentum=self.batch_norm_momentum,
111
- name=self.name + "bn",
112
- )
113
-
114
126
  self.se_conv1 = keras.layers.Conv2D(
115
127
  self.filters_se,
116
128
  1,
117
129
  padding="same",
130
+ data_format=data_format,
118
131
  activation=self.activation,
119
- kernel_initializer=CONV_KERNEL_INITIALIZER,
132
+ kernel_initializer=self._conv_kernel_initializer(),
120
133
  name=self.name + "se_reduce",
121
134
  )
122
135
 
@@ -124,28 +137,40 @@ class FusedMBConvBlock(keras.layers.Layer):
124
137
  self.filters,
125
138
  1,
126
139
  padding="same",
140
+ data_format=data_format,
127
141
  activation="sigmoid",
128
- kernel_initializer=CONV_KERNEL_INITIALIZER,
142
+ kernel_initializer=self._conv_kernel_initializer(),
129
143
  name=self.name + "se_expand",
130
144
  )
131
145
 
146
+ padding_pixels = projection_kernel_size // 2
147
+ self.output_conv_pad = keras.layers.ZeroPadding2D(
148
+ padding=(padding_pixels, padding_pixels),
149
+ name=self.name + "project_conv_pad",
150
+ )
132
151
  self.output_conv = keras.layers.Conv2D(
133
152
  filters=self.output_filters,
134
- kernel_size=1 if expand_ratio != 1 else kernel_size,
153
+ kernel_size=projection_kernel_size,
135
154
  strides=1,
136
- kernel_initializer=CONV_KERNEL_INITIALIZER,
137
- padding="same",
138
- data_format="channels_last",
155
+ kernel_initializer=self._conv_kernel_initializer(),
156
+ padding="valid",
157
+ data_format=data_format,
139
158
  use_bias=False,
140
159
  name=self.name + "project_conv",
141
160
  )
142
161
 
143
- self.bn3 = keras.layers.BatchNormalization(
162
+ self.bn2 = keras.layers.BatchNormalization(
144
163
  axis=BN_AXIS,
145
164
  momentum=self.batch_norm_momentum,
165
+ epsilon=self.batch_norm_epsilon,
146
166
  name=self.name + "project_bn",
147
167
  )
148
168
 
169
+ if self.projection_activation:
170
+ self.projection_act = keras.layers.Activation(
171
+ self.projection_activation, name=self.name + "projection_act"
172
+ )
173
+
149
174
  if self.dropout:
150
175
  self.dropout_layer = keras.layers.Dropout(
151
176
  self.dropout,
@@ -153,23 +178,33 @@ class FusedMBConvBlock(keras.layers.Layer):
153
178
  name=self.name + "drop",
154
179
  )
155
180
 
181
+ def _conv_kernel_initializer(
182
+ self,
183
+ scale=2.0,
184
+ mode="fan_out",
185
+ distribution="truncated_normal",
186
+ seed=None,
187
+ ):
188
+ return keras.initializers.VarianceScaling(
189
+ scale=scale, mode=mode, distribution=distribution, seed=seed
190
+ )
191
+
156
192
  def build(self, input_shape):
157
193
  if self.name is None:
158
194
  self.name = keras.backend.get_uid("block0")
159
195
 
160
196
  def call(self, inputs):
161
197
  # Expansion phase
162
- if self.expand_ratio != 1:
163
- x = self.conv1(inputs)
164
- x = self.bn1(x)
165
- x = self.act(x)
166
- else:
167
- x = inputs
198
+ x = self.conv1_pad(inputs)
199
+ x = self.conv1(x)
200
+ x = self.bn1(x)
201
+ x = self.act(x)
168
202
 
169
203
  # Squeeze and excite
170
204
  if 0 < self.se_ratio <= 1:
171
205
  se = keras.layers.GlobalAveragePooling2D(
172
- name=self.name + "se_squeeze"
206
+ name=self.name + "se_squeeze",
207
+ data_format=self.data_format,
173
208
  )(x)
174
209
  if BN_AXIS == 1:
175
210
  se_shape = (self.filters, 1, 1)
@@ -186,13 +221,18 @@ class FusedMBConvBlock(keras.layers.Layer):
186
221
  x = keras.layers.multiply([x, se], name=self.name + "se_excite")
187
222
 
188
223
  # Output phase:
224
+ x = self.output_conv_pad(x)
189
225
  x = self.output_conv(x)
190
- x = self.bn3(x)
191
- if self.expand_ratio == 1:
192
- x = self.act(x)
226
+ x = self.bn2(x)
227
+ if self.expand_ratio == 1 and self.projection_activation:
228
+ x = self.projection_act(x)
193
229
 
194
230
  # Residual:
195
- if self.strides == 1 and self.input_filters == self.output_filters:
231
+ if (
232
+ self.strides == 1
233
+ and self.input_filters == self.output_filters
234
+ and not self.nores
235
+ ):
196
236
  if self.dropout:
197
237
  x = self.dropout_layer(x)
198
238
  x = keras.layers.Add(name=self.name + "add")([x, inputs])
@@ -205,10 +245,15 @@ class FusedMBConvBlock(keras.layers.Layer):
205
245
  "expand_ratio": self.expand_ratio,
206
246
  "kernel_size": self.kernel_size,
207
247
  "strides": self.strides,
248
+ "data_format": self.data_format,
208
249
  "se_ratio": self.se_ratio,
209
250
  "batch_norm_momentum": self.batch_norm_momentum,
251
+ "batch_norm_epsilon": self.batch_norm_epsilon,
210
252
  "activation": self.activation,
253
+ "projection_activation": self.projection_activation,
211
254
  "dropout": self.dropout,
255
+ "nores": self.nores,
256
+ "projection_kernel_size": self.projection_kernel_size,
212
257
  }
213
258
 
214
259
  base_config = super().get_config()
@@ -2,15 +2,6 @@ import keras
2
2
 
3
3
  BN_AXIS = 3
4
4
 
5
- CONV_KERNEL_INITIALIZER = {
6
- "class_name": "VarianceScaling",
7
- "config": {
8
- "scale": 2.0,
9
- "mode": "fan_out",
10
- "distribution": "truncated_normal",
11
- },
12
- }
13
-
14
5
 
15
6
  class MBConvBlock(keras.layers.Layer):
16
7
  def __init__(
@@ -20,10 +11,13 @@ class MBConvBlock(keras.layers.Layer):
20
11
  expand_ratio=1,
21
12
  kernel_size=3,
22
13
  strides=1,
14
+ data_format="channels_last",
23
15
  se_ratio=0.0,
24
16
  batch_norm_momentum=0.9,
17
+ batch_norm_epsilon=1e-3,
25
18
  activation="swish",
26
19
  dropout=0.2,
20
+ nores=False,
27
21
  **kwargs
28
22
  ):
29
23
  """Implementation of the MBConv block
@@ -59,6 +53,9 @@ class MBConvBlock(keras.layers.Layer):
59
53
  is above 0. The filters used in this phase are chosen as the
60
54
  maximum between 1 and input_filters*se_ratio
61
55
  batch_norm_momentum: default 0.9, the BatchNormalization momentum
56
+ batch_norm_epsilon: default 1e-3, float, epsilon for batch norm
57
+ calcualtions. Used in denominator for calculations to prevent
58
+ divide by 0 errors.
62
59
  activation: default "swish", the activation function used between
63
60
  convolution operations
64
61
  dropout: float, the optional dropout rate to apply before the output
@@ -79,10 +76,13 @@ class MBConvBlock(keras.layers.Layer):
79
76
  self.expand_ratio = expand_ratio
80
77
  self.kernel_size = kernel_size
81
78
  self.strides = strides
79
+ self.data_format = data_format
82
80
  self.se_ratio = se_ratio
83
81
  self.batch_norm_momentum = batch_norm_momentum
82
+ self.batch_norm_epsilon = batch_norm_epsilon
84
83
  self.activation = activation
85
84
  self.dropout = dropout
85
+ self.nores = nores
86
86
  self.filters = self.input_filters * self.expand_ratio
87
87
  self.filters_se = max(1, int(input_filters * se_ratio))
88
88
 
@@ -90,15 +90,16 @@ class MBConvBlock(keras.layers.Layer):
90
90
  filters=self.filters,
91
91
  kernel_size=1,
92
92
  strides=1,
93
- kernel_initializer=CONV_KERNEL_INITIALIZER,
93
+ kernel_initializer=self._conv_kernel_initializer(),
94
94
  padding="same",
95
- data_format="channels_last",
95
+ data_format=data_format,
96
96
  use_bias=False,
97
97
  name=self.name + "expand_conv",
98
98
  )
99
99
  self.bn1 = keras.layers.BatchNormalization(
100
100
  axis=BN_AXIS,
101
101
  momentum=self.batch_norm_momentum,
102
+ epsilon=self.batch_norm_epsilon,
102
103
  name=self.name + "expand_bn",
103
104
  )
104
105
  self.act = keras.layers.Activation(
@@ -107,9 +108,9 @@ class MBConvBlock(keras.layers.Layer):
107
108
  self.depthwise = keras.layers.DepthwiseConv2D(
108
109
  kernel_size=self.kernel_size,
109
110
  strides=self.strides,
110
- depthwise_initializer=CONV_KERNEL_INITIALIZER,
111
+ depthwise_initializer=self._conv_kernel_initializer(),
111
112
  padding="same",
112
- data_format="channels_last",
113
+ data_format=data_format,
113
114
  use_bias=False,
114
115
  name=self.name + "dwconv2",
115
116
  )
@@ -117,6 +118,7 @@ class MBConvBlock(keras.layers.Layer):
117
118
  self.bn2 = keras.layers.BatchNormalization(
118
119
  axis=BN_AXIS,
119
120
  momentum=self.batch_norm_momentum,
121
+ epsilon=self.batch_norm_epsilon,
120
122
  name=self.name + "bn",
121
123
  )
122
124
 
@@ -124,8 +126,9 @@ class MBConvBlock(keras.layers.Layer):
124
126
  self.filters_se,
125
127
  1,
126
128
  padding="same",
129
+ data_format=data_format,
127
130
  activation=self.activation,
128
- kernel_initializer=CONV_KERNEL_INITIALIZER,
131
+ kernel_initializer=self._conv_kernel_initializer(),
129
132
  name=self.name + "se_reduce",
130
133
  )
131
134
 
@@ -133,18 +136,25 @@ class MBConvBlock(keras.layers.Layer):
133
136
  self.filters,
134
137
  1,
135
138
  padding="same",
139
+ data_format=data_format,
136
140
  activation="sigmoid",
137
- kernel_initializer=CONV_KERNEL_INITIALIZER,
141
+ kernel_initializer=self._conv_kernel_initializer(),
138
142
  name=self.name + "se_expand",
139
143
  )
140
144
 
145
+ projection_kernel_size = 1 if expand_ratio != 1 else kernel_size
146
+ padding_pixels = projection_kernel_size // 2
147
+ self.output_conv_pad = keras.layers.ZeroPadding2D(
148
+ padding=(padding_pixels, padding_pixels),
149
+ name=self.name + "project_conv_pad",
150
+ )
141
151
  self.output_conv = keras.layers.Conv2D(
142
152
  filters=self.output_filters,
143
- kernel_size=1 if expand_ratio != 1 else kernel_size,
153
+ kernel_size=projection_kernel_size,
144
154
  strides=1,
145
- kernel_initializer=CONV_KERNEL_INITIALIZER,
146
- padding="same",
147
- data_format="channels_last",
155
+ kernel_initializer=self._conv_kernel_initializer(),
156
+ padding="valid",
157
+ data_format=data_format,
148
158
  use_bias=False,
149
159
  name=self.name + "project_conv",
150
160
  )
@@ -152,6 +162,7 @@ class MBConvBlock(keras.layers.Layer):
152
162
  self.bn3 = keras.layers.BatchNormalization(
153
163
  axis=BN_AXIS,
154
164
  momentum=self.batch_norm_momentum,
165
+ epsilon=self.batch_norm_epsilon,
155
166
  name=self.name + "project_bn",
156
167
  )
157
168
 
@@ -162,6 +173,17 @@ class MBConvBlock(keras.layers.Layer):
162
173
  name=self.name + "drop",
163
174
  )
164
175
 
176
+ def _conv_kernel_initializer(
177
+ self,
178
+ scale=2.0,
179
+ mode="fan_out",
180
+ distribution="truncated_normal",
181
+ seed=None,
182
+ ):
183
+ return keras.initializers.VarianceScaling(
184
+ scale=scale, mode=mode, distribution=distribution, seed=seed
185
+ )
186
+
165
187
  def build(self, input_shape):
166
188
  if self.name is None:
167
189
  self.name = keras.backend.get_uid("block0")
@@ -183,7 +205,8 @@ class MBConvBlock(keras.layers.Layer):
183
205
  # Squeeze and excite
184
206
  if 0 < self.se_ratio <= 1:
185
207
  se = keras.layers.GlobalAveragePooling2D(
186
- name=self.name + "se_squeeze"
208
+ name=self.name + "se_squeeze",
209
+ data_format=self.data_format,
187
210
  )(x)
188
211
  if BN_AXIS == 1:
189
212
  se_shape = (self.filters, 1, 1)
@@ -199,10 +222,15 @@ class MBConvBlock(keras.layers.Layer):
199
222
  x = keras.layers.multiply([x, se], name=self.name + "se_excite")
200
223
 
201
224
  # Output phase
225
+ x = self.output_conv_pad(x)
202
226
  x = self.output_conv(x)
203
227
  x = self.bn3(x)
204
228
 
205
- if self.strides == 1 and self.input_filters == self.output_filters:
229
+ if (
230
+ self.strides == 1
231
+ and self.input_filters == self.output_filters
232
+ and not self.nores
233
+ ):
206
234
  if self.dropout:
207
235
  x = self.dropout_layer(x)
208
236
  x = keras.layers.Add(name=self.name + "add")([x, inputs])
@@ -215,10 +243,13 @@ class MBConvBlock(keras.layers.Layer):
215
243
  "expand_ratio": self.expand_ratio,
216
244
  "kernel_size": self.kernel_size,
217
245
  "strides": self.strides,
246
+ "data_format": self.data_format,
218
247
  "se_ratio": self.se_ratio,
219
248
  "batch_norm_momentum": self.batch_norm_momentum,
249
+ "batch_norm_epsilon": self.batch_norm_epsilon,
220
250
  "activation": self.activation,
221
251
  "dropout": self.dropout,
252
+ "nores": self.nores,
222
253
  }
223
254
  base_config = super().get_config()
224
255
  return dict(list(base_config.items()) + list(config.items()))