keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +52 -0
- keras_hub/api/__init__.py +27 -0
- keras_hub/api/layers/__init__.py +47 -0
- keras_hub/api/metrics/__init__.py +24 -0
- keras_hub/api/models/__init__.py +249 -0
- keras_hub/api/samplers/__init__.py +29 -0
- keras_hub/api/tokenizers/__init__.py +35 -0
- keras_hub/src/__init__.py +13 -0
- keras_hub/src/api_export.py +53 -0
- keras_hub/src/layers/__init__.py +13 -0
- keras_hub/src/layers/modeling/__init__.py +13 -0
- keras_hub/src/layers/modeling/alibi_bias.py +143 -0
- keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
- keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
- keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
- keras_hub/src/layers/modeling/position_embedding.py +123 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
- keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
- keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
- keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
- keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
- keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
- keras_hub/src/layers/preprocessing/__init__.py +13 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
- keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
- keras_hub/src/layers/preprocessing/random_swap.py +267 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
- keras_hub/src/metrics/__init__.py +13 -0
- keras_hub/src/metrics/bleu.py +394 -0
- keras_hub/src/metrics/edit_distance.py +197 -0
- keras_hub/src/metrics/perplexity.py +181 -0
- keras_hub/src/metrics/rouge_base.py +204 -0
- keras_hub/src/metrics/rouge_l.py +97 -0
- keras_hub/src/metrics/rouge_n.py +125 -0
- keras_hub/src/models/__init__.py +13 -0
- keras_hub/src/models/albert/__init__.py +20 -0
- keras_hub/src/models/albert/albert_backbone.py +267 -0
- keras_hub/src/models/albert/albert_classifier.py +202 -0
- keras_hub/src/models/albert/albert_masked_lm.py +129 -0
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/albert/albert_preprocessor.py +206 -0
- keras_hub/src/models/albert/albert_presets.py +70 -0
- keras_hub/src/models/albert/albert_tokenizer.py +119 -0
- keras_hub/src/models/backbone.py +311 -0
- keras_hub/src/models/bart/__init__.py +20 -0
- keras_hub/src/models/bart/bart_backbone.py +261 -0
- keras_hub/src/models/bart/bart_preprocessor.py +276 -0
- keras_hub/src/models/bart/bart_presets.py +74 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
- keras_hub/src/models/bart/bart_tokenizer.py +124 -0
- keras_hub/src/models/bert/__init__.py +23 -0
- keras_hub/src/models/bert/bert_backbone.py +227 -0
- keras_hub/src/models/bert/bert_classifier.py +183 -0
- keras_hub/src/models/bert/bert_masked_lm.py +131 -0
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/bert/bert_preprocessor.py +184 -0
- keras_hub/src/models/bert/bert_presets.py +147 -0
- keras_hub/src/models/bert/bert_tokenizer.py +112 -0
- keras_hub/src/models/bloom/__init__.py +20 -0
- keras_hub/src/models/bloom/bloom_attention.py +186 -0
- keras_hub/src/models/bloom/bloom_backbone.py +173 -0
- keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
- keras_hub/src/models/bloom/bloom_decoder.py +206 -0
- keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
- keras_hub/src/models/bloom/bloom_presets.py +121 -0
- keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
- keras_hub/src/models/causal_lm.py +383 -0
- keras_hub/src/models/classifier.py +109 -0
- keras_hub/src/models/csp_darknet/__init__.py +13 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
- keras_hub/src/models/deberta_v3/__init__.py +24 -0
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
- keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
- keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
- keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
- keras_hub/src/models/densenet/__init__.py +13 -0
- keras_hub/src/models/densenet/densenet_backbone.py +210 -0
- keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
- keras_hub/src/models/distil_bert/__init__.py +26 -0
- keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
- keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
- keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
- keras_hub/src/models/electra/__init__.py +20 -0
- keras_hub/src/models/electra/electra_backbone.py +247 -0
- keras_hub/src/models/electra/electra_preprocessor.py +154 -0
- keras_hub/src/models/electra/electra_presets.py +95 -0
- keras_hub/src/models/electra/electra_tokenizer.py +104 -0
- keras_hub/src/models/f_net/__init__.py +20 -0
- keras_hub/src/models/f_net/f_net_backbone.py +236 -0
- keras_hub/src/models/f_net/f_net_classifier.py +154 -0
- keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
- keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
- keras_hub/src/models/f_net/f_net_presets.py +43 -0
- keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
- keras_hub/src/models/falcon/__init__.py +20 -0
- keras_hub/src/models/falcon/falcon_attention.py +156 -0
- keras_hub/src/models/falcon/falcon_backbone.py +164 -0
- keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
- keras_hub/src/models/falcon/falcon_presets.py +30 -0
- keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
- keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
- keras_hub/src/models/feature_pyramid_backbone.py +73 -0
- keras_hub/src/models/gemma/__init__.py +20 -0
- keras_hub/src/models/gemma/gemma_attention.py +250 -0
- keras_hub/src/models/gemma/gemma_backbone.py +316 -0
- keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
- keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
- keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
- keras_hub/src/models/gemma/gemma_presets.py +248 -0
- keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
- keras_hub/src/models/gemma/rms_normalization.py +40 -0
- keras_hub/src/models/gpt2/__init__.py +20 -0
- keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
- keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
- keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
- keras_hub/src/models/image_classifier.py +90 -0
- keras_hub/src/models/llama/__init__.py +20 -0
- keras_hub/src/models/llama/llama_attention.py +225 -0
- keras_hub/src/models/llama/llama_backbone.py +188 -0
- keras_hub/src/models/llama/llama_causal_lm.py +327 -0
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
- keras_hub/src/models/llama/llama_decoder.py +246 -0
- keras_hub/src/models/llama/llama_layernorm.py +48 -0
- keras_hub/src/models/llama/llama_preprocessor.py +189 -0
- keras_hub/src/models/llama/llama_presets.py +80 -0
- keras_hub/src/models/llama/llama_tokenizer.py +84 -0
- keras_hub/src/models/llama3/__init__.py +20 -0
- keras_hub/src/models/llama3/llama3_backbone.py +84 -0
- keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
- keras_hub/src/models/llama3/llama3_presets.py +69 -0
- keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
- keras_hub/src/models/masked_lm.py +101 -0
- keras_hub/src/models/mistral/__init__.py +20 -0
- keras_hub/src/models/mistral/mistral_attention.py +238 -0
- keras_hub/src/models/mistral/mistral_backbone.py +203 -0
- keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
- keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
- keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
- keras_hub/src/models/mistral/mistral_presets.py +48 -0
- keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
- keras_hub/src/models/mix_transformer/__init__.py +13 -0
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
- keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
- keras_hub/src/models/opt/__init__.py +20 -0
- keras_hub/src/models/opt/opt_backbone.py +173 -0
- keras_hub/src/models/opt/opt_causal_lm.py +301 -0
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
- keras_hub/src/models/opt/opt_preprocessor.py +188 -0
- keras_hub/src/models/opt/opt_presets.py +72 -0
- keras_hub/src/models/opt/opt_tokenizer.py +116 -0
- keras_hub/src/models/pali_gemma/__init__.py +23 -0
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
- keras_hub/src/models/phi3/__init__.py +20 -0
- keras_hub/src/models/phi3/phi3_attention.py +260 -0
- keras_hub/src/models/phi3/phi3_backbone.py +224 -0
- keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/phi3/phi3_decoder.py +260 -0
- keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
- keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
- keras_hub/src/models/phi3/phi3_presets.py +50 -0
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
- keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
- keras_hub/src/models/preprocessor.py +207 -0
- keras_hub/src/models/resnet/__init__.py +13 -0
- keras_hub/src/models/resnet/resnet_backbone.py +612 -0
- keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
- keras_hub/src/models/roberta/__init__.py +20 -0
- keras_hub/src/models/roberta/roberta_backbone.py +184 -0
- keras_hub/src/models/roberta/roberta_classifier.py +209 -0
- keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
- keras_hub/src/models/roberta/roberta_presets.py +43 -0
- keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
- keras_hub/src/models/seq_2_seq_lm.py +54 -0
- keras_hub/src/models/t5/__init__.py +20 -0
- keras_hub/src/models/t5/t5_backbone.py +261 -0
- keras_hub/src/models/t5/t5_layer_norm.py +35 -0
- keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
- keras_hub/src/models/t5/t5_presets.py +95 -0
- keras_hub/src/models/t5/t5_tokenizer.py +100 -0
- keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
- keras_hub/src/models/task.py +419 -0
- keras_hub/src/models/vgg/__init__.py +13 -0
- keras_hub/src/models/vgg/vgg_backbone.py +158 -0
- keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
- keras_hub/src/models/vit_det/__init__.py +13 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
- keras_hub/src/models/vit_det/vit_layers.py +565 -0
- keras_hub/src/models/whisper/__init__.py +20 -0
- keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
- keras_hub/src/models/whisper/whisper_backbone.py +305 -0
- keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
- keras_hub/src/models/whisper/whisper_decoder.py +141 -0
- keras_hub/src/models/whisper/whisper_encoder.py +106 -0
- keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
- keras_hub/src/models/whisper/whisper_presets.py +148 -0
- keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
- keras_hub/src/models/xlm_roberta/__init__.py +26 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
- keras_hub/src/models/xlnet/__init__.py +13 -0
- keras_hub/src/models/xlnet/relative_attention.py +459 -0
- keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
- keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
- keras_hub/src/samplers/__init__.py +13 -0
- keras_hub/src/samplers/beam_sampler.py +207 -0
- keras_hub/src/samplers/contrastive_sampler.py +231 -0
- keras_hub/src/samplers/greedy_sampler.py +50 -0
- keras_hub/src/samplers/random_sampler.py +77 -0
- keras_hub/src/samplers/sampler.py +237 -0
- keras_hub/src/samplers/serialization.py +97 -0
- keras_hub/src/samplers/top_k_sampler.py +92 -0
- keras_hub/src/samplers/top_p_sampler.py +113 -0
- keras_hub/src/tests/__init__.py +13 -0
- keras_hub/src/tests/test_case.py +608 -0
- keras_hub/src/tokenizers/__init__.py +13 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
- keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
- keras_hub/src/tokenizers/tokenizer.py +235 -0
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
- keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
- keras_hub/src/utils/__init__.py +13 -0
- keras_hub/src/utils/keras_utils.py +130 -0
- keras_hub/src/utils/pipeline_model.py +293 -0
- keras_hub/src/utils/preset_utils.py +621 -0
- keras_hub/src/utils/python_utils.py +21 -0
- keras_hub/src/utils/tensor_utils.py +206 -0
- keras_hub/src/utils/timm/__init__.py +13 -0
- keras_hub/src/utils/timm/convert.py +37 -0
- keras_hub/src/utils/timm/convert_resnet.py +171 -0
- keras_hub/src/utils/transformers/__init__.py +13 -0
- keras_hub/src/utils/transformers/convert.py +101 -0
- keras_hub/src/utils/transformers/convert_bert.py +173 -0
- keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
- keras_hub/src/utils/transformers/convert_gemma.py +187 -0
- keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
- keras_hub/src/utils/transformers/convert_llama3.py +136 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
- keras_hub/src/version_utils.py +23 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,612 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import keras
|
15
|
+
from keras import layers
|
16
|
+
from keras import ops
|
17
|
+
|
18
|
+
from keras_hub.src.api_export import keras_hub_export
|
19
|
+
from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone
|
20
|
+
from keras_hub.src.utils.keras_utils import standardize_data_format
|
21
|
+
|
22
|
+
|
23
|
+
@keras_hub_export("keras_hub.models.ResNetBackbone")
|
24
|
+
class ResNetBackbone(FeaturePyramidBackbone):
|
25
|
+
"""ResNet and ResNetV2 core network with hyperparameters.
|
26
|
+
|
27
|
+
This class implements a ResNet backbone as described in [Deep Residual
|
28
|
+
Learning for Image Recognition](https://arxiv.org/abs/1512.03385)(
|
29
|
+
CVPR 2016), [Identity Mappings in Deep Residual Networks](
|
30
|
+
https://arxiv.org/abs/1603.05027)(ECCV 2016) and [ResNet strikes back: An
|
31
|
+
improved training procedure in timm](https://arxiv.org/abs/2110.00476)(
|
32
|
+
NeurIPS 2021 Workshop).
|
33
|
+
|
34
|
+
The difference in ResNet and ResNetV2 rests in the structure of their
|
35
|
+
individual building blocks. In ResNetV2, the batch normalization and
|
36
|
+
ReLU activation precede the convolution layers, as opposed to ResNet where
|
37
|
+
the batch normalization and ReLU activation are applied after the
|
38
|
+
convolution layers.
|
39
|
+
|
40
|
+
Note that `ResNetBackbone` expects the inputs to be images with a value
|
41
|
+
range of `[0, 255]` when `include_rescaling=True`.
|
42
|
+
|
43
|
+
Args:
|
44
|
+
stackwise_num_filters: list of ints. The number of filters for each
|
45
|
+
stack.
|
46
|
+
stackwise_num_blocks: list of ints. The number of blocks for each stack.
|
47
|
+
stackwise_num_strides: list of ints. The number of strides for each
|
48
|
+
stack.
|
49
|
+
block_type: str. The block type to stack. One of `"basic_block"` or
|
50
|
+
`"bottleneck_block"`. Use `"basic_block"` for ResNet18 and ResNet34.
|
51
|
+
Use `"bottleneck_block"` for ResNet50, ResNet101 and ResNet152.
|
52
|
+
use_pre_activation: boolean. Whether to use pre-activation or not.
|
53
|
+
`True` for ResNetV2, `False` for ResNet.
|
54
|
+
include_rescaling: boolean. If `True`, rescale the input using
|
55
|
+
`Rescaling` and `Normalization` layers. If `False`, do nothing.
|
56
|
+
Defaults to `True`.
|
57
|
+
image_shape: tuple. The input shape without the batch size.
|
58
|
+
Defaults to `(None, None, 3)`.
|
59
|
+
pooling: `None` or str. Pooling mode for feature extraction. Defaults
|
60
|
+
to `"avg"`.
|
61
|
+
- `None` means that the output of the model will be the 4D tensor
|
62
|
+
from the last convolutional block.
|
63
|
+
- `avg` means that global average pooling will be applied to the
|
64
|
+
output of the last convolutional block, resulting in a 2D
|
65
|
+
tensor.
|
66
|
+
- `max` means that global max pooling will be applied to the
|
67
|
+
output of the last convolutional block, resulting in a 2D
|
68
|
+
tensor.
|
69
|
+
data_format: `None` or str. If specified, either `"channels_last"` or
|
70
|
+
`"channels_first"`. The ordering of the dimensions in the
|
71
|
+
inputs. `"channels_last"` corresponds to inputs with shape
|
72
|
+
`(batch_size, height, width, channels)`
|
73
|
+
while `"channels_first"` corresponds to inputs with shape
|
74
|
+
`(batch_size, channels, height, width)`. It defaults to the
|
75
|
+
`image_data_format` value found in your Keras config file at
|
76
|
+
`~/.keras/keras.json`. If you never set it, then it will be
|
77
|
+
`"channels_last"`.
|
78
|
+
dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
|
79
|
+
to use for the model's computations and weights.
|
80
|
+
|
81
|
+
Examples:
|
82
|
+
```python
|
83
|
+
input_data = np.random.uniform(0, 255, size=(2, 224, 224, 3))
|
84
|
+
|
85
|
+
# Pretrained ResNet backbone.
|
86
|
+
model = keras_hub.models.ResNetBackbone.from_preset("resnet50")
|
87
|
+
model(input_data)
|
88
|
+
|
89
|
+
# Randomly initialized ResNetV2 backbone with a custom config.
|
90
|
+
model = keras_hub.models.ResNetBackbone(
|
91
|
+
stackwise_num_filters=[64, 64, 64],
|
92
|
+
stackwise_num_blocks=[2, 2, 2],
|
93
|
+
stackwise_num_strides=[1, 2, 2],
|
94
|
+
block_type="basic_block",
|
95
|
+
use_pre_activation=True,
|
96
|
+
pooling="avg",
|
97
|
+
)
|
98
|
+
model(input_data)
|
99
|
+
```
|
100
|
+
"""
|
101
|
+
|
102
|
+
def __init__(
|
103
|
+
self,
|
104
|
+
stackwise_num_filters,
|
105
|
+
stackwise_num_blocks,
|
106
|
+
stackwise_num_strides,
|
107
|
+
block_type,
|
108
|
+
use_pre_activation=False,
|
109
|
+
include_rescaling=True,
|
110
|
+
image_shape=(None, None, 3),
|
111
|
+
pooling="avg",
|
112
|
+
data_format=None,
|
113
|
+
dtype=None,
|
114
|
+
**kwargs,
|
115
|
+
):
|
116
|
+
if len(stackwise_num_filters) != len(stackwise_num_blocks) or len(
|
117
|
+
stackwise_num_filters
|
118
|
+
) != len(stackwise_num_strides):
|
119
|
+
raise ValueError(
|
120
|
+
"The length of `stackwise_num_filters`, `stackwise_num_blocks` "
|
121
|
+
"and `stackwise_num_strides` must be the same. Received: "
|
122
|
+
f"stackwise_num_filters={stackwise_num_filters}, "
|
123
|
+
f"stackwise_num_blocks={stackwise_num_blocks}, "
|
124
|
+
f"stackwise_num_strides={stackwise_num_strides}"
|
125
|
+
)
|
126
|
+
if stackwise_num_filters[0] != 64:
|
127
|
+
raise ValueError(
|
128
|
+
"The first element of `stackwise_num_filters` must be 64. "
|
129
|
+
f"Received: stackwise_num_filters={stackwise_num_filters}"
|
130
|
+
)
|
131
|
+
if block_type not in ("basic_block", "bottleneck_block"):
|
132
|
+
raise ValueError(
|
133
|
+
'`block_type` must be either `"basic_block"` or '
|
134
|
+
f'`"bottleneck_block"`. Received block_type={block_type}.'
|
135
|
+
)
|
136
|
+
version = "v1" if not use_pre_activation else "v2"
|
137
|
+
data_format = standardize_data_format(data_format)
|
138
|
+
bn_axis = -1 if data_format == "channels_last" else 1
|
139
|
+
num_stacks = len(stackwise_num_filters)
|
140
|
+
|
141
|
+
# === Functional Model ===
|
142
|
+
image_input = layers.Input(shape=image_shape)
|
143
|
+
if include_rescaling:
|
144
|
+
x = layers.Rescaling(scale=1 / 255.0, dtype=dtype)(image_input)
|
145
|
+
x = layers.Normalization(
|
146
|
+
axis=bn_axis,
|
147
|
+
mean=(0.485, 0.456, 0.406),
|
148
|
+
variance=(0.229**2, 0.224**2, 0.225**2),
|
149
|
+
dtype=dtype,
|
150
|
+
name="normalization",
|
151
|
+
)(x)
|
152
|
+
else:
|
153
|
+
x = image_input
|
154
|
+
|
155
|
+
# The padding between torch and tensorflow/jax differs when `strides>1`.
|
156
|
+
# Therefore, we need to manually pad the tensor.
|
157
|
+
x = layers.ZeroPadding2D(
|
158
|
+
3,
|
159
|
+
data_format=data_format,
|
160
|
+
dtype=dtype,
|
161
|
+
name="conv1_pad",
|
162
|
+
)(x)
|
163
|
+
x = layers.Conv2D(
|
164
|
+
64,
|
165
|
+
7,
|
166
|
+
strides=2,
|
167
|
+
data_format=data_format,
|
168
|
+
use_bias=False,
|
169
|
+
dtype=dtype,
|
170
|
+
name="conv1_conv",
|
171
|
+
)(x)
|
172
|
+
if not use_pre_activation:
|
173
|
+
x = layers.BatchNormalization(
|
174
|
+
axis=bn_axis,
|
175
|
+
epsilon=1e-5,
|
176
|
+
momentum=0.9,
|
177
|
+
dtype=dtype,
|
178
|
+
name="conv1_bn",
|
179
|
+
)(x)
|
180
|
+
x = layers.Activation("relu", dtype=dtype, name="conv1_relu")(x)
|
181
|
+
|
182
|
+
if use_pre_activation:
|
183
|
+
# A workaround for ResNetV2: we need -inf padding to prevent zeros
|
184
|
+
# from being the max values in the following `MaxPooling2D`.
|
185
|
+
pad_width = [[1, 1], [1, 1]]
|
186
|
+
if data_format == "channels_last":
|
187
|
+
pad_width += [[0, 0]]
|
188
|
+
else:
|
189
|
+
pad_width = [[0, 0]] + pad_width
|
190
|
+
pad_width = [[0, 0]] + pad_width
|
191
|
+
x = ops.pad(x, pad_width=pad_width, constant_values=float("-inf"))
|
192
|
+
else:
|
193
|
+
x = layers.ZeroPadding2D(
|
194
|
+
1, data_format=data_format, dtype=dtype, name="pool1_pad"
|
195
|
+
)(x)
|
196
|
+
x = layers.MaxPooling2D(
|
197
|
+
3,
|
198
|
+
strides=2,
|
199
|
+
data_format=data_format,
|
200
|
+
dtype=dtype,
|
201
|
+
name="pool1_pool",
|
202
|
+
)(x)
|
203
|
+
|
204
|
+
pyramid_outputs = {}
|
205
|
+
for stack_index in range(num_stacks):
|
206
|
+
x = apply_stack(
|
207
|
+
x,
|
208
|
+
filters=stackwise_num_filters[stack_index],
|
209
|
+
blocks=stackwise_num_blocks[stack_index],
|
210
|
+
stride=stackwise_num_strides[stack_index],
|
211
|
+
block_type=block_type,
|
212
|
+
use_pre_activation=use_pre_activation,
|
213
|
+
first_shortcut=(
|
214
|
+
block_type == "bottleneck_block" or stack_index > 0
|
215
|
+
),
|
216
|
+
data_format=data_format,
|
217
|
+
dtype=dtype,
|
218
|
+
name=f"{version}_stack{stack_index}",
|
219
|
+
)
|
220
|
+
pyramid_outputs[f"P{stack_index + 2}"] = x
|
221
|
+
|
222
|
+
if use_pre_activation:
|
223
|
+
x = layers.BatchNormalization(
|
224
|
+
axis=bn_axis,
|
225
|
+
epsilon=1e-5,
|
226
|
+
momentum=0.9,
|
227
|
+
dtype=dtype,
|
228
|
+
name="post_bn",
|
229
|
+
)(x)
|
230
|
+
x = layers.Activation("relu", dtype=dtype, name="post_relu")(x)
|
231
|
+
|
232
|
+
if pooling == "avg":
|
233
|
+
feature_map_output = layers.GlobalAveragePooling2D(
|
234
|
+
data_format=data_format, dtype=dtype
|
235
|
+
)(x)
|
236
|
+
elif pooling == "max":
|
237
|
+
feature_map_output = layers.GlobalMaxPooling2D(
|
238
|
+
data_format=data_format, dtype=dtype
|
239
|
+
)(x)
|
240
|
+
else:
|
241
|
+
feature_map_output = x
|
242
|
+
|
243
|
+
super().__init__(
|
244
|
+
inputs=image_input,
|
245
|
+
outputs=feature_map_output,
|
246
|
+
dtype=dtype,
|
247
|
+
**kwargs,
|
248
|
+
)
|
249
|
+
|
250
|
+
# === Config ===
|
251
|
+
self.stackwise_num_filters = stackwise_num_filters
|
252
|
+
self.stackwise_num_blocks = stackwise_num_blocks
|
253
|
+
self.stackwise_num_strides = stackwise_num_strides
|
254
|
+
self.block_type = block_type
|
255
|
+
self.use_pre_activation = use_pre_activation
|
256
|
+
self.include_rescaling = include_rescaling
|
257
|
+
self.image_shape = image_shape
|
258
|
+
self.pooling = pooling
|
259
|
+
self.pyramid_outputs = pyramid_outputs
|
260
|
+
|
261
|
+
def get_config(self):
|
262
|
+
config = super().get_config()
|
263
|
+
config.update(
|
264
|
+
{
|
265
|
+
"stackwise_num_filters": self.stackwise_num_filters,
|
266
|
+
"stackwise_num_blocks": self.stackwise_num_blocks,
|
267
|
+
"stackwise_num_strides": self.stackwise_num_strides,
|
268
|
+
"block_type": self.block_type,
|
269
|
+
"use_pre_activation": self.use_pre_activation,
|
270
|
+
"include_rescaling": self.include_rescaling,
|
271
|
+
"image_shape": self.image_shape,
|
272
|
+
"pooling": self.pooling,
|
273
|
+
}
|
274
|
+
)
|
275
|
+
return config
|
276
|
+
|
277
|
+
|
278
|
+
def apply_basic_block(
|
279
|
+
x,
|
280
|
+
filters,
|
281
|
+
kernel_size=3,
|
282
|
+
stride=1,
|
283
|
+
conv_shortcut=False,
|
284
|
+
use_pre_activation=False,
|
285
|
+
data_format=None,
|
286
|
+
dtype=None,
|
287
|
+
name=None,
|
288
|
+
):
|
289
|
+
"""Applies a basic residual block.
|
290
|
+
|
291
|
+
Args:
|
292
|
+
x: Tensor. The input tensor to pass through the block.
|
293
|
+
filters: int. The number of filters in the block.
|
294
|
+
kernel_size: int. The kernel size of the bottleneck layer. Defaults to
|
295
|
+
`3`.
|
296
|
+
stride: int. The stride length of the first layer. Defaults to `1`.
|
297
|
+
conv_shortcut: bool. If `True`, use a convolution shortcut. If `False`,
|
298
|
+
use an identity or pooling shortcut based on the stride. Defaults to
|
299
|
+
`False`.
|
300
|
+
use_pre_activation: boolean. Whether to use pre-activation or not.
|
301
|
+
`True` for ResNetV2, `False` for ResNet. Defaults to `False`.
|
302
|
+
data_format: `None` or str. the ordering of the dimensions in the
|
303
|
+
inputs. Can be `"channels_last"`
|
304
|
+
(`(batch_size, height, width, channels)`) or`"channels_first"`
|
305
|
+
(`(batch_size, channels, height, width)`).
|
306
|
+
dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
|
307
|
+
to use for the models computations and weights.
|
308
|
+
name: str. A prefix for the layer names used in the block.
|
309
|
+
|
310
|
+
Returns:
|
311
|
+
The output tensor for the basic residual block.
|
312
|
+
"""
|
313
|
+
data_format = data_format or keras.config.image_data_format()
|
314
|
+
bn_axis = -1 if data_format == "channels_last" else 1
|
315
|
+
|
316
|
+
x_preact = None
|
317
|
+
if use_pre_activation:
|
318
|
+
x_preact = layers.BatchNormalization(
|
319
|
+
axis=bn_axis,
|
320
|
+
epsilon=1e-5,
|
321
|
+
momentum=0.9,
|
322
|
+
dtype=dtype,
|
323
|
+
name=f"{name}_pre_activation_bn",
|
324
|
+
)(x)
|
325
|
+
x_preact = layers.Activation(
|
326
|
+
"relu", dtype=dtype, name=f"{name}_pre_activation_relu"
|
327
|
+
)(x_preact)
|
328
|
+
|
329
|
+
if conv_shortcut:
|
330
|
+
x = x_preact if x_preact is not None else x
|
331
|
+
shortcut = layers.Conv2D(
|
332
|
+
filters,
|
333
|
+
1,
|
334
|
+
strides=stride,
|
335
|
+
data_format=data_format,
|
336
|
+
use_bias=False,
|
337
|
+
dtype=dtype,
|
338
|
+
name=f"{name}_0_conv",
|
339
|
+
)(x)
|
340
|
+
if not use_pre_activation:
|
341
|
+
shortcut = layers.BatchNormalization(
|
342
|
+
axis=bn_axis,
|
343
|
+
epsilon=1e-5,
|
344
|
+
momentum=0.9,
|
345
|
+
dtype=dtype,
|
346
|
+
name=f"{name}_0_bn",
|
347
|
+
)(shortcut)
|
348
|
+
else:
|
349
|
+
shortcut = x
|
350
|
+
|
351
|
+
x = x_preact if x_preact is not None else x
|
352
|
+
if stride > 1:
|
353
|
+
x = layers.ZeroPadding2D(
|
354
|
+
(kernel_size - 1) // 2,
|
355
|
+
data_format=data_format,
|
356
|
+
dtype=dtype,
|
357
|
+
name=f"{name}_1_pad",
|
358
|
+
)(x)
|
359
|
+
x = layers.Conv2D(
|
360
|
+
filters,
|
361
|
+
kernel_size,
|
362
|
+
strides=stride,
|
363
|
+
padding="valid" if stride > 1 else "same",
|
364
|
+
data_format=data_format,
|
365
|
+
use_bias=False,
|
366
|
+
dtype=dtype,
|
367
|
+
name=f"{name}_1_conv",
|
368
|
+
)(x)
|
369
|
+
x = layers.BatchNormalization(
|
370
|
+
axis=bn_axis,
|
371
|
+
epsilon=1e-5,
|
372
|
+
momentum=0.9,
|
373
|
+
dtype=dtype,
|
374
|
+
name=f"{name}_1_bn",
|
375
|
+
)(x)
|
376
|
+
x = layers.Activation("relu", dtype=dtype, name=f"{name}_1_relu")(x)
|
377
|
+
|
378
|
+
x = layers.Conv2D(
|
379
|
+
filters,
|
380
|
+
kernel_size,
|
381
|
+
strides=1,
|
382
|
+
padding="same",
|
383
|
+
data_format=data_format,
|
384
|
+
use_bias=False,
|
385
|
+
dtype=dtype,
|
386
|
+
name=f"{name}_2_conv",
|
387
|
+
)(x)
|
388
|
+
if not use_pre_activation:
|
389
|
+
x = layers.BatchNormalization(
|
390
|
+
axis=bn_axis,
|
391
|
+
epsilon=1e-5,
|
392
|
+
momentum=0.9,
|
393
|
+
dtype=dtype,
|
394
|
+
name=f"{name}_2_bn",
|
395
|
+
)(x)
|
396
|
+
x = layers.Add(dtype=dtype, name=f"{name}_add")([shortcut, x])
|
397
|
+
x = layers.Activation("relu", dtype=dtype, name=f"{name}_out")(x)
|
398
|
+
else:
|
399
|
+
x = layers.Add(dtype=dtype, name=f"{name}_out")([shortcut, x])
|
400
|
+
return x
|
401
|
+
|
402
|
+
|
403
|
+
def apply_bottleneck_block(
|
404
|
+
x,
|
405
|
+
filters,
|
406
|
+
kernel_size=3,
|
407
|
+
stride=1,
|
408
|
+
conv_shortcut=False,
|
409
|
+
use_pre_activation=False,
|
410
|
+
data_format=None,
|
411
|
+
dtype=None,
|
412
|
+
name=None,
|
413
|
+
):
|
414
|
+
"""Applies a bottleneck residual block.
|
415
|
+
|
416
|
+
Args:
|
417
|
+
x: Tensor. The input tensor to pass through the block.
|
418
|
+
filters: int. The number of filters in the block.
|
419
|
+
kernel_size: int. The kernel size of the bottleneck layer. Defaults to
|
420
|
+
`3`.
|
421
|
+
stride: int. The stride length of the first layer. Defaults to `1`.
|
422
|
+
conv_shortcut: bool. If `True`, use a convolution shortcut. If `False`,
|
423
|
+
use an identity or pooling shortcut based on the stride. Defaults to
|
424
|
+
`False`.
|
425
|
+
use_pre_activation: boolean. Whether to use pre-activation or not.
|
426
|
+
`True` for ResNetV2, `False` for ResNet. Defaults to `False`.
|
427
|
+
data_format: `None` or str. the ordering of the dimensions in the
|
428
|
+
inputs. Can be `"channels_last"`
|
429
|
+
(`(batch_size, height, width, channels)`) or`"channels_first"`
|
430
|
+
(`(batch_size, channels, height, width)`).
|
431
|
+
dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
|
432
|
+
to use for the models computations and weights.
|
433
|
+
name: str. A prefix for the layer names used in the block.
|
434
|
+
|
435
|
+
Returns:
|
436
|
+
The output tensor for the residual block.
|
437
|
+
"""
|
438
|
+
data_format = data_format or keras.config.image_data_format()
|
439
|
+
bn_axis = -1 if data_format == "channels_last" else 1
|
440
|
+
|
441
|
+
x_preact = None
|
442
|
+
if use_pre_activation:
|
443
|
+
x_preact = layers.BatchNormalization(
|
444
|
+
axis=bn_axis,
|
445
|
+
epsilon=1e-5,
|
446
|
+
momentum=0.9,
|
447
|
+
dtype=dtype,
|
448
|
+
name=f"{name}_pre_activation_bn",
|
449
|
+
)(x)
|
450
|
+
x_preact = layers.Activation(
|
451
|
+
"relu", dtype=dtype, name=f"{name}_pre_activation_relu"
|
452
|
+
)(x_preact)
|
453
|
+
|
454
|
+
if conv_shortcut:
|
455
|
+
x = x_preact if x_preact is not None else x
|
456
|
+
shortcut = layers.Conv2D(
|
457
|
+
4 * filters,
|
458
|
+
1,
|
459
|
+
strides=stride,
|
460
|
+
data_format=data_format,
|
461
|
+
use_bias=False,
|
462
|
+
dtype=dtype,
|
463
|
+
name=f"{name}_0_conv",
|
464
|
+
)(x)
|
465
|
+
if not use_pre_activation:
|
466
|
+
shortcut = layers.BatchNormalization(
|
467
|
+
axis=bn_axis,
|
468
|
+
epsilon=1e-5,
|
469
|
+
momentum=0.9,
|
470
|
+
dtype=dtype,
|
471
|
+
name=f"{name}_0_bn",
|
472
|
+
)(shortcut)
|
473
|
+
else:
|
474
|
+
shortcut = x
|
475
|
+
|
476
|
+
x = x_preact if x_preact is not None else x
|
477
|
+
x = layers.Conv2D(
|
478
|
+
filters,
|
479
|
+
1,
|
480
|
+
strides=1,
|
481
|
+
data_format=data_format,
|
482
|
+
use_bias=False,
|
483
|
+
dtype=dtype,
|
484
|
+
name=f"{name}_1_conv",
|
485
|
+
)(x)
|
486
|
+
x = layers.BatchNormalization(
|
487
|
+
axis=bn_axis,
|
488
|
+
epsilon=1e-5,
|
489
|
+
momentum=0.9,
|
490
|
+
dtype=dtype,
|
491
|
+
name=f"{name}_1_bn",
|
492
|
+
)(x)
|
493
|
+
x = layers.Activation("relu", dtype=dtype, name=f"{name}_1_relu")(x)
|
494
|
+
|
495
|
+
if stride > 1:
|
496
|
+
x = layers.ZeroPadding2D(
|
497
|
+
(kernel_size - 1) // 2,
|
498
|
+
data_format=data_format,
|
499
|
+
dtype=dtype,
|
500
|
+
name=f"{name}_2_pad",
|
501
|
+
)(x)
|
502
|
+
x = layers.Conv2D(
|
503
|
+
filters,
|
504
|
+
kernel_size,
|
505
|
+
strides=stride,
|
506
|
+
padding="valid" if stride > 1 else "same",
|
507
|
+
data_format=data_format,
|
508
|
+
use_bias=False,
|
509
|
+
dtype=dtype,
|
510
|
+
name=f"{name}_2_conv",
|
511
|
+
)(x)
|
512
|
+
x = layers.BatchNormalization(
|
513
|
+
axis=bn_axis,
|
514
|
+
epsilon=1e-5,
|
515
|
+
momentum=0.9,
|
516
|
+
dtype=dtype,
|
517
|
+
name=f"{name}_2_bn",
|
518
|
+
)(x)
|
519
|
+
x = layers.Activation("relu", dtype=dtype, name=f"{name}_2_relu")(x)
|
520
|
+
|
521
|
+
x = layers.Conv2D(
|
522
|
+
4 * filters,
|
523
|
+
1,
|
524
|
+
data_format=data_format,
|
525
|
+
use_bias=False,
|
526
|
+
dtype=dtype,
|
527
|
+
name=f"{name}_3_conv",
|
528
|
+
)(x)
|
529
|
+
if not use_pre_activation:
|
530
|
+
x = layers.BatchNormalization(
|
531
|
+
axis=bn_axis,
|
532
|
+
epsilon=1e-5,
|
533
|
+
momentum=0.9,
|
534
|
+
dtype=dtype,
|
535
|
+
name=f"{name}_3_bn",
|
536
|
+
)(x)
|
537
|
+
x = layers.Add(dtype=dtype, name=f"{name}_add")([shortcut, x])
|
538
|
+
x = layers.Activation("relu", dtype=dtype, name=f"{name}_out")(x)
|
539
|
+
else:
|
540
|
+
x = layers.Add(dtype=dtype, name=f"{name}_out")([shortcut, x])
|
541
|
+
return x
|
542
|
+
|
543
|
+
|
544
|
+
def apply_stack(
|
545
|
+
x,
|
546
|
+
filters,
|
547
|
+
blocks,
|
548
|
+
stride,
|
549
|
+
block_type,
|
550
|
+
use_pre_activation,
|
551
|
+
first_shortcut=True,
|
552
|
+
data_format=None,
|
553
|
+
dtype=None,
|
554
|
+
name=None,
|
555
|
+
):
|
556
|
+
"""Applies a set of stacked residual blocks.
|
557
|
+
|
558
|
+
Args:
|
559
|
+
x: Tensor. The input tensor to pass through the stack.
|
560
|
+
filters: int. The number of filters in a block.
|
561
|
+
blocks: int. The number of blocks in the stack.
|
562
|
+
stride: int. The stride length of the first layer in the first block.
|
563
|
+
block_type: str. The block type to stack. One of `"basic_block"` or
|
564
|
+
`"bottleneck_block"`. Use `"basic_block"` for ResNet18 and ResNet34.
|
565
|
+
Use `"bottleneck_block"` for ResNet50, ResNet101 and ResNet152.
|
566
|
+
use_pre_activation: boolean. Whether to use pre-activation or not.
|
567
|
+
`True` for ResNetV2, `False` for ResNet and ResNeXt.
|
568
|
+
first_shortcut: bool. If `True`, use a convolution shortcut. If `False`,
|
569
|
+
use an identity or pooling shortcut based on the stride. Defaults to
|
570
|
+
`True`.
|
571
|
+
data_format: `None` or str. the ordering of the dimensions in the
|
572
|
+
inputs. Can be `"channels_last"`
|
573
|
+
(`(batch_size, height, width, channels)`) or`"channels_first"`
|
574
|
+
(`(batch_size, channels, height, width)`).
|
575
|
+
dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
|
576
|
+
to use for the models computations and weights.
|
577
|
+
name: str. A prefix for the layer names used in the stack.
|
578
|
+
|
579
|
+
Returns:
|
580
|
+
Output tensor for the stacked blocks.
|
581
|
+
"""
|
582
|
+
if name is None:
|
583
|
+
version = "v1" if not use_pre_activation else "v2"
|
584
|
+
name = f"{version}_stack"
|
585
|
+
|
586
|
+
if block_type == "basic_block":
|
587
|
+
block_fn = apply_basic_block
|
588
|
+
elif block_type == "bottleneck_block":
|
589
|
+
block_fn = apply_bottleneck_block
|
590
|
+
else:
|
591
|
+
raise ValueError(
|
592
|
+
'`block_type` must be either `"basic_block"` or '
|
593
|
+
f'`"bottleneck_block"`. Received block_type={block_type}.'
|
594
|
+
)
|
595
|
+
for i in range(blocks):
|
596
|
+
if i == 0:
|
597
|
+
stride = stride
|
598
|
+
conv_shortcut = first_shortcut
|
599
|
+
else:
|
600
|
+
stride = 1
|
601
|
+
conv_shortcut = False
|
602
|
+
x = block_fn(
|
603
|
+
x,
|
604
|
+
filters,
|
605
|
+
stride=stride,
|
606
|
+
conv_shortcut=conv_shortcut,
|
607
|
+
use_pre_activation=use_pre_activation,
|
608
|
+
data_format=data_format,
|
609
|
+
dtype=dtype,
|
610
|
+
name=f"{name}_block{str(i)}",
|
611
|
+
)
|
612
|
+
return x
|