keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +21 -3
- keras_hub/api/models/__init__.py +71 -12
- keras_hub/api/tokenizers/__init__.py +1 -1
- keras_hub/src/bounding_box/__init__.py +2 -0
- keras_hub/src/bounding_box/converters.py +102 -12
- keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
- keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
- keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
- keras_hub/src/layers/modeling/rms_normalization.py +36 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
- keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
- keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
- keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
- keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
- keras_hub/src/layers/preprocessing/image_converter.py +170 -34
- keras_hub/src/metrics/bleu.py +4 -3
- keras_hub/src/models/albert/albert_presets.py +4 -12
- keras_hub/src/models/albert/albert_text_classifier.py +7 -7
- keras_hub/src/models/backbone.py +3 -14
- keras_hub/src/models/bart/bart_backbone.py +4 -4
- keras_hub/src/models/bart/bart_presets.py +3 -9
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
- keras_hub/src/models/basnet/__init__.py +5 -0
- keras_hub/src/models/basnet/basnet.py +122 -0
- keras_hub/src/models/basnet/basnet_backbone.py +366 -0
- keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
- keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
- keras_hub/src/models/basnet/basnet_presets.py +17 -0
- keras_hub/src/models/bert/bert_presets.py +14 -32
- keras_hub/src/models/bert/bert_text_classifier.py +3 -3
- keras_hub/src/models/bloom/bloom_presets.py +8 -24
- keras_hub/src/models/causal_lm.py +56 -12
- keras_hub/src/models/clip/__init__.py +5 -0
- keras_hub/src/models/clip/clip_backbone.py +286 -0
- keras_hub/src/models/clip/clip_encoder_block.py +19 -4
- keras_hub/src/models/clip/clip_image_converter.py +8 -0
- keras_hub/src/models/clip/clip_presets.py +93 -0
- keras_hub/src/models/clip/clip_text_encoder.py +4 -1
- keras_hub/src/models/clip/clip_tokenizer.py +18 -3
- keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
- keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
- keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
- keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
- keras_hub/src/models/deeplab_v3/__init__.py +7 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
- keras_hub/src/models/densenet/densenet_backbone.py +6 -4
- keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
- keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
- keras_hub/src/models/densenet/densenet_presets.py +9 -15
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
- keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
- keras_hub/src/models/efficientnet/__init__.py +9 -0
- keras_hub/src/models/efficientnet/cba.py +141 -0
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
- keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
- keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
- keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
- keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
- keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
- keras_hub/src/models/efficientnet/mbconv.py +53 -22
- keras_hub/src/models/electra/electra_backbone.py +2 -2
- keras_hub/src/models/electra/electra_presets.py +6 -18
- keras_hub/src/models/f_net/f_net_presets.py +2 -6
- keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
- keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
- keras_hub/src/models/falcon/falcon_backbone.py +5 -3
- keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
- keras_hub/src/models/falcon/falcon_presets.py +1 -3
- keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
- keras_hub/src/models/feature_pyramid_backbone.py +1 -1
- keras_hub/src/models/flux/__init__.py +5 -0
- keras_hub/src/models/flux/flux_layers.py +496 -0
- keras_hub/src/models/flux/flux_maths.py +225 -0
- keras_hub/src/models/flux/flux_model.py +236 -0
- keras_hub/src/models/flux/flux_presets.py +3 -0
- keras_hub/src/models/flux/flux_text_to_image.py +146 -0
- keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
- keras_hub/src/models/gemma/gemma_backbone.py +35 -20
- keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
- keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
- keras_hub/src/models/gemma/gemma_presets.py +29 -63
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
- keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
- keras_hub/src/models/image_classifier.py +147 -2
- keras_hub/src/models/image_classifier_preprocessor.py +6 -3
- keras_hub/src/models/image_object_detector.py +87 -0
- keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
- keras_hub/src/models/image_segmenter.py +0 -5
- keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
- keras_hub/src/models/image_to_image.py +417 -0
- keras_hub/src/models/inpaint.py +520 -0
- keras_hub/src/models/llama/llama_backbone.py +138 -12
- keras_hub/src/models/llama/llama_causal_lm.py +3 -1
- keras_hub/src/models/llama/llama_presets.py +10 -20
- keras_hub/src/models/llama3/llama3_backbone.py +12 -11
- keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
- keras_hub/src/models/llama3/llama3_presets.py +4 -12
- keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
- keras_hub/src/models/mistral/mistral_backbone.py +16 -15
- keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
- keras_hub/src/models/mistral/mistral_presets.py +3 -9
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
- keras_hub/src/models/mit/__init__.py +6 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
- keras_hub/src/models/mit/mit_image_classifier.py +12 -0
- keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/mit/mit_image_converter.py +8 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
- keras_hub/src/models/mit/mit_presets.py +139 -0
- keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
- keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
- keras_hub/src/models/opt/opt_causal_lm.py +2 -2
- keras_hub/src/models/opt/opt_presets.py +4 -12
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
- keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
- keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
- keras_hub/src/models/phi3/phi3_decoder.py +0 -1
- keras_hub/src/models/phi3/phi3_presets.py +2 -6
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
- keras_hub/src/models/preprocessor.py +25 -11
- keras_hub/src/models/resnet/resnet_backbone.py +3 -14
- keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
- keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
- keras_hub/src/models/resnet/resnet_presets.py +127 -18
- keras_hub/src/models/retinanet/__init__.py +5 -0
- keras_hub/src/models/retinanet/anchor_generator.py +52 -53
- keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
- keras_hub/src/models/retinanet/non_max_supression.py +1 -0
- keras_hub/src/models/retinanet/prediction_head.py +192 -0
- keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
- keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
- keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
- keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
- keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
- keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
- keras_hub/src/models/roberta/roberta_backbone.py +2 -2
- keras_hub/src/models/roberta/roberta_presets.py +6 -8
- keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
- keras_hub/src/models/sam/__init__.py +5 -0
- keras_hub/src/models/sam/sam_backbone.py +2 -3
- keras_hub/src/models/sam/sam_image_converter.py +2 -4
- keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
- keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
- keras_hub/src/models/sam/sam_layers.py +5 -3
- keras_hub/src/models/sam/sam_presets.py +3 -9
- keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
- keras_hub/src/models/sam/sam_transformer.py +5 -4
- keras_hub/src/models/segformer/__init__.py +8 -0
- keras_hub/src/models/segformer/segformer_backbone.py +167 -0
- keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
- keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
- keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
- keras_hub/src/models/segformer/segformer_presets.py +136 -0
- keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
- keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
- keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
- keras_hub/src/models/t5/t5_backbone.py +5 -4
- keras_hub/src/models/t5/t5_presets.py +47 -19
- keras_hub/src/models/task.py +47 -39
- keras_hub/src/models/text_classifier.py +2 -2
- keras_hub/src/models/text_to_image.py +106 -41
- keras_hub/src/models/vae/__init__.py +1 -0
- keras_hub/src/models/vae/vae_backbone.py +184 -0
- keras_hub/src/models/vae/vae_layers.py +739 -0
- keras_hub/src/models/vgg/__init__.py +5 -0
- keras_hub/src/models/vgg/vgg_backbone.py +4 -24
- keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
- keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
- keras_hub/src/models/vgg/vgg_presets.py +48 -0
- keras_hub/src/models/vit/__init__.py +5 -0
- keras_hub/src/models/vit/vit_backbone.py +152 -0
- keras_hub/src/models/vit/vit_image_classifier.py +187 -0
- keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vit/vit_image_converter.py +73 -0
- keras_hub/src/models/vit/vit_layers.py +391 -0
- keras_hub/src/models/vit/vit_presets.py +126 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
- keras_hub/src/models/vit_det/vit_layers.py +3 -3
- keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
- keras_hub/src/models/whisper/whisper_backbone.py +6 -5
- keras_hub/src/models/whisper/whisper_decoder.py +3 -5
- keras_hub/src/models/whisper/whisper_presets.py +10 -30
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
- keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
- keras_hub/src/models/xlnet/relative_attention.py +20 -19
- keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
- keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
- keras_hub/src/samplers/contrastive_sampler.py +2 -3
- keras_hub/src/samplers/sampler.py +2 -1
- keras_hub/src/tests/test_case.py +41 -6
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
- keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
- keras_hub/src/tokenizers/tokenizer.py +10 -13
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
- keras_hub/src/utils/keras_utils.py +2 -13
- keras_hub/src/utils/pipeline_model.py +3 -3
- keras_hub/src/utils/preset_utils.py +196 -144
- keras_hub/src/utils/tensor_utils.py +4 -4
- keras_hub/src/utils/timm/convert_densenet.py +6 -4
- keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
- keras_hub/src/utils/timm/convert_resnet.py +1 -1
- keras_hub/src/utils/timm/convert_vgg.py +85 -0
- keras_hub/src/utils/timm/preset_loader.py +14 -9
- keras_hub/src/utils/transformers/convert_llama3.py +21 -5
- keras_hub/src/utils/transformers/convert_vit.py +150 -0
- keras_hub/src/utils/transformers/preset_loader.py +23 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
- keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
- keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
- keras_hub/src/models/mix_transformer/__init__.py +0 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
- keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
- keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,3 @@
|
|
1
|
-
import keras
|
2
|
-
|
3
1
|
from keras_hub.src.api_export import keras_hub_export
|
4
2
|
from keras_hub.src.models.image_classifier import ImageClassifier
|
5
3
|
from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone
|
@@ -10,140 +8,5 @@ from keras_hub.src.models.resnet.resnet_image_classifier_preprocessor import (
|
|
10
8
|
|
11
9
|
@keras_hub_export("keras_hub.models.ResNetImageClassifier")
|
12
10
|
class ResNetImageClassifier(ImageClassifier):
|
13
|
-
"""ResNet image classifier task model.
|
14
|
-
|
15
|
-
Args:
|
16
|
-
backbone: A `keras_hub.models.ResNetBackbone` instance.
|
17
|
-
num_classes: int. The number of classes to predict.
|
18
|
-
activation: `None`, str or callable. The activation function to use on
|
19
|
-
the `Dense` layer. Set `activation=None` to return the output
|
20
|
-
logits. Defaults to `"softmax"`.
|
21
|
-
head_dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The
|
22
|
-
dtype to use for the classification head's computations and weights.
|
23
|
-
|
24
|
-
To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)`
|
25
|
-
where `x` is a tensor and `y` is a integer from `[0, num_classes)`.
|
26
|
-
All `ImageClassifier` tasks include a `from_preset()` constructor which can
|
27
|
-
be used to load a pre-trained config and weights.
|
28
|
-
|
29
|
-
Examples:
|
30
|
-
|
31
|
-
Call `predict()` to run inference.
|
32
|
-
```python
|
33
|
-
# Load preset and train
|
34
|
-
images = np.ones((2, 224, 224, 3), dtype="float32")
|
35
|
-
classifier = keras_hub.models.ResNetImageClassifier.from_preset(
|
36
|
-
"resnet_50_imagenet"
|
37
|
-
)
|
38
|
-
classifier.predict(images)
|
39
|
-
```
|
40
|
-
|
41
|
-
Call `fit()` on a single batch.
|
42
|
-
```python
|
43
|
-
# Load preset and train
|
44
|
-
images = np.ones((2, 224, 224, 3), dtype="float32")
|
45
|
-
labels = [0, 3]
|
46
|
-
classifier = keras_hub.models.ResNetImageClassifier.from_preset(
|
47
|
-
"resnet_50_imagenet"
|
48
|
-
)
|
49
|
-
classifier.fit(x=images, y=labels, batch_size=2)
|
50
|
-
```
|
51
|
-
|
52
|
-
Call `fit()` with custom loss, optimizer and backbone.
|
53
|
-
```python
|
54
|
-
classifier = keras_hub.models.ResNetImageClassifier.from_preset(
|
55
|
-
"resnet_50_imagenet"
|
56
|
-
)
|
57
|
-
classifier.compile(
|
58
|
-
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
|
59
|
-
optimizer=keras.optimizers.Adam(5e-5),
|
60
|
-
)
|
61
|
-
classifier.backbone.trainable = False
|
62
|
-
classifier.fit(x=images, y=labels, batch_size=2)
|
63
|
-
```
|
64
|
-
|
65
|
-
Custom backbone.
|
66
|
-
```python
|
67
|
-
images = np.ones((2, 224, 224, 3), dtype="float32")
|
68
|
-
labels = [0, 3]
|
69
|
-
backbone = keras_hub.models.ResNetBackbone(
|
70
|
-
stackwise_num_filters=[64, 64, 64],
|
71
|
-
stackwise_num_blocks=[2, 2, 2],
|
72
|
-
stackwise_num_strides=[1, 2, 2],
|
73
|
-
block_type="basic_block",
|
74
|
-
use_pre_activation=True,
|
75
|
-
pooling="avg",
|
76
|
-
)
|
77
|
-
classifier = keras_hub.models.ResNetImageClassifier(
|
78
|
-
backbone=backbone,
|
79
|
-
num_classes=4,
|
80
|
-
)
|
81
|
-
classifier.fit(x=images, y=labels, batch_size=2)
|
82
|
-
```
|
83
|
-
"""
|
84
|
-
|
85
11
|
backbone_cls = ResNetBackbone
|
86
12
|
preprocessor_cls = ResNetImageClassifierPreprocessor
|
87
|
-
|
88
|
-
def __init__(
|
89
|
-
self,
|
90
|
-
backbone,
|
91
|
-
num_classes,
|
92
|
-
preprocessor=None,
|
93
|
-
pooling="avg",
|
94
|
-
activation=None,
|
95
|
-
head_dtype=None,
|
96
|
-
**kwargs,
|
97
|
-
):
|
98
|
-
head_dtype = head_dtype or backbone.dtype_policy
|
99
|
-
|
100
|
-
# === Layers ===
|
101
|
-
self.backbone = backbone
|
102
|
-
self.preprocessor = preprocessor
|
103
|
-
if pooling == "avg":
|
104
|
-
self.pooler = keras.layers.GlobalAveragePooling2D(
|
105
|
-
data_format=backbone.data_format, dtype=head_dtype
|
106
|
-
)
|
107
|
-
elif pooling == "max":
|
108
|
-
self.pooler = keras.layers.GlobalAveragePooling2D(
|
109
|
-
data_format=backbone.data_format, dtype=head_dtype
|
110
|
-
)
|
111
|
-
else:
|
112
|
-
raise ValueError(
|
113
|
-
"Unknown `pooling` type. Polling should be either `'avg'` or "
|
114
|
-
f"`'max'`. Received: pooling={pooling}."
|
115
|
-
)
|
116
|
-
self.output_dense = keras.layers.Dense(
|
117
|
-
num_classes,
|
118
|
-
activation=activation,
|
119
|
-
dtype=head_dtype,
|
120
|
-
name="predictions",
|
121
|
-
)
|
122
|
-
|
123
|
-
# === Functional Model ===
|
124
|
-
inputs = self.backbone.input
|
125
|
-
x = self.backbone(inputs)
|
126
|
-
x = self.pooler(x)
|
127
|
-
outputs = self.output_dense(x)
|
128
|
-
super().__init__(
|
129
|
-
inputs=inputs,
|
130
|
-
outputs=outputs,
|
131
|
-
**kwargs,
|
132
|
-
)
|
133
|
-
|
134
|
-
# === Config ===
|
135
|
-
self.num_classes = num_classes
|
136
|
-
self.activation = activation
|
137
|
-
self.pooling = pooling
|
138
|
-
|
139
|
-
def get_config(self):
|
140
|
-
# Backbone serialized in `super`
|
141
|
-
config = super().get_config()
|
142
|
-
config.update(
|
143
|
-
{
|
144
|
-
"num_classes": self.num_classes,
|
145
|
-
"pooling": self.pooling,
|
146
|
-
"activation": self.activation,
|
147
|
-
}
|
148
|
-
)
|
149
|
-
return config
|
@@ -1,10 +1,8 @@
|
|
1
1
|
from keras_hub.src.api_export import keras_hub_export
|
2
|
-
from keras_hub.src.layers.preprocessing.
|
3
|
-
ResizingImageConverter,
|
4
|
-
)
|
2
|
+
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
|
5
3
|
from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone
|
6
4
|
|
7
5
|
|
8
6
|
@keras_hub_export("keras_hub.layers.ResNetImageConverter")
|
9
|
-
class ResNetImageConverter(
|
7
|
+
class ResNetImageConverter(ImageConverter):
|
10
8
|
backbone_cls = ResNetBackbone
|
@@ -8,11 +8,9 @@ backbone_presets = {
|
|
8
8
|
"at a 224x224 resolution."
|
9
9
|
),
|
10
10
|
"params": 11186112,
|
11
|
-
"official_name": "ResNet",
|
12
11
|
"path": "resnet",
|
13
|
-
"model_card": "https://arxiv.org/abs/2110.00476",
|
14
12
|
},
|
15
|
-
"kaggle_handle": "kaggle://
|
13
|
+
"kaggle_handle": "kaggle://keras/resnetv1/keras/resnet_18_imagenet/3",
|
16
14
|
},
|
17
15
|
"resnet_50_imagenet": {
|
18
16
|
"metadata": {
|
@@ -21,11 +19,9 @@ backbone_presets = {
|
|
21
19
|
"at a 224x224 resolution."
|
22
20
|
),
|
23
21
|
"params": 23561152,
|
24
|
-
"official_name": "ResNet",
|
25
22
|
"path": "resnet",
|
26
|
-
"model_card": "https://arxiv.org/abs/2110.00476",
|
27
23
|
},
|
28
|
-
"kaggle_handle": "kaggle://
|
24
|
+
"kaggle_handle": "kaggle://keras/resnetv1/keras/resnet_50_imagenet/3",
|
29
25
|
},
|
30
26
|
"resnet_101_imagenet": {
|
31
27
|
"metadata": {
|
@@ -34,11 +30,9 @@ backbone_presets = {
|
|
34
30
|
"at a 224x224 resolution."
|
35
31
|
),
|
36
32
|
"params": 42605504,
|
37
|
-
"official_name": "ResNet",
|
38
33
|
"path": "resnet",
|
39
|
-
"model_card": "https://arxiv.org/abs/2110.00476",
|
40
34
|
},
|
41
|
-
"kaggle_handle": "kaggle://
|
35
|
+
"kaggle_handle": "kaggle://keras/resnetv1/keras/resnet_101_imagenet/3",
|
42
36
|
},
|
43
37
|
"resnet_152_imagenet": {
|
44
38
|
"metadata": {
|
@@ -47,11 +41,9 @@ backbone_presets = {
|
|
47
41
|
"at a 224x224 resolution."
|
48
42
|
),
|
49
43
|
"params": 58295232,
|
50
|
-
"official_name": "ResNet",
|
51
44
|
"path": "resnet",
|
52
|
-
"model_card": "https://arxiv.org/abs/2110.00476",
|
53
45
|
},
|
54
|
-
"kaggle_handle": "kaggle://
|
46
|
+
"kaggle_handle": "kaggle://keras/resnetv1/keras/resnet_152_imagenet/3",
|
55
47
|
},
|
56
48
|
"resnet_v2_50_imagenet": {
|
57
49
|
"metadata": {
|
@@ -60,11 +52,9 @@ backbone_presets = {
|
|
60
52
|
"dataset at a 224x224 resolution."
|
61
53
|
),
|
62
54
|
"params": 23561152,
|
63
|
-
"official_name": "ResNet",
|
64
55
|
"path": "resnet",
|
65
|
-
"model_card": "https://arxiv.org/abs/2110.00476",
|
66
56
|
},
|
67
|
-
"kaggle_handle": "kaggle://
|
57
|
+
"kaggle_handle": "kaggle://keras/resnetv2/keras/resnet_v2_50_imagenet/3",
|
68
58
|
},
|
69
59
|
"resnet_v2_101_imagenet": {
|
70
60
|
"metadata": {
|
@@ -73,10 +63,129 @@ backbone_presets = {
|
|
73
63
|
"dataset at a 224x224 resolution."
|
74
64
|
),
|
75
65
|
"params": 42605504,
|
76
|
-
"official_name": "ResNet",
|
77
66
|
"path": "resnet",
|
78
|
-
"model_card": "https://arxiv.org/abs/2110.00476",
|
79
67
|
},
|
80
|
-
"kaggle_handle": "kaggle://
|
68
|
+
"kaggle_handle": "kaggle://keras/resnetv2/keras/resnet_v2_101_imagenet/3",
|
69
|
+
},
|
70
|
+
"resnet_vd_18_imagenet": {
|
71
|
+
"metadata": {
|
72
|
+
"description": (
|
73
|
+
"18-layer ResNetVD (ResNet with bag of tricks) model "
|
74
|
+
"pre-trained on the ImageNet 1k dataset at a 224x224 "
|
75
|
+
"resolution."
|
76
|
+
),
|
77
|
+
"params": 11722824,
|
78
|
+
"path": "resnet",
|
79
|
+
},
|
80
|
+
"kaggle_handle": "kaggle://keras/resnet_vd/keras/resnet_vd_18_imagenet/2",
|
81
|
+
},
|
82
|
+
"resnet_vd_34_imagenet": {
|
83
|
+
"metadata": {
|
84
|
+
"description": (
|
85
|
+
"34-layer ResNetVD (ResNet with bag of tricks) model "
|
86
|
+
"pre-trained on the ImageNet 1k dataset at a 224x224 "
|
87
|
+
"resolution."
|
88
|
+
),
|
89
|
+
"params": 21838408,
|
90
|
+
"path": "resnet",
|
91
|
+
},
|
92
|
+
"kaggle_handle": "kaggle://keras/resnet_vd/keras/resnet_vd_34_imagenet/2",
|
93
|
+
},
|
94
|
+
"resnet_vd_50_imagenet": {
|
95
|
+
"metadata": {
|
96
|
+
"description": (
|
97
|
+
"50-layer ResNetVD (ResNet with bag of tricks) model "
|
98
|
+
"pre-trained on the ImageNet 1k dataset at a 224x224 "
|
99
|
+
"resolution."
|
100
|
+
),
|
101
|
+
"params": 25629512,
|
102
|
+
"path": "resnet",
|
103
|
+
},
|
104
|
+
"kaggle_handle": "kaggle://keras/resnet_vd/keras/resnet_vd_50_imagenet/2",
|
105
|
+
},
|
106
|
+
"resnet_vd_50_ssld_imagenet": {
|
107
|
+
"metadata": {
|
108
|
+
"description": (
|
109
|
+
"50-layer ResNetVD (ResNet with bag of tricks) model "
|
110
|
+
"pre-trained on the ImageNet 1k dataset at a 224x224 "
|
111
|
+
"resolution with knowledge distillation."
|
112
|
+
),
|
113
|
+
"params": 25629512,
|
114
|
+
"path": "resnet",
|
115
|
+
},
|
116
|
+
"kaggle_handle": "kaggle://keras/resnet_vd/keras/resnet_vd_50_ssld_imagenet/2",
|
117
|
+
},
|
118
|
+
"resnet_vd_50_ssld_v2_imagenet": {
|
119
|
+
"metadata": {
|
120
|
+
"description": (
|
121
|
+
"50-layer ResNetVD (ResNet with bag of tricks) model "
|
122
|
+
"pre-trained on the ImageNet 1k dataset at a 224x224 "
|
123
|
+
"resolution with knowledge distillation and AutoAugment."
|
124
|
+
),
|
125
|
+
"params": 25629512,
|
126
|
+
"path": "resnet",
|
127
|
+
},
|
128
|
+
"kaggle_handle": "kaggle://keras/resnet_vd/keras/resnet_vd_50_ssld_v2_imagenet/2",
|
129
|
+
},
|
130
|
+
"resnet_vd_50_ssld_v2_fix_imagenet": {
|
131
|
+
"metadata": {
|
132
|
+
"description": (
|
133
|
+
"50-layer ResNetVD (ResNet with bag of tricks) model "
|
134
|
+
"pre-trained on the ImageNet 1k dataset at a 224x224 "
|
135
|
+
"resolution with knowledge distillation, AutoAugment and "
|
136
|
+
"additional fine-tuning of the classification head."
|
137
|
+
),
|
138
|
+
"params": 25629512,
|
139
|
+
"path": "resnet",
|
140
|
+
},
|
141
|
+
"kaggle_handle": "kaggle://keras/resnet_vd/keras/resnet_vd_50_ssld_v2_fix_imagenet/2",
|
142
|
+
},
|
143
|
+
"resnet_vd_101_imagenet": {
|
144
|
+
"metadata": {
|
145
|
+
"description": (
|
146
|
+
"101-layer ResNetVD (ResNet with bag of tricks) model "
|
147
|
+
"pre-trained on the ImageNet 1k dataset at a 224x224 "
|
148
|
+
"resolution."
|
149
|
+
),
|
150
|
+
"params": 44673864,
|
151
|
+
"path": "resnet",
|
152
|
+
},
|
153
|
+
"kaggle_handle": "kaggle://keras/resnet_vd/keras/resnet_vd_101_imagenet/2",
|
154
|
+
},
|
155
|
+
"resnet_vd_101_ssld_imagenet": {
|
156
|
+
"metadata": {
|
157
|
+
"description": (
|
158
|
+
"101-layer ResNetVD (ResNet with bag of tricks) model "
|
159
|
+
"pre-trained on the ImageNet 1k dataset at a 224x224 "
|
160
|
+
"resolution with knowledge distillation."
|
161
|
+
),
|
162
|
+
"params": 44673864,
|
163
|
+
"path": "resnet",
|
164
|
+
},
|
165
|
+
"kaggle_handle": "kaggle://keras/resnet_vd/keras/resnet_vd_101_ssld_imagenet/2",
|
166
|
+
},
|
167
|
+
"resnet_vd_152_imagenet": {
|
168
|
+
"metadata": {
|
169
|
+
"description": (
|
170
|
+
"152-layer ResNetVD (ResNet with bag of tricks) model "
|
171
|
+
"pre-trained on the ImageNet 1k dataset at a 224x224 "
|
172
|
+
"resolution."
|
173
|
+
),
|
174
|
+
"params": 60363592,
|
175
|
+
"path": "resnet",
|
176
|
+
},
|
177
|
+
"kaggle_handle": "kaggle://keras/resnet_vd/keras/resnet_vd_152_imagenet/2",
|
178
|
+
},
|
179
|
+
"resnet_vd_200_imagenet": {
|
180
|
+
"metadata": {
|
181
|
+
"description": (
|
182
|
+
"200-layer ResNetVD (ResNet with bag of tricks) model "
|
183
|
+
"pre-trained on the ImageNet 1k dataset at a 224x224 "
|
184
|
+
"resolution."
|
185
|
+
),
|
186
|
+
"params": 74933064,
|
187
|
+
"path": "resnet",
|
188
|
+
},
|
189
|
+
"kaggle_handle": "kaggle://keras/resnet_vd/keras/resnet_vd_200_imagenet/2",
|
81
190
|
},
|
82
191
|
}
|
@@ -0,0 +1,5 @@
|
|
1
|
+
from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone
|
2
|
+
from keras_hub.src.models.retinanet.retinanet_presets import backbone_presets
|
3
|
+
from keras_hub.src.utils.preset_utils import register_presets
|
4
|
+
|
5
|
+
register_presets(backbone_presets, RetinaNetBackbone)
|
@@ -3,9 +3,13 @@ import math
|
|
3
3
|
import keras
|
4
4
|
from keras import ops
|
5
5
|
|
6
|
+
from keras_hub.src.api_export import keras_hub_export
|
7
|
+
|
8
|
+
# TODO: https://github.com/keras-team/keras-hub/issues/1965
|
6
9
|
from keras_hub.src.bounding_box.converters import convert_format
|
7
10
|
|
8
11
|
|
12
|
+
@keras_hub_export("keras_hub.layers.AnchorGenerator")
|
9
13
|
class AnchorGenerator(keras.layers.Layer):
|
10
14
|
"""Generates anchor boxes for object detection tasks.
|
11
15
|
|
@@ -81,6 +85,7 @@ class AnchorGenerator(keras.layers.Layer):
|
|
81
85
|
self.num_scales = num_scales
|
82
86
|
self.aspect_ratios = aspect_ratios
|
83
87
|
self.anchor_size = anchor_size
|
88
|
+
self.num_base_anchors = num_scales * len(aspect_ratios)
|
84
89
|
self.built = True
|
85
90
|
|
86
91
|
def call(self, inputs):
|
@@ -92,60 +97,61 @@ class AnchorGenerator(keras.layers.Layer):
|
|
92
97
|
|
93
98
|
image_shape = tuple(image_shape)
|
94
99
|
|
95
|
-
|
100
|
+
multilevel_anchors = {}
|
96
101
|
for level in range(self.min_level, self.max_level + 1):
|
97
|
-
boxes_l = []
|
98
102
|
# Calculate the feature map size for this level
|
99
103
|
feat_size_y = math.ceil(image_shape[0] / 2**level)
|
100
104
|
feat_size_x = math.ceil(image_shape[1] / 2**level)
|
101
105
|
|
102
106
|
# Calculate the stride (step size) for this level
|
103
|
-
stride_y =
|
104
|
-
stride_x =
|
107
|
+
stride_y = image_shape[0] // feat_size_y
|
108
|
+
stride_x = image_shape[1] // feat_size_x
|
105
109
|
|
106
110
|
# Generate anchor center points
|
107
111
|
# Start from stride/2 to center anchors on pixels
|
108
|
-
cx = ops.arange(
|
109
|
-
cy = ops.arange(
|
112
|
+
cx = ops.arange(0, feat_size_x, dtype="float32") * stride_x
|
113
|
+
cy = ops.arange(0, feat_size_y, dtype="float32") * stride_y
|
110
114
|
|
111
115
|
# Create a grid of anchor centers
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
],
|
136
|
-
axis=-1,
|
137
|
-
)
|
138
|
-
boxes_l.append(boxes)
|
139
|
-
# Concat anchors on the same level to tensor shape HxWx(Ax4)
|
140
|
-
boxes_l = ops.concatenate(boxes_l, axis=-1)
|
141
|
-
boxes_l = ops.reshape(boxes_l, (-1, 4))
|
142
|
-
# Convert to user defined
|
143
|
-
multilevel_boxes[f"P{level}"] = convert_format(
|
144
|
-
boxes_l,
|
145
|
-
source="yxyx",
|
116
|
+
cy_grid, cx_grid = ops.meshgrid(cy, cx, indexing="ij")
|
117
|
+
cy_grid = ops.reshape(cy_grid, (-1,))
|
118
|
+
cx_grid = ops.reshape(cx_grid, (-1,))
|
119
|
+
|
120
|
+
shifts = ops.stack((cx_grid, cy_grid, cx_grid, cy_grid), axis=1)
|
121
|
+
sizes = [
|
122
|
+
int(
|
123
|
+
2**level * self.anchor_size * 2 ** (scale / self.num_scales)
|
124
|
+
)
|
125
|
+
for scale in range(self.num_scales)
|
126
|
+
]
|
127
|
+
|
128
|
+
base_anchors = self.generate_base_anchors(
|
129
|
+
sizes=sizes, aspect_ratios=self.aspect_ratios
|
130
|
+
)
|
131
|
+
shifts = ops.reshape(shifts, (-1, 1, 4))
|
132
|
+
base_anchors = ops.reshape(base_anchors, (1, -1, 4))
|
133
|
+
|
134
|
+
anchors = shifts + base_anchors
|
135
|
+
anchors = ops.reshape(anchors, (-1, 4))
|
136
|
+
multilevel_anchors[f"P{level}"] = convert_format(
|
137
|
+
anchors,
|
138
|
+
source="xyxy",
|
146
139
|
target=self.bounding_box_format,
|
147
140
|
)
|
148
|
-
return
|
141
|
+
return multilevel_anchors
|
142
|
+
|
143
|
+
def generate_base_anchors(self, sizes, aspect_ratios):
|
144
|
+
sizes = ops.convert_to_tensor(sizes, dtype="float32")
|
145
|
+
aspect_ratios = ops.convert_to_tensor(aspect_ratios)
|
146
|
+
h_ratios = ops.sqrt(aspect_ratios)
|
147
|
+
w_ratios = 1 / h_ratios
|
148
|
+
|
149
|
+
ws = ops.reshape(w_ratios[:, None] * sizes[None, :], (-1,))
|
150
|
+
hs = ops.reshape(h_ratios[:, None] * sizes[None, :], (-1,))
|
151
|
+
|
152
|
+
base_anchors = ops.stack([-1 * ws, -1 * hs, ws, hs], axis=1) / 2
|
153
|
+
base_anchors = ops.round(base_anchors)
|
154
|
+
return base_anchors
|
149
155
|
|
150
156
|
def compute_output_shape(self, input_shape):
|
151
157
|
multilevel_boxes_shape = {}
|
@@ -156,18 +162,11 @@ class AnchorGenerator(keras.layers.Layer):
|
|
156
162
|
|
157
163
|
for i in range(self.min_level, self.max_level + 1):
|
158
164
|
multilevel_boxes_shape[f"P{i}"] = (
|
159
|
-
(
|
160
|
-
|
161
|
-
|
165
|
+
int(
|
166
|
+
math.ceil(image_height / 2 ** (i))
|
167
|
+
* math.ceil(image_width // 2 ** (i))
|
168
|
+
* self.num_base_anchors
|
169
|
+
),
|
162
170
|
4,
|
163
171
|
)
|
164
172
|
return multilevel_boxes_shape
|
165
|
-
|
166
|
-
@property
|
167
|
-
def anchors_per_location(self):
|
168
|
-
"""
|
169
|
-
The `anchors_per_location` property returns the number of anchors
|
170
|
-
generated per pixel location, which is equal to
|
171
|
-
`num_scales * len(aspect_ratios)`.
|
172
|
-
"""
|
173
|
-
return self.num_scales * len(self.aspect_ratios)
|