keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. keras_hub/api/layers/__init__.py +21 -3
  2. keras_hub/api/models/__init__.py +71 -12
  3. keras_hub/api/tokenizers/__init__.py +1 -1
  4. keras_hub/src/bounding_box/__init__.py +2 -0
  5. keras_hub/src/bounding_box/converters.py +102 -12
  6. keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
  7. keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
  8. keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
  9. keras_hub/src/layers/modeling/rms_normalization.py +36 -0
  10. keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
  11. keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
  12. keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
  13. keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
  14. keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
  15. keras_hub/src/layers/preprocessing/image_converter.py +170 -34
  16. keras_hub/src/metrics/bleu.py +4 -3
  17. keras_hub/src/models/albert/albert_presets.py +4 -12
  18. keras_hub/src/models/albert/albert_text_classifier.py +7 -7
  19. keras_hub/src/models/backbone.py +3 -14
  20. keras_hub/src/models/bart/bart_backbone.py +4 -4
  21. keras_hub/src/models/bart/bart_presets.py +3 -9
  22. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
  23. keras_hub/src/models/basnet/__init__.py +5 -0
  24. keras_hub/src/models/basnet/basnet.py +122 -0
  25. keras_hub/src/models/basnet/basnet_backbone.py +366 -0
  26. keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
  27. keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
  28. keras_hub/src/models/basnet/basnet_presets.py +17 -0
  29. keras_hub/src/models/bert/bert_presets.py +14 -32
  30. keras_hub/src/models/bert/bert_text_classifier.py +3 -3
  31. keras_hub/src/models/bloom/bloom_presets.py +8 -24
  32. keras_hub/src/models/causal_lm.py +56 -12
  33. keras_hub/src/models/clip/__init__.py +5 -0
  34. keras_hub/src/models/clip/clip_backbone.py +286 -0
  35. keras_hub/src/models/clip/clip_encoder_block.py +19 -4
  36. keras_hub/src/models/clip/clip_image_converter.py +8 -0
  37. keras_hub/src/models/clip/clip_presets.py +93 -0
  38. keras_hub/src/models/clip/clip_text_encoder.py +4 -1
  39. keras_hub/src/models/clip/clip_tokenizer.py +18 -3
  40. keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
  41. keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
  42. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
  43. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
  44. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
  45. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
  46. keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
  47. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
  48. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
  49. keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
  50. keras_hub/src/models/deeplab_v3/__init__.py +7 -0
  51. keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
  52. keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
  53. keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
  54. keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
  55. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
  56. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
  57. keras_hub/src/models/densenet/densenet_backbone.py +6 -4
  58. keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
  59. keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
  60. keras_hub/src/models/densenet/densenet_presets.py +9 -15
  61. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
  62. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
  63. keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
  64. keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
  65. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
  66. keras_hub/src/models/efficientnet/__init__.py +9 -0
  67. keras_hub/src/models/efficientnet/cba.py +141 -0
  68. keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
  69. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
  70. keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
  71. keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
  72. keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
  73. keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
  74. keras_hub/src/models/efficientnet/mbconv.py +53 -22
  75. keras_hub/src/models/electra/electra_backbone.py +2 -2
  76. keras_hub/src/models/electra/electra_presets.py +6 -18
  77. keras_hub/src/models/f_net/f_net_presets.py +2 -6
  78. keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
  79. keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
  80. keras_hub/src/models/falcon/falcon_backbone.py +5 -3
  81. keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
  82. keras_hub/src/models/falcon/falcon_presets.py +1 -3
  83. keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
  84. keras_hub/src/models/feature_pyramid_backbone.py +1 -1
  85. keras_hub/src/models/flux/__init__.py +5 -0
  86. keras_hub/src/models/flux/flux_layers.py +496 -0
  87. keras_hub/src/models/flux/flux_maths.py +225 -0
  88. keras_hub/src/models/flux/flux_model.py +236 -0
  89. keras_hub/src/models/flux/flux_presets.py +3 -0
  90. keras_hub/src/models/flux/flux_text_to_image.py +146 -0
  91. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
  92. keras_hub/src/models/gemma/gemma_backbone.py +35 -20
  93. keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
  94. keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
  95. keras_hub/src/models/gemma/gemma_presets.py +29 -63
  96. keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
  97. keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
  98. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
  99. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
  100. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
  101. keras_hub/src/models/image_classifier.py +147 -2
  102. keras_hub/src/models/image_classifier_preprocessor.py +6 -3
  103. keras_hub/src/models/image_object_detector.py +87 -0
  104. keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
  105. keras_hub/src/models/image_segmenter.py +0 -5
  106. keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
  107. keras_hub/src/models/image_to_image.py +417 -0
  108. keras_hub/src/models/inpaint.py +520 -0
  109. keras_hub/src/models/llama/llama_backbone.py +138 -12
  110. keras_hub/src/models/llama/llama_causal_lm.py +3 -1
  111. keras_hub/src/models/llama/llama_presets.py +10 -20
  112. keras_hub/src/models/llama3/llama3_backbone.py +12 -11
  113. keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
  114. keras_hub/src/models/llama3/llama3_presets.py +4 -12
  115. keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
  116. keras_hub/src/models/mistral/mistral_backbone.py +16 -15
  117. keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
  118. keras_hub/src/models/mistral/mistral_presets.py +3 -9
  119. keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
  120. keras_hub/src/models/mit/__init__.py +6 -0
  121. keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
  122. keras_hub/src/models/mit/mit_image_classifier.py +12 -0
  123. keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
  124. keras_hub/src/models/mit/mit_image_converter.py +8 -0
  125. keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
  126. keras_hub/src/models/mit/mit_presets.py +139 -0
  127. keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
  128. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
  129. keras_hub/src/models/opt/opt_causal_lm.py +2 -2
  130. keras_hub/src/models/opt/opt_presets.py +4 -12
  131. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
  132. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
  133. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
  134. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
  135. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
  136. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
  137. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
  138. keras_hub/src/models/phi3/phi3_decoder.py +0 -1
  139. keras_hub/src/models/phi3/phi3_presets.py +2 -6
  140. keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
  141. keras_hub/src/models/preprocessor.py +25 -11
  142. keras_hub/src/models/resnet/resnet_backbone.py +3 -14
  143. keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
  144. keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
  145. keras_hub/src/models/resnet/resnet_presets.py +127 -18
  146. keras_hub/src/models/retinanet/__init__.py +5 -0
  147. keras_hub/src/models/retinanet/anchor_generator.py +52 -53
  148. keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
  149. keras_hub/src/models/retinanet/non_max_supression.py +1 -0
  150. keras_hub/src/models/retinanet/prediction_head.py +192 -0
  151. keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
  152. keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
  153. keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
  154. keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
  155. keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
  156. keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
  157. keras_hub/src/models/roberta/roberta_backbone.py +2 -2
  158. keras_hub/src/models/roberta/roberta_presets.py +6 -8
  159. keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
  160. keras_hub/src/models/sam/__init__.py +5 -0
  161. keras_hub/src/models/sam/sam_backbone.py +2 -3
  162. keras_hub/src/models/sam/sam_image_converter.py +2 -4
  163. keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
  164. keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
  165. keras_hub/src/models/sam/sam_layers.py +5 -3
  166. keras_hub/src/models/sam/sam_presets.py +3 -9
  167. keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
  168. keras_hub/src/models/sam/sam_transformer.py +5 -4
  169. keras_hub/src/models/segformer/__init__.py +8 -0
  170. keras_hub/src/models/segformer/segformer_backbone.py +167 -0
  171. keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
  172. keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
  173. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
  174. keras_hub/src/models/segformer/segformer_presets.py +136 -0
  175. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
  176. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
  177. keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
  178. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
  179. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
  180. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
  181. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
  182. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
  183. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
  184. keras_hub/src/models/t5/t5_backbone.py +5 -4
  185. keras_hub/src/models/t5/t5_presets.py +47 -19
  186. keras_hub/src/models/task.py +47 -39
  187. keras_hub/src/models/text_classifier.py +2 -2
  188. keras_hub/src/models/text_to_image.py +106 -41
  189. keras_hub/src/models/vae/__init__.py +1 -0
  190. keras_hub/src/models/vae/vae_backbone.py +184 -0
  191. keras_hub/src/models/vae/vae_layers.py +739 -0
  192. keras_hub/src/models/vgg/__init__.py +5 -0
  193. keras_hub/src/models/vgg/vgg_backbone.py +4 -24
  194. keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
  195. keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
  196. keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
  197. keras_hub/src/models/vgg/vgg_presets.py +48 -0
  198. keras_hub/src/models/vit/__init__.py +5 -0
  199. keras_hub/src/models/vit/vit_backbone.py +152 -0
  200. keras_hub/src/models/vit/vit_image_classifier.py +187 -0
  201. keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
  202. keras_hub/src/models/vit/vit_image_converter.py +73 -0
  203. keras_hub/src/models/vit/vit_layers.py +391 -0
  204. keras_hub/src/models/vit/vit_presets.py +126 -0
  205. keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
  206. keras_hub/src/models/vit_det/vit_layers.py +3 -3
  207. keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
  208. keras_hub/src/models/whisper/whisper_backbone.py +6 -5
  209. keras_hub/src/models/whisper/whisper_decoder.py +3 -5
  210. keras_hub/src/models/whisper/whisper_presets.py +10 -30
  211. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
  212. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
  213. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
  214. keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
  215. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
  216. keras_hub/src/models/xlnet/relative_attention.py +20 -19
  217. keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
  218. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
  219. keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
  220. keras_hub/src/samplers/contrastive_sampler.py +2 -3
  221. keras_hub/src/samplers/sampler.py +2 -1
  222. keras_hub/src/tests/test_case.py +41 -6
  223. keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
  224. keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
  225. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
  226. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
  227. keras_hub/src/tokenizers/tokenizer.py +10 -13
  228. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
  229. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
  230. keras_hub/src/utils/keras_utils.py +2 -13
  231. keras_hub/src/utils/pipeline_model.py +3 -3
  232. keras_hub/src/utils/preset_utils.py +196 -144
  233. keras_hub/src/utils/tensor_utils.py +4 -4
  234. keras_hub/src/utils/timm/convert_densenet.py +6 -4
  235. keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
  236. keras_hub/src/utils/timm/convert_resnet.py +1 -1
  237. keras_hub/src/utils/timm/convert_vgg.py +85 -0
  238. keras_hub/src/utils/timm/preset_loader.py +14 -9
  239. keras_hub/src/utils/transformers/convert_llama3.py +21 -5
  240. keras_hub/src/utils/transformers/convert_vit.py +150 -0
  241. keras_hub/src/utils/transformers/preset_loader.py +23 -0
  242. keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
  243. keras_hub/src/version_utils.py +1 -1
  244. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
  245. keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
  246. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
  247. keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
  248. keras_hub/src/models/mix_transformer/__init__.py +0 -0
  249. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
  250. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
  251. keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
  252. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,200 @@
1
+ import keras
2
+
3
+ from keras_hub.src.api_export import keras_hub_export
4
+ from keras_hub.src.models.backbone import Backbone
5
+ from keras_hub.src.models.deeplab_v3.deeplab_v3_layers import (
6
+ SpatialPyramidPooling,
7
+ )
8
+
9
+
10
+ @keras_hub_export("keras_hub.models.DeepLabV3Backbone")
11
+ class DeepLabV3Backbone(Backbone):
12
+ """DeepLabV3 & DeepLabV3Plus architecture for semantic segmentation.
13
+
14
+ This class implements a DeepLabV3 & DeepLabV3Plus architecture as described
15
+ in [Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1802.02611)
16
+ (ECCV 2018)
17
+ and [Rethinking Atrous Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1706.05587)
18
+ (CVPR 2017)
19
+
20
+ Args:
21
+ image_encoder: `keras.Model`. An instance that is used as a feature
22
+ extractor for the Encoder. Should either be a
23
+ `keras_hub.models.Backbone` or a `keras.Model` that implements the
24
+ `pyramid_outputs` property with keys "P2", "P3" etc as values.
25
+ A somewhat sensible backbone to use in many cases is
26
+ the `keras_hub.models.ResNetBackbone.from_preset("resnet_v2_50")`.
27
+ projection_filters: int. Number of filters in the convolution layer
28
+ projecting low-level features from the `image_encoder`.
29
+ spatial_pyramid_pooling_key: str. A layer level to extract and perform
30
+ `spatial_pyramid_pooling`, one of the key from the `image_encoder`
31
+ `pyramid_outputs` property such as "P4", "P5" etc.
32
+ upsampling_size: int or tuple of 2 integers. The upsampling factors for
33
+ rows and columns of `spatial_pyramid_pooling` layer.
34
+ If `low_level_feature_key` is given then `spatial_pyramid_pooling`s
35
+ layer resolution should match with the `low_level_feature`s layer
36
+ resolution to concatenate both the layers for combined encoder
37
+ outputs.
38
+ dilation_rates: list. A `list` of integers for parallel dilated conv
39
+ applied to `SpatialPyramidPooling`. Usually a
40
+ sample choice of rates are `[6, 12, 18]`.
41
+ low_level_feature_key: str optional. A layer level to extract the
42
+ feature from one of the key from the `image_encoder`s
43
+ `pyramid_outputs` property such as "P2", "P3" etc which will be the
44
+ Decoder block. Required only when the DeepLabV3Plus architecture
45
+ needs to be applied.
46
+ image_shape: tuple. The input shape without the batch size.
47
+ Defaults to `(None, None, 3)`.
48
+
49
+ Example:
50
+ ```python
51
+ # Load a trained backbone to extract features from it's `pyramid_outputs`.
52
+ image_encoder = keras_hub.models.ResNetBackbone.from_preset(
53
+ "resnet_50_imagenet"
54
+ )
55
+
56
+ model = keras_hub.models.DeepLabV3Backbone(
57
+ image_encoder=image_encoder,
58
+ projection_filters=48,
59
+ low_level_feature_key="P2",
60
+ spatial_pyramid_pooling_key="P5",
61
+ upsampling_size = 8,
62
+ dilation_rates = [6, 12, 18]
63
+ )
64
+ ```
65
+ """ # noqa: E501
66
+
67
+ def __init__(
68
+ self,
69
+ image_encoder,
70
+ spatial_pyramid_pooling_key,
71
+ upsampling_size,
72
+ dilation_rates,
73
+ low_level_feature_key=None,
74
+ projection_filters=48,
75
+ image_shape=(None, None, 3),
76
+ **kwargs,
77
+ ):
78
+ if not isinstance(image_encoder, keras.Model):
79
+ raise ValueError(
80
+ "Argument `image_encoder` must be a `keras.Model` instance. "
81
+ "Received instead "
82
+ f"{image_encoder} (of type {type(image_encoder)})."
83
+ )
84
+ data_format = keras.config.image_data_format()
85
+ channel_axis = -1 if data_format == "channels_last" else 1
86
+
87
+ # === Layers ===
88
+ inputs = keras.layers.Input(image_shape, name="inputs")
89
+
90
+ fpn_model = keras.Model(
91
+ image_encoder.inputs, image_encoder.pyramid_outputs
92
+ )
93
+
94
+ fpn_outputs = fpn_model(inputs)
95
+
96
+ spatial_pyramid_pooling = SpatialPyramidPooling(
97
+ dilation_rates=dilation_rates
98
+ )
99
+ spatial_backbone_features = fpn_outputs[spatial_pyramid_pooling_key]
100
+ spp_outputs = spatial_pyramid_pooling(spatial_backbone_features)
101
+
102
+ encoder_outputs = keras.layers.UpSampling2D(
103
+ size=upsampling_size,
104
+ interpolation="bilinear",
105
+ name="encoder_output_upsampling",
106
+ data_format=data_format,
107
+ )(spp_outputs)
108
+
109
+ if low_level_feature_key:
110
+ decoder_feature = fpn_outputs[low_level_feature_key]
111
+ low_level_projected_features = apply_low_level_feature_network(
112
+ decoder_feature, projection_filters, channel_axis
113
+ )
114
+
115
+ encoder_outputs = keras.layers.Concatenate(
116
+ axis=channel_axis, name="encoder_decoder_concat"
117
+ )([encoder_outputs, low_level_projected_features])
118
+ # upsampling to the original image size
119
+ upsampling = (2 ** int(spatial_pyramid_pooling_key[-1])) // (
120
+ int(upsampling_size[0])
121
+ if isinstance(upsampling_size, tuple)
122
+ else upsampling_size
123
+ )
124
+ # === Functional Model ===
125
+ x = keras.layers.Conv2D(
126
+ name="segmentation_head_conv",
127
+ filters=256,
128
+ kernel_size=1,
129
+ padding="same",
130
+ use_bias=False,
131
+ data_format=data_format,
132
+ )(encoder_outputs)
133
+ x = keras.layers.BatchNormalization(
134
+ name="segmentation_head_norm", axis=channel_axis
135
+ )(x)
136
+ x = keras.layers.ReLU(name="segmentation_head_relu")(x)
137
+ x = keras.layers.UpSampling2D(
138
+ size=upsampling,
139
+ interpolation="bilinear",
140
+ data_format=data_format,
141
+ name="backbone_output_upsampling",
142
+ )(x)
143
+
144
+ super().__init__(inputs=inputs, outputs=x, **kwargs)
145
+
146
+ # === Config ===
147
+ self.image_shape = image_shape
148
+ self.image_encoder = image_encoder
149
+ self.projection_filters = projection_filters
150
+ self.upsampling_size = upsampling_size
151
+ self.dilation_rates = dilation_rates
152
+ self.low_level_feature_key = low_level_feature_key
153
+ self.spatial_pyramid_pooling_key = spatial_pyramid_pooling_key
154
+
155
+ def get_config(self):
156
+ config = super().get_config()
157
+ config.update(
158
+ {
159
+ "image_encoder": keras.saving.serialize_keras_object(
160
+ self.image_encoder
161
+ ),
162
+ "projection_filters": self.projection_filters,
163
+ "dilation_rates": self.dilation_rates,
164
+ "upsampling_size": self.upsampling_size,
165
+ "low_level_feature_key": self.low_level_feature_key,
166
+ "spatial_pyramid_pooling_key": self.spatial_pyramid_pooling_key,
167
+ "image_shape": self.image_shape,
168
+ }
169
+ )
170
+ return config
171
+
172
+ @classmethod
173
+ def from_config(cls, config):
174
+ if "image_encoder" in config and isinstance(
175
+ config["image_encoder"], dict
176
+ ):
177
+ config["image_encoder"] = keras.layers.deserialize(
178
+ config["image_encoder"]
179
+ )
180
+ return super().from_config(config)
181
+
182
+
183
+ def apply_low_level_feature_network(
184
+ input_tensor, projection_filters, channel_axis
185
+ ):
186
+ data_format = keras.config.image_data_format()
187
+ x = keras.layers.Conv2D(
188
+ name="decoder_conv",
189
+ filters=projection_filters,
190
+ kernel_size=1,
191
+ padding="same",
192
+ use_bias=False,
193
+ data_format=data_format,
194
+ )(input_tensor)
195
+
196
+ x = keras.layers.BatchNormalization(name="decoder_norm", axis=channel_axis)(
197
+ x
198
+ )
199
+ x = keras.layers.ReLU(name="decoder_relu")(x)
200
+ return x
@@ -0,0 +1,10 @@
1
+ from keras_hub.src.api_export import keras_hub_export
2
+ from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
3
+ from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import (
4
+ DeepLabV3Backbone,
5
+ )
6
+
7
+
8
+ @keras_hub_export("keras_hub.layers.DeepLabV3ImageConverter")
9
+ class DeepLabV3ImageConverter(ImageConverter):
10
+ backbone_cls = DeepLabV3Backbone
@@ -0,0 +1,16 @@
1
+ from keras_hub.src.api_export import keras_hub_export
2
+ from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import (
3
+ DeepLabV3Backbone,
4
+ )
5
+ from keras_hub.src.models.deeplab_v3.deeplab_v3_image_converter import (
6
+ DeepLabV3ImageConverter,
7
+ )
8
+ from keras_hub.src.models.image_segmenter_preprocessor import (
9
+ ImageSegmenterPreprocessor,
10
+ )
11
+
12
+
13
+ @keras_hub_export("keras_hub.models.DeepLabV3ImageSegmenterPreprocessor")
14
+ class DeepLabV3ImageSegmenterPreprocessor(ImageSegmenterPreprocessor):
15
+ backbone_cls = DeepLabV3Backbone
16
+ image_converter_cls = DeepLabV3ImageConverter
@@ -0,0 +1,215 @@
1
+ import keras
2
+ from keras import ops
3
+
4
+
5
+ class SpatialPyramidPooling(keras.layers.Layer):
6
+ """Implements the Atrous Spatial Pyramid Pooling.
7
+
8
+ Reference for Atrous Spatial Pyramid Pooling [Rethinking Atrous Convolution
9
+ for Semantic Image Segmentation](https://arxiv.org/pdf/1706.05587.pdf) and
10
+ [Encoder-Decoder with Atrous Separable Convolution for Semantic Image
11
+ Segmentation](https://arxiv.org/pdf/1802.02611.pdf)
12
+
13
+ Args:
14
+ dilation_rates: list of ints. The dilation rate for parallel dilated conv.
15
+ Usually a sample choice of rates are `[6, 12, 18]`.
16
+ num_channels: int. The number of output channels, defaults to `256`.
17
+ activation: str. Activation to be used, defaults to `relu`.
18
+ dropout: float. The dropout rate of the final projection output after the
19
+ activations and batch norm, defaults to `0.0`, which means no dropout is
20
+ applied to the output.
21
+
22
+ Example:
23
+ ```python
24
+ inp = keras.layers.Input((384, 384, 3))
25
+ backbone = keras.applications.EfficientNetB0(
26
+ input_tensor=inp,
27
+ include_top=False)
28
+ output = backbone(inp)
29
+ output = SpatialPyramidPooling(
30
+ dilation_rates=[6, 12, 18])(output)
31
+ ```
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ dilation_rates,
37
+ num_channels=256,
38
+ activation="relu",
39
+ dropout=0.0,
40
+ **kwargs,
41
+ ):
42
+ super().__init__(**kwargs)
43
+ self.dilation_rates = dilation_rates
44
+ self.num_channels = num_channels
45
+ self.activation = activation
46
+ self.dropout = dropout
47
+ self.data_format = keras.config.image_data_format()
48
+ self.channel_axis = -1 if self.data_format == "channels_last" else 1
49
+
50
+ def build(self, input_shape):
51
+ channels = input_shape[self.channel_axis]
52
+
53
+ # This is the parallel networks that process the input features with
54
+ # different dilation rates. The output from each channel will be merged
55
+ # together and feed to the output.
56
+ self.aspp_parallel_channels = []
57
+
58
+ # Channel1 with Conv2D and 1x1 kernel size.
59
+ conv_sequential = keras.Sequential(
60
+ [
61
+ keras.layers.Conv2D(
62
+ filters=self.num_channels,
63
+ kernel_size=(1, 1),
64
+ use_bias=False,
65
+ data_format=self.data_format,
66
+ name="aspp_conv_1",
67
+ ),
68
+ keras.layers.BatchNormalization(
69
+ axis=self.channel_axis, name="aspp_bn_1"
70
+ ),
71
+ keras.layers.Activation(
72
+ self.activation, name="aspp_activation_1"
73
+ ),
74
+ ]
75
+ )
76
+ conv_sequential.build(input_shape)
77
+ self.aspp_parallel_channels.append(conv_sequential)
78
+
79
+ # Channel 2 and afterwards are based on self.dilation_rates, and each of
80
+ # them will have conv2D with 3x3 kernel size.
81
+ for i, dilation_rate in enumerate(self.dilation_rates):
82
+ conv_sequential = keras.Sequential(
83
+ [
84
+ keras.layers.Conv2D(
85
+ filters=self.num_channels,
86
+ kernel_size=(3, 3),
87
+ padding="same",
88
+ dilation_rate=dilation_rate,
89
+ use_bias=False,
90
+ data_format=self.data_format,
91
+ name=f"aspp_conv_{i + 2}",
92
+ ),
93
+ keras.layers.BatchNormalization(
94
+ axis=self.channel_axis, name=f"aspp_bn_{i + 2}"
95
+ ),
96
+ keras.layers.Activation(
97
+ self.activation, name=f"aspp_activation_{i + 2}"
98
+ ),
99
+ ]
100
+ )
101
+ conv_sequential.build(input_shape)
102
+ self.aspp_parallel_channels.append(conv_sequential)
103
+
104
+ # Last channel is the global average pooling with conv2D 1x1 kernel.
105
+ if self.channel_axis == -1:
106
+ reshape = keras.layers.Reshape((1, 1, channels), name="reshape")
107
+ else:
108
+ reshape = keras.layers.Reshape((channels, 1, 1), name="reshape")
109
+ pool_sequential = keras.Sequential(
110
+ [
111
+ keras.layers.GlobalAveragePooling2D(
112
+ data_format=self.data_format, name="average_pooling"
113
+ ),
114
+ reshape,
115
+ keras.layers.Conv2D(
116
+ filters=self.num_channels,
117
+ kernel_size=(1, 1),
118
+ use_bias=False,
119
+ data_format=self.data_format,
120
+ name="conv_pooling",
121
+ ),
122
+ keras.layers.BatchNormalization(
123
+ axis=self.channel_axis, name="bn_pooling"
124
+ ),
125
+ keras.layers.Activation(
126
+ self.activation, name="activation_pooling"
127
+ ),
128
+ ]
129
+ )
130
+ pool_sequential.build(input_shape)
131
+ self.aspp_parallel_channels.append(pool_sequential)
132
+
133
+ # Final projection layers
134
+ projection = keras.Sequential(
135
+ [
136
+ keras.layers.Conv2D(
137
+ filters=self.num_channels,
138
+ kernel_size=(1, 1),
139
+ use_bias=False,
140
+ data_format=self.data_format,
141
+ name="conv_projection",
142
+ ),
143
+ keras.layers.BatchNormalization(
144
+ axis=self.channel_axis, name="bn_projection"
145
+ ),
146
+ keras.layers.Activation(
147
+ self.activation, name="activation_projection"
148
+ ),
149
+ keras.layers.Dropout(rate=self.dropout, name="dropout"),
150
+ ],
151
+ )
152
+ projection_input_channels = (
153
+ 2 + len(self.dilation_rates)
154
+ ) * self.num_channels
155
+ if self.data_format == "channels_first":
156
+ projection.build(
157
+ (input_shape[0],)
158
+ + (projection_input_channels,)
159
+ + (input_shape[2:])
160
+ )
161
+ else:
162
+ projection.build((input_shape[:-1]) + (projection_input_channels,))
163
+ self.projection = projection
164
+ self.built = True
165
+
166
+ def call(self, inputs):
167
+ """Calls the Atrous Spatial Pyramid Pooling layer on an input.
168
+
169
+ Args:
170
+ inputs: A tensor of shape [batch, height, width, channels]
171
+
172
+ Returns:
173
+ A tensor of shape [batch, height, width, num_channels]
174
+ """
175
+ result = []
176
+
177
+ for channel in self.aspp_parallel_channels:
178
+ temp = ops.cast(channel(inputs), inputs.dtype)
179
+ result.append(temp)
180
+
181
+ image_shape = ops.shape(inputs)
182
+ if self.channel_axis == -1:
183
+ height, width = image_shape[1], image_shape[2]
184
+ else:
185
+ height, width = image_shape[2], image_shape[3]
186
+ result[-1] = keras.layers.Resizing(
187
+ height,
188
+ width,
189
+ interpolation="bilinear",
190
+ data_format=self.data_format,
191
+ name="resizing",
192
+ )(result[-1])
193
+
194
+ result = ops.concatenate(result, axis=self.channel_axis)
195
+ return self.projection(result)
196
+
197
+ def compute_output_shape(self, inputs_shape):
198
+ if self.data_format == "channels_first":
199
+ return tuple(
200
+ (inputs_shape[0],) + (self.num_channels,) + (inputs_shape[2:])
201
+ )
202
+ else:
203
+ return tuple((inputs_shape[:-1]) + (self.num_channels,))
204
+
205
+ def get_config(self):
206
+ config = super().get_config()
207
+ config.update(
208
+ {
209
+ "dilation_rates": self.dilation_rates,
210
+ "num_channels": self.num_channels,
211
+ "activation": self.activation,
212
+ "dropout": self.dropout,
213
+ }
214
+ )
215
+ return config
@@ -0,0 +1,17 @@
1
+ """DeepLabV3 preset configurations."""
2
+
3
+ backbone_presets = {
4
+ "deeplab_v3_plus_resnet50_pascalvoc": {
5
+ "metadata": {
6
+ "description": (
7
+ "DeepLabV3+ model with ResNet50 as image encoder and trained "
8
+ "on augmented Pascal VOC dataset by Semantic Boundaries "
9
+ "Dataset(SBD) which is having categorical accuracy of 90.01 "
10
+ "and 0.63 Mean IoU."
11
+ ),
12
+ "params": 39190656,
13
+ "path": "deeplab_v3",
14
+ },
15
+ "kaggle_handle": "kaggle://keras/deeplabv3plus/keras/deeplab_v3_plus_resnet50_pascalvoc/4",
16
+ },
17
+ }
@@ -0,0 +1,111 @@
1
+ import keras
2
+
3
+ from keras_hub.src.api_export import keras_hub_export
4
+ from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import (
5
+ DeepLabV3Backbone,
6
+ )
7
+ from keras_hub.src.models.deeplab_v3.deeplab_v3_image_segmeter_preprocessor import ( # noqa: E501
8
+ DeepLabV3ImageSegmenterPreprocessor,
9
+ )
10
+ from keras_hub.src.models.image_segmenter import ImageSegmenter
11
+
12
+
13
+ @keras_hub_export("keras_hub.models.DeepLabV3ImageSegmenter")
14
+ class DeepLabV3ImageSegmenter(ImageSegmenter):
15
+ """DeepLabV3 and DeeplabV3 and DeeplabV3Plus segmentation task.
16
+
17
+ Args:
18
+ backbone: A `keras_hub.models.DeepLabV3` instance.
19
+ num_classes: int. The number of classes for the detection model. Note
20
+ that the `num_classes` contains the background class, and the
21
+ classes from the data should be represented by integers with range
22
+ `[0, num_classes]`.
23
+ activation: str or callable. The activation function to use on
24
+ the `Dense` layer. Set `activation=None` to return the output
25
+ logits. Defaults to `None`.
26
+ preprocessor: A `keras_hub.models.DeepLabV3ImageSegmenterPreprocessor`
27
+ or `None`. If `None`, this model will not apply preprocessing, and
28
+ inputs should be preprocessed before calling the model.
29
+
30
+ Example:
31
+ Load a DeepLabV3 preset with all the 21 class, pretrained segmentation head.
32
+ ```python
33
+ images = np.ones(shape=(1, 96, 96, 3))
34
+ labels = np.zeros(shape=(1, 96, 96, 2))
35
+ segmenter = keras_hub.models.DeepLabV3ImageSegmenter.from_preset(
36
+ "deeplab_v3_plus_resnet50_pascalvoc",
37
+ )
38
+ segmenter.predict(images)
39
+ ```
40
+
41
+ Specify `num_classes` to load randomly initialized segmentation head.
42
+ ```python
43
+ segmenter = keras_hub.models.DeepLabV3ImageSegmenter.from_preset(
44
+ "deeplab_v3_plus_resnet50_pascalvoc",
45
+ num_classes=2,
46
+ )
47
+ segmenter.preprocessor.image_size = (96, 96)
48
+ segmenter.fit(images, labels, epochs=3)
49
+ segmenter.predict(images) # Trained 2 class segmentation.
50
+ ```
51
+
52
+ Load DeepLabv3+ presets a extension of DeepLabv3 by adding a simple yet
53
+ effective decoder module to refine the segmentation results especially
54
+ along object boundaries.
55
+ ```python
56
+ segmenter = keras_hub.models.DeepLabV3ImageSegmenter.from_preset(
57
+ "deeplabv3_plus_resnet50_pascalvoc",
58
+ )
59
+ segmenter.predict(images)
60
+ ```
61
+ """
62
+
63
+ backbone_cls = DeepLabV3Backbone
64
+ preprocessor_cls = DeepLabV3ImageSegmenterPreprocessor
65
+
66
+ def __init__(
67
+ self,
68
+ backbone,
69
+ num_classes,
70
+ activation=None,
71
+ preprocessor=None,
72
+ **kwargs,
73
+ ):
74
+ data_format = keras.config.image_data_format()
75
+ # === Layers ===
76
+ self.output_conv = keras.layers.Conv2D(
77
+ name="segmentation_output",
78
+ filters=num_classes,
79
+ kernel_size=1,
80
+ use_bias=False,
81
+ padding="same",
82
+ activation=activation,
83
+ data_format=data_format,
84
+ )
85
+
86
+ # === Functional Model ===
87
+ inputs = backbone.input
88
+ x = backbone(inputs)
89
+ outputs = self.output_conv(x)
90
+ super().__init__(
91
+ inputs=inputs,
92
+ outputs=outputs,
93
+ **kwargs,
94
+ )
95
+
96
+ # === Config ===
97
+ self.backbone = backbone
98
+ self.num_classes = num_classes
99
+ self.activation = activation
100
+ self.preprocessor = preprocessor
101
+
102
+ def get_config(self):
103
+ # Backbone serialized in `super`
104
+ config = super().get_config()
105
+ config.update(
106
+ {
107
+ "num_classes": self.num_classes,
108
+ "activation": self.activation,
109
+ }
110
+ )
111
+ return config
@@ -29,7 +29,9 @@ class DenseNetBackbone(FeaturePyramidBackbone):
29
29
  input_data = np.ones(shape=(8, 224, 224, 3))
30
30
 
31
31
  # Pretrained backbone
32
- model = keras_hub.models.DenseNetBackbone.from_preset("densenet121_imagenet")
32
+ model = keras_hub.models.DenseNetBackbone.from_preset(
33
+ "densenet_121_imagenet"
34
+ )
33
35
  model(input_data)
34
36
 
35
37
  # Randomly initialized backbone with a custom config
@@ -79,14 +81,14 @@ class DenseNetBackbone(FeaturePyramidBackbone):
79
81
  channel_axis,
80
82
  stackwise_num_repeats[stack_index],
81
83
  growth_rate,
82
- name=f"stack{stack_index+1}",
84
+ name=f"stack{stack_index + 1}",
83
85
  )
84
86
  pyramid_outputs[f"P{index}"] = x
85
87
  x = apply_transition_block(
86
88
  x,
87
89
  channel_axis,
88
90
  compression_ratio,
89
- name=f"transition{stack_index+1}",
91
+ name=f"transition{stack_index + 1}",
90
92
  )
91
93
 
92
94
  x = apply_dense_block(
@@ -138,7 +140,7 @@ def apply_dense_block(x, channel_axis, num_repeats, growth_rate, name=None):
138
140
 
139
141
  for i in range(num_repeats):
140
142
  x = apply_conv_block(
141
- x, channel_axis, growth_rate, name=f"{name}_block{i+1}"
143
+ x, channel_axis, growth_rate, name=f"{name}_block{i + 1}"
142
144
  )
143
145
  return x
144
146