keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. keras_hub/api/layers/__init__.py +21 -3
  2. keras_hub/api/models/__init__.py +71 -12
  3. keras_hub/api/tokenizers/__init__.py +1 -1
  4. keras_hub/src/bounding_box/__init__.py +2 -0
  5. keras_hub/src/bounding_box/converters.py +102 -12
  6. keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
  7. keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
  8. keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
  9. keras_hub/src/layers/modeling/rms_normalization.py +36 -0
  10. keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
  11. keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
  12. keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
  13. keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
  14. keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
  15. keras_hub/src/layers/preprocessing/image_converter.py +170 -34
  16. keras_hub/src/metrics/bleu.py +4 -3
  17. keras_hub/src/models/albert/albert_presets.py +4 -12
  18. keras_hub/src/models/albert/albert_text_classifier.py +7 -7
  19. keras_hub/src/models/backbone.py +3 -14
  20. keras_hub/src/models/bart/bart_backbone.py +4 -4
  21. keras_hub/src/models/bart/bart_presets.py +3 -9
  22. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
  23. keras_hub/src/models/basnet/__init__.py +5 -0
  24. keras_hub/src/models/basnet/basnet.py +122 -0
  25. keras_hub/src/models/basnet/basnet_backbone.py +366 -0
  26. keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
  27. keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
  28. keras_hub/src/models/basnet/basnet_presets.py +17 -0
  29. keras_hub/src/models/bert/bert_presets.py +14 -32
  30. keras_hub/src/models/bert/bert_text_classifier.py +3 -3
  31. keras_hub/src/models/bloom/bloom_presets.py +8 -24
  32. keras_hub/src/models/causal_lm.py +56 -12
  33. keras_hub/src/models/clip/__init__.py +5 -0
  34. keras_hub/src/models/clip/clip_backbone.py +286 -0
  35. keras_hub/src/models/clip/clip_encoder_block.py +19 -4
  36. keras_hub/src/models/clip/clip_image_converter.py +8 -0
  37. keras_hub/src/models/clip/clip_presets.py +93 -0
  38. keras_hub/src/models/clip/clip_text_encoder.py +4 -1
  39. keras_hub/src/models/clip/clip_tokenizer.py +18 -3
  40. keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
  41. keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
  42. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
  43. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
  44. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
  45. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
  46. keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
  47. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
  48. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
  49. keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
  50. keras_hub/src/models/deeplab_v3/__init__.py +7 -0
  51. keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
  52. keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
  53. keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
  54. keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
  55. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
  56. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
  57. keras_hub/src/models/densenet/densenet_backbone.py +6 -4
  58. keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
  59. keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
  60. keras_hub/src/models/densenet/densenet_presets.py +9 -15
  61. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
  62. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
  63. keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
  64. keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
  65. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
  66. keras_hub/src/models/efficientnet/__init__.py +9 -0
  67. keras_hub/src/models/efficientnet/cba.py +141 -0
  68. keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
  69. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
  70. keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
  71. keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
  72. keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
  73. keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
  74. keras_hub/src/models/efficientnet/mbconv.py +53 -22
  75. keras_hub/src/models/electra/electra_backbone.py +2 -2
  76. keras_hub/src/models/electra/electra_presets.py +6 -18
  77. keras_hub/src/models/f_net/f_net_presets.py +2 -6
  78. keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
  79. keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
  80. keras_hub/src/models/falcon/falcon_backbone.py +5 -3
  81. keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
  82. keras_hub/src/models/falcon/falcon_presets.py +1 -3
  83. keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
  84. keras_hub/src/models/feature_pyramid_backbone.py +1 -1
  85. keras_hub/src/models/flux/__init__.py +5 -0
  86. keras_hub/src/models/flux/flux_layers.py +496 -0
  87. keras_hub/src/models/flux/flux_maths.py +225 -0
  88. keras_hub/src/models/flux/flux_model.py +236 -0
  89. keras_hub/src/models/flux/flux_presets.py +3 -0
  90. keras_hub/src/models/flux/flux_text_to_image.py +146 -0
  91. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
  92. keras_hub/src/models/gemma/gemma_backbone.py +35 -20
  93. keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
  94. keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
  95. keras_hub/src/models/gemma/gemma_presets.py +29 -63
  96. keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
  97. keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
  98. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
  99. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
  100. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
  101. keras_hub/src/models/image_classifier.py +147 -2
  102. keras_hub/src/models/image_classifier_preprocessor.py +6 -3
  103. keras_hub/src/models/image_object_detector.py +87 -0
  104. keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
  105. keras_hub/src/models/image_segmenter.py +0 -5
  106. keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
  107. keras_hub/src/models/image_to_image.py +417 -0
  108. keras_hub/src/models/inpaint.py +520 -0
  109. keras_hub/src/models/llama/llama_backbone.py +138 -12
  110. keras_hub/src/models/llama/llama_causal_lm.py +3 -1
  111. keras_hub/src/models/llama/llama_presets.py +10 -20
  112. keras_hub/src/models/llama3/llama3_backbone.py +12 -11
  113. keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
  114. keras_hub/src/models/llama3/llama3_presets.py +4 -12
  115. keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
  116. keras_hub/src/models/mistral/mistral_backbone.py +16 -15
  117. keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
  118. keras_hub/src/models/mistral/mistral_presets.py +3 -9
  119. keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
  120. keras_hub/src/models/mit/__init__.py +6 -0
  121. keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
  122. keras_hub/src/models/mit/mit_image_classifier.py +12 -0
  123. keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
  124. keras_hub/src/models/mit/mit_image_converter.py +8 -0
  125. keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
  126. keras_hub/src/models/mit/mit_presets.py +139 -0
  127. keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
  128. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
  129. keras_hub/src/models/opt/opt_causal_lm.py +2 -2
  130. keras_hub/src/models/opt/opt_presets.py +4 -12
  131. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
  132. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
  133. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
  134. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
  135. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
  136. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
  137. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
  138. keras_hub/src/models/phi3/phi3_decoder.py +0 -1
  139. keras_hub/src/models/phi3/phi3_presets.py +2 -6
  140. keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
  141. keras_hub/src/models/preprocessor.py +25 -11
  142. keras_hub/src/models/resnet/resnet_backbone.py +3 -14
  143. keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
  144. keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
  145. keras_hub/src/models/resnet/resnet_presets.py +127 -18
  146. keras_hub/src/models/retinanet/__init__.py +5 -0
  147. keras_hub/src/models/retinanet/anchor_generator.py +52 -53
  148. keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
  149. keras_hub/src/models/retinanet/non_max_supression.py +1 -0
  150. keras_hub/src/models/retinanet/prediction_head.py +192 -0
  151. keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
  152. keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
  153. keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
  154. keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
  155. keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
  156. keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
  157. keras_hub/src/models/roberta/roberta_backbone.py +2 -2
  158. keras_hub/src/models/roberta/roberta_presets.py +6 -8
  159. keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
  160. keras_hub/src/models/sam/__init__.py +5 -0
  161. keras_hub/src/models/sam/sam_backbone.py +2 -3
  162. keras_hub/src/models/sam/sam_image_converter.py +2 -4
  163. keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
  164. keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
  165. keras_hub/src/models/sam/sam_layers.py +5 -3
  166. keras_hub/src/models/sam/sam_presets.py +3 -9
  167. keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
  168. keras_hub/src/models/sam/sam_transformer.py +5 -4
  169. keras_hub/src/models/segformer/__init__.py +8 -0
  170. keras_hub/src/models/segformer/segformer_backbone.py +167 -0
  171. keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
  172. keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
  173. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
  174. keras_hub/src/models/segformer/segformer_presets.py +136 -0
  175. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
  176. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
  177. keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
  178. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
  179. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
  180. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
  181. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
  182. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
  183. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
  184. keras_hub/src/models/t5/t5_backbone.py +5 -4
  185. keras_hub/src/models/t5/t5_presets.py +47 -19
  186. keras_hub/src/models/task.py +47 -39
  187. keras_hub/src/models/text_classifier.py +2 -2
  188. keras_hub/src/models/text_to_image.py +106 -41
  189. keras_hub/src/models/vae/__init__.py +1 -0
  190. keras_hub/src/models/vae/vae_backbone.py +184 -0
  191. keras_hub/src/models/vae/vae_layers.py +739 -0
  192. keras_hub/src/models/vgg/__init__.py +5 -0
  193. keras_hub/src/models/vgg/vgg_backbone.py +4 -24
  194. keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
  195. keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
  196. keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
  197. keras_hub/src/models/vgg/vgg_presets.py +48 -0
  198. keras_hub/src/models/vit/__init__.py +5 -0
  199. keras_hub/src/models/vit/vit_backbone.py +152 -0
  200. keras_hub/src/models/vit/vit_image_classifier.py +187 -0
  201. keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
  202. keras_hub/src/models/vit/vit_image_converter.py +73 -0
  203. keras_hub/src/models/vit/vit_layers.py +391 -0
  204. keras_hub/src/models/vit/vit_presets.py +126 -0
  205. keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
  206. keras_hub/src/models/vit_det/vit_layers.py +3 -3
  207. keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
  208. keras_hub/src/models/whisper/whisper_backbone.py +6 -5
  209. keras_hub/src/models/whisper/whisper_decoder.py +3 -5
  210. keras_hub/src/models/whisper/whisper_presets.py +10 -30
  211. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
  212. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
  213. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
  214. keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
  215. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
  216. keras_hub/src/models/xlnet/relative_attention.py +20 -19
  217. keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
  218. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
  219. keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
  220. keras_hub/src/samplers/contrastive_sampler.py +2 -3
  221. keras_hub/src/samplers/sampler.py +2 -1
  222. keras_hub/src/tests/test_case.py +41 -6
  223. keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
  224. keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
  225. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
  226. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
  227. keras_hub/src/tokenizers/tokenizer.py +10 -13
  228. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
  229. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
  230. keras_hub/src/utils/keras_utils.py +2 -13
  231. keras_hub/src/utils/pipeline_model.py +3 -3
  232. keras_hub/src/utils/preset_utils.py +196 -144
  233. keras_hub/src/utils/tensor_utils.py +4 -4
  234. keras_hub/src/utils/timm/convert_densenet.py +6 -4
  235. keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
  236. keras_hub/src/utils/timm/convert_resnet.py +1 -1
  237. keras_hub/src/utils/timm/convert_vgg.py +85 -0
  238. keras_hub/src/utils/timm/preset_loader.py +14 -9
  239. keras_hub/src/utils/transformers/convert_llama3.py +21 -5
  240. keras_hub/src/utils/transformers/convert_vit.py +150 -0
  241. keras_hub/src/utils/transformers/preset_loader.py +23 -0
  242. keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
  243. keras_hub/src/version_utils.py +1 -1
  244. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
  245. keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
  246. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
  247. keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
  248. keras_hub/src/models/mix_transformer/__init__.py +0 -0
  249. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
  250. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
  251. keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
  252. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -1,19 +1,20 @@
1
- import os
2
-
3
1
  import keras
4
2
  from rich import console as rich_console
5
3
  from rich import markup
6
4
  from rich import table as rich_table
7
5
 
8
6
  from keras_hub.src.api_export import keras_hub_export
7
+ from keras_hub.src.layers.preprocessing.audio_converter import AudioConverter
8
+ from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
9
+ from keras_hub.src.models.backbone import Backbone
10
+ from keras_hub.src.models.preprocessor import Preprocessor
11
+ from keras_hub.src.tokenizers.tokenizer import Tokenizer
9
12
  from keras_hub.src.utils.keras_utils import print_msg
10
13
  from keras_hub.src.utils.pipeline_model import PipelineModel
11
- from keras_hub.src.utils.preset_utils import TASK_CONFIG_FILE
12
- from keras_hub.src.utils.preset_utils import TASK_WEIGHTS_FILE
13
14
  from keras_hub.src.utils.preset_utils import builtin_presets
14
15
  from keras_hub.src.utils.preset_utils import find_subclass
15
16
  from keras_hub.src.utils.preset_utils import get_preset_loader
16
- from keras_hub.src.utils.preset_utils import save_serialized_object
17
+ from keras_hub.src.utils.preset_utils import get_preset_saver
17
18
  from keras_hub.src.utils.python_utils import classproperty
18
19
 
19
20
 
@@ -58,10 +59,15 @@ class Task(PipelineModel):
58
59
  self.compile()
59
60
 
60
61
  def preprocess_samples(self, x, y=None, sample_weight=None):
61
- if self.preprocessor is not None:
62
+ # If `preprocessor` is `None`, return inputs unaltered.
63
+ if self.preprocessor is None:
64
+ return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
65
+ # If `preprocessor` is `Preprocessor` subclass, pass labels as a kwarg.
66
+ if isinstance(self.preprocessor, Preprocessor):
62
67
  return self.preprocessor(x, y=y, sample_weight=sample_weight)
63
- else:
64
- return super().preprocess_samples(x, y, sample_weight)
68
+ # For other layers and callable, do not pass the label.
69
+ x = self.preprocessor(x)
70
+ return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
65
71
 
66
72
  def __setattr__(self, name, value):
67
73
  # Work around setattr issues for Keras 2 and Keras 3 torch backend.
@@ -143,7 +149,8 @@ class Task(PipelineModel):
143
149
 
144
150
  This constructor can be called in one of two ways. Either from a task
145
151
  specific base class like `keras_hub.models.CausalLM.from_preset()`, or
146
- from a model class like `keras_hub.models.BertTextClassifier.from_preset()`.
152
+ from a model class like
153
+ `keras_hub.models.BertTextClassifier.from_preset()`.
147
154
  If calling from the a base class, the subclass of the returning object
148
155
  will be inferred from the config in the preset directory.
149
156
 
@@ -178,7 +185,10 @@ class Task(PipelineModel):
178
185
  loader = get_preset_loader(preset)
179
186
  backbone_cls = loader.check_backbone_class()
180
187
  # Detect the correct subclass if we need to.
181
- if cls.backbone_cls != backbone_cls:
188
+ if (
189
+ issubclass(backbone_cls, Backbone)
190
+ and cls.backbone_cls != backbone_cls
191
+ ):
182
192
  cls = find_subclass(preset, cls, backbone_cls)
183
193
  # Specifically for classifiers, we never load task weights if
184
194
  # num_classes is supplied. We handle this in the task base class because
@@ -232,17 +242,8 @@ class Task(PipelineModel):
232
242
  Args:
233
243
  preset_dir: The path to the local model preset directory.
234
244
  """
235
- if self.preprocessor is None:
236
- raise ValueError(
237
- "Cannot save `task` to preset: `Preprocessor` is not initialized."
238
- )
239
-
240
- save_serialized_object(self, preset_dir, config_file=TASK_CONFIG_FILE)
241
- if self.has_task_weights():
242
- self.save_task_weights(os.path.join(preset_dir, TASK_WEIGHTS_FILE))
243
-
244
- self.preprocessor.save_to_preset(preset_dir)
245
- self.backbone.save_to_preset(preset_dir)
245
+ saver = get_preset_saver(preset_dir)
246
+ saver.save_task(self)
246
247
 
247
248
  @property
248
249
  def layers(self):
@@ -280,7 +281,7 @@ class Task(PipelineModel):
280
281
 
281
282
  def highlight_number(x):
282
283
  if x is None:
283
- f"[color(45)]{x}[/]"
284
+ return f"[color(45)]{x}[/]"
284
285
  return f"[color(34)]{x:,}[/]" # Format number with commas.
285
286
 
286
287
  def highlight_symbol(x):
@@ -294,7 +295,8 @@ class Task(PipelineModel):
294
295
  return "(" + ", ".join(highlighted) + ")"
295
296
 
296
297
  if self.preprocessor:
297
- # Create a rich console for printing. Capture for non-interactive logging.
298
+ # Create a rich console for printing. Capture for non-interactive
299
+ # logging.
298
300
  if print_fn:
299
301
  console = rich_console.Console(
300
302
  highlight=False, force_terminal=False, color_system=None
@@ -327,24 +329,30 @@ class Task(PipelineModel):
327
329
  info,
328
330
  )
329
331
 
330
- tokenizer = self.preprocessor.tokenizer
331
- if tokenizer:
332
- info = "Vocab size: "
333
- info += highlight_number(tokenizer.vocabulary_size())
334
- add_layer(tokenizer, info)
335
- image_converter = self.preprocessor.image_converter
336
- if image_converter:
337
- info = "Image size: "
338
- info += highlight_shape(image_converter.image_size())
339
- add_layer(image_converter, info)
340
- audio_converter = self.preprocessor.audio_converter
341
- if audio_converter:
342
- info = "Audio shape: "
343
- info += highlight_shape(audio_converter.audio_shape())
344
- add_layer(audio_converter, info)
332
+ # Since the preprocessor might be nested with multiple `Tokenizer`,
333
+ # `ImageConverter`, `AudioConverter` and even other `Preprocessor`
334
+ # instances, we should recursively iterate through them.
335
+ preprocessor = self.preprocessor
336
+ if preprocessor and isinstance(preprocessor, keras.Layer):
337
+ for layer in preprocessor._flatten_layers(include_self=False):
338
+ if isinstance(layer, Tokenizer):
339
+ info = "Vocab size: "
340
+ info += highlight_number(layer.vocabulary_size())
341
+ add_layer(layer, info)
342
+ elif isinstance(layer, ImageConverter):
343
+ info = "Image size: "
344
+ image_size = layer.image_size
345
+ if image_size is None:
346
+ image_size = (None, None)
347
+ info += highlight_shape(image_size)
348
+ add_layer(layer, info)
349
+ elif isinstance(layer, AudioConverter):
350
+ info = "Audio shape: "
351
+ info += highlight_shape(layer.audio_shape())
352
+ add_layer(layer, info)
345
353
 
346
354
  # Print the to the console.
347
- preprocessor_name = markup.escape(self.preprocessor.name)
355
+ preprocessor_name = markup.escape(preprocessor.name)
348
356
  console.print(bold_text(f'Preprocessor: "{preprocessor_name}"'))
349
357
  console.print(table)
350
358
 
@@ -21,8 +21,8 @@ class TextClassifier(Task):
21
21
  To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)`
22
22
  labels where `x` is a string and `y` is a integer from `[0, num_classes)`.
23
23
 
24
- All `TextClassifier` tasks include a `from_preset()` constructor which can be
25
- used to load a pre-trained config and weights.
24
+ All `TextClassifier` tasks include a `from_preset()` constructor which can
25
+ be used to load a pre-trained config and weights.
26
26
 
27
27
  Some, but not all, classification presets include classification head
28
28
  weights in a `task.weights.h5` file. For these presets, you can omit passing
@@ -56,6 +56,11 @@ class TextToImage(Task):
56
56
  # Default compilation.
57
57
  self.compile()
58
58
 
59
+ @property
60
+ def support_negative_prompts(self):
61
+ """Whether the model supports `negative_prompts` key in `generate()`."""
62
+ return bool(True)
63
+
59
64
  @property
60
65
  def latent_shape(self):
61
66
  return tuple(self.backbone.latent_shape)
@@ -171,9 +176,26 @@ class TextToImage(Task):
171
176
  This function converts all inputs to tensors, adds a batch dimension if
172
177
  necessary, and returns a iterable "dataset like" object (either an
173
178
  actual `tf.data.Dataset` or a list with a single batch element).
179
+
180
+ The input format must be one of the following:
181
+ - A single string
182
+ - A list of strings
183
+ - A dict with "prompts" and/or "negative_prompts" keys
184
+ - A tf.data.Dataset with "prompts" and/or "negative_prompts" keys
185
+
186
+ The output will be a dict with "prompts" and/or "negative_prompts" keys.
174
187
  """
175
188
  if tf and isinstance(inputs, tf.data.Dataset):
176
- return inputs.as_numpy_iterator(), False
189
+ _inputs = {
190
+ "prompts": inputs.map(
191
+ lambda x: x["prompts"]
192
+ ).as_numpy_iterator()
193
+ }
194
+ if self.support_negative_prompts:
195
+ _inputs["negative_prompts"] = inputs.map(
196
+ lambda x: x["negative_prompts"]
197
+ ).as_numpy_iterator()
198
+ return _inputs, False
177
199
 
178
200
  def normalize(x):
179
201
  if isinstance(x, str):
@@ -182,13 +204,24 @@ class TextToImage(Task):
182
204
  return x[tf.newaxis], True
183
205
  return x, False
184
206
 
207
+ def get_dummy_prompts(x):
208
+ dummy_prompts = [""] * len(x)
209
+ if tf and isinstance(x, tf.Tensor):
210
+ return tf.convert_to_tensor(dummy_prompts)
211
+ else:
212
+ return dummy_prompts
213
+
185
214
  if isinstance(inputs, dict):
186
215
  for key in inputs:
187
216
  inputs[key], input_is_scalar = normalize(inputs[key])
188
217
  else:
189
218
  inputs, input_is_scalar = normalize(inputs)
219
+ inputs = {"prompts": inputs}
190
220
 
191
- return inputs, input_is_scalar
221
+ if self.support_negative_prompts and "negative_prompts" not in inputs:
222
+ inputs["negative_prompts"] = get_dummy_prompts(inputs["prompts"])
223
+
224
+ return [inputs], input_is_scalar
192
225
 
193
226
  def _normalize_generate_outputs(self, outputs, input_is_scalar):
194
227
  """Normalize user output from the generate function.
@@ -199,12 +232,11 @@ class TextToImage(Task):
199
232
  """
200
233
 
201
234
  def normalize(x):
202
- outputs = ops.clip(ops.divide(ops.add(x, 1.0), 2.0), 0.0, 1.0)
235
+ outputs = ops.concatenate(x, axis=0)
236
+ outputs = ops.clip(ops.divide(ops.add(outputs, 1.0), 2.0), 0.0, 1.0)
203
237
  outputs = ops.cast(ops.round(ops.multiply(outputs, 255.0)), "uint8")
204
- outputs = ops.convert_to_numpy(outputs)
205
- if input_is_scalar:
206
- outputs = outputs[0]
207
- return outputs
238
+ outputs = ops.squeeze(outputs, 0) if input_is_scalar else outputs
239
+ return ops.convert_to_numpy(outputs)
208
240
 
209
241
  if isinstance(outputs[0], dict):
210
242
  normalized = {}
@@ -216,33 +248,62 @@ class TextToImage(Task):
216
248
  def generate(
217
249
  self,
218
250
  inputs,
219
- negative_inputs,
220
251
  num_steps,
221
- guidance_scale,
252
+ guidance_scale=None,
222
253
  seed=None,
223
254
  ):
224
- """Generate image based on the provided `inputs` and `negative_inputs`.
255
+ """Generate image based on the provided `inputs`.
256
+
257
+ Typically, `inputs` contains a text description (known as a prompt) used
258
+ to guide the image generation.
259
+
260
+ Some models support a `negative_prompts` key, which helps steer the
261
+ model away from generating certain styles and elements. To enable this,
262
+ pass `prompts` and `negative_prompts` as a dict:
263
+
264
+ ```python
265
+ prompt = (
266
+ "Astronaut in a jungle, cold color palette, muted colors, "
267
+ "detailed, 8k"
268
+ )
269
+ text_to_image.generate(
270
+ {
271
+ "prompts": prompt,
272
+ "negative_prompts": "green color",
273
+ }
274
+ )
275
+ ```
225
276
 
226
277
  If `inputs` are a `tf.data.Dataset`, outputs will be generated
227
278
  "batch-by-batch" and concatenated. Otherwise, all inputs will be
228
279
  processed as batches.
229
280
 
230
281
  Args:
231
- inputs: python data, tensor data, or a `tf.data.Dataset`.
232
- negative_inputs: python data, tensor data, or a `tf.data.Dataset`.
233
- Unlike `inputs`, these are used as negative inputs to guide the
234
- generation. If not provided, it defaults to `""` for each input
235
- in `inputs`.
282
+ inputs: python data, tensor data, or a `tf.data.Dataset`. The format
283
+ must be one of the following:
284
+ - A single string
285
+ - A list of strings
286
+ - A dict with "prompts" and/or "negative_prompts" keys
287
+ - A `tf.data.Dataset` with "prompts" and/or "negative_prompts"
288
+ keys
236
289
  num_steps: int. The number of diffusion steps to take.
237
- guidance_scale: float. The classifier free guidance scale defined in
238
- [Classifier-Free Diffusion Guidance](
290
+ guidance_scale: Optional float. The classifier free guidance scale
291
+ defined in [Classifier-Free Diffusion Guidance](
239
292
  https://arxiv.org/abs/2207.12598). A higher scale encourages
240
293
  generating images more closely related to the prompts, typically
241
- at the cost of lower image quality.
294
+ at the cost of lower image quality. Note that some models don't
295
+ utilize classifier-free guidance.
242
296
  seed: optional int. Used as a random seed.
243
297
  """
298
+ num_steps = int(num_steps)
299
+ guidance_scale = (
300
+ float(guidance_scale) if guidance_scale is not None else None
301
+ )
244
302
  num_steps = ops.convert_to_tensor(num_steps, "int32")
245
- guidance_scale = ops.convert_to_tensor(guidance_scale)
303
+ if guidance_scale is not None and guidance_scale > 1.0:
304
+ guidance_scale = ops.convert_to_tensor(guidance_scale)
305
+ else:
306
+ guidance_scale = None
246
307
 
247
308
  # Setup our three main passes.
248
309
  # 1. Preprocessing strings to dense integer tensors.
@@ -251,32 +312,36 @@ class TextToImage(Task):
251
312
  generate_function = self.make_generate_function()
252
313
 
253
314
  def preprocess(x):
254
- return self.preprocessor.generate_preprocess(x)
315
+ if self.preprocessor is not None:
316
+ return self.preprocessor.generate_preprocess(x)
317
+ else:
318
+ return x
319
+
320
+ def generate(x):
321
+ token_ids = x[0] if self.support_negative_prompts else x
322
+
323
+ # Initialize latents.
324
+ if isinstance(token_ids, dict):
325
+ arbitrary_key = list(token_ids.keys())[0]
326
+ batch_size = ops.shape(token_ids[arbitrary_key])[0]
327
+ else:
328
+ batch_size = ops.shape(token_ids)[0]
329
+ latent_shape = (batch_size,) + self.latent_shape[1:]
330
+ latents = random.normal(latent_shape, dtype="float32", seed=seed)
331
+
332
+ return generate_function(latents, x, num_steps, guidance_scale)
255
333
 
256
334
  # Normalize and preprocess inputs.
257
335
  inputs, input_is_scalar = self._normalize_generate_inputs(inputs)
258
- if negative_inputs is None:
259
- negative_inputs = [""] * len(inputs)
260
- negative_inputs, _ = self._normalize_generate_inputs(negative_inputs)
261
-
262
- if self.preprocessor is not None:
263
- inputs = preprocess(inputs)
264
- negative_inputs = preprocess(negative_inputs)
265
- if isinstance(inputs, dict):
266
- batch_size = len(inputs[list(inputs.keys())[0]])
336
+ if self.support_negative_prompts:
337
+ token_ids = [preprocess(x["prompts"]) for x in inputs]
338
+ negative_token_ids = [
339
+ preprocess(x["negative_prompts"]) for x in inputs
340
+ ]
341
+ inputs = [x for x in zip(token_ids, negative_token_ids)]
267
342
  else:
268
- batch_size = len(inputs)
269
-
270
- # Initialize random latents.
271
- latent_shape = (batch_size,) + self.latent_shape[1:]
272
- latents = random.normal(latent_shape, dtype="float32", seed=seed)
343
+ inputs = [preprocess(x["prompts"]) for x in inputs]
273
344
 
274
345
  # Text-to-image.
275
- outputs = generate_function(
276
- latents,
277
- inputs,
278
- negative_inputs,
279
- num_steps,
280
- guidance_scale,
281
- )
346
+ outputs = [generate(x) for x in inputs]
282
347
  return self._normalize_generate_outputs(outputs, input_is_scalar)
@@ -0,0 +1 @@
1
+ from keras_hub.src.models.vae.vae_backbone import VAEBackbone
@@ -0,0 +1,184 @@
1
+ import keras
2
+
3
+ from keras_hub.src.models.backbone import Backbone
4
+ from keras_hub.src.models.vae.vae_layers import (
5
+ DiagonalGaussianDistributionSampler,
6
+ )
7
+ from keras_hub.src.models.vae.vae_layers import VAEDecoder
8
+ from keras_hub.src.models.vae.vae_layers import VAEEncoder
9
+ from keras_hub.src.utils.keras_utils import standardize_data_format
10
+
11
+
12
+ class VAEBackbone(Backbone):
13
+ """Variational Autoencoder(VAE) backbone used in latent diffusion models.
14
+
15
+ When encoding, this model generates mean and log variance of the input
16
+ images. When decoding, it reconstructs images from the latent space.
17
+
18
+ Args:
19
+ encoder_num_filters: list of ints. The number of filters for each
20
+ block in encoder.
21
+ encoder_num_blocks: list of ints. The number of blocks for each block in
22
+ encoder.
23
+ decoder_num_filters: list of ints. The number of filters for each
24
+ block in decoder.
25
+ decoder_num_blocks: list of ints. The number of blocks for each block in
26
+ decoder.
27
+ sampler_method: str. The method of the sampler for the intermediate
28
+ output. Available methods are `"sample"` and `"mode"`. `"sample"`
29
+ draws from the distribution using both the mean and log variance.
30
+ `"mode"` draws from the distribution using the mean only. Defaults
31
+ to `sample`.
32
+ input_channels: int. The number of channels in the input.
33
+ sample_channels: int. The number of channels in the sample. Typically,
34
+ this indicates the intermediate output of VAE, which is mean and
35
+ log variance.
36
+ output_channels: int. The number of channels in the output.
37
+ scale: float. The scaling factor applied to the latent space to ensure
38
+ it has unit variance during training of the diffusion model.
39
+ Defaults to `1.5305`, which is the value used in Stable Diffusion 3.
40
+ shift: float. The shift factor applied to the latent space to ensure it
41
+ has zero mean during training of the diffusion model. Defaults to
42
+ `0.0609`, which is the value used in Stable Diffusion 3.
43
+ data_format: `None` or str. If specified, either `"channels_last"` or
44
+ `"channels_first"`. The ordering of the dimensions in the
45
+ inputs. `"channels_last"` corresponds to inputs with shape
46
+ `(batch_size, height, width, channels)`
47
+ while `"channels_first"` corresponds to inputs with shape
48
+ `(batch_size, channels, height, width)`. It defaults to the
49
+ `image_data_format` value found in your Keras config file at
50
+ `~/.keras/keras.json`. If you never set it, then it will be
51
+ `"channels_last"`.
52
+ dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
53
+ to use for the model's computations and weights.
54
+
55
+ Example:
56
+ ```Python
57
+ backbone = VAEBackbone(
58
+ encoder_num_filters=[32, 32, 32, 32],
59
+ encoder_num_blocks=[1, 1, 1, 1],
60
+ decoder_num_filters=[32, 32, 32, 32],
61
+ decoder_num_blocks=[1, 1, 1, 1],
62
+ )
63
+ input_data = ops.ones((2, self.height, self.width, 3))
64
+ output = backbone(input_data)
65
+ ```
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ encoder_num_filters,
71
+ encoder_num_blocks,
72
+ decoder_num_filters,
73
+ decoder_num_blocks,
74
+ sampler_method="sample",
75
+ input_channels=3,
76
+ sample_channels=32,
77
+ output_channels=3,
78
+ scale=1.5305,
79
+ shift=0.0609,
80
+ data_format=None,
81
+ dtype=None,
82
+ **kwargs,
83
+ ):
84
+ data_format = standardize_data_format(data_format)
85
+ if data_format == "channels_last":
86
+ image_shape = (None, None, input_channels)
87
+ channel_axis = -1
88
+ else:
89
+ image_shape = (input_channels, None, None)
90
+ channel_axis = 1
91
+
92
+ # === Layers ===
93
+ self.encoder = VAEEncoder(
94
+ encoder_num_filters,
95
+ encoder_num_blocks,
96
+ output_channels=sample_channels,
97
+ data_format=data_format,
98
+ dtype=dtype,
99
+ name="encoder",
100
+ )
101
+ # Use `sample()` to define the functional model.
102
+ self.distribution_sampler = DiagonalGaussianDistributionSampler(
103
+ method=sampler_method,
104
+ axis=channel_axis,
105
+ dtype=dtype,
106
+ name="distribution_sampler",
107
+ )
108
+ self.decoder = VAEDecoder(
109
+ decoder_num_filters,
110
+ decoder_num_blocks,
111
+ output_channels=output_channels,
112
+ data_format=data_format,
113
+ dtype=dtype,
114
+ name="decoder",
115
+ )
116
+
117
+ # === Functional Model ===
118
+ image_input = keras.Input(shape=image_shape)
119
+ sample = self.encoder(image_input)
120
+ latent = self.distribution_sampler(sample)
121
+ image_output = self.decoder(latent)
122
+ super().__init__(
123
+ inputs=image_input,
124
+ outputs=image_output,
125
+ dtype=dtype,
126
+ **kwargs,
127
+ )
128
+
129
+ # === Config ===
130
+ self.encoder_num_filters = encoder_num_filters
131
+ self.encoder_num_blocks = encoder_num_blocks
132
+ self.decoder_num_filters = decoder_num_filters
133
+ self.decoder_num_blocks = decoder_num_blocks
134
+ self.sampler_method = sampler_method
135
+ self.input_channels = input_channels
136
+ self.sample_channels = sample_channels
137
+ self.output_channels = output_channels
138
+ self._scale = scale
139
+ self._shift = shift
140
+
141
+ @property
142
+ def scale(self):
143
+ """The scaling factor for the latent space.
144
+
145
+ This is used to scale the latent space to have unit variance when
146
+ training the diffusion model.
147
+ """
148
+ return self._scale
149
+
150
+ @property
151
+ def shift(self):
152
+ """The shift factor for the latent space.
153
+
154
+ This is used to shift the latent space to have zero mean when
155
+ training the diffusion model.
156
+ """
157
+ return self._shift
158
+
159
+ def encode(self, inputs, **kwargs):
160
+ """Encode the input images into latent space."""
161
+ sample = self.encoder(inputs, **kwargs)
162
+ return self.distribution_sampler(sample)
163
+
164
+ def decode(self, inputs, **kwargs):
165
+ """Decode the input latent space into images."""
166
+ return self.decoder(inputs, **kwargs)
167
+
168
+ def get_config(self):
169
+ config = super().get_config()
170
+ config.update(
171
+ {
172
+ "encoder_num_filters": self.encoder_num_filters,
173
+ "encoder_num_blocks": self.encoder_num_blocks,
174
+ "decoder_num_filters": self.decoder_num_filters,
175
+ "decoder_num_blocks": self.decoder_num_blocks,
176
+ "sampler_method": self.sampler_method,
177
+ "input_channels": self.input_channels,
178
+ "sample_channels": self.sample_channels,
179
+ "output_channels": self.output_channels,
180
+ "scale": self.scale,
181
+ "shift": self.shift,
182
+ }
183
+ )
184
+ return config