keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl → 0.15.0.dev20240911134614__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. keras_hub/api/__init__.py +1 -0
  2. keras_hub/api/bounding_box/__init__.py +36 -0
  3. keras_hub/api/layers/__init__.py +14 -0
  4. keras_hub/api/models/__init__.py +75 -31
  5. keras_hub/api/tokenizers/__init__.py +30 -0
  6. keras_hub/src/bounding_box/__init__.py +13 -0
  7. keras_hub/src/bounding_box/converters.py +529 -0
  8. keras_hub/src/bounding_box/formats.py +162 -0
  9. keras_hub/src/bounding_box/iou.py +263 -0
  10. keras_hub/src/bounding_box/to_dense.py +95 -0
  11. keras_hub/src/bounding_box/to_ragged.py +99 -0
  12. keras_hub/src/bounding_box/utils.py +194 -0
  13. keras_hub/src/bounding_box/validate_format.py +99 -0
  14. keras_hub/src/layers/preprocessing/audio_converter.py +121 -0
  15. keras_hub/src/layers/preprocessing/image_converter.py +130 -0
  16. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +2 -0
  17. keras_hub/src/layers/preprocessing/multi_segment_packer.py +9 -8
  18. keras_hub/src/layers/preprocessing/preprocessing_layer.py +2 -29
  19. keras_hub/src/layers/preprocessing/random_deletion.py +33 -31
  20. keras_hub/src/layers/preprocessing/random_swap.py +33 -31
  21. keras_hub/src/layers/preprocessing/resizing_image_converter.py +101 -0
  22. keras_hub/src/layers/preprocessing/start_end_packer.py +3 -2
  23. keras_hub/src/models/albert/__init__.py +1 -2
  24. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +6 -86
  25. keras_hub/src/models/albert/{albert_classifier.py → albert_text_classifier.py} +29 -10
  26. keras_hub/src/models/albert/{albert_preprocessor.py → albert_text_classifier_preprocessor.py} +14 -70
  27. keras_hub/src/models/albert/albert_tokenizer.py +17 -36
  28. keras_hub/src/models/backbone.py +12 -34
  29. keras_hub/src/models/bart/__init__.py +1 -2
  30. keras_hub/src/models/bart/bart_preprocessor.py +6 -18
  31. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +21 -148
  32. keras_hub/src/models/bart/bart_tokenizer.py +12 -39
  33. keras_hub/src/models/bert/__init__.py +1 -5
  34. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +6 -87
  35. keras_hub/src/models/bert/bert_presets.py +1 -4
  36. keras_hub/src/models/bert/{bert_classifier.py → bert_text_classifier.py} +12 -10
  37. keras_hub/src/models/bert/{bert_preprocessor.py → bert_text_classifier_preprocessor.py} +14 -70
  38. keras_hub/src/models/bert/bert_tokenizer.py +17 -35
  39. keras_hub/src/models/bloom/__init__.py +1 -2
  40. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +6 -91
  41. keras_hub/src/models/bloom/bloom_preprocessor.py +5 -12
  42. keras_hub/src/models/bloom/bloom_tokenizer.py +12 -41
  43. keras_hub/src/models/causal_lm.py +10 -29
  44. keras_hub/src/models/causal_lm_preprocessor.py +195 -0
  45. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +54 -15
  46. keras_hub/src/models/deberta_v3/__init__.py +1 -4
  47. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +14 -77
  48. keras_hub/src/models/deberta_v3/{deberta_v3_classifier.py → deberta_v3_text_classifier.py} +11 -11
  49. keras_hub/src/models/deberta_v3/{deberta_v3_preprocessor.py → deberta_v3_text_classifier_preprocessor.py} +23 -64
  50. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +30 -25
  51. keras_hub/src/models/densenet/densenet_backbone.py +46 -22
  52. keras_hub/src/models/distil_bert/__init__.py +1 -4
  53. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +14 -76
  54. keras_hub/src/models/distil_bert/{distil_bert_classifier.py → distil_bert_text_classifier.py} +12 -12
  55. keras_hub/src/models/distil_bert/{distil_bert_preprocessor.py → distil_bert_text_classifier_preprocessor.py} +23 -63
  56. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +19 -35
  57. keras_hub/src/models/efficientnet/__init__.py +13 -0
  58. keras_hub/src/models/efficientnet/efficientnet_backbone.py +569 -0
  59. keras_hub/src/models/efficientnet/fusedmbconv.py +229 -0
  60. keras_hub/src/models/efficientnet/mbconv.py +238 -0
  61. keras_hub/src/models/electra/__init__.py +1 -2
  62. keras_hub/src/models/electra/electra_preprocessor.py +6 -5
  63. keras_hub/src/models/electra/electra_tokenizer.py +17 -32
  64. keras_hub/src/models/f_net/__init__.py +1 -2
  65. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +12 -78
  66. keras_hub/src/models/f_net/{f_net_classifier.py → f_net_text_classifier.py} +10 -8
  67. keras_hub/src/models/f_net/{f_net_preprocessor.py → f_net_text_classifier_preprocessor.py} +19 -63
  68. keras_hub/src/models/f_net/f_net_tokenizer.py +17 -35
  69. keras_hub/src/models/falcon/__init__.py +1 -2
  70. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +6 -89
  71. keras_hub/src/models/falcon/falcon_preprocessor.py +5 -12
  72. keras_hub/src/models/falcon/falcon_tokenizer.py +12 -35
  73. keras_hub/src/models/gemma/__init__.py +1 -2
  74. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +6 -90
  75. keras_hub/src/models/gemma/gemma_preprocessor.py +5 -12
  76. keras_hub/src/models/gemma/gemma_tokenizer.py +12 -23
  77. keras_hub/src/models/gpt2/__init__.py +1 -2
  78. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +6 -89
  79. keras_hub/src/models/gpt2/gpt2_preprocessor.py +5 -12
  80. keras_hub/src/models/gpt2/gpt2_tokenizer.py +12 -34
  81. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +6 -91
  82. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +5 -12
  83. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +12 -34
  84. keras_hub/src/models/image_classifier.py +0 -5
  85. keras_hub/src/models/image_classifier_preprocessor.py +83 -0
  86. keras_hub/src/models/llama/__init__.py +1 -2
  87. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +6 -85
  88. keras_hub/src/models/llama/llama_preprocessor.py +5 -12
  89. keras_hub/src/models/llama/llama_tokenizer.py +12 -25
  90. keras_hub/src/models/llama3/__init__.py +1 -2
  91. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +6 -89
  92. keras_hub/src/models/llama3/llama3_preprocessor.py +2 -0
  93. keras_hub/src/models/llama3/llama3_tokenizer.py +12 -33
  94. keras_hub/src/models/masked_lm.py +0 -2
  95. keras_hub/src/models/masked_lm_preprocessor.py +156 -0
  96. keras_hub/src/models/mistral/__init__.py +1 -2
  97. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +6 -91
  98. keras_hub/src/models/mistral/mistral_preprocessor.py +5 -12
  99. keras_hub/src/models/mistral/mistral_tokenizer.py +12 -23
  100. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +2 -2
  101. keras_hub/src/models/mobilenet/__init__.py +13 -0
  102. keras_hub/src/models/mobilenet/mobilenet_backbone.py +530 -0
  103. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +114 -0
  104. keras_hub/src/models/opt/__init__.py +1 -2
  105. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +6 -93
  106. keras_hub/src/models/opt/opt_preprocessor.py +5 -12
  107. keras_hub/src/models/opt/opt_tokenizer.py +12 -41
  108. keras_hub/src/models/pali_gemma/__init__.py +1 -4
  109. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +28 -28
  110. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +25 -0
  111. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +5 -5
  112. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +10 -2
  113. keras_hub/src/models/phi3/__init__.py +1 -2
  114. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -9
  115. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +6 -89
  116. keras_hub/src/models/phi3/phi3_preprocessor.py +5 -12
  117. keras_hub/src/models/phi3/phi3_tokenizer.py +12 -36
  118. keras_hub/src/models/preprocessor.py +76 -83
  119. keras_hub/src/models/resnet/__init__.py +6 -0
  120. keras_hub/src/models/resnet/resnet_backbone.py +387 -26
  121. keras_hub/src/models/resnet/resnet_image_classifier.py +7 -3
  122. keras_hub/src/models/resnet/resnet_image_classifier_preprocessor.py +28 -0
  123. keras_hub/src/models/resnet/resnet_image_converter.py +23 -0
  124. keras_hub/src/models/resnet/resnet_presets.py +95 -0
  125. keras_hub/src/models/roberta/__init__.py +1 -2
  126. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +22 -74
  127. keras_hub/src/models/roberta/{roberta_classifier.py → roberta_text_classifier.py} +11 -11
  128. keras_hub/src/models/roberta/{roberta_preprocessor.py → roberta_text_classifier_preprocessor.py} +21 -53
  129. keras_hub/src/models/roberta/roberta_tokenizer.py +13 -52
  130. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +269 -0
  131. keras_hub/src/models/stable_diffusion_v3/__init__.py +13 -0
  132. keras_hub/src/models/stable_diffusion_v3/clip_encoder_block.py +103 -0
  133. keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +93 -0
  134. keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py +149 -0
  135. keras_hub/src/models/stable_diffusion_v3/clip_tokenizer.py +167 -0
  136. keras_hub/src/models/stable_diffusion_v3/mmdit.py +427 -0
  137. keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +317 -0
  138. keras_hub/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py +74 -0
  139. keras_hub/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py +155 -0
  140. keras_hub/src/models/stable_diffusion_v3/vae_attention.py +126 -0
  141. keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +186 -0
  142. keras_hub/src/models/t5/__init__.py +1 -2
  143. keras_hub/src/models/t5/t5_tokenizer.py +13 -23
  144. keras_hub/src/models/task.py +71 -116
  145. keras_hub/src/models/{classifier.py → text_classifier.py} +8 -13
  146. keras_hub/src/models/text_classifier_preprocessor.py +138 -0
  147. keras_hub/src/models/whisper/__init__.py +1 -2
  148. keras_hub/src/models/whisper/{whisper_audio_feature_extractor.py → whisper_audio_converter.py} +20 -18
  149. keras_hub/src/models/whisper/whisper_backbone.py +0 -3
  150. keras_hub/src/models/whisper/whisper_presets.py +10 -10
  151. keras_hub/src/models/whisper/whisper_tokenizer.py +20 -16
  152. keras_hub/src/models/xlm_roberta/__init__.py +1 -4
  153. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +26 -72
  154. keras_hub/src/models/xlm_roberta/{xlm_roberta_classifier.py → xlm_roberta_text_classifier.py} +11 -11
  155. keras_hub/src/models/xlm_roberta/{xlm_roberta_preprocessor.py → xlm_roberta_text_classifier_preprocessor.py} +26 -53
  156. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +25 -10
  157. keras_hub/src/tests/test_case.py +25 -0
  158. keras_hub/src/tokenizers/byte_pair_tokenizer.py +29 -17
  159. keras_hub/src/tokenizers/byte_tokenizer.py +14 -15
  160. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +19 -7
  161. keras_hub/src/tokenizers/tokenizer.py +67 -32
  162. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +14 -15
  163. keras_hub/src/tokenizers/word_piece_tokenizer.py +33 -47
  164. keras_hub/src/utils/keras_utils.py +0 -50
  165. keras_hub/src/utils/preset_utils.py +238 -67
  166. keras_hub/src/utils/tensor_utils.py +187 -69
  167. keras_hub/src/utils/timm/convert_resnet.py +20 -16
  168. keras_hub/src/utils/timm/preset_loader.py +67 -0
  169. keras_hub/src/utils/transformers/convert_albert.py +193 -0
  170. keras_hub/src/utils/transformers/convert_bart.py +373 -0
  171. keras_hub/src/utils/transformers/convert_bert.py +7 -17
  172. keras_hub/src/utils/transformers/convert_distilbert.py +10 -20
  173. keras_hub/src/utils/transformers/convert_gemma.py +5 -19
  174. keras_hub/src/utils/transformers/convert_gpt2.py +5 -18
  175. keras_hub/src/utils/transformers/convert_llama3.py +7 -18
  176. keras_hub/src/utils/transformers/convert_mistral.py +129 -0
  177. keras_hub/src/utils/transformers/convert_pali_gemma.py +7 -29
  178. keras_hub/src/utils/transformers/preset_loader.py +77 -0
  179. keras_hub/src/utils/transformers/safetensor_utils.py +2 -2
  180. keras_hub/src/version_utils.py +1 -1
  181. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.15.0.dev20240911134614.dist-info}/METADATA +1 -2
  182. keras_hub_nightly-0.15.0.dev20240911134614.dist-info/RECORD +338 -0
  183. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.15.0.dev20240911134614.dist-info}/WHEEL +1 -1
  184. keras_hub/src/models/whisper/whisper_preprocessor.py +0 -326
  185. keras_hub/src/utils/timm/convert.py +0 -37
  186. keras_hub/src/utils/transformers/convert.py +0 -101
  187. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +0 -297
  188. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.15.0.dev20240911134614.dist-info}/top_level.txt +0 -0
@@ -27,9 +27,10 @@ class ResNetBackbone(FeaturePyramidBackbone):
27
27
  This class implements a ResNet backbone as described in [Deep Residual
28
28
  Learning for Image Recognition](https://arxiv.org/abs/1512.03385)(
29
29
  CVPR 2016), [Identity Mappings in Deep Residual Networks](
30
- https://arxiv.org/abs/1603.05027)(ECCV 2016) and [ResNet strikes back: An
30
+ https://arxiv.org/abs/1603.05027)(ECCV 2016), [ResNet strikes back: An
31
31
  improved training procedure in timm](https://arxiv.org/abs/2110.00476)(
32
- NeurIPS 2021 Workshop).
32
+ NeurIPS 2021 Workshop) and [Bag of Tricks for Image Classification with
33
+ Convolutional Neural Networks](https://arxiv.org/abs/1812.01187).
33
34
 
34
35
  The difference in ResNet and ResNetV2 rests in the structure of their
35
36
  individual building blocks. In ResNetV2, the batch normalization and
@@ -37,18 +38,31 @@ class ResNetBackbone(FeaturePyramidBackbone):
37
38
  the batch normalization and ReLU activation are applied after the
38
39
  convolution layers.
39
40
 
41
+ ResNetVd introduces two key modifications to the standard ResNet. First,
42
+ the initial convolutional layer is replaced by a series of three
43
+ successive convolutional layers. Second, shortcut connections use an
44
+ additional pooling operation rather than performing downsampling within
45
+ the convolutional layers themselves.
46
+
40
47
  Note that `ResNetBackbone` expects the inputs to be images with a value
41
48
  range of `[0, 255]` when `include_rescaling=True`.
42
49
 
43
50
  Args:
51
+ input_conv_filters: list of ints. The number of filters of the initial
52
+ convolution(s).
53
+ input_conv_kernel_sizes: list of ints. The kernel sizes of the initial
54
+ convolution(s).
44
55
  stackwise_num_filters: list of ints. The number of filters for each
45
56
  stack.
46
57
  stackwise_num_blocks: list of ints. The number of blocks for each stack.
47
58
  stackwise_num_strides: list of ints. The number of strides for each
48
59
  stack.
49
60
  block_type: str. The block type to stack. One of `"basic_block"` or
50
- `"bottleneck_block"`. Use `"basic_block"` for ResNet18 and ResNet34.
51
- Use `"bottleneck_block"` for ResNet50, ResNet101 and ResNet152.
61
+ `"bottleneck_block"`, `"basic_block_vd"` or
62
+ `"bottleneck_block_vd"`. Use `"basic_block"` for ResNet18 and
63
+ ResNet34. Use `"bottleneck_block"` for ResNet50, ResNet101 and
64
+ ResNet152 and the `"_vd"` prefix for the respective ResNet_vd
65
+ variants.
52
66
  use_pre_activation: boolean. Whether to use pre-activation or not.
53
67
  `True` for ResNetV2, `False` for ResNet.
54
68
  include_rescaling: boolean. If `True`, rescale the input using
@@ -88,6 +102,8 @@ class ResNetBackbone(FeaturePyramidBackbone):
88
102
 
89
103
  # Randomly initialized ResNetV2 backbone with a custom config.
90
104
  model = keras_hub.models.ResNetBackbone(
105
+ input_conv_filters=[64],
106
+ input_conv_kernel_sizes=[7],
91
107
  stackwise_num_filters=[64, 64, 64],
92
108
  stackwise_num_blocks=[2, 2, 2],
93
109
  stackwise_num_strides=[1, 2, 2],
@@ -101,6 +117,8 @@ class ResNetBackbone(FeaturePyramidBackbone):
101
117
 
102
118
  def __init__(
103
119
  self,
120
+ input_conv_filters,
121
+ input_conv_kernel_sizes,
104
122
  stackwise_num_filters,
105
123
  stackwise_num_blocks,
106
124
  stackwise_num_strides,
@@ -113,6 +131,13 @@ class ResNetBackbone(FeaturePyramidBackbone):
113
131
  dtype=None,
114
132
  **kwargs,
115
133
  ):
134
+ if len(input_conv_filters) != len(input_conv_kernel_sizes):
135
+ raise ValueError(
136
+ "The length of `input_conv_filters` and"
137
+ "`input_conv_kernel_sizes` must be the same. "
138
+ f"Received: input_conv_filters={input_conv_filters}, "
139
+ f"input_conv_kernel_sizes={input_conv_kernel_sizes}."
140
+ )
116
141
  if len(stackwise_num_filters) != len(stackwise_num_blocks) or len(
117
142
  stackwise_num_filters
118
143
  ) != len(stackwise_num_strides):
@@ -128,14 +153,20 @@ class ResNetBackbone(FeaturePyramidBackbone):
128
153
  "The first element of `stackwise_num_filters` must be 64. "
129
154
  f"Received: stackwise_num_filters={stackwise_num_filters}"
130
155
  )
131
- if block_type not in ("basic_block", "bottleneck_block"):
156
+ if block_type not in (
157
+ "basic_block",
158
+ "bottleneck_block",
159
+ "basic_block_vd",
160
+ "bottleneck_block_vd",
161
+ ):
132
162
  raise ValueError(
133
- '`block_type` must be either `"basic_block"` or '
134
- f'`"bottleneck_block"`. Received block_type={block_type}.'
163
+ '`block_type` must be either `"basic_block"`, '
164
+ '`"bottleneck_block"`, `"basic_block_vd"` or '
165
+ f'`"bottleneck_block_vd"`. Received block_type={block_type}.'
135
166
  )
136
- version = "v1" if not use_pre_activation else "v2"
137
167
  data_format = standardize_data_format(data_format)
138
168
  bn_axis = -1 if data_format == "channels_last" else 1
169
+ num_input_convs = len(input_conv_filters)
139
170
  num_stacks = len(stackwise_num_filters)
140
171
 
141
172
  # === Functional Model ===
@@ -155,29 +186,56 @@ class ResNetBackbone(FeaturePyramidBackbone):
155
186
  # The padding between torch and tensorflow/jax differs when `strides>1`.
156
187
  # Therefore, we need to manually pad the tensor.
157
188
  x = layers.ZeroPadding2D(
158
- 3,
189
+ (input_conv_kernel_sizes[0] - 1) // 2,
159
190
  data_format=data_format,
160
191
  dtype=dtype,
161
192
  name="conv1_pad",
162
193
  )(x)
163
194
  x = layers.Conv2D(
164
- 64,
165
- 7,
195
+ input_conv_filters[0],
196
+ input_conv_kernel_sizes[0],
166
197
  strides=2,
167
198
  data_format=data_format,
168
199
  use_bias=False,
200
+ padding="valid",
169
201
  dtype=dtype,
170
202
  name="conv1_conv",
171
203
  )(x)
204
+ for conv_index in range(1, num_input_convs):
205
+ x = layers.BatchNormalization(
206
+ axis=bn_axis,
207
+ epsilon=1e-5,
208
+ momentum=0.9,
209
+ dtype=dtype,
210
+ name=f"conv{conv_index}_bn",
211
+ )(x)
212
+ x = layers.Activation(
213
+ "relu", dtype=dtype, name=f"conv{conv_index}_relu"
214
+ )(x)
215
+ x = layers.Conv2D(
216
+ input_conv_filters[conv_index],
217
+ input_conv_kernel_sizes[conv_index],
218
+ strides=1,
219
+ data_format=data_format,
220
+ use_bias=False,
221
+ padding="same",
222
+ dtype=dtype,
223
+ name=f"conv{conv_index+1}_conv",
224
+ )(x)
225
+
172
226
  if not use_pre_activation:
173
227
  x = layers.BatchNormalization(
174
228
  axis=bn_axis,
175
229
  epsilon=1e-5,
176
230
  momentum=0.9,
177
231
  dtype=dtype,
178
- name="conv1_bn",
232
+ name=f"conv{num_input_convs}_bn",
233
+ )(x)
234
+ x = layers.Activation(
235
+ "relu",
236
+ dtype=dtype,
237
+ name=f"conv{num_input_convs}_relu",
179
238
  )(x)
180
- x = layers.Activation("relu", dtype=dtype, name="conv1_relu")(x)
181
239
 
182
240
  if use_pre_activation:
183
241
  # A workaround for ResNetV2: we need -inf padding to prevent zeros
@@ -210,12 +268,10 @@ class ResNetBackbone(FeaturePyramidBackbone):
210
268
  stride=stackwise_num_strides[stack_index],
211
269
  block_type=block_type,
212
270
  use_pre_activation=use_pre_activation,
213
- first_shortcut=(
214
- block_type == "bottleneck_block" or stack_index > 0
215
- ),
271
+ first_shortcut=(block_type != "basic_block" or stack_index > 0),
216
272
  data_format=data_format,
217
273
  dtype=dtype,
218
- name=f"{version}_stack{stack_index}",
274
+ name=f"stack{stack_index}",
219
275
  )
220
276
  pyramid_outputs[f"P{stack_index + 2}"] = x
221
277
 
@@ -248,6 +304,8 @@ class ResNetBackbone(FeaturePyramidBackbone):
248
304
  )
249
305
 
250
306
  # === Config ===
307
+ self.input_conv_filters = input_conv_filters
308
+ self.input_conv_kernel_sizes = input_conv_kernel_sizes
251
309
  self.stackwise_num_filters = stackwise_num_filters
252
310
  self.stackwise_num_blocks = stackwise_num_blocks
253
311
  self.stackwise_num_strides = stackwise_num_strides
@@ -262,6 +320,8 @@ class ResNetBackbone(FeaturePyramidBackbone):
262
320
  config = super().get_config()
263
321
  config.update(
264
322
  {
323
+ "input_conv_filters": self.input_conv_filters,
324
+ "input_conv_kernel_sizes": self.input_conv_kernel_sizes,
265
325
  "stackwise_num_filters": self.stackwise_num_filters,
266
326
  "stackwise_num_blocks": self.stackwise_num_blocks,
267
327
  "stackwise_num_strides": self.stackwise_num_strides,
@@ -327,7 +387,10 @@ def apply_basic_block(
327
387
  )(x_preact)
328
388
 
329
389
  if conv_shortcut:
330
- x = x_preact if x_preact is not None else x
390
+ if x_preact is not None:
391
+ shortcut = x_preact
392
+ else:
393
+ shortcut = x
331
394
  shortcut = layers.Conv2D(
332
395
  filters,
333
396
  1,
@@ -336,7 +399,7 @@ def apply_basic_block(
336
399
  use_bias=False,
337
400
  dtype=dtype,
338
401
  name=f"{name}_0_conv",
339
- )(x)
402
+ )(shortcut)
340
403
  if not use_pre_activation:
341
404
  shortcut = layers.BatchNormalization(
342
405
  axis=bn_axis,
@@ -452,7 +515,10 @@ def apply_bottleneck_block(
452
515
  )(x_preact)
453
516
 
454
517
  if conv_shortcut:
455
- x = x_preact if x_preact is not None else x
518
+ if x_preact is not None:
519
+ shortcut = x_preact
520
+ else:
521
+ shortcut = x
456
522
  shortcut = layers.Conv2D(
457
523
  4 * filters,
458
524
  1,
@@ -461,7 +527,295 @@ def apply_bottleneck_block(
461
527
  use_bias=False,
462
528
  dtype=dtype,
463
529
  name=f"{name}_0_conv",
530
+ )(shortcut)
531
+ if not use_pre_activation:
532
+ shortcut = layers.BatchNormalization(
533
+ axis=bn_axis,
534
+ epsilon=1e-5,
535
+ momentum=0.9,
536
+ dtype=dtype,
537
+ name=f"{name}_0_bn",
538
+ )(shortcut)
539
+ else:
540
+ shortcut = x
541
+
542
+ x = x_preact if x_preact is not None else x
543
+ x = layers.Conv2D(
544
+ filters,
545
+ 1,
546
+ strides=1,
547
+ data_format=data_format,
548
+ use_bias=False,
549
+ dtype=dtype,
550
+ name=f"{name}_1_conv",
551
+ )(x)
552
+ x = layers.BatchNormalization(
553
+ axis=bn_axis,
554
+ epsilon=1e-5,
555
+ momentum=0.9,
556
+ dtype=dtype,
557
+ name=f"{name}_1_bn",
558
+ )(x)
559
+ x = layers.Activation("relu", dtype=dtype, name=f"{name}_1_relu")(x)
560
+
561
+ if stride > 1:
562
+ x = layers.ZeroPadding2D(
563
+ (kernel_size - 1) // 2,
564
+ data_format=data_format,
565
+ dtype=dtype,
566
+ name=f"{name}_2_pad",
464
567
  )(x)
568
+ x = layers.Conv2D(
569
+ filters,
570
+ kernel_size,
571
+ strides=stride,
572
+ padding="valid" if stride > 1 else "same",
573
+ data_format=data_format,
574
+ use_bias=False,
575
+ dtype=dtype,
576
+ name=f"{name}_2_conv",
577
+ )(x)
578
+ x = layers.BatchNormalization(
579
+ axis=bn_axis,
580
+ epsilon=1e-5,
581
+ momentum=0.9,
582
+ dtype=dtype,
583
+ name=f"{name}_2_bn",
584
+ )(x)
585
+ x = layers.Activation("relu", dtype=dtype, name=f"{name}_2_relu")(x)
586
+
587
+ x = layers.Conv2D(
588
+ 4 * filters,
589
+ 1,
590
+ data_format=data_format,
591
+ use_bias=False,
592
+ dtype=dtype,
593
+ name=f"{name}_3_conv",
594
+ )(x)
595
+ if not use_pre_activation:
596
+ x = layers.BatchNormalization(
597
+ axis=bn_axis,
598
+ epsilon=1e-5,
599
+ momentum=0.9,
600
+ dtype=dtype,
601
+ name=f"{name}_3_bn",
602
+ )(x)
603
+ x = layers.Add(dtype=dtype, name=f"{name}_add")([shortcut, x])
604
+ x = layers.Activation("relu", dtype=dtype, name=f"{name}_out")(x)
605
+ else:
606
+ x = layers.Add(dtype=dtype, name=f"{name}_out")([shortcut, x])
607
+ return x
608
+
609
+
610
+ def apply_basic_block_vd(
611
+ x,
612
+ filters,
613
+ kernel_size=3,
614
+ stride=1,
615
+ conv_shortcut=False,
616
+ use_pre_activation=False,
617
+ data_format=None,
618
+ dtype=None,
619
+ name=None,
620
+ ):
621
+ """Applies a basic residual block.
622
+
623
+ Args:
624
+ x: Tensor. The input tensor to pass through the block.
625
+ filters: int. The number of filters in the block.
626
+ kernel_size: int. The kernel size of the bottleneck layer. Defaults to
627
+ `3`.
628
+ stride: int. The stride length of the first layer. Defaults to `1`.
629
+ conv_shortcut: bool. If `True`, use a convolution shortcut. If `False`,
630
+ use an identity or pooling shortcut based on the stride. Defaults to
631
+ `False`.
632
+ use_pre_activation: boolean. Whether to use pre-activation or not.
633
+ `True` for ResNetV2, `False` for ResNet. Defaults to `False`.
634
+ data_format: `None` or str. the ordering of the dimensions in the
635
+ inputs. Can be `"channels_last"`
636
+ (`(batch_size, height, width, channels)`) or`"channels_first"`
637
+ (`(batch_size, channels, height, width)`).
638
+ dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
639
+ to use for the models computations and weights.
640
+ name: str. A prefix for the layer names used in the block.
641
+
642
+ Returns:
643
+ The output tensor for the basic residual block.
644
+ """
645
+ data_format = data_format or keras.config.image_data_format()
646
+ bn_axis = -1 if data_format == "channels_last" else 1
647
+
648
+ x_preact = None
649
+ if use_pre_activation:
650
+ x_preact = layers.BatchNormalization(
651
+ axis=bn_axis,
652
+ epsilon=1e-5,
653
+ momentum=0.9,
654
+ dtype=dtype,
655
+ name=f"{name}_pre_activation_bn",
656
+ )(x)
657
+ x_preact = layers.Activation(
658
+ "relu", dtype=dtype, name=f"{name}_pre_activation_relu"
659
+ )(x_preact)
660
+
661
+ if conv_shortcut:
662
+ if x_preact is not None:
663
+ shortcut = x_preact
664
+ elif stride > 1:
665
+ shortcut = layers.AveragePooling2D(
666
+ 2,
667
+ strides=stride,
668
+ data_format=data_format,
669
+ dtype=dtype,
670
+ padding="same",
671
+ )(x)
672
+ else:
673
+ shortcut = x
674
+ shortcut = layers.Conv2D(
675
+ filters,
676
+ 1,
677
+ strides=1,
678
+ data_format=data_format,
679
+ use_bias=False,
680
+ dtype=dtype,
681
+ name=f"{name}_0_conv",
682
+ )(shortcut)
683
+ if not use_pre_activation:
684
+ shortcut = layers.BatchNormalization(
685
+ axis=bn_axis,
686
+ epsilon=1e-5,
687
+ momentum=0.9,
688
+ dtype=dtype,
689
+ name=f"{name}_0_bn",
690
+ )(shortcut)
691
+ else:
692
+ shortcut = x
693
+
694
+ x = x_preact if x_preact is not None else x
695
+ if stride > 1:
696
+ x = layers.ZeroPadding2D(
697
+ (kernel_size - 1) // 2,
698
+ data_format=data_format,
699
+ dtype=dtype,
700
+ name=f"{name}_1_pad",
701
+ )(x)
702
+ x = layers.Conv2D(
703
+ filters,
704
+ kernel_size,
705
+ strides=stride,
706
+ padding="valid" if stride > 1 else "same",
707
+ data_format=data_format,
708
+ use_bias=False,
709
+ dtype=dtype,
710
+ name=f"{name}_1_conv",
711
+ )(x)
712
+ x = layers.BatchNormalization(
713
+ axis=bn_axis,
714
+ epsilon=1e-5,
715
+ momentum=0.9,
716
+ dtype=dtype,
717
+ name=f"{name}_1_bn",
718
+ )(x)
719
+ x = layers.Activation("relu", dtype=dtype, name=f"{name}_1_relu")(x)
720
+
721
+ x = layers.Conv2D(
722
+ filters,
723
+ kernel_size,
724
+ strides=1,
725
+ padding="same",
726
+ data_format=data_format,
727
+ use_bias=False,
728
+ dtype=dtype,
729
+ name=f"{name}_2_conv",
730
+ )(x)
731
+ if not use_pre_activation:
732
+ x = layers.BatchNormalization(
733
+ axis=bn_axis,
734
+ epsilon=1e-5,
735
+ momentum=0.9,
736
+ dtype=dtype,
737
+ name=f"{name}_2_bn",
738
+ )(x)
739
+ x = layers.Add(dtype=dtype, name=f"{name}_add")([shortcut, x])
740
+ x = layers.Activation("relu", dtype=dtype, name=f"{name}_out")(x)
741
+ else:
742
+ x = layers.Add(dtype=dtype, name=f"{name}_out")([shortcut, x])
743
+ return x
744
+
745
+
746
+ def apply_bottleneck_block_vd(
747
+ x,
748
+ filters,
749
+ kernel_size=3,
750
+ stride=1,
751
+ conv_shortcut=False,
752
+ use_pre_activation=False,
753
+ data_format=None,
754
+ dtype=None,
755
+ name=None,
756
+ ):
757
+ """Applies a bottleneck residual block.
758
+
759
+ Args:
760
+ x: Tensor. The input tensor to pass through the block.
761
+ filters: int. The number of filters in the block.
762
+ kernel_size: int. The kernel size of the bottleneck layer. Defaults to
763
+ `3`.
764
+ stride: int. The stride length of the first layer. Defaults to `1`.
765
+ conv_shortcut: bool. If `True`, use a convolution shortcut. If `False`,
766
+ use an identity or pooling shortcut based on the stride. Defaults to
767
+ `False`.
768
+ use_pre_activation: boolean. Whether to use pre-activation or not.
769
+ `True` for ResNetV2, `False` for ResNet. Defaults to `False`.
770
+ data_format: `None` or str. the ordering of the dimensions in the
771
+ inputs. Can be `"channels_last"`
772
+ (`(batch_size, height, width, channels)`) or`"channels_first"`
773
+ (`(batch_size, channels, height, width)`).
774
+ dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
775
+ to use for the models computations and weights.
776
+ name: str. A prefix for the layer names used in the block.
777
+
778
+ Returns:
779
+ The output tensor for the residual block.
780
+ """
781
+ data_format = data_format or keras.config.image_data_format()
782
+ bn_axis = -1 if data_format == "channels_last" else 1
783
+
784
+ x_preact = None
785
+ if use_pre_activation:
786
+ x_preact = layers.BatchNormalization(
787
+ axis=bn_axis,
788
+ epsilon=1e-5,
789
+ momentum=0.9,
790
+ dtype=dtype,
791
+ name=f"{name}_pre_activation_bn",
792
+ )(x)
793
+ x_preact = layers.Activation(
794
+ "relu", dtype=dtype, name=f"{name}_pre_activation_relu"
795
+ )(x_preact)
796
+
797
+ if conv_shortcut:
798
+ if x_preact is not None:
799
+ shortcut = x_preact
800
+ elif stride > 1:
801
+ shortcut = layers.AveragePooling2D(
802
+ 2,
803
+ strides=stride,
804
+ data_format=data_format,
805
+ dtype=dtype,
806
+ padding="same",
807
+ )(x)
808
+ else:
809
+ shortcut = x
810
+ shortcut = layers.Conv2D(
811
+ 4 * filters,
812
+ 1,
813
+ strides=1,
814
+ data_format=data_format,
815
+ use_bias=False,
816
+ dtype=dtype,
817
+ name=f"{name}_0_conv",
818
+ )(shortcut)
465
819
  if not use_pre_activation:
466
820
  shortcut = layers.BatchNormalization(
467
821
  axis=bn_axis,
@@ -561,8 +915,11 @@ def apply_stack(
561
915
  blocks: int. The number of blocks in the stack.
562
916
  stride: int. The stride length of the first layer in the first block.
563
917
  block_type: str. The block type to stack. One of `"basic_block"` or
564
- `"bottleneck_block"`. Use `"basic_block"` for ResNet18 and ResNet34.
565
- Use `"bottleneck_block"` for ResNet50, ResNet101 and ResNet152.
918
+ `"bottleneck_block"`, `"basic_block_vd"` or
919
+ `"bottleneck_block_vd"`. Use `"basic_block"` for ResNet18 and
920
+ ResNet34. Use `"bottleneck_block"` for ResNet50, ResNet101 and
921
+ ResNet152 and the `"_vd"` prefix for the respective ResNet_vd
922
+ variants.
566
923
  use_pre_activation: boolean. Whether to use pre-activation or not.
567
924
  `True` for ResNetV2, `False` for ResNet and ResNeXt.
568
925
  first_shortcut: bool. If `True`, use a convolution shortcut. If `False`,
@@ -580,17 +937,21 @@ def apply_stack(
580
937
  Output tensor for the stacked blocks.
581
938
  """
582
939
  if name is None:
583
- version = "v1" if not use_pre_activation else "v2"
584
- name = f"{version}_stack"
940
+ name = "stack"
585
941
 
586
942
  if block_type == "basic_block":
587
943
  block_fn = apply_basic_block
588
944
  elif block_type == "bottleneck_block":
589
945
  block_fn = apply_bottleneck_block
946
+ elif block_type == "basic_block_vd":
947
+ block_fn = apply_basic_block_vd
948
+ elif block_type == "bottleneck_block_vd":
949
+ block_fn = apply_bottleneck_block_vd
590
950
  else:
591
951
  raise ValueError(
592
- '`block_type` must be either `"basic_block"` or '
593
- f'`"bottleneck_block"`. Received block_type={block_type}.'
952
+ '`block_type` must be either `"basic_block"`, '
953
+ '`"bottleneck_block"`, `"basic_block_vd"` or '
954
+ f'`"bottleneck_block_vd"`. Received block_type={block_type}.'
594
955
  )
595
956
  for i in range(blocks):
596
957
  if i == 0:
@@ -16,6 +16,9 @@ import keras
16
16
  from keras_hub.src.api_export import keras_hub_export
17
17
  from keras_hub.src.models.image_classifier import ImageClassifier
18
18
  from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone
19
+ from keras_hub.src.models.resnet.resnet_image_classifier_preprocessor import (
20
+ ResNetImageClassifierPreprocessor,
21
+ )
19
22
 
20
23
 
21
24
  @keras_hub_export("keras_hub.models.ResNetImageClassifier")
@@ -88,21 +91,22 @@ class ResNetImageClassifier(ImageClassifier):
88
91
  """
89
92
 
90
93
  backbone_cls = ResNetBackbone
94
+ preprocessor_cls = ResNetImageClassifierPreprocessor
91
95
 
92
96
  def __init__(
93
97
  self,
94
98
  backbone,
95
99
  num_classes,
96
- activation="softmax",
100
+ preprocessor=None,
101
+ activation=None,
97
102
  head_dtype=None,
98
- preprocessor=None, # adding this dummy arg for saved model test
99
- # TODO: once preprocessor flow is figured out, this needs to be updated
100
103
  **kwargs,
101
104
  ):
102
105
  head_dtype = head_dtype or backbone.dtype_policy
103
106
 
104
107
  # === Layers ===
105
108
  self.backbone = backbone
109
+ self.preprocessor = preprocessor
106
110
  self.output_dense = keras.layers.Dense(
107
111
  num_classes,
108
112
  activation=activation,
@@ -0,0 +1,28 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from keras_hub.src.api_export import keras_hub_export
16
+ from keras_hub.src.models.image_classifier_preprocessor import (
17
+ ImageClassifierPreprocessor,
18
+ )
19
+ from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone
20
+ from keras_hub.src.models.resnet.resnet_image_converter import (
21
+ ResNetImageConverter,
22
+ )
23
+
24
+
25
+ @keras_hub_export("keras_hub.models.ResNetImageClassifierPreprocessor")
26
+ class ResNetImageClassifierPreprocessor(ImageClassifierPreprocessor):
27
+ backbone_cls = ResNetBackbone
28
+ image_converter_cls = ResNetImageConverter
@@ -0,0 +1,23 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from keras_hub.src.api_export import keras_hub_export
15
+ from keras_hub.src.layers.preprocessing.resizing_image_converter import (
16
+ ResizingImageConverter,
17
+ )
18
+ from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone
19
+
20
+
21
+ @keras_hub_export("keras_hub.layers.ResNetImageConverter")
22
+ class ResNetImageConverter(ResizingImageConverter):
23
+ backbone_cls = ResNetBackbone