keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. keras_hub/api/layers/__init__.py +21 -3
  2. keras_hub/api/models/__init__.py +71 -12
  3. keras_hub/api/tokenizers/__init__.py +1 -1
  4. keras_hub/src/bounding_box/__init__.py +2 -0
  5. keras_hub/src/bounding_box/converters.py +102 -12
  6. keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
  7. keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
  8. keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
  9. keras_hub/src/layers/modeling/rms_normalization.py +36 -0
  10. keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
  11. keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
  12. keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
  13. keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
  14. keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
  15. keras_hub/src/layers/preprocessing/image_converter.py +170 -34
  16. keras_hub/src/metrics/bleu.py +4 -3
  17. keras_hub/src/models/albert/albert_presets.py +4 -12
  18. keras_hub/src/models/albert/albert_text_classifier.py +7 -7
  19. keras_hub/src/models/backbone.py +3 -14
  20. keras_hub/src/models/bart/bart_backbone.py +4 -4
  21. keras_hub/src/models/bart/bart_presets.py +3 -9
  22. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
  23. keras_hub/src/models/basnet/__init__.py +5 -0
  24. keras_hub/src/models/basnet/basnet.py +122 -0
  25. keras_hub/src/models/basnet/basnet_backbone.py +366 -0
  26. keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
  27. keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
  28. keras_hub/src/models/basnet/basnet_presets.py +17 -0
  29. keras_hub/src/models/bert/bert_presets.py +14 -32
  30. keras_hub/src/models/bert/bert_text_classifier.py +3 -3
  31. keras_hub/src/models/bloom/bloom_presets.py +8 -24
  32. keras_hub/src/models/causal_lm.py +56 -12
  33. keras_hub/src/models/clip/__init__.py +5 -0
  34. keras_hub/src/models/clip/clip_backbone.py +286 -0
  35. keras_hub/src/models/clip/clip_encoder_block.py +19 -4
  36. keras_hub/src/models/clip/clip_image_converter.py +8 -0
  37. keras_hub/src/models/clip/clip_presets.py +93 -0
  38. keras_hub/src/models/clip/clip_text_encoder.py +4 -1
  39. keras_hub/src/models/clip/clip_tokenizer.py +18 -3
  40. keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
  41. keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
  42. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
  43. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
  44. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
  45. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
  46. keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
  47. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
  48. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
  49. keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
  50. keras_hub/src/models/deeplab_v3/__init__.py +7 -0
  51. keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
  52. keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
  53. keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
  54. keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
  55. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
  56. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
  57. keras_hub/src/models/densenet/densenet_backbone.py +6 -4
  58. keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
  59. keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
  60. keras_hub/src/models/densenet/densenet_presets.py +9 -15
  61. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
  62. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
  63. keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
  64. keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
  65. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
  66. keras_hub/src/models/efficientnet/__init__.py +9 -0
  67. keras_hub/src/models/efficientnet/cba.py +141 -0
  68. keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
  69. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
  70. keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
  71. keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
  72. keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
  73. keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
  74. keras_hub/src/models/efficientnet/mbconv.py +53 -22
  75. keras_hub/src/models/electra/electra_backbone.py +2 -2
  76. keras_hub/src/models/electra/electra_presets.py +6 -18
  77. keras_hub/src/models/f_net/f_net_presets.py +2 -6
  78. keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
  79. keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
  80. keras_hub/src/models/falcon/falcon_backbone.py +5 -3
  81. keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
  82. keras_hub/src/models/falcon/falcon_presets.py +1 -3
  83. keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
  84. keras_hub/src/models/feature_pyramid_backbone.py +1 -1
  85. keras_hub/src/models/flux/__init__.py +5 -0
  86. keras_hub/src/models/flux/flux_layers.py +496 -0
  87. keras_hub/src/models/flux/flux_maths.py +225 -0
  88. keras_hub/src/models/flux/flux_model.py +236 -0
  89. keras_hub/src/models/flux/flux_presets.py +3 -0
  90. keras_hub/src/models/flux/flux_text_to_image.py +146 -0
  91. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
  92. keras_hub/src/models/gemma/gemma_backbone.py +35 -20
  93. keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
  94. keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
  95. keras_hub/src/models/gemma/gemma_presets.py +29 -63
  96. keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
  97. keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
  98. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
  99. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
  100. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
  101. keras_hub/src/models/image_classifier.py +147 -2
  102. keras_hub/src/models/image_classifier_preprocessor.py +6 -3
  103. keras_hub/src/models/image_object_detector.py +87 -0
  104. keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
  105. keras_hub/src/models/image_segmenter.py +0 -5
  106. keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
  107. keras_hub/src/models/image_to_image.py +417 -0
  108. keras_hub/src/models/inpaint.py +520 -0
  109. keras_hub/src/models/llama/llama_backbone.py +138 -12
  110. keras_hub/src/models/llama/llama_causal_lm.py +3 -1
  111. keras_hub/src/models/llama/llama_presets.py +10 -20
  112. keras_hub/src/models/llama3/llama3_backbone.py +12 -11
  113. keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
  114. keras_hub/src/models/llama3/llama3_presets.py +4 -12
  115. keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
  116. keras_hub/src/models/mistral/mistral_backbone.py +16 -15
  117. keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
  118. keras_hub/src/models/mistral/mistral_presets.py +3 -9
  119. keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
  120. keras_hub/src/models/mit/__init__.py +6 -0
  121. keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
  122. keras_hub/src/models/mit/mit_image_classifier.py +12 -0
  123. keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
  124. keras_hub/src/models/mit/mit_image_converter.py +8 -0
  125. keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
  126. keras_hub/src/models/mit/mit_presets.py +139 -0
  127. keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
  128. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
  129. keras_hub/src/models/opt/opt_causal_lm.py +2 -2
  130. keras_hub/src/models/opt/opt_presets.py +4 -12
  131. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
  132. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
  133. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
  134. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
  135. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
  136. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
  137. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
  138. keras_hub/src/models/phi3/phi3_decoder.py +0 -1
  139. keras_hub/src/models/phi3/phi3_presets.py +2 -6
  140. keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
  141. keras_hub/src/models/preprocessor.py +25 -11
  142. keras_hub/src/models/resnet/resnet_backbone.py +3 -14
  143. keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
  144. keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
  145. keras_hub/src/models/resnet/resnet_presets.py +127 -18
  146. keras_hub/src/models/retinanet/__init__.py +5 -0
  147. keras_hub/src/models/retinanet/anchor_generator.py +52 -53
  148. keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
  149. keras_hub/src/models/retinanet/non_max_supression.py +1 -0
  150. keras_hub/src/models/retinanet/prediction_head.py +192 -0
  151. keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
  152. keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
  153. keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
  154. keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
  155. keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
  156. keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
  157. keras_hub/src/models/roberta/roberta_backbone.py +2 -2
  158. keras_hub/src/models/roberta/roberta_presets.py +6 -8
  159. keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
  160. keras_hub/src/models/sam/__init__.py +5 -0
  161. keras_hub/src/models/sam/sam_backbone.py +2 -3
  162. keras_hub/src/models/sam/sam_image_converter.py +2 -4
  163. keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
  164. keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
  165. keras_hub/src/models/sam/sam_layers.py +5 -3
  166. keras_hub/src/models/sam/sam_presets.py +3 -9
  167. keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
  168. keras_hub/src/models/sam/sam_transformer.py +5 -4
  169. keras_hub/src/models/segformer/__init__.py +8 -0
  170. keras_hub/src/models/segformer/segformer_backbone.py +167 -0
  171. keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
  172. keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
  173. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
  174. keras_hub/src/models/segformer/segformer_presets.py +136 -0
  175. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
  176. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
  177. keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
  178. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
  179. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
  180. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
  181. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
  182. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
  183. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
  184. keras_hub/src/models/t5/t5_backbone.py +5 -4
  185. keras_hub/src/models/t5/t5_presets.py +47 -19
  186. keras_hub/src/models/task.py +47 -39
  187. keras_hub/src/models/text_classifier.py +2 -2
  188. keras_hub/src/models/text_to_image.py +106 -41
  189. keras_hub/src/models/vae/__init__.py +1 -0
  190. keras_hub/src/models/vae/vae_backbone.py +184 -0
  191. keras_hub/src/models/vae/vae_layers.py +739 -0
  192. keras_hub/src/models/vgg/__init__.py +5 -0
  193. keras_hub/src/models/vgg/vgg_backbone.py +4 -24
  194. keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
  195. keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
  196. keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
  197. keras_hub/src/models/vgg/vgg_presets.py +48 -0
  198. keras_hub/src/models/vit/__init__.py +5 -0
  199. keras_hub/src/models/vit/vit_backbone.py +152 -0
  200. keras_hub/src/models/vit/vit_image_classifier.py +187 -0
  201. keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
  202. keras_hub/src/models/vit/vit_image_converter.py +73 -0
  203. keras_hub/src/models/vit/vit_layers.py +391 -0
  204. keras_hub/src/models/vit/vit_presets.py +126 -0
  205. keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
  206. keras_hub/src/models/vit_det/vit_layers.py +3 -3
  207. keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
  208. keras_hub/src/models/whisper/whisper_backbone.py +6 -5
  209. keras_hub/src/models/whisper/whisper_decoder.py +3 -5
  210. keras_hub/src/models/whisper/whisper_presets.py +10 -30
  211. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
  212. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
  213. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
  214. keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
  215. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
  216. keras_hub/src/models/xlnet/relative_attention.py +20 -19
  217. keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
  218. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
  219. keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
  220. keras_hub/src/samplers/contrastive_sampler.py +2 -3
  221. keras_hub/src/samplers/sampler.py +2 -1
  222. keras_hub/src/tests/test_case.py +41 -6
  223. keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
  224. keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
  225. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
  226. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
  227. keras_hub/src/tokenizers/tokenizer.py +10 -13
  228. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
  229. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
  230. keras_hub/src/utils/keras_utils.py +2 -13
  231. keras_hub/src/utils/pipeline_model.py +3 -3
  232. keras_hub/src/utils/preset_utils.py +196 -144
  233. keras_hub/src/utils/tensor_utils.py +4 -4
  234. keras_hub/src/utils/timm/convert_densenet.py +6 -4
  235. keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
  236. keras_hub/src/utils/timm/convert_resnet.py +1 -1
  237. keras_hub/src/utils/timm/convert_vgg.py +85 -0
  238. keras_hub/src/utils/timm/preset_loader.py +14 -9
  239. keras_hub/src/utils/transformers/convert_llama3.py +21 -5
  240. keras_hub/src/utils/transformers/convert_vit.py +150 -0
  241. keras_hub/src/utils/transformers/preset_loader.py +23 -0
  242. keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
  243. keras_hub/src/version_utils.py +1 -1
  244. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
  245. keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
  246. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
  247. keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
  248. keras_hub/src/models/mix_transformer/__init__.py +0 -0
  249. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
  250. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
  251. keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
  252. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,53 @@
1
+ from keras_hub.src.api_export import keras_hub_export
2
+ from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
3
+ from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone
4
+ from keras_hub.src.utils.tensor_utils import preprocessing_function
5
+
6
+
7
+ @keras_hub_export("keras_hub.layers.RetinaNetImageConverter")
8
+ class RetinaNetImageConverter(ImageConverter):
9
+ backbone_cls = RetinaNetBackbone
10
+
11
+ def __init__(
12
+ self,
13
+ image_size=None,
14
+ scale=None,
15
+ offset=None,
16
+ norm_mean=[0.485, 0.456, 0.406],
17
+ norm_std=[0.229, 0.224, 0.225],
18
+ **kwargs,
19
+ ):
20
+ super().__init__(**kwargs)
21
+ self.image_size = image_size
22
+ self.scale = scale
23
+ self.offset = offset
24
+ self.norm_mean = norm_mean
25
+ self.norm_std = norm_std
26
+ self.built = True
27
+
28
+ @preprocessing_function
29
+ def call(self, inputs):
30
+ # TODO: https://github.com/keras-team/keras-hub/issues/1965
31
+ x = inputs
32
+ # Rescaling Image
33
+ if self.scale is not None:
34
+ x = x * self._expand_non_channel_dims(self.scale, x)
35
+ if self.offset is not None:
36
+ x = x + self._expand_non_channel_dims(self.offset, x)
37
+ # By default normalize using imagenet mean and std
38
+ if self.norm_mean:
39
+ x = x - self._expand_non_channel_dims(self.norm_mean, x)
40
+ if self.norm_std:
41
+ x = x / self._expand_non_channel_dims(self.norm_std, x)
42
+
43
+ return x
44
+
45
+ def get_config(self):
46
+ config = super().get_config()
47
+ config.update(
48
+ {
49
+ "norm_mean": self.norm_mean,
50
+ "norm_std": self.norm_std,
51
+ }
52
+ )
53
+ return config
@@ -1,9 +1,12 @@
1
+ import math
2
+
1
3
  import keras
2
4
  from keras import ops
3
5
 
4
- from keras_hub.src.bounding_box.converters import _encode_box_to_deltas
6
+ # TODO: https://github.com/keras-team/keras-hub/issues/1965
7
+ from keras_hub.src.bounding_box.converters import convert_format
8
+ from keras_hub.src.bounding_box.converters import encode_box_to_deltas
5
9
  from keras_hub.src.bounding_box.iou import compute_iou
6
- from keras_hub.src.models.retinanet.anchor_generator import AnchorGenerator
7
10
  from keras_hub.src.models.retinanet.box_matcher import BoxMatcher
8
11
  from keras_hub.src.utils import tensor_utils
9
12
 
@@ -24,17 +27,10 @@ class RetinaNetLabelEncoder(keras.layers.Layer):
24
27
  consistency during training, regardless of the input format.
25
28
 
26
29
  Args:
27
- bounding_box_format: str. The format of bounding boxes of input dataset.
28
- Refer TODO: Add link to Keras Core Docs.
29
- min_level: int. Minimum level of the output feature pyramid.
30
- max_level: int. Maximum level of the output feature pyramid.
31
- num_scales: int. Number of intermediate scales added on each level.
32
- For example, num_scales=2 adds one additional intermediate anchor
33
- scale [2^0, 2^0.5] on each level.
34
- aspect_ratios: List[float]. Aspect ratios of anchors added on
35
- each level. Each number indicates the ratio of width to height.
36
- anchor_size: float. Scale of size of the base anchor relative to the
37
- feature stride 2^level.
30
+ anchor_generator: A `keras_hub.layers.AnchorGenerator`.
31
+ bounding_box_format: str. Ground truth format of bounding boxes.
32
+ encoding_format: str. The desired target encoding format for the boxes.
33
+ TODO: https://github.com/keras-team/keras-hub/issues/1907
38
34
  positive_threshold: float. the threshold to set an anchor to positive
39
35
  match to gt box. Values above it are positive matches.
40
36
  Defaults to `0.5`
@@ -43,7 +39,7 @@ class RetinaNetLabelEncoder(keras.layers.Layer):
43
39
  Defaults to `0.4`
44
40
  box_variance: List[float]. The scaling factors used to scale the
45
41
  bounding box targets.
46
- Defaults to `[0.1, 0.1, 0.2, 0.2]`.
42
+ Defaults to `[1.0, 1.0, 1.0, 1.0]`.
47
43
  background_class: int. The class ID used for the background class,
48
44
  Defaults to `-1`.
49
45
  ignore_class: int. The class ID used for the ignore class,
@@ -63,15 +59,12 @@ class RetinaNetLabelEncoder(keras.layers.Layer):
63
59
 
64
60
  def __init__(
65
61
  self,
62
+ anchor_generator,
66
63
  bounding_box_format,
67
- min_level,
68
- max_level,
69
- num_scales,
70
- aspect_ratios,
71
- anchor_size,
64
+ encoding_format="center_yxhw",
72
65
  positive_threshold=0.5,
73
66
  negative_threshold=0.4,
74
- box_variance=[0.1, 0.1, 0.2, 0.2],
67
+ box_variance=[1.0, 1.0, 1.0, 1.0],
75
68
  background_class=-1.0,
76
69
  ignore_class=-2.0,
77
70
  box_matcher_match_values=[-1, -2, 1],
@@ -79,27 +72,15 @@ class RetinaNetLabelEncoder(keras.layers.Layer):
79
72
  **kwargs,
80
73
  ):
81
74
  super().__init__(**kwargs)
75
+ self.anchor_generator = anchor_generator
82
76
  self.bounding_box_format = bounding_box_format
83
- self.min_level = min_level
84
- self.max_level = max_level
85
- self.num_scales = num_scales
86
- self.aspect_ratios = aspect_ratios
87
- self.anchor_size = anchor_size
77
+ self.encoding_format = encoding_format
88
78
  self.positive_threshold = positive_threshold
89
79
  self.box_variance = box_variance
90
80
  self.negative_threshold = negative_threshold
91
81
  self.background_class = background_class
92
82
  self.ignore_class = ignore_class
93
83
 
94
- self.anchor_generator = AnchorGenerator(
95
- bounding_box_format=bounding_box_format,
96
- min_level=min_level,
97
- max_level=max_level,
98
- num_scales=num_scales,
99
- aspect_ratios=aspect_ratios,
100
- anchor_size=anchor_size,
101
- )
102
-
103
84
  self.box_matcher = BoxMatcher(
104
85
  thresholds=[negative_threshold, positive_threshold],
105
86
  match_values=box_matcher_match_values,
@@ -116,7 +97,7 @@ class RetinaNetLabelEncoder(keras.layers.Layer):
116
97
  images: A Tensor. The input images argument should be
117
98
  of shape `[B, H, W, C]` or `[B, C, H, W]`.
118
99
  gt_boxes: A Tensor with shape of `[B, num_boxes, 4]`.
119
- gt_labels: A Tensor with shape of `[B, num_boxes, num_classes]`
100
+ gt_classes: A Tensor with shape of `[B, num_boxes, num_classes]`
120
101
 
121
102
  Returns:
122
103
  box_targets: A Tensor of shape `[batch_size, num_anchors, 4]`
@@ -171,10 +152,15 @@ class RetinaNetLabelEncoder(keras.layers.Layer):
171
152
  image_shape: Tuple indicating the image shape `[H, W, C]`.
172
153
 
173
154
  Returns:
174
- Encoded boudning boxes in the format of `center_yxwh` and
155
+ Encoded bounding boxes in the format of `center_yxwh` and
175
156
  corresponding labels for each encoded bounding box.
176
157
  """
177
-
158
+ anchor_boxes = convert_format(
159
+ anchor_boxes,
160
+ source=self.anchor_generator.bounding_box_format,
161
+ target=self.bounding_box_format,
162
+ image_shape=image_shape,
163
+ )
178
164
  iou_matrix = compute_iou(
179
165
  anchor_boxes,
180
166
  gt_boxes,
@@ -193,11 +179,12 @@ class RetinaNetLabelEncoder(keras.layers.Layer):
193
179
  matched_gt_boxes, (-1, ops.shape(matched_gt_boxes)[1], 4)
194
180
  )
195
181
 
196
- box_target = _encode_box_to_deltas(
182
+ box_targets = encode_box_to_deltas(
197
183
  anchors=anchor_boxes,
198
184
  boxes=matched_gt_boxes,
199
185
  anchor_format=self.bounding_box_format,
200
186
  box_format=self.bounding_box_format,
187
+ encoding_format=self.encoding_format,
201
188
  variance=self.box_variance,
202
189
  image_shape=image_shape,
203
190
  )
@@ -205,16 +192,16 @@ class RetinaNetLabelEncoder(keras.layers.Layer):
205
192
  matched_gt_cls_ids = tensor_utils.target_gather(
206
193
  gt_classes, matched_gt_idx
207
194
  )
208
- cls_target = ops.where(
195
+ class_targets = ops.where(
209
196
  ops.not_equal(positive_mask, 1.0),
210
197
  self.background_class,
211
198
  matched_gt_cls_ids,
212
199
  )
213
- cls_target = ops.where(
214
- ops.equal(ignore_mask, 1.0), self.ignore_class, cls_target
200
+ class_targets = ops.where(
201
+ ops.equal(ignore_mask, 1.0), self.ignore_class, class_targets
215
202
  )
216
203
  label = ops.concatenate(
217
- [box_target, ops.cast(cls_target, box_target.dtype)], axis=-1
204
+ [box_targets, ops.cast(class_targets, box_targets.dtype)], axis=-1
218
205
  )
219
206
 
220
207
  # In the case that a box in the corner of an image matches with an all
@@ -234,12 +221,11 @@ class RetinaNetLabelEncoder(keras.layers.Layer):
234
221
  config = super().get_config()
235
222
  config.update(
236
223
  {
224
+ "anchor_generator": keras.layers.serialize(
225
+ self.anchor_generator
226
+ ),
237
227
  "bounding_box_format": self.bounding_box_format,
238
- "min_level": self.min_level,
239
- "max_level": self.max_level,
240
- "num_scales": self.num_scales,
241
- "aspect_ratios": self.aspect_ratios,
242
- "anchor_size": self.anchor_size,
228
+ "encoding_format": self.encoding_format,
243
229
  "positive_threshold": self.positive_threshold,
244
230
  "box_variance": self.box_variance,
245
231
  "negative_threshold": self.negative_threshold,
@@ -249,6 +235,18 @@ class RetinaNetLabelEncoder(keras.layers.Layer):
249
235
  )
250
236
  return config
251
237
 
238
+ @classmethod
239
+ def from_config(cls, config):
240
+ config.update(
241
+ {
242
+ "anchor_generator": keras.layers.deserialize(
243
+ config["anchor_generator"]
244
+ ),
245
+ }
246
+ )
247
+
248
+ return super().from_config(config)
249
+
252
250
  def compute_output_shape(
253
251
  self, images_shape, gt_boxes_shape, gt_classes_shape
254
252
  ):
@@ -258,10 +256,10 @@ class RetinaNetLabelEncoder(keras.layers.Layer):
258
256
 
259
257
  total_num_anchors = 0
260
258
  for i in range(min_level, max_level + 1):
261
- total_num_anchors += (
262
- (image_H // 2 ** (i))
263
- * (image_W // 2 ** (i))
264
- * self.anchor_generator.anchors_per_location
259
+ total_num_anchors += int(
260
+ math.ceil(image_H / 2 ** (i))
261
+ * math.ceil(image_W / 2 ** (i))
262
+ * self.anchor_generator.num_base_anchors
265
263
  )
266
264
 
267
265
  return (batch_size, total_num_anchors, 4), (
@@ -0,0 +1,381 @@
1
+ import keras
2
+ from keras import ops
3
+
4
+ from keras_hub.src.api_export import keras_hub_export
5
+
6
+ # TODO: https://github.com/keras-team/keras-hub/issues/1965
7
+ from keras_hub.src.bounding_box.converters import convert_format
8
+ from keras_hub.src.bounding_box.converters import decode_deltas_to_boxes
9
+ from keras_hub.src.models.image_object_detector import ImageObjectDetector
10
+ from keras_hub.src.models.retinanet.anchor_generator import AnchorGenerator
11
+ from keras_hub.src.models.retinanet.non_max_supression import NonMaxSuppression
12
+ from keras_hub.src.models.retinanet.prediction_head import PredictionHead
13
+ from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone
14
+ from keras_hub.src.models.retinanet.retinanet_label_encoder import (
15
+ RetinaNetLabelEncoder,
16
+ )
17
+ from keras_hub.src.models.retinanet.retinanet_object_detector_preprocessor import ( # noqa: E501
18
+ RetinaNetObjectDetectorPreprocessor,
19
+ )
20
+
21
+
22
+ @keras_hub_export("keras_hub.models.RetinaNetObjectDetector")
23
+ class RetinaNetObjectDetector(ImageObjectDetector):
24
+ """RetinaNet object detector model.
25
+
26
+ This class implements the RetinaNet object detection architecture.
27
+ It consists of a feature extractor backbone, a feature pyramid network(FPN),
28
+ and two prediction heads (for classification and bounding box regression).
29
+
30
+ Args:
31
+ backbone: `keras.Model`. A `keras.models.RetinaNetBackbone` class,
32
+ defining the backbone network architecture. Provides feature maps
33
+ for detection.
34
+ anchor_generator: A `keras_hub.layers.AnchorGenerator` instance.
35
+ Generates anchor boxes at different scales and aspect ratios
36
+ across the image. If None, a default `AnchorGenerator` is
37
+ created with the following parameters:
38
+ - `bounding_box_format`: Same as the model's
39
+ `bounding_box_format`.
40
+ - `min_level`: The backbone's `min_level`.
41
+ - `max_level`: The backbone's `max_level`.
42
+ - `num_scales`: 3.
43
+ - `aspect_ratios`: [0.5, 1.0, 2.0].
44
+ - `anchor_size`: 4.0.
45
+ You can create a custom `AnchorGenerator` by instantiating the
46
+ `keras_hub.layers.AnchorGenerator` class and passing the desired
47
+ arguments.
48
+ num_classes: int. The number of object classes to be detected.
49
+ bounding_box_format: str. Dataset bounding box format (e.g., "xyxy",
50
+ "yxyx"). The supported formats are
51
+ refer TODO: https://github.com/keras-team/keras-hub/issues/1907.
52
+ Defaults to `yxyx`.
53
+ label_encoder: Optional. A `RetinaNetLabelEncoder` instance. Encodes
54
+ ground truth boxes and classes into training targets. It matches
55
+ ground truth boxes to anchors based on IoU and encodes box
56
+ coordinates as offsets. If `None`, a default encoder is created.
57
+ See the `RetinaNetLabelEncoder` class for details. If None, a
58
+ default encoder is created with standard parameters.
59
+ - `anchor_generator`: Same as the model's.
60
+ - `bounding_box_format`: Same as the model's
61
+ `bounding_box_format`.
62
+ - `positive_threshold`: 0.5
63
+ - `negative_threshold`: 0.4
64
+ - `encoding_format`: "center_xywh"
65
+ - `box_variance`: [1.0, 1.0, 1.0, 1.0]
66
+ - `background_class`: -1
67
+ - `ignore_class`: -2
68
+ use_prediction_head_norm: bool. Whether to use Group Normalization after
69
+ the convolution layers in the prediction heads. Defaults to `False`.
70
+ classification_head_prior_probability: float. Prior probability for the
71
+ classification head (used for focal loss). Defaults to 0.01.
72
+ pre_logits_num_conv_layers: int. The number of convolutional layers in
73
+ the head before the logits layer. These convolutional layers are
74
+ applied before the final linear layer (logits) that produces the
75
+ output predictions (bounding box regressions,
76
+ classification scores).
77
+ preprocessor: Optional. An instance of
78
+ `RetinaNetObjectDetectorPreprocessor`or a custom preprocessor.
79
+ Handles image preprocessing before feeding into the backbone.
80
+ activation: Optional. The activation function to be used in the
81
+ classification head. If None, sigmoid is used.
82
+ dtype: Optional. The data type for the prediction heads. Defaults to the
83
+ backbone's dtype policy.
84
+ prediction_decoder: Optional. A `keras.layers.Layer` instance
85
+ responsible for transforming RetinaNet predictions
86
+ (box regressions and classifications) into final bounding boxes and
87
+ classes with confidence scores. Defaults to a `NonMaxSuppression`
88
+ instance.
89
+ """
90
+
91
+ backbone_cls = RetinaNetBackbone
92
+ preprocessor_cls = RetinaNetObjectDetectorPreprocessor
93
+
94
+ def __init__(
95
+ self,
96
+ backbone,
97
+ num_classes,
98
+ bounding_box_format="yxyx",
99
+ anchor_generator=None,
100
+ label_encoder=None,
101
+ use_prediction_head_norm=False,
102
+ classification_head_prior_probability=0.01,
103
+ pre_logits_num_conv_layers=4,
104
+ preprocessor=None,
105
+ activation=None,
106
+ dtype=None,
107
+ prediction_decoder=None,
108
+ **kwargs,
109
+ ):
110
+ # === Layers ===
111
+ image_input = keras.layers.Input(backbone.image_shape, name="images")
112
+ head_dtype = dtype or backbone.dtype_policy
113
+
114
+ anchor_generator = anchor_generator or AnchorGenerator(
115
+ bounding_box_format,
116
+ min_level=backbone.min_level,
117
+ max_level=backbone.max_level,
118
+ num_scales=3,
119
+ aspect_ratios=[0.5, 1.0, 2.0],
120
+ anchor_size=4,
121
+ )
122
+ # As weights are ported from torch they use encoded format
123
+ # as "center_xywh"
124
+ label_encoder = label_encoder or RetinaNetLabelEncoder(
125
+ anchor_generator,
126
+ bounding_box_format=bounding_box_format,
127
+ encoding_format="center_xywh",
128
+ )
129
+
130
+ box_head = PredictionHead(
131
+ output_filters=anchor_generator.num_base_anchors * 4,
132
+ num_conv_layers=pre_logits_num_conv_layers,
133
+ num_filters=256,
134
+ use_group_norm=use_prediction_head_norm,
135
+ use_prior_probability=True,
136
+ prior_probability=classification_head_prior_probability,
137
+ dtype=head_dtype,
138
+ name="box_head",
139
+ )
140
+ classification_head = PredictionHead(
141
+ output_filters=anchor_generator.num_base_anchors * num_classes,
142
+ num_conv_layers=pre_logits_num_conv_layers,
143
+ num_filters=256,
144
+ use_group_norm=use_prediction_head_norm,
145
+ dtype=head_dtype,
146
+ name="classification_head",
147
+ )
148
+
149
+ # === Functional Model ===
150
+ feature_map = backbone(image_input)
151
+
152
+ class_predictions = []
153
+ box_predictions = []
154
+
155
+ # Iterate through the feature pyramid levels (e.g., P3, P4, P5, P6, P7).
156
+ for level in feature_map:
157
+ box_predictions.append(
158
+ keras.layers.Reshape((-1, 4), name=f"box_pred_{level}")(
159
+ box_head(feature_map[level])
160
+ )
161
+ )
162
+ class_predictions.append(
163
+ keras.layers.Reshape(
164
+ (-1, num_classes), name=f"cls_pred_{level}"
165
+ )(classification_head(feature_map[level]))
166
+ )
167
+
168
+ # Concatenate predictions from all FPN levels.
169
+ class_predictions = keras.layers.Concatenate(axis=1, name="cls_logits")(
170
+ class_predictions
171
+ )
172
+ # box_pred is always in "center_xywh" delta-encoded no matter what
173
+ # format you pass in.
174
+ box_predictions = keras.layers.Concatenate(
175
+ axis=1, name="bbox_regression"
176
+ )(box_predictions)
177
+
178
+ outputs = {
179
+ "bbox_regression": box_predictions,
180
+ "cls_logits": class_predictions,
181
+ }
182
+
183
+ super().__init__(
184
+ inputs=image_input,
185
+ outputs=outputs,
186
+ **kwargs,
187
+ )
188
+
189
+ # === Config ===
190
+ self.bounding_box_format = bounding_box_format
191
+ self.use_prediction_head_norm = use_prediction_head_norm
192
+ self.num_classes = num_classes
193
+ self.backbone = backbone
194
+ self.preprocessor = preprocessor
195
+ self.activation = activation
196
+ self.pre_logits_num_conv_layers = pre_logits_num_conv_layers
197
+ self.box_head = box_head
198
+ self.classification_head = classification_head
199
+ self.anchor_generator = anchor_generator
200
+ self.label_encoder = label_encoder
201
+ self._prediction_decoder = prediction_decoder or NonMaxSuppression(
202
+ from_logits=(activation != keras.activations.sigmoid),
203
+ bounding_box_format=bounding_box_format,
204
+ )
205
+
206
+ def compute_loss(self, x, y, y_pred, sample_weight, **kwargs):
207
+ y_for_label_encoder = convert_format(
208
+ y,
209
+ source=self.bounding_box_format,
210
+ target=self.label_encoder.bounding_box_format,
211
+ images=x,
212
+ )
213
+
214
+ boxes, classes = self.label_encoder(
215
+ images=x,
216
+ gt_boxes=y_for_label_encoder["boxes"],
217
+ gt_classes=y_for_label_encoder["classes"],
218
+ )
219
+
220
+ box_pred = y_pred["bbox_regression"]
221
+ cls_pred = y_pred["cls_logits"]
222
+
223
+ if boxes.shape[-1] != 4:
224
+ raise ValueError(
225
+ "boxes should have shape (None, None, 4). Got "
226
+ f"boxes.shape={tuple(boxes.shape)}"
227
+ )
228
+
229
+ if box_pred.shape[-1] != 4:
230
+ raise ValueError(
231
+ "box_pred should have shape (None, None, 4). Got "
232
+ f"box_pred.shape={tuple(box_pred.shape)}. Does your model's "
233
+ "`num_classes` parameter match your losses `num_classes` "
234
+ "parameter?"
235
+ )
236
+ if cls_pred.shape[-1] != self.num_classes:
237
+ raise ValueError(
238
+ "cls_pred should have shape (None, None, 4). Got "
239
+ f"cls_pred.shape={tuple(cls_pred.shape)}. Does your model's "
240
+ "`num_classes` parameter match your losses `num_classes` "
241
+ "parameter?"
242
+ )
243
+
244
+ cls_labels = ops.one_hot(
245
+ ops.cast(classes, "int32"), self.num_classes, dtype="float32"
246
+ )
247
+ positive_mask = ops.cast(ops.greater(classes, -1.0), dtype="float32")
248
+ normalizer = ops.sum(positive_mask)
249
+ cls_weights = ops.cast(ops.not_equal(classes, -2.0), dtype="float32")
250
+ cls_weights /= normalizer
251
+ box_weights = positive_mask / normalizer
252
+
253
+ y_true = {
254
+ "bbox_regression": boxes,
255
+ "cls_logits": cls_labels,
256
+ }
257
+ sample_weights = {
258
+ "bbox_regression": box_weights,
259
+ "cls_logits": cls_weights,
260
+ }
261
+ zero_weight = {
262
+ "bbox_regression": ops.zeros_like(box_weights),
263
+ "cls_logits": ops.zeros_like(cls_weights),
264
+ }
265
+
266
+ sample_weight = ops.cond(
267
+ normalizer == 0,
268
+ lambda: zero_weight,
269
+ lambda: sample_weights,
270
+ )
271
+ return super().compute_loss(
272
+ x=x, y=y_true, y_pred=y_pred, sample_weight=sample_weight, **kwargs
273
+ )
274
+
275
+ def predict_step(self, *args):
276
+ outputs = super().predict_step(*args)
277
+ if isinstance(outputs, tuple):
278
+ return self.decode_predictions(outputs[0], args[-1]), outputs[1]
279
+ return self.decode_predictions(outputs, *args)
280
+
281
+ @property
282
+ def prediction_decoder(self):
283
+ return self._prediction_decoder
284
+
285
+ @prediction_decoder.setter
286
+ def prediction_decoder(self, prediction_decoder):
287
+ if prediction_decoder.bounding_box_format != self.bounding_box_format:
288
+ raise ValueError(
289
+ "Expected `prediction_decoder` and `RetinaNet` to "
290
+ "use the same `bounding_box_format`, but got "
291
+ "`prediction_decoder.bounding_box_format="
292
+ f"{prediction_decoder.bounding_box_format}`, and "
293
+ "`self.bounding_box_format="
294
+ f"{self.bounding_box_format}`."
295
+ )
296
+ self._prediction_decoder = prediction_decoder
297
+ self.make_predict_function(force=True)
298
+ self.make_train_function(force=True)
299
+ self.make_test_function(force=True)
300
+
301
+ def decode_predictions(self, predictions, data):
302
+ box_pred = predictions["bbox_regression"]
303
+ cls_pred = predictions["cls_logits"]
304
+ # box_pred is on "center_yxhw" format, convert to target format.
305
+ if isinstance(data, list) or isinstance(data, tuple):
306
+ images, _ = data
307
+ else:
308
+ images = data
309
+ image_shape = ops.shape(images)[1:]
310
+ anchor_boxes = self.anchor_generator(images)
311
+ anchor_boxes = ops.concatenate(list(anchor_boxes.values()), axis=0)
312
+ box_pred = decode_deltas_to_boxes(
313
+ anchors=anchor_boxes,
314
+ boxes_delta=box_pred,
315
+ encoded_format="center_xywh",
316
+ anchor_format=self.anchor_generator.bounding_box_format,
317
+ box_format=self.bounding_box_format,
318
+ image_shape=image_shape,
319
+ )
320
+ # box_pred is now in "self.bounding_box_format" format
321
+ box_pred = convert_format(
322
+ box_pred,
323
+ source=self.bounding_box_format,
324
+ target=self.prediction_decoder.bounding_box_format,
325
+ image_shape=image_shape,
326
+ )
327
+ y_pred = self.prediction_decoder(
328
+ box_pred, cls_pred, image_shape=image_shape
329
+ )
330
+ y_pred["boxes"] = convert_format(
331
+ y_pred["boxes"],
332
+ source=self.prediction_decoder.bounding_box_format,
333
+ target=self.bounding_box_format,
334
+ image_shape=image_shape,
335
+ )
336
+ return y_pred
337
+
338
+ def get_config(self):
339
+ config = super().get_config()
340
+ config.update(
341
+ {
342
+ "num_classes": self.num_classes,
343
+ "use_prediction_head_norm": self.use_prediction_head_norm,
344
+ "pre_logits_num_conv_layers": self.pre_logits_num_conv_layers,
345
+ "bounding_box_format": self.bounding_box_format,
346
+ "anchor_generator": keras.layers.serialize(
347
+ self.anchor_generator
348
+ ),
349
+ "label_encoder": keras.layers.serialize(self.label_encoder),
350
+ "prediction_decoder": keras.layers.serialize(
351
+ self._prediction_decoder
352
+ ),
353
+ }
354
+ )
355
+
356
+ return config
357
+
358
+ @classmethod
359
+ def from_config(cls, config):
360
+ if "label_encoder" in config and isinstance(
361
+ config["label_encoder"], dict
362
+ ):
363
+ config["label_encoder"] = keras.layers.deserialize(
364
+ config["label_encoder"]
365
+ )
366
+
367
+ if "anchor_generator" in config and isinstance(
368
+ config["anchor_generator"], dict
369
+ ):
370
+ config["anchor_generator"] = keras.layers.deserialize(
371
+ config["anchor_generator"]
372
+ )
373
+
374
+ if "prediction_decoder" in config and isinstance(
375
+ config["prediction_decoder"], dict
376
+ ):
377
+ config["prediction_decoder"] = keras.layers.deserialize(
378
+ config["prediction_decoder"]
379
+ )
380
+
381
+ return super().from_config(config)
@@ -0,0 +1,14 @@
1
+ from keras_hub.src.api_export import keras_hub_export
2
+ from keras_hub.src.models.image_object_detector_preprocessor import (
3
+ ImageObjectDetectorPreprocessor,
4
+ )
5
+ from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone
6
+ from keras_hub.src.models.retinanet.retinanet_image_converter import (
7
+ RetinaNetImageConverter,
8
+ )
9
+
10
+
11
+ @keras_hub_export("keras_hub.models.RetinaNetObjectDetectorPreprocessor")
12
+ class RetinaNetObjectDetectorPreprocessor(ImageObjectDetectorPreprocessor):
13
+ backbone_cls = RetinaNetBackbone
14
+ image_converter_cls = RetinaNetImageConverter
@@ -0,0 +1,16 @@
1
+ """RetinaNet model preset configurations."""
2
+
3
+ # Metadata for loading pretrained model weights.
4
+ backbone_presets = {
5
+ "retinanet_resnet50_fpn_coco": {
6
+ "metadata": {
7
+ "description": (
8
+ "RetinaNet model with ResNet50 backbone fine-tuned on COCO in "
9
+ "800x800 resolution."
10
+ ),
11
+ "params": 34121239,
12
+ "path": "retinanet",
13
+ },
14
+ "kaggle_handle": "kaggle://keras/retinanet/keras/retinanet_resnet50_fpn_coco/2",
15
+ }
16
+ }
@@ -23,8 +23,8 @@ class RobertaBackbone(Backbone):
23
23
 
24
24
  The default constructor gives a fully customizable, randomly initialized
25
25
  RoBERTa encoder with any number of layers, heads, and embedding
26
- dimensions. To load preset architectures and weights, use the `from_preset()`
27
- constructor.
26
+ dimensions. To load preset architectures and weights, use the
27
+ `from_preset()` constructor.
28
28
 
29
29
  Disclaimer: Pre-trained models are provided on an "as is" basis, without
30
30
  warranties or conditions of any kind. The underlying model is provided by a