keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. keras_hub/api/layers/__init__.py +21 -3
  2. keras_hub/api/models/__init__.py +71 -12
  3. keras_hub/api/tokenizers/__init__.py +1 -1
  4. keras_hub/src/bounding_box/__init__.py +2 -0
  5. keras_hub/src/bounding_box/converters.py +102 -12
  6. keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
  7. keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
  8. keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
  9. keras_hub/src/layers/modeling/rms_normalization.py +36 -0
  10. keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
  11. keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
  12. keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
  13. keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
  14. keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
  15. keras_hub/src/layers/preprocessing/image_converter.py +170 -34
  16. keras_hub/src/metrics/bleu.py +4 -3
  17. keras_hub/src/models/albert/albert_presets.py +4 -12
  18. keras_hub/src/models/albert/albert_text_classifier.py +7 -7
  19. keras_hub/src/models/backbone.py +3 -14
  20. keras_hub/src/models/bart/bart_backbone.py +4 -4
  21. keras_hub/src/models/bart/bart_presets.py +3 -9
  22. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
  23. keras_hub/src/models/basnet/__init__.py +5 -0
  24. keras_hub/src/models/basnet/basnet.py +122 -0
  25. keras_hub/src/models/basnet/basnet_backbone.py +366 -0
  26. keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
  27. keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
  28. keras_hub/src/models/basnet/basnet_presets.py +17 -0
  29. keras_hub/src/models/bert/bert_presets.py +14 -32
  30. keras_hub/src/models/bert/bert_text_classifier.py +3 -3
  31. keras_hub/src/models/bloom/bloom_presets.py +8 -24
  32. keras_hub/src/models/causal_lm.py +56 -12
  33. keras_hub/src/models/clip/__init__.py +5 -0
  34. keras_hub/src/models/clip/clip_backbone.py +286 -0
  35. keras_hub/src/models/clip/clip_encoder_block.py +19 -4
  36. keras_hub/src/models/clip/clip_image_converter.py +8 -0
  37. keras_hub/src/models/clip/clip_presets.py +93 -0
  38. keras_hub/src/models/clip/clip_text_encoder.py +4 -1
  39. keras_hub/src/models/clip/clip_tokenizer.py +18 -3
  40. keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
  41. keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
  42. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
  43. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
  44. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
  45. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
  46. keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
  47. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
  48. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
  49. keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
  50. keras_hub/src/models/deeplab_v3/__init__.py +7 -0
  51. keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
  52. keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
  53. keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
  54. keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
  55. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
  56. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
  57. keras_hub/src/models/densenet/densenet_backbone.py +6 -4
  58. keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
  59. keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
  60. keras_hub/src/models/densenet/densenet_presets.py +9 -15
  61. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
  62. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
  63. keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
  64. keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
  65. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
  66. keras_hub/src/models/efficientnet/__init__.py +9 -0
  67. keras_hub/src/models/efficientnet/cba.py +141 -0
  68. keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
  69. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
  70. keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
  71. keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
  72. keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
  73. keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
  74. keras_hub/src/models/efficientnet/mbconv.py +53 -22
  75. keras_hub/src/models/electra/electra_backbone.py +2 -2
  76. keras_hub/src/models/electra/electra_presets.py +6 -18
  77. keras_hub/src/models/f_net/f_net_presets.py +2 -6
  78. keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
  79. keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
  80. keras_hub/src/models/falcon/falcon_backbone.py +5 -3
  81. keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
  82. keras_hub/src/models/falcon/falcon_presets.py +1 -3
  83. keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
  84. keras_hub/src/models/feature_pyramid_backbone.py +1 -1
  85. keras_hub/src/models/flux/__init__.py +5 -0
  86. keras_hub/src/models/flux/flux_layers.py +496 -0
  87. keras_hub/src/models/flux/flux_maths.py +225 -0
  88. keras_hub/src/models/flux/flux_model.py +236 -0
  89. keras_hub/src/models/flux/flux_presets.py +3 -0
  90. keras_hub/src/models/flux/flux_text_to_image.py +146 -0
  91. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
  92. keras_hub/src/models/gemma/gemma_backbone.py +35 -20
  93. keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
  94. keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
  95. keras_hub/src/models/gemma/gemma_presets.py +29 -63
  96. keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
  97. keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
  98. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
  99. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
  100. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
  101. keras_hub/src/models/image_classifier.py +147 -2
  102. keras_hub/src/models/image_classifier_preprocessor.py +6 -3
  103. keras_hub/src/models/image_object_detector.py +87 -0
  104. keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
  105. keras_hub/src/models/image_segmenter.py +0 -5
  106. keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
  107. keras_hub/src/models/image_to_image.py +417 -0
  108. keras_hub/src/models/inpaint.py +520 -0
  109. keras_hub/src/models/llama/llama_backbone.py +138 -12
  110. keras_hub/src/models/llama/llama_causal_lm.py +3 -1
  111. keras_hub/src/models/llama/llama_presets.py +10 -20
  112. keras_hub/src/models/llama3/llama3_backbone.py +12 -11
  113. keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
  114. keras_hub/src/models/llama3/llama3_presets.py +4 -12
  115. keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
  116. keras_hub/src/models/mistral/mistral_backbone.py +16 -15
  117. keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
  118. keras_hub/src/models/mistral/mistral_presets.py +3 -9
  119. keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
  120. keras_hub/src/models/mit/__init__.py +6 -0
  121. keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
  122. keras_hub/src/models/mit/mit_image_classifier.py +12 -0
  123. keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
  124. keras_hub/src/models/mit/mit_image_converter.py +8 -0
  125. keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
  126. keras_hub/src/models/mit/mit_presets.py +139 -0
  127. keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
  128. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
  129. keras_hub/src/models/opt/opt_causal_lm.py +2 -2
  130. keras_hub/src/models/opt/opt_presets.py +4 -12
  131. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
  132. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
  133. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
  134. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
  135. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
  136. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
  137. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
  138. keras_hub/src/models/phi3/phi3_decoder.py +0 -1
  139. keras_hub/src/models/phi3/phi3_presets.py +2 -6
  140. keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
  141. keras_hub/src/models/preprocessor.py +25 -11
  142. keras_hub/src/models/resnet/resnet_backbone.py +3 -14
  143. keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
  144. keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
  145. keras_hub/src/models/resnet/resnet_presets.py +127 -18
  146. keras_hub/src/models/retinanet/__init__.py +5 -0
  147. keras_hub/src/models/retinanet/anchor_generator.py +52 -53
  148. keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
  149. keras_hub/src/models/retinanet/non_max_supression.py +1 -0
  150. keras_hub/src/models/retinanet/prediction_head.py +192 -0
  151. keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
  152. keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
  153. keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
  154. keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
  155. keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
  156. keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
  157. keras_hub/src/models/roberta/roberta_backbone.py +2 -2
  158. keras_hub/src/models/roberta/roberta_presets.py +6 -8
  159. keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
  160. keras_hub/src/models/sam/__init__.py +5 -0
  161. keras_hub/src/models/sam/sam_backbone.py +2 -3
  162. keras_hub/src/models/sam/sam_image_converter.py +2 -4
  163. keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
  164. keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
  165. keras_hub/src/models/sam/sam_layers.py +5 -3
  166. keras_hub/src/models/sam/sam_presets.py +3 -9
  167. keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
  168. keras_hub/src/models/sam/sam_transformer.py +5 -4
  169. keras_hub/src/models/segformer/__init__.py +8 -0
  170. keras_hub/src/models/segformer/segformer_backbone.py +167 -0
  171. keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
  172. keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
  173. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
  174. keras_hub/src/models/segformer/segformer_presets.py +136 -0
  175. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
  176. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
  177. keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
  178. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
  179. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
  180. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
  181. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
  182. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
  183. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
  184. keras_hub/src/models/t5/t5_backbone.py +5 -4
  185. keras_hub/src/models/t5/t5_presets.py +47 -19
  186. keras_hub/src/models/task.py +47 -39
  187. keras_hub/src/models/text_classifier.py +2 -2
  188. keras_hub/src/models/text_to_image.py +106 -41
  189. keras_hub/src/models/vae/__init__.py +1 -0
  190. keras_hub/src/models/vae/vae_backbone.py +184 -0
  191. keras_hub/src/models/vae/vae_layers.py +739 -0
  192. keras_hub/src/models/vgg/__init__.py +5 -0
  193. keras_hub/src/models/vgg/vgg_backbone.py +4 -24
  194. keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
  195. keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
  196. keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
  197. keras_hub/src/models/vgg/vgg_presets.py +48 -0
  198. keras_hub/src/models/vit/__init__.py +5 -0
  199. keras_hub/src/models/vit/vit_backbone.py +152 -0
  200. keras_hub/src/models/vit/vit_image_classifier.py +187 -0
  201. keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
  202. keras_hub/src/models/vit/vit_image_converter.py +73 -0
  203. keras_hub/src/models/vit/vit_layers.py +391 -0
  204. keras_hub/src/models/vit/vit_presets.py +126 -0
  205. keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
  206. keras_hub/src/models/vit_det/vit_layers.py +3 -3
  207. keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
  208. keras_hub/src/models/whisper/whisper_backbone.py +6 -5
  209. keras_hub/src/models/whisper/whisper_decoder.py +3 -5
  210. keras_hub/src/models/whisper/whisper_presets.py +10 -30
  211. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
  212. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
  213. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
  214. keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
  215. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
  216. keras_hub/src/models/xlnet/relative_attention.py +20 -19
  217. keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
  218. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
  219. keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
  220. keras_hub/src/samplers/contrastive_sampler.py +2 -3
  221. keras_hub/src/samplers/sampler.py +2 -1
  222. keras_hub/src/tests/test_case.py +41 -6
  223. keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
  224. keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
  225. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
  226. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
  227. keras_hub/src/tokenizers/tokenizer.py +10 -13
  228. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
  229. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
  230. keras_hub/src/utils/keras_utils.py +2 -13
  231. keras_hub/src/utils/pipeline_model.py +3 -3
  232. keras_hub/src/utils/preset_utils.py +196 -144
  233. keras_hub/src/utils/tensor_utils.py +4 -4
  234. keras_hub/src/utils/timm/convert_densenet.py +6 -4
  235. keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
  236. keras_hub/src/utils/timm/convert_resnet.py +1 -1
  237. keras_hub/src/utils/timm/convert_vgg.py +85 -0
  238. keras_hub/src/utils/timm/preset_loader.py +14 -9
  239. keras_hub/src/utils/transformers/convert_llama3.py +21 -5
  240. keras_hub/src/utils/transformers/convert_vit.py +150 -0
  241. keras_hub/src/utils/transformers/preset_loader.py +23 -0
  242. keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
  243. keras_hub/src/version_utils.py +1 -1
  244. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
  245. keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
  246. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
  247. keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
  248. keras_hub/src/models/mix_transformer/__init__.py +0 -0
  249. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
  250. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
  251. keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
  252. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,17 @@
1
+ import math
2
+
1
3
  import keras
2
4
 
5
+ from keras_hub.src.utils.keras_utils import standardize_data_format
6
+
3
7
 
4
8
  class FeaturePyramid(keras.layers.Layer):
5
9
  """A Feature Pyramid Network (FPN) layer.
6
10
 
7
11
  This implements the paper:
8
- Tsung-Yi Lin, Piotr Dollar, Ross Girshick, Kaiming He, Bharath Hariharan,
9
- and Serge Belongie. Feature Pyramid Networks for Object Detection.
12
+ Tsung-Yi Lin, Piotr Dollar, Ross Girshick, Kaiming He,
13
+ Bharath Hariharan, and Serge Belongie.
14
+ Feature Pyramid Networks for Object Detection.
10
15
  (https://arxiv.org/pdf/1612.03144)
11
16
 
12
17
  Feature Pyramid Networks (FPNs) are basic components that are added to an
@@ -37,14 +42,18 @@ class FeaturePyramid(keras.layers.Layer):
37
42
  Args:
38
43
  min_level: int. The minimum level of the feature pyramid.
39
44
  max_level: int. The maximum level of the feature pyramid.
45
+ use_p5: bool. If True, uses the output of the last layer (`P5` from
46
+ Feature Pyramid Network) as input for creating coarser convolution
47
+ layers (`P6`, `P7`). If False, uses the direct input `P5`
48
+ for creating coarser convolution layers.
40
49
  num_filters: int. The number of filters in each feature map.
41
50
  activation: string or `keras.activations`. The activation function
42
51
  to be used in network.
43
52
  Defaults to `"relu"`.
44
- kernel_initializer: `str` or `keras.initializers` initializer.
53
+ kernel_initializer: `str` or `keras.initializers`.
45
54
  The kernel initializer for the convolution layers.
46
55
  Defaults to `"VarianceScaling"`.
47
- bias_initializer: `str` or `keras.initializers` initializer.
56
+ bias_initializer: `str` or `keras.initializers`.
48
57
  The bias initializer for the convolution layers.
49
58
  Defaults to `"zeros"`.
50
59
  batch_norm_momentum: float.
@@ -53,10 +62,10 @@ class FeaturePyramid(keras.layers.Layer):
53
62
  batch_norm_epsilon: float.
54
63
  The epsilon for the batch normalization layers.
55
64
  Defaults to `0.001`.
56
- kernel_regularizer: `str` or `keras.regularizers` regularizer.
65
+ kernel_regularizer: `str` or `keras.regularizers`.
57
66
  The kernel regularizer for the convolution layers.
58
67
  Defaults to `None`.
59
- bias_regularizer: `str` or `keras.regularizers` regularizer.
68
+ bias_regularizer: `str` or `keras.regularizers`.
60
69
  The bias regularizer for the convolution layers.
61
70
  Defaults to `None`.
62
71
  use_batch_norm: bool. Whether to use batch normalization.
@@ -69,6 +78,7 @@ class FeaturePyramid(keras.layers.Layer):
69
78
  self,
70
79
  min_level,
71
80
  max_level,
81
+ use_p5,
72
82
  num_filters=256,
73
83
  activation="relu",
74
84
  kernel_initializer="VarianceScaling",
@@ -78,6 +88,7 @@ class FeaturePyramid(keras.layers.Layer):
78
88
  kernel_regularizer=None,
79
89
  bias_regularizer=None,
80
90
  use_batch_norm=False,
91
+ data_format=None,
81
92
  **kwargs,
82
93
  ):
83
94
  super().__init__(**kwargs)
@@ -89,6 +100,7 @@ class FeaturePyramid(keras.layers.Layer):
89
100
  self.min_level = min_level
90
101
  self.max_level = max_level
91
102
  self.num_filters = num_filters
103
+ self.use_p5 = use_p5
92
104
  self.activation = keras.activations.get(activation)
93
105
  self.kernel_initializer = keras.initializers.get(kernel_initializer)
94
106
  self.bias_initializer = keras.initializers.get(bias_initializer)
@@ -103,8 +115,8 @@ class FeaturePyramid(keras.layers.Layer):
103
115
  self.bias_regularizer = keras.regularizers.get(bias_regularizer)
104
116
  else:
105
117
  self.bias_regularizer = None
106
- self.data_format = keras.backend.image_data_format()
107
- self.batch_norm_axis = -1 if self.data_format == "channels_last" else 1
118
+ self.data_format = standardize_data_format(data_format)
119
+ self.batch_norm_axis = -1 if data_format == "channels_last" else 1
108
120
 
109
121
  def build(self, input_shapes):
110
122
  input_shapes = {
@@ -117,7 +129,6 @@ class FeaturePyramid(keras.layers.Layer):
117
129
  }
118
130
  input_levels = [int(level[1]) for level in input_shapes]
119
131
  backbone_max_level = min(max(input_levels), self.max_level)
120
-
121
132
  # Build lateral layers
122
133
  self.lateral_conv_layers = {}
123
134
  for i in range(self.min_level, backbone_max_level + 1):
@@ -134,7 +145,11 @@ class FeaturePyramid(keras.layers.Layer):
134
145
  dtype=self.dtype_policy,
135
146
  name=f"lateral_conv_{level}",
136
147
  )
137
- self.lateral_conv_layers[level].build(input_shapes[level])
148
+ self.lateral_conv_layers[level].build(
149
+ (None, None, None, input_shapes[level][-1])
150
+ if self.data_format == "channels_last"
151
+ else (None, input_shapes[level][1], None, None)
152
+ )
138
153
 
139
154
  self.lateral_batch_norm_layers = {}
140
155
  if self.use_batch_norm:
@@ -149,9 +164,9 @@ class FeaturePyramid(keras.layers.Layer):
149
164
  )
150
165
  )
151
166
  self.lateral_batch_norm_layers[level].build(
152
- (None, None, None, 256)
167
+ (None, None, None, self.num_filters)
153
168
  if self.data_format == "channels_last"
154
- else (None, 256, None, None)
169
+ else (None, self.num_filters, None, None)
155
170
  )
156
171
 
157
172
  # Build output layers
@@ -171,9 +186,9 @@ class FeaturePyramid(keras.layers.Layer):
171
186
  name=f"output_conv_{level}",
172
187
  )
173
188
  self.output_conv_layers[level].build(
174
- (None, None, None, 256)
189
+ (None, None, None, self.num_filters)
175
190
  if self.data_format == "channels_last"
176
- else (None, 256, None, None)
191
+ else (None, self.num_filters, None, None)
177
192
  )
178
193
 
179
194
  # Build coarser layers
@@ -192,11 +207,18 @@ class FeaturePyramid(keras.layers.Layer):
192
207
  dtype=self.dtype_policy,
193
208
  name=f"coarser_{level}",
194
209
  )
195
- self.output_conv_layers[level].build(
196
- (None, None, None, 256)
197
- if self.data_format == "channels_last"
198
- else (None, 256, None, None)
199
- )
210
+ if i == backbone_max_level + 1 and self.use_p5:
211
+ self.output_conv_layers[level].build(
212
+ (None, None, None, input_shapes[f"P{i - 1}"][-1])
213
+ if self.data_format == "channels_last"
214
+ else (None, input_shapes[f"P{i - 1}"][1], None, None)
215
+ )
216
+ else:
217
+ self.output_conv_layers[level].build(
218
+ (None, None, None, self.num_filters)
219
+ if self.data_format == "channels_last"
220
+ else (None, self.num_filters, None, None)
221
+ )
200
222
 
201
223
  # Build batch norm layers
202
224
  self.output_batch_norms = {}
@@ -212,9 +234,9 @@ class FeaturePyramid(keras.layers.Layer):
212
234
  )
213
235
  )
214
236
  self.output_batch_norms[level].build(
215
- (None, None, None, 256)
237
+ (None, None, None, self.num_filters)
216
238
  if self.data_format == "channels_last"
217
- else (None, 256, None, None)
239
+ else (None, self.num_filters, None, None)
218
240
  )
219
241
 
220
242
  # The same upsampling layer is used for all levels
@@ -255,7 +277,7 @@ class FeaturePyramid(keras.layers.Layer):
255
277
  if i < backbone_max_level:
256
278
  # for the top most output, it doesn't need to merge with any
257
279
  # upper stream outputs
258
- upstream_output = self.top_down_op(output_features[f"P{i+1}"])
280
+ upstream_output = self.top_down_op(output_features[f"P{i + 1}"])
259
281
  output = self.merge_op([output, upstream_output])
260
282
  output_features[level] = (
261
283
  self.lateral_batch_norm_layers[level](output)
@@ -273,7 +295,11 @@ class FeaturePyramid(keras.layers.Layer):
273
295
 
274
296
  for i in range(backbone_max_level + 1, self.max_level + 1):
275
297
  level = f"P{i}"
276
- feats_in = output_features[f"P{i-1}"]
298
+ feats_in = (
299
+ inputs[f"P{i - 1}"]
300
+ if i == backbone_max_level + 1 and self.use_p5
301
+ else output_features[f"P{i - 1}"]
302
+ )
277
303
  if i > backbone_max_level + 1:
278
304
  feats_in = self.activation(feats_in)
279
305
  output_features[level] = (
@@ -283,7 +309,10 @@ class FeaturePyramid(keras.layers.Layer):
283
309
  if self.use_batch_norm
284
310
  else self.output_conv_layers[level](feats_in)
285
311
  )
286
-
312
+ output_features = {
313
+ f"P{i}": output_features[f"P{i}"]
314
+ for i in range(self.min_level, self.max_level + 1)
315
+ }
287
316
  return output_features
288
317
 
289
318
  def get_config(self):
@@ -293,7 +322,9 @@ class FeaturePyramid(keras.layers.Layer):
293
322
  "min_level": self.min_level,
294
323
  "max_level": self.max_level,
295
324
  "num_filters": self.num_filters,
325
+ "use_p5": self.use_p5,
296
326
  "use_batch_norm": self.use_batch_norm,
327
+ "data_format": self.data_format,
297
328
  "activation": keras.activations.serialize(self.activation),
298
329
  "kernel_initializer": keras.initializers.serialize(
299
330
  self.kernel_initializer
@@ -320,34 +351,51 @@ class FeaturePyramid(keras.layers.Layer):
320
351
 
321
352
  def compute_output_shape(self, input_shapes):
322
353
  output_shape = {}
323
- print(input_shapes)
324
354
  input_levels = [int(level[1]) for level in input_shapes]
325
355
  backbone_max_level = min(max(input_levels), self.max_level)
326
356
 
327
357
  for i in range(self.min_level, backbone_max_level + 1):
328
358
  level = f"P{i}"
329
359
  if self.data_format == "channels_last":
330
- output_shape[level] = input_shapes[level][:-1] + (256,)
360
+ output_shape[level] = input_shapes[level][:-1] + (
361
+ self.num_filters,
362
+ )
331
363
  else:
332
364
  output_shape[level] = (
333
365
  input_shapes[level][0],
334
- 256,
366
+ self.num_filters,
335
367
  ) + input_shapes[level][1:3]
336
368
 
337
369
  intermediate_shape = input_shapes[f"P{backbone_max_level}"]
338
370
  intermediate_shape = (
339
371
  (
340
372
  intermediate_shape[0],
341
- intermediate_shape[1] // 2,
342
- intermediate_shape[2] // 2,
343
- 256,
373
+ (
374
+ int(math.ceil(intermediate_shape[1] / 2))
375
+ if intermediate_shape[1] is not None
376
+ else None
377
+ ),
378
+ (
379
+ int(math.ceil(intermediate_shape[1] / 2))
380
+ if intermediate_shape[1] is not None
381
+ else None
382
+ ),
383
+ self.num_filters,
344
384
  )
345
385
  if self.data_format == "channels_last"
346
386
  else (
347
387
  intermediate_shape[0],
348
- 256,
349
- intermediate_shape[1] // 2,
350
- intermediate_shape[2] // 2,
388
+ self.num_filters,
389
+ (
390
+ int(math.ceil(intermediate_shape[1] / 2))
391
+ if intermediate_shape[1] is not None
392
+ else None
393
+ ),
394
+ (
395
+ int(math.ceil(intermediate_shape[1] / 2))
396
+ if intermediate_shape[1] is not None
397
+ else None
398
+ ),
351
399
  )
352
400
  )
353
401
 
@@ -357,16 +405,32 @@ class FeaturePyramid(keras.layers.Layer):
357
405
  intermediate_shape = (
358
406
  (
359
407
  intermediate_shape[0],
360
- intermediate_shape[1] // 2,
361
- intermediate_shape[2] // 2,
362
- 256,
408
+ (
409
+ int(math.ceil(intermediate_shape[1] / 2))
410
+ if intermediate_shape[1] is not None
411
+ else None
412
+ ),
413
+ (
414
+ int(math.ceil(intermediate_shape[1] / 2))
415
+ if intermediate_shape[1] is not None
416
+ else None
417
+ ),
418
+ self.num_filters,
363
419
  )
364
420
  if self.data_format == "channels_last"
365
421
  else (
366
422
  intermediate_shape[0],
367
- 256,
368
- intermediate_shape[1] // 2,
369
- intermediate_shape[2] // 2,
423
+ self.num_filters,
424
+ (
425
+ int(math.ceil(intermediate_shape[1] / 2))
426
+ if intermediate_shape[1] is not None
427
+ else None
428
+ ),
429
+ (
430
+ int(math.ceil(intermediate_shape[1] / 2))
431
+ if intermediate_shape[1] is not None
432
+ else None
433
+ ),
370
434
  )
371
435
  )
372
436
 
@@ -3,6 +3,7 @@ import math
3
3
  import keras
4
4
  from keras import ops
5
5
 
6
+ # TODO: https://github.com/keras-team/keras-hub/issues/1965
6
7
  from keras_hub.src.bounding_box import converters
7
8
  from keras_hub.src.bounding_box import utils
8
9
  from keras_hub.src.bounding_box import validate_format
@@ -0,0 +1,192 @@
1
+ import keras
2
+
3
+ from keras_hub.src.utils.keras_utils import standardize_data_format
4
+
5
+
6
+ class PredictionHead(keras.layers.Layer):
7
+ """A head for classification or bounding box regression predictions.
8
+
9
+ Args:
10
+ output_filters: int. The umber of convolution filters in the final
11
+ layer. The number of output channels determines the prediction type:
12
+ - **Classification**:
13
+ `output_filters = num_anchors * num_classes`
14
+ Predicts class probabilities for each anchor.
15
+ - **Bounding Box Regression**:
16
+ `output_filters = num_anchors * 4` Predicts bounding box
17
+ offsets (x1, y1, x2, y2) for each anchor.
18
+ num_filters: int. The number of convolution filters to use in the base
19
+ layer.
20
+ num_conv_layers: int. The number of convolution layers before the final
21
+ layer.
22
+ use_prior_probability: bool. Set to True to use prior probability in the
23
+ bias initializer for the final convolution layer.
24
+ Defaults to `False`.
25
+ prior_probability: float. The prior probability value to use for
26
+ initializing the bias. Only used if `use_prior_probability` is
27
+ `True`. Defaults to `0.01`.
28
+ kernel_initializer: `str` or `keras.initializers`. The kernel
29
+ initializer for the convolution layers. Defaults to
30
+ `"random_normal"`.
31
+ bias_initializer: `str` or `keras.initializers`. The bias initializer
32
+ for the convolution layers. Defaults to `"zeros"`.
33
+ kernel_regularizer: `str` or `keras.regularizers`. The kernel
34
+ regularizer for the convolution layers. Defaults to `None`.
35
+ bias_regularizer: `str` or `keras.regularizers`. The bias regularizer
36
+ for the convolution layers. Defaults to `None`.
37
+ use_group_norm: bool. Whether to use Group Normalization after
38
+ the convolution layers. Defaults to `False`.
39
+
40
+ Returns:
41
+ A function representing either the classification
42
+ or the box regression head depending on `output_filters`.
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ output_filters,
48
+ num_filters,
49
+ num_conv_layers,
50
+ use_prior_probability=False,
51
+ prior_probability=0.01,
52
+ activation="relu",
53
+ kernel_initializer="random_normal",
54
+ bias_initializer="zeros",
55
+ kernel_regularizer=None,
56
+ bias_regularizer=None,
57
+ use_group_norm=False,
58
+ data_format=None,
59
+ **kwargs,
60
+ ):
61
+ super().__init__(**kwargs)
62
+
63
+ self.output_filters = output_filters
64
+ self.num_filters = num_filters
65
+ self.num_conv_layers = num_conv_layers
66
+ self.use_prior_probability = use_prior_probability
67
+ self.prior_probability = prior_probability
68
+ self.activation = keras.activations.get(activation)
69
+ self.kernel_initializer = keras.initializers.get(kernel_initializer)
70
+ self.bias_initializer = keras.initializers.get(bias_initializer)
71
+ if kernel_regularizer is not None:
72
+ self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)
73
+ else:
74
+ self.kernel_regularizer = None
75
+ if bias_regularizer is not None:
76
+ self.bias_regularizer = keras.regularizers.get(bias_regularizer)
77
+ else:
78
+ self.bias_regularizer = None
79
+ self.use_group_norm = use_group_norm
80
+ self.data_format = standardize_data_format(data_format)
81
+
82
+ def build(self, input_shape):
83
+ intermediate_shape = input_shape
84
+ self.conv_layers = []
85
+ self.group_norm_layers = []
86
+ for idx in range(self.num_conv_layers):
87
+ conv = keras.layers.Conv2D(
88
+ self.num_filters,
89
+ kernel_size=3,
90
+ padding="same",
91
+ kernel_initializer=self.kernel_initializer,
92
+ bias_initializer=self.bias_initializer,
93
+ use_bias=not self.use_group_norm,
94
+ kernel_regularizer=self.kernel_regularizer,
95
+ bias_regularizer=self.bias_regularizer,
96
+ data_format=self.data_format,
97
+ dtype=self.dtype_policy,
98
+ name=f"conv2d_{idx}",
99
+ )
100
+ conv.build(intermediate_shape)
101
+ self.conv_layers.append(conv)
102
+ intermediate_shape = (
103
+ input_shape[:-1] + (self.num_filters,)
104
+ if self.data_format == "channels_last"
105
+ else (input_shape[0], self.num_filters) + (input_shape[1:-1])
106
+ )
107
+ if self.use_group_norm:
108
+ group_norm = keras.layers.GroupNormalization(
109
+ groups=32,
110
+ axis=-1 if self.data_format == "channels_last" else 1,
111
+ dtype=self.dtype_policy,
112
+ name=f"group_norm_{idx}",
113
+ )
114
+ group_norm.build(intermediate_shape)
115
+ self.group_norm_layers.append(group_norm)
116
+ prior_probability = keras.initializers.Constant(
117
+ -1
118
+ * keras.ops.log(
119
+ (1 - self.prior_probability) / self.prior_probability
120
+ )
121
+ )
122
+ self.prediction_layer = keras.layers.Conv2D(
123
+ self.output_filters,
124
+ kernel_size=3,
125
+ strides=1,
126
+ padding="same",
127
+ kernel_initializer=self.kernel_initializer,
128
+ bias_initializer=(
129
+ prior_probability
130
+ if self.use_prior_probability
131
+ else self.bias_initializer
132
+ ),
133
+ kernel_regularizer=self.kernel_regularizer,
134
+ bias_regularizer=self.bias_regularizer,
135
+ dtype=self.dtype_policy,
136
+ name="logits_layer",
137
+ )
138
+ self.prediction_layer.build(
139
+ (None, None, None, self.num_filters)
140
+ if self.data_format == "channels_last"
141
+ else (None, self.num_filters, None, None)
142
+ )
143
+ self.built = True
144
+
145
+ def call(self, input):
146
+ x = input
147
+ for idx in range(self.num_conv_layers):
148
+ x = self.conv_layers[idx](x)
149
+ if self.use_group_norm:
150
+ x = self.group_norm_layers[idx](x)
151
+ x = self.activation(x)
152
+
153
+ output = self.prediction_layer(x)
154
+ return output
155
+
156
+ def get_config(self):
157
+ config = super().get_config()
158
+ config.update(
159
+ {
160
+ "output_filters": self.output_filters,
161
+ "num_filters": self.num_filters,
162
+ "num_conv_layers": self.num_conv_layers,
163
+ "use_group_norm": self.use_group_norm,
164
+ "use_prior_probability": self.use_prior_probability,
165
+ "prior_probability": self.prior_probability,
166
+ "activation": keras.activations.serialize(self.activation),
167
+ "kernel_initializer": keras.initializers.serialize(
168
+ self.kernel_initializer
169
+ ),
170
+ "bias_initializer": keras.initializers.serialize(
171
+ self.kernel_initializer
172
+ ),
173
+ "kernel_regularizer": (
174
+ keras.regularizers.serialize(self.kernel_regularizer)
175
+ if self.kernel_regularizer is not None
176
+ else None
177
+ ),
178
+ "bias_regularizer": (
179
+ keras.regularizers.serialize(self.bias_regularizer)
180
+ if self.bias_regularizer is not None
181
+ else None
182
+ ),
183
+ }
184
+ )
185
+ return config
186
+
187
+ def compute_output_shape(self, input_shape):
188
+ return (
189
+ input_shape[:-1] + (self.output_filters,)
190
+ if self.data_format == "channels_last"
191
+ else (input_shape[0],) + (self.output_filters,) + input_shape[1:-1]
192
+ )
@@ -0,0 +1,146 @@
1
+ import keras
2
+
3
+ from keras_hub.src.api_export import keras_hub_export
4
+ from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone
5
+ from keras_hub.src.models.retinanet.feature_pyramid import FeaturePyramid
6
+ from keras_hub.src.utils.keras_utils import standardize_data_format
7
+
8
+
9
+ @keras_hub_export("keras_hub.models.RetinaNetBackbone")
10
+ class RetinaNetBackbone(FeaturePyramidBackbone):
11
+ """RetinaNet Backbone.
12
+
13
+ Combines a CNN backbone (e.g., ResNet, MobileNet) with a feature pyramid
14
+ network (FPN)to extract multi-scale features for object detection.
15
+
16
+ Args:
17
+ image_encoder: `keras.Model`. The backbone model (e.g., ResNet50,
18
+ MobileNetV2) used to extract features from the input image.
19
+ It should have pyramid outputs (i.e., a dictionary mapping level
20
+ names like `"P2"`, `"P3"`, etc. to their corresponding feature
21
+ tensors).
22
+ min_level: int. The minimum level of the feature pyramid (e.g., 3).
23
+ This determines the coarsest level of features used.
24
+ max_level: int. The maximum level of the feature pyramid (e.g., 7).
25
+ This determines the finest level of features used.
26
+ use_p5: bool. Determines the input source for creating coarser
27
+ feature pyramid levels. If `True`, the output of the last backbone
28
+ layer (typically `'P5'` in an FPN) is used as input to create
29
+ higher-level feature maps (e.g., `'P6'`, `'P7'`) through
30
+ additional convolutional layers. If `False`, the original `'P5'`
31
+ feature map from the backbone is directly used as input for
32
+ creating the coarser levels, bypassing any further processing of
33
+ `'P5'` within the feature pyramid. Defaults to `False`.
34
+ use_fpn_batch_norm: bool. Whether to use batch normalization in the
35
+ feature pyramid network. Defaults to `False`.
36
+ image_shape: tuple. tuple. The shape of the input image (H, W, C).
37
+ The height and width can be `None` if they are variable.
38
+ data_format: str. The data format of the input image
39
+ (channels_first or channels_last).
40
+ dtype: str. The data type of the input image.
41
+ **kwargs: Additional keyword arguments passed to the base class.
42
+
43
+ Raises:
44
+ ValueError: If `min_level` is greater than `max_level`.
45
+ ValueError: If `backbone_max_level` is less than 5 and `max_level` is
46
+ greater than or equal to 5.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ image_encoder,
52
+ min_level,
53
+ max_level,
54
+ use_p5,
55
+ use_fpn_batch_norm=False,
56
+ image_shape=(None, None, 3),
57
+ data_format=None,
58
+ dtype=None,
59
+ **kwargs,
60
+ ):
61
+ # === Layers ===
62
+ if min_level > max_level:
63
+ raise ValueError(
64
+ f"Minimum level ({min_level}) must be less than or equal to "
65
+ f"maximum level ({max_level})."
66
+ )
67
+
68
+ data_format = standardize_data_format(data_format)
69
+ input_levels = [
70
+ int(level[1]) for level in image_encoder.pyramid_outputs
71
+ ]
72
+ backbone_max_level = min(max(input_levels), max_level)
73
+
74
+ if backbone_max_level < 5 and max_level >= 5:
75
+ raise ValueError(
76
+ f"Backbone maximum level ({backbone_max_level}) is less than "
77
+ f"the desired maximum level ({max_level}). "
78
+ f"Please ensure that the backbone can generate features up to "
79
+ f"the specified maximum level."
80
+ )
81
+ feature_extractor = keras.Model(
82
+ inputs=image_encoder.inputs,
83
+ outputs={
84
+ f"P{level}": image_encoder.pyramid_outputs[f"P{level}"]
85
+ for level in range(min_level, backbone_max_level + 1)
86
+ },
87
+ name="backbone",
88
+ )
89
+
90
+ feature_pyramid = FeaturePyramid(
91
+ min_level=min_level,
92
+ max_level=max_level,
93
+ use_p5=use_p5,
94
+ name="fpn",
95
+ dtype=dtype,
96
+ data_format=data_format,
97
+ use_batch_norm=use_fpn_batch_norm,
98
+ )
99
+
100
+ # === Functional model ===
101
+ image_input = keras.layers.Input(image_shape, name="inputs")
102
+ feature_extractor_outputs = feature_extractor(image_input)
103
+ feature_pyramid_outputs = feature_pyramid(feature_extractor_outputs)
104
+
105
+ super().__init__(
106
+ inputs=image_input,
107
+ outputs=feature_pyramid_outputs,
108
+ dtype=dtype,
109
+ **kwargs,
110
+ )
111
+
112
+ # === config ===
113
+ self.min_level = min_level
114
+ self.max_level = max_level
115
+ self.use_p5 = use_p5
116
+ self.use_fpn_batch_norm = use_fpn_batch_norm
117
+ self.image_encoder = image_encoder
118
+ self.feature_pyramid = feature_pyramid
119
+ self.image_shape = image_shape
120
+ self.pyramid_outputs = feature_pyramid_outputs
121
+
122
+ def get_config(self):
123
+ config = super().get_config()
124
+ config.update(
125
+ {
126
+ "image_encoder": keras.layers.serialize(self.image_encoder),
127
+ "min_level": self.min_level,
128
+ "max_level": self.max_level,
129
+ "use_p5": self.use_p5,
130
+ "use_fpn_batch_norm": self.use_fpn_batch_norm,
131
+ "image_shape": self.image_shape,
132
+ }
133
+ )
134
+ return config
135
+
136
+ @classmethod
137
+ def from_config(cls, config):
138
+ config.update(
139
+ {
140
+ "image_encoder": keras.layers.deserialize(
141
+ config["image_encoder"]
142
+ ),
143
+ }
144
+ )
145
+
146
+ return super().from_config(config)