keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +21 -3
- keras_hub/api/models/__init__.py +71 -12
- keras_hub/api/tokenizers/__init__.py +1 -1
- keras_hub/src/bounding_box/__init__.py +2 -0
- keras_hub/src/bounding_box/converters.py +102 -12
- keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
- keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
- keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
- keras_hub/src/layers/modeling/rms_normalization.py +36 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
- keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
- keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
- keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
- keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
- keras_hub/src/layers/preprocessing/image_converter.py +170 -34
- keras_hub/src/metrics/bleu.py +4 -3
- keras_hub/src/models/albert/albert_presets.py +4 -12
- keras_hub/src/models/albert/albert_text_classifier.py +7 -7
- keras_hub/src/models/backbone.py +3 -14
- keras_hub/src/models/bart/bart_backbone.py +4 -4
- keras_hub/src/models/bart/bart_presets.py +3 -9
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
- keras_hub/src/models/basnet/__init__.py +5 -0
- keras_hub/src/models/basnet/basnet.py +122 -0
- keras_hub/src/models/basnet/basnet_backbone.py +366 -0
- keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
- keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
- keras_hub/src/models/basnet/basnet_presets.py +17 -0
- keras_hub/src/models/bert/bert_presets.py +14 -32
- keras_hub/src/models/bert/bert_text_classifier.py +3 -3
- keras_hub/src/models/bloom/bloom_presets.py +8 -24
- keras_hub/src/models/causal_lm.py +56 -12
- keras_hub/src/models/clip/__init__.py +5 -0
- keras_hub/src/models/clip/clip_backbone.py +286 -0
- keras_hub/src/models/clip/clip_encoder_block.py +19 -4
- keras_hub/src/models/clip/clip_image_converter.py +8 -0
- keras_hub/src/models/clip/clip_presets.py +93 -0
- keras_hub/src/models/clip/clip_text_encoder.py +4 -1
- keras_hub/src/models/clip/clip_tokenizer.py +18 -3
- keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
- keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
- keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
- keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
- keras_hub/src/models/deeplab_v3/__init__.py +7 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
- keras_hub/src/models/densenet/densenet_backbone.py +6 -4
- keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
- keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
- keras_hub/src/models/densenet/densenet_presets.py +9 -15
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
- keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
- keras_hub/src/models/efficientnet/__init__.py +9 -0
- keras_hub/src/models/efficientnet/cba.py +141 -0
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
- keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
- keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
- keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
- keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
- keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
- keras_hub/src/models/efficientnet/mbconv.py +53 -22
- keras_hub/src/models/electra/electra_backbone.py +2 -2
- keras_hub/src/models/electra/electra_presets.py +6 -18
- keras_hub/src/models/f_net/f_net_presets.py +2 -6
- keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
- keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
- keras_hub/src/models/falcon/falcon_backbone.py +5 -3
- keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
- keras_hub/src/models/falcon/falcon_presets.py +1 -3
- keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
- keras_hub/src/models/feature_pyramid_backbone.py +1 -1
- keras_hub/src/models/flux/__init__.py +5 -0
- keras_hub/src/models/flux/flux_layers.py +496 -0
- keras_hub/src/models/flux/flux_maths.py +225 -0
- keras_hub/src/models/flux/flux_model.py +236 -0
- keras_hub/src/models/flux/flux_presets.py +3 -0
- keras_hub/src/models/flux/flux_text_to_image.py +146 -0
- keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
- keras_hub/src/models/gemma/gemma_backbone.py +35 -20
- keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
- keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
- keras_hub/src/models/gemma/gemma_presets.py +29 -63
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
- keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
- keras_hub/src/models/image_classifier.py +147 -2
- keras_hub/src/models/image_classifier_preprocessor.py +6 -3
- keras_hub/src/models/image_object_detector.py +87 -0
- keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
- keras_hub/src/models/image_segmenter.py +0 -5
- keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
- keras_hub/src/models/image_to_image.py +417 -0
- keras_hub/src/models/inpaint.py +520 -0
- keras_hub/src/models/llama/llama_backbone.py +138 -12
- keras_hub/src/models/llama/llama_causal_lm.py +3 -1
- keras_hub/src/models/llama/llama_presets.py +10 -20
- keras_hub/src/models/llama3/llama3_backbone.py +12 -11
- keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
- keras_hub/src/models/llama3/llama3_presets.py +4 -12
- keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
- keras_hub/src/models/mistral/mistral_backbone.py +16 -15
- keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
- keras_hub/src/models/mistral/mistral_presets.py +3 -9
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
- keras_hub/src/models/mit/__init__.py +6 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
- keras_hub/src/models/mit/mit_image_classifier.py +12 -0
- keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/mit/mit_image_converter.py +8 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
- keras_hub/src/models/mit/mit_presets.py +139 -0
- keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
- keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
- keras_hub/src/models/opt/opt_causal_lm.py +2 -2
- keras_hub/src/models/opt/opt_presets.py +4 -12
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
- keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
- keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
- keras_hub/src/models/phi3/phi3_decoder.py +0 -1
- keras_hub/src/models/phi3/phi3_presets.py +2 -6
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
- keras_hub/src/models/preprocessor.py +25 -11
- keras_hub/src/models/resnet/resnet_backbone.py +3 -14
- keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
- keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
- keras_hub/src/models/resnet/resnet_presets.py +127 -18
- keras_hub/src/models/retinanet/__init__.py +5 -0
- keras_hub/src/models/retinanet/anchor_generator.py +52 -53
- keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
- keras_hub/src/models/retinanet/non_max_supression.py +1 -0
- keras_hub/src/models/retinanet/prediction_head.py +192 -0
- keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
- keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
- keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
- keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
- keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
- keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
- keras_hub/src/models/roberta/roberta_backbone.py +2 -2
- keras_hub/src/models/roberta/roberta_presets.py +6 -8
- keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
- keras_hub/src/models/sam/__init__.py +5 -0
- keras_hub/src/models/sam/sam_backbone.py +2 -3
- keras_hub/src/models/sam/sam_image_converter.py +2 -4
- keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
- keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
- keras_hub/src/models/sam/sam_layers.py +5 -3
- keras_hub/src/models/sam/sam_presets.py +3 -9
- keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
- keras_hub/src/models/sam/sam_transformer.py +5 -4
- keras_hub/src/models/segformer/__init__.py +8 -0
- keras_hub/src/models/segformer/segformer_backbone.py +167 -0
- keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
- keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
- keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
- keras_hub/src/models/segformer/segformer_presets.py +136 -0
- keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
- keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
- keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
- keras_hub/src/models/t5/t5_backbone.py +5 -4
- keras_hub/src/models/t5/t5_presets.py +47 -19
- keras_hub/src/models/task.py +47 -39
- keras_hub/src/models/text_classifier.py +2 -2
- keras_hub/src/models/text_to_image.py +106 -41
- keras_hub/src/models/vae/__init__.py +1 -0
- keras_hub/src/models/vae/vae_backbone.py +184 -0
- keras_hub/src/models/vae/vae_layers.py +739 -0
- keras_hub/src/models/vgg/__init__.py +5 -0
- keras_hub/src/models/vgg/vgg_backbone.py +4 -24
- keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
- keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
- keras_hub/src/models/vgg/vgg_presets.py +48 -0
- keras_hub/src/models/vit/__init__.py +5 -0
- keras_hub/src/models/vit/vit_backbone.py +152 -0
- keras_hub/src/models/vit/vit_image_classifier.py +187 -0
- keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vit/vit_image_converter.py +73 -0
- keras_hub/src/models/vit/vit_layers.py +391 -0
- keras_hub/src/models/vit/vit_presets.py +126 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
- keras_hub/src/models/vit_det/vit_layers.py +3 -3
- keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
- keras_hub/src/models/whisper/whisper_backbone.py +6 -5
- keras_hub/src/models/whisper/whisper_decoder.py +3 -5
- keras_hub/src/models/whisper/whisper_presets.py +10 -30
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
- keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
- keras_hub/src/models/xlnet/relative_attention.py +20 -19
- keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
- keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
- keras_hub/src/samplers/contrastive_sampler.py +2 -3
- keras_hub/src/samplers/sampler.py +2 -1
- keras_hub/src/tests/test_case.py +41 -6
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
- keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
- keras_hub/src/tokenizers/tokenizer.py +10 -13
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
- keras_hub/src/utils/keras_utils.py +2 -13
- keras_hub/src/utils/pipeline_model.py +3 -3
- keras_hub/src/utils/preset_utils.py +196 -144
- keras_hub/src/utils/tensor_utils.py +4 -4
- keras_hub/src/utils/timm/convert_densenet.py +6 -4
- keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
- keras_hub/src/utils/timm/convert_resnet.py +1 -1
- keras_hub/src/utils/timm/convert_vgg.py +85 -0
- keras_hub/src/utils/timm/preset_loader.py +14 -9
- keras_hub/src/utils/transformers/convert_llama3.py +21 -5
- keras_hub/src/utils/transformers/convert_vit.py +150 -0
- keras_hub/src/utils/transformers/preset_loader.py +23 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
- keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
- keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
- keras_hub/src/models/mix_transformer/__init__.py +0 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
- keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
- keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,17 @@
|
|
1
|
+
import math
|
2
|
+
|
1
3
|
import keras
|
2
4
|
|
5
|
+
from keras_hub.src.utils.keras_utils import standardize_data_format
|
6
|
+
|
3
7
|
|
4
8
|
class FeaturePyramid(keras.layers.Layer):
|
5
9
|
"""A Feature Pyramid Network (FPN) layer.
|
6
10
|
|
7
11
|
This implements the paper:
|
8
|
-
Tsung-Yi Lin, Piotr Dollar, Ross Girshick, Kaiming He,
|
9
|
-
and Serge Belongie.
|
12
|
+
Tsung-Yi Lin, Piotr Dollar, Ross Girshick, Kaiming He,
|
13
|
+
Bharath Hariharan, and Serge Belongie.
|
14
|
+
Feature Pyramid Networks for Object Detection.
|
10
15
|
(https://arxiv.org/pdf/1612.03144)
|
11
16
|
|
12
17
|
Feature Pyramid Networks (FPNs) are basic components that are added to an
|
@@ -37,14 +42,18 @@ class FeaturePyramid(keras.layers.Layer):
|
|
37
42
|
Args:
|
38
43
|
min_level: int. The minimum level of the feature pyramid.
|
39
44
|
max_level: int. The maximum level of the feature pyramid.
|
45
|
+
use_p5: bool. If True, uses the output of the last layer (`P5` from
|
46
|
+
Feature Pyramid Network) as input for creating coarser convolution
|
47
|
+
layers (`P6`, `P7`). If False, uses the direct input `P5`
|
48
|
+
for creating coarser convolution layers.
|
40
49
|
num_filters: int. The number of filters in each feature map.
|
41
50
|
activation: string or `keras.activations`. The activation function
|
42
51
|
to be used in network.
|
43
52
|
Defaults to `"relu"`.
|
44
|
-
kernel_initializer: `str` or `keras.initializers
|
53
|
+
kernel_initializer: `str` or `keras.initializers`.
|
45
54
|
The kernel initializer for the convolution layers.
|
46
55
|
Defaults to `"VarianceScaling"`.
|
47
|
-
bias_initializer: `str` or `keras.initializers
|
56
|
+
bias_initializer: `str` or `keras.initializers`.
|
48
57
|
The bias initializer for the convolution layers.
|
49
58
|
Defaults to `"zeros"`.
|
50
59
|
batch_norm_momentum: float.
|
@@ -53,10 +62,10 @@ class FeaturePyramid(keras.layers.Layer):
|
|
53
62
|
batch_norm_epsilon: float.
|
54
63
|
The epsilon for the batch normalization layers.
|
55
64
|
Defaults to `0.001`.
|
56
|
-
kernel_regularizer: `str` or `keras.regularizers
|
65
|
+
kernel_regularizer: `str` or `keras.regularizers`.
|
57
66
|
The kernel regularizer for the convolution layers.
|
58
67
|
Defaults to `None`.
|
59
|
-
bias_regularizer: `str` or `keras.regularizers
|
68
|
+
bias_regularizer: `str` or `keras.regularizers`.
|
60
69
|
The bias regularizer for the convolution layers.
|
61
70
|
Defaults to `None`.
|
62
71
|
use_batch_norm: bool. Whether to use batch normalization.
|
@@ -69,6 +78,7 @@ class FeaturePyramid(keras.layers.Layer):
|
|
69
78
|
self,
|
70
79
|
min_level,
|
71
80
|
max_level,
|
81
|
+
use_p5,
|
72
82
|
num_filters=256,
|
73
83
|
activation="relu",
|
74
84
|
kernel_initializer="VarianceScaling",
|
@@ -78,6 +88,7 @@ class FeaturePyramid(keras.layers.Layer):
|
|
78
88
|
kernel_regularizer=None,
|
79
89
|
bias_regularizer=None,
|
80
90
|
use_batch_norm=False,
|
91
|
+
data_format=None,
|
81
92
|
**kwargs,
|
82
93
|
):
|
83
94
|
super().__init__(**kwargs)
|
@@ -89,6 +100,7 @@ class FeaturePyramid(keras.layers.Layer):
|
|
89
100
|
self.min_level = min_level
|
90
101
|
self.max_level = max_level
|
91
102
|
self.num_filters = num_filters
|
103
|
+
self.use_p5 = use_p5
|
92
104
|
self.activation = keras.activations.get(activation)
|
93
105
|
self.kernel_initializer = keras.initializers.get(kernel_initializer)
|
94
106
|
self.bias_initializer = keras.initializers.get(bias_initializer)
|
@@ -103,8 +115,8 @@ class FeaturePyramid(keras.layers.Layer):
|
|
103
115
|
self.bias_regularizer = keras.regularizers.get(bias_regularizer)
|
104
116
|
else:
|
105
117
|
self.bias_regularizer = None
|
106
|
-
self.data_format =
|
107
|
-
self.batch_norm_axis = -1 if
|
118
|
+
self.data_format = standardize_data_format(data_format)
|
119
|
+
self.batch_norm_axis = -1 if data_format == "channels_last" else 1
|
108
120
|
|
109
121
|
def build(self, input_shapes):
|
110
122
|
input_shapes = {
|
@@ -117,7 +129,6 @@ class FeaturePyramid(keras.layers.Layer):
|
|
117
129
|
}
|
118
130
|
input_levels = [int(level[1]) for level in input_shapes]
|
119
131
|
backbone_max_level = min(max(input_levels), self.max_level)
|
120
|
-
|
121
132
|
# Build lateral layers
|
122
133
|
self.lateral_conv_layers = {}
|
123
134
|
for i in range(self.min_level, backbone_max_level + 1):
|
@@ -134,7 +145,11 @@ class FeaturePyramid(keras.layers.Layer):
|
|
134
145
|
dtype=self.dtype_policy,
|
135
146
|
name=f"lateral_conv_{level}",
|
136
147
|
)
|
137
|
-
self.lateral_conv_layers[level].build(
|
148
|
+
self.lateral_conv_layers[level].build(
|
149
|
+
(None, None, None, input_shapes[level][-1])
|
150
|
+
if self.data_format == "channels_last"
|
151
|
+
else (None, input_shapes[level][1], None, None)
|
152
|
+
)
|
138
153
|
|
139
154
|
self.lateral_batch_norm_layers = {}
|
140
155
|
if self.use_batch_norm:
|
@@ -149,9 +164,9 @@ class FeaturePyramid(keras.layers.Layer):
|
|
149
164
|
)
|
150
165
|
)
|
151
166
|
self.lateral_batch_norm_layers[level].build(
|
152
|
-
(None, None, None,
|
167
|
+
(None, None, None, self.num_filters)
|
153
168
|
if self.data_format == "channels_last"
|
154
|
-
else (None,
|
169
|
+
else (None, self.num_filters, None, None)
|
155
170
|
)
|
156
171
|
|
157
172
|
# Build output layers
|
@@ -171,9 +186,9 @@ class FeaturePyramid(keras.layers.Layer):
|
|
171
186
|
name=f"output_conv_{level}",
|
172
187
|
)
|
173
188
|
self.output_conv_layers[level].build(
|
174
|
-
(None, None, None,
|
189
|
+
(None, None, None, self.num_filters)
|
175
190
|
if self.data_format == "channels_last"
|
176
|
-
else (None,
|
191
|
+
else (None, self.num_filters, None, None)
|
177
192
|
)
|
178
193
|
|
179
194
|
# Build coarser layers
|
@@ -192,11 +207,18 @@ class FeaturePyramid(keras.layers.Layer):
|
|
192
207
|
dtype=self.dtype_policy,
|
193
208
|
name=f"coarser_{level}",
|
194
209
|
)
|
195
|
-
self.
|
196
|
-
(
|
197
|
-
|
198
|
-
|
199
|
-
|
210
|
+
if i == backbone_max_level + 1 and self.use_p5:
|
211
|
+
self.output_conv_layers[level].build(
|
212
|
+
(None, None, None, input_shapes[f"P{i - 1}"][-1])
|
213
|
+
if self.data_format == "channels_last"
|
214
|
+
else (None, input_shapes[f"P{i - 1}"][1], None, None)
|
215
|
+
)
|
216
|
+
else:
|
217
|
+
self.output_conv_layers[level].build(
|
218
|
+
(None, None, None, self.num_filters)
|
219
|
+
if self.data_format == "channels_last"
|
220
|
+
else (None, self.num_filters, None, None)
|
221
|
+
)
|
200
222
|
|
201
223
|
# Build batch norm layers
|
202
224
|
self.output_batch_norms = {}
|
@@ -212,9 +234,9 @@ class FeaturePyramid(keras.layers.Layer):
|
|
212
234
|
)
|
213
235
|
)
|
214
236
|
self.output_batch_norms[level].build(
|
215
|
-
(None, None, None,
|
237
|
+
(None, None, None, self.num_filters)
|
216
238
|
if self.data_format == "channels_last"
|
217
|
-
else (None,
|
239
|
+
else (None, self.num_filters, None, None)
|
218
240
|
)
|
219
241
|
|
220
242
|
# The same upsampling layer is used for all levels
|
@@ -255,7 +277,7 @@ class FeaturePyramid(keras.layers.Layer):
|
|
255
277
|
if i < backbone_max_level:
|
256
278
|
# for the top most output, it doesn't need to merge with any
|
257
279
|
# upper stream outputs
|
258
|
-
upstream_output = self.top_down_op(output_features[f"P{i+1}"])
|
280
|
+
upstream_output = self.top_down_op(output_features[f"P{i + 1}"])
|
259
281
|
output = self.merge_op([output, upstream_output])
|
260
282
|
output_features[level] = (
|
261
283
|
self.lateral_batch_norm_layers[level](output)
|
@@ -273,7 +295,11 @@ class FeaturePyramid(keras.layers.Layer):
|
|
273
295
|
|
274
296
|
for i in range(backbone_max_level + 1, self.max_level + 1):
|
275
297
|
level = f"P{i}"
|
276
|
-
feats_in =
|
298
|
+
feats_in = (
|
299
|
+
inputs[f"P{i - 1}"]
|
300
|
+
if i == backbone_max_level + 1 and self.use_p5
|
301
|
+
else output_features[f"P{i - 1}"]
|
302
|
+
)
|
277
303
|
if i > backbone_max_level + 1:
|
278
304
|
feats_in = self.activation(feats_in)
|
279
305
|
output_features[level] = (
|
@@ -283,7 +309,10 @@ class FeaturePyramid(keras.layers.Layer):
|
|
283
309
|
if self.use_batch_norm
|
284
310
|
else self.output_conv_layers[level](feats_in)
|
285
311
|
)
|
286
|
-
|
312
|
+
output_features = {
|
313
|
+
f"P{i}": output_features[f"P{i}"]
|
314
|
+
for i in range(self.min_level, self.max_level + 1)
|
315
|
+
}
|
287
316
|
return output_features
|
288
317
|
|
289
318
|
def get_config(self):
|
@@ -293,7 +322,9 @@ class FeaturePyramid(keras.layers.Layer):
|
|
293
322
|
"min_level": self.min_level,
|
294
323
|
"max_level": self.max_level,
|
295
324
|
"num_filters": self.num_filters,
|
325
|
+
"use_p5": self.use_p5,
|
296
326
|
"use_batch_norm": self.use_batch_norm,
|
327
|
+
"data_format": self.data_format,
|
297
328
|
"activation": keras.activations.serialize(self.activation),
|
298
329
|
"kernel_initializer": keras.initializers.serialize(
|
299
330
|
self.kernel_initializer
|
@@ -320,34 +351,51 @@ class FeaturePyramid(keras.layers.Layer):
|
|
320
351
|
|
321
352
|
def compute_output_shape(self, input_shapes):
|
322
353
|
output_shape = {}
|
323
|
-
print(input_shapes)
|
324
354
|
input_levels = [int(level[1]) for level in input_shapes]
|
325
355
|
backbone_max_level = min(max(input_levels), self.max_level)
|
326
356
|
|
327
357
|
for i in range(self.min_level, backbone_max_level + 1):
|
328
358
|
level = f"P{i}"
|
329
359
|
if self.data_format == "channels_last":
|
330
|
-
output_shape[level] = input_shapes[level][:-1] + (
|
360
|
+
output_shape[level] = input_shapes[level][:-1] + (
|
361
|
+
self.num_filters,
|
362
|
+
)
|
331
363
|
else:
|
332
364
|
output_shape[level] = (
|
333
365
|
input_shapes[level][0],
|
334
|
-
|
366
|
+
self.num_filters,
|
335
367
|
) + input_shapes[level][1:3]
|
336
368
|
|
337
369
|
intermediate_shape = input_shapes[f"P{backbone_max_level}"]
|
338
370
|
intermediate_shape = (
|
339
371
|
(
|
340
372
|
intermediate_shape[0],
|
341
|
-
|
342
|
-
|
343
|
-
|
373
|
+
(
|
374
|
+
int(math.ceil(intermediate_shape[1] / 2))
|
375
|
+
if intermediate_shape[1] is not None
|
376
|
+
else None
|
377
|
+
),
|
378
|
+
(
|
379
|
+
int(math.ceil(intermediate_shape[1] / 2))
|
380
|
+
if intermediate_shape[1] is not None
|
381
|
+
else None
|
382
|
+
),
|
383
|
+
self.num_filters,
|
344
384
|
)
|
345
385
|
if self.data_format == "channels_last"
|
346
386
|
else (
|
347
387
|
intermediate_shape[0],
|
348
|
-
|
349
|
-
|
350
|
-
|
388
|
+
self.num_filters,
|
389
|
+
(
|
390
|
+
int(math.ceil(intermediate_shape[1] / 2))
|
391
|
+
if intermediate_shape[1] is not None
|
392
|
+
else None
|
393
|
+
),
|
394
|
+
(
|
395
|
+
int(math.ceil(intermediate_shape[1] / 2))
|
396
|
+
if intermediate_shape[1] is not None
|
397
|
+
else None
|
398
|
+
),
|
351
399
|
)
|
352
400
|
)
|
353
401
|
|
@@ -357,16 +405,32 @@ class FeaturePyramid(keras.layers.Layer):
|
|
357
405
|
intermediate_shape = (
|
358
406
|
(
|
359
407
|
intermediate_shape[0],
|
360
|
-
|
361
|
-
|
362
|
-
|
408
|
+
(
|
409
|
+
int(math.ceil(intermediate_shape[1] / 2))
|
410
|
+
if intermediate_shape[1] is not None
|
411
|
+
else None
|
412
|
+
),
|
413
|
+
(
|
414
|
+
int(math.ceil(intermediate_shape[1] / 2))
|
415
|
+
if intermediate_shape[1] is not None
|
416
|
+
else None
|
417
|
+
),
|
418
|
+
self.num_filters,
|
363
419
|
)
|
364
420
|
if self.data_format == "channels_last"
|
365
421
|
else (
|
366
422
|
intermediate_shape[0],
|
367
|
-
|
368
|
-
|
369
|
-
|
423
|
+
self.num_filters,
|
424
|
+
(
|
425
|
+
int(math.ceil(intermediate_shape[1] / 2))
|
426
|
+
if intermediate_shape[1] is not None
|
427
|
+
else None
|
428
|
+
),
|
429
|
+
(
|
430
|
+
int(math.ceil(intermediate_shape[1] / 2))
|
431
|
+
if intermediate_shape[1] is not None
|
432
|
+
else None
|
433
|
+
),
|
370
434
|
)
|
371
435
|
)
|
372
436
|
|
@@ -3,6 +3,7 @@ import math
|
|
3
3
|
import keras
|
4
4
|
from keras import ops
|
5
5
|
|
6
|
+
# TODO: https://github.com/keras-team/keras-hub/issues/1965
|
6
7
|
from keras_hub.src.bounding_box import converters
|
7
8
|
from keras_hub.src.bounding_box import utils
|
8
9
|
from keras_hub.src.bounding_box import validate_format
|
@@ -0,0 +1,192 @@
|
|
1
|
+
import keras
|
2
|
+
|
3
|
+
from keras_hub.src.utils.keras_utils import standardize_data_format
|
4
|
+
|
5
|
+
|
6
|
+
class PredictionHead(keras.layers.Layer):
|
7
|
+
"""A head for classification or bounding box regression predictions.
|
8
|
+
|
9
|
+
Args:
|
10
|
+
output_filters: int. The umber of convolution filters in the final
|
11
|
+
layer. The number of output channels determines the prediction type:
|
12
|
+
- **Classification**:
|
13
|
+
`output_filters = num_anchors * num_classes`
|
14
|
+
Predicts class probabilities for each anchor.
|
15
|
+
- **Bounding Box Regression**:
|
16
|
+
`output_filters = num_anchors * 4` Predicts bounding box
|
17
|
+
offsets (x1, y1, x2, y2) for each anchor.
|
18
|
+
num_filters: int. The number of convolution filters to use in the base
|
19
|
+
layer.
|
20
|
+
num_conv_layers: int. The number of convolution layers before the final
|
21
|
+
layer.
|
22
|
+
use_prior_probability: bool. Set to True to use prior probability in the
|
23
|
+
bias initializer for the final convolution layer.
|
24
|
+
Defaults to `False`.
|
25
|
+
prior_probability: float. The prior probability value to use for
|
26
|
+
initializing the bias. Only used if `use_prior_probability` is
|
27
|
+
`True`. Defaults to `0.01`.
|
28
|
+
kernel_initializer: `str` or `keras.initializers`. The kernel
|
29
|
+
initializer for the convolution layers. Defaults to
|
30
|
+
`"random_normal"`.
|
31
|
+
bias_initializer: `str` or `keras.initializers`. The bias initializer
|
32
|
+
for the convolution layers. Defaults to `"zeros"`.
|
33
|
+
kernel_regularizer: `str` or `keras.regularizers`. The kernel
|
34
|
+
regularizer for the convolution layers. Defaults to `None`.
|
35
|
+
bias_regularizer: `str` or `keras.regularizers`. The bias regularizer
|
36
|
+
for the convolution layers. Defaults to `None`.
|
37
|
+
use_group_norm: bool. Whether to use Group Normalization after
|
38
|
+
the convolution layers. Defaults to `False`.
|
39
|
+
|
40
|
+
Returns:
|
41
|
+
A function representing either the classification
|
42
|
+
or the box regression head depending on `output_filters`.
|
43
|
+
"""
|
44
|
+
|
45
|
+
def __init__(
|
46
|
+
self,
|
47
|
+
output_filters,
|
48
|
+
num_filters,
|
49
|
+
num_conv_layers,
|
50
|
+
use_prior_probability=False,
|
51
|
+
prior_probability=0.01,
|
52
|
+
activation="relu",
|
53
|
+
kernel_initializer="random_normal",
|
54
|
+
bias_initializer="zeros",
|
55
|
+
kernel_regularizer=None,
|
56
|
+
bias_regularizer=None,
|
57
|
+
use_group_norm=False,
|
58
|
+
data_format=None,
|
59
|
+
**kwargs,
|
60
|
+
):
|
61
|
+
super().__init__(**kwargs)
|
62
|
+
|
63
|
+
self.output_filters = output_filters
|
64
|
+
self.num_filters = num_filters
|
65
|
+
self.num_conv_layers = num_conv_layers
|
66
|
+
self.use_prior_probability = use_prior_probability
|
67
|
+
self.prior_probability = prior_probability
|
68
|
+
self.activation = keras.activations.get(activation)
|
69
|
+
self.kernel_initializer = keras.initializers.get(kernel_initializer)
|
70
|
+
self.bias_initializer = keras.initializers.get(bias_initializer)
|
71
|
+
if kernel_regularizer is not None:
|
72
|
+
self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)
|
73
|
+
else:
|
74
|
+
self.kernel_regularizer = None
|
75
|
+
if bias_regularizer is not None:
|
76
|
+
self.bias_regularizer = keras.regularizers.get(bias_regularizer)
|
77
|
+
else:
|
78
|
+
self.bias_regularizer = None
|
79
|
+
self.use_group_norm = use_group_norm
|
80
|
+
self.data_format = standardize_data_format(data_format)
|
81
|
+
|
82
|
+
def build(self, input_shape):
|
83
|
+
intermediate_shape = input_shape
|
84
|
+
self.conv_layers = []
|
85
|
+
self.group_norm_layers = []
|
86
|
+
for idx in range(self.num_conv_layers):
|
87
|
+
conv = keras.layers.Conv2D(
|
88
|
+
self.num_filters,
|
89
|
+
kernel_size=3,
|
90
|
+
padding="same",
|
91
|
+
kernel_initializer=self.kernel_initializer,
|
92
|
+
bias_initializer=self.bias_initializer,
|
93
|
+
use_bias=not self.use_group_norm,
|
94
|
+
kernel_regularizer=self.kernel_regularizer,
|
95
|
+
bias_regularizer=self.bias_regularizer,
|
96
|
+
data_format=self.data_format,
|
97
|
+
dtype=self.dtype_policy,
|
98
|
+
name=f"conv2d_{idx}",
|
99
|
+
)
|
100
|
+
conv.build(intermediate_shape)
|
101
|
+
self.conv_layers.append(conv)
|
102
|
+
intermediate_shape = (
|
103
|
+
input_shape[:-1] + (self.num_filters,)
|
104
|
+
if self.data_format == "channels_last"
|
105
|
+
else (input_shape[0], self.num_filters) + (input_shape[1:-1])
|
106
|
+
)
|
107
|
+
if self.use_group_norm:
|
108
|
+
group_norm = keras.layers.GroupNormalization(
|
109
|
+
groups=32,
|
110
|
+
axis=-1 if self.data_format == "channels_last" else 1,
|
111
|
+
dtype=self.dtype_policy,
|
112
|
+
name=f"group_norm_{idx}",
|
113
|
+
)
|
114
|
+
group_norm.build(intermediate_shape)
|
115
|
+
self.group_norm_layers.append(group_norm)
|
116
|
+
prior_probability = keras.initializers.Constant(
|
117
|
+
-1
|
118
|
+
* keras.ops.log(
|
119
|
+
(1 - self.prior_probability) / self.prior_probability
|
120
|
+
)
|
121
|
+
)
|
122
|
+
self.prediction_layer = keras.layers.Conv2D(
|
123
|
+
self.output_filters,
|
124
|
+
kernel_size=3,
|
125
|
+
strides=1,
|
126
|
+
padding="same",
|
127
|
+
kernel_initializer=self.kernel_initializer,
|
128
|
+
bias_initializer=(
|
129
|
+
prior_probability
|
130
|
+
if self.use_prior_probability
|
131
|
+
else self.bias_initializer
|
132
|
+
),
|
133
|
+
kernel_regularizer=self.kernel_regularizer,
|
134
|
+
bias_regularizer=self.bias_regularizer,
|
135
|
+
dtype=self.dtype_policy,
|
136
|
+
name="logits_layer",
|
137
|
+
)
|
138
|
+
self.prediction_layer.build(
|
139
|
+
(None, None, None, self.num_filters)
|
140
|
+
if self.data_format == "channels_last"
|
141
|
+
else (None, self.num_filters, None, None)
|
142
|
+
)
|
143
|
+
self.built = True
|
144
|
+
|
145
|
+
def call(self, input):
|
146
|
+
x = input
|
147
|
+
for idx in range(self.num_conv_layers):
|
148
|
+
x = self.conv_layers[idx](x)
|
149
|
+
if self.use_group_norm:
|
150
|
+
x = self.group_norm_layers[idx](x)
|
151
|
+
x = self.activation(x)
|
152
|
+
|
153
|
+
output = self.prediction_layer(x)
|
154
|
+
return output
|
155
|
+
|
156
|
+
def get_config(self):
|
157
|
+
config = super().get_config()
|
158
|
+
config.update(
|
159
|
+
{
|
160
|
+
"output_filters": self.output_filters,
|
161
|
+
"num_filters": self.num_filters,
|
162
|
+
"num_conv_layers": self.num_conv_layers,
|
163
|
+
"use_group_norm": self.use_group_norm,
|
164
|
+
"use_prior_probability": self.use_prior_probability,
|
165
|
+
"prior_probability": self.prior_probability,
|
166
|
+
"activation": keras.activations.serialize(self.activation),
|
167
|
+
"kernel_initializer": keras.initializers.serialize(
|
168
|
+
self.kernel_initializer
|
169
|
+
),
|
170
|
+
"bias_initializer": keras.initializers.serialize(
|
171
|
+
self.kernel_initializer
|
172
|
+
),
|
173
|
+
"kernel_regularizer": (
|
174
|
+
keras.regularizers.serialize(self.kernel_regularizer)
|
175
|
+
if self.kernel_regularizer is not None
|
176
|
+
else None
|
177
|
+
),
|
178
|
+
"bias_regularizer": (
|
179
|
+
keras.regularizers.serialize(self.bias_regularizer)
|
180
|
+
if self.bias_regularizer is not None
|
181
|
+
else None
|
182
|
+
),
|
183
|
+
}
|
184
|
+
)
|
185
|
+
return config
|
186
|
+
|
187
|
+
def compute_output_shape(self, input_shape):
|
188
|
+
return (
|
189
|
+
input_shape[:-1] + (self.output_filters,)
|
190
|
+
if self.data_format == "channels_last"
|
191
|
+
else (input_shape[0],) + (self.output_filters,) + input_shape[1:-1]
|
192
|
+
)
|
@@ -0,0 +1,146 @@
|
|
1
|
+
import keras
|
2
|
+
|
3
|
+
from keras_hub.src.api_export import keras_hub_export
|
4
|
+
from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone
|
5
|
+
from keras_hub.src.models.retinanet.feature_pyramid import FeaturePyramid
|
6
|
+
from keras_hub.src.utils.keras_utils import standardize_data_format
|
7
|
+
|
8
|
+
|
9
|
+
@keras_hub_export("keras_hub.models.RetinaNetBackbone")
|
10
|
+
class RetinaNetBackbone(FeaturePyramidBackbone):
|
11
|
+
"""RetinaNet Backbone.
|
12
|
+
|
13
|
+
Combines a CNN backbone (e.g., ResNet, MobileNet) with a feature pyramid
|
14
|
+
network (FPN)to extract multi-scale features for object detection.
|
15
|
+
|
16
|
+
Args:
|
17
|
+
image_encoder: `keras.Model`. The backbone model (e.g., ResNet50,
|
18
|
+
MobileNetV2) used to extract features from the input image.
|
19
|
+
It should have pyramid outputs (i.e., a dictionary mapping level
|
20
|
+
names like `"P2"`, `"P3"`, etc. to their corresponding feature
|
21
|
+
tensors).
|
22
|
+
min_level: int. The minimum level of the feature pyramid (e.g., 3).
|
23
|
+
This determines the coarsest level of features used.
|
24
|
+
max_level: int. The maximum level of the feature pyramid (e.g., 7).
|
25
|
+
This determines the finest level of features used.
|
26
|
+
use_p5: bool. Determines the input source for creating coarser
|
27
|
+
feature pyramid levels. If `True`, the output of the last backbone
|
28
|
+
layer (typically `'P5'` in an FPN) is used as input to create
|
29
|
+
higher-level feature maps (e.g., `'P6'`, `'P7'`) through
|
30
|
+
additional convolutional layers. If `False`, the original `'P5'`
|
31
|
+
feature map from the backbone is directly used as input for
|
32
|
+
creating the coarser levels, bypassing any further processing of
|
33
|
+
`'P5'` within the feature pyramid. Defaults to `False`.
|
34
|
+
use_fpn_batch_norm: bool. Whether to use batch normalization in the
|
35
|
+
feature pyramid network. Defaults to `False`.
|
36
|
+
image_shape: tuple. tuple. The shape of the input image (H, W, C).
|
37
|
+
The height and width can be `None` if they are variable.
|
38
|
+
data_format: str. The data format of the input image
|
39
|
+
(channels_first or channels_last).
|
40
|
+
dtype: str. The data type of the input image.
|
41
|
+
**kwargs: Additional keyword arguments passed to the base class.
|
42
|
+
|
43
|
+
Raises:
|
44
|
+
ValueError: If `min_level` is greater than `max_level`.
|
45
|
+
ValueError: If `backbone_max_level` is less than 5 and `max_level` is
|
46
|
+
greater than or equal to 5.
|
47
|
+
"""
|
48
|
+
|
49
|
+
def __init__(
|
50
|
+
self,
|
51
|
+
image_encoder,
|
52
|
+
min_level,
|
53
|
+
max_level,
|
54
|
+
use_p5,
|
55
|
+
use_fpn_batch_norm=False,
|
56
|
+
image_shape=(None, None, 3),
|
57
|
+
data_format=None,
|
58
|
+
dtype=None,
|
59
|
+
**kwargs,
|
60
|
+
):
|
61
|
+
# === Layers ===
|
62
|
+
if min_level > max_level:
|
63
|
+
raise ValueError(
|
64
|
+
f"Minimum level ({min_level}) must be less than or equal to "
|
65
|
+
f"maximum level ({max_level})."
|
66
|
+
)
|
67
|
+
|
68
|
+
data_format = standardize_data_format(data_format)
|
69
|
+
input_levels = [
|
70
|
+
int(level[1]) for level in image_encoder.pyramid_outputs
|
71
|
+
]
|
72
|
+
backbone_max_level = min(max(input_levels), max_level)
|
73
|
+
|
74
|
+
if backbone_max_level < 5 and max_level >= 5:
|
75
|
+
raise ValueError(
|
76
|
+
f"Backbone maximum level ({backbone_max_level}) is less than "
|
77
|
+
f"the desired maximum level ({max_level}). "
|
78
|
+
f"Please ensure that the backbone can generate features up to "
|
79
|
+
f"the specified maximum level."
|
80
|
+
)
|
81
|
+
feature_extractor = keras.Model(
|
82
|
+
inputs=image_encoder.inputs,
|
83
|
+
outputs={
|
84
|
+
f"P{level}": image_encoder.pyramid_outputs[f"P{level}"]
|
85
|
+
for level in range(min_level, backbone_max_level + 1)
|
86
|
+
},
|
87
|
+
name="backbone",
|
88
|
+
)
|
89
|
+
|
90
|
+
feature_pyramid = FeaturePyramid(
|
91
|
+
min_level=min_level,
|
92
|
+
max_level=max_level,
|
93
|
+
use_p5=use_p5,
|
94
|
+
name="fpn",
|
95
|
+
dtype=dtype,
|
96
|
+
data_format=data_format,
|
97
|
+
use_batch_norm=use_fpn_batch_norm,
|
98
|
+
)
|
99
|
+
|
100
|
+
# === Functional model ===
|
101
|
+
image_input = keras.layers.Input(image_shape, name="inputs")
|
102
|
+
feature_extractor_outputs = feature_extractor(image_input)
|
103
|
+
feature_pyramid_outputs = feature_pyramid(feature_extractor_outputs)
|
104
|
+
|
105
|
+
super().__init__(
|
106
|
+
inputs=image_input,
|
107
|
+
outputs=feature_pyramid_outputs,
|
108
|
+
dtype=dtype,
|
109
|
+
**kwargs,
|
110
|
+
)
|
111
|
+
|
112
|
+
# === config ===
|
113
|
+
self.min_level = min_level
|
114
|
+
self.max_level = max_level
|
115
|
+
self.use_p5 = use_p5
|
116
|
+
self.use_fpn_batch_norm = use_fpn_batch_norm
|
117
|
+
self.image_encoder = image_encoder
|
118
|
+
self.feature_pyramid = feature_pyramid
|
119
|
+
self.image_shape = image_shape
|
120
|
+
self.pyramid_outputs = feature_pyramid_outputs
|
121
|
+
|
122
|
+
def get_config(self):
|
123
|
+
config = super().get_config()
|
124
|
+
config.update(
|
125
|
+
{
|
126
|
+
"image_encoder": keras.layers.serialize(self.image_encoder),
|
127
|
+
"min_level": self.min_level,
|
128
|
+
"max_level": self.max_level,
|
129
|
+
"use_p5": self.use_p5,
|
130
|
+
"use_fpn_batch_norm": self.use_fpn_batch_norm,
|
131
|
+
"image_shape": self.image_shape,
|
132
|
+
}
|
133
|
+
)
|
134
|
+
return config
|
135
|
+
|
136
|
+
@classmethod
|
137
|
+
def from_config(cls, config):
|
138
|
+
config.update(
|
139
|
+
{
|
140
|
+
"image_encoder": keras.layers.deserialize(
|
141
|
+
config["image_encoder"]
|
142
|
+
),
|
143
|
+
}
|
144
|
+
)
|
145
|
+
|
146
|
+
return super().from_config(config)
|