keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +21 -3
- keras_hub/api/models/__init__.py +71 -12
- keras_hub/api/tokenizers/__init__.py +1 -1
- keras_hub/src/bounding_box/__init__.py +2 -0
- keras_hub/src/bounding_box/converters.py +102 -12
- keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
- keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
- keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
- keras_hub/src/layers/modeling/rms_normalization.py +36 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
- keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
- keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
- keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
- keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
- keras_hub/src/layers/preprocessing/image_converter.py +170 -34
- keras_hub/src/metrics/bleu.py +4 -3
- keras_hub/src/models/albert/albert_presets.py +4 -12
- keras_hub/src/models/albert/albert_text_classifier.py +7 -7
- keras_hub/src/models/backbone.py +3 -14
- keras_hub/src/models/bart/bart_backbone.py +4 -4
- keras_hub/src/models/bart/bart_presets.py +3 -9
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
- keras_hub/src/models/basnet/__init__.py +5 -0
- keras_hub/src/models/basnet/basnet.py +122 -0
- keras_hub/src/models/basnet/basnet_backbone.py +366 -0
- keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
- keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
- keras_hub/src/models/basnet/basnet_presets.py +17 -0
- keras_hub/src/models/bert/bert_presets.py +14 -32
- keras_hub/src/models/bert/bert_text_classifier.py +3 -3
- keras_hub/src/models/bloom/bloom_presets.py +8 -24
- keras_hub/src/models/causal_lm.py +56 -12
- keras_hub/src/models/clip/__init__.py +5 -0
- keras_hub/src/models/clip/clip_backbone.py +286 -0
- keras_hub/src/models/clip/clip_encoder_block.py +19 -4
- keras_hub/src/models/clip/clip_image_converter.py +8 -0
- keras_hub/src/models/clip/clip_presets.py +93 -0
- keras_hub/src/models/clip/clip_text_encoder.py +4 -1
- keras_hub/src/models/clip/clip_tokenizer.py +18 -3
- keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
- keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
- keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
- keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
- keras_hub/src/models/deeplab_v3/__init__.py +7 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
- keras_hub/src/models/densenet/densenet_backbone.py +6 -4
- keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
- keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
- keras_hub/src/models/densenet/densenet_presets.py +9 -15
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
- keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
- keras_hub/src/models/efficientnet/__init__.py +9 -0
- keras_hub/src/models/efficientnet/cba.py +141 -0
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
- keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
- keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
- keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
- keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
- keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
- keras_hub/src/models/efficientnet/mbconv.py +53 -22
- keras_hub/src/models/electra/electra_backbone.py +2 -2
- keras_hub/src/models/electra/electra_presets.py +6 -18
- keras_hub/src/models/f_net/f_net_presets.py +2 -6
- keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
- keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
- keras_hub/src/models/falcon/falcon_backbone.py +5 -3
- keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
- keras_hub/src/models/falcon/falcon_presets.py +1 -3
- keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
- keras_hub/src/models/feature_pyramid_backbone.py +1 -1
- keras_hub/src/models/flux/__init__.py +5 -0
- keras_hub/src/models/flux/flux_layers.py +496 -0
- keras_hub/src/models/flux/flux_maths.py +225 -0
- keras_hub/src/models/flux/flux_model.py +236 -0
- keras_hub/src/models/flux/flux_presets.py +3 -0
- keras_hub/src/models/flux/flux_text_to_image.py +146 -0
- keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
- keras_hub/src/models/gemma/gemma_backbone.py +35 -20
- keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
- keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
- keras_hub/src/models/gemma/gemma_presets.py +29 -63
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
- keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
- keras_hub/src/models/image_classifier.py +147 -2
- keras_hub/src/models/image_classifier_preprocessor.py +6 -3
- keras_hub/src/models/image_object_detector.py +87 -0
- keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
- keras_hub/src/models/image_segmenter.py +0 -5
- keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
- keras_hub/src/models/image_to_image.py +417 -0
- keras_hub/src/models/inpaint.py +520 -0
- keras_hub/src/models/llama/llama_backbone.py +138 -12
- keras_hub/src/models/llama/llama_causal_lm.py +3 -1
- keras_hub/src/models/llama/llama_presets.py +10 -20
- keras_hub/src/models/llama3/llama3_backbone.py +12 -11
- keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
- keras_hub/src/models/llama3/llama3_presets.py +4 -12
- keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
- keras_hub/src/models/mistral/mistral_backbone.py +16 -15
- keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
- keras_hub/src/models/mistral/mistral_presets.py +3 -9
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
- keras_hub/src/models/mit/__init__.py +6 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
- keras_hub/src/models/mit/mit_image_classifier.py +12 -0
- keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/mit/mit_image_converter.py +8 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
- keras_hub/src/models/mit/mit_presets.py +139 -0
- keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
- keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
- keras_hub/src/models/opt/opt_causal_lm.py +2 -2
- keras_hub/src/models/opt/opt_presets.py +4 -12
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
- keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
- keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
- keras_hub/src/models/phi3/phi3_decoder.py +0 -1
- keras_hub/src/models/phi3/phi3_presets.py +2 -6
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
- keras_hub/src/models/preprocessor.py +25 -11
- keras_hub/src/models/resnet/resnet_backbone.py +3 -14
- keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
- keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
- keras_hub/src/models/resnet/resnet_presets.py +127 -18
- keras_hub/src/models/retinanet/__init__.py +5 -0
- keras_hub/src/models/retinanet/anchor_generator.py +52 -53
- keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
- keras_hub/src/models/retinanet/non_max_supression.py +1 -0
- keras_hub/src/models/retinanet/prediction_head.py +192 -0
- keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
- keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
- keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
- keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
- keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
- keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
- keras_hub/src/models/roberta/roberta_backbone.py +2 -2
- keras_hub/src/models/roberta/roberta_presets.py +6 -8
- keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
- keras_hub/src/models/sam/__init__.py +5 -0
- keras_hub/src/models/sam/sam_backbone.py +2 -3
- keras_hub/src/models/sam/sam_image_converter.py +2 -4
- keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
- keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
- keras_hub/src/models/sam/sam_layers.py +5 -3
- keras_hub/src/models/sam/sam_presets.py +3 -9
- keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
- keras_hub/src/models/sam/sam_transformer.py +5 -4
- keras_hub/src/models/segformer/__init__.py +8 -0
- keras_hub/src/models/segformer/segformer_backbone.py +167 -0
- keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
- keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
- keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
- keras_hub/src/models/segformer/segformer_presets.py +136 -0
- keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
- keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
- keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
- keras_hub/src/models/t5/t5_backbone.py +5 -4
- keras_hub/src/models/t5/t5_presets.py +47 -19
- keras_hub/src/models/task.py +47 -39
- keras_hub/src/models/text_classifier.py +2 -2
- keras_hub/src/models/text_to_image.py +106 -41
- keras_hub/src/models/vae/__init__.py +1 -0
- keras_hub/src/models/vae/vae_backbone.py +184 -0
- keras_hub/src/models/vae/vae_layers.py +739 -0
- keras_hub/src/models/vgg/__init__.py +5 -0
- keras_hub/src/models/vgg/vgg_backbone.py +4 -24
- keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
- keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
- keras_hub/src/models/vgg/vgg_presets.py +48 -0
- keras_hub/src/models/vit/__init__.py +5 -0
- keras_hub/src/models/vit/vit_backbone.py +152 -0
- keras_hub/src/models/vit/vit_image_classifier.py +187 -0
- keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vit/vit_image_converter.py +73 -0
- keras_hub/src/models/vit/vit_layers.py +391 -0
- keras_hub/src/models/vit/vit_presets.py +126 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
- keras_hub/src/models/vit_det/vit_layers.py +3 -3
- keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
- keras_hub/src/models/whisper/whisper_backbone.py +6 -5
- keras_hub/src/models/whisper/whisper_decoder.py +3 -5
- keras_hub/src/models/whisper/whisper_presets.py +10 -30
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
- keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
- keras_hub/src/models/xlnet/relative_attention.py +20 -19
- keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
- keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
- keras_hub/src/samplers/contrastive_sampler.py +2 -3
- keras_hub/src/samplers/sampler.py +2 -1
- keras_hub/src/tests/test_case.py +41 -6
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
- keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
- keras_hub/src/tokenizers/tokenizer.py +10 -13
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
- keras_hub/src/utils/keras_utils.py +2 -13
- keras_hub/src/utils/pipeline_model.py +3 -3
- keras_hub/src/utils/preset_utils.py +196 -144
- keras_hub/src/utils/tensor_utils.py +4 -4
- keras_hub/src/utils/timm/convert_densenet.py +6 -4
- keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
- keras_hub/src/utils/timm/convert_resnet.py +1 -1
- keras_hub/src/utils/timm/convert_vgg.py +85 -0
- keras_hub/src/utils/timm/preset_loader.py +14 -9
- keras_hub/src/utils/transformers/convert_llama3.py +21 -5
- keras_hub/src/utils/transformers/convert_vit.py +150 -0
- keras_hub/src/utils/transformers/preset_loader.py +23 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
- keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
- keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
- keras_hub/src/models/mix_transformer/__init__.py +0 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
- keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
- keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -1,119 +0,0 @@
|
|
1
|
-
import keras
|
2
|
-
|
3
|
-
from keras_hub.src.api_export import keras_hub_export
|
4
|
-
from keras_hub.src.models.image_classifier import ImageClassifier
|
5
|
-
from keras_hub.src.models.mix_transformer.mix_transformer_backbone import (
|
6
|
-
MiTBackbone,
|
7
|
-
)
|
8
|
-
|
9
|
-
|
10
|
-
@keras_hub_export("keras_hub.models.MiTImageClassifier")
|
11
|
-
class MiTImageClassifier(ImageClassifier):
|
12
|
-
"""MiTImageClassifier image classifier model.
|
13
|
-
|
14
|
-
Args:
|
15
|
-
backbone: A `keras_hub.models.MiTBackbone` instance.
|
16
|
-
num_classes: int. The number of classes to predict.
|
17
|
-
activation: `None`, str or callable. The activation function to use on
|
18
|
-
the `Dense` layer. Set `activation=None` to return the output
|
19
|
-
logits. Defaults to `"softmax"`.
|
20
|
-
|
21
|
-
To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)`
|
22
|
-
where `x` is a tensor and `y` is a integer from `[0, num_classes)`.
|
23
|
-
All `ImageClassifier` tasks include a `from_preset()` constructor which can
|
24
|
-
be used to load a pre-trained config and weights.
|
25
|
-
|
26
|
-
Examples:
|
27
|
-
|
28
|
-
Call `predict()` to run inference.
|
29
|
-
```python
|
30
|
-
# Load preset and train
|
31
|
-
images = np.ones((2, 224, 224, 3), dtype="float32")
|
32
|
-
classifier = keras_hub.models.MiTImageClassifier.from_preset(
|
33
|
-
"mit_b0_imagenet")
|
34
|
-
classifier.predict(images)
|
35
|
-
```
|
36
|
-
|
37
|
-
Call `fit()` on a single batch.
|
38
|
-
```python
|
39
|
-
# Load preset and train
|
40
|
-
images = np.ones((2, 224, 224, 3), dtype="float32")
|
41
|
-
labels = [0, 3]
|
42
|
-
classifier = keras_hub.models.MixTransformerImageClassifier.from_preset(
|
43
|
-
"mit_b0_imagenet")
|
44
|
-
classifier.fit(x=images, y=labels, batch_size=2)
|
45
|
-
```
|
46
|
-
|
47
|
-
Call `fit()` with custom loss, optimizer and backbone.
|
48
|
-
```python
|
49
|
-
classifier = keras_hub.models.MiTImageClassifier.from_preset(
|
50
|
-
"mit_b0_imagenet")
|
51
|
-
classifier.compile(
|
52
|
-
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
|
53
|
-
optimizer=keras.optimizers.Adam(5e-5),
|
54
|
-
)
|
55
|
-
classifier.backbone.trainable = False
|
56
|
-
classifier.fit(x=images, y=labels, batch_size=2)
|
57
|
-
```
|
58
|
-
|
59
|
-
Custom backbone.
|
60
|
-
```python
|
61
|
-
images = np.ones((2, 224, 224, 3), dtype="float32")
|
62
|
-
labels = [0, 3]
|
63
|
-
backbone = keras_hub.models.MiTBackbone(
|
64
|
-
stackwise_num_filters=[128, 256, 512, 1024],
|
65
|
-
stackwise_depth=[3, 9, 9, 3],
|
66
|
-
block_type="basic_block",
|
67
|
-
image_shape = (224, 224, 3),
|
68
|
-
)
|
69
|
-
classifier = keras_hub.models.MiTImageClassifier(
|
70
|
-
backbone=backbone,
|
71
|
-
num_classes=4,
|
72
|
-
)
|
73
|
-
classifier.fit(x=images, y=labels, batch_size=2)
|
74
|
-
```
|
75
|
-
"""
|
76
|
-
|
77
|
-
backbone_cls = MiTBackbone
|
78
|
-
|
79
|
-
def __init__(
|
80
|
-
self,
|
81
|
-
backbone,
|
82
|
-
num_classes,
|
83
|
-
activation="softmax",
|
84
|
-
preprocessor=None, # adding this dummy arg for saved model test
|
85
|
-
# TODO: once preprocessor flow is figured out, this needs to be updated
|
86
|
-
**kwargs,
|
87
|
-
):
|
88
|
-
# === Layers ===
|
89
|
-
self.backbone = backbone
|
90
|
-
self.output_dense = keras.layers.Dense(
|
91
|
-
num_classes,
|
92
|
-
activation=activation,
|
93
|
-
name="predictions",
|
94
|
-
)
|
95
|
-
|
96
|
-
# === Functional Model ===
|
97
|
-
inputs = self.backbone.input
|
98
|
-
x = self.backbone(inputs)
|
99
|
-
outputs = self.output_dense(x)
|
100
|
-
super().__init__(
|
101
|
-
inputs=inputs,
|
102
|
-
outputs=outputs,
|
103
|
-
**kwargs,
|
104
|
-
)
|
105
|
-
|
106
|
-
# === Config ===
|
107
|
-
self.num_classes = num_classes
|
108
|
-
self.activation = activation
|
109
|
-
|
110
|
-
def get_config(self):
|
111
|
-
# Backbone serialized in `super`
|
112
|
-
config = super().get_config()
|
113
|
-
config.update(
|
114
|
-
{
|
115
|
-
"num_classes": self.num_classes,
|
116
|
-
"activation": self.activation,
|
117
|
-
}
|
118
|
-
)
|
119
|
-
return config
|
@@ -1,320 +0,0 @@
|
|
1
|
-
import math
|
2
|
-
|
3
|
-
from keras import layers
|
4
|
-
from keras import ops
|
5
|
-
|
6
|
-
from keras_hub.src.models.backbone import Backbone
|
7
|
-
from keras_hub.src.utils.keras_utils import standardize_data_format
|
8
|
-
|
9
|
-
|
10
|
-
class VAEAttention(layers.Layer):
|
11
|
-
def __init__(self, filters, groups=32, data_format=None, **kwargs):
|
12
|
-
super().__init__(**kwargs)
|
13
|
-
self.filters = filters
|
14
|
-
self.data_format = standardize_data_format(data_format)
|
15
|
-
gn_axis = -1 if self.data_format == "channels_last" else 1
|
16
|
-
|
17
|
-
self.group_norm = layers.GroupNormalization(
|
18
|
-
groups=groups,
|
19
|
-
axis=gn_axis,
|
20
|
-
epsilon=1e-6,
|
21
|
-
dtype="float32",
|
22
|
-
name="group_norm",
|
23
|
-
)
|
24
|
-
self.query_conv2d = layers.Conv2D(
|
25
|
-
filters,
|
26
|
-
1,
|
27
|
-
1,
|
28
|
-
data_format=self.data_format,
|
29
|
-
dtype=self.dtype_policy,
|
30
|
-
name="query_conv2d",
|
31
|
-
)
|
32
|
-
self.key_conv2d = layers.Conv2D(
|
33
|
-
filters,
|
34
|
-
1,
|
35
|
-
1,
|
36
|
-
data_format=self.data_format,
|
37
|
-
dtype=self.dtype_policy,
|
38
|
-
name="key_conv2d",
|
39
|
-
)
|
40
|
-
self.value_conv2d = layers.Conv2D(
|
41
|
-
filters,
|
42
|
-
1,
|
43
|
-
1,
|
44
|
-
data_format=self.data_format,
|
45
|
-
dtype=self.dtype_policy,
|
46
|
-
name="value_conv2d",
|
47
|
-
)
|
48
|
-
self.softmax = layers.Softmax(dtype="float32")
|
49
|
-
self.output_conv2d = layers.Conv2D(
|
50
|
-
filters,
|
51
|
-
1,
|
52
|
-
1,
|
53
|
-
data_format=self.data_format,
|
54
|
-
dtype=self.dtype_policy,
|
55
|
-
name="output_conv2d",
|
56
|
-
)
|
57
|
-
|
58
|
-
self.groups = groups
|
59
|
-
self._inverse_sqrt_filters = 1.0 / math.sqrt(float(filters))
|
60
|
-
|
61
|
-
def build(self, input_shape):
|
62
|
-
self.group_norm.build(input_shape)
|
63
|
-
self.query_conv2d.build(input_shape)
|
64
|
-
self.key_conv2d.build(input_shape)
|
65
|
-
self.value_conv2d.build(input_shape)
|
66
|
-
self.output_conv2d.build(input_shape)
|
67
|
-
|
68
|
-
def call(self, inputs, training=None):
|
69
|
-
x = self.group_norm(inputs)
|
70
|
-
query = self.query_conv2d(x)
|
71
|
-
key = self.key_conv2d(x)
|
72
|
-
value = self.value_conv2d(x)
|
73
|
-
|
74
|
-
if self.data_format == "channels_first":
|
75
|
-
query = ops.transpose(query, (0, 2, 3, 1))
|
76
|
-
key = ops.transpose(key, (0, 2, 3, 1))
|
77
|
-
value = ops.transpose(value, (0, 2, 3, 1))
|
78
|
-
shape = ops.shape(inputs)
|
79
|
-
b = shape[0]
|
80
|
-
query = ops.reshape(query, (b, -1, self.filters))
|
81
|
-
key = ops.reshape(key, (b, -1, self.filters))
|
82
|
-
value = ops.reshape(value, (b, -1, self.filters))
|
83
|
-
|
84
|
-
# Compute attention.
|
85
|
-
query = ops.multiply(
|
86
|
-
query, ops.cast(self._inverse_sqrt_filters, query.dtype)
|
87
|
-
)
|
88
|
-
# [B, H0 * W0, C], [B, H1 * W1, C] -> [B, H0 * W0, H1 * W1]
|
89
|
-
attention_scores = ops.einsum("abc,adc->abd", query, key)
|
90
|
-
attention_scores = ops.cast(
|
91
|
-
self.softmax(attention_scores), self.compute_dtype
|
92
|
-
)
|
93
|
-
# [B, H2 * W2, C], [B, H0 * W0, H1 * W1] -> [B, H1 * W1 ,C]
|
94
|
-
attention_output = ops.einsum("abc,adb->adc", value, attention_scores)
|
95
|
-
x = ops.reshape(attention_output, shape)
|
96
|
-
|
97
|
-
x = self.output_conv2d(x)
|
98
|
-
if self.data_format == "channels_first":
|
99
|
-
x = ops.transpose(x, (0, 3, 1, 2))
|
100
|
-
x = ops.add(x, inputs)
|
101
|
-
return x
|
102
|
-
|
103
|
-
def get_config(self):
|
104
|
-
config = super().get_config()
|
105
|
-
config.update(
|
106
|
-
{
|
107
|
-
"filters": self.filters,
|
108
|
-
"groups": self.groups,
|
109
|
-
}
|
110
|
-
)
|
111
|
-
return config
|
112
|
-
|
113
|
-
def compute_output_shape(self, input_shape):
|
114
|
-
return input_shape
|
115
|
-
|
116
|
-
|
117
|
-
def apply_resnet_block(x, filters, data_format=None, dtype=None, name=None):
|
118
|
-
data_format = standardize_data_format(data_format)
|
119
|
-
gn_axis = -1 if data_format == "channels_last" else 1
|
120
|
-
input_filters = x.shape[gn_axis]
|
121
|
-
|
122
|
-
residual = x
|
123
|
-
x = layers.GroupNormalization(
|
124
|
-
groups=32,
|
125
|
-
axis=gn_axis,
|
126
|
-
epsilon=1e-6,
|
127
|
-
dtype="float32",
|
128
|
-
name=f"{name}_norm1",
|
129
|
-
)(x)
|
130
|
-
x = layers.Activation("swish", dtype=dtype)(x)
|
131
|
-
x = layers.Conv2D(
|
132
|
-
filters,
|
133
|
-
3,
|
134
|
-
1,
|
135
|
-
padding="same",
|
136
|
-
data_format=data_format,
|
137
|
-
dtype=dtype,
|
138
|
-
name=f"{name}_conv1",
|
139
|
-
)(x)
|
140
|
-
x = layers.GroupNormalization(
|
141
|
-
groups=32,
|
142
|
-
axis=gn_axis,
|
143
|
-
epsilon=1e-6,
|
144
|
-
dtype="float32",
|
145
|
-
name=f"{name}_norm2",
|
146
|
-
)(x)
|
147
|
-
x = layers.Activation("swish", dtype=dtype)(x)
|
148
|
-
x = layers.Conv2D(
|
149
|
-
filters,
|
150
|
-
3,
|
151
|
-
1,
|
152
|
-
padding="same",
|
153
|
-
data_format=data_format,
|
154
|
-
dtype=dtype,
|
155
|
-
name=f"{name}_conv2",
|
156
|
-
)(x)
|
157
|
-
if input_filters != filters:
|
158
|
-
residual = layers.Conv2D(
|
159
|
-
filters,
|
160
|
-
1,
|
161
|
-
1,
|
162
|
-
data_format=data_format,
|
163
|
-
dtype=dtype,
|
164
|
-
name=f"{name}_residual_projection",
|
165
|
-
)(residual)
|
166
|
-
x = layers.Add(dtype=dtype)([residual, x])
|
167
|
-
return x
|
168
|
-
|
169
|
-
|
170
|
-
class VAEImageDecoder(Backbone):
|
171
|
-
"""Decoder for the VAE model used in Stable Diffusion 3.
|
172
|
-
|
173
|
-
Args:
|
174
|
-
stackwise_num_filters: list of ints. The number of filters for each
|
175
|
-
stack.
|
176
|
-
stackwise_num_blocks: list of ints. The number of blocks for each stack.
|
177
|
-
output_channels: int. The number of channels in the output.
|
178
|
-
latent_shape: tuple. The shape of the latent image.
|
179
|
-
data_format: `None` or str. If specified, either `"channels_last"` or
|
180
|
-
`"channels_first"`. The ordering of the dimensions in the
|
181
|
-
inputs. `"channels_last"` corresponds to inputs with shape
|
182
|
-
`(batch_size, height, width, channels)`
|
183
|
-
while `"channels_first"` corresponds to inputs with shape
|
184
|
-
`(batch_size, channels, height, width)`. It defaults to the
|
185
|
-
`image_data_format` value found in your Keras config file at
|
186
|
-
`~/.keras/keras.json`. If you never set it, then it will be
|
187
|
-
`"channels_last"`.
|
188
|
-
dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
|
189
|
-
to use for the model's computations and weights.
|
190
|
-
"""
|
191
|
-
|
192
|
-
def __init__(
|
193
|
-
self,
|
194
|
-
stackwise_num_filters,
|
195
|
-
stackwise_num_blocks,
|
196
|
-
output_channels=3,
|
197
|
-
latent_shape=(None, None, 16),
|
198
|
-
data_format=None,
|
199
|
-
dtype=None,
|
200
|
-
**kwargs,
|
201
|
-
):
|
202
|
-
data_format = standardize_data_format(data_format)
|
203
|
-
gn_axis = -1 if data_format == "channels_last" else 1
|
204
|
-
|
205
|
-
# === Functional Model ===
|
206
|
-
latent_inputs = layers.Input(shape=latent_shape)
|
207
|
-
|
208
|
-
x = layers.Conv2D(
|
209
|
-
stackwise_num_filters[0],
|
210
|
-
3,
|
211
|
-
1,
|
212
|
-
padding="same",
|
213
|
-
data_format=data_format,
|
214
|
-
dtype=dtype,
|
215
|
-
name="input_projection",
|
216
|
-
)(latent_inputs)
|
217
|
-
x = apply_resnet_block(
|
218
|
-
x,
|
219
|
-
stackwise_num_filters[0],
|
220
|
-
data_format=data_format,
|
221
|
-
dtype=dtype,
|
222
|
-
name="input_block0",
|
223
|
-
)
|
224
|
-
x = VAEAttention(
|
225
|
-
stackwise_num_filters[0],
|
226
|
-
data_format=data_format,
|
227
|
-
dtype=dtype,
|
228
|
-
name="input_attention",
|
229
|
-
)(x)
|
230
|
-
x = apply_resnet_block(
|
231
|
-
x,
|
232
|
-
stackwise_num_filters[0],
|
233
|
-
data_format=data_format,
|
234
|
-
dtype=dtype,
|
235
|
-
name="input_block1",
|
236
|
-
)
|
237
|
-
|
238
|
-
# Stacks.
|
239
|
-
for i, filters in enumerate(stackwise_num_filters):
|
240
|
-
for j in range(stackwise_num_blocks[i]):
|
241
|
-
x = apply_resnet_block(
|
242
|
-
x,
|
243
|
-
filters,
|
244
|
-
data_format=data_format,
|
245
|
-
dtype=dtype,
|
246
|
-
name=f"block{i}_{j}",
|
247
|
-
)
|
248
|
-
if i != len(stackwise_num_filters) - 1:
|
249
|
-
# No upsamling in the last blcok.
|
250
|
-
x = layers.UpSampling2D(
|
251
|
-
2,
|
252
|
-
data_format=data_format,
|
253
|
-
dtype=dtype,
|
254
|
-
name=f"upsample_{i}",
|
255
|
-
)(x)
|
256
|
-
x = layers.Conv2D(
|
257
|
-
filters,
|
258
|
-
3,
|
259
|
-
1,
|
260
|
-
padding="same",
|
261
|
-
data_format=data_format,
|
262
|
-
dtype=dtype,
|
263
|
-
name=f"upsample_{i}_conv",
|
264
|
-
)(x)
|
265
|
-
|
266
|
-
# Ouput block.
|
267
|
-
x = layers.GroupNormalization(
|
268
|
-
groups=32,
|
269
|
-
axis=gn_axis,
|
270
|
-
epsilon=1e-6,
|
271
|
-
dtype="float32",
|
272
|
-
name="output_norm",
|
273
|
-
)(x)
|
274
|
-
x = layers.Activation("swish", dtype=dtype, name="output_activation")(x)
|
275
|
-
image_outputs = layers.Conv2D(
|
276
|
-
output_channels,
|
277
|
-
3,
|
278
|
-
1,
|
279
|
-
padding="same",
|
280
|
-
data_format=data_format,
|
281
|
-
dtype=dtype,
|
282
|
-
name="output_projection",
|
283
|
-
)(x)
|
284
|
-
super().__init__(inputs=latent_inputs, outputs=image_outputs, **kwargs)
|
285
|
-
|
286
|
-
# === Config ===
|
287
|
-
self.stackwise_num_filters = stackwise_num_filters
|
288
|
-
self.stackwise_num_blocks = stackwise_num_blocks
|
289
|
-
self.output_channels = output_channels
|
290
|
-
self.latent_shape = latent_shape
|
291
|
-
|
292
|
-
@property
|
293
|
-
def scaling_factor(self):
|
294
|
-
"""The scaling factor for the latent space.
|
295
|
-
|
296
|
-
This is used to scale the latent space to have unit variance when
|
297
|
-
training the diffusion model.
|
298
|
-
"""
|
299
|
-
return 1.5305
|
300
|
-
|
301
|
-
@property
|
302
|
-
def shift_factor(self):
|
303
|
-
"""The shift factor for the latent space.
|
304
|
-
|
305
|
-
This is used to shift the latent space to have zero mean when
|
306
|
-
training the diffusion model.
|
307
|
-
"""
|
308
|
-
return 0.0609
|
309
|
-
|
310
|
-
def get_config(self):
|
311
|
-
config = super().get_config()
|
312
|
-
config.update(
|
313
|
-
{
|
314
|
-
"stackwise_num_filters": self.stackwise_num_filters,
|
315
|
-
"stackwise_num_blocks": self.stackwise_num_blocks,
|
316
|
-
"output_channels": self.output_channels,
|
317
|
-
"image_shape": self.latent_shape,
|
318
|
-
}
|
319
|
-
)
|
320
|
-
return config
|