keras-hub-nightly 0.16.1.dev202410200345__py3-none-any.whl → 0.19.0.dev202412070351__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. keras_hub/api/layers/__init__.py +12 -0
  2. keras_hub/api/models/__init__.py +32 -0
  3. keras_hub/src/bounding_box/__init__.py +2 -0
  4. keras_hub/src/bounding_box/converters.py +102 -12
  5. keras_hub/src/layers/modeling/rms_normalization.py +34 -0
  6. keras_hub/src/layers/modeling/transformer_encoder.py +27 -7
  7. keras_hub/src/layers/preprocessing/image_converter.py +5 -0
  8. keras_hub/src/models/albert/albert_presets.py +0 -8
  9. keras_hub/src/models/bart/bart_presets.py +0 -6
  10. keras_hub/src/models/bert/bert_presets.py +0 -20
  11. keras_hub/src/models/bloom/bloom_presets.py +0 -16
  12. keras_hub/src/models/clip/__init__.py +5 -0
  13. keras_hub/src/models/clip/clip_backbone.py +286 -0
  14. keras_hub/src/models/clip/clip_encoder_block.py +19 -4
  15. keras_hub/src/models/clip/clip_image_converter.py +8 -0
  16. keras_hub/src/models/clip/clip_presets.py +93 -0
  17. keras_hub/src/models/clip/clip_text_encoder.py +4 -1
  18. keras_hub/src/models/clip/clip_tokenizer.py +18 -3
  19. keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
  20. keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
  21. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +0 -10
  22. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +0 -2
  23. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +5 -3
  24. keras_hub/src/models/densenet/densenet_backbone.py +1 -1
  25. keras_hub/src/models/densenet/densenet_presets.py +0 -6
  26. keras_hub/src/models/distil_bert/distil_bert_presets.py +0 -6
  27. keras_hub/src/models/efficientnet/__init__.py +9 -0
  28. keras_hub/src/models/efficientnet/cba.py +141 -0
  29. keras_hub/src/models/efficientnet/efficientnet_backbone.py +139 -56
  30. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
  31. keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
  32. keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
  33. keras_hub/src/models/efficientnet/efficientnet_presets.py +192 -0
  34. keras_hub/src/models/efficientnet/fusedmbconv.py +81 -36
  35. keras_hub/src/models/efficientnet/mbconv.py +52 -21
  36. keras_hub/src/models/electra/electra_presets.py +0 -12
  37. keras_hub/src/models/f_net/f_net_presets.py +0 -4
  38. keras_hub/src/models/falcon/falcon_presets.py +0 -2
  39. keras_hub/src/models/flux/__init__.py +5 -0
  40. keras_hub/src/models/flux/flux_layers.py +494 -0
  41. keras_hub/src/models/flux/flux_maths.py +218 -0
  42. keras_hub/src/models/flux/flux_model.py +231 -0
  43. keras_hub/src/models/flux/flux_presets.py +14 -0
  44. keras_hub/src/models/flux/flux_text_to_image.py +142 -0
  45. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
  46. keras_hub/src/models/gemma/gemma_presets.py +0 -40
  47. keras_hub/src/models/gpt2/gpt2_presets.py +0 -9
  48. keras_hub/src/models/image_object_detector.py +87 -0
  49. keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
  50. keras_hub/src/models/image_to_image.py +16 -10
  51. keras_hub/src/models/inpaint.py +20 -13
  52. keras_hub/src/models/llama/llama_backbone.py +1 -1
  53. keras_hub/src/models/llama/llama_presets.py +5 -15
  54. keras_hub/src/models/llama3/llama3_presets.py +0 -8
  55. keras_hub/src/models/mistral/mistral_presets.py +0 -6
  56. keras_hub/src/models/mit/mit_backbone.py +41 -27
  57. keras_hub/src/models/mit/mit_layers.py +9 -7
  58. keras_hub/src/models/mit/mit_presets.py +12 -24
  59. keras_hub/src/models/opt/opt_presets.py +0 -8
  60. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +61 -11
  61. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
  62. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +166 -10
  63. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +12 -11
  64. keras_hub/src/models/phi3/phi3_presets.py +0 -4
  65. keras_hub/src/models/resnet/resnet_presets.py +10 -42
  66. keras_hub/src/models/retinanet/__init__.py +5 -0
  67. keras_hub/src/models/retinanet/anchor_generator.py +52 -53
  68. keras_hub/src/models/retinanet/feature_pyramid.py +99 -36
  69. keras_hub/src/models/retinanet/non_max_supression.py +1 -0
  70. keras_hub/src/models/retinanet/prediction_head.py +192 -0
  71. keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
  72. keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
  73. keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
  74. keras_hub/src/models/retinanet/retinanet_object_detector.py +382 -0
  75. keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
  76. keras_hub/src/models/retinanet/retinanet_presets.py +15 -0
  77. keras_hub/src/models/roberta/roberta_presets.py +0 -4
  78. keras_hub/src/models/sam/sam_backbone.py +0 -1
  79. keras_hub/src/models/sam/sam_image_segmenter.py +9 -10
  80. keras_hub/src/models/sam/sam_presets.py +0 -6
  81. keras_hub/src/models/segformer/__init__.py +8 -0
  82. keras_hub/src/models/segformer/segformer_backbone.py +163 -0
  83. keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
  84. keras_hub/src/models/segformer/segformer_image_segmenter.py +171 -0
  85. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
  86. keras_hub/src/models/segformer/segformer_presets.py +124 -0
  87. keras_hub/src/models/stable_diffusion_3/mmdit.py +41 -0
  88. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +38 -21
  89. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +3 -3
  90. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +3 -3
  91. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +28 -4
  92. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +1 -1
  93. keras_hub/src/models/t5/t5_backbone.py +5 -4
  94. keras_hub/src/models/t5/t5_presets.py +41 -13
  95. keras_hub/src/models/text_to_image.py +13 -5
  96. keras_hub/src/models/vgg/vgg_backbone.py +1 -1
  97. keras_hub/src/models/vgg/vgg_presets.py +0 -8
  98. keras_hub/src/models/whisper/whisper_audio_converter.py +1 -1
  99. keras_hub/src/models/whisper/whisper_presets.py +0 -20
  100. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +0 -4
  101. keras_hub/src/tests/test_case.py +25 -0
  102. keras_hub/src/utils/preset_utils.py +17 -4
  103. keras_hub/src/utils/timm/convert_efficientnet.py +449 -0
  104. keras_hub/src/utils/timm/preset_loader.py +3 -0
  105. keras_hub/src/version_utils.py +1 -1
  106. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/METADATA +15 -26
  107. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/RECORD +109 -76
  108. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/WHEEL +1 -1
  109. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,159 @@
1
+ from keras import layers
2
+
3
+ from keras_hub.src.api_export import keras_hub_export
4
+ from keras_hub.src.models.backbone import Backbone
5
+ from keras_hub.src.models.clip.clip_encoder_block import CLIPEncoderBlock
6
+ from keras_hub.src.models.clip.clip_vision_embedding import CLIPVisionEmbedding
7
+ from keras_hub.src.utils.keras_utils import standardize_data_format
8
+
9
+
10
+ @keras_hub_export("keras_hub.models.CLIPVisionEncoder")
11
+ class CLIPVisionEncoder(Backbone):
12
+ """CLIP vision core network with hyperparameters.
13
+
14
+ Args:
15
+ patch_size: int. The size of each square patch in the input image.
16
+ hidden_dim: int. The size of the transformer hidden state at the end
17
+ of each transformer layer.
18
+ num_layers: int. The number of transformer layers.
19
+ num_heads: int. The number of attention heads for each transformer.
20
+ intermediate_dim: int. The output dimension of the first Dense layer in
21
+ a two-layer feedforward network for each transformer.
22
+ intermediate_activation: activation function. The activation that
23
+ is used for the first Dense layer in a two-layer feedforward network
24
+ for each transformer.
25
+ intermediate_output_index: optional int. The index of the intermediate
26
+ output. If specified, the output will become a dictionary with two
27
+ keys `"sequence_output"` and `"intermediate_output"`.
28
+ image_shape: tuple. The input shape without the batch size. Defaults to
29
+ `(224, 224, 3)`.
30
+ data_format: `None` or str. If specified, either `"channels_last"` or
31
+ `"channels_first"`. The ordering of the dimensions in the
32
+ inputs. `"channels_last"` corresponds to inputs with shape
33
+ `(batch_size, height, width, channels)`
34
+ while `"channels_first"` corresponds to inputs with shape
35
+ `(batch_size, channels, height, width)`. It defaults to the
36
+ `image_data_format` value found in your Keras config file at
37
+ `~/.keras/keras.json`. If you never set it, then it will be
38
+ `"channels_last"`.
39
+ dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
40
+ for the models computations and weights. Note that some
41
+ computations, such as softmax and layer normalization will always
42
+ be done a float32 precision regardless of dtype.
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ patch_size,
48
+ hidden_dim,
49
+ num_layers,
50
+ num_heads,
51
+ intermediate_dim,
52
+ intermediate_activation="quick_gelu",
53
+ intermediate_output_index=None,
54
+ image_shape=(224, 224, 3),
55
+ data_format=None,
56
+ dtype=None,
57
+ name=None,
58
+ **kwargs,
59
+ ):
60
+ data_format = standardize_data_format(data_format)
61
+ if data_format == "channels_last":
62
+ height, width = image_shape[0], image_shape[1]
63
+ else:
64
+ height, width = image_shape[1], image_shape[2]
65
+ if height != width:
66
+ raise ValueError(
67
+ "`CLIPVisionEncoder` expects the height and width to be the "
68
+ f"same in `image_shape`. Received: image_shape={image_shape}"
69
+ )
70
+
71
+ if (
72
+ intermediate_output_index is not None
73
+ and intermediate_output_index < 0
74
+ ):
75
+ intermediate_output_index += num_layers
76
+
77
+ # `prefix` is used to prevent duplicate name when utilizing multiple
78
+ # CLIP models within a single model, such as in StableDiffusion3.
79
+ prefix = str(name) + "_" if name is not None else ""
80
+
81
+ # === Layers ===
82
+ self.embedding = CLIPVisionEmbedding(
83
+ hidden_dim=hidden_dim,
84
+ patch_size=patch_size,
85
+ image_size=height,
86
+ data_format=data_format,
87
+ dtype=dtype,
88
+ name=f"{prefix}embedding",
89
+ )
90
+ self.pre_layer_norm = layers.LayerNormalization(
91
+ epsilon=1e-5, dtype=dtype, name=f"{prefix}pre_layer_norm"
92
+ )
93
+ self.encoder_layers = [
94
+ CLIPEncoderBlock(
95
+ hidden_dim,
96
+ num_heads,
97
+ intermediate_dim,
98
+ intermediate_activation,
99
+ use_causal_mask=False, # `False` in the vision encoder.
100
+ dtype=dtype,
101
+ name=f"{prefix}encoder_block_{i}",
102
+ )
103
+ for i in range(num_layers)
104
+ ]
105
+ self.layer_norm = layers.LayerNormalization(
106
+ epsilon=1e-5, dtype=dtype, name=f"{prefix}layer_norm"
107
+ )
108
+
109
+ # === Functional Model ===
110
+ image_input = layers.Input(shape=image_shape, name="images")
111
+ x = self.embedding(image_input)
112
+ x = self.pre_layer_norm(x)
113
+ intermediate_output = None
114
+ for i, block in enumerate(self.encoder_layers):
115
+ x = block(x)
116
+ if i == intermediate_output_index:
117
+ intermediate_output = x
118
+ sequence_output = self.layer_norm(x)
119
+
120
+ if intermediate_output_index is not None:
121
+ outputs = {
122
+ "sequence_output": sequence_output,
123
+ "intermediate_output": intermediate_output,
124
+ }
125
+ else:
126
+ outputs = sequence_output
127
+ super().__init__(
128
+ inputs={"images": image_input},
129
+ outputs=outputs,
130
+ dtype=dtype,
131
+ name=name,
132
+ **kwargs,
133
+ )
134
+
135
+ # === Config ===
136
+ self.patch_size = patch_size
137
+ self.hidden_dim = hidden_dim
138
+ self.num_layers = num_layers
139
+ self.num_heads = num_heads
140
+ self.intermediate_dim = intermediate_dim
141
+ self.intermediate_activation = intermediate_activation
142
+ self.intermediate_output_index = intermediate_output_index
143
+ self.image_shape = image_shape
144
+
145
+ def get_config(self):
146
+ config = super().get_config()
147
+ config.update(
148
+ {
149
+ "patch_size": self.patch_size,
150
+ "hidden_dim": self.hidden_dim,
151
+ "num_layers": self.num_layers,
152
+ "num_heads": self.num_heads,
153
+ "intermediate_dim": self.intermediate_dim,
154
+ "intermediate_activation": self.intermediate_activation,
155
+ "intermediate_output_index": self.intermediate_output_index,
156
+ "image_shape": self.image_shape,
157
+ }
158
+ )
159
+ return config
@@ -8,9 +8,7 @@ backbone_presets = {
8
8
  "Trained on English Wikipedia, BookCorpus and OpenWebText."
9
9
  ),
10
10
  "params": 70682112,
11
- "official_name": "DeBERTaV3",
12
11
  "path": "deberta_v3",
13
- "model_card": "https://huggingface.co/microsoft/deberta-v3-xsmall",
14
12
  },
15
13
  "kaggle_handle": "kaggle://keras/deberta_v3/keras/deberta_v3_extra_small_en/2",
16
14
  },
@@ -21,9 +19,7 @@ backbone_presets = {
21
19
  "Trained on English Wikipedia, BookCorpus and OpenWebText."
22
20
  ),
23
21
  "params": 141304320,
24
- "official_name": "DeBERTaV3",
25
22
  "path": "deberta_v3",
26
- "model_card": "https://huggingface.co/microsoft/deberta-v3-small",
27
23
  },
28
24
  "kaggle_handle": "kaggle://keras/deberta_v3/keras/deberta_v3_small_en/2",
29
25
  },
@@ -34,9 +30,7 @@ backbone_presets = {
34
30
  "Trained on English Wikipedia, BookCorpus and OpenWebText."
35
31
  ),
36
32
  "params": 183831552,
37
- "official_name": "DeBERTaV3",
38
33
  "path": "deberta_v3",
39
- "model_card": "https://huggingface.co/microsoft/deberta-v3-base",
40
34
  },
41
35
  "kaggle_handle": "kaggle://keras/deberta_v3/keras/deberta_v3_base_en/2",
42
36
  },
@@ -47,9 +41,7 @@ backbone_presets = {
47
41
  "Trained on English Wikipedia, BookCorpus and OpenWebText."
48
42
  ),
49
43
  "params": 434012160,
50
- "official_name": "DeBERTaV3",
51
44
  "path": "deberta_v3",
52
- "model_card": "https://huggingface.co/microsoft/deberta-v3-large",
53
45
  },
54
46
  "kaggle_handle": "kaggle://keras/deberta_v3/keras/deberta_v3_large_en/2",
55
47
  },
@@ -60,9 +52,7 @@ backbone_presets = {
60
52
  "Trained on the 2.5TB multilingual CC100 dataset."
61
53
  ),
62
54
  "params": 278218752,
63
- "official_name": "DeBERTaV3",
64
55
  "path": "deberta_v3",
65
- "model_card": "https://huggingface.co/microsoft/mdeberta-v3-base",
66
56
  },
67
57
  "kaggle_handle": "kaggle://keras/deberta_v3/keras/deberta_v3_base_multi/2",
68
58
  },
@@ -9,9 +9,7 @@ backbone_presets = {
9
9
  "which is having categorical accuracy of 90.01 and 0.63 Mean IoU."
10
10
  ),
11
11
  "params": 39190656,
12
- "official_name": "DeepLabV3",
13
12
  "path": "deeplab_v3",
14
- "model_card": "https://arxiv.org/abs/1802.02611",
15
13
  },
16
14
  "kaggle_handle": "kaggle://keras/deeplabv3plus/keras/deeplab_v3_plus_resnet50_pascalvoc/3",
17
15
  },
@@ -31,9 +31,9 @@ class DeepLabV3ImageSegmenter(ImageSegmenter):
31
31
  Load a DeepLabV3 preset with all the 21 class, pretrained segmentation head.
32
32
  ```python
33
33
  images = np.ones(shape=(1, 96, 96, 3))
34
- labels = np.zeros(shape=(1, 96, 96, 1))
34
+ labels = np.zeros(shape=(1, 96, 96, 2))
35
35
  segmenter = keras_hub.models.DeepLabV3ImageSegmenter.from_preset(
36
- "deeplabv3_resnet50_pascalvoc",
36
+ "deeplab_v3_plus_resnet50_pascalvoc",
37
37
  )
38
38
  segmenter.predict(images)
39
39
  ```
@@ -41,12 +41,14 @@ class DeepLabV3ImageSegmenter(ImageSegmenter):
41
41
  Specify `num_classes` to load randomly initialized segmentation head.
42
42
  ```python
43
43
  segmenter = keras_hub.models.DeepLabV3ImageSegmenter.from_preset(
44
- "deeplabv3_resnet50_pascalvoc",
44
+ "deeplab_v3_plus_resnet50_pascalvoc",
45
45
  num_classes=2,
46
46
  )
47
+ segmenter.preprocessor.image_size = (96, 96)
47
48
  segmenter.fit(images, labels, epochs=3)
48
49
  segmenter.predict(images) # Trained 2 class segmentation.
49
50
  ```
51
+
50
52
  Load DeepLabv3+ presets a extension of DeepLabv3 by adding a simple yet
51
53
  effective decoder module to refine the segmentation results especially
52
54
  along object boundaries.
@@ -29,7 +29,7 @@ class DenseNetBackbone(FeaturePyramidBackbone):
29
29
  input_data = np.ones(shape=(8, 224, 224, 3))
30
30
 
31
31
  # Pretrained backbone
32
- model = keras_hub.models.DenseNetBackbone.from_preset("densenet121_imagenet")
32
+ model = keras_hub.models.DenseNetBackbone.from_preset("densenet_121_imagenet")
33
33
  model(input_data)
34
34
 
35
35
  # Randomly initialized backbone with a custom config
@@ -8,9 +8,7 @@ backbone_presets = {
8
8
  "at a 224x224 resolution."
9
9
  ),
10
10
  "params": 7037504,
11
- "official_name": "DenseNet",
12
11
  "path": "densenet",
13
- "model_card": "https://arxiv.org/abs/1608.06993",
14
12
  },
15
13
  "kaggle_handle": "kaggle://keras/densenet/keras/densenet_121_imagenet/2",
16
14
  },
@@ -21,9 +19,7 @@ backbone_presets = {
21
19
  "at a 224x224 resolution."
22
20
  ),
23
21
  "params": 12642880,
24
- "official_name": "DenseNet",
25
22
  "path": "densenet",
26
- "model_card": "https://arxiv.org/abs/1608.06993",
27
23
  },
28
24
  "kaggle_handle": "kaggle://keras/densenet/keras/densenet_169_imagenet/2",
29
25
  },
@@ -34,9 +30,7 @@ backbone_presets = {
34
30
  "at a 224x224 resolution."
35
31
  ),
36
32
  "params": 18321984,
37
- "official_name": "DenseNet",
38
33
  "path": "densenet",
39
- "model_card": "https://arxiv.org/abs/1608.06993",
40
34
  },
41
35
  "kaggle_handle": "kaggle://keras/densenet/keras/densenet_201_imagenet/2",
42
36
  },
@@ -9,9 +9,7 @@ backbone_presets = {
9
9
  "teacher model."
10
10
  ),
11
11
  "params": 66362880,
12
- "official_name": "DistilBERT",
13
12
  "path": "distil_bert",
14
- "model_card": "https://huggingface.co/distilbert-base-uncased",
15
13
  },
16
14
  "kaggle_handle": "kaggle://keras/distil_bert/keras/distil_bert_base_en_uncased/2",
17
15
  },
@@ -23,9 +21,7 @@ backbone_presets = {
23
21
  "teacher model."
24
22
  ),
25
23
  "params": 65190912,
26
- "official_name": "DistilBERT",
27
24
  "path": "distil_bert",
28
- "model_card": "https://huggingface.co/distilbert-base-cased",
29
25
  },
30
26
  "kaggle_handle": "kaggle://keras/distil_bert/keras/distil_bert_base_en/2",
31
27
  },
@@ -35,9 +31,7 @@ backbone_presets = {
35
31
  "6-layer DistilBERT model where case is maintained. Trained on Wikipedias of 104 languages"
36
32
  ),
37
33
  "params": 134734080,
38
- "official_name": "DistilBERT",
39
34
  "path": "distil_bert",
40
- "model_card": "https://huggingface.co/distilbert-base-multilingual-cased",
41
35
  },
42
36
  "kaggle_handle": "kaggle://keras/distil_bert/keras/distil_bert_base_multi/2",
43
37
  },
@@ -0,0 +1,9 @@
1
+ from keras_hub.src.models.efficientnet.efficientnet_backbone import (
2
+ EfficientNetBackbone,
3
+ )
4
+ from keras_hub.src.models.efficientnet.efficientnet_presets import (
5
+ backbone_presets,
6
+ )
7
+ from keras_hub.src.utils.preset_utils import register_presets
8
+
9
+ register_presets(backbone_presets, EfficientNetBackbone)
@@ -0,0 +1,141 @@
1
+ import keras
2
+
3
+ BN_AXIS = 3
4
+
5
+
6
+ class CBABlock(keras.layers.Layer):
7
+ """
8
+ Args:
9
+ input_filters: int, the number of input filters
10
+ output_filters: int, the number of output filters
11
+ kernel_size: default 3, the kernel_size to apply to the expansion phase
12
+ convolutions
13
+ strides: default 1, the strides to apply to the expansion phase
14
+ convolutions
15
+ data_format: str, channels_last (default) or channels_first, expects
16
+ tensors to be of shape (N, H, W, C) or (N, C, H, W) respectively
17
+ batch_norm_momentum: default 0.9, the BatchNormalization momentum
18
+ batch_norm_epsilon: default 1e-3, the BatchNormalization epsilon
19
+ activation: default "swish", the activation function used between
20
+ convolution operations
21
+ dropout: float, the optional dropout rate to apply before the output
22
+ convolution, defaults to 0.2
23
+ nores: bool, default False, forces no residual connection if True,
24
+ otherwise allows it if False.
25
+
26
+ Returns:
27
+ A tensor representing a feature map, passed through the ConvBNAct
28
+ block
29
+
30
+ Note:
31
+ Not intended to be used outside of the EfficientNet architecture.
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ input_filters,
37
+ output_filters,
38
+ kernel_size=3,
39
+ strides=1,
40
+ data_format="channels_last",
41
+ batch_norm_momentum=0.9,
42
+ batch_norm_epsilon=1e-3,
43
+ activation="swish",
44
+ dropout=0.2,
45
+ nores=False,
46
+ **kwargs
47
+ ):
48
+ super().__init__(**kwargs)
49
+ self.input_filters = input_filters
50
+ self.output_filters = output_filters
51
+ self.kernel_size = kernel_size
52
+ self.strides = strides
53
+ self.data_format = data_format
54
+ self.batch_norm_momentum = batch_norm_momentum
55
+ self.batch_norm_epsilon = batch_norm_epsilon
56
+ self.activation = activation
57
+ self.dropout = dropout
58
+ self.nores = nores
59
+
60
+ padding_pixels = kernel_size // 2
61
+ self.conv1_pad = keras.layers.ZeroPadding2D(
62
+ padding=(padding_pixels, padding_pixels),
63
+ name=self.name + "conv_pad",
64
+ )
65
+ self.conv1 = keras.layers.Conv2D(
66
+ filters=self.output_filters,
67
+ kernel_size=kernel_size,
68
+ strides=strides,
69
+ kernel_initializer=self._conv_kernel_initializer(),
70
+ padding="valid",
71
+ data_format=data_format,
72
+ use_bias=False,
73
+ name=self.name + "conv",
74
+ )
75
+ self.bn1 = keras.layers.BatchNormalization(
76
+ axis=BN_AXIS,
77
+ momentum=self.batch_norm_momentum,
78
+ epsilon=self.batch_norm_epsilon,
79
+ name=self.name + "bn",
80
+ )
81
+ self.act = keras.layers.Activation(
82
+ self.activation, name=self.name + "activation"
83
+ )
84
+
85
+ if self.dropout:
86
+ self.dropout_layer = keras.layers.Dropout(
87
+ self.dropout,
88
+ noise_shape=(None, 1, 1, 1),
89
+ name=self.name + "drop",
90
+ )
91
+
92
+ def _conv_kernel_initializer(
93
+ self,
94
+ scale=2.0,
95
+ mode="fan_out",
96
+ distribution="truncated_normal",
97
+ seed=None,
98
+ ):
99
+ return keras.initializers.VarianceScaling(
100
+ scale=scale, mode=mode, distribution=distribution, seed=seed
101
+ )
102
+
103
+ def build(self, input_shape):
104
+ if self.name is None:
105
+ self.name = keras.backend.get_uid("block0")
106
+
107
+ def call(self, inputs):
108
+ x = self.conv1_pad(inputs)
109
+ x = self.conv1(x)
110
+ x = self.bn1(x)
111
+ x = self.act(x)
112
+
113
+ # Residual:
114
+ if (
115
+ self.strides == 1
116
+ and self.input_filters == self.output_filters
117
+ and not self.nores
118
+ ):
119
+ if self.dropout:
120
+ x = self.dropout_layer(x)
121
+ x = keras.layers.Add(name=self.name + "add")([x, inputs])
122
+ return x
123
+
124
+ def get_config(self):
125
+ config = super().get_config()
126
+ config.update(
127
+ {
128
+ "input_filters": self.input_filters,
129
+ "output_filters": self.output_filters,
130
+ "kernel_size": self.kernel_size,
131
+ "strides": self.strides,
132
+ "data_format": self.data_format,
133
+ "batch_norm_momentum": self.batch_norm_momentum,
134
+ "batch_norm_epsilon": self.batch_norm_epsilon,
135
+ "activation": self.activation,
136
+ "dropout": self.dropout,
137
+ "nores": self.nores,
138
+ }
139
+ )
140
+
141
+ return config