keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl → 0.16.0.dev2024092017__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +0 -6
- keras_hub/api/__init__.py +2 -0
- keras_hub/api/bounding_box/__init__.py +36 -0
- keras_hub/api/layers/__init__.py +14 -0
- keras_hub/api/models/__init__.py +97 -48
- keras_hub/api/tokenizers/__init__.py +30 -0
- keras_hub/api/utils/__init__.py +22 -0
- keras_hub/src/api_export.py +15 -9
- keras_hub/src/bounding_box/__init__.py +13 -0
- keras_hub/src/bounding_box/converters.py +529 -0
- keras_hub/src/bounding_box/formats.py +162 -0
- keras_hub/src/bounding_box/iou.py +263 -0
- keras_hub/src/bounding_box/to_dense.py +95 -0
- keras_hub/src/bounding_box/to_ragged.py +99 -0
- keras_hub/src/bounding_box/utils.py +194 -0
- keras_hub/src/bounding_box/validate_format.py +99 -0
- keras_hub/src/layers/preprocessing/audio_converter.py +121 -0
- keras_hub/src/layers/preprocessing/image_converter.py +130 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +2 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +9 -8
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +2 -29
- keras_hub/src/layers/preprocessing/random_deletion.py +33 -31
- keras_hub/src/layers/preprocessing/random_swap.py +33 -31
- keras_hub/src/layers/preprocessing/resizing_image_converter.py +101 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +3 -2
- keras_hub/src/models/albert/__init__.py +1 -2
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +6 -86
- keras_hub/src/models/albert/{albert_classifier.py → albert_text_classifier.py} +34 -10
- keras_hub/src/models/albert/{albert_preprocessor.py → albert_text_classifier_preprocessor.py} +14 -70
- keras_hub/src/models/albert/albert_tokenizer.py +17 -36
- keras_hub/src/models/backbone.py +12 -34
- keras_hub/src/models/bart/__init__.py +1 -2
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +21 -148
- keras_hub/src/models/bart/bart_tokenizer.py +12 -39
- keras_hub/src/models/bert/__init__.py +1 -5
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +6 -87
- keras_hub/src/models/bert/bert_presets.py +1 -4
- keras_hub/src/models/bert/{bert_classifier.py → bert_text_classifier.py} +19 -12
- keras_hub/src/models/bert/{bert_preprocessor.py → bert_text_classifier_preprocessor.py} +14 -70
- keras_hub/src/models/bert/bert_tokenizer.py +17 -35
- keras_hub/src/models/bloom/__init__.py +1 -2
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +6 -91
- keras_hub/src/models/bloom/bloom_tokenizer.py +12 -41
- keras_hub/src/models/causal_lm.py +10 -29
- keras_hub/src/models/causal_lm_preprocessor.py +195 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +54 -15
- keras_hub/src/models/deberta_v3/__init__.py +1 -4
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +14 -77
- keras_hub/src/models/deberta_v3/{deberta_v3_classifier.py → deberta_v3_text_classifier.py} +16 -11
- keras_hub/src/models/deberta_v3/{deberta_v3_preprocessor.py → deberta_v3_text_classifier_preprocessor.py} +23 -64
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +30 -25
- keras_hub/src/models/densenet/densenet_backbone.py +46 -22
- keras_hub/src/models/distil_bert/__init__.py +1 -4
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +14 -76
- keras_hub/src/models/distil_bert/{distil_bert_classifier.py → distil_bert_text_classifier.py} +17 -12
- keras_hub/src/models/distil_bert/{distil_bert_preprocessor.py → distil_bert_text_classifier_preprocessor.py} +23 -63
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +19 -35
- keras_hub/src/models/efficientnet/__init__.py +13 -0
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +569 -0
- keras_hub/src/models/efficientnet/fusedmbconv.py +229 -0
- keras_hub/src/models/efficientnet/mbconv.py +238 -0
- keras_hub/src/models/electra/__init__.py +1 -2
- keras_hub/src/models/electra/electra_tokenizer.py +17 -32
- keras_hub/src/models/f_net/__init__.py +1 -2
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +12 -78
- keras_hub/src/models/f_net/{f_net_classifier.py → f_net_text_classifier.py} +17 -10
- keras_hub/src/models/f_net/{f_net_preprocessor.py → f_net_text_classifier_preprocessor.py} +19 -63
- keras_hub/src/models/f_net/f_net_tokenizer.py +17 -35
- keras_hub/src/models/falcon/__init__.py +1 -2
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/falcon/falcon_tokenizer.py +12 -35
- keras_hub/src/models/gemma/__init__.py +1 -2
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +6 -90
- keras_hub/src/models/gemma/gemma_decoder_block.py +1 -1
- keras_hub/src/models/gemma/gemma_tokenizer.py +12 -23
- keras_hub/src/models/gpt2/__init__.py +1 -2
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +12 -90
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +12 -34
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +6 -91
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +12 -34
- keras_hub/src/models/image_classifier.py +0 -5
- keras_hub/src/models/image_classifier_preprocessor.py +83 -0
- keras_hub/src/models/llama/__init__.py +1 -2
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +6 -85
- keras_hub/src/models/llama/llama_tokenizer.py +12 -25
- keras_hub/src/models/llama3/__init__.py +1 -2
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/llama3/llama3_tokenizer.py +12 -33
- keras_hub/src/models/masked_lm.py +0 -2
- keras_hub/src/models/masked_lm_preprocessor.py +156 -0
- keras_hub/src/models/mistral/__init__.py +1 -2
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +6 -91
- keras_hub/src/models/mistral/mistral_tokenizer.py +12 -23
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +2 -2
- keras_hub/src/models/mobilenet/__init__.py +13 -0
- keras_hub/src/models/mobilenet/mobilenet_backbone.py +530 -0
- keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +114 -0
- keras_hub/src/models/opt/__init__.py +1 -2
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +6 -93
- keras_hub/src/models/opt/opt_tokenizer.py +12 -41
- keras_hub/src/models/pali_gemma/__init__.py +1 -4
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +28 -28
- keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +25 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +5 -5
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +11 -3
- keras_hub/src/models/phi3/__init__.py +1 -2
- keras_hub/src/models/phi3/phi3_causal_lm.py +3 -9
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/phi3/phi3_tokenizer.py +12 -36
- keras_hub/src/models/preprocessor.py +72 -83
- keras_hub/src/models/resnet/__init__.py +6 -0
- keras_hub/src/models/resnet/resnet_backbone.py +390 -42
- keras_hub/src/models/resnet/resnet_image_classifier.py +33 -6
- keras_hub/src/models/resnet/resnet_image_classifier_preprocessor.py +28 -0
- keras_hub/src/models/{llama3/llama3_preprocessor.py → resnet/resnet_image_converter.py} +7 -5
- keras_hub/src/models/resnet/resnet_presets.py +95 -0
- keras_hub/src/models/retinanet/__init__.py +13 -0
- keras_hub/src/models/retinanet/anchor_generator.py +175 -0
- keras_hub/src/models/retinanet/box_matcher.py +259 -0
- keras_hub/src/models/retinanet/non_max_supression.py +578 -0
- keras_hub/src/models/roberta/__init__.py +1 -2
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +22 -74
- keras_hub/src/models/roberta/{roberta_classifier.py → roberta_text_classifier.py} +16 -11
- keras_hub/src/models/roberta/{roberta_preprocessor.py → roberta_text_classifier_preprocessor.py} +21 -53
- keras_hub/src/models/roberta/roberta_tokenizer.py +13 -52
- keras_hub/src/models/seq_2_seq_lm_preprocessor.py +269 -0
- keras_hub/src/models/stable_diffusion_v3/__init__.py +13 -0
- keras_hub/src/models/stable_diffusion_v3/clip_encoder_block.py +103 -0
- keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +93 -0
- keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py +149 -0
- keras_hub/src/models/stable_diffusion_v3/clip_tokenizer.py +167 -0
- keras_hub/src/models/stable_diffusion_v3/mmdit.py +427 -0
- keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +317 -0
- keras_hub/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py +74 -0
- keras_hub/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py +155 -0
- keras_hub/src/models/stable_diffusion_v3/vae_attention.py +126 -0
- keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +186 -0
- keras_hub/src/models/t5/__init__.py +1 -2
- keras_hub/src/models/t5/t5_tokenizer.py +13 -23
- keras_hub/src/models/task.py +71 -116
- keras_hub/src/models/{classifier.py → text_classifier.py} +19 -13
- keras_hub/src/models/text_classifier_preprocessor.py +138 -0
- keras_hub/src/models/whisper/__init__.py +1 -2
- keras_hub/src/models/whisper/{whisper_audio_feature_extractor.py → whisper_audio_converter.py} +20 -18
- keras_hub/src/models/whisper/whisper_backbone.py +0 -3
- keras_hub/src/models/whisper/whisper_presets.py +10 -10
- keras_hub/src/models/whisper/whisper_tokenizer.py +20 -16
- keras_hub/src/models/xlm_roberta/__init__.py +1 -4
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +26 -72
- keras_hub/src/models/xlm_roberta/{xlm_roberta_classifier.py → xlm_roberta_text_classifier.py} +16 -11
- keras_hub/src/models/xlm_roberta/{xlm_roberta_preprocessor.py → xlm_roberta_text_classifier_preprocessor.py} +26 -53
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +25 -10
- keras_hub/src/tests/test_case.py +46 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +30 -17
- keras_hub/src/tokenizers/byte_tokenizer.py +14 -15
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +20 -7
- keras_hub/src/tokenizers/tokenizer.py +67 -32
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +14 -15
- keras_hub/src/tokenizers/word_piece_tokenizer.py +34 -47
- keras_hub/src/utils/imagenet/__init__.py +13 -0
- keras_hub/src/utils/imagenet/imagenet_utils.py +1067 -0
- keras_hub/src/utils/keras_utils.py +0 -50
- keras_hub/src/utils/preset_utils.py +230 -68
- keras_hub/src/utils/tensor_utils.py +187 -69
- keras_hub/src/utils/timm/convert_resnet.py +19 -16
- keras_hub/src/utils/timm/preset_loader.py +66 -0
- keras_hub/src/utils/transformers/convert_albert.py +193 -0
- keras_hub/src/utils/transformers/convert_bart.py +373 -0
- keras_hub/src/utils/transformers/convert_bert.py +7 -17
- keras_hub/src/utils/transformers/convert_distilbert.py +10 -20
- keras_hub/src/utils/transformers/convert_gemma.py +5 -19
- keras_hub/src/utils/transformers/convert_gpt2.py +5 -18
- keras_hub/src/utils/transformers/convert_llama3.py +7 -18
- keras_hub/src/utils/transformers/convert_mistral.py +129 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +7 -29
- keras_hub/src/utils/transformers/preset_loader.py +77 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +2 -2
- keras_hub/src/version_utils.py +1 -1
- keras_hub_nightly-0.16.0.dev2024092017.dist-info/METADATA +202 -0
- keras_hub_nightly-0.16.0.dev2024092017.dist-info/RECORD +334 -0
- {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/WHEEL +1 -1
- keras_hub/src/models/bart/bart_preprocessor.py +0 -276
- keras_hub/src/models/bloom/bloom_preprocessor.py +0 -185
- keras_hub/src/models/electra/electra_preprocessor.py +0 -154
- keras_hub/src/models/falcon/falcon_preprocessor.py +0 -187
- keras_hub/src/models/gemma/gemma_preprocessor.py +0 -191
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +0 -145
- keras_hub/src/models/llama/llama_preprocessor.py +0 -189
- keras_hub/src/models/mistral/mistral_preprocessor.py +0 -190
- keras_hub/src/models/opt/opt_preprocessor.py +0 -188
- keras_hub/src/models/phi3/phi3_preprocessor.py +0 -190
- keras_hub/src/models/whisper/whisper_preprocessor.py +0 -326
- keras_hub/src/utils/timm/convert.py +0 -37
- keras_hub/src/utils/transformers/convert.py +0 -101
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +0 -34
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +0 -297
- {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/top_level.txt +0 -0
@@ -27,9 +27,10 @@ class ResNetBackbone(FeaturePyramidBackbone):
|
|
27
27
|
This class implements a ResNet backbone as described in [Deep Residual
|
28
28
|
Learning for Image Recognition](https://arxiv.org/abs/1512.03385)(
|
29
29
|
CVPR 2016), [Identity Mappings in Deep Residual Networks](
|
30
|
-
https://arxiv.org/abs/1603.05027)(ECCV 2016)
|
30
|
+
https://arxiv.org/abs/1603.05027)(ECCV 2016), [ResNet strikes back: An
|
31
31
|
improved training procedure in timm](https://arxiv.org/abs/2110.00476)(
|
32
|
-
NeurIPS 2021 Workshop)
|
32
|
+
NeurIPS 2021 Workshop) and [Bag of Tricks for Image Classification with
|
33
|
+
Convolutional Neural Networks](https://arxiv.org/abs/1812.01187).
|
33
34
|
|
34
35
|
The difference in ResNet and ResNetV2 rests in the structure of their
|
35
36
|
individual building blocks. In ResNetV2, the batch normalization and
|
@@ -37,18 +38,31 @@ class ResNetBackbone(FeaturePyramidBackbone):
|
|
37
38
|
the batch normalization and ReLU activation are applied after the
|
38
39
|
convolution layers.
|
39
40
|
|
41
|
+
ResNetVd introduces two key modifications to the standard ResNet. First,
|
42
|
+
the initial convolutional layer is replaced by a series of three
|
43
|
+
successive convolutional layers. Second, shortcut connections use an
|
44
|
+
additional pooling operation rather than performing downsampling within
|
45
|
+
the convolutional layers themselves.
|
46
|
+
|
40
47
|
Note that `ResNetBackbone` expects the inputs to be images with a value
|
41
48
|
range of `[0, 255]` when `include_rescaling=True`.
|
42
49
|
|
43
50
|
Args:
|
51
|
+
input_conv_filters: list of ints. The number of filters of the initial
|
52
|
+
convolution(s).
|
53
|
+
input_conv_kernel_sizes: list of ints. The kernel sizes of the initial
|
54
|
+
convolution(s).
|
44
55
|
stackwise_num_filters: list of ints. The number of filters for each
|
45
56
|
stack.
|
46
57
|
stackwise_num_blocks: list of ints. The number of blocks for each stack.
|
47
58
|
stackwise_num_strides: list of ints. The number of strides for each
|
48
59
|
stack.
|
49
|
-
block_type: str. The block type to stack. One of `"basic_block"
|
50
|
-
`"bottleneck_block"
|
51
|
-
Use `"
|
60
|
+
block_type: str. The block type to stack. One of `"basic_block"`,
|
61
|
+
`"bottleneck_block"`, `"basic_block_vd"` or
|
62
|
+
`"bottleneck_block_vd"`. Use `"basic_block"` for ResNet18 and
|
63
|
+
ResNet34. Use `"bottleneck_block"` for ResNet50, ResNet101 and
|
64
|
+
ResNet152 and the `"_vd"` prefix for the respective ResNet_vd
|
65
|
+
variants.
|
52
66
|
use_pre_activation: boolean. Whether to use pre-activation or not.
|
53
67
|
`True` for ResNetV2, `False` for ResNet.
|
54
68
|
include_rescaling: boolean. If `True`, rescale the input using
|
@@ -88,6 +102,8 @@ class ResNetBackbone(FeaturePyramidBackbone):
|
|
88
102
|
|
89
103
|
# Randomly initialized ResNetV2 backbone with a custom config.
|
90
104
|
model = keras_hub.models.ResNetBackbone(
|
105
|
+
input_conv_filters=[64],
|
106
|
+
input_conv_kernel_sizes=[7],
|
91
107
|
stackwise_num_filters=[64, 64, 64],
|
92
108
|
stackwise_num_blocks=[2, 2, 2],
|
93
109
|
stackwise_num_strides=[1, 2, 2],
|
@@ -101,6 +117,8 @@ class ResNetBackbone(FeaturePyramidBackbone):
|
|
101
117
|
|
102
118
|
def __init__(
|
103
119
|
self,
|
120
|
+
input_conv_filters,
|
121
|
+
input_conv_kernel_sizes,
|
104
122
|
stackwise_num_filters,
|
105
123
|
stackwise_num_blocks,
|
106
124
|
stackwise_num_strides,
|
@@ -108,11 +126,17 @@ class ResNetBackbone(FeaturePyramidBackbone):
|
|
108
126
|
use_pre_activation=False,
|
109
127
|
include_rescaling=True,
|
110
128
|
image_shape=(None, None, 3),
|
111
|
-
pooling="avg",
|
112
129
|
data_format=None,
|
113
130
|
dtype=None,
|
114
131
|
**kwargs,
|
115
132
|
):
|
133
|
+
if len(input_conv_filters) != len(input_conv_kernel_sizes):
|
134
|
+
raise ValueError(
|
135
|
+
"The length of `input_conv_filters` and"
|
136
|
+
"`input_conv_kernel_sizes` must be the same. "
|
137
|
+
f"Received: input_conv_filters={input_conv_filters}, "
|
138
|
+
f"input_conv_kernel_sizes={input_conv_kernel_sizes}."
|
139
|
+
)
|
116
140
|
if len(stackwise_num_filters) != len(stackwise_num_blocks) or len(
|
117
141
|
stackwise_num_filters
|
118
142
|
) != len(stackwise_num_strides):
|
@@ -128,14 +152,20 @@ class ResNetBackbone(FeaturePyramidBackbone):
|
|
128
152
|
"The first element of `stackwise_num_filters` must be 64. "
|
129
153
|
f"Received: stackwise_num_filters={stackwise_num_filters}"
|
130
154
|
)
|
131
|
-
if block_type not in (
|
155
|
+
if block_type not in (
|
156
|
+
"basic_block",
|
157
|
+
"bottleneck_block",
|
158
|
+
"basic_block_vd",
|
159
|
+
"bottleneck_block_vd",
|
160
|
+
):
|
132
161
|
raise ValueError(
|
133
|
-
'`block_type` must be either `"basic_block"
|
134
|
-
|
162
|
+
'`block_type` must be either `"basic_block"`, '
|
163
|
+
'`"bottleneck_block"`, `"basic_block_vd"` or '
|
164
|
+
f'`"bottleneck_block_vd"`. Received block_type={block_type}.'
|
135
165
|
)
|
136
|
-
version = "v1" if not use_pre_activation else "v2"
|
137
166
|
data_format = standardize_data_format(data_format)
|
138
167
|
bn_axis = -1 if data_format == "channels_last" else 1
|
168
|
+
num_input_convs = len(input_conv_filters)
|
139
169
|
num_stacks = len(stackwise_num_filters)
|
140
170
|
|
141
171
|
# === Functional Model ===
|
@@ -155,29 +185,56 @@ class ResNetBackbone(FeaturePyramidBackbone):
|
|
155
185
|
# The padding between torch and tensorflow/jax differs when `strides>1`.
|
156
186
|
# Therefore, we need to manually pad the tensor.
|
157
187
|
x = layers.ZeroPadding2D(
|
158
|
-
|
188
|
+
(input_conv_kernel_sizes[0] - 1) // 2,
|
159
189
|
data_format=data_format,
|
160
190
|
dtype=dtype,
|
161
191
|
name="conv1_pad",
|
162
192
|
)(x)
|
163
193
|
x = layers.Conv2D(
|
164
|
-
|
165
|
-
|
194
|
+
input_conv_filters[0],
|
195
|
+
input_conv_kernel_sizes[0],
|
166
196
|
strides=2,
|
167
197
|
data_format=data_format,
|
168
198
|
use_bias=False,
|
199
|
+
padding="valid",
|
169
200
|
dtype=dtype,
|
170
201
|
name="conv1_conv",
|
171
202
|
)(x)
|
203
|
+
for conv_index in range(1, num_input_convs):
|
204
|
+
x = layers.BatchNormalization(
|
205
|
+
axis=bn_axis,
|
206
|
+
epsilon=1e-5,
|
207
|
+
momentum=0.9,
|
208
|
+
dtype=dtype,
|
209
|
+
name=f"conv{conv_index}_bn",
|
210
|
+
)(x)
|
211
|
+
x = layers.Activation(
|
212
|
+
"relu", dtype=dtype, name=f"conv{conv_index}_relu"
|
213
|
+
)(x)
|
214
|
+
x = layers.Conv2D(
|
215
|
+
input_conv_filters[conv_index],
|
216
|
+
input_conv_kernel_sizes[conv_index],
|
217
|
+
strides=1,
|
218
|
+
data_format=data_format,
|
219
|
+
use_bias=False,
|
220
|
+
padding="same",
|
221
|
+
dtype=dtype,
|
222
|
+
name=f"conv{conv_index+1}_conv",
|
223
|
+
)(x)
|
224
|
+
|
172
225
|
if not use_pre_activation:
|
173
226
|
x = layers.BatchNormalization(
|
174
227
|
axis=bn_axis,
|
175
228
|
epsilon=1e-5,
|
176
229
|
momentum=0.9,
|
177
230
|
dtype=dtype,
|
178
|
-
name="
|
231
|
+
name=f"conv{num_input_convs}_bn",
|
232
|
+
)(x)
|
233
|
+
x = layers.Activation(
|
234
|
+
"relu",
|
235
|
+
dtype=dtype,
|
236
|
+
name=f"conv{num_input_convs}_relu",
|
179
237
|
)(x)
|
180
|
-
x = layers.Activation("relu", dtype=dtype, name="conv1_relu")(x)
|
181
238
|
|
182
239
|
if use_pre_activation:
|
183
240
|
# A workaround for ResNetV2: we need -inf padding to prevent zeros
|
@@ -210,12 +267,10 @@ class ResNetBackbone(FeaturePyramidBackbone):
|
|
210
267
|
stride=stackwise_num_strides[stack_index],
|
211
268
|
block_type=block_type,
|
212
269
|
use_pre_activation=use_pre_activation,
|
213
|
-
first_shortcut=(
|
214
|
-
block_type == "bottleneck_block" or stack_index > 0
|
215
|
-
),
|
270
|
+
first_shortcut=(block_type != "basic_block" or stack_index > 0),
|
216
271
|
data_format=data_format,
|
217
272
|
dtype=dtype,
|
218
|
-
name=f"{
|
273
|
+
name=f"stack{stack_index}",
|
219
274
|
)
|
220
275
|
pyramid_outputs[f"P{stack_index + 2}"] = x
|
221
276
|
|
@@ -229,25 +284,16 @@ class ResNetBackbone(FeaturePyramidBackbone):
|
|
229
284
|
)(x)
|
230
285
|
x = layers.Activation("relu", dtype=dtype, name="post_relu")(x)
|
231
286
|
|
232
|
-
if pooling == "avg":
|
233
|
-
feature_map_output = layers.GlobalAveragePooling2D(
|
234
|
-
data_format=data_format, dtype=dtype
|
235
|
-
)(x)
|
236
|
-
elif pooling == "max":
|
237
|
-
feature_map_output = layers.GlobalMaxPooling2D(
|
238
|
-
data_format=data_format, dtype=dtype
|
239
|
-
)(x)
|
240
|
-
else:
|
241
|
-
feature_map_output = x
|
242
|
-
|
243
287
|
super().__init__(
|
244
288
|
inputs=image_input,
|
245
|
-
outputs=
|
289
|
+
outputs=x,
|
246
290
|
dtype=dtype,
|
247
291
|
**kwargs,
|
248
292
|
)
|
249
293
|
|
250
294
|
# === Config ===
|
295
|
+
self.input_conv_filters = input_conv_filters
|
296
|
+
self.input_conv_kernel_sizes = input_conv_kernel_sizes
|
251
297
|
self.stackwise_num_filters = stackwise_num_filters
|
252
298
|
self.stackwise_num_blocks = stackwise_num_blocks
|
253
299
|
self.stackwise_num_strides = stackwise_num_strides
|
@@ -255,13 +301,15 @@ class ResNetBackbone(FeaturePyramidBackbone):
|
|
255
301
|
self.use_pre_activation = use_pre_activation
|
256
302
|
self.include_rescaling = include_rescaling
|
257
303
|
self.image_shape = image_shape
|
258
|
-
self.pooling = pooling
|
259
304
|
self.pyramid_outputs = pyramid_outputs
|
305
|
+
self.data_format = data_format
|
260
306
|
|
261
307
|
def get_config(self):
|
262
308
|
config = super().get_config()
|
263
309
|
config.update(
|
264
310
|
{
|
311
|
+
"input_conv_filters": self.input_conv_filters,
|
312
|
+
"input_conv_kernel_sizes": self.input_conv_kernel_sizes,
|
265
313
|
"stackwise_num_filters": self.stackwise_num_filters,
|
266
314
|
"stackwise_num_blocks": self.stackwise_num_blocks,
|
267
315
|
"stackwise_num_strides": self.stackwise_num_strides,
|
@@ -269,7 +317,6 @@ class ResNetBackbone(FeaturePyramidBackbone):
|
|
269
317
|
"use_pre_activation": self.use_pre_activation,
|
270
318
|
"include_rescaling": self.include_rescaling,
|
271
319
|
"image_shape": self.image_shape,
|
272
|
-
"pooling": self.pooling,
|
273
320
|
}
|
274
321
|
)
|
275
322
|
return config
|
@@ -327,7 +374,10 @@ def apply_basic_block(
|
|
327
374
|
)(x_preact)
|
328
375
|
|
329
376
|
if conv_shortcut:
|
330
|
-
|
377
|
+
if x_preact is not None:
|
378
|
+
shortcut = x_preact
|
379
|
+
else:
|
380
|
+
shortcut = x
|
331
381
|
shortcut = layers.Conv2D(
|
332
382
|
filters,
|
333
383
|
1,
|
@@ -336,7 +386,7 @@ def apply_basic_block(
|
|
336
386
|
use_bias=False,
|
337
387
|
dtype=dtype,
|
338
388
|
name=f"{name}_0_conv",
|
339
|
-
)(
|
389
|
+
)(shortcut)
|
340
390
|
if not use_pre_activation:
|
341
391
|
shortcut = layers.BatchNormalization(
|
342
392
|
axis=bn_axis,
|
@@ -452,7 +502,10 @@ def apply_bottleneck_block(
|
|
452
502
|
)(x_preact)
|
453
503
|
|
454
504
|
if conv_shortcut:
|
455
|
-
|
505
|
+
if x_preact is not None:
|
506
|
+
shortcut = x_preact
|
507
|
+
else:
|
508
|
+
shortcut = x
|
456
509
|
shortcut = layers.Conv2D(
|
457
510
|
4 * filters,
|
458
511
|
1,
|
@@ -461,7 +514,295 @@ def apply_bottleneck_block(
|
|
461
514
|
use_bias=False,
|
462
515
|
dtype=dtype,
|
463
516
|
name=f"{name}_0_conv",
|
517
|
+
)(shortcut)
|
518
|
+
if not use_pre_activation:
|
519
|
+
shortcut = layers.BatchNormalization(
|
520
|
+
axis=bn_axis,
|
521
|
+
epsilon=1e-5,
|
522
|
+
momentum=0.9,
|
523
|
+
dtype=dtype,
|
524
|
+
name=f"{name}_0_bn",
|
525
|
+
)(shortcut)
|
526
|
+
else:
|
527
|
+
shortcut = x
|
528
|
+
|
529
|
+
x = x_preact if x_preact is not None else x
|
530
|
+
x = layers.Conv2D(
|
531
|
+
filters,
|
532
|
+
1,
|
533
|
+
strides=1,
|
534
|
+
data_format=data_format,
|
535
|
+
use_bias=False,
|
536
|
+
dtype=dtype,
|
537
|
+
name=f"{name}_1_conv",
|
538
|
+
)(x)
|
539
|
+
x = layers.BatchNormalization(
|
540
|
+
axis=bn_axis,
|
541
|
+
epsilon=1e-5,
|
542
|
+
momentum=0.9,
|
543
|
+
dtype=dtype,
|
544
|
+
name=f"{name}_1_bn",
|
545
|
+
)(x)
|
546
|
+
x = layers.Activation("relu", dtype=dtype, name=f"{name}_1_relu")(x)
|
547
|
+
|
548
|
+
if stride > 1:
|
549
|
+
x = layers.ZeroPadding2D(
|
550
|
+
(kernel_size - 1) // 2,
|
551
|
+
data_format=data_format,
|
552
|
+
dtype=dtype,
|
553
|
+
name=f"{name}_2_pad",
|
554
|
+
)(x)
|
555
|
+
x = layers.Conv2D(
|
556
|
+
filters,
|
557
|
+
kernel_size,
|
558
|
+
strides=stride,
|
559
|
+
padding="valid" if stride > 1 else "same",
|
560
|
+
data_format=data_format,
|
561
|
+
use_bias=False,
|
562
|
+
dtype=dtype,
|
563
|
+
name=f"{name}_2_conv",
|
564
|
+
)(x)
|
565
|
+
x = layers.BatchNormalization(
|
566
|
+
axis=bn_axis,
|
567
|
+
epsilon=1e-5,
|
568
|
+
momentum=0.9,
|
569
|
+
dtype=dtype,
|
570
|
+
name=f"{name}_2_bn",
|
571
|
+
)(x)
|
572
|
+
x = layers.Activation("relu", dtype=dtype, name=f"{name}_2_relu")(x)
|
573
|
+
|
574
|
+
x = layers.Conv2D(
|
575
|
+
4 * filters,
|
576
|
+
1,
|
577
|
+
data_format=data_format,
|
578
|
+
use_bias=False,
|
579
|
+
dtype=dtype,
|
580
|
+
name=f"{name}_3_conv",
|
581
|
+
)(x)
|
582
|
+
if not use_pre_activation:
|
583
|
+
x = layers.BatchNormalization(
|
584
|
+
axis=bn_axis,
|
585
|
+
epsilon=1e-5,
|
586
|
+
momentum=0.9,
|
587
|
+
dtype=dtype,
|
588
|
+
name=f"{name}_3_bn",
|
464
589
|
)(x)
|
590
|
+
x = layers.Add(dtype=dtype, name=f"{name}_add")([shortcut, x])
|
591
|
+
x = layers.Activation("relu", dtype=dtype, name=f"{name}_out")(x)
|
592
|
+
else:
|
593
|
+
x = layers.Add(dtype=dtype, name=f"{name}_out")([shortcut, x])
|
594
|
+
return x
|
595
|
+
|
596
|
+
|
597
|
+
def apply_basic_block_vd(
|
598
|
+
x,
|
599
|
+
filters,
|
600
|
+
kernel_size=3,
|
601
|
+
stride=1,
|
602
|
+
conv_shortcut=False,
|
603
|
+
use_pre_activation=False,
|
604
|
+
data_format=None,
|
605
|
+
dtype=None,
|
606
|
+
name=None,
|
607
|
+
):
|
608
|
+
"""Applies a basic residual block.
|
609
|
+
|
610
|
+
Args:
|
611
|
+
x: Tensor. The input tensor to pass through the block.
|
612
|
+
filters: int. The number of filters in the block.
|
613
|
+
kernel_size: int. The kernel size of the bottleneck layer. Defaults to
|
614
|
+
`3`.
|
615
|
+
stride: int. The stride length of the first layer. Defaults to `1`.
|
616
|
+
conv_shortcut: bool. If `True`, use a convolution shortcut. If `False`,
|
617
|
+
use an identity or pooling shortcut based on the stride. Defaults to
|
618
|
+
`False`.
|
619
|
+
use_pre_activation: boolean. Whether to use pre-activation or not.
|
620
|
+
`True` for ResNetV2, `False` for ResNet. Defaults to `False`.
|
621
|
+
data_format: `None` or str. the ordering of the dimensions in the
|
622
|
+
inputs. Can be `"channels_last"`
|
623
|
+
(`(batch_size, height, width, channels)`) or`"channels_first"`
|
624
|
+
(`(batch_size, channels, height, width)`).
|
625
|
+
dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
|
626
|
+
to use for the models computations and weights.
|
627
|
+
name: str. A prefix for the layer names used in the block.
|
628
|
+
|
629
|
+
Returns:
|
630
|
+
The output tensor for the basic residual block.
|
631
|
+
"""
|
632
|
+
data_format = data_format or keras.config.image_data_format()
|
633
|
+
bn_axis = -1 if data_format == "channels_last" else 1
|
634
|
+
|
635
|
+
x_preact = None
|
636
|
+
if use_pre_activation:
|
637
|
+
x_preact = layers.BatchNormalization(
|
638
|
+
axis=bn_axis,
|
639
|
+
epsilon=1e-5,
|
640
|
+
momentum=0.9,
|
641
|
+
dtype=dtype,
|
642
|
+
name=f"{name}_pre_activation_bn",
|
643
|
+
)(x)
|
644
|
+
x_preact = layers.Activation(
|
645
|
+
"relu", dtype=dtype, name=f"{name}_pre_activation_relu"
|
646
|
+
)(x_preact)
|
647
|
+
|
648
|
+
if conv_shortcut:
|
649
|
+
if x_preact is not None:
|
650
|
+
shortcut = x_preact
|
651
|
+
elif stride > 1:
|
652
|
+
shortcut = layers.AveragePooling2D(
|
653
|
+
2,
|
654
|
+
strides=stride,
|
655
|
+
data_format=data_format,
|
656
|
+
dtype=dtype,
|
657
|
+
padding="same",
|
658
|
+
)(x)
|
659
|
+
else:
|
660
|
+
shortcut = x
|
661
|
+
shortcut = layers.Conv2D(
|
662
|
+
filters,
|
663
|
+
1,
|
664
|
+
strides=1,
|
665
|
+
data_format=data_format,
|
666
|
+
use_bias=False,
|
667
|
+
dtype=dtype,
|
668
|
+
name=f"{name}_0_conv",
|
669
|
+
)(shortcut)
|
670
|
+
if not use_pre_activation:
|
671
|
+
shortcut = layers.BatchNormalization(
|
672
|
+
axis=bn_axis,
|
673
|
+
epsilon=1e-5,
|
674
|
+
momentum=0.9,
|
675
|
+
dtype=dtype,
|
676
|
+
name=f"{name}_0_bn",
|
677
|
+
)(shortcut)
|
678
|
+
else:
|
679
|
+
shortcut = x
|
680
|
+
|
681
|
+
x = x_preact if x_preact is not None else x
|
682
|
+
if stride > 1:
|
683
|
+
x = layers.ZeroPadding2D(
|
684
|
+
(kernel_size - 1) // 2,
|
685
|
+
data_format=data_format,
|
686
|
+
dtype=dtype,
|
687
|
+
name=f"{name}_1_pad",
|
688
|
+
)(x)
|
689
|
+
x = layers.Conv2D(
|
690
|
+
filters,
|
691
|
+
kernel_size,
|
692
|
+
strides=stride,
|
693
|
+
padding="valid" if stride > 1 else "same",
|
694
|
+
data_format=data_format,
|
695
|
+
use_bias=False,
|
696
|
+
dtype=dtype,
|
697
|
+
name=f"{name}_1_conv",
|
698
|
+
)(x)
|
699
|
+
x = layers.BatchNormalization(
|
700
|
+
axis=bn_axis,
|
701
|
+
epsilon=1e-5,
|
702
|
+
momentum=0.9,
|
703
|
+
dtype=dtype,
|
704
|
+
name=f"{name}_1_bn",
|
705
|
+
)(x)
|
706
|
+
x = layers.Activation("relu", dtype=dtype, name=f"{name}_1_relu")(x)
|
707
|
+
|
708
|
+
x = layers.Conv2D(
|
709
|
+
filters,
|
710
|
+
kernel_size,
|
711
|
+
strides=1,
|
712
|
+
padding="same",
|
713
|
+
data_format=data_format,
|
714
|
+
use_bias=False,
|
715
|
+
dtype=dtype,
|
716
|
+
name=f"{name}_2_conv",
|
717
|
+
)(x)
|
718
|
+
if not use_pre_activation:
|
719
|
+
x = layers.BatchNormalization(
|
720
|
+
axis=bn_axis,
|
721
|
+
epsilon=1e-5,
|
722
|
+
momentum=0.9,
|
723
|
+
dtype=dtype,
|
724
|
+
name=f"{name}_2_bn",
|
725
|
+
)(x)
|
726
|
+
x = layers.Add(dtype=dtype, name=f"{name}_add")([shortcut, x])
|
727
|
+
x = layers.Activation("relu", dtype=dtype, name=f"{name}_out")(x)
|
728
|
+
else:
|
729
|
+
x = layers.Add(dtype=dtype, name=f"{name}_out")([shortcut, x])
|
730
|
+
return x
|
731
|
+
|
732
|
+
|
733
|
+
def apply_bottleneck_block_vd(
|
734
|
+
x,
|
735
|
+
filters,
|
736
|
+
kernel_size=3,
|
737
|
+
stride=1,
|
738
|
+
conv_shortcut=False,
|
739
|
+
use_pre_activation=False,
|
740
|
+
data_format=None,
|
741
|
+
dtype=None,
|
742
|
+
name=None,
|
743
|
+
):
|
744
|
+
"""Applies a bottleneck residual block.
|
745
|
+
|
746
|
+
Args:
|
747
|
+
x: Tensor. The input tensor to pass through the block.
|
748
|
+
filters: int. The number of filters in the block.
|
749
|
+
kernel_size: int. The kernel size of the bottleneck layer. Defaults to
|
750
|
+
`3`.
|
751
|
+
stride: int. The stride length of the first layer. Defaults to `1`.
|
752
|
+
conv_shortcut: bool. If `True`, use a convolution shortcut. If `False`,
|
753
|
+
use an identity or pooling shortcut based on the stride. Defaults to
|
754
|
+
`False`.
|
755
|
+
use_pre_activation: boolean. Whether to use pre-activation or not.
|
756
|
+
`True` for ResNetV2, `False` for ResNet. Defaults to `False`.
|
757
|
+
data_format: `None` or str. the ordering of the dimensions in the
|
758
|
+
inputs. Can be `"channels_last"`
|
759
|
+
(`(batch_size, height, width, channels)`) or`"channels_first"`
|
760
|
+
(`(batch_size, channels, height, width)`).
|
761
|
+
dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
|
762
|
+
to use for the models computations and weights.
|
763
|
+
name: str. A prefix for the layer names used in the block.
|
764
|
+
|
765
|
+
Returns:
|
766
|
+
The output tensor for the residual block.
|
767
|
+
"""
|
768
|
+
data_format = data_format or keras.config.image_data_format()
|
769
|
+
bn_axis = -1 if data_format == "channels_last" else 1
|
770
|
+
|
771
|
+
x_preact = None
|
772
|
+
if use_pre_activation:
|
773
|
+
x_preact = layers.BatchNormalization(
|
774
|
+
axis=bn_axis,
|
775
|
+
epsilon=1e-5,
|
776
|
+
momentum=0.9,
|
777
|
+
dtype=dtype,
|
778
|
+
name=f"{name}_pre_activation_bn",
|
779
|
+
)(x)
|
780
|
+
x_preact = layers.Activation(
|
781
|
+
"relu", dtype=dtype, name=f"{name}_pre_activation_relu"
|
782
|
+
)(x_preact)
|
783
|
+
|
784
|
+
if conv_shortcut:
|
785
|
+
if x_preact is not None:
|
786
|
+
shortcut = x_preact
|
787
|
+
elif stride > 1:
|
788
|
+
shortcut = layers.AveragePooling2D(
|
789
|
+
2,
|
790
|
+
strides=stride,
|
791
|
+
data_format=data_format,
|
792
|
+
dtype=dtype,
|
793
|
+
padding="same",
|
794
|
+
)(x)
|
795
|
+
else:
|
796
|
+
shortcut = x
|
797
|
+
shortcut = layers.Conv2D(
|
798
|
+
4 * filters,
|
799
|
+
1,
|
800
|
+
strides=1,
|
801
|
+
data_format=data_format,
|
802
|
+
use_bias=False,
|
803
|
+
dtype=dtype,
|
804
|
+
name=f"{name}_0_conv",
|
805
|
+
)(shortcut)
|
465
806
|
if not use_pre_activation:
|
466
807
|
shortcut = layers.BatchNormalization(
|
467
808
|
axis=bn_axis,
|
@@ -561,8 +902,11 @@ def apply_stack(
|
|
561
902
|
blocks: int. The number of blocks in the stack.
|
562
903
|
stride: int. The stride length of the first layer in the first block.
|
563
904
|
block_type: str. The block type to stack. One of `"basic_block"` or
|
564
|
-
`"bottleneck_block"
|
565
|
-
Use `"
|
905
|
+
`"bottleneck_block"`, `"basic_block_vd"` or
|
906
|
+
`"bottleneck_block_vd"`. Use `"basic_block"` for ResNet18 and
|
907
|
+
ResNet34. Use `"bottleneck_block"` for ResNet50, ResNet101 and
|
908
|
+
ResNet152 and the `"_vd"` prefix for the respective ResNet_vd
|
909
|
+
variants.
|
566
910
|
use_pre_activation: boolean. Whether to use pre-activation or not.
|
567
911
|
`True` for ResNetV2, `False` for ResNet and ResNeXt.
|
568
912
|
first_shortcut: bool. If `True`, use a convolution shortcut. If `False`,
|
@@ -580,17 +924,21 @@ def apply_stack(
|
|
580
924
|
Output tensor for the stacked blocks.
|
581
925
|
"""
|
582
926
|
if name is None:
|
583
|
-
|
584
|
-
name = f"{version}_stack"
|
927
|
+
name = "stack"
|
585
928
|
|
586
929
|
if block_type == "basic_block":
|
587
930
|
block_fn = apply_basic_block
|
588
931
|
elif block_type == "bottleneck_block":
|
589
932
|
block_fn = apply_bottleneck_block
|
933
|
+
elif block_type == "basic_block_vd":
|
934
|
+
block_fn = apply_basic_block_vd
|
935
|
+
elif block_type == "bottleneck_block_vd":
|
936
|
+
block_fn = apply_bottleneck_block_vd
|
590
937
|
else:
|
591
938
|
raise ValueError(
|
592
|
-
'`block_type` must be either `"basic_block"
|
593
|
-
|
939
|
+
'`block_type` must be either `"basic_block"`, '
|
940
|
+
'`"bottleneck_block"`, `"basic_block_vd"` or '
|
941
|
+
f'`"bottleneck_block_vd"`. Received block_type={block_type}.'
|
594
942
|
)
|
595
943
|
for i in range(blocks):
|
596
944
|
if i == 0:
|