keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +21 -3
- keras_hub/api/models/__init__.py +71 -12
- keras_hub/api/tokenizers/__init__.py +1 -1
- keras_hub/src/bounding_box/__init__.py +2 -0
- keras_hub/src/bounding_box/converters.py +102 -12
- keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
- keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
- keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
- keras_hub/src/layers/modeling/rms_normalization.py +36 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
- keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
- keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
- keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
- keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
- keras_hub/src/layers/preprocessing/image_converter.py +170 -34
- keras_hub/src/metrics/bleu.py +4 -3
- keras_hub/src/models/albert/albert_presets.py +4 -12
- keras_hub/src/models/albert/albert_text_classifier.py +7 -7
- keras_hub/src/models/backbone.py +3 -14
- keras_hub/src/models/bart/bart_backbone.py +4 -4
- keras_hub/src/models/bart/bart_presets.py +3 -9
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
- keras_hub/src/models/basnet/__init__.py +5 -0
- keras_hub/src/models/basnet/basnet.py +122 -0
- keras_hub/src/models/basnet/basnet_backbone.py +366 -0
- keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
- keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
- keras_hub/src/models/basnet/basnet_presets.py +17 -0
- keras_hub/src/models/bert/bert_presets.py +14 -32
- keras_hub/src/models/bert/bert_text_classifier.py +3 -3
- keras_hub/src/models/bloom/bloom_presets.py +8 -24
- keras_hub/src/models/causal_lm.py +56 -12
- keras_hub/src/models/clip/__init__.py +5 -0
- keras_hub/src/models/clip/clip_backbone.py +286 -0
- keras_hub/src/models/clip/clip_encoder_block.py +19 -4
- keras_hub/src/models/clip/clip_image_converter.py +8 -0
- keras_hub/src/models/clip/clip_presets.py +93 -0
- keras_hub/src/models/clip/clip_text_encoder.py +4 -1
- keras_hub/src/models/clip/clip_tokenizer.py +18 -3
- keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
- keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
- keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
- keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
- keras_hub/src/models/deeplab_v3/__init__.py +7 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
- keras_hub/src/models/densenet/densenet_backbone.py +6 -4
- keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
- keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
- keras_hub/src/models/densenet/densenet_presets.py +9 -15
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
- keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
- keras_hub/src/models/efficientnet/__init__.py +9 -0
- keras_hub/src/models/efficientnet/cba.py +141 -0
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
- keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
- keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
- keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
- keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
- keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
- keras_hub/src/models/efficientnet/mbconv.py +53 -22
- keras_hub/src/models/electra/electra_backbone.py +2 -2
- keras_hub/src/models/electra/electra_presets.py +6 -18
- keras_hub/src/models/f_net/f_net_presets.py +2 -6
- keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
- keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
- keras_hub/src/models/falcon/falcon_backbone.py +5 -3
- keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
- keras_hub/src/models/falcon/falcon_presets.py +1 -3
- keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
- keras_hub/src/models/feature_pyramid_backbone.py +1 -1
- keras_hub/src/models/flux/__init__.py +5 -0
- keras_hub/src/models/flux/flux_layers.py +496 -0
- keras_hub/src/models/flux/flux_maths.py +225 -0
- keras_hub/src/models/flux/flux_model.py +236 -0
- keras_hub/src/models/flux/flux_presets.py +3 -0
- keras_hub/src/models/flux/flux_text_to_image.py +146 -0
- keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
- keras_hub/src/models/gemma/gemma_backbone.py +35 -20
- keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
- keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
- keras_hub/src/models/gemma/gemma_presets.py +29 -63
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
- keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
- keras_hub/src/models/image_classifier.py +147 -2
- keras_hub/src/models/image_classifier_preprocessor.py +6 -3
- keras_hub/src/models/image_object_detector.py +87 -0
- keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
- keras_hub/src/models/image_segmenter.py +0 -5
- keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
- keras_hub/src/models/image_to_image.py +417 -0
- keras_hub/src/models/inpaint.py +520 -0
- keras_hub/src/models/llama/llama_backbone.py +138 -12
- keras_hub/src/models/llama/llama_causal_lm.py +3 -1
- keras_hub/src/models/llama/llama_presets.py +10 -20
- keras_hub/src/models/llama3/llama3_backbone.py +12 -11
- keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
- keras_hub/src/models/llama3/llama3_presets.py +4 -12
- keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
- keras_hub/src/models/mistral/mistral_backbone.py +16 -15
- keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
- keras_hub/src/models/mistral/mistral_presets.py +3 -9
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
- keras_hub/src/models/mit/__init__.py +6 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
- keras_hub/src/models/mit/mit_image_classifier.py +12 -0
- keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/mit/mit_image_converter.py +8 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
- keras_hub/src/models/mit/mit_presets.py +139 -0
- keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
- keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
- keras_hub/src/models/opt/opt_causal_lm.py +2 -2
- keras_hub/src/models/opt/opt_presets.py +4 -12
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
- keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
- keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
- keras_hub/src/models/phi3/phi3_decoder.py +0 -1
- keras_hub/src/models/phi3/phi3_presets.py +2 -6
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
- keras_hub/src/models/preprocessor.py +25 -11
- keras_hub/src/models/resnet/resnet_backbone.py +3 -14
- keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
- keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
- keras_hub/src/models/resnet/resnet_presets.py +127 -18
- keras_hub/src/models/retinanet/__init__.py +5 -0
- keras_hub/src/models/retinanet/anchor_generator.py +52 -53
- keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
- keras_hub/src/models/retinanet/non_max_supression.py +1 -0
- keras_hub/src/models/retinanet/prediction_head.py +192 -0
- keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
- keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
- keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
- keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
- keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
- keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
- keras_hub/src/models/roberta/roberta_backbone.py +2 -2
- keras_hub/src/models/roberta/roberta_presets.py +6 -8
- keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
- keras_hub/src/models/sam/__init__.py +5 -0
- keras_hub/src/models/sam/sam_backbone.py +2 -3
- keras_hub/src/models/sam/sam_image_converter.py +2 -4
- keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
- keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
- keras_hub/src/models/sam/sam_layers.py +5 -3
- keras_hub/src/models/sam/sam_presets.py +3 -9
- keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
- keras_hub/src/models/sam/sam_transformer.py +5 -4
- keras_hub/src/models/segformer/__init__.py +8 -0
- keras_hub/src/models/segformer/segformer_backbone.py +167 -0
- keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
- keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
- keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
- keras_hub/src/models/segformer/segformer_presets.py +136 -0
- keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
- keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
- keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
- keras_hub/src/models/t5/t5_backbone.py +5 -4
- keras_hub/src/models/t5/t5_presets.py +47 -19
- keras_hub/src/models/task.py +47 -39
- keras_hub/src/models/text_classifier.py +2 -2
- keras_hub/src/models/text_to_image.py +106 -41
- keras_hub/src/models/vae/__init__.py +1 -0
- keras_hub/src/models/vae/vae_backbone.py +184 -0
- keras_hub/src/models/vae/vae_layers.py +739 -0
- keras_hub/src/models/vgg/__init__.py +5 -0
- keras_hub/src/models/vgg/vgg_backbone.py +4 -24
- keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
- keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
- keras_hub/src/models/vgg/vgg_presets.py +48 -0
- keras_hub/src/models/vit/__init__.py +5 -0
- keras_hub/src/models/vit/vit_backbone.py +152 -0
- keras_hub/src/models/vit/vit_image_classifier.py +187 -0
- keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vit/vit_image_converter.py +73 -0
- keras_hub/src/models/vit/vit_layers.py +391 -0
- keras_hub/src/models/vit/vit_presets.py +126 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
- keras_hub/src/models/vit_det/vit_layers.py +3 -3
- keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
- keras_hub/src/models/whisper/whisper_backbone.py +6 -5
- keras_hub/src/models/whisper/whisper_decoder.py +3 -5
- keras_hub/src/models/whisper/whisper_presets.py +10 -30
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
- keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
- keras_hub/src/models/xlnet/relative_attention.py +20 -19
- keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
- keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
- keras_hub/src/samplers/contrastive_sampler.py +2 -3
- keras_hub/src/samplers/sampler.py +2 -1
- keras_hub/src/tests/test_case.py +41 -6
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
- keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
- keras_hub/src/tokenizers/tokenizer.py +10 -13
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
- keras_hub/src/utils/keras_utils.py +2 -13
- keras_hub/src/utils/pipeline_model.py +3 -3
- keras_hub/src/utils/preset_utils.py +196 -144
- keras_hub/src/utils/tensor_utils.py +4 -4
- keras_hub/src/utils/timm/convert_densenet.py +6 -4
- keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
- keras_hub/src/utils/timm/convert_resnet.py +1 -1
- keras_hub/src/utils/timm/convert_vgg.py +85 -0
- keras_hub/src/utils/timm/preset_loader.py +14 -9
- keras_hub/src/utils/transformers/convert_llama3.py +21 -5
- keras_hub/src/utils/transformers/convert_vit.py +150 -0
- keras_hub/src/utils/transformers/preset_loader.py +23 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
- keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
- keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
- keras_hub/src/models/mix_transformer/__init__.py +0 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
- keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
- keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -1,47 +1,190 @@
|
|
1
|
+
import math
|
2
|
+
|
3
|
+
import keras
|
4
|
+
import numpy as np
|
5
|
+
from keras import ops
|
6
|
+
|
1
7
|
from keras_hub.src.api_export import keras_hub_export
|
2
8
|
from keras_hub.src.layers.preprocessing.preprocessing_layer import (
|
3
9
|
PreprocessingLayer,
|
4
10
|
)
|
5
|
-
from keras_hub.src.utils.
|
11
|
+
from keras_hub.src.utils.keras_utils import standardize_data_format
|
6
12
|
from keras_hub.src.utils.preset_utils import builtin_presets
|
7
13
|
from keras_hub.src.utils.preset_utils import find_subclass
|
8
14
|
from keras_hub.src.utils.preset_utils import get_preset_loader
|
9
|
-
from keras_hub.src.utils.preset_utils import
|
15
|
+
from keras_hub.src.utils.preset_utils import get_preset_saver
|
10
16
|
from keras_hub.src.utils.python_utils import classproperty
|
17
|
+
from keras_hub.src.utils.tensor_utils import preprocessing_function
|
11
18
|
|
12
19
|
|
13
20
|
@keras_hub_export("keras_hub.layers.ImageConverter")
|
14
21
|
class ImageConverter(PreprocessingLayer):
|
15
|
-
"""
|
22
|
+
"""Preprocess raw images into model ready inputs.
|
23
|
+
|
24
|
+
This class converts from raw images to model ready inputs. This conversion
|
25
|
+
proceeds in the following steps:
|
16
26
|
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
27
|
+
1. Resize the image using to `image_size`. If `image_size` is `None`, this
|
28
|
+
step will be skipped.
|
29
|
+
2. Rescale the image by multiplying by `scale`, which can be either global
|
30
|
+
or per channel. If `scale` is `None`, this step will be skipped.
|
31
|
+
3. Offset the image by adding `offset`, which can be either global
|
32
|
+
or per channel. If `offset` is `None`, this step will be skipped.
|
22
33
|
|
23
34
|
The layer will take as input a raw image tensor in the channels last or
|
24
35
|
channels first format, and output a preprocessed image input for modeling.
|
25
|
-
|
26
|
-
|
27
|
-
|
36
|
+
This tensor can be batched (rank 4), or unbatched (rank 3).
|
37
|
+
|
38
|
+
This layer can be used with the `from_preset()` constructor to load a layer
|
39
|
+
that will rescale and resize an image for a specific pretrained model.
|
40
|
+
Using the layer this way allows writing preprocessing code that does not
|
41
|
+
need updating when switching between model checkpoints.
|
42
|
+
|
43
|
+
Args:
|
44
|
+
image_size: `(int, int)` tuple or `None`. The output size of the image,
|
45
|
+
not including the channels axis. If `None`, the input will not be
|
46
|
+
resized.
|
47
|
+
scale: float, tuple of floats, or `None`. The scale to apply to the
|
48
|
+
inputs. If `scale` is a single float, the entire input will be
|
49
|
+
multiplied by `scale`. If `scale` is a tuple, it's assumed to
|
50
|
+
contain per-channel scale value multiplied against each channel of
|
51
|
+
the input images. If `scale` is `None`, no scaling is applied.
|
52
|
+
offset: float, tuple of floats, or `None`. The offset to apply to the
|
53
|
+
inputs. If `offset` is a single float, the entire input will be
|
54
|
+
summed with `offset`. If `offset` is a tuple, it's assumed to
|
55
|
+
contain per-channel offset value summed against each channel of the
|
56
|
+
input images. If `offset` is `None`, no scaling is applied.
|
57
|
+
crop_to_aspect_ratio: If `True`, resize the images without aspect
|
58
|
+
ratio distortion. When the original aspect ratio differs
|
59
|
+
from the target aspect ratio, the output image will be
|
60
|
+
cropped so as to return the
|
61
|
+
largest possible window in the image (of size `(height, width)`)
|
62
|
+
that matches the target aspect ratio. By default
|
63
|
+
(`crop_to_aspect_ratio=False`), aspect ratio may not be preserved.
|
64
|
+
interpolation: String, the interpolation method.
|
65
|
+
Supports `"bilinear"`, `"nearest"`, `"bicubic"`,
|
66
|
+
`"lanczos3"`, `"lanczos5"`. Defaults to `"bilinear"`.
|
67
|
+
data_format: String, either `"channels_last"` or `"channels_first"`.
|
68
|
+
The ordering of the dimensions in the inputs. `"channels_last"`
|
69
|
+
corresponds to inputs with shape `(batch, height, width, channels)`
|
70
|
+
while `"channels_first"` corresponds to inputs with shape
|
71
|
+
`(batch, channels, height, width)`. It defaults to the
|
72
|
+
`image_data_format` value found in your Keras config file at
|
73
|
+
`~/.keras/keras.json`. If you never set it, then it will be
|
74
|
+
`"channels_last"`.
|
28
75
|
|
29
76
|
Examples:
|
30
77
|
```python
|
31
|
-
# Resize images
|
32
|
-
converter = keras_hub.layers.ImageConverter
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
converter(np.
|
78
|
+
# Resize raw images and scale them to [0, 1].
|
79
|
+
converter = keras_hub.layers.ImageConverter(
|
80
|
+
image_size=(128, 128),
|
81
|
+
scale=1. / 255,
|
82
|
+
)
|
83
|
+
converter(np.random.randint(0, 256, size=(2, 512, 512, 3)))
|
84
|
+
|
85
|
+
# Resize images to the specific size needed for a PaliGemma preset.
|
86
|
+
converter = keras_hub.layers.ImageConverter.from_preset(
|
87
|
+
"pali_gemma_3b_224"
|
88
|
+
)
|
89
|
+
converter(np.random.randint(0, 256, size=(2, 512, 512, 3)))
|
37
90
|
```
|
38
91
|
"""
|
39
92
|
|
40
93
|
backbone_cls = None
|
41
94
|
|
95
|
+
def __init__(
|
96
|
+
self,
|
97
|
+
image_size=None,
|
98
|
+
scale=None,
|
99
|
+
offset=None,
|
100
|
+
crop_to_aspect_ratio=True,
|
101
|
+
interpolation="bilinear",
|
102
|
+
data_format=None,
|
103
|
+
**kwargs,
|
104
|
+
):
|
105
|
+
# TODO: old arg names. Delete this block after resaving Kaggle assets.
|
106
|
+
if "height" in kwargs and "width" in kwargs:
|
107
|
+
image_size = (kwargs.pop("height"), kwargs.pop("width"))
|
108
|
+
if "variance" in kwargs and "mean" in kwargs:
|
109
|
+
std = [math.sqrt(v) for v in kwargs.pop("variance")]
|
110
|
+
scale = [scale / s for s in std]
|
111
|
+
offset = [-m / s for m, s in zip(kwargs.pop("mean"), std)]
|
112
|
+
|
113
|
+
super().__init__(**kwargs)
|
114
|
+
|
115
|
+
# Create the `Resizing` layer here even if it's not being used. That
|
116
|
+
# allows us to make `image_size` a settable property.
|
117
|
+
self.resizing = keras.layers.Resizing(
|
118
|
+
height=image_size[0] if image_size else None,
|
119
|
+
width=image_size[1] if image_size else None,
|
120
|
+
crop_to_aspect_ratio=crop_to_aspect_ratio,
|
121
|
+
interpolation=interpolation,
|
122
|
+
data_format=data_format,
|
123
|
+
dtype=self.dtype_policy,
|
124
|
+
name="resizing",
|
125
|
+
)
|
126
|
+
self.scale = scale
|
127
|
+
self.offset = offset
|
128
|
+
self.crop_to_aspect_ratio = crop_to_aspect_ratio
|
129
|
+
self.interpolation = interpolation
|
130
|
+
self.data_format = standardize_data_format(data_format)
|
131
|
+
|
132
|
+
@property
|
42
133
|
def image_size(self):
|
43
|
-
"""
|
44
|
-
|
134
|
+
"""Settable tuple of `(height, width)` ints. The output image shape."""
|
135
|
+
if self.resizing.height is None:
|
136
|
+
return None
|
137
|
+
return (self.resizing.height, self.resizing.width)
|
138
|
+
|
139
|
+
@image_size.setter
|
140
|
+
def image_size(self, value):
|
141
|
+
if value is None:
|
142
|
+
value = (None, None)
|
143
|
+
self.resizing.height = value[0]
|
144
|
+
self.resizing.width = value[1]
|
145
|
+
|
146
|
+
@preprocessing_function
|
147
|
+
def call(self, inputs):
|
148
|
+
x = inputs
|
149
|
+
if self.image_size is not None:
|
150
|
+
x = self.resizing(x)
|
151
|
+
if self.scale is not None:
|
152
|
+
x = x * self._expand_non_channel_dims(self.scale, x)
|
153
|
+
if self.offset is not None:
|
154
|
+
x = x + self._expand_non_channel_dims(self.offset, x)
|
155
|
+
return x
|
156
|
+
|
157
|
+
def _expand_non_channel_dims(self, value, inputs):
|
158
|
+
unbatched = len(ops.shape(inputs)) == 3
|
159
|
+
channels_first = self.data_format == "channels_first"
|
160
|
+
if unbatched:
|
161
|
+
broadcast_dims = (1, 2) if channels_first else (0, 1)
|
162
|
+
else:
|
163
|
+
broadcast_dims = (0, 2, 3) if channels_first else (0, 1, 2)
|
164
|
+
# If inputs are not a tensor type, return a numpy array.
|
165
|
+
# This might happen when running under tf.data.
|
166
|
+
if ops.is_tensor(inputs):
|
167
|
+
# preprocessing decorator moves tensors to cpu in torch backend and
|
168
|
+
# processed on CPU, and then converted back to the appropriate
|
169
|
+
# device (potentially GPU) after preprocessing.
|
170
|
+
if keras.backend.backend() == "torch" and self.image_size is None:
|
171
|
+
return ops.expand_dims(value, broadcast_dims).cpu()
|
172
|
+
return ops.expand_dims(value, broadcast_dims)
|
173
|
+
else:
|
174
|
+
return np.expand_dims(value, broadcast_dims)
|
175
|
+
|
176
|
+
def get_config(self):
|
177
|
+
config = super().get_config()
|
178
|
+
config.update(
|
179
|
+
{
|
180
|
+
"image_size": self.image_size,
|
181
|
+
"scale": self.scale,
|
182
|
+
"offset": self.offset,
|
183
|
+
"interpolation": self.interpolation,
|
184
|
+
"crop_to_aspect_ratio": self.crop_to_aspect_ratio,
|
185
|
+
}
|
186
|
+
)
|
187
|
+
return config
|
45
188
|
|
46
189
|
@classproperty
|
47
190
|
def presets(cls):
|
@@ -69,13 +212,6 @@ class ImageConverter(PreprocessingLayer):
|
|
69
212
|
You can run `cls.presets.keys()` to list all built-in presets available
|
70
213
|
on the class.
|
71
214
|
|
72
|
-
This constructor can be called in one of two ways. Either from the base
|
73
|
-
class like `keras_hub.models.ImageConverter.from_preset()`, or from a
|
74
|
-
model class like
|
75
|
-
`keras_hub.models.PaliGemmaImageConverter.from_preset()`. If calling
|
76
|
-
from the base class, the subclass of the returning object will be
|
77
|
-
inferred from the config in the preset directory.
|
78
|
-
|
79
215
|
Args:
|
80
216
|
preset: string. A built-in preset identifier, a Kaggle Models
|
81
217
|
handle, a Hugging Face handle, or a path to a local directory.
|
@@ -85,17 +221,20 @@ class ImageConverter(PreprocessingLayer):
|
|
85
221
|
|
86
222
|
Examples:
|
87
223
|
```python
|
224
|
+
batch = np.random.randint(0, 256, size=(2, 512, 512, 3))
|
225
|
+
|
88
226
|
# Resize images for `"pali_gemma_3b_224"`.
|
89
227
|
converter = keras_hub.layers.ImageConverter.from_preset(
|
90
228
|
"pali_gemma_3b_224"
|
91
229
|
)
|
92
|
-
converter(
|
93
|
-
|
230
|
+
converter(batch) # Output shape: (2, 224, 224, 3)
|
231
|
+
|
232
|
+
# Resize images for `"pali_gemma_3b_448"` without cropping.
|
94
233
|
converter = keras_hub.layers.ImageConverter.from_preset(
|
95
234
|
"pali_gemma_3b_448",
|
96
235
|
crop_to_aspect_ratio=False,
|
97
236
|
)
|
98
|
-
converter(
|
237
|
+
converter(batch) # Output shape: (2, 448, 448, 3)
|
99
238
|
```
|
100
239
|
"""
|
101
240
|
loader = get_preset_loader(preset)
|
@@ -110,8 +249,5 @@ class ImageConverter(PreprocessingLayer):
|
|
110
249
|
Args:
|
111
250
|
preset_dir: The path to the local model preset directory.
|
112
251
|
"""
|
113
|
-
|
114
|
-
|
115
|
-
preset_dir,
|
116
|
-
config_file=IMAGE_CONVERTER_CONFIG_FILE,
|
117
|
-
)
|
252
|
+
saver = get_preset_saver(preset_dir)
|
253
|
+
saver.save_image_converter(self)
|
keras_hub/src/metrics/bleu.py
CHANGED
@@ -164,7 +164,7 @@ class Bleu(keras.metrics.Metric):
|
|
164
164
|
return inputs
|
165
165
|
|
166
166
|
def _get_ngrams(self, segment, max_order):
|
167
|
-
"""Extracts all n-grams up to a given maximum order from an input
|
167
|
+
"""Extracts all n-grams up to a given maximum order from an input.
|
168
168
|
|
169
169
|
Uses Python ops. Inspired from
|
170
170
|
https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py.
|
@@ -329,8 +329,9 @@ class Bleu(keras.metrics.Metric):
|
|
329
329
|
return tf.squeeze(inputs, axis=-1)
|
330
330
|
else:
|
331
331
|
raise ValueError(
|
332
|
-
f"{tensor_name} must be of rank {base_rank},
|
333
|
-
f"or {base_rank+2}.
|
332
|
+
f"{tensor_name} must be of rank {base_rank}, "
|
333
|
+
f"{base_rank + 1}, or {base_rank + 2}. "
|
334
|
+
f"Found rank: {inputs.shape.rank}"
|
334
335
|
)
|
335
336
|
|
336
337
|
y_true = validate_and_fix_rank(y_true, "y_true", 1)
|
@@ -8,11 +8,9 @@ backbone_presets = {
|
|
8
8
|
"Trained on English Wikipedia + BooksCorpus."
|
9
9
|
),
|
10
10
|
"params": 11683584,
|
11
|
-
"official_name": "ALBERT",
|
12
11
|
"path": "albert",
|
13
|
-
"model_card": "https://github.com/google-research/albert/blob/master/README.md",
|
14
12
|
},
|
15
|
-
"kaggle_handle": "kaggle://keras/albert/keras/albert_base_en_uncased/
|
13
|
+
"kaggle_handle": "kaggle://keras/albert/keras/albert_base_en_uncased/5",
|
16
14
|
},
|
17
15
|
"albert_large_en_uncased": {
|
18
16
|
"metadata": {
|
@@ -21,11 +19,9 @@ backbone_presets = {
|
|
21
19
|
"Trained on English Wikipedia + BooksCorpus."
|
22
20
|
),
|
23
21
|
"params": 17683968,
|
24
|
-
"official_name": "ALBERT",
|
25
22
|
"path": "albert",
|
26
|
-
"model_card": "https://github.com/google-research/albert/blob/master/README.md",
|
27
23
|
},
|
28
|
-
"kaggle_handle": "kaggle://keras/albert/keras/albert_large_en_uncased/
|
24
|
+
"kaggle_handle": "kaggle://keras/albert/keras/albert_large_en_uncased/3",
|
29
25
|
},
|
30
26
|
"albert_extra_large_en_uncased": {
|
31
27
|
"metadata": {
|
@@ -34,11 +30,9 @@ backbone_presets = {
|
|
34
30
|
"Trained on English Wikipedia + BooksCorpus."
|
35
31
|
),
|
36
32
|
"params": 58724864,
|
37
|
-
"official_name": "ALBERT",
|
38
33
|
"path": "albert",
|
39
|
-
"model_card": "https://github.com/google-research/albert/blob/master/README.md",
|
40
34
|
},
|
41
|
-
"kaggle_handle": "kaggle://keras/albert/keras/albert_extra_large_en_uncased/
|
35
|
+
"kaggle_handle": "kaggle://keras/albert/keras/albert_extra_large_en_uncased/3",
|
42
36
|
},
|
43
37
|
"albert_extra_extra_large_en_uncased": {
|
44
38
|
"metadata": {
|
@@ -47,10 +41,8 @@ backbone_presets = {
|
|
47
41
|
"Trained on English Wikipedia + BooksCorpus."
|
48
42
|
),
|
49
43
|
"params": 222595584,
|
50
|
-
"official_name": "ALBERT",
|
51
44
|
"path": "albert",
|
52
|
-
"model_card": "https://github.com/google-research/albert/blob/master/README.md",
|
53
45
|
},
|
54
|
-
"kaggle_handle": "kaggle://keras/albert/keras/albert_extra_extra_large_en_uncased/
|
46
|
+
"kaggle_handle": "kaggle://keras/albert/keras/albert_extra_extra_large_en_uncased/3",
|
55
47
|
},
|
56
48
|
}
|
@@ -20,10 +20,10 @@ from keras_hub.src.models.text_classifier import TextClassifier
|
|
20
20
|
class AlbertTextClassifier(TextClassifier):
|
21
21
|
"""An end-to-end ALBERT model for classification tasks
|
22
22
|
|
23
|
-
This model attaches a classification head to a
|
24
|
-
backbone, mapping from the backbone outputs
|
25
|
-
a classification task. For usage of this model
|
26
|
-
the `from_preset()` method.
|
23
|
+
This model attaches a classification head to a
|
24
|
+
`keras_hub.model.AlbertBackbone` backbone, mapping from the backbone outputs
|
25
|
+
to logit output suitable for a classification task. For usage of this model
|
26
|
+
with pre-trained weights, see the `from_preset()` method.
|
27
27
|
|
28
28
|
This model can optionally be configured with a `preprocessor` layer, in
|
29
29
|
which case it will automatically apply preprocessing to raw inputs during
|
@@ -36,9 +36,9 @@ class AlbertTextClassifier(TextClassifier):
|
|
36
36
|
Args:
|
37
37
|
backbone: A `keras_hub.models.AlertBackbone` instance.
|
38
38
|
num_classes: int. Number of classes to predict.
|
39
|
-
preprocessor: A `keras_hub.models.AlbertTextClassifierPreprocessor` or
|
40
|
-
`None`, this model will not apply preprocessing, and
|
41
|
-
be preprocessed before calling the model.
|
39
|
+
preprocessor: A `keras_hub.models.AlbertTextClassifierPreprocessor` or
|
40
|
+
`None`. If `None`, this model will not apply preprocessing, and
|
41
|
+
inputs should be preprocessed before calling the model.
|
42
42
|
activation: Optional `str` or callable. The
|
43
43
|
activation function to use on the model outputs. Set
|
44
44
|
`activation="softmax"` to return output probabilities.
|
keras_hub/src/models/backbone.py
CHANGED
@@ -1,15 +1,9 @@
|
|
1
|
-
import os
|
2
|
-
|
3
1
|
import keras
|
4
2
|
|
5
3
|
from keras_hub.src.api_export import keras_hub_export
|
6
|
-
from keras_hub.src.utils.keras_utils import assert_quantization_support
|
7
|
-
from keras_hub.src.utils.preset_utils import CONFIG_FILE
|
8
|
-
from keras_hub.src.utils.preset_utils import MODEL_WEIGHTS_FILE
|
9
4
|
from keras_hub.src.utils.preset_utils import builtin_presets
|
10
5
|
from keras_hub.src.utils.preset_utils import get_preset_loader
|
11
|
-
from keras_hub.src.utils.preset_utils import
|
12
|
-
from keras_hub.src.utils.preset_utils import save_serialized_object
|
6
|
+
from keras_hub.src.utils.preset_utils import get_preset_saver
|
13
7
|
from keras_hub.src.utils.python_utils import classproperty
|
14
8
|
|
15
9
|
|
@@ -88,10 +82,6 @@ class Backbone(keras.Model):
|
|
88
82
|
def token_embedding(self, value):
|
89
83
|
self._token_embedding = value
|
90
84
|
|
91
|
-
def quantize(self, mode, **kwargs):
|
92
|
-
assert_quantization_support()
|
93
|
-
return super().quantize(mode, **kwargs)
|
94
|
-
|
95
85
|
def get_config(self):
|
96
86
|
# Don't chain to super here. `get_config()` for functional models is
|
97
87
|
# a nested layer config and cannot be passed to Backbone constructors.
|
@@ -193,9 +183,8 @@ class Backbone(keras.Model):
|
|
193
183
|
Args:
|
194
184
|
preset_dir: The path to the local model preset directory.
|
195
185
|
"""
|
196
|
-
|
197
|
-
|
198
|
-
save_metadata(self, preset_dir)
|
186
|
+
saver = get_preset_saver(preset_dir)
|
187
|
+
saver.save_backbone(self)
|
199
188
|
|
200
189
|
def enable_lora(self, rank):
|
201
190
|
"""Enable Lora on the backbone.
|
@@ -22,9 +22,9 @@ class BartBackbone(Backbone):
|
|
22
22
|
described in
|
23
23
|
["BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension"](https://arxiv.org/abs/1910.13461).
|
24
24
|
|
25
|
-
The default constructor gives a fully customizable, randomly initialized
|
26
|
-
model with any number of layers, heads, and embedding dimensions. To
|
27
|
-
preset architectures and weights, use the `from_preset` constructor.
|
25
|
+
The default constructor gives a fully customizable, randomly initialized
|
26
|
+
BART model with any number of layers, heads, and embedding dimensions. To
|
27
|
+
load preset architectures and weights, use the `from_preset` constructor.
|
28
28
|
|
29
29
|
Disclaimer: Pre-trained models are provided on an "as is" basis, without
|
30
30
|
warranties or conditions of any kind. The underlying model is provided by a
|
@@ -78,7 +78,7 @@ class BartBackbone(Backbone):
|
|
78
78
|
)
|
79
79
|
output = model(input_data)
|
80
80
|
```
|
81
|
-
"""
|
81
|
+
""" # noqa: E501
|
82
82
|
|
83
83
|
def __init__(
|
84
84
|
self,
|
@@ -8,11 +8,9 @@ backbone_presets = {
|
|
8
8
|
"Trained on BookCorpus, English Wikipedia and CommonCrawl."
|
9
9
|
),
|
10
10
|
"params": 139417344,
|
11
|
-
"official_name": "BART",
|
12
11
|
"path": "bart",
|
13
|
-
"model_card": "https://github.com/facebookresearch/fairseq/blob/main/examples/bart/README.md",
|
14
12
|
},
|
15
|
-
"kaggle_handle": "kaggle://keras/bart/keras/bart_base_en/
|
13
|
+
"kaggle_handle": "kaggle://keras/bart/keras/bart_base_en/3",
|
16
14
|
},
|
17
15
|
"bart_large_en": {
|
18
16
|
"metadata": {
|
@@ -21,9 +19,7 @@ backbone_presets = {
|
|
21
19
|
"Trained on BookCorpus, English Wikipedia and CommonCrawl."
|
22
20
|
),
|
23
21
|
"params": 406287360,
|
24
|
-
"official_name": "BART",
|
25
22
|
"path": "bart",
|
26
|
-
"model_card": "https://github.com/facebookresearch/fairseq/blob/main/examples/bart/README.md",
|
27
23
|
},
|
28
24
|
"config": {
|
29
25
|
"vocabulary_size": 50265,
|
@@ -34,7 +30,7 @@ backbone_presets = {
|
|
34
30
|
"dropout": 0.1,
|
35
31
|
"max_sequence_length": 1024,
|
36
32
|
},
|
37
|
-
"kaggle_handle": "kaggle://keras/bart/keras/bart_large_en/
|
33
|
+
"kaggle_handle": "kaggle://keras/bart/keras/bart_large_en/3",
|
38
34
|
},
|
39
35
|
"bart_large_en_cnn": {
|
40
36
|
"metadata": {
|
@@ -43,9 +39,7 @@ backbone_presets = {
|
|
43
39
|
"summarization dataset."
|
44
40
|
),
|
45
41
|
"params": 406287360,
|
46
|
-
"official_name": "BART",
|
47
42
|
"path": "bart",
|
48
|
-
"model_card": "https://github.com/facebookresearch/fairseq/blob/main/examples/bart/README.md",
|
49
43
|
},
|
50
44
|
"config": {
|
51
45
|
"vocabulary_size": 50264,
|
@@ -56,6 +50,6 @@ backbone_presets = {
|
|
56
50
|
"dropout": 0.1,
|
57
51
|
"max_sequence_length": 1024,
|
58
52
|
},
|
59
|
-
"kaggle_handle": "kaggle://keras/bart/keras/bart_large_en_cnn/
|
53
|
+
"kaggle_handle": "kaggle://keras/bart/keras/bart_large_en_cnn/3",
|
60
54
|
},
|
61
55
|
}
|
@@ -60,7 +60,8 @@ class BartSeq2SeqLM(Seq2SeqLM):
|
|
60
60
|
bart_lm.generate("The quick brown fox", max_length=30)
|
61
61
|
```
|
62
62
|
|
63
|
-
Use `generate()` with encoder inputs and an incomplete decoder input
|
63
|
+
Use `generate()` with encoder inputs and an incomplete decoder input
|
64
|
+
(prompt).
|
64
65
|
```python
|
65
66
|
bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset("bart_base_en")
|
66
67
|
bart_lm.generate(
|
@@ -79,10 +80,10 @@ class BartSeq2SeqLM(Seq2SeqLM):
|
|
79
80
|
prompt = {
|
80
81
|
"encoder_token_ids": np.array([[0, 133, 2119, 6219, 23602, 2, 1, 1]]),
|
81
82
|
"encoder_padding_mask": np.array(
|
82
|
-
[[
|
83
|
+
[[1, 1, 1, 1, 1, 1, 0, 0]]
|
83
84
|
),
|
84
85
|
"decoder_token_ids": np.array([[2, 0, 133, 1769, 2, 1, 1]]),
|
85
|
-
"decoder_padding_mask": np.array([[
|
86
|
+
"decoder_padding_mask": np.array([[1, 1, 1, 1, 0, 0]])
|
86
87
|
}
|
87
88
|
|
88
89
|
bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset(
|
@@ -95,7 +96,7 @@ class BartSeq2SeqLM(Seq2SeqLM):
|
|
95
96
|
Call `fit()` on a single batch.
|
96
97
|
```python
|
97
98
|
features = {
|
98
|
-
"encoder_text": ["The quick
|
99
|
+
"encoder_text": ["The quick fox jumped.", "I forgot my homework."],
|
99
100
|
"decoder_text": ["The fast hazel fox leapt.", "I forgot my assignment."]
|
100
101
|
}
|
101
102
|
bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset("bart_base_en")
|
@@ -195,7 +196,7 @@ class BartSeq2SeqLM(Seq2SeqLM):
|
|
195
196
|
cross_attention_cache=None,
|
196
197
|
cross_attention_cache_update_index=None,
|
197
198
|
):
|
198
|
-
"""Forward pass with a key/value caches for generative decoding
|
199
|
+
"""Forward pass with a key/value caches for generative decoding.
|
199
200
|
|
200
201
|
`call_decoder_with_cache` adds an additional inference-time forward pass
|
201
202
|
for the model for seq2seq text generation. Unlike calling the model
|
@@ -241,7 +242,7 @@ class BartSeq2SeqLM(Seq2SeqLM):
|
|
241
242
|
key/value cache in the decoder's self-attention layer and
|
242
243
|
`cross_attention_cache` is the key/value cache in the decoder's
|
243
244
|
cross-attention layer.
|
244
|
-
"""
|
245
|
+
""" # noqa: E501
|
245
246
|
# Embedding layers.
|
246
247
|
tokens = self.backbone.token_embedding(decoder_token_ids)
|
247
248
|
positions = self.backbone.decoder_position_embedding(
|
@@ -331,7 +332,7 @@ class BartSeq2SeqLM(Seq2SeqLM):
|
|
331
332
|
def _build_cache(
|
332
333
|
self, encoder_token_ids, encoder_padding_mask, decoder_token_ids
|
333
334
|
):
|
334
|
-
"""Builds the self-attention cache and the cross-attention cache
|
335
|
+
"""Builds the self-attention cache and the cross-attention cache."""
|
335
336
|
encoder_hidden_states = self.call_encoder(
|
336
337
|
token_ids=encoder_token_ids, padding_mask=encoder_padding_mask
|
337
338
|
)
|
@@ -417,7 +418,7 @@ class BartSeq2SeqLM(Seq2SeqLM):
|
|
417
418
|
prompt = ops.slice(prompt, [0, cache_index], [num_samples, 1])
|
418
419
|
|
419
420
|
def repeat_tensor(x):
|
420
|
-
"""Repeats
|
421
|
+
"""Repeats along batch axis to match dim for beam search."""
|
421
422
|
if ops.shape(x)[0] == num_samples:
|
422
423
|
return x
|
423
424
|
return ops.repeat(x, repeats=num_samples // batch_size, axis=0)
|
@@ -0,0 +1,122 @@
|
|
1
|
+
import keras
|
2
|
+
|
3
|
+
from keras_hub.src.api_export import keras_hub_export
|
4
|
+
from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone
|
5
|
+
from keras_hub.src.models.basnet.basnet_preprocessor import BASNetPreprocessor
|
6
|
+
from keras_hub.src.models.image_segmenter import ImageSegmenter
|
7
|
+
|
8
|
+
|
9
|
+
@keras_hub_export("keras_hub.models.BASNetImageSegmenter")
|
10
|
+
class BASNetImageSegmenter(ImageSegmenter):
|
11
|
+
"""BASNet image segmentation task.
|
12
|
+
|
13
|
+
Args:
|
14
|
+
backbone: A `keras_hub.models.BASNetBackbone` instance.
|
15
|
+
preprocessor: `None`, a `keras_hub.models.Preprocessor` instance,
|
16
|
+
a `keras.Layer` instance, or a callable. If `None` no preprocessing
|
17
|
+
will be applied to the inputs.
|
18
|
+
|
19
|
+
Example:
|
20
|
+
```python
|
21
|
+
import keras_hub
|
22
|
+
|
23
|
+
images = np.ones(shape=(1, 288, 288, 3))
|
24
|
+
labels = np.zeros(shape=(1, 288, 288, 1))
|
25
|
+
|
26
|
+
image_encoder = keras_hub.models.ResNetBackbone.from_preset(
|
27
|
+
"resnet_18_imagenet",
|
28
|
+
load_weights=False
|
29
|
+
)
|
30
|
+
backbone = keras_hub.models.BASNetBackbone(
|
31
|
+
image_encoder,
|
32
|
+
num_classes=1,
|
33
|
+
image_shape=[288, 288, 3]
|
34
|
+
)
|
35
|
+
model = keras_hub.models.BASNetImageSegmenter(backbone)
|
36
|
+
|
37
|
+
# Evaluate the model
|
38
|
+
pred_labels = model(images)
|
39
|
+
|
40
|
+
# Train the model
|
41
|
+
model.compile(
|
42
|
+
optimizer="adam",
|
43
|
+
loss=keras.losses.BinaryCrossentropy(from_logits=False),
|
44
|
+
metrics=["accuracy"],
|
45
|
+
)
|
46
|
+
model.fit(images, labels, epochs=3)
|
47
|
+
```
|
48
|
+
"""
|
49
|
+
|
50
|
+
backbone_cls = BASNetBackbone
|
51
|
+
preprocessor_cls = BASNetPreprocessor
|
52
|
+
|
53
|
+
def __init__(
|
54
|
+
self,
|
55
|
+
backbone,
|
56
|
+
preprocessor=None,
|
57
|
+
**kwargs,
|
58
|
+
):
|
59
|
+
# === Functional Model ===
|
60
|
+
x = backbone.input
|
61
|
+
outputs = backbone(x)
|
62
|
+
# only return the refinement module's output as final prediction
|
63
|
+
outputs = outputs["refine_out"]
|
64
|
+
super().__init__(inputs=x, outputs=outputs, **kwargs)
|
65
|
+
|
66
|
+
# === Config ===
|
67
|
+
self.backbone = backbone
|
68
|
+
self.preprocessor = preprocessor
|
69
|
+
|
70
|
+
def compute_loss(self, x, y, y_pred, *args, **kwargs):
|
71
|
+
# train BASNet's prediction and refinement module outputs against the
|
72
|
+
# same ground truth data
|
73
|
+
outputs = self.backbone(x)
|
74
|
+
losses = []
|
75
|
+
for output in outputs.values():
|
76
|
+
losses.append(super().compute_loss(x, y, output, *args, **kwargs))
|
77
|
+
return keras.ops.sum(losses, axis=0)
|
78
|
+
|
79
|
+
def compile(
|
80
|
+
self,
|
81
|
+
optimizer="auto",
|
82
|
+
loss="auto",
|
83
|
+
metrics="auto",
|
84
|
+
**kwargs,
|
85
|
+
):
|
86
|
+
"""Configures the `BASNet` task for training.
|
87
|
+
|
88
|
+
`BASNet` extends the default compilation signature
|
89
|
+
of `keras.Model.compile` with defaults for `optimizer` and `loss`. To
|
90
|
+
override these defaults, pass any value to these arguments during
|
91
|
+
compilation.
|
92
|
+
|
93
|
+
Args:
|
94
|
+
optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer`
|
95
|
+
instance. Defaults to `"auto"`, which uses the default
|
96
|
+
optimizer for `BASNet`. See `keras.Model.compile` and
|
97
|
+
`keras.optimizers` for more info on possible `optimizer`
|
98
|
+
values.
|
99
|
+
loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance.
|
100
|
+
Defaults to `"auto"`, in which case the default loss
|
101
|
+
computation of `BASNet` will be applied.
|
102
|
+
See `keras.Model.compile` and `keras.losses` for more info on
|
103
|
+
possible `loss` values.
|
104
|
+
metrics: `"auto"`, or a list of metrics to be evaluated by
|
105
|
+
the model during training and testing. Defaults to `"auto"`,
|
106
|
+
where a `keras.metrics.Accuracy` will be applied to track the
|
107
|
+
accuracy of the model during training.
|
108
|
+
See `keras.Model.compile` and `keras.metrics` for
|
109
|
+
more info on possible `metrics` values.
|
110
|
+
**kwargs: See `keras.Model.compile` for a full list of arguments
|
111
|
+
supported by the compile method.
|
112
|
+
"""
|
113
|
+
if loss == "auto":
|
114
|
+
loss = keras.losses.BinaryCrossentropy()
|
115
|
+
if metrics == "auto":
|
116
|
+
metrics = [keras.metrics.Accuracy()]
|
117
|
+
super().compile(
|
118
|
+
optimizer=optimizer,
|
119
|
+
loss=loss,
|
120
|
+
metrics=metrics,
|
121
|
+
**kwargs,
|
122
|
+
)
|