keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. keras_hub/api/layers/__init__.py +21 -3
  2. keras_hub/api/models/__init__.py +71 -12
  3. keras_hub/api/tokenizers/__init__.py +1 -1
  4. keras_hub/src/bounding_box/__init__.py +2 -0
  5. keras_hub/src/bounding_box/converters.py +102 -12
  6. keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
  7. keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
  8. keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
  9. keras_hub/src/layers/modeling/rms_normalization.py +36 -0
  10. keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
  11. keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
  12. keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
  13. keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
  14. keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
  15. keras_hub/src/layers/preprocessing/image_converter.py +170 -34
  16. keras_hub/src/metrics/bleu.py +4 -3
  17. keras_hub/src/models/albert/albert_presets.py +4 -12
  18. keras_hub/src/models/albert/albert_text_classifier.py +7 -7
  19. keras_hub/src/models/backbone.py +3 -14
  20. keras_hub/src/models/bart/bart_backbone.py +4 -4
  21. keras_hub/src/models/bart/bart_presets.py +3 -9
  22. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
  23. keras_hub/src/models/basnet/__init__.py +5 -0
  24. keras_hub/src/models/basnet/basnet.py +122 -0
  25. keras_hub/src/models/basnet/basnet_backbone.py +366 -0
  26. keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
  27. keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
  28. keras_hub/src/models/basnet/basnet_presets.py +17 -0
  29. keras_hub/src/models/bert/bert_presets.py +14 -32
  30. keras_hub/src/models/bert/bert_text_classifier.py +3 -3
  31. keras_hub/src/models/bloom/bloom_presets.py +8 -24
  32. keras_hub/src/models/causal_lm.py +56 -12
  33. keras_hub/src/models/clip/__init__.py +5 -0
  34. keras_hub/src/models/clip/clip_backbone.py +286 -0
  35. keras_hub/src/models/clip/clip_encoder_block.py +19 -4
  36. keras_hub/src/models/clip/clip_image_converter.py +8 -0
  37. keras_hub/src/models/clip/clip_presets.py +93 -0
  38. keras_hub/src/models/clip/clip_text_encoder.py +4 -1
  39. keras_hub/src/models/clip/clip_tokenizer.py +18 -3
  40. keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
  41. keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
  42. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
  43. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
  44. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
  45. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
  46. keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
  47. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
  48. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
  49. keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
  50. keras_hub/src/models/deeplab_v3/__init__.py +7 -0
  51. keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
  52. keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
  53. keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
  54. keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
  55. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
  56. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
  57. keras_hub/src/models/densenet/densenet_backbone.py +6 -4
  58. keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
  59. keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
  60. keras_hub/src/models/densenet/densenet_presets.py +9 -15
  61. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
  62. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
  63. keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
  64. keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
  65. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
  66. keras_hub/src/models/efficientnet/__init__.py +9 -0
  67. keras_hub/src/models/efficientnet/cba.py +141 -0
  68. keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
  69. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
  70. keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
  71. keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
  72. keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
  73. keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
  74. keras_hub/src/models/efficientnet/mbconv.py +53 -22
  75. keras_hub/src/models/electra/electra_backbone.py +2 -2
  76. keras_hub/src/models/electra/electra_presets.py +6 -18
  77. keras_hub/src/models/f_net/f_net_presets.py +2 -6
  78. keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
  79. keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
  80. keras_hub/src/models/falcon/falcon_backbone.py +5 -3
  81. keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
  82. keras_hub/src/models/falcon/falcon_presets.py +1 -3
  83. keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
  84. keras_hub/src/models/feature_pyramid_backbone.py +1 -1
  85. keras_hub/src/models/flux/__init__.py +5 -0
  86. keras_hub/src/models/flux/flux_layers.py +496 -0
  87. keras_hub/src/models/flux/flux_maths.py +225 -0
  88. keras_hub/src/models/flux/flux_model.py +236 -0
  89. keras_hub/src/models/flux/flux_presets.py +3 -0
  90. keras_hub/src/models/flux/flux_text_to_image.py +146 -0
  91. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
  92. keras_hub/src/models/gemma/gemma_backbone.py +35 -20
  93. keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
  94. keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
  95. keras_hub/src/models/gemma/gemma_presets.py +29 -63
  96. keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
  97. keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
  98. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
  99. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
  100. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
  101. keras_hub/src/models/image_classifier.py +147 -2
  102. keras_hub/src/models/image_classifier_preprocessor.py +6 -3
  103. keras_hub/src/models/image_object_detector.py +87 -0
  104. keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
  105. keras_hub/src/models/image_segmenter.py +0 -5
  106. keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
  107. keras_hub/src/models/image_to_image.py +417 -0
  108. keras_hub/src/models/inpaint.py +520 -0
  109. keras_hub/src/models/llama/llama_backbone.py +138 -12
  110. keras_hub/src/models/llama/llama_causal_lm.py +3 -1
  111. keras_hub/src/models/llama/llama_presets.py +10 -20
  112. keras_hub/src/models/llama3/llama3_backbone.py +12 -11
  113. keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
  114. keras_hub/src/models/llama3/llama3_presets.py +4 -12
  115. keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
  116. keras_hub/src/models/mistral/mistral_backbone.py +16 -15
  117. keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
  118. keras_hub/src/models/mistral/mistral_presets.py +3 -9
  119. keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
  120. keras_hub/src/models/mit/__init__.py +6 -0
  121. keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
  122. keras_hub/src/models/mit/mit_image_classifier.py +12 -0
  123. keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
  124. keras_hub/src/models/mit/mit_image_converter.py +8 -0
  125. keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
  126. keras_hub/src/models/mit/mit_presets.py +139 -0
  127. keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
  128. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
  129. keras_hub/src/models/opt/opt_causal_lm.py +2 -2
  130. keras_hub/src/models/opt/opt_presets.py +4 -12
  131. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
  132. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
  133. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
  134. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
  135. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
  136. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
  137. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
  138. keras_hub/src/models/phi3/phi3_decoder.py +0 -1
  139. keras_hub/src/models/phi3/phi3_presets.py +2 -6
  140. keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
  141. keras_hub/src/models/preprocessor.py +25 -11
  142. keras_hub/src/models/resnet/resnet_backbone.py +3 -14
  143. keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
  144. keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
  145. keras_hub/src/models/resnet/resnet_presets.py +127 -18
  146. keras_hub/src/models/retinanet/__init__.py +5 -0
  147. keras_hub/src/models/retinanet/anchor_generator.py +52 -53
  148. keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
  149. keras_hub/src/models/retinanet/non_max_supression.py +1 -0
  150. keras_hub/src/models/retinanet/prediction_head.py +192 -0
  151. keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
  152. keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
  153. keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
  154. keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
  155. keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
  156. keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
  157. keras_hub/src/models/roberta/roberta_backbone.py +2 -2
  158. keras_hub/src/models/roberta/roberta_presets.py +6 -8
  159. keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
  160. keras_hub/src/models/sam/__init__.py +5 -0
  161. keras_hub/src/models/sam/sam_backbone.py +2 -3
  162. keras_hub/src/models/sam/sam_image_converter.py +2 -4
  163. keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
  164. keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
  165. keras_hub/src/models/sam/sam_layers.py +5 -3
  166. keras_hub/src/models/sam/sam_presets.py +3 -9
  167. keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
  168. keras_hub/src/models/sam/sam_transformer.py +5 -4
  169. keras_hub/src/models/segformer/__init__.py +8 -0
  170. keras_hub/src/models/segformer/segformer_backbone.py +167 -0
  171. keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
  172. keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
  173. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
  174. keras_hub/src/models/segformer/segformer_presets.py +136 -0
  175. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
  176. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
  177. keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
  178. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
  179. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
  180. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
  181. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
  182. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
  183. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
  184. keras_hub/src/models/t5/t5_backbone.py +5 -4
  185. keras_hub/src/models/t5/t5_presets.py +47 -19
  186. keras_hub/src/models/task.py +47 -39
  187. keras_hub/src/models/text_classifier.py +2 -2
  188. keras_hub/src/models/text_to_image.py +106 -41
  189. keras_hub/src/models/vae/__init__.py +1 -0
  190. keras_hub/src/models/vae/vae_backbone.py +184 -0
  191. keras_hub/src/models/vae/vae_layers.py +739 -0
  192. keras_hub/src/models/vgg/__init__.py +5 -0
  193. keras_hub/src/models/vgg/vgg_backbone.py +4 -24
  194. keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
  195. keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
  196. keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
  197. keras_hub/src/models/vgg/vgg_presets.py +48 -0
  198. keras_hub/src/models/vit/__init__.py +5 -0
  199. keras_hub/src/models/vit/vit_backbone.py +152 -0
  200. keras_hub/src/models/vit/vit_image_classifier.py +187 -0
  201. keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
  202. keras_hub/src/models/vit/vit_image_converter.py +73 -0
  203. keras_hub/src/models/vit/vit_layers.py +391 -0
  204. keras_hub/src/models/vit/vit_presets.py +126 -0
  205. keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
  206. keras_hub/src/models/vit_det/vit_layers.py +3 -3
  207. keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
  208. keras_hub/src/models/whisper/whisper_backbone.py +6 -5
  209. keras_hub/src/models/whisper/whisper_decoder.py +3 -5
  210. keras_hub/src/models/whisper/whisper_presets.py +10 -30
  211. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
  212. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
  213. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
  214. keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
  215. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
  216. keras_hub/src/models/xlnet/relative_attention.py +20 -19
  217. keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
  218. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
  219. keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
  220. keras_hub/src/samplers/contrastive_sampler.py +2 -3
  221. keras_hub/src/samplers/sampler.py +2 -1
  222. keras_hub/src/tests/test_case.py +41 -6
  223. keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
  224. keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
  225. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
  226. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
  227. keras_hub/src/tokenizers/tokenizer.py +10 -13
  228. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
  229. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
  230. keras_hub/src/utils/keras_utils.py +2 -13
  231. keras_hub/src/utils/pipeline_model.py +3 -3
  232. keras_hub/src/utils/preset_utils.py +196 -144
  233. keras_hub/src/utils/tensor_utils.py +4 -4
  234. keras_hub/src/utils/timm/convert_densenet.py +6 -4
  235. keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
  236. keras_hub/src/utils/timm/convert_resnet.py +1 -1
  237. keras_hub/src/utils/timm/convert_vgg.py +85 -0
  238. keras_hub/src/utils/timm/preset_loader.py +14 -9
  239. keras_hub/src/utils/transformers/convert_llama3.py +21 -5
  240. keras_hub/src/utils/transformers/convert_vit.py +150 -0
  241. keras_hub/src/utils/transformers/preset_loader.py +23 -0
  242. keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
  243. keras_hub/src/version_utils.py +1 -1
  244. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
  245. keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
  246. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
  247. keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
  248. keras_hub/src/models/mix_transformer/__init__.py +0 -0
  249. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
  250. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
  251. keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
  252. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,739 @@
1
+ import math
2
+
3
+ import keras
4
+ from keras import ops
5
+
6
+ from keras_hub.src.utils.keras_utils import standardize_data_format
7
+
8
+
9
+ class Conv2DMultiHeadAttention(keras.layers.Layer):
10
+ """A MultiHeadAttention layer utilizing `Conv2D` and `GroupNormalization`.
11
+
12
+ Args:
13
+ filters: int. The number of the filters for the convolutional layers.
14
+ groups: int. The number of the groups for the group normalization
15
+ layers. Defaults to `32`.
16
+ data_format: `None` or str. If specified, either `"channels_last"` or
17
+ `"channels_first"`. The ordering of the dimensions in the
18
+ inputs. `"channels_last"` corresponds to inputs with shape
19
+ `(batch_size, height, width, channels)`
20
+ while `"channels_first"` corresponds to inputs with shape
21
+ `(batch_size, channels, height, width)`. It defaults to the
22
+ `image_data_format` value found in your Keras config file at
23
+ `~/.keras/keras.json`. If you never set it, then it will be
24
+ `"channels_last"`.
25
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
26
+ including `name`, `dtype` etc.
27
+ """
28
+
29
+ def __init__(self, filters, groups=32, data_format=None, **kwargs):
30
+ super().__init__(**kwargs)
31
+ data_format = standardize_data_format(data_format)
32
+ channel_axis = -1 if data_format == "channels_last" else 1
33
+ self.filters = int(filters)
34
+ self.groups = int(groups)
35
+ self._inverse_sqrt_filters = 1.0 / math.sqrt(float(filters))
36
+ self.data_format = data_format
37
+
38
+ self.group_norm = keras.layers.GroupNormalization(
39
+ groups=groups,
40
+ axis=channel_axis,
41
+ epsilon=1e-6,
42
+ dtype=self.dtype_policy,
43
+ name="group_norm",
44
+ )
45
+ self.query_conv2d = keras.layers.Conv2D(
46
+ filters,
47
+ 1,
48
+ 1,
49
+ data_format=data_format,
50
+ dtype=self.dtype_policy,
51
+ name="query_conv2d",
52
+ )
53
+ self.key_conv2d = keras.layers.Conv2D(
54
+ filters,
55
+ 1,
56
+ 1,
57
+ data_format=data_format,
58
+ dtype=self.dtype_policy,
59
+ name="key_conv2d",
60
+ )
61
+ self.value_conv2d = keras.layers.Conv2D(
62
+ filters,
63
+ 1,
64
+ 1,
65
+ data_format=data_format,
66
+ dtype=self.dtype_policy,
67
+ name="value_conv2d",
68
+ )
69
+ self.softmax = keras.layers.Softmax(dtype="float32")
70
+ self.output_conv2d = keras.layers.Conv2D(
71
+ filters,
72
+ 1,
73
+ 1,
74
+ data_format=data_format,
75
+ dtype=self.dtype_policy,
76
+ name="output_conv2d",
77
+ )
78
+
79
+ def build(self, input_shape):
80
+ self.group_norm.build(input_shape)
81
+ self.query_conv2d.build(input_shape)
82
+ self.key_conv2d.build(input_shape)
83
+ self.value_conv2d.build(input_shape)
84
+ self.output_conv2d.build(input_shape)
85
+
86
+ def call(self, inputs, training=None):
87
+ x = self.group_norm(inputs, training=training)
88
+ query = self.query_conv2d(x, training=training)
89
+ key = self.key_conv2d(x, training=training)
90
+ value = self.value_conv2d(x, training=training)
91
+
92
+ if self.data_format == "channels_first":
93
+ query = ops.transpose(query, (0, 2, 3, 1))
94
+ key = ops.transpose(key, (0, 2, 3, 1))
95
+ value = ops.transpose(value, (0, 2, 3, 1))
96
+ shape = ops.shape(inputs)
97
+ b = shape[0]
98
+ query = ops.reshape(query, (b, -1, self.filters))
99
+ key = ops.reshape(key, (b, -1, self.filters))
100
+ value = ops.reshape(value, (b, -1, self.filters))
101
+
102
+ # Compute attention.
103
+ query = ops.multiply(
104
+ query, ops.cast(self._inverse_sqrt_filters, query.dtype)
105
+ )
106
+ # [B, H0 * W0, C], [B, H1 * W1, C] -> [B, H0 * W0, H1 * W1]
107
+ attention_scores = ops.einsum("abc,adc->abd", query, key)
108
+ attention_scores = ops.cast(
109
+ self.softmax(attention_scores), self.compute_dtype
110
+ )
111
+ # [B, H2 * W2, C], [B, H0 * W0, H1 * W1] -> [B, H1 * W1 ,C]
112
+ attention_output = ops.einsum("abc,adb->adc", value, attention_scores)
113
+ x = ops.reshape(attention_output, shape)
114
+
115
+ x = self.output_conv2d(x, training=training)
116
+ if self.data_format == "channels_first":
117
+ x = ops.transpose(x, (0, 3, 1, 2))
118
+ x = ops.add(x, inputs)
119
+ return x
120
+
121
+ def get_config(self):
122
+ config = super().get_config()
123
+ config.update(
124
+ {
125
+ "filters": self.filters,
126
+ "groups": self.groups,
127
+ }
128
+ )
129
+ return config
130
+
131
+ def compute_output_shape(self, input_shape):
132
+ return input_shape
133
+
134
+
135
+ class ResNetBlock(keras.layers.Layer):
136
+ """A ResNet block utilizing `GroupNormalization` and SiLU activation.
137
+
138
+ Args:
139
+ filters: The number of filters in the block.
140
+ has_residual_projection: Whether to add a projection layer for the
141
+ residual connection. Defaults to `False`.
142
+ data_format: `None` or str. If specified, either `"channels_last"` or
143
+ `"channels_first"`. The ordering of the dimensions in the
144
+ inputs. `"channels_last"` corresponds to inputs with shape
145
+ `(batch_size, height, width, channels)`
146
+ while `"channels_first"` corresponds to inputs with shape
147
+ `(batch_size, channels, height, width)`. It defaults to the
148
+ `image_data_format` value found in your Keras config file at
149
+ `~/.keras/keras.json`. If you never set it, then it will be
150
+ `"channels_last"`.
151
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
152
+ including `name`, `dtype` etc.
153
+ """
154
+
155
+ def __init__(
156
+ self,
157
+ filters,
158
+ has_residual_projection=False,
159
+ data_format=None,
160
+ **kwargs,
161
+ ):
162
+ super().__init__(**kwargs)
163
+ data_format = standardize_data_format(data_format)
164
+ channel_axis = -1 if data_format == "channels_last" else 1
165
+ self.filters = int(filters)
166
+ self.has_residual_projection = bool(has_residual_projection)
167
+
168
+ # === Layers ===
169
+ self.norm1 = keras.layers.GroupNormalization(
170
+ groups=32,
171
+ axis=channel_axis,
172
+ epsilon=1e-6,
173
+ dtype=self.dtype_policy,
174
+ name="norm1",
175
+ )
176
+ self.act1 = keras.layers.Activation("silu", dtype=self.dtype_policy)
177
+ self.conv1 = keras.layers.Conv2D(
178
+ filters,
179
+ 3,
180
+ 1,
181
+ padding="same",
182
+ data_format=data_format,
183
+ dtype=self.dtype_policy,
184
+ name="conv1",
185
+ )
186
+ self.norm2 = keras.layers.GroupNormalization(
187
+ groups=32,
188
+ axis=channel_axis,
189
+ epsilon=1e-6,
190
+ dtype=self.dtype_policy,
191
+ name="norm2",
192
+ )
193
+ self.act2 = keras.layers.Activation("silu", dtype=self.dtype_policy)
194
+ self.conv2 = keras.layers.Conv2D(
195
+ filters,
196
+ 3,
197
+ 1,
198
+ padding="same",
199
+ data_format=data_format,
200
+ dtype=self.dtype_policy,
201
+ name="conv2",
202
+ )
203
+ if self.has_residual_projection:
204
+ self.residual_projection = keras.layers.Conv2D(
205
+ filters,
206
+ 1,
207
+ 1,
208
+ data_format=data_format,
209
+ dtype=self.dtype_policy,
210
+ name="residual_projection",
211
+ )
212
+ self.add = keras.layers.Add(dtype=self.dtype_policy)
213
+
214
+ def build(self, input_shape):
215
+ residual_shape = list(input_shape)
216
+ self.norm1.build(input_shape)
217
+ self.act1.build(input_shape)
218
+ self.conv1.build(input_shape)
219
+ input_shape = self.conv1.compute_output_shape(input_shape)
220
+ self.norm2.build(input_shape)
221
+ self.act2.build(input_shape)
222
+ self.conv2.build(input_shape)
223
+ input_shape = self.conv2.compute_output_shape(input_shape)
224
+ if self.has_residual_projection:
225
+ self.residual_projection.build(residual_shape)
226
+ self.add.build([input_shape, input_shape])
227
+
228
+ def call(self, inputs, training=None):
229
+ x = inputs
230
+ residual = x
231
+ x = self.norm1(x, training=training)
232
+ x = self.act1(x, training=training)
233
+ x = self.conv1(x, training=training)
234
+ x = self.norm2(x, training=training)
235
+ x = self.act2(x, training=training)
236
+ x = self.conv2(x, training=training)
237
+ if self.has_residual_projection:
238
+ residual = self.residual_projection(residual, training=training)
239
+ x = self.add([residual, x])
240
+ return x
241
+
242
+ def get_config(self):
243
+ config = super().get_config()
244
+ config.update(
245
+ {
246
+ "filters": self.filters,
247
+ "has_residual_projection": self.has_residual_projection,
248
+ }
249
+ )
250
+ return config
251
+
252
+ def compute_output_shape(self, input_shape):
253
+ outputs_shape = list(input_shape)
254
+ if self.has_residual_projection:
255
+ outputs_shape = self.residual_projection.compute_output_shape(
256
+ outputs_shape
257
+ )
258
+ return outputs_shape
259
+
260
+
261
+ class VAEEncoder(keras.layers.Layer):
262
+ """The encoder layer of VAE.
263
+
264
+ Args:
265
+ stackwise_num_filters: list of ints. The number of filters for each
266
+ stack.
267
+ stackwise_num_blocks: list of ints. The number of blocks for each stack.
268
+ output_channels: int. The number of channels in the output. Defaults to
269
+ `32`.
270
+ data_format: `None` or str. If specified, either `"channels_last"` or
271
+ `"channels_first"`. The ordering of the dimensions in the
272
+ inputs. `"channels_last"` corresponds to inputs with shape
273
+ `(batch_size, height, width, channels)`
274
+ while `"channels_first"` corresponds to inputs with shape
275
+ `(batch_size, channels, height, width)`. It defaults to the
276
+ `image_data_format` value found in your Keras config file at
277
+ `~/.keras/keras.json`. If you never set it, then it will be
278
+ `"channels_last"`.
279
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
280
+ including `name`, `dtype` etc.
281
+ """
282
+
283
+ def __init__(
284
+ self,
285
+ stackwise_num_filters,
286
+ stackwise_num_blocks,
287
+ output_channels=32,
288
+ data_format=None,
289
+ **kwargs,
290
+ ):
291
+ super().__init__(**kwargs)
292
+ data_format = standardize_data_format(data_format)
293
+ channel_axis = -1 if data_format == "channels_last" else 1
294
+ self.stackwise_num_filters = stackwise_num_filters
295
+ self.stackwise_num_blocks = stackwise_num_blocks
296
+ self.output_channels = int(output_channels)
297
+ self.data_format = data_format
298
+
299
+ # === Layers ===
300
+ self.input_projection = keras.layers.Conv2D(
301
+ stackwise_num_filters[0],
302
+ 3,
303
+ 1,
304
+ padding="same",
305
+ data_format=data_format,
306
+ dtype=self.dtype_policy,
307
+ name="input_projection",
308
+ )
309
+
310
+ # Blocks.
311
+ input_filters = stackwise_num_filters[0]
312
+ self.blocks = []
313
+ self.downsamples = []
314
+ for i, filters in enumerate(stackwise_num_filters):
315
+ for j in range(stackwise_num_blocks[i]):
316
+ self.blocks.append(
317
+ ResNetBlock(
318
+ filters,
319
+ has_residual_projection=input_filters != filters,
320
+ data_format=data_format,
321
+ dtype=self.dtype_policy,
322
+ name=f"block_{i}_{j}",
323
+ )
324
+ )
325
+ input_filters = filters
326
+ # No downsample in the last block.
327
+ if i != len(stackwise_num_filters) - 1:
328
+ self.downsamples.append(
329
+ keras.layers.ZeroPadding2D(
330
+ padding=((0, 1), (0, 1)),
331
+ data_format=data_format,
332
+ dtype=self.dtype_policy,
333
+ name=f"downsample_{i}_pad",
334
+ )
335
+ )
336
+ self.downsamples.append(
337
+ keras.layers.Conv2D(
338
+ filters,
339
+ 3,
340
+ 2,
341
+ data_format=data_format,
342
+ dtype=self.dtype_policy,
343
+ name=f"downsample_{i}_conv",
344
+ )
345
+ )
346
+
347
+ # Mid block.
348
+ self.mid_block_0 = ResNetBlock(
349
+ stackwise_num_filters[-1],
350
+ has_residual_projection=False,
351
+ data_format=data_format,
352
+ dtype=self.dtype_policy,
353
+ name="mid_block_0",
354
+ )
355
+ self.mid_attention = Conv2DMultiHeadAttention(
356
+ stackwise_num_filters[-1],
357
+ data_format=data_format,
358
+ dtype=self.dtype_policy,
359
+ name="mid_attention",
360
+ )
361
+ self.mid_block_1 = ResNetBlock(
362
+ stackwise_num_filters[-1],
363
+ has_residual_projection=False,
364
+ data_format=data_format,
365
+ dtype=self.dtype_policy,
366
+ name="mid_block_1",
367
+ )
368
+
369
+ # Output layers.
370
+ self.output_norm = keras.layers.GroupNormalization(
371
+ groups=32,
372
+ axis=channel_axis,
373
+ epsilon=1e-6,
374
+ dtype=self.dtype_policy,
375
+ name="output_norm",
376
+ )
377
+ self.output_act = keras.layers.Activation(
378
+ "swish", dtype=self.dtype_policy
379
+ )
380
+ self.output_projection = keras.layers.Conv2D(
381
+ output_channels,
382
+ 3,
383
+ 1,
384
+ padding="same",
385
+ data_format=data_format,
386
+ dtype=self.dtype_policy,
387
+ name="output_projection",
388
+ )
389
+
390
+ def build(self, input_shape):
391
+ self.input_projection.build(input_shape)
392
+ input_shape = self.input_projection.compute_output_shape(input_shape)
393
+ blocks_idx = 0
394
+ downsamples_idx = 0
395
+ for i, _ in enumerate(self.stackwise_num_filters):
396
+ for _ in range(self.stackwise_num_blocks[i]):
397
+ self.blocks[blocks_idx].build(input_shape)
398
+ input_shape = self.blocks[blocks_idx].compute_output_shape(
399
+ input_shape
400
+ )
401
+ blocks_idx += 1
402
+ if i != len(self.stackwise_num_filters) - 1:
403
+ self.downsamples[downsamples_idx].build(input_shape)
404
+ input_shape = self.downsamples[
405
+ downsamples_idx
406
+ ].compute_output_shape(input_shape)
407
+ downsamples_idx += 1
408
+ self.downsamples[downsamples_idx].build(input_shape)
409
+ input_shape = self.downsamples[
410
+ downsamples_idx
411
+ ].compute_output_shape(input_shape)
412
+ downsamples_idx += 1
413
+ self.mid_block_0.build(input_shape)
414
+ input_shape = self.mid_block_0.compute_output_shape(input_shape)
415
+ self.mid_attention.build(input_shape)
416
+ input_shape = self.mid_attention.compute_output_shape(input_shape)
417
+ self.mid_block_1.build(input_shape)
418
+ input_shape = self.mid_block_1.compute_output_shape(input_shape)
419
+ self.output_norm.build(input_shape)
420
+ self.output_act.build(input_shape)
421
+ self.output_projection.build(input_shape)
422
+
423
+ def call(self, inputs, training=None):
424
+ x = inputs
425
+ x = self.input_projection(x, training=training)
426
+ blocks_idx = 0
427
+ upsamples_idx = 0
428
+ for i, _ in enumerate(self.stackwise_num_filters):
429
+ for _ in range(self.stackwise_num_blocks[i]):
430
+ x = self.blocks[blocks_idx](x, training=training)
431
+ blocks_idx += 1
432
+ if i != len(self.stackwise_num_filters) - 1:
433
+ x = self.downsamples[upsamples_idx](x, training=training)
434
+ x = self.downsamples[upsamples_idx + 1](x, training=training)
435
+ upsamples_idx += 2
436
+ x = self.mid_block_0(x, training=training)
437
+ x = self.mid_attention(x, training=training)
438
+ x = self.mid_block_1(x, training=training)
439
+ x = self.output_norm(x, training=training)
440
+ x = self.output_act(x, training=training)
441
+ x = self.output_projection(x, training=training)
442
+ return x
443
+
444
+ def get_config(self):
445
+ config = super().get_config()
446
+ config.update(
447
+ {
448
+ "stackwise_num_filters": self.stackwise_num_filters,
449
+ "stackwise_num_blocks": self.stackwise_num_blocks,
450
+ "output_channels": self.output_channels,
451
+ }
452
+ )
453
+ return config
454
+
455
+ def compute_output_shape(self, input_shape):
456
+ if self.data_format == "channels_last":
457
+ h_axis, w_axis, c_axis = 1, 2, 3
458
+ else:
459
+ c_axis, h_axis, w_axis = 1, 2, 3
460
+ scale_factor = 2 ** (len(self.stackwise_num_filters) - 1)
461
+ outputs_shape = list(input_shape)
462
+ if (
463
+ outputs_shape[h_axis] is not None
464
+ and outputs_shape[w_axis] is not None
465
+ ):
466
+ outputs_shape[h_axis] = outputs_shape[h_axis] // scale_factor
467
+ outputs_shape[w_axis] = outputs_shape[w_axis] // scale_factor
468
+ outputs_shape[c_axis] = self.output_channels
469
+ return outputs_shape
470
+
471
+
472
+ class VAEDecoder(keras.layers.Layer):
473
+ """The decoder layer of VAE.
474
+
475
+ Args:
476
+ stackwise_num_filters: list of ints. The number of filters for each
477
+ stack.
478
+ stackwise_num_blocks: list of ints. The number of blocks for each stack.
479
+ output_channels: int. The number of channels in the output. Defaults to
480
+ `3`.
481
+ data_format: `None` or str. If specified, either `"channels_last"` or
482
+ `"channels_first"`. The ordering of the dimensions in the
483
+ inputs. `"channels_last"` corresponds to inputs with shape
484
+ `(batch_size, height, width, channels)`
485
+ while `"channels_first"` corresponds to inputs with shape
486
+ `(batch_size, channels, height, width)`. It defaults to the
487
+ `image_data_format` value found in your Keras config file at
488
+ `~/.keras/keras.json`. If you never set it, then it will be
489
+ `"channels_last"`.
490
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
491
+ including `name`, `dtype` etc.
492
+ """
493
+
494
+ def __init__(
495
+ self,
496
+ stackwise_num_filters,
497
+ stackwise_num_blocks,
498
+ output_channels=3,
499
+ data_format=None,
500
+ **kwargs,
501
+ ):
502
+ super().__init__(**kwargs)
503
+ data_format = standardize_data_format(data_format)
504
+ channel_axis = -1 if data_format == "channels_last" else 1
505
+ self.stackwise_num_filters = stackwise_num_filters
506
+ self.stackwise_num_blocks = stackwise_num_blocks
507
+ self.output_channels = int(output_channels)
508
+ self.data_format = data_format
509
+
510
+ # === Layers ===
511
+ self.input_projection = keras.layers.Conv2D(
512
+ stackwise_num_filters[0],
513
+ 3,
514
+ 1,
515
+ padding="same",
516
+ data_format=data_format,
517
+ dtype=self.dtype_policy,
518
+ name="input_projection",
519
+ )
520
+
521
+ # Mid block.
522
+ self.mid_block_0 = ResNetBlock(
523
+ stackwise_num_filters[0],
524
+ data_format=data_format,
525
+ dtype=self.dtype_policy,
526
+ name="mid_block_0",
527
+ )
528
+ self.mid_attention = Conv2DMultiHeadAttention(
529
+ stackwise_num_filters[0],
530
+ data_format=data_format,
531
+ dtype=self.dtype_policy,
532
+ name="mid_attention",
533
+ )
534
+ self.mid_block_1 = ResNetBlock(
535
+ stackwise_num_filters[0],
536
+ data_format=data_format,
537
+ dtype=self.dtype_policy,
538
+ name="mid_block_1",
539
+ )
540
+
541
+ # Blocks.
542
+ input_filters = stackwise_num_filters[0]
543
+ self.blocks = []
544
+ self.upsamples = []
545
+ for i, filters in enumerate(stackwise_num_filters):
546
+ for j in range(stackwise_num_blocks[i]):
547
+ self.blocks.append(
548
+ ResNetBlock(
549
+ filters,
550
+ has_residual_projection=input_filters != filters,
551
+ data_format=data_format,
552
+ dtype=self.dtype_policy,
553
+ name=f"block_{i}_{j}",
554
+ )
555
+ )
556
+ input_filters = filters
557
+ # No upsample in the last block.
558
+ if i != len(stackwise_num_filters) - 1:
559
+ self.upsamples.append(
560
+ keras.layers.UpSampling2D(
561
+ 2,
562
+ data_format=data_format,
563
+ dtype=self.dtype_policy,
564
+ name=f"upsample_{i}",
565
+ )
566
+ )
567
+ self.upsamples.append(
568
+ keras.layers.Conv2D(
569
+ filters,
570
+ 3,
571
+ 1,
572
+ padding="same",
573
+ data_format=data_format,
574
+ dtype=self.dtype_policy,
575
+ name=f"upsample_{i}_conv",
576
+ )
577
+ )
578
+
579
+ # Output layers.
580
+ self.output_norm = keras.layers.GroupNormalization(
581
+ groups=32,
582
+ axis=channel_axis,
583
+ epsilon=1e-6,
584
+ dtype=self.dtype_policy,
585
+ name="output_norm",
586
+ )
587
+ self.output_act = keras.layers.Activation(
588
+ "swish", dtype=self.dtype_policy
589
+ )
590
+ self.output_projection = keras.layers.Conv2D(
591
+ output_channels,
592
+ 3,
593
+ 1,
594
+ padding="same",
595
+ data_format=data_format,
596
+ dtype=self.dtype_policy,
597
+ name="output_projection",
598
+ )
599
+
600
+ def build(self, input_shape):
601
+ self.input_projection.build(input_shape)
602
+ input_shape = self.input_projection.compute_output_shape(input_shape)
603
+ self.mid_block_0.build(input_shape)
604
+ input_shape = self.mid_block_0.compute_output_shape(input_shape)
605
+ self.mid_attention.build(input_shape)
606
+ input_shape = self.mid_attention.compute_output_shape(input_shape)
607
+ self.mid_block_1.build(input_shape)
608
+ input_shape = self.mid_block_1.compute_output_shape(input_shape)
609
+ blocks_idx = 0
610
+ upsamples_idx = 0
611
+ for i, _ in enumerate(self.stackwise_num_filters):
612
+ for _ in range(self.stackwise_num_blocks[i]):
613
+ self.blocks[blocks_idx].build(input_shape)
614
+ input_shape = self.blocks[blocks_idx].compute_output_shape(
615
+ input_shape
616
+ )
617
+ blocks_idx += 1
618
+ if i != len(self.stackwise_num_filters) - 1:
619
+ self.upsamples[upsamples_idx].build(input_shape)
620
+ input_shape = self.upsamples[
621
+ upsamples_idx
622
+ ].compute_output_shape(input_shape)
623
+ self.upsamples[upsamples_idx + 1].build(input_shape)
624
+ input_shape = self.upsamples[
625
+ upsamples_idx + 1
626
+ ].compute_output_shape(input_shape)
627
+ upsamples_idx += 2
628
+ self.output_norm.build(input_shape)
629
+ self.output_act.build(input_shape)
630
+ self.output_projection.build(input_shape)
631
+
632
+ def call(self, inputs, training=None):
633
+ x = inputs
634
+ x = self.input_projection(x, training=training)
635
+ x = self.mid_block_0(x, training=training)
636
+ x = self.mid_attention(x, training=training)
637
+ x = self.mid_block_1(x, training=training)
638
+ blocks_idx = 0
639
+ upsamples_idx = 0
640
+ for i, _ in enumerate(self.stackwise_num_filters):
641
+ for _ in range(self.stackwise_num_blocks[i]):
642
+ x = self.blocks[blocks_idx](x, training=training)
643
+ blocks_idx += 1
644
+ if i != len(self.stackwise_num_filters) - 1:
645
+ x = self.upsamples[upsamples_idx](x, training=training)
646
+ x = self.upsamples[upsamples_idx + 1](x, training=training)
647
+ upsamples_idx += 2
648
+ x = self.output_norm(x, training=training)
649
+ x = self.output_act(x, training=training)
650
+ x = self.output_projection(x, training=training)
651
+ return x
652
+
653
+ def get_config(self):
654
+ config = super().get_config()
655
+ config.update(
656
+ {
657
+ "stackwise_num_filters": self.stackwise_num_filters,
658
+ "stackwise_num_blocks": self.stackwise_num_blocks,
659
+ "output_channels": self.output_channels,
660
+ }
661
+ )
662
+ return config
663
+
664
+ def compute_output_shape(self, input_shape):
665
+ if self.data_format == "channels_last":
666
+ h_axis, w_axis, c_axis = 1, 2, 3
667
+ else:
668
+ c_axis, h_axis, w_axis = 1, 2, 3
669
+ scale_factor = 2 ** (len(self.stackwise_num_filters) - 1)
670
+ outputs_shape = list(input_shape)
671
+ if (
672
+ outputs_shape[h_axis] is not None
673
+ and outputs_shape[w_axis] is not None
674
+ ):
675
+ outputs_shape[h_axis] = outputs_shape[h_axis] * scale_factor
676
+ outputs_shape[w_axis] = outputs_shape[w_axis] * scale_factor
677
+ outputs_shape[c_axis] = self.output_channels
678
+ return outputs_shape
679
+
680
+
681
+ class DiagonalGaussianDistributionSampler(keras.layers.Layer):
682
+ """A sampler for a diagonal Gaussian distribution.
683
+
684
+ This layer samples latent variables from a diagonal Gaussian distribution.
685
+
686
+ Args:
687
+ method: str. The method used to sample from the distribution. Available
688
+ methods are `"sample"` and `"mode"`. `"sample"` draws from the
689
+ distribution using both the mean and log variance. `"mode"` draws
690
+ from the distribution using the mean only.
691
+ axis: int. The axis along which to split the mean and log variance.
692
+ Defaults to `-1`.
693
+ seed: optional int. Used as a random seed.
694
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
695
+ including `name`, `dtype` etc.
696
+ """
697
+
698
+ def __init__(self, method, axis=-1, seed=None, **kwargs):
699
+ super().__init__(**kwargs)
700
+ # TODO: Support `kl` and `nll` modes.
701
+ valid_methods = ("sample", "mode")
702
+ if method not in valid_methods:
703
+ raise ValueError(
704
+ f"Invalid method {method}. Valid methods are "
705
+ f"{list(valid_methods)}."
706
+ )
707
+ self.method = method
708
+ self.axis = axis
709
+ self.seed = seed
710
+ self.seed_generator = keras.random.SeedGenerator(seed)
711
+
712
+ def call(self, inputs):
713
+ x = inputs
714
+ if self.method == "sample":
715
+ x_mean, x_logvar = ops.split(x, 2, axis=self.axis)
716
+ x_logvar = ops.clip(x_logvar, -30.0, 20.0)
717
+ x_std = ops.exp(ops.multiply(0.5, x_logvar))
718
+ sample = keras.random.normal(
719
+ ops.shape(x_mean), dtype=x_mean.dtype, seed=self.seed_generator
720
+ )
721
+ x = ops.add(x_mean, ops.multiply(x_std, sample))
722
+ else:
723
+ x, _ = ops.split(x, 2, axis=self.axis)
724
+ return x
725
+
726
+ def get_config(self):
727
+ config = super().get_config()
728
+ config.update(
729
+ {
730
+ "axis": self.axis,
731
+ "seed": self.seed,
732
+ }
733
+ )
734
+ return config
735
+
736
+ def compute_output_shape(self, input_shape):
737
+ output_shape = list(input_shape)
738
+ output_shape[self.axis] = output_shape[self.axis] // 2
739
+ return output_shape