keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl → 0.16.0.dev2024092017__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (198) hide show
  1. keras_hub/__init__.py +0 -6
  2. keras_hub/api/__init__.py +2 -0
  3. keras_hub/api/bounding_box/__init__.py +36 -0
  4. keras_hub/api/layers/__init__.py +14 -0
  5. keras_hub/api/models/__init__.py +97 -48
  6. keras_hub/api/tokenizers/__init__.py +30 -0
  7. keras_hub/api/utils/__init__.py +22 -0
  8. keras_hub/src/api_export.py +15 -9
  9. keras_hub/src/bounding_box/__init__.py +13 -0
  10. keras_hub/src/bounding_box/converters.py +529 -0
  11. keras_hub/src/bounding_box/formats.py +162 -0
  12. keras_hub/src/bounding_box/iou.py +263 -0
  13. keras_hub/src/bounding_box/to_dense.py +95 -0
  14. keras_hub/src/bounding_box/to_ragged.py +99 -0
  15. keras_hub/src/bounding_box/utils.py +194 -0
  16. keras_hub/src/bounding_box/validate_format.py +99 -0
  17. keras_hub/src/layers/preprocessing/audio_converter.py +121 -0
  18. keras_hub/src/layers/preprocessing/image_converter.py +130 -0
  19. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +2 -0
  20. keras_hub/src/layers/preprocessing/multi_segment_packer.py +9 -8
  21. keras_hub/src/layers/preprocessing/preprocessing_layer.py +2 -29
  22. keras_hub/src/layers/preprocessing/random_deletion.py +33 -31
  23. keras_hub/src/layers/preprocessing/random_swap.py +33 -31
  24. keras_hub/src/layers/preprocessing/resizing_image_converter.py +101 -0
  25. keras_hub/src/layers/preprocessing/start_end_packer.py +3 -2
  26. keras_hub/src/models/albert/__init__.py +1 -2
  27. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +6 -86
  28. keras_hub/src/models/albert/{albert_classifier.py → albert_text_classifier.py} +34 -10
  29. keras_hub/src/models/albert/{albert_preprocessor.py → albert_text_classifier_preprocessor.py} +14 -70
  30. keras_hub/src/models/albert/albert_tokenizer.py +17 -36
  31. keras_hub/src/models/backbone.py +12 -34
  32. keras_hub/src/models/bart/__init__.py +1 -2
  33. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +21 -148
  34. keras_hub/src/models/bart/bart_tokenizer.py +12 -39
  35. keras_hub/src/models/bert/__init__.py +1 -5
  36. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +6 -87
  37. keras_hub/src/models/bert/bert_presets.py +1 -4
  38. keras_hub/src/models/bert/{bert_classifier.py → bert_text_classifier.py} +19 -12
  39. keras_hub/src/models/bert/{bert_preprocessor.py → bert_text_classifier_preprocessor.py} +14 -70
  40. keras_hub/src/models/bert/bert_tokenizer.py +17 -35
  41. keras_hub/src/models/bloom/__init__.py +1 -2
  42. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +6 -91
  43. keras_hub/src/models/bloom/bloom_tokenizer.py +12 -41
  44. keras_hub/src/models/causal_lm.py +10 -29
  45. keras_hub/src/models/causal_lm_preprocessor.py +195 -0
  46. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +54 -15
  47. keras_hub/src/models/deberta_v3/__init__.py +1 -4
  48. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +14 -77
  49. keras_hub/src/models/deberta_v3/{deberta_v3_classifier.py → deberta_v3_text_classifier.py} +16 -11
  50. keras_hub/src/models/deberta_v3/{deberta_v3_preprocessor.py → deberta_v3_text_classifier_preprocessor.py} +23 -64
  51. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +30 -25
  52. keras_hub/src/models/densenet/densenet_backbone.py +46 -22
  53. keras_hub/src/models/distil_bert/__init__.py +1 -4
  54. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +14 -76
  55. keras_hub/src/models/distil_bert/{distil_bert_classifier.py → distil_bert_text_classifier.py} +17 -12
  56. keras_hub/src/models/distil_bert/{distil_bert_preprocessor.py → distil_bert_text_classifier_preprocessor.py} +23 -63
  57. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +19 -35
  58. keras_hub/src/models/efficientnet/__init__.py +13 -0
  59. keras_hub/src/models/efficientnet/efficientnet_backbone.py +569 -0
  60. keras_hub/src/models/efficientnet/fusedmbconv.py +229 -0
  61. keras_hub/src/models/efficientnet/mbconv.py +238 -0
  62. keras_hub/src/models/electra/__init__.py +1 -2
  63. keras_hub/src/models/electra/electra_tokenizer.py +17 -32
  64. keras_hub/src/models/f_net/__init__.py +1 -2
  65. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +12 -78
  66. keras_hub/src/models/f_net/{f_net_classifier.py → f_net_text_classifier.py} +17 -10
  67. keras_hub/src/models/f_net/{f_net_preprocessor.py → f_net_text_classifier_preprocessor.py} +19 -63
  68. keras_hub/src/models/f_net/f_net_tokenizer.py +17 -35
  69. keras_hub/src/models/falcon/__init__.py +1 -2
  70. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +6 -89
  71. keras_hub/src/models/falcon/falcon_tokenizer.py +12 -35
  72. keras_hub/src/models/gemma/__init__.py +1 -2
  73. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +6 -90
  74. keras_hub/src/models/gemma/gemma_decoder_block.py +1 -1
  75. keras_hub/src/models/gemma/gemma_tokenizer.py +12 -23
  76. keras_hub/src/models/gpt2/__init__.py +1 -2
  77. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +6 -89
  78. keras_hub/src/models/gpt2/gpt2_preprocessor.py +12 -90
  79. keras_hub/src/models/gpt2/gpt2_tokenizer.py +12 -34
  80. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +6 -91
  81. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +12 -34
  82. keras_hub/src/models/image_classifier.py +0 -5
  83. keras_hub/src/models/image_classifier_preprocessor.py +83 -0
  84. keras_hub/src/models/llama/__init__.py +1 -2
  85. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +6 -85
  86. keras_hub/src/models/llama/llama_tokenizer.py +12 -25
  87. keras_hub/src/models/llama3/__init__.py +1 -2
  88. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +6 -89
  89. keras_hub/src/models/llama3/llama3_tokenizer.py +12 -33
  90. keras_hub/src/models/masked_lm.py +0 -2
  91. keras_hub/src/models/masked_lm_preprocessor.py +156 -0
  92. keras_hub/src/models/mistral/__init__.py +1 -2
  93. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +6 -91
  94. keras_hub/src/models/mistral/mistral_tokenizer.py +12 -23
  95. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +2 -2
  96. keras_hub/src/models/mobilenet/__init__.py +13 -0
  97. keras_hub/src/models/mobilenet/mobilenet_backbone.py +530 -0
  98. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +114 -0
  99. keras_hub/src/models/opt/__init__.py +1 -2
  100. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +6 -93
  101. keras_hub/src/models/opt/opt_tokenizer.py +12 -41
  102. keras_hub/src/models/pali_gemma/__init__.py +1 -4
  103. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +28 -28
  104. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +25 -0
  105. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +5 -5
  106. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +11 -3
  107. keras_hub/src/models/phi3/__init__.py +1 -2
  108. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -9
  109. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +6 -89
  110. keras_hub/src/models/phi3/phi3_tokenizer.py +12 -36
  111. keras_hub/src/models/preprocessor.py +72 -83
  112. keras_hub/src/models/resnet/__init__.py +6 -0
  113. keras_hub/src/models/resnet/resnet_backbone.py +390 -42
  114. keras_hub/src/models/resnet/resnet_image_classifier.py +33 -6
  115. keras_hub/src/models/resnet/resnet_image_classifier_preprocessor.py +28 -0
  116. keras_hub/src/models/{llama3/llama3_preprocessor.py → resnet/resnet_image_converter.py} +7 -5
  117. keras_hub/src/models/resnet/resnet_presets.py +95 -0
  118. keras_hub/src/models/retinanet/__init__.py +13 -0
  119. keras_hub/src/models/retinanet/anchor_generator.py +175 -0
  120. keras_hub/src/models/retinanet/box_matcher.py +259 -0
  121. keras_hub/src/models/retinanet/non_max_supression.py +578 -0
  122. keras_hub/src/models/roberta/__init__.py +1 -2
  123. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +22 -74
  124. keras_hub/src/models/roberta/{roberta_classifier.py → roberta_text_classifier.py} +16 -11
  125. keras_hub/src/models/roberta/{roberta_preprocessor.py → roberta_text_classifier_preprocessor.py} +21 -53
  126. keras_hub/src/models/roberta/roberta_tokenizer.py +13 -52
  127. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +269 -0
  128. keras_hub/src/models/stable_diffusion_v3/__init__.py +13 -0
  129. keras_hub/src/models/stable_diffusion_v3/clip_encoder_block.py +103 -0
  130. keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +93 -0
  131. keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py +149 -0
  132. keras_hub/src/models/stable_diffusion_v3/clip_tokenizer.py +167 -0
  133. keras_hub/src/models/stable_diffusion_v3/mmdit.py +427 -0
  134. keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +317 -0
  135. keras_hub/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py +74 -0
  136. keras_hub/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py +155 -0
  137. keras_hub/src/models/stable_diffusion_v3/vae_attention.py +126 -0
  138. keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +186 -0
  139. keras_hub/src/models/t5/__init__.py +1 -2
  140. keras_hub/src/models/t5/t5_tokenizer.py +13 -23
  141. keras_hub/src/models/task.py +71 -116
  142. keras_hub/src/models/{classifier.py → text_classifier.py} +19 -13
  143. keras_hub/src/models/text_classifier_preprocessor.py +138 -0
  144. keras_hub/src/models/whisper/__init__.py +1 -2
  145. keras_hub/src/models/whisper/{whisper_audio_feature_extractor.py → whisper_audio_converter.py} +20 -18
  146. keras_hub/src/models/whisper/whisper_backbone.py +0 -3
  147. keras_hub/src/models/whisper/whisper_presets.py +10 -10
  148. keras_hub/src/models/whisper/whisper_tokenizer.py +20 -16
  149. keras_hub/src/models/xlm_roberta/__init__.py +1 -4
  150. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +26 -72
  151. keras_hub/src/models/xlm_roberta/{xlm_roberta_classifier.py → xlm_roberta_text_classifier.py} +16 -11
  152. keras_hub/src/models/xlm_roberta/{xlm_roberta_preprocessor.py → xlm_roberta_text_classifier_preprocessor.py} +26 -53
  153. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +25 -10
  154. keras_hub/src/tests/test_case.py +46 -0
  155. keras_hub/src/tokenizers/byte_pair_tokenizer.py +30 -17
  156. keras_hub/src/tokenizers/byte_tokenizer.py +14 -15
  157. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +20 -7
  158. keras_hub/src/tokenizers/tokenizer.py +67 -32
  159. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +14 -15
  160. keras_hub/src/tokenizers/word_piece_tokenizer.py +34 -47
  161. keras_hub/src/utils/imagenet/__init__.py +13 -0
  162. keras_hub/src/utils/imagenet/imagenet_utils.py +1067 -0
  163. keras_hub/src/utils/keras_utils.py +0 -50
  164. keras_hub/src/utils/preset_utils.py +230 -68
  165. keras_hub/src/utils/tensor_utils.py +187 -69
  166. keras_hub/src/utils/timm/convert_resnet.py +19 -16
  167. keras_hub/src/utils/timm/preset_loader.py +66 -0
  168. keras_hub/src/utils/transformers/convert_albert.py +193 -0
  169. keras_hub/src/utils/transformers/convert_bart.py +373 -0
  170. keras_hub/src/utils/transformers/convert_bert.py +7 -17
  171. keras_hub/src/utils/transformers/convert_distilbert.py +10 -20
  172. keras_hub/src/utils/transformers/convert_gemma.py +5 -19
  173. keras_hub/src/utils/transformers/convert_gpt2.py +5 -18
  174. keras_hub/src/utils/transformers/convert_llama3.py +7 -18
  175. keras_hub/src/utils/transformers/convert_mistral.py +129 -0
  176. keras_hub/src/utils/transformers/convert_pali_gemma.py +7 -29
  177. keras_hub/src/utils/transformers/preset_loader.py +77 -0
  178. keras_hub/src/utils/transformers/safetensor_utils.py +2 -2
  179. keras_hub/src/version_utils.py +1 -1
  180. keras_hub_nightly-0.16.0.dev2024092017.dist-info/METADATA +202 -0
  181. keras_hub_nightly-0.16.0.dev2024092017.dist-info/RECORD +334 -0
  182. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/WHEEL +1 -1
  183. keras_hub/src/models/bart/bart_preprocessor.py +0 -276
  184. keras_hub/src/models/bloom/bloom_preprocessor.py +0 -185
  185. keras_hub/src/models/electra/electra_preprocessor.py +0 -154
  186. keras_hub/src/models/falcon/falcon_preprocessor.py +0 -187
  187. keras_hub/src/models/gemma/gemma_preprocessor.py +0 -191
  188. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +0 -145
  189. keras_hub/src/models/llama/llama_preprocessor.py +0 -189
  190. keras_hub/src/models/mistral/mistral_preprocessor.py +0 -190
  191. keras_hub/src/models/opt/opt_preprocessor.py +0 -188
  192. keras_hub/src/models/phi3/phi3_preprocessor.py +0 -190
  193. keras_hub/src/models/whisper/whisper_preprocessor.py +0 -326
  194. keras_hub/src/utils/timm/convert.py +0 -37
  195. keras_hub/src/utils/transformers/convert.py +0 -101
  196. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +0 -34
  197. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +0 -297
  198. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/top_level.txt +0 -0
@@ -27,9 +27,10 @@ class ResNetBackbone(FeaturePyramidBackbone):
27
27
  This class implements a ResNet backbone as described in [Deep Residual
28
28
  Learning for Image Recognition](https://arxiv.org/abs/1512.03385)(
29
29
  CVPR 2016), [Identity Mappings in Deep Residual Networks](
30
- https://arxiv.org/abs/1603.05027)(ECCV 2016) and [ResNet strikes back: An
30
+ https://arxiv.org/abs/1603.05027)(ECCV 2016), [ResNet strikes back: An
31
31
  improved training procedure in timm](https://arxiv.org/abs/2110.00476)(
32
- NeurIPS 2021 Workshop).
32
+ NeurIPS 2021 Workshop) and [Bag of Tricks for Image Classification with
33
+ Convolutional Neural Networks](https://arxiv.org/abs/1812.01187).
33
34
 
34
35
  The difference in ResNet and ResNetV2 rests in the structure of their
35
36
  individual building blocks. In ResNetV2, the batch normalization and
@@ -37,18 +38,31 @@ class ResNetBackbone(FeaturePyramidBackbone):
37
38
  the batch normalization and ReLU activation are applied after the
38
39
  convolution layers.
39
40
 
41
+ ResNetVd introduces two key modifications to the standard ResNet. First,
42
+ the initial convolutional layer is replaced by a series of three
43
+ successive convolutional layers. Second, shortcut connections use an
44
+ additional pooling operation rather than performing downsampling within
45
+ the convolutional layers themselves.
46
+
40
47
  Note that `ResNetBackbone` expects the inputs to be images with a value
41
48
  range of `[0, 255]` when `include_rescaling=True`.
42
49
 
43
50
  Args:
51
+ input_conv_filters: list of ints. The number of filters of the initial
52
+ convolution(s).
53
+ input_conv_kernel_sizes: list of ints. The kernel sizes of the initial
54
+ convolution(s).
44
55
  stackwise_num_filters: list of ints. The number of filters for each
45
56
  stack.
46
57
  stackwise_num_blocks: list of ints. The number of blocks for each stack.
47
58
  stackwise_num_strides: list of ints. The number of strides for each
48
59
  stack.
49
- block_type: str. The block type to stack. One of `"basic_block"` or
50
- `"bottleneck_block"`. Use `"basic_block"` for ResNet18 and ResNet34.
51
- Use `"bottleneck_block"` for ResNet50, ResNet101 and ResNet152.
60
+ block_type: str. The block type to stack. One of `"basic_block"`,
61
+ `"bottleneck_block"`, `"basic_block_vd"` or
62
+ `"bottleneck_block_vd"`. Use `"basic_block"` for ResNet18 and
63
+ ResNet34. Use `"bottleneck_block"` for ResNet50, ResNet101 and
64
+ ResNet152 and the `"_vd"` prefix for the respective ResNet_vd
65
+ variants.
52
66
  use_pre_activation: boolean. Whether to use pre-activation or not.
53
67
  `True` for ResNetV2, `False` for ResNet.
54
68
  include_rescaling: boolean. If `True`, rescale the input using
@@ -88,6 +102,8 @@ class ResNetBackbone(FeaturePyramidBackbone):
88
102
 
89
103
  # Randomly initialized ResNetV2 backbone with a custom config.
90
104
  model = keras_hub.models.ResNetBackbone(
105
+ input_conv_filters=[64],
106
+ input_conv_kernel_sizes=[7],
91
107
  stackwise_num_filters=[64, 64, 64],
92
108
  stackwise_num_blocks=[2, 2, 2],
93
109
  stackwise_num_strides=[1, 2, 2],
@@ -101,6 +117,8 @@ class ResNetBackbone(FeaturePyramidBackbone):
101
117
 
102
118
  def __init__(
103
119
  self,
120
+ input_conv_filters,
121
+ input_conv_kernel_sizes,
104
122
  stackwise_num_filters,
105
123
  stackwise_num_blocks,
106
124
  stackwise_num_strides,
@@ -108,11 +126,17 @@ class ResNetBackbone(FeaturePyramidBackbone):
108
126
  use_pre_activation=False,
109
127
  include_rescaling=True,
110
128
  image_shape=(None, None, 3),
111
- pooling="avg",
112
129
  data_format=None,
113
130
  dtype=None,
114
131
  **kwargs,
115
132
  ):
133
+ if len(input_conv_filters) != len(input_conv_kernel_sizes):
134
+ raise ValueError(
135
+ "The length of `input_conv_filters` and"
136
+ "`input_conv_kernel_sizes` must be the same. "
137
+ f"Received: input_conv_filters={input_conv_filters}, "
138
+ f"input_conv_kernel_sizes={input_conv_kernel_sizes}."
139
+ )
116
140
  if len(stackwise_num_filters) != len(stackwise_num_blocks) or len(
117
141
  stackwise_num_filters
118
142
  ) != len(stackwise_num_strides):
@@ -128,14 +152,20 @@ class ResNetBackbone(FeaturePyramidBackbone):
128
152
  "The first element of `stackwise_num_filters` must be 64. "
129
153
  f"Received: stackwise_num_filters={stackwise_num_filters}"
130
154
  )
131
- if block_type not in ("basic_block", "bottleneck_block"):
155
+ if block_type not in (
156
+ "basic_block",
157
+ "bottleneck_block",
158
+ "basic_block_vd",
159
+ "bottleneck_block_vd",
160
+ ):
132
161
  raise ValueError(
133
- '`block_type` must be either `"basic_block"` or '
134
- f'`"bottleneck_block"`. Received block_type={block_type}.'
162
+ '`block_type` must be either `"basic_block"`, '
163
+ '`"bottleneck_block"`, `"basic_block_vd"` or '
164
+ f'`"bottleneck_block_vd"`. Received block_type={block_type}.'
135
165
  )
136
- version = "v1" if not use_pre_activation else "v2"
137
166
  data_format = standardize_data_format(data_format)
138
167
  bn_axis = -1 if data_format == "channels_last" else 1
168
+ num_input_convs = len(input_conv_filters)
139
169
  num_stacks = len(stackwise_num_filters)
140
170
 
141
171
  # === Functional Model ===
@@ -155,29 +185,56 @@ class ResNetBackbone(FeaturePyramidBackbone):
155
185
  # The padding between torch and tensorflow/jax differs when `strides>1`.
156
186
  # Therefore, we need to manually pad the tensor.
157
187
  x = layers.ZeroPadding2D(
158
- 3,
188
+ (input_conv_kernel_sizes[0] - 1) // 2,
159
189
  data_format=data_format,
160
190
  dtype=dtype,
161
191
  name="conv1_pad",
162
192
  )(x)
163
193
  x = layers.Conv2D(
164
- 64,
165
- 7,
194
+ input_conv_filters[0],
195
+ input_conv_kernel_sizes[0],
166
196
  strides=2,
167
197
  data_format=data_format,
168
198
  use_bias=False,
199
+ padding="valid",
169
200
  dtype=dtype,
170
201
  name="conv1_conv",
171
202
  )(x)
203
+ for conv_index in range(1, num_input_convs):
204
+ x = layers.BatchNormalization(
205
+ axis=bn_axis,
206
+ epsilon=1e-5,
207
+ momentum=0.9,
208
+ dtype=dtype,
209
+ name=f"conv{conv_index}_bn",
210
+ )(x)
211
+ x = layers.Activation(
212
+ "relu", dtype=dtype, name=f"conv{conv_index}_relu"
213
+ )(x)
214
+ x = layers.Conv2D(
215
+ input_conv_filters[conv_index],
216
+ input_conv_kernel_sizes[conv_index],
217
+ strides=1,
218
+ data_format=data_format,
219
+ use_bias=False,
220
+ padding="same",
221
+ dtype=dtype,
222
+ name=f"conv{conv_index+1}_conv",
223
+ )(x)
224
+
172
225
  if not use_pre_activation:
173
226
  x = layers.BatchNormalization(
174
227
  axis=bn_axis,
175
228
  epsilon=1e-5,
176
229
  momentum=0.9,
177
230
  dtype=dtype,
178
- name="conv1_bn",
231
+ name=f"conv{num_input_convs}_bn",
232
+ )(x)
233
+ x = layers.Activation(
234
+ "relu",
235
+ dtype=dtype,
236
+ name=f"conv{num_input_convs}_relu",
179
237
  )(x)
180
- x = layers.Activation("relu", dtype=dtype, name="conv1_relu")(x)
181
238
 
182
239
  if use_pre_activation:
183
240
  # A workaround for ResNetV2: we need -inf padding to prevent zeros
@@ -210,12 +267,10 @@ class ResNetBackbone(FeaturePyramidBackbone):
210
267
  stride=stackwise_num_strides[stack_index],
211
268
  block_type=block_type,
212
269
  use_pre_activation=use_pre_activation,
213
- first_shortcut=(
214
- block_type == "bottleneck_block" or stack_index > 0
215
- ),
270
+ first_shortcut=(block_type != "basic_block" or stack_index > 0),
216
271
  data_format=data_format,
217
272
  dtype=dtype,
218
- name=f"{version}_stack{stack_index}",
273
+ name=f"stack{stack_index}",
219
274
  )
220
275
  pyramid_outputs[f"P{stack_index + 2}"] = x
221
276
 
@@ -229,25 +284,16 @@ class ResNetBackbone(FeaturePyramidBackbone):
229
284
  )(x)
230
285
  x = layers.Activation("relu", dtype=dtype, name="post_relu")(x)
231
286
 
232
- if pooling == "avg":
233
- feature_map_output = layers.GlobalAveragePooling2D(
234
- data_format=data_format, dtype=dtype
235
- )(x)
236
- elif pooling == "max":
237
- feature_map_output = layers.GlobalMaxPooling2D(
238
- data_format=data_format, dtype=dtype
239
- )(x)
240
- else:
241
- feature_map_output = x
242
-
243
287
  super().__init__(
244
288
  inputs=image_input,
245
- outputs=feature_map_output,
289
+ outputs=x,
246
290
  dtype=dtype,
247
291
  **kwargs,
248
292
  )
249
293
 
250
294
  # === Config ===
295
+ self.input_conv_filters = input_conv_filters
296
+ self.input_conv_kernel_sizes = input_conv_kernel_sizes
251
297
  self.stackwise_num_filters = stackwise_num_filters
252
298
  self.stackwise_num_blocks = stackwise_num_blocks
253
299
  self.stackwise_num_strides = stackwise_num_strides
@@ -255,13 +301,15 @@ class ResNetBackbone(FeaturePyramidBackbone):
255
301
  self.use_pre_activation = use_pre_activation
256
302
  self.include_rescaling = include_rescaling
257
303
  self.image_shape = image_shape
258
- self.pooling = pooling
259
304
  self.pyramid_outputs = pyramid_outputs
305
+ self.data_format = data_format
260
306
 
261
307
  def get_config(self):
262
308
  config = super().get_config()
263
309
  config.update(
264
310
  {
311
+ "input_conv_filters": self.input_conv_filters,
312
+ "input_conv_kernel_sizes": self.input_conv_kernel_sizes,
265
313
  "stackwise_num_filters": self.stackwise_num_filters,
266
314
  "stackwise_num_blocks": self.stackwise_num_blocks,
267
315
  "stackwise_num_strides": self.stackwise_num_strides,
@@ -269,7 +317,6 @@ class ResNetBackbone(FeaturePyramidBackbone):
269
317
  "use_pre_activation": self.use_pre_activation,
270
318
  "include_rescaling": self.include_rescaling,
271
319
  "image_shape": self.image_shape,
272
- "pooling": self.pooling,
273
320
  }
274
321
  )
275
322
  return config
@@ -327,7 +374,10 @@ def apply_basic_block(
327
374
  )(x_preact)
328
375
 
329
376
  if conv_shortcut:
330
- x = x_preact if x_preact is not None else x
377
+ if x_preact is not None:
378
+ shortcut = x_preact
379
+ else:
380
+ shortcut = x
331
381
  shortcut = layers.Conv2D(
332
382
  filters,
333
383
  1,
@@ -336,7 +386,7 @@ def apply_basic_block(
336
386
  use_bias=False,
337
387
  dtype=dtype,
338
388
  name=f"{name}_0_conv",
339
- )(x)
389
+ )(shortcut)
340
390
  if not use_pre_activation:
341
391
  shortcut = layers.BatchNormalization(
342
392
  axis=bn_axis,
@@ -452,7 +502,10 @@ def apply_bottleneck_block(
452
502
  )(x_preact)
453
503
 
454
504
  if conv_shortcut:
455
- x = x_preact if x_preact is not None else x
505
+ if x_preact is not None:
506
+ shortcut = x_preact
507
+ else:
508
+ shortcut = x
456
509
  shortcut = layers.Conv2D(
457
510
  4 * filters,
458
511
  1,
@@ -461,7 +514,295 @@ def apply_bottleneck_block(
461
514
  use_bias=False,
462
515
  dtype=dtype,
463
516
  name=f"{name}_0_conv",
517
+ )(shortcut)
518
+ if not use_pre_activation:
519
+ shortcut = layers.BatchNormalization(
520
+ axis=bn_axis,
521
+ epsilon=1e-5,
522
+ momentum=0.9,
523
+ dtype=dtype,
524
+ name=f"{name}_0_bn",
525
+ )(shortcut)
526
+ else:
527
+ shortcut = x
528
+
529
+ x = x_preact if x_preact is not None else x
530
+ x = layers.Conv2D(
531
+ filters,
532
+ 1,
533
+ strides=1,
534
+ data_format=data_format,
535
+ use_bias=False,
536
+ dtype=dtype,
537
+ name=f"{name}_1_conv",
538
+ )(x)
539
+ x = layers.BatchNormalization(
540
+ axis=bn_axis,
541
+ epsilon=1e-5,
542
+ momentum=0.9,
543
+ dtype=dtype,
544
+ name=f"{name}_1_bn",
545
+ )(x)
546
+ x = layers.Activation("relu", dtype=dtype, name=f"{name}_1_relu")(x)
547
+
548
+ if stride > 1:
549
+ x = layers.ZeroPadding2D(
550
+ (kernel_size - 1) // 2,
551
+ data_format=data_format,
552
+ dtype=dtype,
553
+ name=f"{name}_2_pad",
554
+ )(x)
555
+ x = layers.Conv2D(
556
+ filters,
557
+ kernel_size,
558
+ strides=stride,
559
+ padding="valid" if stride > 1 else "same",
560
+ data_format=data_format,
561
+ use_bias=False,
562
+ dtype=dtype,
563
+ name=f"{name}_2_conv",
564
+ )(x)
565
+ x = layers.BatchNormalization(
566
+ axis=bn_axis,
567
+ epsilon=1e-5,
568
+ momentum=0.9,
569
+ dtype=dtype,
570
+ name=f"{name}_2_bn",
571
+ )(x)
572
+ x = layers.Activation("relu", dtype=dtype, name=f"{name}_2_relu")(x)
573
+
574
+ x = layers.Conv2D(
575
+ 4 * filters,
576
+ 1,
577
+ data_format=data_format,
578
+ use_bias=False,
579
+ dtype=dtype,
580
+ name=f"{name}_3_conv",
581
+ )(x)
582
+ if not use_pre_activation:
583
+ x = layers.BatchNormalization(
584
+ axis=bn_axis,
585
+ epsilon=1e-5,
586
+ momentum=0.9,
587
+ dtype=dtype,
588
+ name=f"{name}_3_bn",
464
589
  )(x)
590
+ x = layers.Add(dtype=dtype, name=f"{name}_add")([shortcut, x])
591
+ x = layers.Activation("relu", dtype=dtype, name=f"{name}_out")(x)
592
+ else:
593
+ x = layers.Add(dtype=dtype, name=f"{name}_out")([shortcut, x])
594
+ return x
595
+
596
+
597
+ def apply_basic_block_vd(
598
+ x,
599
+ filters,
600
+ kernel_size=3,
601
+ stride=1,
602
+ conv_shortcut=False,
603
+ use_pre_activation=False,
604
+ data_format=None,
605
+ dtype=None,
606
+ name=None,
607
+ ):
608
+ """Applies a basic residual block.
609
+
610
+ Args:
611
+ x: Tensor. The input tensor to pass through the block.
612
+ filters: int. The number of filters in the block.
613
+ kernel_size: int. The kernel size of the bottleneck layer. Defaults to
614
+ `3`.
615
+ stride: int. The stride length of the first layer. Defaults to `1`.
616
+ conv_shortcut: bool. If `True`, use a convolution shortcut. If `False`,
617
+ use an identity or pooling shortcut based on the stride. Defaults to
618
+ `False`.
619
+ use_pre_activation: boolean. Whether to use pre-activation or not.
620
+ `True` for ResNetV2, `False` for ResNet. Defaults to `False`.
621
+ data_format: `None` or str. the ordering of the dimensions in the
622
+ inputs. Can be `"channels_last"`
623
+ (`(batch_size, height, width, channels)`) or`"channels_first"`
624
+ (`(batch_size, channels, height, width)`).
625
+ dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
626
+ to use for the models computations and weights.
627
+ name: str. A prefix for the layer names used in the block.
628
+
629
+ Returns:
630
+ The output tensor for the basic residual block.
631
+ """
632
+ data_format = data_format or keras.config.image_data_format()
633
+ bn_axis = -1 if data_format == "channels_last" else 1
634
+
635
+ x_preact = None
636
+ if use_pre_activation:
637
+ x_preact = layers.BatchNormalization(
638
+ axis=bn_axis,
639
+ epsilon=1e-5,
640
+ momentum=0.9,
641
+ dtype=dtype,
642
+ name=f"{name}_pre_activation_bn",
643
+ )(x)
644
+ x_preact = layers.Activation(
645
+ "relu", dtype=dtype, name=f"{name}_pre_activation_relu"
646
+ )(x_preact)
647
+
648
+ if conv_shortcut:
649
+ if x_preact is not None:
650
+ shortcut = x_preact
651
+ elif stride > 1:
652
+ shortcut = layers.AveragePooling2D(
653
+ 2,
654
+ strides=stride,
655
+ data_format=data_format,
656
+ dtype=dtype,
657
+ padding="same",
658
+ )(x)
659
+ else:
660
+ shortcut = x
661
+ shortcut = layers.Conv2D(
662
+ filters,
663
+ 1,
664
+ strides=1,
665
+ data_format=data_format,
666
+ use_bias=False,
667
+ dtype=dtype,
668
+ name=f"{name}_0_conv",
669
+ )(shortcut)
670
+ if not use_pre_activation:
671
+ shortcut = layers.BatchNormalization(
672
+ axis=bn_axis,
673
+ epsilon=1e-5,
674
+ momentum=0.9,
675
+ dtype=dtype,
676
+ name=f"{name}_0_bn",
677
+ )(shortcut)
678
+ else:
679
+ shortcut = x
680
+
681
+ x = x_preact if x_preact is not None else x
682
+ if stride > 1:
683
+ x = layers.ZeroPadding2D(
684
+ (kernel_size - 1) // 2,
685
+ data_format=data_format,
686
+ dtype=dtype,
687
+ name=f"{name}_1_pad",
688
+ )(x)
689
+ x = layers.Conv2D(
690
+ filters,
691
+ kernel_size,
692
+ strides=stride,
693
+ padding="valid" if stride > 1 else "same",
694
+ data_format=data_format,
695
+ use_bias=False,
696
+ dtype=dtype,
697
+ name=f"{name}_1_conv",
698
+ )(x)
699
+ x = layers.BatchNormalization(
700
+ axis=bn_axis,
701
+ epsilon=1e-5,
702
+ momentum=0.9,
703
+ dtype=dtype,
704
+ name=f"{name}_1_bn",
705
+ )(x)
706
+ x = layers.Activation("relu", dtype=dtype, name=f"{name}_1_relu")(x)
707
+
708
+ x = layers.Conv2D(
709
+ filters,
710
+ kernel_size,
711
+ strides=1,
712
+ padding="same",
713
+ data_format=data_format,
714
+ use_bias=False,
715
+ dtype=dtype,
716
+ name=f"{name}_2_conv",
717
+ )(x)
718
+ if not use_pre_activation:
719
+ x = layers.BatchNormalization(
720
+ axis=bn_axis,
721
+ epsilon=1e-5,
722
+ momentum=0.9,
723
+ dtype=dtype,
724
+ name=f"{name}_2_bn",
725
+ )(x)
726
+ x = layers.Add(dtype=dtype, name=f"{name}_add")([shortcut, x])
727
+ x = layers.Activation("relu", dtype=dtype, name=f"{name}_out")(x)
728
+ else:
729
+ x = layers.Add(dtype=dtype, name=f"{name}_out")([shortcut, x])
730
+ return x
731
+
732
+
733
+ def apply_bottleneck_block_vd(
734
+ x,
735
+ filters,
736
+ kernel_size=3,
737
+ stride=1,
738
+ conv_shortcut=False,
739
+ use_pre_activation=False,
740
+ data_format=None,
741
+ dtype=None,
742
+ name=None,
743
+ ):
744
+ """Applies a bottleneck residual block.
745
+
746
+ Args:
747
+ x: Tensor. The input tensor to pass through the block.
748
+ filters: int. The number of filters in the block.
749
+ kernel_size: int. The kernel size of the bottleneck layer. Defaults to
750
+ `3`.
751
+ stride: int. The stride length of the first layer. Defaults to `1`.
752
+ conv_shortcut: bool. If `True`, use a convolution shortcut. If `False`,
753
+ use an identity or pooling shortcut based on the stride. Defaults to
754
+ `False`.
755
+ use_pre_activation: boolean. Whether to use pre-activation or not.
756
+ `True` for ResNetV2, `False` for ResNet. Defaults to `False`.
757
+ data_format: `None` or str. the ordering of the dimensions in the
758
+ inputs. Can be `"channels_last"`
759
+ (`(batch_size, height, width, channels)`) or`"channels_first"`
760
+ (`(batch_size, channels, height, width)`).
761
+ dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
762
+ to use for the models computations and weights.
763
+ name: str. A prefix for the layer names used in the block.
764
+
765
+ Returns:
766
+ The output tensor for the residual block.
767
+ """
768
+ data_format = data_format or keras.config.image_data_format()
769
+ bn_axis = -1 if data_format == "channels_last" else 1
770
+
771
+ x_preact = None
772
+ if use_pre_activation:
773
+ x_preact = layers.BatchNormalization(
774
+ axis=bn_axis,
775
+ epsilon=1e-5,
776
+ momentum=0.9,
777
+ dtype=dtype,
778
+ name=f"{name}_pre_activation_bn",
779
+ )(x)
780
+ x_preact = layers.Activation(
781
+ "relu", dtype=dtype, name=f"{name}_pre_activation_relu"
782
+ )(x_preact)
783
+
784
+ if conv_shortcut:
785
+ if x_preact is not None:
786
+ shortcut = x_preact
787
+ elif stride > 1:
788
+ shortcut = layers.AveragePooling2D(
789
+ 2,
790
+ strides=stride,
791
+ data_format=data_format,
792
+ dtype=dtype,
793
+ padding="same",
794
+ )(x)
795
+ else:
796
+ shortcut = x
797
+ shortcut = layers.Conv2D(
798
+ 4 * filters,
799
+ 1,
800
+ strides=1,
801
+ data_format=data_format,
802
+ use_bias=False,
803
+ dtype=dtype,
804
+ name=f"{name}_0_conv",
805
+ )(shortcut)
465
806
  if not use_pre_activation:
466
807
  shortcut = layers.BatchNormalization(
467
808
  axis=bn_axis,
@@ -561,8 +902,11 @@ def apply_stack(
561
902
  blocks: int. The number of blocks in the stack.
562
903
  stride: int. The stride length of the first layer in the first block.
563
904
  block_type: str. The block type to stack. One of `"basic_block"` or
564
- `"bottleneck_block"`. Use `"basic_block"` for ResNet18 and ResNet34.
565
- Use `"bottleneck_block"` for ResNet50, ResNet101 and ResNet152.
905
+ `"bottleneck_block"`, `"basic_block_vd"` or
906
+ `"bottleneck_block_vd"`. Use `"basic_block"` for ResNet18 and
907
+ ResNet34. Use `"bottleneck_block"` for ResNet50, ResNet101 and
908
+ ResNet152 and the `"_vd"` prefix for the respective ResNet_vd
909
+ variants.
566
910
  use_pre_activation: boolean. Whether to use pre-activation or not.
567
911
  `True` for ResNetV2, `False` for ResNet and ResNeXt.
568
912
  first_shortcut: bool. If `True`, use a convolution shortcut. If `False`,
@@ -580,17 +924,21 @@ def apply_stack(
580
924
  Output tensor for the stacked blocks.
581
925
  """
582
926
  if name is None:
583
- version = "v1" if not use_pre_activation else "v2"
584
- name = f"{version}_stack"
927
+ name = "stack"
585
928
 
586
929
  if block_type == "basic_block":
587
930
  block_fn = apply_basic_block
588
931
  elif block_type == "bottleneck_block":
589
932
  block_fn = apply_bottleneck_block
933
+ elif block_type == "basic_block_vd":
934
+ block_fn = apply_basic_block_vd
935
+ elif block_type == "bottleneck_block_vd":
936
+ block_fn = apply_bottleneck_block_vd
590
937
  else:
591
938
  raise ValueError(
592
- '`block_type` must be either `"basic_block"` or '
593
- f'`"bottleneck_block"`. Received block_type={block_type}.'
939
+ '`block_type` must be either `"basic_block"`, '
940
+ '`"bottleneck_block"`, `"basic_block_vd"` or '
941
+ f'`"bottleneck_block_vd"`. Received block_type={block_type}.'
594
942
  )
595
943
  for i in range(blocks):
596
944
  if i == 0: