keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. keras_hub/api/layers/__init__.py +21 -3
  2. keras_hub/api/models/__init__.py +71 -12
  3. keras_hub/api/tokenizers/__init__.py +1 -1
  4. keras_hub/src/bounding_box/__init__.py +2 -0
  5. keras_hub/src/bounding_box/converters.py +102 -12
  6. keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
  7. keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
  8. keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
  9. keras_hub/src/layers/modeling/rms_normalization.py +36 -0
  10. keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
  11. keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
  12. keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
  13. keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
  14. keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
  15. keras_hub/src/layers/preprocessing/image_converter.py +170 -34
  16. keras_hub/src/metrics/bleu.py +4 -3
  17. keras_hub/src/models/albert/albert_presets.py +4 -12
  18. keras_hub/src/models/albert/albert_text_classifier.py +7 -7
  19. keras_hub/src/models/backbone.py +3 -14
  20. keras_hub/src/models/bart/bart_backbone.py +4 -4
  21. keras_hub/src/models/bart/bart_presets.py +3 -9
  22. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
  23. keras_hub/src/models/basnet/__init__.py +5 -0
  24. keras_hub/src/models/basnet/basnet.py +122 -0
  25. keras_hub/src/models/basnet/basnet_backbone.py +366 -0
  26. keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
  27. keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
  28. keras_hub/src/models/basnet/basnet_presets.py +17 -0
  29. keras_hub/src/models/bert/bert_presets.py +14 -32
  30. keras_hub/src/models/bert/bert_text_classifier.py +3 -3
  31. keras_hub/src/models/bloom/bloom_presets.py +8 -24
  32. keras_hub/src/models/causal_lm.py +56 -12
  33. keras_hub/src/models/clip/__init__.py +5 -0
  34. keras_hub/src/models/clip/clip_backbone.py +286 -0
  35. keras_hub/src/models/clip/clip_encoder_block.py +19 -4
  36. keras_hub/src/models/clip/clip_image_converter.py +8 -0
  37. keras_hub/src/models/clip/clip_presets.py +93 -0
  38. keras_hub/src/models/clip/clip_text_encoder.py +4 -1
  39. keras_hub/src/models/clip/clip_tokenizer.py +18 -3
  40. keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
  41. keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
  42. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
  43. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
  44. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
  45. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
  46. keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
  47. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
  48. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
  49. keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
  50. keras_hub/src/models/deeplab_v3/__init__.py +7 -0
  51. keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
  52. keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
  53. keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
  54. keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
  55. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
  56. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
  57. keras_hub/src/models/densenet/densenet_backbone.py +6 -4
  58. keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
  59. keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
  60. keras_hub/src/models/densenet/densenet_presets.py +9 -15
  61. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
  62. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
  63. keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
  64. keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
  65. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
  66. keras_hub/src/models/efficientnet/__init__.py +9 -0
  67. keras_hub/src/models/efficientnet/cba.py +141 -0
  68. keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
  69. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
  70. keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
  71. keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
  72. keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
  73. keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
  74. keras_hub/src/models/efficientnet/mbconv.py +53 -22
  75. keras_hub/src/models/electra/electra_backbone.py +2 -2
  76. keras_hub/src/models/electra/electra_presets.py +6 -18
  77. keras_hub/src/models/f_net/f_net_presets.py +2 -6
  78. keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
  79. keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
  80. keras_hub/src/models/falcon/falcon_backbone.py +5 -3
  81. keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
  82. keras_hub/src/models/falcon/falcon_presets.py +1 -3
  83. keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
  84. keras_hub/src/models/feature_pyramid_backbone.py +1 -1
  85. keras_hub/src/models/flux/__init__.py +5 -0
  86. keras_hub/src/models/flux/flux_layers.py +496 -0
  87. keras_hub/src/models/flux/flux_maths.py +225 -0
  88. keras_hub/src/models/flux/flux_model.py +236 -0
  89. keras_hub/src/models/flux/flux_presets.py +3 -0
  90. keras_hub/src/models/flux/flux_text_to_image.py +146 -0
  91. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
  92. keras_hub/src/models/gemma/gemma_backbone.py +35 -20
  93. keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
  94. keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
  95. keras_hub/src/models/gemma/gemma_presets.py +29 -63
  96. keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
  97. keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
  98. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
  99. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
  100. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
  101. keras_hub/src/models/image_classifier.py +147 -2
  102. keras_hub/src/models/image_classifier_preprocessor.py +6 -3
  103. keras_hub/src/models/image_object_detector.py +87 -0
  104. keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
  105. keras_hub/src/models/image_segmenter.py +0 -5
  106. keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
  107. keras_hub/src/models/image_to_image.py +417 -0
  108. keras_hub/src/models/inpaint.py +520 -0
  109. keras_hub/src/models/llama/llama_backbone.py +138 -12
  110. keras_hub/src/models/llama/llama_causal_lm.py +3 -1
  111. keras_hub/src/models/llama/llama_presets.py +10 -20
  112. keras_hub/src/models/llama3/llama3_backbone.py +12 -11
  113. keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
  114. keras_hub/src/models/llama3/llama3_presets.py +4 -12
  115. keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
  116. keras_hub/src/models/mistral/mistral_backbone.py +16 -15
  117. keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
  118. keras_hub/src/models/mistral/mistral_presets.py +3 -9
  119. keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
  120. keras_hub/src/models/mit/__init__.py +6 -0
  121. keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
  122. keras_hub/src/models/mit/mit_image_classifier.py +12 -0
  123. keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
  124. keras_hub/src/models/mit/mit_image_converter.py +8 -0
  125. keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
  126. keras_hub/src/models/mit/mit_presets.py +139 -0
  127. keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
  128. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
  129. keras_hub/src/models/opt/opt_causal_lm.py +2 -2
  130. keras_hub/src/models/opt/opt_presets.py +4 -12
  131. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
  132. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
  133. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
  134. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
  135. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
  136. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
  137. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
  138. keras_hub/src/models/phi3/phi3_decoder.py +0 -1
  139. keras_hub/src/models/phi3/phi3_presets.py +2 -6
  140. keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
  141. keras_hub/src/models/preprocessor.py +25 -11
  142. keras_hub/src/models/resnet/resnet_backbone.py +3 -14
  143. keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
  144. keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
  145. keras_hub/src/models/resnet/resnet_presets.py +127 -18
  146. keras_hub/src/models/retinanet/__init__.py +5 -0
  147. keras_hub/src/models/retinanet/anchor_generator.py +52 -53
  148. keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
  149. keras_hub/src/models/retinanet/non_max_supression.py +1 -0
  150. keras_hub/src/models/retinanet/prediction_head.py +192 -0
  151. keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
  152. keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
  153. keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
  154. keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
  155. keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
  156. keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
  157. keras_hub/src/models/roberta/roberta_backbone.py +2 -2
  158. keras_hub/src/models/roberta/roberta_presets.py +6 -8
  159. keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
  160. keras_hub/src/models/sam/__init__.py +5 -0
  161. keras_hub/src/models/sam/sam_backbone.py +2 -3
  162. keras_hub/src/models/sam/sam_image_converter.py +2 -4
  163. keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
  164. keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
  165. keras_hub/src/models/sam/sam_layers.py +5 -3
  166. keras_hub/src/models/sam/sam_presets.py +3 -9
  167. keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
  168. keras_hub/src/models/sam/sam_transformer.py +5 -4
  169. keras_hub/src/models/segformer/__init__.py +8 -0
  170. keras_hub/src/models/segformer/segformer_backbone.py +167 -0
  171. keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
  172. keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
  173. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
  174. keras_hub/src/models/segformer/segformer_presets.py +136 -0
  175. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
  176. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
  177. keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
  178. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
  179. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
  180. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
  181. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
  182. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
  183. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
  184. keras_hub/src/models/t5/t5_backbone.py +5 -4
  185. keras_hub/src/models/t5/t5_presets.py +47 -19
  186. keras_hub/src/models/task.py +47 -39
  187. keras_hub/src/models/text_classifier.py +2 -2
  188. keras_hub/src/models/text_to_image.py +106 -41
  189. keras_hub/src/models/vae/__init__.py +1 -0
  190. keras_hub/src/models/vae/vae_backbone.py +184 -0
  191. keras_hub/src/models/vae/vae_layers.py +739 -0
  192. keras_hub/src/models/vgg/__init__.py +5 -0
  193. keras_hub/src/models/vgg/vgg_backbone.py +4 -24
  194. keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
  195. keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
  196. keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
  197. keras_hub/src/models/vgg/vgg_presets.py +48 -0
  198. keras_hub/src/models/vit/__init__.py +5 -0
  199. keras_hub/src/models/vit/vit_backbone.py +152 -0
  200. keras_hub/src/models/vit/vit_image_classifier.py +187 -0
  201. keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
  202. keras_hub/src/models/vit/vit_image_converter.py +73 -0
  203. keras_hub/src/models/vit/vit_layers.py +391 -0
  204. keras_hub/src/models/vit/vit_presets.py +126 -0
  205. keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
  206. keras_hub/src/models/vit_det/vit_layers.py +3 -3
  207. keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
  208. keras_hub/src/models/whisper/whisper_backbone.py +6 -5
  209. keras_hub/src/models/whisper/whisper_decoder.py +3 -5
  210. keras_hub/src/models/whisper/whisper_presets.py +10 -30
  211. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
  212. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
  213. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
  214. keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
  215. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
  216. keras_hub/src/models/xlnet/relative_attention.py +20 -19
  217. keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
  218. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
  219. keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
  220. keras_hub/src/samplers/contrastive_sampler.py +2 -3
  221. keras_hub/src/samplers/sampler.py +2 -1
  222. keras_hub/src/tests/test_case.py +41 -6
  223. keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
  224. keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
  225. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
  226. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
  227. keras_hub/src/tokenizers/tokenizer.py +10 -13
  228. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
  229. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
  230. keras_hub/src/utils/keras_utils.py +2 -13
  231. keras_hub/src/utils/pipeline_model.py +3 -3
  232. keras_hub/src/utils/preset_utils.py +196 -144
  233. keras_hub/src/utils/tensor_utils.py +4 -4
  234. keras_hub/src/utils/timm/convert_densenet.py +6 -4
  235. keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
  236. keras_hub/src/utils/timm/convert_resnet.py +1 -1
  237. keras_hub/src/utils/timm/convert_vgg.py +85 -0
  238. keras_hub/src/utils/timm/preset_loader.py +14 -9
  239. keras_hub/src/utils/transformers/convert_llama3.py +21 -5
  240. keras_hub/src/utils/transformers/convert_vit.py +150 -0
  241. keras_hub/src/utils/transformers/preset_loader.py +23 -0
  242. keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
  243. keras_hub/src/version_utils.py +1 -1
  244. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
  245. keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
  246. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
  247. keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
  248. keras_hub/src/models/mix_transformer/__init__.py +0 -0
  249. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
  250. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
  251. keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
  252. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,87 @@
1
+ import keras
2
+
3
+ from keras_hub.src.api_export import keras_hub_export
4
+ from keras_hub.src.models.task import Task
5
+
6
+
7
+ @keras_hub_export("keras_hub.models.ImageObjectDetector")
8
+ class ImageObjectDetector(Task):
9
+ """Base class for all image object detection tasks.
10
+
11
+ The `ImageObjectDetector` tasks wrap a `keras_hub.models.Backbone` and
12
+ a `keras_hub.models.Preprocessor` to create a model that can be used for
13
+ object detection. `ImageObjectDetector` tasks take an additional
14
+ `num_classes` argument, controlling the number of predicted output classes.
15
+
16
+ To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)`
17
+ labels where `x` is a string and `y` is dictionary with `boxes` and
18
+ `classes`.
19
+
20
+ All `ImageObjectDetector` tasks include a `from_preset()` constructor which
21
+ can be used to load a pre-trained config and weights.
22
+ """
23
+
24
+ def compile(
25
+ self,
26
+ optimizer="auto",
27
+ box_loss="auto",
28
+ classification_loss="auto",
29
+ metrics=None,
30
+ **kwargs,
31
+ ):
32
+ """Configures the `ImageObjectDetector` task for training.
33
+
34
+ The `ImageObjectDetector` task extends the default compilation signature
35
+ of `keras.Model.compile` with defaults for `optimizer`, `loss`, and
36
+ `metrics`. To override these defaults, pass any value
37
+ to these arguments during compilation.
38
+
39
+ Args:
40
+ optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer`
41
+ instance. Defaults to `"auto"`, which uses the default optimizer
42
+ for the given model and task. See `keras.Model.compile` and
43
+ `keras.optimizers` for more info on possible `optimizer` values.
44
+ box_loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance.
45
+ Defaults to `"auto"`, where a
46
+ `keras.losses.Huber` loss will be
47
+ applied for the object detector task. See
48
+ `keras.Model.compile` and `keras.losses` for more info on
49
+ possible `loss` values.
50
+ classification_loss: `"auto"`, a loss name, or a `keras.losses.Loss`
51
+ instance. Defaults to `"auto"`, where a
52
+ `keras.losses.BinaryFocalCrossentropy` loss will be
53
+ applied for the object detector task. See
54
+ `keras.Model.compile` and `keras.losses` for more info on
55
+ possible `loss` values.
56
+ metrics: `a list of metrics to be evaluated by
57
+ the model during training and testing. Defaults to `None`.
58
+ See `keras.Model.compile` and `keras.metrics` for
59
+ more info on possible `metrics` values.
60
+ **kwargs: See `keras.Model.compile` for a full list of arguments
61
+ supported by the compile method.
62
+ """
63
+ if optimizer == "auto":
64
+ optimizer = keras.optimizers.Adam(5e-5)
65
+ if box_loss == "auto":
66
+ box_loss = keras.losses.Huber(reduction="sum")
67
+ if classification_loss == "auto":
68
+ activation = getattr(self, "activation", None)
69
+ activation = keras.activations.get(activation)
70
+ from_logits = activation != keras.activations.sigmoid
71
+ classification_loss = keras.losses.BinaryFocalCrossentropy(
72
+ from_logits=from_logits, reduction="sum"
73
+ )
74
+ if metrics is not None:
75
+ raise ValueError("User metrics not yet supported")
76
+
77
+ losses = {
78
+ "bbox_regression": box_loss,
79
+ "cls_logits": classification_loss,
80
+ }
81
+
82
+ super().compile(
83
+ optimizer=optimizer,
84
+ loss=losses,
85
+ metrics=metrics,
86
+ **kwargs,
87
+ )
@@ -0,0 +1,57 @@
1
+ import keras
2
+
3
+ from keras_hub.src.api_export import keras_hub_export
4
+ from keras_hub.src.models.preprocessor import Preprocessor
5
+ from keras_hub.src.utils.tensor_utils import preprocessing_function
6
+
7
+
8
+ @keras_hub_export("keras_hub.models.ImageObjectDetectorPreprocessor")
9
+ class ImageObjectDetectorPreprocessor(Preprocessor):
10
+ """Base class for object detector preprocessing layers.
11
+
12
+ `ImageObjectDetectorPreprocessor` tasks wraps a
13
+ `keras_hub.layers.Preprocessor` to create a preprocessing layer for
14
+ object detection tasks. It is intended to be paired with a
15
+ `keras_hub.models.ImageObjectDetector` task.
16
+
17
+ All `ImageObjectDetectorPreprocessor` take three inputs, `x`, `y`, and
18
+ `sample_weight`. `x`, the first input, should always be included. It can
19
+ be a image or batch of images. See examples below. `y` and `sample_weight`
20
+ are optional inputs that will be passed through unaltered. Usually, `y` will
21
+ be the a dict of `{"boxes": Tensor(batch_size, num_boxes, 4),
22
+ "classes": (batch_size, num_boxes)}.
23
+
24
+ The layer will returns either `x`, an `(x, y)` tuple if labels were
25
+ provided, or an `(x, y, sample_weight)` tuple if labels and sample weight
26
+ were provided. `x` will be the input images after all model preprocessing
27
+ has been applied.
28
+
29
+ All `ImageObjectDetectorPreprocessor` tasks include a `from_preset()`
30
+ constructor which can be used to load a pre-trained config and vocabularies.
31
+ You can call the `from_preset()` constructor directly on this base class, in
32
+ which case the correct class for your model will be automatically
33
+ instantiated.
34
+
35
+ Args:
36
+ image_converter: Preprocessing pipeline for images.
37
+
38
+ Examples.
39
+ ```python
40
+ preprocessor = keras_hub.models.ImageObjectDetectorPreprocessor.from_preset(
41
+ "retinanet_resnet50",
42
+ )
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ image_converter=None,
48
+ **kwargs,
49
+ ):
50
+ super().__init__(**kwargs)
51
+ self.image_converter = image_converter
52
+
53
+ @preprocessing_function
54
+ def call(self, x, y=None, sample_weight=None):
55
+ if self.image_converter:
56
+ x = self.image_converter(x)
57
+ return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
@@ -16,11 +16,6 @@ class ImageSegmenter(Task):
16
16
  be used to load a pre-trained config and weights.
17
17
  """
18
18
 
19
- def __init__(self, *args, **kwargs):
20
- super().__init__(*args, **kwargs)
21
- # Default compilation.
22
- self.compile()
23
-
24
19
  def compile(
25
20
  self,
26
21
  optimizer="auto",
@@ -19,9 +19,12 @@ class ImageSegmenterPreprocessor(Preprocessor):
19
19
 
20
20
  - `x`: The first input, should always be included. It can be an image or
21
21
  a batch of images.
22
- - `y`: (Optional) Usually the segmentation mask(s), will be passed through
23
- unaltered.
22
+ - `y`: (Optional) Usually the segmentation mask(s), if `resize_output_mask`
23
+ is set to `True` this will be resized to input image shape else will be
24
+ passed through unaltered.
24
25
  - `sample_weight`: (Optional) Will be passed through unaltered.
26
+ - `resize_output_mask` bool: If set to `True` the output mask will be
27
+ resized to the same size as the input image. Defaults to `False`.
25
28
 
26
29
  The layer will output either `x`, an `(x, y)` tuple if labels were provided,
27
30
  or an `(x, y, sample_weight)` tuple if labels and sample weight were
@@ -29,7 +32,7 @@ class ImageSegmenterPreprocessor(Preprocessor):
29
32
  been applied.
30
33
 
31
34
  All `ImageSegmenterPreprocessor` tasks include a `from_preset()`
32
- constructor which can be used to load a pre-trained config and vocabularies.
35
+ constructor which can be used to load a pre-trained config.
33
36
  You can call the `from_preset()` constructor directly on this base class, in
34
37
  which case the correct class for your model will be automatically
35
38
  instantiated.
@@ -49,7 +52,8 @@ class ImageSegmenterPreprocessor(Preprocessor):
49
52
  x, y = preprocessor(x, y)
50
53
 
51
54
  # Resize a batch of images and masks.
52
- x, y = [np.ones((512, 512, 3)), np.zeros((512, 512, 3))], [np.ones((512, 512, 1)), np.zeros((512, 512, 1))]
55
+ x, y = [np.ones((512, 512, 3)), np.zeros((512, 512, 3))],
56
+ [np.ones((512, 512, 1)), np.zeros((512, 512, 1))]
53
57
  x, y = preprocessor(x, y)
54
58
 
55
59
  # Use a `tf.data.Dataset`.
@@ -61,13 +65,34 @@ class ImageSegmenterPreprocessor(Preprocessor):
61
65
  def __init__(
62
66
  self,
63
67
  image_converter=None,
68
+ resize_output_mask=False,
64
69
  **kwargs,
65
70
  ):
66
71
  super().__init__(**kwargs)
67
72
  self.image_converter = image_converter
73
+ self.resize_output_mask = resize_output_mask
68
74
 
69
75
  @preprocessing_function
70
76
  def call(self, x, y=None, sample_weight=None):
71
77
  if self.image_converter:
72
78
  x = self.image_converter(x)
79
+
80
+ if y is not None and self.image_converter and self.resize_output_mask:
81
+ y = keras.layers.Resizing(
82
+ height=(
83
+ self.image_converter.image_size[0]
84
+ if self.image_converter.image_size
85
+ else None
86
+ ),
87
+ width=(
88
+ self.image_converter.image_size[1]
89
+ if self.image_converter.image_size
90
+ else None
91
+ ),
92
+ crop_to_aspect_ratio=self.image_converter.crop_to_aspect_ratio,
93
+ interpolation="nearest",
94
+ data_format=self.image_converter.data_format,
95
+ dtype=self.dtype_policy,
96
+ name="mask_resizing",
97
+ )(y)
73
98
  return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
@@ -0,0 +1,417 @@
1
+ import itertools
2
+ from functools import partial
3
+
4
+ import keras
5
+ from keras import ops
6
+ from keras import random
7
+
8
+ from keras_hub.src.api_export import keras_hub_export
9
+ from keras_hub.src.models.task import Task
10
+ from keras_hub.src.utils.keras_utils import standardize_data_format
11
+
12
+ try:
13
+ import tensorflow as tf
14
+ except ImportError:
15
+ tf = None
16
+
17
+
18
+ @keras_hub_export("keras_hub.models.ImageToImage")
19
+ class ImageToImage(Task):
20
+ """Base class for image-to-image tasks.
21
+
22
+ `ImageToImage` tasks wrap a `keras_hub.models.Backbone` and
23
+ a `keras_hub.models.Preprocessor` to create a model that can be used for
24
+ generation and generative fine-tuning.
25
+
26
+ `ImageToImage` tasks provide an additional, high-level `generate()` function
27
+ which can be used to generate image by token with a (image, string) in,
28
+ image out signature.
29
+
30
+ All `ImageToImage` tasks include a `from_preset()` constructor which can be
31
+ used to load a pre-trained config and weights.
32
+
33
+ Example:
34
+
35
+ ```python
36
+ # Load a Stable Diffusion 3 backbone with pre-trained weights.
37
+ reference_image = np.ones((1024, 1024, 3), dtype="float32")
38
+ image_to_image = keras_hub.models.ImageToImage.from_preset(
39
+ "stable_diffusion_3_medium",
40
+ )
41
+ image_to_image.generate(
42
+ reference_image,
43
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
44
+ )
45
+
46
+ # Load a Stable Diffusion 3 backbone at bfloat16 precision.
47
+ image_to_image = keras_hub.models.ImageToImage.from_preset(
48
+ "stable_diffusion_3_medium",
49
+ dtype="bfloat16",
50
+ )
51
+ image_to_image.generate(
52
+ reference_image,
53
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
54
+ )
55
+ ```
56
+ """
57
+
58
+ def __init__(self, *args, **kwargs):
59
+ super().__init__(*args, **kwargs)
60
+ # Default compilation.
61
+ self.compile()
62
+
63
+ @property
64
+ def support_negative_prompts(self):
65
+ """Whether the model supports `negative_prompts` key in `generate()`."""
66
+ return bool(True)
67
+
68
+ @property
69
+ def image_shape(self):
70
+ return tuple(self.backbone.image_shape)
71
+
72
+ @property
73
+ def latent_shape(self):
74
+ return tuple(self.backbone.latent_shape)
75
+
76
+ def compile(
77
+ self,
78
+ optimizer="auto",
79
+ loss="auto",
80
+ *,
81
+ metrics="auto",
82
+ **kwargs,
83
+ ):
84
+ """Configures the `ImageToImage` task for training.
85
+
86
+ The `ImageToImage` task extends the default compilation signature of
87
+ `keras.Model.compile` with defaults for `optimizer`, `loss`, and
88
+ `metrics`. To override these defaults, pass any value
89
+ to these arguments during compilation.
90
+
91
+ Args:
92
+ optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer`
93
+ instance. Defaults to `"auto"`, which uses the default optimizer
94
+ for the given model and task. See `keras.Model.compile` and
95
+ `keras.optimizers` for more info on possible `optimizer` values.
96
+ loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance.
97
+ Defaults to `"auto"`, where a
98
+ `keras.losses.MeanSquaredError` loss will be applied. See
99
+ `keras.Model.compile` and `keras.losses` for more info on
100
+ possible `loss` values.
101
+ metrics: `"auto"`, or a list of metrics to be evaluated by
102
+ the model during training and testing. Defaults to `"auto"`,
103
+ where a `keras.metrics.MeanSquaredError` will be applied to
104
+ track the loss of the model during training. See
105
+ `keras.Model.compile` and `keras.metrics` for more info on
106
+ possible `metrics` values.
107
+ **kwargs: See `keras.Model.compile` for a full list of arguments
108
+ supported by the compile method.
109
+ """
110
+ # Ref: https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py#L410-L414
111
+ if optimizer == "auto":
112
+ optimizer = keras.optimizers.AdamW(
113
+ 1e-4, weight_decay=1e-2, epsilon=1e-8, clipnorm=1.0
114
+ )
115
+ if loss == "auto":
116
+ loss = keras.losses.MeanSquaredError()
117
+ if metrics == "auto":
118
+ metrics = [keras.metrics.MeanSquaredError()]
119
+ super().compile(
120
+ optimizer=optimizer,
121
+ loss=loss,
122
+ metrics=metrics,
123
+ **kwargs,
124
+ )
125
+ self.generate_function = None
126
+
127
+ def generate_step(self, *args, **kwargs):
128
+ """Run generation on batches of input."""
129
+ raise NotImplementedError
130
+
131
+ def make_generate_function(self):
132
+ """Create or return the compiled generation function."""
133
+ if self.generate_function is not None:
134
+ return self.generate_function
135
+
136
+ self.generate_function = self.generate_step
137
+ if keras.config.backend() == "torch":
138
+ import torch
139
+
140
+ def wrapped_function(*args, **kwargs):
141
+ with torch.no_grad():
142
+ return self.generate_step(*args, **kwargs)
143
+
144
+ self.generate_function = wrapped_function
145
+ elif keras.config.backend() == "tensorflow" and not self.run_eagerly:
146
+ self.generate_function = tf.function(
147
+ self.generate_step, jit_compile=self.jit_compile
148
+ )
149
+ elif keras.config.backend() == "jax" and not self.run_eagerly:
150
+ import jax
151
+
152
+ @partial(jax.jit)
153
+ def compiled_function(state, *args, **kwargs):
154
+ (
155
+ trainable_variables,
156
+ non_trainable_variables,
157
+ ) = state
158
+ mapping = itertools.chain(
159
+ zip(self.trainable_variables, trainable_variables),
160
+ zip(self.non_trainable_variables, non_trainable_variables),
161
+ )
162
+
163
+ with keras.StatelessScope(state_mapping=mapping):
164
+ outputs = self.generate_step(*args, **kwargs)
165
+ return outputs
166
+
167
+ def wrapped_function(*args, **kwargs):
168
+ # Create an explicit tuple of all variable state.
169
+ state = (
170
+ # Use the explicit variable.value to preserve the
171
+ # sharding spec of distribution.
172
+ [v.value for v in self.trainable_variables],
173
+ [v.value for v in self.non_trainable_variables],
174
+ )
175
+ outputs = compiled_function(state, *args, **kwargs)
176
+ return outputs
177
+
178
+ self.generate_function = wrapped_function
179
+ return self.generate_function
180
+
181
+ def _normalize_generate_inputs(self, inputs):
182
+ """Normalize user input to the generate function.
183
+
184
+ This function converts all inputs to tensors, adds a batch dimension if
185
+ necessary, and returns a iterable "dataset like" object (either an
186
+ actual `tf.data.Dataset` or a list with a single batch element).
187
+
188
+ The input format must be one of the following:
189
+ - A dict with "images", "prompts" and/or "negative_prompts" keys
190
+ - A tf.data.Dataset with "images", "prompts" and/or "negative_prompts"
191
+ keys
192
+
193
+ The output will be a dict with "images", "prompts" and/or
194
+ "negative_prompts" keys.
195
+ """
196
+ if tf and isinstance(inputs, tf.data.Dataset):
197
+ _inputs = {
198
+ "images": inputs.map(lambda x: x["images"]).as_numpy_iterator(),
199
+ "prompts": inputs.map(
200
+ lambda x: x["prompts"]
201
+ ).as_numpy_iterator(),
202
+ }
203
+ if self.support_negative_prompts:
204
+ _inputs["negative_prompts"] = inputs.map(
205
+ lambda x: x["negative_prompts"]
206
+ ).as_numpy_iterator()
207
+ return _inputs, False
208
+
209
+ if (
210
+ not isinstance(inputs, dict)
211
+ or "images" not in inputs
212
+ or "prompts" not in inputs
213
+ ):
214
+ raise ValueError(
215
+ '`inputs` must be a dict with "images" and "prompts" keys or a'
216
+ f"tf.data.Dataset. Received: inputs={inputs}"
217
+ )
218
+
219
+ def normalize(x):
220
+ if isinstance(x, str):
221
+ return [x], True
222
+ if tf and isinstance(x, tf.Tensor) and x.shape.rank == 0:
223
+ return x[tf.newaxis], True
224
+ return x, False
225
+
226
+ def normalize_images(x):
227
+ data_format = getattr(
228
+ self.backbone, "data_format", standardize_data_format(None)
229
+ )
230
+ input_is_scalar = False
231
+ x = ops.convert_to_tensor(x)
232
+ if len(ops.shape(x)) < 4:
233
+ x = ops.expand_dims(x, axis=0)
234
+ input_is_scalar = True
235
+ x = ops.image.resize(
236
+ x,
237
+ (self.backbone.image_shape[0], self.backbone.image_shape[1]),
238
+ interpolation="nearest",
239
+ data_format=data_format,
240
+ )
241
+ return x, input_is_scalar
242
+
243
+ def get_dummy_prompts(x):
244
+ dummy_prompts = [""] * len(x)
245
+ if tf and isinstance(x, tf.Tensor):
246
+ return tf.convert_to_tensor(dummy_prompts)
247
+ else:
248
+ return dummy_prompts
249
+
250
+ for key in inputs:
251
+ if key == "images":
252
+ inputs[key], input_is_scalar = normalize_images(inputs[key])
253
+ else:
254
+ inputs[key], input_is_scalar = normalize(inputs[key])
255
+
256
+ if self.support_negative_prompts and "negative_prompts" not in inputs:
257
+ inputs["negative_prompts"] = get_dummy_prompts(inputs["prompts"])
258
+
259
+ return [inputs], input_is_scalar
260
+
261
+ def _normalize_generate_outputs(self, outputs, input_is_scalar):
262
+ """Normalize user output from the generate function.
263
+
264
+ This function converts all output to numpy with a value range of
265
+ `[0, 255]`. If a batch dimension was added to the input, it is removed
266
+ from the output.
267
+ """
268
+
269
+ def normalize(x):
270
+ outputs = ops.concatenate(x, axis=0)
271
+ outputs = ops.clip(ops.divide(ops.add(outputs, 1.0), 2.0), 0.0, 1.0)
272
+ outputs = ops.cast(ops.round(ops.multiply(outputs, 255.0)), "uint8")
273
+ outputs = ops.squeeze(outputs, 0) if input_is_scalar else outputs
274
+ return ops.convert_to_numpy(outputs)
275
+
276
+ if isinstance(outputs[0], dict):
277
+ normalized = {}
278
+ for key in outputs[0]:
279
+ normalized[key] = normalize([x[key] for x in outputs])
280
+ return normalized
281
+ return normalize([x for x in outputs])
282
+
283
+ def generate(
284
+ self,
285
+ inputs,
286
+ num_steps,
287
+ strength,
288
+ guidance_scale=None,
289
+ seed=None,
290
+ ):
291
+ """Generate image based on the provided `inputs`.
292
+
293
+ Typically, `inputs` is a dict with `"images"` and `"prompts"` keys.
294
+ `"images"` are reference images within a value range of
295
+ `[-1.0, 1.0]`, which will be resized to `self.backbone.height` and
296
+ `self.backbone.width`, then encoded into latent space by the VAE
297
+ encoder. `"prompts"` are strings that will be tokenized and encoded by
298
+ the text encoder.
299
+
300
+ Some models support a `"negative_prompts"` key, which helps steer the
301
+ model away from generating certain styles and elements. To enable this,
302
+ add `"negative_prompts"` to the input dict.
303
+
304
+ If `inputs` are a `tf.data.Dataset`, outputs will be generated
305
+ "batch-by-batch" and concatenated. Otherwise, all inputs will be
306
+ processed as batches.
307
+
308
+ Args:
309
+ inputs: python data, tensor data, or a `tf.data.Dataset`. The format
310
+ must be one of the following:
311
+ - A dict with `"images"`, `"prompts"` and/or
312
+ `"negative_prompts"` keys.
313
+ - A `tf.data.Dataset` with `"images"`, `"prompts"` and/or
314
+ `"negative_prompts"` keys.
315
+ num_steps: int. The number of diffusion steps to take.
316
+ strength: float. Indicates the extent to which the reference
317
+ `images` are transformed. Must be between `0.0` and `1.0`. When
318
+ `strength=1.0`, `images` is essentially ignore and added noise
319
+ is maximum and the denoising process runs for the full number of
320
+ iterations specified in `num_steps`.
321
+ guidance_scale: Optional float. The classifier free guidance scale
322
+ defined in [Classifier-Free Diffusion Guidance](
323
+ https://arxiv.org/abs/2207.12598). A higher scale encourages
324
+ generating images more closely related to the prompts, typically
325
+ at the cost of lower image quality. Note that some models don't
326
+ utilize classifier-free guidance.
327
+ seed: optional int. Used as a random seed.
328
+ """
329
+ num_steps = int(num_steps)
330
+ strength = float(strength)
331
+ guidance_scale = (
332
+ float(guidance_scale) if guidance_scale is not None else None
333
+ )
334
+ if strength < 0.0 or strength > 1.0:
335
+ raise ValueError(
336
+ "`strength` must be between `0.0` and `1.0`. "
337
+ f"Received strength={strength}."
338
+ )
339
+ if guidance_scale is not None and guidance_scale > 1.0:
340
+ guidance_scale = ops.convert_to_tensor(float(guidance_scale))
341
+ else:
342
+ guidance_scale = None
343
+ starting_step = int(num_steps * (1.0 - strength))
344
+ starting_step = ops.convert_to_tensor(starting_step, "int32")
345
+ num_steps = ops.convert_to_tensor(int(num_steps), "int32")
346
+
347
+ # Check `inputs` format.
348
+ required_keys = ["images", "prompts"]
349
+ if tf and isinstance(inputs, tf.data.Dataset):
350
+ spec = inputs.element_spec
351
+ if not all(key in spec for key in required_keys):
352
+ raise ValueError(
353
+ "Expected a `tf.data.Dataset` with the following keys:"
354
+ f"{required_keys}. Received: inputs.element_spec={spec}"
355
+ )
356
+ else:
357
+ if not isinstance(inputs, dict):
358
+ raise ValueError(
359
+ "Expected a `dict` or `tf.data.Dataset`. "
360
+ f"Received: inputs={inputs} of type {type(inputs)}."
361
+ )
362
+ if not all(key in inputs for key in required_keys):
363
+ raise ValueError(
364
+ "Expected a `dict` with the following keys:"
365
+ f"{required_keys}. "
366
+ f"Received: inputs.keys={list(inputs.keys())}"
367
+ )
368
+
369
+ # Setup our three main passes.
370
+ # 1. Preprocessing strings to dense integer tensors.
371
+ # 2. Generate outputs via a compiled function on dense tensors.
372
+ # 3. Postprocess dense tensors to a value range of `[0, 255]`.
373
+ generate_function = self.make_generate_function()
374
+
375
+ def preprocess(x):
376
+ if self.preprocessor is not None:
377
+ return self.preprocessor.generate_preprocess(x)
378
+ else:
379
+ return x
380
+
381
+ def generate(images, x):
382
+ token_ids = x[0] if self.support_negative_prompts else x
383
+
384
+ # Initialize noises.
385
+ if isinstance(token_ids, dict):
386
+ arbitrary_key = list(token_ids.keys())[0]
387
+ batch_size = ops.shape(token_ids[arbitrary_key])[0]
388
+ else:
389
+ batch_size = ops.shape(token_ids)[0]
390
+ noise_shape = (batch_size,) + self.latent_shape[1:]
391
+ noises = random.normal(noise_shape, dtype="float32", seed=seed)
392
+
393
+ return generate_function(
394
+ images, noises, x, starting_step, num_steps, guidance_scale
395
+ )
396
+
397
+ # Normalize and preprocess inputs.
398
+ inputs, input_is_scalar = self._normalize_generate_inputs(inputs)
399
+ if self.support_negative_prompts:
400
+ images = [x["images"] for x in inputs]
401
+ token_ids = [preprocess(x["prompts"]) for x in inputs]
402
+ negative_token_ids = [
403
+ preprocess(x["negative_prompts"]) for x in inputs
404
+ ]
405
+ # Tuple format: (images, (token_ids, negative_token_ids)).
406
+ inputs = [
407
+ x for x in zip(images, zip(token_ids, negative_token_ids))
408
+ ]
409
+ else:
410
+ images = [x["images"] for x in inputs]
411
+ token_ids = [preprocess(x["prompts"]) for x in inputs]
412
+ # Tuple format: (images, token_ids).
413
+ inputs = [x for x in zip(images, token_ids)]
414
+
415
+ # Image-to-image.
416
+ outputs = [generate(*x) for x in inputs]
417
+ return self._normalize_generate_outputs(outputs, input_is_scalar)