keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +21 -3
- keras_hub/api/models/__init__.py +71 -12
- keras_hub/api/tokenizers/__init__.py +1 -1
- keras_hub/src/bounding_box/__init__.py +2 -0
- keras_hub/src/bounding_box/converters.py +102 -12
- keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
- keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
- keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
- keras_hub/src/layers/modeling/rms_normalization.py +36 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
- keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
- keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
- keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
- keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
- keras_hub/src/layers/preprocessing/image_converter.py +170 -34
- keras_hub/src/metrics/bleu.py +4 -3
- keras_hub/src/models/albert/albert_presets.py +4 -12
- keras_hub/src/models/albert/albert_text_classifier.py +7 -7
- keras_hub/src/models/backbone.py +3 -14
- keras_hub/src/models/bart/bart_backbone.py +4 -4
- keras_hub/src/models/bart/bart_presets.py +3 -9
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
- keras_hub/src/models/basnet/__init__.py +5 -0
- keras_hub/src/models/basnet/basnet.py +122 -0
- keras_hub/src/models/basnet/basnet_backbone.py +366 -0
- keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
- keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
- keras_hub/src/models/basnet/basnet_presets.py +17 -0
- keras_hub/src/models/bert/bert_presets.py +14 -32
- keras_hub/src/models/bert/bert_text_classifier.py +3 -3
- keras_hub/src/models/bloom/bloom_presets.py +8 -24
- keras_hub/src/models/causal_lm.py +56 -12
- keras_hub/src/models/clip/__init__.py +5 -0
- keras_hub/src/models/clip/clip_backbone.py +286 -0
- keras_hub/src/models/clip/clip_encoder_block.py +19 -4
- keras_hub/src/models/clip/clip_image_converter.py +8 -0
- keras_hub/src/models/clip/clip_presets.py +93 -0
- keras_hub/src/models/clip/clip_text_encoder.py +4 -1
- keras_hub/src/models/clip/clip_tokenizer.py +18 -3
- keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
- keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
- keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
- keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
- keras_hub/src/models/deeplab_v3/__init__.py +7 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
- keras_hub/src/models/densenet/densenet_backbone.py +6 -4
- keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
- keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
- keras_hub/src/models/densenet/densenet_presets.py +9 -15
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
- keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
- keras_hub/src/models/efficientnet/__init__.py +9 -0
- keras_hub/src/models/efficientnet/cba.py +141 -0
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
- keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
- keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
- keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
- keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
- keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
- keras_hub/src/models/efficientnet/mbconv.py +53 -22
- keras_hub/src/models/electra/electra_backbone.py +2 -2
- keras_hub/src/models/electra/electra_presets.py +6 -18
- keras_hub/src/models/f_net/f_net_presets.py +2 -6
- keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
- keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
- keras_hub/src/models/falcon/falcon_backbone.py +5 -3
- keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
- keras_hub/src/models/falcon/falcon_presets.py +1 -3
- keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
- keras_hub/src/models/feature_pyramid_backbone.py +1 -1
- keras_hub/src/models/flux/__init__.py +5 -0
- keras_hub/src/models/flux/flux_layers.py +496 -0
- keras_hub/src/models/flux/flux_maths.py +225 -0
- keras_hub/src/models/flux/flux_model.py +236 -0
- keras_hub/src/models/flux/flux_presets.py +3 -0
- keras_hub/src/models/flux/flux_text_to_image.py +146 -0
- keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
- keras_hub/src/models/gemma/gemma_backbone.py +35 -20
- keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
- keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
- keras_hub/src/models/gemma/gemma_presets.py +29 -63
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
- keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
- keras_hub/src/models/image_classifier.py +147 -2
- keras_hub/src/models/image_classifier_preprocessor.py +6 -3
- keras_hub/src/models/image_object_detector.py +87 -0
- keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
- keras_hub/src/models/image_segmenter.py +0 -5
- keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
- keras_hub/src/models/image_to_image.py +417 -0
- keras_hub/src/models/inpaint.py +520 -0
- keras_hub/src/models/llama/llama_backbone.py +138 -12
- keras_hub/src/models/llama/llama_causal_lm.py +3 -1
- keras_hub/src/models/llama/llama_presets.py +10 -20
- keras_hub/src/models/llama3/llama3_backbone.py +12 -11
- keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
- keras_hub/src/models/llama3/llama3_presets.py +4 -12
- keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
- keras_hub/src/models/mistral/mistral_backbone.py +16 -15
- keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
- keras_hub/src/models/mistral/mistral_presets.py +3 -9
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
- keras_hub/src/models/mit/__init__.py +6 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
- keras_hub/src/models/mit/mit_image_classifier.py +12 -0
- keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/mit/mit_image_converter.py +8 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
- keras_hub/src/models/mit/mit_presets.py +139 -0
- keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
- keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
- keras_hub/src/models/opt/opt_causal_lm.py +2 -2
- keras_hub/src/models/opt/opt_presets.py +4 -12
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
- keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
- keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
- keras_hub/src/models/phi3/phi3_decoder.py +0 -1
- keras_hub/src/models/phi3/phi3_presets.py +2 -6
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
- keras_hub/src/models/preprocessor.py +25 -11
- keras_hub/src/models/resnet/resnet_backbone.py +3 -14
- keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
- keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
- keras_hub/src/models/resnet/resnet_presets.py +127 -18
- keras_hub/src/models/retinanet/__init__.py +5 -0
- keras_hub/src/models/retinanet/anchor_generator.py +52 -53
- keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
- keras_hub/src/models/retinanet/non_max_supression.py +1 -0
- keras_hub/src/models/retinanet/prediction_head.py +192 -0
- keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
- keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
- keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
- keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
- keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
- keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
- keras_hub/src/models/roberta/roberta_backbone.py +2 -2
- keras_hub/src/models/roberta/roberta_presets.py +6 -8
- keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
- keras_hub/src/models/sam/__init__.py +5 -0
- keras_hub/src/models/sam/sam_backbone.py +2 -3
- keras_hub/src/models/sam/sam_image_converter.py +2 -4
- keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
- keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
- keras_hub/src/models/sam/sam_layers.py +5 -3
- keras_hub/src/models/sam/sam_presets.py +3 -9
- keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
- keras_hub/src/models/sam/sam_transformer.py +5 -4
- keras_hub/src/models/segformer/__init__.py +8 -0
- keras_hub/src/models/segformer/segformer_backbone.py +167 -0
- keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
- keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
- keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
- keras_hub/src/models/segformer/segformer_presets.py +136 -0
- keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
- keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
- keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
- keras_hub/src/models/t5/t5_backbone.py +5 -4
- keras_hub/src/models/t5/t5_presets.py +47 -19
- keras_hub/src/models/task.py +47 -39
- keras_hub/src/models/text_classifier.py +2 -2
- keras_hub/src/models/text_to_image.py +106 -41
- keras_hub/src/models/vae/__init__.py +1 -0
- keras_hub/src/models/vae/vae_backbone.py +184 -0
- keras_hub/src/models/vae/vae_layers.py +739 -0
- keras_hub/src/models/vgg/__init__.py +5 -0
- keras_hub/src/models/vgg/vgg_backbone.py +4 -24
- keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
- keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
- keras_hub/src/models/vgg/vgg_presets.py +48 -0
- keras_hub/src/models/vit/__init__.py +5 -0
- keras_hub/src/models/vit/vit_backbone.py +152 -0
- keras_hub/src/models/vit/vit_image_classifier.py +187 -0
- keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vit/vit_image_converter.py +73 -0
- keras_hub/src/models/vit/vit_layers.py +391 -0
- keras_hub/src/models/vit/vit_presets.py +126 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
- keras_hub/src/models/vit_det/vit_layers.py +3 -3
- keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
- keras_hub/src/models/whisper/whisper_backbone.py +6 -5
- keras_hub/src/models/whisper/whisper_decoder.py +3 -5
- keras_hub/src/models/whisper/whisper_presets.py +10 -30
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
- keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
- keras_hub/src/models/xlnet/relative_attention.py +20 -19
- keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
- keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
- keras_hub/src/samplers/contrastive_sampler.py +2 -3
- keras_hub/src/samplers/sampler.py +2 -1
- keras_hub/src/tests/test_case.py +41 -6
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
- keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
- keras_hub/src/tokenizers/tokenizer.py +10 -13
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
- keras_hub/src/utils/keras_utils.py +2 -13
- keras_hub/src/utils/pipeline_model.py +3 -3
- keras_hub/src/utils/preset_utils.py +196 -144
- keras_hub/src/utils/tensor_utils.py +4 -4
- keras_hub/src/utils/timm/convert_densenet.py +6 -4
- keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
- keras_hub/src/utils/timm/convert_resnet.py +1 -1
- keras_hub/src/utils/timm/convert_vgg.py +85 -0
- keras_hub/src/utils/timm/preset_loader.py +14 -9
- keras_hub/src/utils/transformers/convert_llama3.py +21 -5
- keras_hub/src/utils/transformers/convert_vit.py +150 -0
- keras_hub/src/utils/transformers/preset_loader.py +23 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
- keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
- keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
- keras_hub/src/models/mix_transformer/__init__.py +0 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
- keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
- keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -5,26 +5,24 @@ backbone_presets = {
|
|
5
5
|
"metadata": {
|
6
6
|
"description": (
|
7
7
|
"12-layer RoBERTa model where case is maintained."
|
8
|
-
"Trained on English Wikipedia, BooksCorpus, CommonCraw, and
|
8
|
+
"Trained on English Wikipedia, BooksCorpus, CommonCraw, and "
|
9
|
+
"OpenWebText."
|
9
10
|
),
|
10
11
|
"params": 124052736,
|
11
|
-
"official_name": "RoBERTa",
|
12
12
|
"path": "roberta",
|
13
|
-
"model_card": "https://github.com/facebookresearch/fairseq/blob/main/examples/roberta/README.md",
|
14
13
|
},
|
15
|
-
"kaggle_handle": "kaggle://keras/roberta/keras/roberta_base_en/
|
14
|
+
"kaggle_handle": "kaggle://keras/roberta/keras/roberta_base_en/3",
|
16
15
|
},
|
17
16
|
"roberta_large_en": {
|
18
17
|
"metadata": {
|
19
18
|
"description": (
|
20
19
|
"24-layer RoBERTa model where case is maintained."
|
21
|
-
"Trained on English Wikipedia, BooksCorpus, CommonCraw, and
|
20
|
+
"Trained on English Wikipedia, BooksCorpus, CommonCraw, and "
|
21
|
+
"OpenWebText."
|
22
22
|
),
|
23
23
|
"params": 354307072,
|
24
|
-
"official_name": "RoBERTa",
|
25
24
|
"path": "roberta",
|
26
|
-
"model_card": "https://github.com/facebookresearch/fairseq/blob/main/examples/roberta/README.md",
|
27
25
|
},
|
28
|
-
"kaggle_handle": "kaggle://keras/roberta/keras/roberta_large_en/
|
26
|
+
"kaggle_handle": "kaggle://keras/roberta/keras/roberta_large_en/3",
|
29
27
|
},
|
30
28
|
}
|
@@ -38,9 +38,9 @@ class RobertaTextClassifier(TextClassifier):
|
|
38
38
|
Args:
|
39
39
|
backbone: A `keras_hub.models.RobertaBackbone` instance.
|
40
40
|
num_classes: int. Number of classes to predict.
|
41
|
-
preprocessor: A `keras_hub.models.RobertaTextClassifierPreprocessor` or
|
42
|
-
`None`, this model will not apply preprocessing, and
|
43
|
-
be preprocessed before calling the model.
|
41
|
+
preprocessor: A `keras_hub.models.RobertaTextClassifierPreprocessor` or
|
42
|
+
`None`. If `None`, this model will not apply preprocessing, and
|
43
|
+
inputs should be preprocessed before calling the model.
|
44
44
|
activation: Optional `str` or callable. The activation function to use
|
45
45
|
on the model outputs. Set `activation="softmax"` to return output
|
46
46
|
probabilities. Defaults to `None`.
|
@@ -9,8 +9,8 @@ class SAMBackbone(Backbone):
|
|
9
9
|
"""A backbone for the Segment Anything Model (SAM).
|
10
10
|
|
11
11
|
Args:
|
12
|
-
image_encoder: `keras_hub.models.ViTDetBackbone`. A feature extractor
|
13
|
-
the input images.
|
12
|
+
image_encoder: `keras_hub.models.ViTDetBackbone`. A feature extractor
|
13
|
+
for the input images.
|
14
14
|
prompt_encoder: `keras_hub.layers.SAMPromptEncoder`. A Keras layer to
|
15
15
|
compute embeddings for points, box, and mask prompt.
|
16
16
|
mask_decoder: `keras_hub.layers.SAMMaskDecoder`. A Keras layer to
|
@@ -68,7 +68,6 @@ class SAMBackbone(Backbone):
|
|
68
68
|
image_encoder=image_encoder,
|
69
69
|
prompt_encoder=prompt_encoder,
|
70
70
|
mask_decoder=mask_decoder,
|
71
|
-
image_shape=(image_size, image_size, 3),
|
72
71
|
)
|
73
72
|
backbone(input_data)
|
74
73
|
```
|
@@ -1,10 +1,8 @@
|
|
1
1
|
from keras_hub.src.api_export import keras_hub_export
|
2
|
-
from keras_hub.src.layers.preprocessing.
|
3
|
-
ResizingImageConverter,
|
4
|
-
)
|
2
|
+
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
|
5
3
|
from keras_hub.src.models.sam.sam_backbone import SAMBackbone
|
6
4
|
|
7
5
|
|
8
6
|
@keras_hub_export("keras_hub.layers.SAMImageConverter")
|
9
|
-
class SAMImageConverter(
|
7
|
+
class SAMImageConverter(ImageConverter):
|
10
8
|
backbone_cls = SAMBackbone
|
@@ -31,7 +31,7 @@ class SAMImageSegmenter(ImageSegmenter):
|
|
31
31
|
|
32
32
|
|
33
33
|
Args:
|
34
|
-
backbone: A `keras_hub.models.
|
34
|
+
backbone: A `keras_hub.models.SAMBackbone` instance.
|
35
35
|
|
36
36
|
Example:
|
37
37
|
Load pretrained model using `from_preset`.
|
@@ -51,9 +51,9 @@ class SAMImageSegmenter(ImageSegmenter):
|
|
51
51
|
(batch_size, 0, image_size, image_size, 1)
|
52
52
|
),
|
53
53
|
}
|
54
|
-
|
55
|
-
|
56
|
-
|
54
|
+
sam = keras_hub.models.SAMImageSegmenter.from_preset('sam_base_sa1b')
|
55
|
+
outputs = sam.predict(input_data)
|
56
|
+
masks, iou_pred = outputs["masks"], outputs["iou_pred"]
|
57
57
|
```
|
58
58
|
|
59
59
|
Load segment anything image segmenter with custom backbone
|
@@ -65,7 +65,7 @@ class SAMImageSegmenter(ImageSegmenter):
|
|
65
65
|
(batch_size, image_size, image_size, 3),
|
66
66
|
dtype="float32",
|
67
67
|
)
|
68
|
-
image_encoder = ViTDetBackbone(
|
68
|
+
image_encoder = keras_hub.models.ViTDetBackbone(
|
69
69
|
hidden_size=16,
|
70
70
|
num_layers=16,
|
71
71
|
intermediate_dim=16 * 4,
|
@@ -76,7 +76,7 @@ class SAMImageSegmenter(ImageSegmenter):
|
|
76
76
|
window_size=2,
|
77
77
|
image_shape=(image_size, image_size, 3),
|
78
78
|
)
|
79
|
-
prompt_encoder = SAMPromptEncoder(
|
79
|
+
prompt_encoder = keras_hub.layers.SAMPromptEncoder(
|
80
80
|
hidden_size=8,
|
81
81
|
image_embedding_size=(8, 8),
|
82
82
|
input_image_size=(
|
@@ -85,7 +85,7 @@ class SAMImageSegmenter(ImageSegmenter):
|
|
85
85
|
),
|
86
86
|
mask_in_channels=16,
|
87
87
|
)
|
88
|
-
mask_decoder = SAMMaskDecoder(
|
88
|
+
mask_decoder = keras_hub.layers.SAMMaskDecoder(
|
89
89
|
num_layers=2,
|
90
90
|
hidden_size=8,
|
91
91
|
intermediate_dim=32,
|
@@ -95,13 +95,12 @@ class SAMImageSegmenter(ImageSegmenter):
|
|
95
95
|
iou_head_depth=3,
|
96
96
|
iou_head_hidden_dim=8,
|
97
97
|
)
|
98
|
-
backbone = SAMBackbone(
|
98
|
+
backbone = keras_hub.models.SAMBackbone(
|
99
99
|
image_encoder=image_encoder,
|
100
100
|
prompt_encoder=prompt_encoder,
|
101
101
|
mask_decoder=mask_decoder,
|
102
|
-
image_shape=(image_size, image_size, 3),
|
103
102
|
)
|
104
|
-
sam = SAMImageSegmenter(
|
103
|
+
sam = keras_hub.models.SAMImageSegmenter(
|
105
104
|
backbone=backbone
|
106
105
|
)
|
107
106
|
```
|
@@ -115,7 +114,7 @@ class SAMImageSegmenter(ImageSegmenter):
|
|
115
114
|
labels = np.array([[1., 0.]])
|
116
115
|
box = np.array([[[[384., 384.], [640., 640.]]]])
|
117
116
|
input_mask = np.ones((1, 1, 256, 256, 1))
|
118
|
-
Prepare an input dictionary:
|
117
|
+
# Prepare an input dictionary:
|
119
118
|
inputs = {
|
120
119
|
"images": image,
|
121
120
|
"points": points,
|
@@ -201,17 +200,18 @@ class SAMImageSegmenter(ImageSegmenter):
|
|
201
200
|
def _add_placeholder_prompts(self, inputs):
|
202
201
|
"""Adds placeholder prompt inputs for a call to SAM.
|
203
202
|
|
204
|
-
Because SAM is a functional subclass model, all inputs must be specified
|
205
|
-
calls to the model. However, prompt inputs are all optional, so we
|
206
|
-
add placeholders when they're not specified by the user.
|
203
|
+
Because SAM is a functional subclass model, all inputs must be specified
|
204
|
+
in calls to the model. However, prompt inputs are all optional, so we
|
205
|
+
have to add placeholders when they're not specified by the user.
|
207
206
|
"""
|
208
207
|
inputs = inputs.copy()
|
209
208
|
|
210
209
|
# Get the batch shape based on the image input
|
211
210
|
batch_size = ops.shape(inputs["images"])[0]
|
212
211
|
|
213
|
-
# The type of the placeholders must match the existing inputs with
|
214
|
-
# to whether or not they are tensors (as opposed to Numpy
|
212
|
+
# The type of the placeholders must match the existing inputs with
|
213
|
+
# respect to whether or not they are tensors (as opposed to Numpy
|
214
|
+
# arrays).
|
215
215
|
zeros = ops.zeros if ops.is_tensor(inputs["images"]) else np.zeros
|
216
216
|
|
217
217
|
# Fill in missing inputs.
|
@@ -1,12 +1,22 @@
|
|
1
|
+
import keras
|
2
|
+
|
1
3
|
from keras_hub.src.api_export import keras_hub_export
|
2
4
|
from keras_hub.src.models.image_segmenter_preprocessor import (
|
3
5
|
ImageSegmenterPreprocessor,
|
4
6
|
)
|
5
7
|
from keras_hub.src.models.sam.sam_backbone import SAMBackbone
|
6
8
|
from keras_hub.src.models.sam.sam_image_converter import SAMImageConverter
|
9
|
+
from keras_hub.src.utils.tensor_utils import preprocessing_function
|
7
10
|
|
8
11
|
|
9
|
-
@keras_hub_export("keras_hub.models.
|
12
|
+
@keras_hub_export("keras_hub.models.SAMImageSegmenterPreprocessor")
|
10
13
|
class SAMImageSegmenterPreprocessor(ImageSegmenterPreprocessor):
|
11
14
|
backbone_cls = SAMBackbone
|
12
15
|
image_converter_cls = SAMImageConverter
|
16
|
+
|
17
|
+
@preprocessing_function
|
18
|
+
def call(self, x, y=None, sample_weight=None):
|
19
|
+
images = x["images"]
|
20
|
+
if self.image_converter:
|
21
|
+
x["images"] = self.image_converter(images)
|
22
|
+
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
|
@@ -170,8 +170,8 @@ class TwoWayMultiHeadAttention(keras.layers.Layer):
|
|
170
170
|
key_dim: int. Size of each attention head for query, key, and
|
171
171
|
value.
|
172
172
|
intermediate_dim: int. Number of hidden dims to use in the mlp block.
|
173
|
-
skip_first_layer_pos_embedding: bool. A boolean indicating whether to
|
174
|
-
first layer positional embeddings.
|
173
|
+
skip_first_layer_pos_embedding: bool. A boolean indicating whether to
|
174
|
+
skip the first layer positional embeddings.
|
175
175
|
attention_downsample_rate: int, optional. The downsample rate to use
|
176
176
|
in the attention layers. Defaults to 2.
|
177
177
|
activation: str, optional. The activation for the mlp block's output
|
@@ -296,7 +296,9 @@ class TwoWayMultiHeadAttention(keras.layers.Layer):
|
|
296
296
|
"num_heads": self.num_heads,
|
297
297
|
"key_dim": self.key_dim,
|
298
298
|
"intermediate_dim": self.intermediate_dim,
|
299
|
-
"skip_first_layer_pos_embedding":
|
299
|
+
"skip_first_layer_pos_embedding": (
|
300
|
+
self.skip_first_layer_pos_embedding
|
301
|
+
),
|
300
302
|
"attention_downsample_rate": self.attention_downsample_rate,
|
301
303
|
"activation": self.activation,
|
302
304
|
}
|
@@ -5,30 +5,24 @@ backbone_presets = {
|
|
5
5
|
"metadata": {
|
6
6
|
"description": ("The base SAM model trained on the SA1B dataset."),
|
7
7
|
"params": 93735728,
|
8
|
-
"official_name": "SAMImageSegmenter",
|
9
8
|
"path": "sam",
|
10
|
-
"model_card": "https://arxiv.org/abs/2304.02643",
|
11
9
|
},
|
12
|
-
"kaggle_handle": "kaggle://
|
10
|
+
"kaggle_handle": "kaggle://keras/sam/keras/sam_base_sa1b/5",
|
13
11
|
},
|
14
12
|
"sam_large_sa1b": {
|
15
13
|
"metadata": {
|
16
14
|
"description": ("The large SAM model trained on the SA1B dataset."),
|
17
15
|
"params": 641090864,
|
18
|
-
"official_name": "SAMImageSegmenter",
|
19
16
|
"path": "sam",
|
20
|
-
"model_card": "https://arxiv.org/abs/2304.02643",
|
21
17
|
},
|
22
|
-
"kaggle_handle": "kaggle://
|
18
|
+
"kaggle_handle": "kaggle://keras/sam/keras/sam_large_sa1b/5",
|
23
19
|
},
|
24
20
|
"sam_huge_sa1b": {
|
25
21
|
"metadata": {
|
26
22
|
"description": ("The huge SAM model trained on the SA1B dataset."),
|
27
23
|
"params": 312343088,
|
28
|
-
"official_name": "SAMImageSegmenter",
|
29
24
|
"path": "sam",
|
30
|
-
"model_card": "https://arxiv.org/abs/2304.02643",
|
31
25
|
},
|
32
|
-
"kaggle_handle": "kaggle://
|
26
|
+
"kaggle_handle": "kaggle://keras/sam/keras/sam_huge_sa1b/5",
|
33
27
|
},
|
34
28
|
}
|
@@ -57,7 +57,7 @@ class SAMPromptEncoder(keras.layers.Layer):
|
|
57
57
|
input_image_size=(1024, 1024),
|
58
58
|
mask_in_channels=16,
|
59
59
|
activation="gelu",
|
60
|
-
**kwargs
|
60
|
+
**kwargs,
|
61
61
|
):
|
62
62
|
super().__init__(**kwargs)
|
63
63
|
self.hidden_size = hidden_size
|
@@ -305,7 +305,9 @@ class SAMPromptEncoder(keras.layers.Layer):
|
|
305
305
|
return {
|
306
306
|
"prompt_sparse_embeddings": sparse_embeddings,
|
307
307
|
"prompt_dense_embeddings": dense_embeddings,
|
308
|
-
"prompt_dense_positional_embeddings":
|
308
|
+
"prompt_dense_positional_embeddings": (
|
309
|
+
prompt_dense_positional_embeddings
|
310
|
+
),
|
309
311
|
}
|
310
312
|
|
311
313
|
def get_config(self):
|
@@ -31,14 +31,15 @@ class TwoWayTransformer(keras.layers.Layer):
|
|
31
31
|
location and type.
|
32
32
|
|
33
33
|
Args:
|
34
|
-
num_layers: int, optional. The num_layers of the attention blocks
|
35
|
-
of attention blocks to use). Defaults to `2`.
|
34
|
+
num_layers: int, optional. The num_layers of the attention blocks
|
35
|
+
(the number of attention blocks to use). Defaults to `2`.
|
36
36
|
hidden_size: int, optional. The number of features of the input image
|
37
37
|
and point embeddings. Defaults to `256`.
|
38
38
|
num_heads: int, optional. Number of heads to use in the attention
|
39
39
|
layers. Defaults to `8`.
|
40
|
-
intermediate_dim: int, optional. The number of units in the hidden
|
41
|
-
the MLP block used in the attention layers.
|
40
|
+
intermediate_dim: int, optional. The number of units in the hidden
|
41
|
+
layer of the MLP block used in the attention layers.
|
42
|
+
Defaults to `2048`.
|
42
43
|
activation: str, optional. The activation of the MLP block's output
|
43
44
|
layer used in the attention layers. Defaults to `"relu"`.
|
44
45
|
attention_downsample_rate: int, optional. The downsample rate of the
|
@@ -0,0 +1,8 @@
|
|
1
|
+
from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone
|
2
|
+
from keras_hub.src.models.segformer.segformer_image_segmenter import (
|
3
|
+
SegFormerImageSegmenter,
|
4
|
+
)
|
5
|
+
from keras_hub.src.models.segformer.segformer_presets import presets
|
6
|
+
from keras_hub.src.utils.preset_utils import register_presets
|
7
|
+
|
8
|
+
register_presets(presets, SegFormerImageSegmenter)
|
@@ -0,0 +1,167 @@
|
|
1
|
+
import keras
|
2
|
+
|
3
|
+
from keras_hub.src.api_export import keras_hub_export
|
4
|
+
from keras_hub.src.models.backbone import Backbone
|
5
|
+
|
6
|
+
|
7
|
+
@keras_hub_export("keras_hub.models.SegFormerBackbone")
|
8
|
+
class SegFormerBackbone(Backbone):
|
9
|
+
"""A Keras model implementing SegFormer for semantic segmentation.
|
10
|
+
|
11
|
+
This class implements the majority of the SegFormer architecture described
|
12
|
+
in [SegFormer: Simple and Efficient Design for Semantic Segmentation](https://arxiv.org/abs/2105.15203)
|
13
|
+
and based on the TensorFlow implementation
|
14
|
+
[from DeepVision](https://github.com/DavidLandup0/deepvision/tree/main/deepvision/models/segmentation/segformer).
|
15
|
+
|
16
|
+
SegFormers are meant to be used with the MixTransformer (MiT) encoder
|
17
|
+
family, and use a very lightweight all-MLP decoder head.
|
18
|
+
|
19
|
+
The MiT encoder uses a hierarchical transformer which outputs features at
|
20
|
+
multiple scales, similar to that of the hierarchical outputs typically
|
21
|
+
associated with CNNs.
|
22
|
+
|
23
|
+
Args:
|
24
|
+
image_encoder: `keras.Model`. The backbone network for the model that is
|
25
|
+
used as a feature extractor for the SegFormer encoder.
|
26
|
+
Should be used with the MiT backbone model
|
27
|
+
(`keras_hub.models.MiTBackbone`) which was created
|
28
|
+
specifically for SegFormers.
|
29
|
+
num_classes: int, the number of classes for the detection model,
|
30
|
+
including the background class.
|
31
|
+
projection_filters: int, number of filters in the
|
32
|
+
convolution layer projecting the concatenated features into
|
33
|
+
a segmentation map. Defaults to 256`.
|
34
|
+
|
35
|
+
Example:
|
36
|
+
|
37
|
+
Using the class with a custom `backbone`:
|
38
|
+
|
39
|
+
```python
|
40
|
+
import keras_hub
|
41
|
+
|
42
|
+
backbone = keras_hub.models.MiTBackbone(
|
43
|
+
depths=[2, 2, 2, 2],
|
44
|
+
image_shape=(224, 224, 3),
|
45
|
+
hidden_dims=[32, 64, 160, 256],
|
46
|
+
num_layers=4,
|
47
|
+
blockwise_num_heads=[1, 2, 5, 8],
|
48
|
+
blockwise_sr_ratios=[8, 4, 2, 1],
|
49
|
+
max_drop_path_rate=0.1,
|
50
|
+
patch_sizes=[7, 3, 3, 3],
|
51
|
+
strides=[4, 2, 2, 2],
|
52
|
+
)
|
53
|
+
|
54
|
+
segformer_backbone = keras_hub.models.SegFormerBackbone(
|
55
|
+
image_encoder=backbone, projection_filters=256)
|
56
|
+
```
|
57
|
+
|
58
|
+
Using the class with a preset `backbone`:
|
59
|
+
|
60
|
+
```python
|
61
|
+
import keras_hub
|
62
|
+
|
63
|
+
backbone = keras_hub.models.MiTBackbone.from_preset("mit_b0_ade20k_512")
|
64
|
+
segformer_backbone = keras_hub.models.SegFormerBackbone(
|
65
|
+
image_encoder=backbone, projection_filters=256)
|
66
|
+
```
|
67
|
+
|
68
|
+
"""
|
69
|
+
|
70
|
+
def __init__(
|
71
|
+
self,
|
72
|
+
image_encoder,
|
73
|
+
projection_filters,
|
74
|
+
**kwargs,
|
75
|
+
):
|
76
|
+
if not isinstance(image_encoder, keras.layers.Layer) or not isinstance(
|
77
|
+
image_encoder, keras.Model
|
78
|
+
):
|
79
|
+
raise ValueError(
|
80
|
+
"Argument `image_encoder` must be a `keras.layers.Layer` "
|
81
|
+
f"instance or `keras.Model`. Received instead "
|
82
|
+
f"image_encoder={image_encoder} "
|
83
|
+
f"(of type {type(image_encoder)})."
|
84
|
+
)
|
85
|
+
|
86
|
+
# === Layers ===
|
87
|
+
inputs = keras.layers.Input(shape=image_encoder.input.shape[1:])
|
88
|
+
|
89
|
+
self.feature_extractor = keras.Model(
|
90
|
+
image_encoder.inputs, image_encoder.pyramid_outputs
|
91
|
+
)
|
92
|
+
|
93
|
+
features = self.feature_extractor(inputs)
|
94
|
+
# Get height and width of level one output
|
95
|
+
_, height, width, _ = features["P1"].shape
|
96
|
+
|
97
|
+
self.mlp_blocks = []
|
98
|
+
|
99
|
+
for feature_dim, feature in zip(image_encoder.hidden_dims, features):
|
100
|
+
self.mlp_blocks.append(
|
101
|
+
keras.layers.Dense(
|
102
|
+
projection_filters, name=f"linear_{feature_dim}"
|
103
|
+
)
|
104
|
+
)
|
105
|
+
|
106
|
+
self.resizing = keras.layers.Resizing(
|
107
|
+
height, width, interpolation="bilinear"
|
108
|
+
)
|
109
|
+
self.concat = keras.layers.Concatenate(axis=-1)
|
110
|
+
self.linear_fuse = keras.Sequential(
|
111
|
+
[
|
112
|
+
keras.layers.Conv2D(
|
113
|
+
filters=projection_filters, kernel_size=1, use_bias=False
|
114
|
+
),
|
115
|
+
keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9),
|
116
|
+
keras.layers.Activation("relu"),
|
117
|
+
]
|
118
|
+
)
|
119
|
+
|
120
|
+
# === Functional Model ===
|
121
|
+
# Project all multi-level outputs onto
|
122
|
+
# the same dimensionality and feature map shape
|
123
|
+
multi_layer_outs = []
|
124
|
+
for index, (feature_dim, feature) in enumerate(
|
125
|
+
zip(image_encoder.hidden_dims, features)
|
126
|
+
):
|
127
|
+
out = self.mlp_blocks[index](features[feature])
|
128
|
+
out = self.resizing(out)
|
129
|
+
multi_layer_outs.append(out)
|
130
|
+
|
131
|
+
# Concat now-equal feature maps
|
132
|
+
concatenated_outs = self.concat(multi_layer_outs[::-1])
|
133
|
+
|
134
|
+
# Fuse concatenated features into a segmentation map
|
135
|
+
seg = self.linear_fuse(concatenated_outs)
|
136
|
+
|
137
|
+
super().__init__(
|
138
|
+
inputs=inputs,
|
139
|
+
outputs=seg,
|
140
|
+
**kwargs,
|
141
|
+
)
|
142
|
+
|
143
|
+
# === Config ===
|
144
|
+
self.projection_filters = projection_filters
|
145
|
+
self.image_encoder = image_encoder
|
146
|
+
|
147
|
+
def get_config(self):
|
148
|
+
config = super().get_config()
|
149
|
+
config.update(
|
150
|
+
{
|
151
|
+
"projection_filters": self.projection_filters,
|
152
|
+
"image_encoder": keras.saving.serialize_keras_object(
|
153
|
+
self.image_encoder
|
154
|
+
),
|
155
|
+
}
|
156
|
+
)
|
157
|
+
return config
|
158
|
+
|
159
|
+
@classmethod
|
160
|
+
def from_config(cls, config):
|
161
|
+
if "image_encoder" in config and isinstance(
|
162
|
+
config["image_encoder"], dict
|
163
|
+
):
|
164
|
+
config["image_encoder"] = keras.layers.deserialize(
|
165
|
+
config["image_encoder"]
|
166
|
+
)
|
167
|
+
return super().from_config(config)
|
@@ -0,0 +1,8 @@
|
|
1
|
+
from keras_hub.src.api_export import keras_hub_export
|
2
|
+
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
|
3
|
+
from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone
|
4
|
+
|
5
|
+
|
6
|
+
@keras_hub_export("keras_hub.layers.SegFormerImageConverter")
|
7
|
+
class SegFormerImageConverter(ImageConverter):
|
8
|
+
backbone_cls = SegFormerBackbone
|
@@ -0,0 +1,184 @@
|
|
1
|
+
import keras
|
2
|
+
|
3
|
+
from keras_hub.src.api_export import keras_hub_export
|
4
|
+
from keras_hub.src.models.image_segmenter import ImageSegmenter
|
5
|
+
from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone
|
6
|
+
from keras_hub.src.models.segformer.segformer_image_segmenter_preprocessor import ( # noqa: E501
|
7
|
+
SegFormerImageSegmenterPreprocessor,
|
8
|
+
)
|
9
|
+
|
10
|
+
|
11
|
+
@keras_hub_export("keras_hub.models.SegFormerImageSegmenter")
|
12
|
+
class SegFormerImageSegmenter(ImageSegmenter):
|
13
|
+
"""A Keras model implementing SegFormer for semantic segmentation.
|
14
|
+
|
15
|
+
This class implements the segmentation head of the SegFormer architecture
|
16
|
+
described in [SegFormer: Simple and Efficient Design for Semantic
|
17
|
+
Segmentation with Transformers] (https://arxiv.org/abs/2105.15203) and
|
18
|
+
[based on the TensorFlow implementation from DeepVision]
|
19
|
+
(https://github.com/DavidLandup0/deepvision/tree/main/deepvision/models/segmentation/segformer).
|
20
|
+
|
21
|
+
SegFormers are meant to be used with the MixTransformer (MiT) encoder
|
22
|
+
family, and and use a very lightweight all-MLP decoder head.
|
23
|
+
|
24
|
+
The MiT encoder uses a hierarchical transformer which outputs features at
|
25
|
+
multiple scales, similar to that of the hierarchical outputs typically
|
26
|
+
associated with CNNs.
|
27
|
+
|
28
|
+
Args:
|
29
|
+
image_encoder: `keras.Model`. The backbone network for the model that is
|
30
|
+
used as a feature extractor for the SegFormer encoder. It is
|
31
|
+
*intended* to be used only with the MiT backbone model
|
32
|
+
(`keras_hub.models.MiTBackbone`) which was created specifically for
|
33
|
+
SegFormers. Alternatively, can be a `keras_hub.models.Backbone` a
|
34
|
+
model subclassing `keras_hub.models.FeaturePyramidBackbone`, or a
|
35
|
+
`keras.Model` that has a `pyramid_outputs` property which is a
|
36
|
+
dictionary with keys "P2", "P3", "P4", and "P5" and layer names as
|
37
|
+
values.
|
38
|
+
num_classes: int, the number of classes for the detection model,
|
39
|
+
including the background class.
|
40
|
+
projection_filters: int, number of filters in the
|
41
|
+
convolution layer projecting the concatenated features into a
|
42
|
+
segmentation map. Defaults to 256`.
|
43
|
+
|
44
|
+
|
45
|
+
Example:
|
46
|
+
|
47
|
+
Using presets:
|
48
|
+
|
49
|
+
```python
|
50
|
+
segmenter = keras_hub.models.SegFormerImageSegmenter.from_preset(
|
51
|
+
"segformer_b0_ade20k_512"
|
52
|
+
)
|
53
|
+
|
54
|
+
images = np.random.rand(1, 512, 512, 3)
|
55
|
+
segformer(images)
|
56
|
+
```
|
57
|
+
|
58
|
+
Using the SegFormer backbone:
|
59
|
+
|
60
|
+
```python
|
61
|
+
encoder = keras_hub.models.MiTBackbone.from_preset(
|
62
|
+
"mit_b0_ade20k_512"
|
63
|
+
)
|
64
|
+
backbone = keras_hub.models.SegFormerBackbone(
|
65
|
+
image_encoder=encoder,
|
66
|
+
projection_filters=256,
|
67
|
+
)
|
68
|
+
```
|
69
|
+
|
70
|
+
Using the SegFormer backbone with a custom encoder:
|
71
|
+
|
72
|
+
```python
|
73
|
+
images = np.ones(shape=(1, 96, 96, 3))
|
74
|
+
labels = np.zeros(shape=(1, 96, 96, 1))
|
75
|
+
|
76
|
+
encoder = keras_hub.models.MiTBackbone(
|
77
|
+
depths=[2, 2, 2, 2],
|
78
|
+
image_shape=(96, 96, 3),
|
79
|
+
hidden_dims=[32, 64, 160, 256],
|
80
|
+
num_layers=4,
|
81
|
+
blockwise_num_heads=[1, 2, 5, 8],
|
82
|
+
blockwise_sr_ratios=[8, 4, 2, 1],
|
83
|
+
max_drop_path_rate=0.1,
|
84
|
+
patch_sizes=[7, 3, 3, 3],
|
85
|
+
strides=[4, 2, 2, 2],
|
86
|
+
)
|
87
|
+
|
88
|
+
backbone = keras_hub.models.SegFormerBackbone(
|
89
|
+
image_encoder=encoder,
|
90
|
+
projection_filters=256,
|
91
|
+
)
|
92
|
+
segformer = keras_hub.models.SegFormerImageSegmenter(
|
93
|
+
backbone=backbone,
|
94
|
+
num_classes=4,
|
95
|
+
)
|
96
|
+
segformer(images
|
97
|
+
```
|
98
|
+
|
99
|
+
Using the segmentor class with a preset backbone:
|
100
|
+
|
101
|
+
```python
|
102
|
+
image_encoder = keras_hub.models.MiTBackbone.from_preset(
|
103
|
+
"mit_b0_ade20k_512"
|
104
|
+
)
|
105
|
+
backbone = keras_hub.models.SegFormerBackbone(
|
106
|
+
image_encoder=encoder,
|
107
|
+
projection_filters=256,
|
108
|
+
)
|
109
|
+
segformer = keras_hub.models.SegFormerImageSegmenter(
|
110
|
+
backbone=backbone,
|
111
|
+
num_classes=4,
|
112
|
+
)
|
113
|
+
```
|
114
|
+
"""
|
115
|
+
|
116
|
+
backbone_cls = SegFormerBackbone
|
117
|
+
preprocessor_cls = SegFormerImageSegmenterPreprocessor
|
118
|
+
|
119
|
+
def __init__(
|
120
|
+
self,
|
121
|
+
backbone,
|
122
|
+
num_classes,
|
123
|
+
preprocessor=None,
|
124
|
+
**kwargs,
|
125
|
+
):
|
126
|
+
if not isinstance(backbone, keras.layers.Layer) or not isinstance(
|
127
|
+
backbone, keras.Model
|
128
|
+
):
|
129
|
+
raise ValueError(
|
130
|
+
"Argument `backbone` must be a `keras.layers.Layer` instance "
|
131
|
+
f" or `keras.Model`. Received instead "
|
132
|
+
f"backbone={backbone} (of type {type(backbone)})."
|
133
|
+
)
|
134
|
+
|
135
|
+
# === Layers ===
|
136
|
+
inputs = backbone.input
|
137
|
+
|
138
|
+
self.backbone = backbone
|
139
|
+
self.preprocessor = preprocessor
|
140
|
+
self.dropout = keras.layers.Dropout(0.1)
|
141
|
+
self.output_segmentation_head = keras.layers.Conv2D(
|
142
|
+
filters=num_classes, kernel_size=1, strides=1
|
143
|
+
)
|
144
|
+
self.resizing = keras.layers.Resizing(
|
145
|
+
height=inputs.shape[1],
|
146
|
+
width=inputs.shape[2],
|
147
|
+
interpolation="bilinear",
|
148
|
+
)
|
149
|
+
|
150
|
+
# === Functional Model ===
|
151
|
+
x = self.backbone(inputs)
|
152
|
+
x = self.dropout(x)
|
153
|
+
x = self.output_segmentation_head(x)
|
154
|
+
output = self.resizing(x)
|
155
|
+
|
156
|
+
super().__init__(
|
157
|
+
inputs=inputs,
|
158
|
+
outputs=output,
|
159
|
+
**kwargs,
|
160
|
+
)
|
161
|
+
|
162
|
+
# === Config ===
|
163
|
+
self.num_classes = num_classes
|
164
|
+
self.backbone = backbone
|
165
|
+
|
166
|
+
def get_config(self):
|
167
|
+
config = super().get_config()
|
168
|
+
config.update(
|
169
|
+
{
|
170
|
+
"num_classes": self.num_classes,
|
171
|
+
"backbone": keras.saving.serialize_keras_object(self.backbone),
|
172
|
+
}
|
173
|
+
)
|
174
|
+
return config
|
175
|
+
|
176
|
+
@classmethod
|
177
|
+
def from_config(cls, config):
|
178
|
+
if "image_encoder" in config and isinstance(
|
179
|
+
config["image_encoder"], dict
|
180
|
+
):
|
181
|
+
config["image_encoder"] = keras.layers.deserialize(
|
182
|
+
config["image_encoder"]
|
183
|
+
)
|
184
|
+
return super().from_config(config)
|