keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. keras_hub/api/layers/__init__.py +21 -3
  2. keras_hub/api/models/__init__.py +71 -12
  3. keras_hub/api/tokenizers/__init__.py +1 -1
  4. keras_hub/src/bounding_box/__init__.py +2 -0
  5. keras_hub/src/bounding_box/converters.py +102 -12
  6. keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
  7. keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
  8. keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
  9. keras_hub/src/layers/modeling/rms_normalization.py +36 -0
  10. keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
  11. keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
  12. keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
  13. keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
  14. keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
  15. keras_hub/src/layers/preprocessing/image_converter.py +170 -34
  16. keras_hub/src/metrics/bleu.py +4 -3
  17. keras_hub/src/models/albert/albert_presets.py +4 -12
  18. keras_hub/src/models/albert/albert_text_classifier.py +7 -7
  19. keras_hub/src/models/backbone.py +3 -14
  20. keras_hub/src/models/bart/bart_backbone.py +4 -4
  21. keras_hub/src/models/bart/bart_presets.py +3 -9
  22. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
  23. keras_hub/src/models/basnet/__init__.py +5 -0
  24. keras_hub/src/models/basnet/basnet.py +122 -0
  25. keras_hub/src/models/basnet/basnet_backbone.py +366 -0
  26. keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
  27. keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
  28. keras_hub/src/models/basnet/basnet_presets.py +17 -0
  29. keras_hub/src/models/bert/bert_presets.py +14 -32
  30. keras_hub/src/models/bert/bert_text_classifier.py +3 -3
  31. keras_hub/src/models/bloom/bloom_presets.py +8 -24
  32. keras_hub/src/models/causal_lm.py +56 -12
  33. keras_hub/src/models/clip/__init__.py +5 -0
  34. keras_hub/src/models/clip/clip_backbone.py +286 -0
  35. keras_hub/src/models/clip/clip_encoder_block.py +19 -4
  36. keras_hub/src/models/clip/clip_image_converter.py +8 -0
  37. keras_hub/src/models/clip/clip_presets.py +93 -0
  38. keras_hub/src/models/clip/clip_text_encoder.py +4 -1
  39. keras_hub/src/models/clip/clip_tokenizer.py +18 -3
  40. keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
  41. keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
  42. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
  43. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
  44. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
  45. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
  46. keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
  47. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
  48. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
  49. keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
  50. keras_hub/src/models/deeplab_v3/__init__.py +7 -0
  51. keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
  52. keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
  53. keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
  54. keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
  55. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
  56. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
  57. keras_hub/src/models/densenet/densenet_backbone.py +6 -4
  58. keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
  59. keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
  60. keras_hub/src/models/densenet/densenet_presets.py +9 -15
  61. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
  62. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
  63. keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
  64. keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
  65. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
  66. keras_hub/src/models/efficientnet/__init__.py +9 -0
  67. keras_hub/src/models/efficientnet/cba.py +141 -0
  68. keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
  69. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
  70. keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
  71. keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
  72. keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
  73. keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
  74. keras_hub/src/models/efficientnet/mbconv.py +53 -22
  75. keras_hub/src/models/electra/electra_backbone.py +2 -2
  76. keras_hub/src/models/electra/electra_presets.py +6 -18
  77. keras_hub/src/models/f_net/f_net_presets.py +2 -6
  78. keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
  79. keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
  80. keras_hub/src/models/falcon/falcon_backbone.py +5 -3
  81. keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
  82. keras_hub/src/models/falcon/falcon_presets.py +1 -3
  83. keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
  84. keras_hub/src/models/feature_pyramid_backbone.py +1 -1
  85. keras_hub/src/models/flux/__init__.py +5 -0
  86. keras_hub/src/models/flux/flux_layers.py +496 -0
  87. keras_hub/src/models/flux/flux_maths.py +225 -0
  88. keras_hub/src/models/flux/flux_model.py +236 -0
  89. keras_hub/src/models/flux/flux_presets.py +3 -0
  90. keras_hub/src/models/flux/flux_text_to_image.py +146 -0
  91. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
  92. keras_hub/src/models/gemma/gemma_backbone.py +35 -20
  93. keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
  94. keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
  95. keras_hub/src/models/gemma/gemma_presets.py +29 -63
  96. keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
  97. keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
  98. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
  99. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
  100. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
  101. keras_hub/src/models/image_classifier.py +147 -2
  102. keras_hub/src/models/image_classifier_preprocessor.py +6 -3
  103. keras_hub/src/models/image_object_detector.py +87 -0
  104. keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
  105. keras_hub/src/models/image_segmenter.py +0 -5
  106. keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
  107. keras_hub/src/models/image_to_image.py +417 -0
  108. keras_hub/src/models/inpaint.py +520 -0
  109. keras_hub/src/models/llama/llama_backbone.py +138 -12
  110. keras_hub/src/models/llama/llama_causal_lm.py +3 -1
  111. keras_hub/src/models/llama/llama_presets.py +10 -20
  112. keras_hub/src/models/llama3/llama3_backbone.py +12 -11
  113. keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
  114. keras_hub/src/models/llama3/llama3_presets.py +4 -12
  115. keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
  116. keras_hub/src/models/mistral/mistral_backbone.py +16 -15
  117. keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
  118. keras_hub/src/models/mistral/mistral_presets.py +3 -9
  119. keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
  120. keras_hub/src/models/mit/__init__.py +6 -0
  121. keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
  122. keras_hub/src/models/mit/mit_image_classifier.py +12 -0
  123. keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
  124. keras_hub/src/models/mit/mit_image_converter.py +8 -0
  125. keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
  126. keras_hub/src/models/mit/mit_presets.py +139 -0
  127. keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
  128. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
  129. keras_hub/src/models/opt/opt_causal_lm.py +2 -2
  130. keras_hub/src/models/opt/opt_presets.py +4 -12
  131. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
  132. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
  133. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
  134. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
  135. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
  136. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
  137. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
  138. keras_hub/src/models/phi3/phi3_decoder.py +0 -1
  139. keras_hub/src/models/phi3/phi3_presets.py +2 -6
  140. keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
  141. keras_hub/src/models/preprocessor.py +25 -11
  142. keras_hub/src/models/resnet/resnet_backbone.py +3 -14
  143. keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
  144. keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
  145. keras_hub/src/models/resnet/resnet_presets.py +127 -18
  146. keras_hub/src/models/retinanet/__init__.py +5 -0
  147. keras_hub/src/models/retinanet/anchor_generator.py +52 -53
  148. keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
  149. keras_hub/src/models/retinanet/non_max_supression.py +1 -0
  150. keras_hub/src/models/retinanet/prediction_head.py +192 -0
  151. keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
  152. keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
  153. keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
  154. keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
  155. keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
  156. keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
  157. keras_hub/src/models/roberta/roberta_backbone.py +2 -2
  158. keras_hub/src/models/roberta/roberta_presets.py +6 -8
  159. keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
  160. keras_hub/src/models/sam/__init__.py +5 -0
  161. keras_hub/src/models/sam/sam_backbone.py +2 -3
  162. keras_hub/src/models/sam/sam_image_converter.py +2 -4
  163. keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
  164. keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
  165. keras_hub/src/models/sam/sam_layers.py +5 -3
  166. keras_hub/src/models/sam/sam_presets.py +3 -9
  167. keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
  168. keras_hub/src/models/sam/sam_transformer.py +5 -4
  169. keras_hub/src/models/segformer/__init__.py +8 -0
  170. keras_hub/src/models/segformer/segformer_backbone.py +167 -0
  171. keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
  172. keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
  173. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
  174. keras_hub/src/models/segformer/segformer_presets.py +136 -0
  175. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
  176. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
  177. keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
  178. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
  179. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
  180. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
  181. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
  182. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
  183. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
  184. keras_hub/src/models/t5/t5_backbone.py +5 -4
  185. keras_hub/src/models/t5/t5_presets.py +47 -19
  186. keras_hub/src/models/task.py +47 -39
  187. keras_hub/src/models/text_classifier.py +2 -2
  188. keras_hub/src/models/text_to_image.py +106 -41
  189. keras_hub/src/models/vae/__init__.py +1 -0
  190. keras_hub/src/models/vae/vae_backbone.py +184 -0
  191. keras_hub/src/models/vae/vae_layers.py +739 -0
  192. keras_hub/src/models/vgg/__init__.py +5 -0
  193. keras_hub/src/models/vgg/vgg_backbone.py +4 -24
  194. keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
  195. keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
  196. keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
  197. keras_hub/src/models/vgg/vgg_presets.py +48 -0
  198. keras_hub/src/models/vit/__init__.py +5 -0
  199. keras_hub/src/models/vit/vit_backbone.py +152 -0
  200. keras_hub/src/models/vit/vit_image_classifier.py +187 -0
  201. keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
  202. keras_hub/src/models/vit/vit_image_converter.py +73 -0
  203. keras_hub/src/models/vit/vit_layers.py +391 -0
  204. keras_hub/src/models/vit/vit_presets.py +126 -0
  205. keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
  206. keras_hub/src/models/vit_det/vit_layers.py +3 -3
  207. keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
  208. keras_hub/src/models/whisper/whisper_backbone.py +6 -5
  209. keras_hub/src/models/whisper/whisper_decoder.py +3 -5
  210. keras_hub/src/models/whisper/whisper_presets.py +10 -30
  211. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
  212. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
  213. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
  214. keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
  215. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
  216. keras_hub/src/models/xlnet/relative_attention.py +20 -19
  217. keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
  218. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
  219. keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
  220. keras_hub/src/samplers/contrastive_sampler.py +2 -3
  221. keras_hub/src/samplers/sampler.py +2 -1
  222. keras_hub/src/tests/test_case.py +41 -6
  223. keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
  224. keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
  225. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
  226. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
  227. keras_hub/src/tokenizers/tokenizer.py +10 -13
  228. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
  229. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
  230. keras_hub/src/utils/keras_utils.py +2 -13
  231. keras_hub/src/utils/pipeline_model.py +3 -3
  232. keras_hub/src/utils/preset_utils.py +196 -144
  233. keras_hub/src/utils/tensor_utils.py +4 -4
  234. keras_hub/src/utils/timm/convert_densenet.py +6 -4
  235. keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
  236. keras_hub/src/utils/timm/convert_resnet.py +1 -1
  237. keras_hub/src/utils/timm/convert_vgg.py +85 -0
  238. keras_hub/src/utils/timm/preset_loader.py +14 -9
  239. keras_hub/src/utils/transformers/convert_llama3.py +21 -5
  240. keras_hub/src/utils/transformers/convert_vit.py +150 -0
  241. keras_hub/src/utils/transformers/preset_loader.py +23 -0
  242. keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
  243. keras_hub/src/version_utils.py +1 -1
  244. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
  245. keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
  246. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
  247. keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
  248. keras_hub/src/models/mix_transformer/__init__.py +0 -0
  249. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
  250. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
  251. keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
  252. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -1,119 +0,0 @@
1
- import keras
2
-
3
- from keras_hub.src.api_export import keras_hub_export
4
- from keras_hub.src.models.image_classifier import ImageClassifier
5
- from keras_hub.src.models.mix_transformer.mix_transformer_backbone import (
6
- MiTBackbone,
7
- )
8
-
9
-
10
- @keras_hub_export("keras_hub.models.MiTImageClassifier")
11
- class MiTImageClassifier(ImageClassifier):
12
- """MiTImageClassifier image classifier model.
13
-
14
- Args:
15
- backbone: A `keras_hub.models.MiTBackbone` instance.
16
- num_classes: int. The number of classes to predict.
17
- activation: `None`, str or callable. The activation function to use on
18
- the `Dense` layer. Set `activation=None` to return the output
19
- logits. Defaults to `"softmax"`.
20
-
21
- To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)`
22
- where `x` is a tensor and `y` is a integer from `[0, num_classes)`.
23
- All `ImageClassifier` tasks include a `from_preset()` constructor which can
24
- be used to load a pre-trained config and weights.
25
-
26
- Examples:
27
-
28
- Call `predict()` to run inference.
29
- ```python
30
- # Load preset and train
31
- images = np.ones((2, 224, 224, 3), dtype="float32")
32
- classifier = keras_hub.models.MiTImageClassifier.from_preset(
33
- "mit_b0_imagenet")
34
- classifier.predict(images)
35
- ```
36
-
37
- Call `fit()` on a single batch.
38
- ```python
39
- # Load preset and train
40
- images = np.ones((2, 224, 224, 3), dtype="float32")
41
- labels = [0, 3]
42
- classifier = keras_hub.models.MixTransformerImageClassifier.from_preset(
43
- "mit_b0_imagenet")
44
- classifier.fit(x=images, y=labels, batch_size=2)
45
- ```
46
-
47
- Call `fit()` with custom loss, optimizer and backbone.
48
- ```python
49
- classifier = keras_hub.models.MiTImageClassifier.from_preset(
50
- "mit_b0_imagenet")
51
- classifier.compile(
52
- loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
53
- optimizer=keras.optimizers.Adam(5e-5),
54
- )
55
- classifier.backbone.trainable = False
56
- classifier.fit(x=images, y=labels, batch_size=2)
57
- ```
58
-
59
- Custom backbone.
60
- ```python
61
- images = np.ones((2, 224, 224, 3), dtype="float32")
62
- labels = [0, 3]
63
- backbone = keras_hub.models.MiTBackbone(
64
- stackwise_num_filters=[128, 256, 512, 1024],
65
- stackwise_depth=[3, 9, 9, 3],
66
- block_type="basic_block",
67
- image_shape = (224, 224, 3),
68
- )
69
- classifier = keras_hub.models.MiTImageClassifier(
70
- backbone=backbone,
71
- num_classes=4,
72
- )
73
- classifier.fit(x=images, y=labels, batch_size=2)
74
- ```
75
- """
76
-
77
- backbone_cls = MiTBackbone
78
-
79
- def __init__(
80
- self,
81
- backbone,
82
- num_classes,
83
- activation="softmax",
84
- preprocessor=None, # adding this dummy arg for saved model test
85
- # TODO: once preprocessor flow is figured out, this needs to be updated
86
- **kwargs,
87
- ):
88
- # === Layers ===
89
- self.backbone = backbone
90
- self.output_dense = keras.layers.Dense(
91
- num_classes,
92
- activation=activation,
93
- name="predictions",
94
- )
95
-
96
- # === Functional Model ===
97
- inputs = self.backbone.input
98
- x = self.backbone(inputs)
99
- outputs = self.output_dense(x)
100
- super().__init__(
101
- inputs=inputs,
102
- outputs=outputs,
103
- **kwargs,
104
- )
105
-
106
- # === Config ===
107
- self.num_classes = num_classes
108
- self.activation = activation
109
-
110
- def get_config(self):
111
- # Backbone serialized in `super`
112
- config = super().get_config()
113
- config.update(
114
- {
115
- "num_classes": self.num_classes,
116
- "activation": self.activation,
117
- }
118
- )
119
- return config
@@ -1,320 +0,0 @@
1
- import math
2
-
3
- from keras import layers
4
- from keras import ops
5
-
6
- from keras_hub.src.models.backbone import Backbone
7
- from keras_hub.src.utils.keras_utils import standardize_data_format
8
-
9
-
10
- class VAEAttention(layers.Layer):
11
- def __init__(self, filters, groups=32, data_format=None, **kwargs):
12
- super().__init__(**kwargs)
13
- self.filters = filters
14
- self.data_format = standardize_data_format(data_format)
15
- gn_axis = -1 if self.data_format == "channels_last" else 1
16
-
17
- self.group_norm = layers.GroupNormalization(
18
- groups=groups,
19
- axis=gn_axis,
20
- epsilon=1e-6,
21
- dtype="float32",
22
- name="group_norm",
23
- )
24
- self.query_conv2d = layers.Conv2D(
25
- filters,
26
- 1,
27
- 1,
28
- data_format=self.data_format,
29
- dtype=self.dtype_policy,
30
- name="query_conv2d",
31
- )
32
- self.key_conv2d = layers.Conv2D(
33
- filters,
34
- 1,
35
- 1,
36
- data_format=self.data_format,
37
- dtype=self.dtype_policy,
38
- name="key_conv2d",
39
- )
40
- self.value_conv2d = layers.Conv2D(
41
- filters,
42
- 1,
43
- 1,
44
- data_format=self.data_format,
45
- dtype=self.dtype_policy,
46
- name="value_conv2d",
47
- )
48
- self.softmax = layers.Softmax(dtype="float32")
49
- self.output_conv2d = layers.Conv2D(
50
- filters,
51
- 1,
52
- 1,
53
- data_format=self.data_format,
54
- dtype=self.dtype_policy,
55
- name="output_conv2d",
56
- )
57
-
58
- self.groups = groups
59
- self._inverse_sqrt_filters = 1.0 / math.sqrt(float(filters))
60
-
61
- def build(self, input_shape):
62
- self.group_norm.build(input_shape)
63
- self.query_conv2d.build(input_shape)
64
- self.key_conv2d.build(input_shape)
65
- self.value_conv2d.build(input_shape)
66
- self.output_conv2d.build(input_shape)
67
-
68
- def call(self, inputs, training=None):
69
- x = self.group_norm(inputs)
70
- query = self.query_conv2d(x)
71
- key = self.key_conv2d(x)
72
- value = self.value_conv2d(x)
73
-
74
- if self.data_format == "channels_first":
75
- query = ops.transpose(query, (0, 2, 3, 1))
76
- key = ops.transpose(key, (0, 2, 3, 1))
77
- value = ops.transpose(value, (0, 2, 3, 1))
78
- shape = ops.shape(inputs)
79
- b = shape[0]
80
- query = ops.reshape(query, (b, -1, self.filters))
81
- key = ops.reshape(key, (b, -1, self.filters))
82
- value = ops.reshape(value, (b, -1, self.filters))
83
-
84
- # Compute attention.
85
- query = ops.multiply(
86
- query, ops.cast(self._inverse_sqrt_filters, query.dtype)
87
- )
88
- # [B, H0 * W0, C], [B, H1 * W1, C] -> [B, H0 * W0, H1 * W1]
89
- attention_scores = ops.einsum("abc,adc->abd", query, key)
90
- attention_scores = ops.cast(
91
- self.softmax(attention_scores), self.compute_dtype
92
- )
93
- # [B, H2 * W2, C], [B, H0 * W0, H1 * W1] -> [B, H1 * W1 ,C]
94
- attention_output = ops.einsum("abc,adb->adc", value, attention_scores)
95
- x = ops.reshape(attention_output, shape)
96
-
97
- x = self.output_conv2d(x)
98
- if self.data_format == "channels_first":
99
- x = ops.transpose(x, (0, 3, 1, 2))
100
- x = ops.add(x, inputs)
101
- return x
102
-
103
- def get_config(self):
104
- config = super().get_config()
105
- config.update(
106
- {
107
- "filters": self.filters,
108
- "groups": self.groups,
109
- }
110
- )
111
- return config
112
-
113
- def compute_output_shape(self, input_shape):
114
- return input_shape
115
-
116
-
117
- def apply_resnet_block(x, filters, data_format=None, dtype=None, name=None):
118
- data_format = standardize_data_format(data_format)
119
- gn_axis = -1 if data_format == "channels_last" else 1
120
- input_filters = x.shape[gn_axis]
121
-
122
- residual = x
123
- x = layers.GroupNormalization(
124
- groups=32,
125
- axis=gn_axis,
126
- epsilon=1e-6,
127
- dtype="float32",
128
- name=f"{name}_norm1",
129
- )(x)
130
- x = layers.Activation("swish", dtype=dtype)(x)
131
- x = layers.Conv2D(
132
- filters,
133
- 3,
134
- 1,
135
- padding="same",
136
- data_format=data_format,
137
- dtype=dtype,
138
- name=f"{name}_conv1",
139
- )(x)
140
- x = layers.GroupNormalization(
141
- groups=32,
142
- axis=gn_axis,
143
- epsilon=1e-6,
144
- dtype="float32",
145
- name=f"{name}_norm2",
146
- )(x)
147
- x = layers.Activation("swish", dtype=dtype)(x)
148
- x = layers.Conv2D(
149
- filters,
150
- 3,
151
- 1,
152
- padding="same",
153
- data_format=data_format,
154
- dtype=dtype,
155
- name=f"{name}_conv2",
156
- )(x)
157
- if input_filters != filters:
158
- residual = layers.Conv2D(
159
- filters,
160
- 1,
161
- 1,
162
- data_format=data_format,
163
- dtype=dtype,
164
- name=f"{name}_residual_projection",
165
- )(residual)
166
- x = layers.Add(dtype=dtype)([residual, x])
167
- return x
168
-
169
-
170
- class VAEImageDecoder(Backbone):
171
- """Decoder for the VAE model used in Stable Diffusion 3.
172
-
173
- Args:
174
- stackwise_num_filters: list of ints. The number of filters for each
175
- stack.
176
- stackwise_num_blocks: list of ints. The number of blocks for each stack.
177
- output_channels: int. The number of channels in the output.
178
- latent_shape: tuple. The shape of the latent image.
179
- data_format: `None` or str. If specified, either `"channels_last"` or
180
- `"channels_first"`. The ordering of the dimensions in the
181
- inputs. `"channels_last"` corresponds to inputs with shape
182
- `(batch_size, height, width, channels)`
183
- while `"channels_first"` corresponds to inputs with shape
184
- `(batch_size, channels, height, width)`. It defaults to the
185
- `image_data_format` value found in your Keras config file at
186
- `~/.keras/keras.json`. If you never set it, then it will be
187
- `"channels_last"`.
188
- dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
189
- to use for the model's computations and weights.
190
- """
191
-
192
- def __init__(
193
- self,
194
- stackwise_num_filters,
195
- stackwise_num_blocks,
196
- output_channels=3,
197
- latent_shape=(None, None, 16),
198
- data_format=None,
199
- dtype=None,
200
- **kwargs,
201
- ):
202
- data_format = standardize_data_format(data_format)
203
- gn_axis = -1 if data_format == "channels_last" else 1
204
-
205
- # === Functional Model ===
206
- latent_inputs = layers.Input(shape=latent_shape)
207
-
208
- x = layers.Conv2D(
209
- stackwise_num_filters[0],
210
- 3,
211
- 1,
212
- padding="same",
213
- data_format=data_format,
214
- dtype=dtype,
215
- name="input_projection",
216
- )(latent_inputs)
217
- x = apply_resnet_block(
218
- x,
219
- stackwise_num_filters[0],
220
- data_format=data_format,
221
- dtype=dtype,
222
- name="input_block0",
223
- )
224
- x = VAEAttention(
225
- stackwise_num_filters[0],
226
- data_format=data_format,
227
- dtype=dtype,
228
- name="input_attention",
229
- )(x)
230
- x = apply_resnet_block(
231
- x,
232
- stackwise_num_filters[0],
233
- data_format=data_format,
234
- dtype=dtype,
235
- name="input_block1",
236
- )
237
-
238
- # Stacks.
239
- for i, filters in enumerate(stackwise_num_filters):
240
- for j in range(stackwise_num_blocks[i]):
241
- x = apply_resnet_block(
242
- x,
243
- filters,
244
- data_format=data_format,
245
- dtype=dtype,
246
- name=f"block{i}_{j}",
247
- )
248
- if i != len(stackwise_num_filters) - 1:
249
- # No upsamling in the last blcok.
250
- x = layers.UpSampling2D(
251
- 2,
252
- data_format=data_format,
253
- dtype=dtype,
254
- name=f"upsample_{i}",
255
- )(x)
256
- x = layers.Conv2D(
257
- filters,
258
- 3,
259
- 1,
260
- padding="same",
261
- data_format=data_format,
262
- dtype=dtype,
263
- name=f"upsample_{i}_conv",
264
- )(x)
265
-
266
- # Ouput block.
267
- x = layers.GroupNormalization(
268
- groups=32,
269
- axis=gn_axis,
270
- epsilon=1e-6,
271
- dtype="float32",
272
- name="output_norm",
273
- )(x)
274
- x = layers.Activation("swish", dtype=dtype, name="output_activation")(x)
275
- image_outputs = layers.Conv2D(
276
- output_channels,
277
- 3,
278
- 1,
279
- padding="same",
280
- data_format=data_format,
281
- dtype=dtype,
282
- name="output_projection",
283
- )(x)
284
- super().__init__(inputs=latent_inputs, outputs=image_outputs, **kwargs)
285
-
286
- # === Config ===
287
- self.stackwise_num_filters = stackwise_num_filters
288
- self.stackwise_num_blocks = stackwise_num_blocks
289
- self.output_channels = output_channels
290
- self.latent_shape = latent_shape
291
-
292
- @property
293
- def scaling_factor(self):
294
- """The scaling factor for the latent space.
295
-
296
- This is used to scale the latent space to have unit variance when
297
- training the diffusion model.
298
- """
299
- return 1.5305
300
-
301
- @property
302
- def shift_factor(self):
303
- """The shift factor for the latent space.
304
-
305
- This is used to shift the latent space to have zero mean when
306
- training the diffusion model.
307
- """
308
- return 0.0609
309
-
310
- def get_config(self):
311
- config = super().get_config()
312
- config.update(
313
- {
314
- "stackwise_num_filters": self.stackwise_num_filters,
315
- "stackwise_num_blocks": self.stackwise_num_blocks,
316
- "output_channels": self.output_channels,
317
- "image_shape": self.latent_shape,
318
- }
319
- )
320
- return config