keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +21 -3
- keras_hub/api/models/__init__.py +71 -12
- keras_hub/api/tokenizers/__init__.py +1 -1
- keras_hub/src/bounding_box/__init__.py +2 -0
- keras_hub/src/bounding_box/converters.py +102 -12
- keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
- keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
- keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
- keras_hub/src/layers/modeling/rms_normalization.py +36 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
- keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
- keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
- keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
- keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
- keras_hub/src/layers/preprocessing/image_converter.py +170 -34
- keras_hub/src/metrics/bleu.py +4 -3
- keras_hub/src/models/albert/albert_presets.py +4 -12
- keras_hub/src/models/albert/albert_text_classifier.py +7 -7
- keras_hub/src/models/backbone.py +3 -14
- keras_hub/src/models/bart/bart_backbone.py +4 -4
- keras_hub/src/models/bart/bart_presets.py +3 -9
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
- keras_hub/src/models/basnet/__init__.py +5 -0
- keras_hub/src/models/basnet/basnet.py +122 -0
- keras_hub/src/models/basnet/basnet_backbone.py +366 -0
- keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
- keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
- keras_hub/src/models/basnet/basnet_presets.py +17 -0
- keras_hub/src/models/bert/bert_presets.py +14 -32
- keras_hub/src/models/bert/bert_text_classifier.py +3 -3
- keras_hub/src/models/bloom/bloom_presets.py +8 -24
- keras_hub/src/models/causal_lm.py +56 -12
- keras_hub/src/models/clip/__init__.py +5 -0
- keras_hub/src/models/clip/clip_backbone.py +286 -0
- keras_hub/src/models/clip/clip_encoder_block.py +19 -4
- keras_hub/src/models/clip/clip_image_converter.py +8 -0
- keras_hub/src/models/clip/clip_presets.py +93 -0
- keras_hub/src/models/clip/clip_text_encoder.py +4 -1
- keras_hub/src/models/clip/clip_tokenizer.py +18 -3
- keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
- keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
- keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
- keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
- keras_hub/src/models/deeplab_v3/__init__.py +7 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
- keras_hub/src/models/densenet/densenet_backbone.py +6 -4
- keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
- keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
- keras_hub/src/models/densenet/densenet_presets.py +9 -15
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
- keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
- keras_hub/src/models/efficientnet/__init__.py +9 -0
- keras_hub/src/models/efficientnet/cba.py +141 -0
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
- keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
- keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
- keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
- keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
- keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
- keras_hub/src/models/efficientnet/mbconv.py +53 -22
- keras_hub/src/models/electra/electra_backbone.py +2 -2
- keras_hub/src/models/electra/electra_presets.py +6 -18
- keras_hub/src/models/f_net/f_net_presets.py +2 -6
- keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
- keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
- keras_hub/src/models/falcon/falcon_backbone.py +5 -3
- keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
- keras_hub/src/models/falcon/falcon_presets.py +1 -3
- keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
- keras_hub/src/models/feature_pyramid_backbone.py +1 -1
- keras_hub/src/models/flux/__init__.py +5 -0
- keras_hub/src/models/flux/flux_layers.py +496 -0
- keras_hub/src/models/flux/flux_maths.py +225 -0
- keras_hub/src/models/flux/flux_model.py +236 -0
- keras_hub/src/models/flux/flux_presets.py +3 -0
- keras_hub/src/models/flux/flux_text_to_image.py +146 -0
- keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
- keras_hub/src/models/gemma/gemma_backbone.py +35 -20
- keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
- keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
- keras_hub/src/models/gemma/gemma_presets.py +29 -63
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
- keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
- keras_hub/src/models/image_classifier.py +147 -2
- keras_hub/src/models/image_classifier_preprocessor.py +6 -3
- keras_hub/src/models/image_object_detector.py +87 -0
- keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
- keras_hub/src/models/image_segmenter.py +0 -5
- keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
- keras_hub/src/models/image_to_image.py +417 -0
- keras_hub/src/models/inpaint.py +520 -0
- keras_hub/src/models/llama/llama_backbone.py +138 -12
- keras_hub/src/models/llama/llama_causal_lm.py +3 -1
- keras_hub/src/models/llama/llama_presets.py +10 -20
- keras_hub/src/models/llama3/llama3_backbone.py +12 -11
- keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
- keras_hub/src/models/llama3/llama3_presets.py +4 -12
- keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
- keras_hub/src/models/mistral/mistral_backbone.py +16 -15
- keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
- keras_hub/src/models/mistral/mistral_presets.py +3 -9
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
- keras_hub/src/models/mit/__init__.py +6 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
- keras_hub/src/models/mit/mit_image_classifier.py +12 -0
- keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/mit/mit_image_converter.py +8 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
- keras_hub/src/models/mit/mit_presets.py +139 -0
- keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
- keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
- keras_hub/src/models/opt/opt_causal_lm.py +2 -2
- keras_hub/src/models/opt/opt_presets.py +4 -12
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
- keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
- keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
- keras_hub/src/models/phi3/phi3_decoder.py +0 -1
- keras_hub/src/models/phi3/phi3_presets.py +2 -6
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
- keras_hub/src/models/preprocessor.py +25 -11
- keras_hub/src/models/resnet/resnet_backbone.py +3 -14
- keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
- keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
- keras_hub/src/models/resnet/resnet_presets.py +127 -18
- keras_hub/src/models/retinanet/__init__.py +5 -0
- keras_hub/src/models/retinanet/anchor_generator.py +52 -53
- keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
- keras_hub/src/models/retinanet/non_max_supression.py +1 -0
- keras_hub/src/models/retinanet/prediction_head.py +192 -0
- keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
- keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
- keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
- keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
- keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
- keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
- keras_hub/src/models/roberta/roberta_backbone.py +2 -2
- keras_hub/src/models/roberta/roberta_presets.py +6 -8
- keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
- keras_hub/src/models/sam/__init__.py +5 -0
- keras_hub/src/models/sam/sam_backbone.py +2 -3
- keras_hub/src/models/sam/sam_image_converter.py +2 -4
- keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
- keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
- keras_hub/src/models/sam/sam_layers.py +5 -3
- keras_hub/src/models/sam/sam_presets.py +3 -9
- keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
- keras_hub/src/models/sam/sam_transformer.py +5 -4
- keras_hub/src/models/segformer/__init__.py +8 -0
- keras_hub/src/models/segformer/segformer_backbone.py +167 -0
- keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
- keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
- keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
- keras_hub/src/models/segformer/segformer_presets.py +136 -0
- keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
- keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
- keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
- keras_hub/src/models/t5/t5_backbone.py +5 -4
- keras_hub/src/models/t5/t5_presets.py +47 -19
- keras_hub/src/models/task.py +47 -39
- keras_hub/src/models/text_classifier.py +2 -2
- keras_hub/src/models/text_to_image.py +106 -41
- keras_hub/src/models/vae/__init__.py +1 -0
- keras_hub/src/models/vae/vae_backbone.py +184 -0
- keras_hub/src/models/vae/vae_layers.py +739 -0
- keras_hub/src/models/vgg/__init__.py +5 -0
- keras_hub/src/models/vgg/vgg_backbone.py +4 -24
- keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
- keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
- keras_hub/src/models/vgg/vgg_presets.py +48 -0
- keras_hub/src/models/vit/__init__.py +5 -0
- keras_hub/src/models/vit/vit_backbone.py +152 -0
- keras_hub/src/models/vit/vit_image_classifier.py +187 -0
- keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vit/vit_image_converter.py +73 -0
- keras_hub/src/models/vit/vit_layers.py +391 -0
- keras_hub/src/models/vit/vit_presets.py +126 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
- keras_hub/src/models/vit_det/vit_layers.py +3 -3
- keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
- keras_hub/src/models/whisper/whisper_backbone.py +6 -5
- keras_hub/src/models/whisper/whisper_decoder.py +3 -5
- keras_hub/src/models/whisper/whisper_presets.py +10 -30
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
- keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
- keras_hub/src/models/xlnet/relative_attention.py +20 -19
- keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
- keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
- keras_hub/src/samplers/contrastive_sampler.py +2 -3
- keras_hub/src/samplers/sampler.py +2 -1
- keras_hub/src/tests/test_case.py +41 -6
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
- keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
- keras_hub/src/tokenizers/tokenizer.py +10 -13
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
- keras_hub/src/utils/keras_utils.py +2 -13
- keras_hub/src/utils/pipeline_model.py +3 -3
- keras_hub/src/utils/preset_utils.py +196 -144
- keras_hub/src/utils/tensor_utils.py +4 -4
- keras_hub/src/utils/timm/convert_densenet.py +6 -4
- keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
- keras_hub/src/utils/timm/convert_resnet.py +1 -1
- keras_hub/src/utils/timm/convert_vgg.py +85 -0
- keras_hub/src/utils/timm/preset_loader.py +14 -9
- keras_hub/src/utils/transformers/convert_llama3.py +21 -5
- keras_hub/src/utils/transformers/convert_vit.py +150 -0
- keras_hub/src/utils/transformers/preset_loader.py +23 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
- keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
- keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
- keras_hub/src/models/mix_transformer/__init__.py +0 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
- keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
- keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,739 @@
|
|
1
|
+
import math
|
2
|
+
|
3
|
+
import keras
|
4
|
+
from keras import ops
|
5
|
+
|
6
|
+
from keras_hub.src.utils.keras_utils import standardize_data_format
|
7
|
+
|
8
|
+
|
9
|
+
class Conv2DMultiHeadAttention(keras.layers.Layer):
|
10
|
+
"""A MultiHeadAttention layer utilizing `Conv2D` and `GroupNormalization`.
|
11
|
+
|
12
|
+
Args:
|
13
|
+
filters: int. The number of the filters for the convolutional layers.
|
14
|
+
groups: int. The number of the groups for the group normalization
|
15
|
+
layers. Defaults to `32`.
|
16
|
+
data_format: `None` or str. If specified, either `"channels_last"` or
|
17
|
+
`"channels_first"`. The ordering of the dimensions in the
|
18
|
+
inputs. `"channels_last"` corresponds to inputs with shape
|
19
|
+
`(batch_size, height, width, channels)`
|
20
|
+
while `"channels_first"` corresponds to inputs with shape
|
21
|
+
`(batch_size, channels, height, width)`. It defaults to the
|
22
|
+
`image_data_format` value found in your Keras config file at
|
23
|
+
`~/.keras/keras.json`. If you never set it, then it will be
|
24
|
+
`"channels_last"`.
|
25
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
26
|
+
including `name`, `dtype` etc.
|
27
|
+
"""
|
28
|
+
|
29
|
+
def __init__(self, filters, groups=32, data_format=None, **kwargs):
|
30
|
+
super().__init__(**kwargs)
|
31
|
+
data_format = standardize_data_format(data_format)
|
32
|
+
channel_axis = -1 if data_format == "channels_last" else 1
|
33
|
+
self.filters = int(filters)
|
34
|
+
self.groups = int(groups)
|
35
|
+
self._inverse_sqrt_filters = 1.0 / math.sqrt(float(filters))
|
36
|
+
self.data_format = data_format
|
37
|
+
|
38
|
+
self.group_norm = keras.layers.GroupNormalization(
|
39
|
+
groups=groups,
|
40
|
+
axis=channel_axis,
|
41
|
+
epsilon=1e-6,
|
42
|
+
dtype=self.dtype_policy,
|
43
|
+
name="group_norm",
|
44
|
+
)
|
45
|
+
self.query_conv2d = keras.layers.Conv2D(
|
46
|
+
filters,
|
47
|
+
1,
|
48
|
+
1,
|
49
|
+
data_format=data_format,
|
50
|
+
dtype=self.dtype_policy,
|
51
|
+
name="query_conv2d",
|
52
|
+
)
|
53
|
+
self.key_conv2d = keras.layers.Conv2D(
|
54
|
+
filters,
|
55
|
+
1,
|
56
|
+
1,
|
57
|
+
data_format=data_format,
|
58
|
+
dtype=self.dtype_policy,
|
59
|
+
name="key_conv2d",
|
60
|
+
)
|
61
|
+
self.value_conv2d = keras.layers.Conv2D(
|
62
|
+
filters,
|
63
|
+
1,
|
64
|
+
1,
|
65
|
+
data_format=data_format,
|
66
|
+
dtype=self.dtype_policy,
|
67
|
+
name="value_conv2d",
|
68
|
+
)
|
69
|
+
self.softmax = keras.layers.Softmax(dtype="float32")
|
70
|
+
self.output_conv2d = keras.layers.Conv2D(
|
71
|
+
filters,
|
72
|
+
1,
|
73
|
+
1,
|
74
|
+
data_format=data_format,
|
75
|
+
dtype=self.dtype_policy,
|
76
|
+
name="output_conv2d",
|
77
|
+
)
|
78
|
+
|
79
|
+
def build(self, input_shape):
|
80
|
+
self.group_norm.build(input_shape)
|
81
|
+
self.query_conv2d.build(input_shape)
|
82
|
+
self.key_conv2d.build(input_shape)
|
83
|
+
self.value_conv2d.build(input_shape)
|
84
|
+
self.output_conv2d.build(input_shape)
|
85
|
+
|
86
|
+
def call(self, inputs, training=None):
|
87
|
+
x = self.group_norm(inputs, training=training)
|
88
|
+
query = self.query_conv2d(x, training=training)
|
89
|
+
key = self.key_conv2d(x, training=training)
|
90
|
+
value = self.value_conv2d(x, training=training)
|
91
|
+
|
92
|
+
if self.data_format == "channels_first":
|
93
|
+
query = ops.transpose(query, (0, 2, 3, 1))
|
94
|
+
key = ops.transpose(key, (0, 2, 3, 1))
|
95
|
+
value = ops.transpose(value, (0, 2, 3, 1))
|
96
|
+
shape = ops.shape(inputs)
|
97
|
+
b = shape[0]
|
98
|
+
query = ops.reshape(query, (b, -1, self.filters))
|
99
|
+
key = ops.reshape(key, (b, -1, self.filters))
|
100
|
+
value = ops.reshape(value, (b, -1, self.filters))
|
101
|
+
|
102
|
+
# Compute attention.
|
103
|
+
query = ops.multiply(
|
104
|
+
query, ops.cast(self._inverse_sqrt_filters, query.dtype)
|
105
|
+
)
|
106
|
+
# [B, H0 * W0, C], [B, H1 * W1, C] -> [B, H0 * W0, H1 * W1]
|
107
|
+
attention_scores = ops.einsum("abc,adc->abd", query, key)
|
108
|
+
attention_scores = ops.cast(
|
109
|
+
self.softmax(attention_scores), self.compute_dtype
|
110
|
+
)
|
111
|
+
# [B, H2 * W2, C], [B, H0 * W0, H1 * W1] -> [B, H1 * W1 ,C]
|
112
|
+
attention_output = ops.einsum("abc,adb->adc", value, attention_scores)
|
113
|
+
x = ops.reshape(attention_output, shape)
|
114
|
+
|
115
|
+
x = self.output_conv2d(x, training=training)
|
116
|
+
if self.data_format == "channels_first":
|
117
|
+
x = ops.transpose(x, (0, 3, 1, 2))
|
118
|
+
x = ops.add(x, inputs)
|
119
|
+
return x
|
120
|
+
|
121
|
+
def get_config(self):
|
122
|
+
config = super().get_config()
|
123
|
+
config.update(
|
124
|
+
{
|
125
|
+
"filters": self.filters,
|
126
|
+
"groups": self.groups,
|
127
|
+
}
|
128
|
+
)
|
129
|
+
return config
|
130
|
+
|
131
|
+
def compute_output_shape(self, input_shape):
|
132
|
+
return input_shape
|
133
|
+
|
134
|
+
|
135
|
+
class ResNetBlock(keras.layers.Layer):
|
136
|
+
"""A ResNet block utilizing `GroupNormalization` and SiLU activation.
|
137
|
+
|
138
|
+
Args:
|
139
|
+
filters: The number of filters in the block.
|
140
|
+
has_residual_projection: Whether to add a projection layer for the
|
141
|
+
residual connection. Defaults to `False`.
|
142
|
+
data_format: `None` or str. If specified, either `"channels_last"` or
|
143
|
+
`"channels_first"`. The ordering of the dimensions in the
|
144
|
+
inputs. `"channels_last"` corresponds to inputs with shape
|
145
|
+
`(batch_size, height, width, channels)`
|
146
|
+
while `"channels_first"` corresponds to inputs with shape
|
147
|
+
`(batch_size, channels, height, width)`. It defaults to the
|
148
|
+
`image_data_format` value found in your Keras config file at
|
149
|
+
`~/.keras/keras.json`. If you never set it, then it will be
|
150
|
+
`"channels_last"`.
|
151
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
152
|
+
including `name`, `dtype` etc.
|
153
|
+
"""
|
154
|
+
|
155
|
+
def __init__(
|
156
|
+
self,
|
157
|
+
filters,
|
158
|
+
has_residual_projection=False,
|
159
|
+
data_format=None,
|
160
|
+
**kwargs,
|
161
|
+
):
|
162
|
+
super().__init__(**kwargs)
|
163
|
+
data_format = standardize_data_format(data_format)
|
164
|
+
channel_axis = -1 if data_format == "channels_last" else 1
|
165
|
+
self.filters = int(filters)
|
166
|
+
self.has_residual_projection = bool(has_residual_projection)
|
167
|
+
|
168
|
+
# === Layers ===
|
169
|
+
self.norm1 = keras.layers.GroupNormalization(
|
170
|
+
groups=32,
|
171
|
+
axis=channel_axis,
|
172
|
+
epsilon=1e-6,
|
173
|
+
dtype=self.dtype_policy,
|
174
|
+
name="norm1",
|
175
|
+
)
|
176
|
+
self.act1 = keras.layers.Activation("silu", dtype=self.dtype_policy)
|
177
|
+
self.conv1 = keras.layers.Conv2D(
|
178
|
+
filters,
|
179
|
+
3,
|
180
|
+
1,
|
181
|
+
padding="same",
|
182
|
+
data_format=data_format,
|
183
|
+
dtype=self.dtype_policy,
|
184
|
+
name="conv1",
|
185
|
+
)
|
186
|
+
self.norm2 = keras.layers.GroupNormalization(
|
187
|
+
groups=32,
|
188
|
+
axis=channel_axis,
|
189
|
+
epsilon=1e-6,
|
190
|
+
dtype=self.dtype_policy,
|
191
|
+
name="norm2",
|
192
|
+
)
|
193
|
+
self.act2 = keras.layers.Activation("silu", dtype=self.dtype_policy)
|
194
|
+
self.conv2 = keras.layers.Conv2D(
|
195
|
+
filters,
|
196
|
+
3,
|
197
|
+
1,
|
198
|
+
padding="same",
|
199
|
+
data_format=data_format,
|
200
|
+
dtype=self.dtype_policy,
|
201
|
+
name="conv2",
|
202
|
+
)
|
203
|
+
if self.has_residual_projection:
|
204
|
+
self.residual_projection = keras.layers.Conv2D(
|
205
|
+
filters,
|
206
|
+
1,
|
207
|
+
1,
|
208
|
+
data_format=data_format,
|
209
|
+
dtype=self.dtype_policy,
|
210
|
+
name="residual_projection",
|
211
|
+
)
|
212
|
+
self.add = keras.layers.Add(dtype=self.dtype_policy)
|
213
|
+
|
214
|
+
def build(self, input_shape):
|
215
|
+
residual_shape = list(input_shape)
|
216
|
+
self.norm1.build(input_shape)
|
217
|
+
self.act1.build(input_shape)
|
218
|
+
self.conv1.build(input_shape)
|
219
|
+
input_shape = self.conv1.compute_output_shape(input_shape)
|
220
|
+
self.norm2.build(input_shape)
|
221
|
+
self.act2.build(input_shape)
|
222
|
+
self.conv2.build(input_shape)
|
223
|
+
input_shape = self.conv2.compute_output_shape(input_shape)
|
224
|
+
if self.has_residual_projection:
|
225
|
+
self.residual_projection.build(residual_shape)
|
226
|
+
self.add.build([input_shape, input_shape])
|
227
|
+
|
228
|
+
def call(self, inputs, training=None):
|
229
|
+
x = inputs
|
230
|
+
residual = x
|
231
|
+
x = self.norm1(x, training=training)
|
232
|
+
x = self.act1(x, training=training)
|
233
|
+
x = self.conv1(x, training=training)
|
234
|
+
x = self.norm2(x, training=training)
|
235
|
+
x = self.act2(x, training=training)
|
236
|
+
x = self.conv2(x, training=training)
|
237
|
+
if self.has_residual_projection:
|
238
|
+
residual = self.residual_projection(residual, training=training)
|
239
|
+
x = self.add([residual, x])
|
240
|
+
return x
|
241
|
+
|
242
|
+
def get_config(self):
|
243
|
+
config = super().get_config()
|
244
|
+
config.update(
|
245
|
+
{
|
246
|
+
"filters": self.filters,
|
247
|
+
"has_residual_projection": self.has_residual_projection,
|
248
|
+
}
|
249
|
+
)
|
250
|
+
return config
|
251
|
+
|
252
|
+
def compute_output_shape(self, input_shape):
|
253
|
+
outputs_shape = list(input_shape)
|
254
|
+
if self.has_residual_projection:
|
255
|
+
outputs_shape = self.residual_projection.compute_output_shape(
|
256
|
+
outputs_shape
|
257
|
+
)
|
258
|
+
return outputs_shape
|
259
|
+
|
260
|
+
|
261
|
+
class VAEEncoder(keras.layers.Layer):
|
262
|
+
"""The encoder layer of VAE.
|
263
|
+
|
264
|
+
Args:
|
265
|
+
stackwise_num_filters: list of ints. The number of filters for each
|
266
|
+
stack.
|
267
|
+
stackwise_num_blocks: list of ints. The number of blocks for each stack.
|
268
|
+
output_channels: int. The number of channels in the output. Defaults to
|
269
|
+
`32`.
|
270
|
+
data_format: `None` or str. If specified, either `"channels_last"` or
|
271
|
+
`"channels_first"`. The ordering of the dimensions in the
|
272
|
+
inputs. `"channels_last"` corresponds to inputs with shape
|
273
|
+
`(batch_size, height, width, channels)`
|
274
|
+
while `"channels_first"` corresponds to inputs with shape
|
275
|
+
`(batch_size, channels, height, width)`. It defaults to the
|
276
|
+
`image_data_format` value found in your Keras config file at
|
277
|
+
`~/.keras/keras.json`. If you never set it, then it will be
|
278
|
+
`"channels_last"`.
|
279
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
280
|
+
including `name`, `dtype` etc.
|
281
|
+
"""
|
282
|
+
|
283
|
+
def __init__(
|
284
|
+
self,
|
285
|
+
stackwise_num_filters,
|
286
|
+
stackwise_num_blocks,
|
287
|
+
output_channels=32,
|
288
|
+
data_format=None,
|
289
|
+
**kwargs,
|
290
|
+
):
|
291
|
+
super().__init__(**kwargs)
|
292
|
+
data_format = standardize_data_format(data_format)
|
293
|
+
channel_axis = -1 if data_format == "channels_last" else 1
|
294
|
+
self.stackwise_num_filters = stackwise_num_filters
|
295
|
+
self.stackwise_num_blocks = stackwise_num_blocks
|
296
|
+
self.output_channels = int(output_channels)
|
297
|
+
self.data_format = data_format
|
298
|
+
|
299
|
+
# === Layers ===
|
300
|
+
self.input_projection = keras.layers.Conv2D(
|
301
|
+
stackwise_num_filters[0],
|
302
|
+
3,
|
303
|
+
1,
|
304
|
+
padding="same",
|
305
|
+
data_format=data_format,
|
306
|
+
dtype=self.dtype_policy,
|
307
|
+
name="input_projection",
|
308
|
+
)
|
309
|
+
|
310
|
+
# Blocks.
|
311
|
+
input_filters = stackwise_num_filters[0]
|
312
|
+
self.blocks = []
|
313
|
+
self.downsamples = []
|
314
|
+
for i, filters in enumerate(stackwise_num_filters):
|
315
|
+
for j in range(stackwise_num_blocks[i]):
|
316
|
+
self.blocks.append(
|
317
|
+
ResNetBlock(
|
318
|
+
filters,
|
319
|
+
has_residual_projection=input_filters != filters,
|
320
|
+
data_format=data_format,
|
321
|
+
dtype=self.dtype_policy,
|
322
|
+
name=f"block_{i}_{j}",
|
323
|
+
)
|
324
|
+
)
|
325
|
+
input_filters = filters
|
326
|
+
# No downsample in the last block.
|
327
|
+
if i != len(stackwise_num_filters) - 1:
|
328
|
+
self.downsamples.append(
|
329
|
+
keras.layers.ZeroPadding2D(
|
330
|
+
padding=((0, 1), (0, 1)),
|
331
|
+
data_format=data_format,
|
332
|
+
dtype=self.dtype_policy,
|
333
|
+
name=f"downsample_{i}_pad",
|
334
|
+
)
|
335
|
+
)
|
336
|
+
self.downsamples.append(
|
337
|
+
keras.layers.Conv2D(
|
338
|
+
filters,
|
339
|
+
3,
|
340
|
+
2,
|
341
|
+
data_format=data_format,
|
342
|
+
dtype=self.dtype_policy,
|
343
|
+
name=f"downsample_{i}_conv",
|
344
|
+
)
|
345
|
+
)
|
346
|
+
|
347
|
+
# Mid block.
|
348
|
+
self.mid_block_0 = ResNetBlock(
|
349
|
+
stackwise_num_filters[-1],
|
350
|
+
has_residual_projection=False,
|
351
|
+
data_format=data_format,
|
352
|
+
dtype=self.dtype_policy,
|
353
|
+
name="mid_block_0",
|
354
|
+
)
|
355
|
+
self.mid_attention = Conv2DMultiHeadAttention(
|
356
|
+
stackwise_num_filters[-1],
|
357
|
+
data_format=data_format,
|
358
|
+
dtype=self.dtype_policy,
|
359
|
+
name="mid_attention",
|
360
|
+
)
|
361
|
+
self.mid_block_1 = ResNetBlock(
|
362
|
+
stackwise_num_filters[-1],
|
363
|
+
has_residual_projection=False,
|
364
|
+
data_format=data_format,
|
365
|
+
dtype=self.dtype_policy,
|
366
|
+
name="mid_block_1",
|
367
|
+
)
|
368
|
+
|
369
|
+
# Output layers.
|
370
|
+
self.output_norm = keras.layers.GroupNormalization(
|
371
|
+
groups=32,
|
372
|
+
axis=channel_axis,
|
373
|
+
epsilon=1e-6,
|
374
|
+
dtype=self.dtype_policy,
|
375
|
+
name="output_norm",
|
376
|
+
)
|
377
|
+
self.output_act = keras.layers.Activation(
|
378
|
+
"swish", dtype=self.dtype_policy
|
379
|
+
)
|
380
|
+
self.output_projection = keras.layers.Conv2D(
|
381
|
+
output_channels,
|
382
|
+
3,
|
383
|
+
1,
|
384
|
+
padding="same",
|
385
|
+
data_format=data_format,
|
386
|
+
dtype=self.dtype_policy,
|
387
|
+
name="output_projection",
|
388
|
+
)
|
389
|
+
|
390
|
+
def build(self, input_shape):
|
391
|
+
self.input_projection.build(input_shape)
|
392
|
+
input_shape = self.input_projection.compute_output_shape(input_shape)
|
393
|
+
blocks_idx = 0
|
394
|
+
downsamples_idx = 0
|
395
|
+
for i, _ in enumerate(self.stackwise_num_filters):
|
396
|
+
for _ in range(self.stackwise_num_blocks[i]):
|
397
|
+
self.blocks[blocks_idx].build(input_shape)
|
398
|
+
input_shape = self.blocks[blocks_idx].compute_output_shape(
|
399
|
+
input_shape
|
400
|
+
)
|
401
|
+
blocks_idx += 1
|
402
|
+
if i != len(self.stackwise_num_filters) - 1:
|
403
|
+
self.downsamples[downsamples_idx].build(input_shape)
|
404
|
+
input_shape = self.downsamples[
|
405
|
+
downsamples_idx
|
406
|
+
].compute_output_shape(input_shape)
|
407
|
+
downsamples_idx += 1
|
408
|
+
self.downsamples[downsamples_idx].build(input_shape)
|
409
|
+
input_shape = self.downsamples[
|
410
|
+
downsamples_idx
|
411
|
+
].compute_output_shape(input_shape)
|
412
|
+
downsamples_idx += 1
|
413
|
+
self.mid_block_0.build(input_shape)
|
414
|
+
input_shape = self.mid_block_0.compute_output_shape(input_shape)
|
415
|
+
self.mid_attention.build(input_shape)
|
416
|
+
input_shape = self.mid_attention.compute_output_shape(input_shape)
|
417
|
+
self.mid_block_1.build(input_shape)
|
418
|
+
input_shape = self.mid_block_1.compute_output_shape(input_shape)
|
419
|
+
self.output_norm.build(input_shape)
|
420
|
+
self.output_act.build(input_shape)
|
421
|
+
self.output_projection.build(input_shape)
|
422
|
+
|
423
|
+
def call(self, inputs, training=None):
|
424
|
+
x = inputs
|
425
|
+
x = self.input_projection(x, training=training)
|
426
|
+
blocks_idx = 0
|
427
|
+
upsamples_idx = 0
|
428
|
+
for i, _ in enumerate(self.stackwise_num_filters):
|
429
|
+
for _ in range(self.stackwise_num_blocks[i]):
|
430
|
+
x = self.blocks[blocks_idx](x, training=training)
|
431
|
+
blocks_idx += 1
|
432
|
+
if i != len(self.stackwise_num_filters) - 1:
|
433
|
+
x = self.downsamples[upsamples_idx](x, training=training)
|
434
|
+
x = self.downsamples[upsamples_idx + 1](x, training=training)
|
435
|
+
upsamples_idx += 2
|
436
|
+
x = self.mid_block_0(x, training=training)
|
437
|
+
x = self.mid_attention(x, training=training)
|
438
|
+
x = self.mid_block_1(x, training=training)
|
439
|
+
x = self.output_norm(x, training=training)
|
440
|
+
x = self.output_act(x, training=training)
|
441
|
+
x = self.output_projection(x, training=training)
|
442
|
+
return x
|
443
|
+
|
444
|
+
def get_config(self):
|
445
|
+
config = super().get_config()
|
446
|
+
config.update(
|
447
|
+
{
|
448
|
+
"stackwise_num_filters": self.stackwise_num_filters,
|
449
|
+
"stackwise_num_blocks": self.stackwise_num_blocks,
|
450
|
+
"output_channels": self.output_channels,
|
451
|
+
}
|
452
|
+
)
|
453
|
+
return config
|
454
|
+
|
455
|
+
def compute_output_shape(self, input_shape):
|
456
|
+
if self.data_format == "channels_last":
|
457
|
+
h_axis, w_axis, c_axis = 1, 2, 3
|
458
|
+
else:
|
459
|
+
c_axis, h_axis, w_axis = 1, 2, 3
|
460
|
+
scale_factor = 2 ** (len(self.stackwise_num_filters) - 1)
|
461
|
+
outputs_shape = list(input_shape)
|
462
|
+
if (
|
463
|
+
outputs_shape[h_axis] is not None
|
464
|
+
and outputs_shape[w_axis] is not None
|
465
|
+
):
|
466
|
+
outputs_shape[h_axis] = outputs_shape[h_axis] // scale_factor
|
467
|
+
outputs_shape[w_axis] = outputs_shape[w_axis] // scale_factor
|
468
|
+
outputs_shape[c_axis] = self.output_channels
|
469
|
+
return outputs_shape
|
470
|
+
|
471
|
+
|
472
|
+
class VAEDecoder(keras.layers.Layer):
|
473
|
+
"""The decoder layer of VAE.
|
474
|
+
|
475
|
+
Args:
|
476
|
+
stackwise_num_filters: list of ints. The number of filters for each
|
477
|
+
stack.
|
478
|
+
stackwise_num_blocks: list of ints. The number of blocks for each stack.
|
479
|
+
output_channels: int. The number of channels in the output. Defaults to
|
480
|
+
`3`.
|
481
|
+
data_format: `None` or str. If specified, either `"channels_last"` or
|
482
|
+
`"channels_first"`. The ordering of the dimensions in the
|
483
|
+
inputs. `"channels_last"` corresponds to inputs with shape
|
484
|
+
`(batch_size, height, width, channels)`
|
485
|
+
while `"channels_first"` corresponds to inputs with shape
|
486
|
+
`(batch_size, channels, height, width)`. It defaults to the
|
487
|
+
`image_data_format` value found in your Keras config file at
|
488
|
+
`~/.keras/keras.json`. If you never set it, then it will be
|
489
|
+
`"channels_last"`.
|
490
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
491
|
+
including `name`, `dtype` etc.
|
492
|
+
"""
|
493
|
+
|
494
|
+
def __init__(
|
495
|
+
self,
|
496
|
+
stackwise_num_filters,
|
497
|
+
stackwise_num_blocks,
|
498
|
+
output_channels=3,
|
499
|
+
data_format=None,
|
500
|
+
**kwargs,
|
501
|
+
):
|
502
|
+
super().__init__(**kwargs)
|
503
|
+
data_format = standardize_data_format(data_format)
|
504
|
+
channel_axis = -1 if data_format == "channels_last" else 1
|
505
|
+
self.stackwise_num_filters = stackwise_num_filters
|
506
|
+
self.stackwise_num_blocks = stackwise_num_blocks
|
507
|
+
self.output_channels = int(output_channels)
|
508
|
+
self.data_format = data_format
|
509
|
+
|
510
|
+
# === Layers ===
|
511
|
+
self.input_projection = keras.layers.Conv2D(
|
512
|
+
stackwise_num_filters[0],
|
513
|
+
3,
|
514
|
+
1,
|
515
|
+
padding="same",
|
516
|
+
data_format=data_format,
|
517
|
+
dtype=self.dtype_policy,
|
518
|
+
name="input_projection",
|
519
|
+
)
|
520
|
+
|
521
|
+
# Mid block.
|
522
|
+
self.mid_block_0 = ResNetBlock(
|
523
|
+
stackwise_num_filters[0],
|
524
|
+
data_format=data_format,
|
525
|
+
dtype=self.dtype_policy,
|
526
|
+
name="mid_block_0",
|
527
|
+
)
|
528
|
+
self.mid_attention = Conv2DMultiHeadAttention(
|
529
|
+
stackwise_num_filters[0],
|
530
|
+
data_format=data_format,
|
531
|
+
dtype=self.dtype_policy,
|
532
|
+
name="mid_attention",
|
533
|
+
)
|
534
|
+
self.mid_block_1 = ResNetBlock(
|
535
|
+
stackwise_num_filters[0],
|
536
|
+
data_format=data_format,
|
537
|
+
dtype=self.dtype_policy,
|
538
|
+
name="mid_block_1",
|
539
|
+
)
|
540
|
+
|
541
|
+
# Blocks.
|
542
|
+
input_filters = stackwise_num_filters[0]
|
543
|
+
self.blocks = []
|
544
|
+
self.upsamples = []
|
545
|
+
for i, filters in enumerate(stackwise_num_filters):
|
546
|
+
for j in range(stackwise_num_blocks[i]):
|
547
|
+
self.blocks.append(
|
548
|
+
ResNetBlock(
|
549
|
+
filters,
|
550
|
+
has_residual_projection=input_filters != filters,
|
551
|
+
data_format=data_format,
|
552
|
+
dtype=self.dtype_policy,
|
553
|
+
name=f"block_{i}_{j}",
|
554
|
+
)
|
555
|
+
)
|
556
|
+
input_filters = filters
|
557
|
+
# No upsample in the last block.
|
558
|
+
if i != len(stackwise_num_filters) - 1:
|
559
|
+
self.upsamples.append(
|
560
|
+
keras.layers.UpSampling2D(
|
561
|
+
2,
|
562
|
+
data_format=data_format,
|
563
|
+
dtype=self.dtype_policy,
|
564
|
+
name=f"upsample_{i}",
|
565
|
+
)
|
566
|
+
)
|
567
|
+
self.upsamples.append(
|
568
|
+
keras.layers.Conv2D(
|
569
|
+
filters,
|
570
|
+
3,
|
571
|
+
1,
|
572
|
+
padding="same",
|
573
|
+
data_format=data_format,
|
574
|
+
dtype=self.dtype_policy,
|
575
|
+
name=f"upsample_{i}_conv",
|
576
|
+
)
|
577
|
+
)
|
578
|
+
|
579
|
+
# Output layers.
|
580
|
+
self.output_norm = keras.layers.GroupNormalization(
|
581
|
+
groups=32,
|
582
|
+
axis=channel_axis,
|
583
|
+
epsilon=1e-6,
|
584
|
+
dtype=self.dtype_policy,
|
585
|
+
name="output_norm",
|
586
|
+
)
|
587
|
+
self.output_act = keras.layers.Activation(
|
588
|
+
"swish", dtype=self.dtype_policy
|
589
|
+
)
|
590
|
+
self.output_projection = keras.layers.Conv2D(
|
591
|
+
output_channels,
|
592
|
+
3,
|
593
|
+
1,
|
594
|
+
padding="same",
|
595
|
+
data_format=data_format,
|
596
|
+
dtype=self.dtype_policy,
|
597
|
+
name="output_projection",
|
598
|
+
)
|
599
|
+
|
600
|
+
def build(self, input_shape):
|
601
|
+
self.input_projection.build(input_shape)
|
602
|
+
input_shape = self.input_projection.compute_output_shape(input_shape)
|
603
|
+
self.mid_block_0.build(input_shape)
|
604
|
+
input_shape = self.mid_block_0.compute_output_shape(input_shape)
|
605
|
+
self.mid_attention.build(input_shape)
|
606
|
+
input_shape = self.mid_attention.compute_output_shape(input_shape)
|
607
|
+
self.mid_block_1.build(input_shape)
|
608
|
+
input_shape = self.mid_block_1.compute_output_shape(input_shape)
|
609
|
+
blocks_idx = 0
|
610
|
+
upsamples_idx = 0
|
611
|
+
for i, _ in enumerate(self.stackwise_num_filters):
|
612
|
+
for _ in range(self.stackwise_num_blocks[i]):
|
613
|
+
self.blocks[blocks_idx].build(input_shape)
|
614
|
+
input_shape = self.blocks[blocks_idx].compute_output_shape(
|
615
|
+
input_shape
|
616
|
+
)
|
617
|
+
blocks_idx += 1
|
618
|
+
if i != len(self.stackwise_num_filters) - 1:
|
619
|
+
self.upsamples[upsamples_idx].build(input_shape)
|
620
|
+
input_shape = self.upsamples[
|
621
|
+
upsamples_idx
|
622
|
+
].compute_output_shape(input_shape)
|
623
|
+
self.upsamples[upsamples_idx + 1].build(input_shape)
|
624
|
+
input_shape = self.upsamples[
|
625
|
+
upsamples_idx + 1
|
626
|
+
].compute_output_shape(input_shape)
|
627
|
+
upsamples_idx += 2
|
628
|
+
self.output_norm.build(input_shape)
|
629
|
+
self.output_act.build(input_shape)
|
630
|
+
self.output_projection.build(input_shape)
|
631
|
+
|
632
|
+
def call(self, inputs, training=None):
|
633
|
+
x = inputs
|
634
|
+
x = self.input_projection(x, training=training)
|
635
|
+
x = self.mid_block_0(x, training=training)
|
636
|
+
x = self.mid_attention(x, training=training)
|
637
|
+
x = self.mid_block_1(x, training=training)
|
638
|
+
blocks_idx = 0
|
639
|
+
upsamples_idx = 0
|
640
|
+
for i, _ in enumerate(self.stackwise_num_filters):
|
641
|
+
for _ in range(self.stackwise_num_blocks[i]):
|
642
|
+
x = self.blocks[blocks_idx](x, training=training)
|
643
|
+
blocks_idx += 1
|
644
|
+
if i != len(self.stackwise_num_filters) - 1:
|
645
|
+
x = self.upsamples[upsamples_idx](x, training=training)
|
646
|
+
x = self.upsamples[upsamples_idx + 1](x, training=training)
|
647
|
+
upsamples_idx += 2
|
648
|
+
x = self.output_norm(x, training=training)
|
649
|
+
x = self.output_act(x, training=training)
|
650
|
+
x = self.output_projection(x, training=training)
|
651
|
+
return x
|
652
|
+
|
653
|
+
def get_config(self):
|
654
|
+
config = super().get_config()
|
655
|
+
config.update(
|
656
|
+
{
|
657
|
+
"stackwise_num_filters": self.stackwise_num_filters,
|
658
|
+
"stackwise_num_blocks": self.stackwise_num_blocks,
|
659
|
+
"output_channels": self.output_channels,
|
660
|
+
}
|
661
|
+
)
|
662
|
+
return config
|
663
|
+
|
664
|
+
def compute_output_shape(self, input_shape):
|
665
|
+
if self.data_format == "channels_last":
|
666
|
+
h_axis, w_axis, c_axis = 1, 2, 3
|
667
|
+
else:
|
668
|
+
c_axis, h_axis, w_axis = 1, 2, 3
|
669
|
+
scale_factor = 2 ** (len(self.stackwise_num_filters) - 1)
|
670
|
+
outputs_shape = list(input_shape)
|
671
|
+
if (
|
672
|
+
outputs_shape[h_axis] is not None
|
673
|
+
and outputs_shape[w_axis] is not None
|
674
|
+
):
|
675
|
+
outputs_shape[h_axis] = outputs_shape[h_axis] * scale_factor
|
676
|
+
outputs_shape[w_axis] = outputs_shape[w_axis] * scale_factor
|
677
|
+
outputs_shape[c_axis] = self.output_channels
|
678
|
+
return outputs_shape
|
679
|
+
|
680
|
+
|
681
|
+
class DiagonalGaussianDistributionSampler(keras.layers.Layer):
|
682
|
+
"""A sampler for a diagonal Gaussian distribution.
|
683
|
+
|
684
|
+
This layer samples latent variables from a diagonal Gaussian distribution.
|
685
|
+
|
686
|
+
Args:
|
687
|
+
method: str. The method used to sample from the distribution. Available
|
688
|
+
methods are `"sample"` and `"mode"`. `"sample"` draws from the
|
689
|
+
distribution using both the mean and log variance. `"mode"` draws
|
690
|
+
from the distribution using the mean only.
|
691
|
+
axis: int. The axis along which to split the mean and log variance.
|
692
|
+
Defaults to `-1`.
|
693
|
+
seed: optional int. Used as a random seed.
|
694
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
695
|
+
including `name`, `dtype` etc.
|
696
|
+
"""
|
697
|
+
|
698
|
+
def __init__(self, method, axis=-1, seed=None, **kwargs):
|
699
|
+
super().__init__(**kwargs)
|
700
|
+
# TODO: Support `kl` and `nll` modes.
|
701
|
+
valid_methods = ("sample", "mode")
|
702
|
+
if method not in valid_methods:
|
703
|
+
raise ValueError(
|
704
|
+
f"Invalid method {method}. Valid methods are "
|
705
|
+
f"{list(valid_methods)}."
|
706
|
+
)
|
707
|
+
self.method = method
|
708
|
+
self.axis = axis
|
709
|
+
self.seed = seed
|
710
|
+
self.seed_generator = keras.random.SeedGenerator(seed)
|
711
|
+
|
712
|
+
def call(self, inputs):
|
713
|
+
x = inputs
|
714
|
+
if self.method == "sample":
|
715
|
+
x_mean, x_logvar = ops.split(x, 2, axis=self.axis)
|
716
|
+
x_logvar = ops.clip(x_logvar, -30.0, 20.0)
|
717
|
+
x_std = ops.exp(ops.multiply(0.5, x_logvar))
|
718
|
+
sample = keras.random.normal(
|
719
|
+
ops.shape(x_mean), dtype=x_mean.dtype, seed=self.seed_generator
|
720
|
+
)
|
721
|
+
x = ops.add(x_mean, ops.multiply(x_std, sample))
|
722
|
+
else:
|
723
|
+
x, _ = ops.split(x, 2, axis=self.axis)
|
724
|
+
return x
|
725
|
+
|
726
|
+
def get_config(self):
|
727
|
+
config = super().get_config()
|
728
|
+
config.update(
|
729
|
+
{
|
730
|
+
"axis": self.axis,
|
731
|
+
"seed": self.seed,
|
732
|
+
}
|
733
|
+
)
|
734
|
+
return config
|
735
|
+
|
736
|
+
def compute_output_shape(self, input_shape):
|
737
|
+
output_shape = list(input_shape)
|
738
|
+
output_shape[self.axis] = output_shape[self.axis] // 2
|
739
|
+
return output_shape
|