keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl → 0.15.0.dev20240911134614__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/__init__.py +1 -0
- keras_hub/api/bounding_box/__init__.py +36 -0
- keras_hub/api/layers/__init__.py +14 -0
- keras_hub/api/models/__init__.py +75 -31
- keras_hub/api/tokenizers/__init__.py +30 -0
- keras_hub/src/bounding_box/__init__.py +13 -0
- keras_hub/src/bounding_box/converters.py +529 -0
- keras_hub/src/bounding_box/formats.py +162 -0
- keras_hub/src/bounding_box/iou.py +263 -0
- keras_hub/src/bounding_box/to_dense.py +95 -0
- keras_hub/src/bounding_box/to_ragged.py +99 -0
- keras_hub/src/bounding_box/utils.py +194 -0
- keras_hub/src/bounding_box/validate_format.py +99 -0
- keras_hub/src/layers/preprocessing/audio_converter.py +121 -0
- keras_hub/src/layers/preprocessing/image_converter.py +130 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +2 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +9 -8
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +2 -29
- keras_hub/src/layers/preprocessing/random_deletion.py +33 -31
- keras_hub/src/layers/preprocessing/random_swap.py +33 -31
- keras_hub/src/layers/preprocessing/resizing_image_converter.py +101 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +3 -2
- keras_hub/src/models/albert/__init__.py +1 -2
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +6 -86
- keras_hub/src/models/albert/{albert_classifier.py → albert_text_classifier.py} +29 -10
- keras_hub/src/models/albert/{albert_preprocessor.py → albert_text_classifier_preprocessor.py} +14 -70
- keras_hub/src/models/albert/albert_tokenizer.py +17 -36
- keras_hub/src/models/backbone.py +12 -34
- keras_hub/src/models/bart/__init__.py +1 -2
- keras_hub/src/models/bart/bart_preprocessor.py +6 -18
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +21 -148
- keras_hub/src/models/bart/bart_tokenizer.py +12 -39
- keras_hub/src/models/bert/__init__.py +1 -5
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +6 -87
- keras_hub/src/models/bert/bert_presets.py +1 -4
- keras_hub/src/models/bert/{bert_classifier.py → bert_text_classifier.py} +12 -10
- keras_hub/src/models/bert/{bert_preprocessor.py → bert_text_classifier_preprocessor.py} +14 -70
- keras_hub/src/models/bert/bert_tokenizer.py +17 -35
- keras_hub/src/models/bloom/__init__.py +1 -2
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +6 -91
- keras_hub/src/models/bloom/bloom_preprocessor.py +5 -12
- keras_hub/src/models/bloom/bloom_tokenizer.py +12 -41
- keras_hub/src/models/causal_lm.py +10 -29
- keras_hub/src/models/causal_lm_preprocessor.py +195 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +54 -15
- keras_hub/src/models/deberta_v3/__init__.py +1 -4
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +14 -77
- keras_hub/src/models/deberta_v3/{deberta_v3_classifier.py → deberta_v3_text_classifier.py} +11 -11
- keras_hub/src/models/deberta_v3/{deberta_v3_preprocessor.py → deberta_v3_text_classifier_preprocessor.py} +23 -64
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +30 -25
- keras_hub/src/models/densenet/densenet_backbone.py +46 -22
- keras_hub/src/models/distil_bert/__init__.py +1 -4
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +14 -76
- keras_hub/src/models/distil_bert/{distil_bert_classifier.py → distil_bert_text_classifier.py} +12 -12
- keras_hub/src/models/distil_bert/{distil_bert_preprocessor.py → distil_bert_text_classifier_preprocessor.py} +23 -63
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +19 -35
- keras_hub/src/models/efficientnet/__init__.py +13 -0
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +569 -0
- keras_hub/src/models/efficientnet/fusedmbconv.py +229 -0
- keras_hub/src/models/efficientnet/mbconv.py +238 -0
- keras_hub/src/models/electra/__init__.py +1 -2
- keras_hub/src/models/electra/electra_preprocessor.py +6 -5
- keras_hub/src/models/electra/electra_tokenizer.py +17 -32
- keras_hub/src/models/f_net/__init__.py +1 -2
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +12 -78
- keras_hub/src/models/f_net/{f_net_classifier.py → f_net_text_classifier.py} +10 -8
- keras_hub/src/models/f_net/{f_net_preprocessor.py → f_net_text_classifier_preprocessor.py} +19 -63
- keras_hub/src/models/f_net/f_net_tokenizer.py +17 -35
- keras_hub/src/models/falcon/__init__.py +1 -2
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/falcon/falcon_preprocessor.py +5 -12
- keras_hub/src/models/falcon/falcon_tokenizer.py +12 -35
- keras_hub/src/models/gemma/__init__.py +1 -2
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +6 -90
- keras_hub/src/models/gemma/gemma_preprocessor.py +5 -12
- keras_hub/src/models/gemma/gemma_tokenizer.py +12 -23
- keras_hub/src/models/gpt2/__init__.py +1 -2
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +5 -12
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +12 -34
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +6 -91
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +5 -12
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +12 -34
- keras_hub/src/models/image_classifier.py +0 -5
- keras_hub/src/models/image_classifier_preprocessor.py +83 -0
- keras_hub/src/models/llama/__init__.py +1 -2
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +6 -85
- keras_hub/src/models/llama/llama_preprocessor.py +5 -12
- keras_hub/src/models/llama/llama_tokenizer.py +12 -25
- keras_hub/src/models/llama3/__init__.py +1 -2
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/llama3/llama3_preprocessor.py +2 -0
- keras_hub/src/models/llama3/llama3_tokenizer.py +12 -33
- keras_hub/src/models/masked_lm.py +0 -2
- keras_hub/src/models/masked_lm_preprocessor.py +156 -0
- keras_hub/src/models/mistral/__init__.py +1 -2
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +6 -91
- keras_hub/src/models/mistral/mistral_preprocessor.py +5 -12
- keras_hub/src/models/mistral/mistral_tokenizer.py +12 -23
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +2 -2
- keras_hub/src/models/mobilenet/__init__.py +13 -0
- keras_hub/src/models/mobilenet/mobilenet_backbone.py +530 -0
- keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +114 -0
- keras_hub/src/models/opt/__init__.py +1 -2
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +6 -93
- keras_hub/src/models/opt/opt_preprocessor.py +5 -12
- keras_hub/src/models/opt/opt_tokenizer.py +12 -41
- keras_hub/src/models/pali_gemma/__init__.py +1 -4
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +28 -28
- keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +25 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +5 -5
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +10 -2
- keras_hub/src/models/phi3/__init__.py +1 -2
- keras_hub/src/models/phi3/phi3_causal_lm.py +3 -9
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/phi3/phi3_preprocessor.py +5 -12
- keras_hub/src/models/phi3/phi3_tokenizer.py +12 -36
- keras_hub/src/models/preprocessor.py +76 -83
- keras_hub/src/models/resnet/__init__.py +6 -0
- keras_hub/src/models/resnet/resnet_backbone.py +387 -26
- keras_hub/src/models/resnet/resnet_image_classifier.py +7 -3
- keras_hub/src/models/resnet/resnet_image_classifier_preprocessor.py +28 -0
- keras_hub/src/models/resnet/resnet_image_converter.py +23 -0
- keras_hub/src/models/resnet/resnet_presets.py +95 -0
- keras_hub/src/models/roberta/__init__.py +1 -2
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +22 -74
- keras_hub/src/models/roberta/{roberta_classifier.py → roberta_text_classifier.py} +11 -11
- keras_hub/src/models/roberta/{roberta_preprocessor.py → roberta_text_classifier_preprocessor.py} +21 -53
- keras_hub/src/models/roberta/roberta_tokenizer.py +13 -52
- keras_hub/src/models/seq_2_seq_lm_preprocessor.py +269 -0
- keras_hub/src/models/stable_diffusion_v3/__init__.py +13 -0
- keras_hub/src/models/stable_diffusion_v3/clip_encoder_block.py +103 -0
- keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +93 -0
- keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py +149 -0
- keras_hub/src/models/stable_diffusion_v3/clip_tokenizer.py +167 -0
- keras_hub/src/models/stable_diffusion_v3/mmdit.py +427 -0
- keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +317 -0
- keras_hub/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py +74 -0
- keras_hub/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py +155 -0
- keras_hub/src/models/stable_diffusion_v3/vae_attention.py +126 -0
- keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +186 -0
- keras_hub/src/models/t5/__init__.py +1 -2
- keras_hub/src/models/t5/t5_tokenizer.py +13 -23
- keras_hub/src/models/task.py +71 -116
- keras_hub/src/models/{classifier.py → text_classifier.py} +8 -13
- keras_hub/src/models/text_classifier_preprocessor.py +138 -0
- keras_hub/src/models/whisper/__init__.py +1 -2
- keras_hub/src/models/whisper/{whisper_audio_feature_extractor.py → whisper_audio_converter.py} +20 -18
- keras_hub/src/models/whisper/whisper_backbone.py +0 -3
- keras_hub/src/models/whisper/whisper_presets.py +10 -10
- keras_hub/src/models/whisper/whisper_tokenizer.py +20 -16
- keras_hub/src/models/xlm_roberta/__init__.py +1 -4
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +26 -72
- keras_hub/src/models/xlm_roberta/{xlm_roberta_classifier.py → xlm_roberta_text_classifier.py} +11 -11
- keras_hub/src/models/xlm_roberta/{xlm_roberta_preprocessor.py → xlm_roberta_text_classifier_preprocessor.py} +26 -53
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +25 -10
- keras_hub/src/tests/test_case.py +25 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +29 -17
- keras_hub/src/tokenizers/byte_tokenizer.py +14 -15
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +19 -7
- keras_hub/src/tokenizers/tokenizer.py +67 -32
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +14 -15
- keras_hub/src/tokenizers/word_piece_tokenizer.py +33 -47
- keras_hub/src/utils/keras_utils.py +0 -50
- keras_hub/src/utils/preset_utils.py +238 -67
- keras_hub/src/utils/tensor_utils.py +187 -69
- keras_hub/src/utils/timm/convert_resnet.py +20 -16
- keras_hub/src/utils/timm/preset_loader.py +67 -0
- keras_hub/src/utils/transformers/convert_albert.py +193 -0
- keras_hub/src/utils/transformers/convert_bart.py +373 -0
- keras_hub/src/utils/transformers/convert_bert.py +7 -17
- keras_hub/src/utils/transformers/convert_distilbert.py +10 -20
- keras_hub/src/utils/transformers/convert_gemma.py +5 -19
- keras_hub/src/utils/transformers/convert_gpt2.py +5 -18
- keras_hub/src/utils/transformers/convert_llama3.py +7 -18
- keras_hub/src/utils/transformers/convert_mistral.py +129 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +7 -29
- keras_hub/src/utils/transformers/preset_loader.py +77 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +2 -2
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.15.0.dev20240911134614.dist-info}/METADATA +1 -2
- keras_hub_nightly-0.15.0.dev20240911134614.dist-info/RECORD +338 -0
- {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.15.0.dev20240911134614.dist-info}/WHEEL +1 -1
- keras_hub/src/models/whisper/whisper_preprocessor.py +0 -326
- keras_hub/src/utils/timm/convert.py +0 -37
- keras_hub/src/utils/transformers/convert.py +0 -101
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +0 -297
- {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.15.0.dev20240911134614.dist-info}/top_level.txt +0 -0
@@ -27,9 +27,10 @@ class ResNetBackbone(FeaturePyramidBackbone):
|
|
27
27
|
This class implements a ResNet backbone as described in [Deep Residual
|
28
28
|
Learning for Image Recognition](https://arxiv.org/abs/1512.03385)(
|
29
29
|
CVPR 2016), [Identity Mappings in Deep Residual Networks](
|
30
|
-
https://arxiv.org/abs/1603.05027)(ECCV 2016)
|
30
|
+
https://arxiv.org/abs/1603.05027)(ECCV 2016), [ResNet strikes back: An
|
31
31
|
improved training procedure in timm](https://arxiv.org/abs/2110.00476)(
|
32
|
-
NeurIPS 2021 Workshop)
|
32
|
+
NeurIPS 2021 Workshop) and [Bag of Tricks for Image Classification with
|
33
|
+
Convolutional Neural Networks](https://arxiv.org/abs/1812.01187).
|
33
34
|
|
34
35
|
The difference in ResNet and ResNetV2 rests in the structure of their
|
35
36
|
individual building blocks. In ResNetV2, the batch normalization and
|
@@ -37,18 +38,31 @@ class ResNetBackbone(FeaturePyramidBackbone):
|
|
37
38
|
the batch normalization and ReLU activation are applied after the
|
38
39
|
convolution layers.
|
39
40
|
|
41
|
+
ResNetVd introduces two key modifications to the standard ResNet. First,
|
42
|
+
the initial convolutional layer is replaced by a series of three
|
43
|
+
successive convolutional layers. Second, shortcut connections use an
|
44
|
+
additional pooling operation rather than performing downsampling within
|
45
|
+
the convolutional layers themselves.
|
46
|
+
|
40
47
|
Note that `ResNetBackbone` expects the inputs to be images with a value
|
41
48
|
range of `[0, 255]` when `include_rescaling=True`.
|
42
49
|
|
43
50
|
Args:
|
51
|
+
input_conv_filters: list of ints. The number of filters of the initial
|
52
|
+
convolution(s).
|
53
|
+
input_conv_kernel_sizes: list of ints. The kernel sizes of the initial
|
54
|
+
convolution(s).
|
44
55
|
stackwise_num_filters: list of ints. The number of filters for each
|
45
56
|
stack.
|
46
57
|
stackwise_num_blocks: list of ints. The number of blocks for each stack.
|
47
58
|
stackwise_num_strides: list of ints. The number of strides for each
|
48
59
|
stack.
|
49
60
|
block_type: str. The block type to stack. One of `"basic_block"` or
|
50
|
-
`"bottleneck_block"
|
51
|
-
Use `"
|
61
|
+
`"bottleneck_block"`, `"basic_block_vd"` or
|
62
|
+
`"bottleneck_block_vd"`. Use `"basic_block"` for ResNet18 and
|
63
|
+
ResNet34. Use `"bottleneck_block"` for ResNet50, ResNet101 and
|
64
|
+
ResNet152 and the `"_vd"` prefix for the respective ResNet_vd
|
65
|
+
variants.
|
52
66
|
use_pre_activation: boolean. Whether to use pre-activation or not.
|
53
67
|
`True` for ResNetV2, `False` for ResNet.
|
54
68
|
include_rescaling: boolean. If `True`, rescale the input using
|
@@ -88,6 +102,8 @@ class ResNetBackbone(FeaturePyramidBackbone):
|
|
88
102
|
|
89
103
|
# Randomly initialized ResNetV2 backbone with a custom config.
|
90
104
|
model = keras_hub.models.ResNetBackbone(
|
105
|
+
input_conv_filters=[64],
|
106
|
+
input_conv_kernel_sizes=[7],
|
91
107
|
stackwise_num_filters=[64, 64, 64],
|
92
108
|
stackwise_num_blocks=[2, 2, 2],
|
93
109
|
stackwise_num_strides=[1, 2, 2],
|
@@ -101,6 +117,8 @@ class ResNetBackbone(FeaturePyramidBackbone):
|
|
101
117
|
|
102
118
|
def __init__(
|
103
119
|
self,
|
120
|
+
input_conv_filters,
|
121
|
+
input_conv_kernel_sizes,
|
104
122
|
stackwise_num_filters,
|
105
123
|
stackwise_num_blocks,
|
106
124
|
stackwise_num_strides,
|
@@ -113,6 +131,13 @@ class ResNetBackbone(FeaturePyramidBackbone):
|
|
113
131
|
dtype=None,
|
114
132
|
**kwargs,
|
115
133
|
):
|
134
|
+
if len(input_conv_filters) != len(input_conv_kernel_sizes):
|
135
|
+
raise ValueError(
|
136
|
+
"The length of `input_conv_filters` and"
|
137
|
+
"`input_conv_kernel_sizes` must be the same. "
|
138
|
+
f"Received: input_conv_filters={input_conv_filters}, "
|
139
|
+
f"input_conv_kernel_sizes={input_conv_kernel_sizes}."
|
140
|
+
)
|
116
141
|
if len(stackwise_num_filters) != len(stackwise_num_blocks) or len(
|
117
142
|
stackwise_num_filters
|
118
143
|
) != len(stackwise_num_strides):
|
@@ -128,14 +153,20 @@ class ResNetBackbone(FeaturePyramidBackbone):
|
|
128
153
|
"The first element of `stackwise_num_filters` must be 64. "
|
129
154
|
f"Received: stackwise_num_filters={stackwise_num_filters}"
|
130
155
|
)
|
131
|
-
if block_type not in (
|
156
|
+
if block_type not in (
|
157
|
+
"basic_block",
|
158
|
+
"bottleneck_block",
|
159
|
+
"basic_block_vd",
|
160
|
+
"bottleneck_block_vd",
|
161
|
+
):
|
132
162
|
raise ValueError(
|
133
|
-
'`block_type` must be either `"basic_block"
|
134
|
-
|
163
|
+
'`block_type` must be either `"basic_block"`, '
|
164
|
+
'`"bottleneck_block"`, `"basic_block_vd"` or '
|
165
|
+
f'`"bottleneck_block_vd"`. Received block_type={block_type}.'
|
135
166
|
)
|
136
|
-
version = "v1" if not use_pre_activation else "v2"
|
137
167
|
data_format = standardize_data_format(data_format)
|
138
168
|
bn_axis = -1 if data_format == "channels_last" else 1
|
169
|
+
num_input_convs = len(input_conv_filters)
|
139
170
|
num_stacks = len(stackwise_num_filters)
|
140
171
|
|
141
172
|
# === Functional Model ===
|
@@ -155,29 +186,56 @@ class ResNetBackbone(FeaturePyramidBackbone):
|
|
155
186
|
# The padding between torch and tensorflow/jax differs when `strides>1`.
|
156
187
|
# Therefore, we need to manually pad the tensor.
|
157
188
|
x = layers.ZeroPadding2D(
|
158
|
-
|
189
|
+
(input_conv_kernel_sizes[0] - 1) // 2,
|
159
190
|
data_format=data_format,
|
160
191
|
dtype=dtype,
|
161
192
|
name="conv1_pad",
|
162
193
|
)(x)
|
163
194
|
x = layers.Conv2D(
|
164
|
-
|
165
|
-
|
195
|
+
input_conv_filters[0],
|
196
|
+
input_conv_kernel_sizes[0],
|
166
197
|
strides=2,
|
167
198
|
data_format=data_format,
|
168
199
|
use_bias=False,
|
200
|
+
padding="valid",
|
169
201
|
dtype=dtype,
|
170
202
|
name="conv1_conv",
|
171
203
|
)(x)
|
204
|
+
for conv_index in range(1, num_input_convs):
|
205
|
+
x = layers.BatchNormalization(
|
206
|
+
axis=bn_axis,
|
207
|
+
epsilon=1e-5,
|
208
|
+
momentum=0.9,
|
209
|
+
dtype=dtype,
|
210
|
+
name=f"conv{conv_index}_bn",
|
211
|
+
)(x)
|
212
|
+
x = layers.Activation(
|
213
|
+
"relu", dtype=dtype, name=f"conv{conv_index}_relu"
|
214
|
+
)(x)
|
215
|
+
x = layers.Conv2D(
|
216
|
+
input_conv_filters[conv_index],
|
217
|
+
input_conv_kernel_sizes[conv_index],
|
218
|
+
strides=1,
|
219
|
+
data_format=data_format,
|
220
|
+
use_bias=False,
|
221
|
+
padding="same",
|
222
|
+
dtype=dtype,
|
223
|
+
name=f"conv{conv_index+1}_conv",
|
224
|
+
)(x)
|
225
|
+
|
172
226
|
if not use_pre_activation:
|
173
227
|
x = layers.BatchNormalization(
|
174
228
|
axis=bn_axis,
|
175
229
|
epsilon=1e-5,
|
176
230
|
momentum=0.9,
|
177
231
|
dtype=dtype,
|
178
|
-
name="
|
232
|
+
name=f"conv{num_input_convs}_bn",
|
233
|
+
)(x)
|
234
|
+
x = layers.Activation(
|
235
|
+
"relu",
|
236
|
+
dtype=dtype,
|
237
|
+
name=f"conv{num_input_convs}_relu",
|
179
238
|
)(x)
|
180
|
-
x = layers.Activation("relu", dtype=dtype, name="conv1_relu")(x)
|
181
239
|
|
182
240
|
if use_pre_activation:
|
183
241
|
# A workaround for ResNetV2: we need -inf padding to prevent zeros
|
@@ -210,12 +268,10 @@ class ResNetBackbone(FeaturePyramidBackbone):
|
|
210
268
|
stride=stackwise_num_strides[stack_index],
|
211
269
|
block_type=block_type,
|
212
270
|
use_pre_activation=use_pre_activation,
|
213
|
-
first_shortcut=(
|
214
|
-
block_type == "bottleneck_block" or stack_index > 0
|
215
|
-
),
|
271
|
+
first_shortcut=(block_type != "basic_block" or stack_index > 0),
|
216
272
|
data_format=data_format,
|
217
273
|
dtype=dtype,
|
218
|
-
name=f"{
|
274
|
+
name=f"stack{stack_index}",
|
219
275
|
)
|
220
276
|
pyramid_outputs[f"P{stack_index + 2}"] = x
|
221
277
|
|
@@ -248,6 +304,8 @@ class ResNetBackbone(FeaturePyramidBackbone):
|
|
248
304
|
)
|
249
305
|
|
250
306
|
# === Config ===
|
307
|
+
self.input_conv_filters = input_conv_filters
|
308
|
+
self.input_conv_kernel_sizes = input_conv_kernel_sizes
|
251
309
|
self.stackwise_num_filters = stackwise_num_filters
|
252
310
|
self.stackwise_num_blocks = stackwise_num_blocks
|
253
311
|
self.stackwise_num_strides = stackwise_num_strides
|
@@ -262,6 +320,8 @@ class ResNetBackbone(FeaturePyramidBackbone):
|
|
262
320
|
config = super().get_config()
|
263
321
|
config.update(
|
264
322
|
{
|
323
|
+
"input_conv_filters": self.input_conv_filters,
|
324
|
+
"input_conv_kernel_sizes": self.input_conv_kernel_sizes,
|
265
325
|
"stackwise_num_filters": self.stackwise_num_filters,
|
266
326
|
"stackwise_num_blocks": self.stackwise_num_blocks,
|
267
327
|
"stackwise_num_strides": self.stackwise_num_strides,
|
@@ -327,7 +387,10 @@ def apply_basic_block(
|
|
327
387
|
)(x_preact)
|
328
388
|
|
329
389
|
if conv_shortcut:
|
330
|
-
|
390
|
+
if x_preact is not None:
|
391
|
+
shortcut = x_preact
|
392
|
+
else:
|
393
|
+
shortcut = x
|
331
394
|
shortcut = layers.Conv2D(
|
332
395
|
filters,
|
333
396
|
1,
|
@@ -336,7 +399,7 @@ def apply_basic_block(
|
|
336
399
|
use_bias=False,
|
337
400
|
dtype=dtype,
|
338
401
|
name=f"{name}_0_conv",
|
339
|
-
)(
|
402
|
+
)(shortcut)
|
340
403
|
if not use_pre_activation:
|
341
404
|
shortcut = layers.BatchNormalization(
|
342
405
|
axis=bn_axis,
|
@@ -452,7 +515,10 @@ def apply_bottleneck_block(
|
|
452
515
|
)(x_preact)
|
453
516
|
|
454
517
|
if conv_shortcut:
|
455
|
-
|
518
|
+
if x_preact is not None:
|
519
|
+
shortcut = x_preact
|
520
|
+
else:
|
521
|
+
shortcut = x
|
456
522
|
shortcut = layers.Conv2D(
|
457
523
|
4 * filters,
|
458
524
|
1,
|
@@ -461,7 +527,295 @@ def apply_bottleneck_block(
|
|
461
527
|
use_bias=False,
|
462
528
|
dtype=dtype,
|
463
529
|
name=f"{name}_0_conv",
|
530
|
+
)(shortcut)
|
531
|
+
if not use_pre_activation:
|
532
|
+
shortcut = layers.BatchNormalization(
|
533
|
+
axis=bn_axis,
|
534
|
+
epsilon=1e-5,
|
535
|
+
momentum=0.9,
|
536
|
+
dtype=dtype,
|
537
|
+
name=f"{name}_0_bn",
|
538
|
+
)(shortcut)
|
539
|
+
else:
|
540
|
+
shortcut = x
|
541
|
+
|
542
|
+
x = x_preact if x_preact is not None else x
|
543
|
+
x = layers.Conv2D(
|
544
|
+
filters,
|
545
|
+
1,
|
546
|
+
strides=1,
|
547
|
+
data_format=data_format,
|
548
|
+
use_bias=False,
|
549
|
+
dtype=dtype,
|
550
|
+
name=f"{name}_1_conv",
|
551
|
+
)(x)
|
552
|
+
x = layers.BatchNormalization(
|
553
|
+
axis=bn_axis,
|
554
|
+
epsilon=1e-5,
|
555
|
+
momentum=0.9,
|
556
|
+
dtype=dtype,
|
557
|
+
name=f"{name}_1_bn",
|
558
|
+
)(x)
|
559
|
+
x = layers.Activation("relu", dtype=dtype, name=f"{name}_1_relu")(x)
|
560
|
+
|
561
|
+
if stride > 1:
|
562
|
+
x = layers.ZeroPadding2D(
|
563
|
+
(kernel_size - 1) // 2,
|
564
|
+
data_format=data_format,
|
565
|
+
dtype=dtype,
|
566
|
+
name=f"{name}_2_pad",
|
464
567
|
)(x)
|
568
|
+
x = layers.Conv2D(
|
569
|
+
filters,
|
570
|
+
kernel_size,
|
571
|
+
strides=stride,
|
572
|
+
padding="valid" if stride > 1 else "same",
|
573
|
+
data_format=data_format,
|
574
|
+
use_bias=False,
|
575
|
+
dtype=dtype,
|
576
|
+
name=f"{name}_2_conv",
|
577
|
+
)(x)
|
578
|
+
x = layers.BatchNormalization(
|
579
|
+
axis=bn_axis,
|
580
|
+
epsilon=1e-5,
|
581
|
+
momentum=0.9,
|
582
|
+
dtype=dtype,
|
583
|
+
name=f"{name}_2_bn",
|
584
|
+
)(x)
|
585
|
+
x = layers.Activation("relu", dtype=dtype, name=f"{name}_2_relu")(x)
|
586
|
+
|
587
|
+
x = layers.Conv2D(
|
588
|
+
4 * filters,
|
589
|
+
1,
|
590
|
+
data_format=data_format,
|
591
|
+
use_bias=False,
|
592
|
+
dtype=dtype,
|
593
|
+
name=f"{name}_3_conv",
|
594
|
+
)(x)
|
595
|
+
if not use_pre_activation:
|
596
|
+
x = layers.BatchNormalization(
|
597
|
+
axis=bn_axis,
|
598
|
+
epsilon=1e-5,
|
599
|
+
momentum=0.9,
|
600
|
+
dtype=dtype,
|
601
|
+
name=f"{name}_3_bn",
|
602
|
+
)(x)
|
603
|
+
x = layers.Add(dtype=dtype, name=f"{name}_add")([shortcut, x])
|
604
|
+
x = layers.Activation("relu", dtype=dtype, name=f"{name}_out")(x)
|
605
|
+
else:
|
606
|
+
x = layers.Add(dtype=dtype, name=f"{name}_out")([shortcut, x])
|
607
|
+
return x
|
608
|
+
|
609
|
+
|
610
|
+
def apply_basic_block_vd(
|
611
|
+
x,
|
612
|
+
filters,
|
613
|
+
kernel_size=3,
|
614
|
+
stride=1,
|
615
|
+
conv_shortcut=False,
|
616
|
+
use_pre_activation=False,
|
617
|
+
data_format=None,
|
618
|
+
dtype=None,
|
619
|
+
name=None,
|
620
|
+
):
|
621
|
+
"""Applies a basic residual block.
|
622
|
+
|
623
|
+
Args:
|
624
|
+
x: Tensor. The input tensor to pass through the block.
|
625
|
+
filters: int. The number of filters in the block.
|
626
|
+
kernel_size: int. The kernel size of the bottleneck layer. Defaults to
|
627
|
+
`3`.
|
628
|
+
stride: int. The stride length of the first layer. Defaults to `1`.
|
629
|
+
conv_shortcut: bool. If `True`, use a convolution shortcut. If `False`,
|
630
|
+
use an identity or pooling shortcut based on the stride. Defaults to
|
631
|
+
`False`.
|
632
|
+
use_pre_activation: boolean. Whether to use pre-activation or not.
|
633
|
+
`True` for ResNetV2, `False` for ResNet. Defaults to `False`.
|
634
|
+
data_format: `None` or str. the ordering of the dimensions in the
|
635
|
+
inputs. Can be `"channels_last"`
|
636
|
+
(`(batch_size, height, width, channels)`) or`"channels_first"`
|
637
|
+
(`(batch_size, channels, height, width)`).
|
638
|
+
dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
|
639
|
+
to use for the models computations and weights.
|
640
|
+
name: str. A prefix for the layer names used in the block.
|
641
|
+
|
642
|
+
Returns:
|
643
|
+
The output tensor for the basic residual block.
|
644
|
+
"""
|
645
|
+
data_format = data_format or keras.config.image_data_format()
|
646
|
+
bn_axis = -1 if data_format == "channels_last" else 1
|
647
|
+
|
648
|
+
x_preact = None
|
649
|
+
if use_pre_activation:
|
650
|
+
x_preact = layers.BatchNormalization(
|
651
|
+
axis=bn_axis,
|
652
|
+
epsilon=1e-5,
|
653
|
+
momentum=0.9,
|
654
|
+
dtype=dtype,
|
655
|
+
name=f"{name}_pre_activation_bn",
|
656
|
+
)(x)
|
657
|
+
x_preact = layers.Activation(
|
658
|
+
"relu", dtype=dtype, name=f"{name}_pre_activation_relu"
|
659
|
+
)(x_preact)
|
660
|
+
|
661
|
+
if conv_shortcut:
|
662
|
+
if x_preact is not None:
|
663
|
+
shortcut = x_preact
|
664
|
+
elif stride > 1:
|
665
|
+
shortcut = layers.AveragePooling2D(
|
666
|
+
2,
|
667
|
+
strides=stride,
|
668
|
+
data_format=data_format,
|
669
|
+
dtype=dtype,
|
670
|
+
padding="same",
|
671
|
+
)(x)
|
672
|
+
else:
|
673
|
+
shortcut = x
|
674
|
+
shortcut = layers.Conv2D(
|
675
|
+
filters,
|
676
|
+
1,
|
677
|
+
strides=1,
|
678
|
+
data_format=data_format,
|
679
|
+
use_bias=False,
|
680
|
+
dtype=dtype,
|
681
|
+
name=f"{name}_0_conv",
|
682
|
+
)(shortcut)
|
683
|
+
if not use_pre_activation:
|
684
|
+
shortcut = layers.BatchNormalization(
|
685
|
+
axis=bn_axis,
|
686
|
+
epsilon=1e-5,
|
687
|
+
momentum=0.9,
|
688
|
+
dtype=dtype,
|
689
|
+
name=f"{name}_0_bn",
|
690
|
+
)(shortcut)
|
691
|
+
else:
|
692
|
+
shortcut = x
|
693
|
+
|
694
|
+
x = x_preact if x_preact is not None else x
|
695
|
+
if stride > 1:
|
696
|
+
x = layers.ZeroPadding2D(
|
697
|
+
(kernel_size - 1) // 2,
|
698
|
+
data_format=data_format,
|
699
|
+
dtype=dtype,
|
700
|
+
name=f"{name}_1_pad",
|
701
|
+
)(x)
|
702
|
+
x = layers.Conv2D(
|
703
|
+
filters,
|
704
|
+
kernel_size,
|
705
|
+
strides=stride,
|
706
|
+
padding="valid" if stride > 1 else "same",
|
707
|
+
data_format=data_format,
|
708
|
+
use_bias=False,
|
709
|
+
dtype=dtype,
|
710
|
+
name=f"{name}_1_conv",
|
711
|
+
)(x)
|
712
|
+
x = layers.BatchNormalization(
|
713
|
+
axis=bn_axis,
|
714
|
+
epsilon=1e-5,
|
715
|
+
momentum=0.9,
|
716
|
+
dtype=dtype,
|
717
|
+
name=f"{name}_1_bn",
|
718
|
+
)(x)
|
719
|
+
x = layers.Activation("relu", dtype=dtype, name=f"{name}_1_relu")(x)
|
720
|
+
|
721
|
+
x = layers.Conv2D(
|
722
|
+
filters,
|
723
|
+
kernel_size,
|
724
|
+
strides=1,
|
725
|
+
padding="same",
|
726
|
+
data_format=data_format,
|
727
|
+
use_bias=False,
|
728
|
+
dtype=dtype,
|
729
|
+
name=f"{name}_2_conv",
|
730
|
+
)(x)
|
731
|
+
if not use_pre_activation:
|
732
|
+
x = layers.BatchNormalization(
|
733
|
+
axis=bn_axis,
|
734
|
+
epsilon=1e-5,
|
735
|
+
momentum=0.9,
|
736
|
+
dtype=dtype,
|
737
|
+
name=f"{name}_2_bn",
|
738
|
+
)(x)
|
739
|
+
x = layers.Add(dtype=dtype, name=f"{name}_add")([shortcut, x])
|
740
|
+
x = layers.Activation("relu", dtype=dtype, name=f"{name}_out")(x)
|
741
|
+
else:
|
742
|
+
x = layers.Add(dtype=dtype, name=f"{name}_out")([shortcut, x])
|
743
|
+
return x
|
744
|
+
|
745
|
+
|
746
|
+
def apply_bottleneck_block_vd(
|
747
|
+
x,
|
748
|
+
filters,
|
749
|
+
kernel_size=3,
|
750
|
+
stride=1,
|
751
|
+
conv_shortcut=False,
|
752
|
+
use_pre_activation=False,
|
753
|
+
data_format=None,
|
754
|
+
dtype=None,
|
755
|
+
name=None,
|
756
|
+
):
|
757
|
+
"""Applies a bottleneck residual block.
|
758
|
+
|
759
|
+
Args:
|
760
|
+
x: Tensor. The input tensor to pass through the block.
|
761
|
+
filters: int. The number of filters in the block.
|
762
|
+
kernel_size: int. The kernel size of the bottleneck layer. Defaults to
|
763
|
+
`3`.
|
764
|
+
stride: int. The stride length of the first layer. Defaults to `1`.
|
765
|
+
conv_shortcut: bool. If `True`, use a convolution shortcut. If `False`,
|
766
|
+
use an identity or pooling shortcut based on the stride. Defaults to
|
767
|
+
`False`.
|
768
|
+
use_pre_activation: boolean. Whether to use pre-activation or not.
|
769
|
+
`True` for ResNetV2, `False` for ResNet. Defaults to `False`.
|
770
|
+
data_format: `None` or str. the ordering of the dimensions in the
|
771
|
+
inputs. Can be `"channels_last"`
|
772
|
+
(`(batch_size, height, width, channels)`) or`"channels_first"`
|
773
|
+
(`(batch_size, channels, height, width)`).
|
774
|
+
dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
|
775
|
+
to use for the models computations and weights.
|
776
|
+
name: str. A prefix for the layer names used in the block.
|
777
|
+
|
778
|
+
Returns:
|
779
|
+
The output tensor for the residual block.
|
780
|
+
"""
|
781
|
+
data_format = data_format or keras.config.image_data_format()
|
782
|
+
bn_axis = -1 if data_format == "channels_last" else 1
|
783
|
+
|
784
|
+
x_preact = None
|
785
|
+
if use_pre_activation:
|
786
|
+
x_preact = layers.BatchNormalization(
|
787
|
+
axis=bn_axis,
|
788
|
+
epsilon=1e-5,
|
789
|
+
momentum=0.9,
|
790
|
+
dtype=dtype,
|
791
|
+
name=f"{name}_pre_activation_bn",
|
792
|
+
)(x)
|
793
|
+
x_preact = layers.Activation(
|
794
|
+
"relu", dtype=dtype, name=f"{name}_pre_activation_relu"
|
795
|
+
)(x_preact)
|
796
|
+
|
797
|
+
if conv_shortcut:
|
798
|
+
if x_preact is not None:
|
799
|
+
shortcut = x_preact
|
800
|
+
elif stride > 1:
|
801
|
+
shortcut = layers.AveragePooling2D(
|
802
|
+
2,
|
803
|
+
strides=stride,
|
804
|
+
data_format=data_format,
|
805
|
+
dtype=dtype,
|
806
|
+
padding="same",
|
807
|
+
)(x)
|
808
|
+
else:
|
809
|
+
shortcut = x
|
810
|
+
shortcut = layers.Conv2D(
|
811
|
+
4 * filters,
|
812
|
+
1,
|
813
|
+
strides=1,
|
814
|
+
data_format=data_format,
|
815
|
+
use_bias=False,
|
816
|
+
dtype=dtype,
|
817
|
+
name=f"{name}_0_conv",
|
818
|
+
)(shortcut)
|
465
819
|
if not use_pre_activation:
|
466
820
|
shortcut = layers.BatchNormalization(
|
467
821
|
axis=bn_axis,
|
@@ -561,8 +915,11 @@ def apply_stack(
|
|
561
915
|
blocks: int. The number of blocks in the stack.
|
562
916
|
stride: int. The stride length of the first layer in the first block.
|
563
917
|
block_type: str. The block type to stack. One of `"basic_block"` or
|
564
|
-
`"bottleneck_block"
|
565
|
-
Use `"
|
918
|
+
`"bottleneck_block"`, `"basic_block_vd"` or
|
919
|
+
`"bottleneck_block_vd"`. Use `"basic_block"` for ResNet18 and
|
920
|
+
ResNet34. Use `"bottleneck_block"` for ResNet50, ResNet101 and
|
921
|
+
ResNet152 and the `"_vd"` prefix for the respective ResNet_vd
|
922
|
+
variants.
|
566
923
|
use_pre_activation: boolean. Whether to use pre-activation or not.
|
567
924
|
`True` for ResNetV2, `False` for ResNet and ResNeXt.
|
568
925
|
first_shortcut: bool. If `True`, use a convolution shortcut. If `False`,
|
@@ -580,17 +937,21 @@ def apply_stack(
|
|
580
937
|
Output tensor for the stacked blocks.
|
581
938
|
"""
|
582
939
|
if name is None:
|
583
|
-
|
584
|
-
name = f"{version}_stack"
|
940
|
+
name = "stack"
|
585
941
|
|
586
942
|
if block_type == "basic_block":
|
587
943
|
block_fn = apply_basic_block
|
588
944
|
elif block_type == "bottleneck_block":
|
589
945
|
block_fn = apply_bottleneck_block
|
946
|
+
elif block_type == "basic_block_vd":
|
947
|
+
block_fn = apply_basic_block_vd
|
948
|
+
elif block_type == "bottleneck_block_vd":
|
949
|
+
block_fn = apply_bottleneck_block_vd
|
590
950
|
else:
|
591
951
|
raise ValueError(
|
592
|
-
'`block_type` must be either `"basic_block"
|
593
|
-
|
952
|
+
'`block_type` must be either `"basic_block"`, '
|
953
|
+
'`"bottleneck_block"`, `"basic_block_vd"` or '
|
954
|
+
f'`"bottleneck_block_vd"`. Received block_type={block_type}.'
|
594
955
|
)
|
595
956
|
for i in range(blocks):
|
596
957
|
if i == 0:
|
@@ -16,6 +16,9 @@ import keras
|
|
16
16
|
from keras_hub.src.api_export import keras_hub_export
|
17
17
|
from keras_hub.src.models.image_classifier import ImageClassifier
|
18
18
|
from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone
|
19
|
+
from keras_hub.src.models.resnet.resnet_image_classifier_preprocessor import (
|
20
|
+
ResNetImageClassifierPreprocessor,
|
21
|
+
)
|
19
22
|
|
20
23
|
|
21
24
|
@keras_hub_export("keras_hub.models.ResNetImageClassifier")
|
@@ -88,21 +91,22 @@ class ResNetImageClassifier(ImageClassifier):
|
|
88
91
|
"""
|
89
92
|
|
90
93
|
backbone_cls = ResNetBackbone
|
94
|
+
preprocessor_cls = ResNetImageClassifierPreprocessor
|
91
95
|
|
92
96
|
def __init__(
|
93
97
|
self,
|
94
98
|
backbone,
|
95
99
|
num_classes,
|
96
|
-
|
100
|
+
preprocessor=None,
|
101
|
+
activation=None,
|
97
102
|
head_dtype=None,
|
98
|
-
preprocessor=None, # adding this dummy arg for saved model test
|
99
|
-
# TODO: once preprocessor flow is figured out, this needs to be updated
|
100
103
|
**kwargs,
|
101
104
|
):
|
102
105
|
head_dtype = head_dtype or backbone.dtype_policy
|
103
106
|
|
104
107
|
# === Layers ===
|
105
108
|
self.backbone = backbone
|
109
|
+
self.preprocessor = preprocessor
|
106
110
|
self.output_dense = keras.layers.Dense(
|
107
111
|
num_classes,
|
108
112
|
activation=activation,
|
@@ -0,0 +1,28 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from keras_hub.src.api_export import keras_hub_export
|
16
|
+
from keras_hub.src.models.image_classifier_preprocessor import (
|
17
|
+
ImageClassifierPreprocessor,
|
18
|
+
)
|
19
|
+
from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone
|
20
|
+
from keras_hub.src.models.resnet.resnet_image_converter import (
|
21
|
+
ResNetImageConverter,
|
22
|
+
)
|
23
|
+
|
24
|
+
|
25
|
+
@keras_hub_export("keras_hub.models.ResNetImageClassifierPreprocessor")
|
26
|
+
class ResNetImageClassifierPreprocessor(ImageClassifierPreprocessor):
|
27
|
+
backbone_cls = ResNetBackbone
|
28
|
+
image_converter_cls = ResNetImageConverter
|
@@ -0,0 +1,23 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
from keras_hub.src.api_export import keras_hub_export
|
15
|
+
from keras_hub.src.layers.preprocessing.resizing_image_converter import (
|
16
|
+
ResizingImageConverter,
|
17
|
+
)
|
18
|
+
from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone
|
19
|
+
|
20
|
+
|
21
|
+
@keras_hub_export("keras_hub.layers.ResNetImageConverter")
|
22
|
+
class ResNetImageConverter(ResizingImageConverter):
|
23
|
+
backbone_cls = ResNetBackbone
|