keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. keras_hub/api/layers/__init__.py +21 -3
  2. keras_hub/api/models/__init__.py +71 -12
  3. keras_hub/api/tokenizers/__init__.py +1 -1
  4. keras_hub/src/bounding_box/__init__.py +2 -0
  5. keras_hub/src/bounding_box/converters.py +102 -12
  6. keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
  7. keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
  8. keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
  9. keras_hub/src/layers/modeling/rms_normalization.py +36 -0
  10. keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
  11. keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
  12. keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
  13. keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
  14. keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
  15. keras_hub/src/layers/preprocessing/image_converter.py +170 -34
  16. keras_hub/src/metrics/bleu.py +4 -3
  17. keras_hub/src/models/albert/albert_presets.py +4 -12
  18. keras_hub/src/models/albert/albert_text_classifier.py +7 -7
  19. keras_hub/src/models/backbone.py +3 -14
  20. keras_hub/src/models/bart/bart_backbone.py +4 -4
  21. keras_hub/src/models/bart/bart_presets.py +3 -9
  22. keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
  23. keras_hub/src/models/basnet/__init__.py +5 -0
  24. keras_hub/src/models/basnet/basnet.py +122 -0
  25. keras_hub/src/models/basnet/basnet_backbone.py +366 -0
  26. keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
  27. keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
  28. keras_hub/src/models/basnet/basnet_presets.py +17 -0
  29. keras_hub/src/models/bert/bert_presets.py +14 -32
  30. keras_hub/src/models/bert/bert_text_classifier.py +3 -3
  31. keras_hub/src/models/bloom/bloom_presets.py +8 -24
  32. keras_hub/src/models/causal_lm.py +56 -12
  33. keras_hub/src/models/clip/__init__.py +5 -0
  34. keras_hub/src/models/clip/clip_backbone.py +286 -0
  35. keras_hub/src/models/clip/clip_encoder_block.py +19 -4
  36. keras_hub/src/models/clip/clip_image_converter.py +8 -0
  37. keras_hub/src/models/clip/clip_presets.py +93 -0
  38. keras_hub/src/models/clip/clip_text_encoder.py +4 -1
  39. keras_hub/src/models/clip/clip_tokenizer.py +18 -3
  40. keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
  41. keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
  42. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
  43. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
  44. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
  45. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
  46. keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
  47. keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
  48. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
  49. keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
  50. keras_hub/src/models/deeplab_v3/__init__.py +7 -0
  51. keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
  52. keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
  53. keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
  54. keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
  55. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
  56. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
  57. keras_hub/src/models/densenet/densenet_backbone.py +6 -4
  58. keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
  59. keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
  60. keras_hub/src/models/densenet/densenet_presets.py +9 -15
  61. keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
  62. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
  63. keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
  64. keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
  65. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
  66. keras_hub/src/models/efficientnet/__init__.py +9 -0
  67. keras_hub/src/models/efficientnet/cba.py +141 -0
  68. keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
  69. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
  70. keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
  71. keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
  72. keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
  73. keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
  74. keras_hub/src/models/efficientnet/mbconv.py +53 -22
  75. keras_hub/src/models/electra/electra_backbone.py +2 -2
  76. keras_hub/src/models/electra/electra_presets.py +6 -18
  77. keras_hub/src/models/f_net/f_net_presets.py +2 -6
  78. keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
  79. keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
  80. keras_hub/src/models/falcon/falcon_backbone.py +5 -3
  81. keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
  82. keras_hub/src/models/falcon/falcon_presets.py +1 -3
  83. keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
  84. keras_hub/src/models/feature_pyramid_backbone.py +1 -1
  85. keras_hub/src/models/flux/__init__.py +5 -0
  86. keras_hub/src/models/flux/flux_layers.py +496 -0
  87. keras_hub/src/models/flux/flux_maths.py +225 -0
  88. keras_hub/src/models/flux/flux_model.py +236 -0
  89. keras_hub/src/models/flux/flux_presets.py +3 -0
  90. keras_hub/src/models/flux/flux_text_to_image.py +146 -0
  91. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
  92. keras_hub/src/models/gemma/gemma_backbone.py +35 -20
  93. keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
  94. keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
  95. keras_hub/src/models/gemma/gemma_presets.py +29 -63
  96. keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
  97. keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
  98. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
  99. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
  100. keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
  101. keras_hub/src/models/image_classifier.py +147 -2
  102. keras_hub/src/models/image_classifier_preprocessor.py +6 -3
  103. keras_hub/src/models/image_object_detector.py +87 -0
  104. keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
  105. keras_hub/src/models/image_segmenter.py +0 -5
  106. keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
  107. keras_hub/src/models/image_to_image.py +417 -0
  108. keras_hub/src/models/inpaint.py +520 -0
  109. keras_hub/src/models/llama/llama_backbone.py +138 -12
  110. keras_hub/src/models/llama/llama_causal_lm.py +3 -1
  111. keras_hub/src/models/llama/llama_presets.py +10 -20
  112. keras_hub/src/models/llama3/llama3_backbone.py +12 -11
  113. keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
  114. keras_hub/src/models/llama3/llama3_presets.py +4 -12
  115. keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
  116. keras_hub/src/models/mistral/mistral_backbone.py +16 -15
  117. keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
  118. keras_hub/src/models/mistral/mistral_presets.py +3 -9
  119. keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
  120. keras_hub/src/models/mit/__init__.py +6 -0
  121. keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
  122. keras_hub/src/models/mit/mit_image_classifier.py +12 -0
  123. keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
  124. keras_hub/src/models/mit/mit_image_converter.py +8 -0
  125. keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
  126. keras_hub/src/models/mit/mit_presets.py +139 -0
  127. keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
  128. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
  129. keras_hub/src/models/opt/opt_causal_lm.py +2 -2
  130. keras_hub/src/models/opt/opt_presets.py +4 -12
  131. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
  132. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
  133. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
  134. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
  135. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
  136. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
  137. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
  138. keras_hub/src/models/phi3/phi3_decoder.py +0 -1
  139. keras_hub/src/models/phi3/phi3_presets.py +2 -6
  140. keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
  141. keras_hub/src/models/preprocessor.py +25 -11
  142. keras_hub/src/models/resnet/resnet_backbone.py +3 -14
  143. keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
  144. keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
  145. keras_hub/src/models/resnet/resnet_presets.py +127 -18
  146. keras_hub/src/models/retinanet/__init__.py +5 -0
  147. keras_hub/src/models/retinanet/anchor_generator.py +52 -53
  148. keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
  149. keras_hub/src/models/retinanet/non_max_supression.py +1 -0
  150. keras_hub/src/models/retinanet/prediction_head.py +192 -0
  151. keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
  152. keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
  153. keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
  154. keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
  155. keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
  156. keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
  157. keras_hub/src/models/roberta/roberta_backbone.py +2 -2
  158. keras_hub/src/models/roberta/roberta_presets.py +6 -8
  159. keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
  160. keras_hub/src/models/sam/__init__.py +5 -0
  161. keras_hub/src/models/sam/sam_backbone.py +2 -3
  162. keras_hub/src/models/sam/sam_image_converter.py +2 -4
  163. keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
  164. keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
  165. keras_hub/src/models/sam/sam_layers.py +5 -3
  166. keras_hub/src/models/sam/sam_presets.py +3 -9
  167. keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
  168. keras_hub/src/models/sam/sam_transformer.py +5 -4
  169. keras_hub/src/models/segformer/__init__.py +8 -0
  170. keras_hub/src/models/segformer/segformer_backbone.py +167 -0
  171. keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
  172. keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
  173. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
  174. keras_hub/src/models/segformer/segformer_presets.py +136 -0
  175. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
  176. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
  177. keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
  178. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
  179. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
  180. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
  181. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
  182. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
  183. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
  184. keras_hub/src/models/t5/t5_backbone.py +5 -4
  185. keras_hub/src/models/t5/t5_presets.py +47 -19
  186. keras_hub/src/models/task.py +47 -39
  187. keras_hub/src/models/text_classifier.py +2 -2
  188. keras_hub/src/models/text_to_image.py +106 -41
  189. keras_hub/src/models/vae/__init__.py +1 -0
  190. keras_hub/src/models/vae/vae_backbone.py +184 -0
  191. keras_hub/src/models/vae/vae_layers.py +739 -0
  192. keras_hub/src/models/vgg/__init__.py +5 -0
  193. keras_hub/src/models/vgg/vgg_backbone.py +4 -24
  194. keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
  195. keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
  196. keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
  197. keras_hub/src/models/vgg/vgg_presets.py +48 -0
  198. keras_hub/src/models/vit/__init__.py +5 -0
  199. keras_hub/src/models/vit/vit_backbone.py +152 -0
  200. keras_hub/src/models/vit/vit_image_classifier.py +187 -0
  201. keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
  202. keras_hub/src/models/vit/vit_image_converter.py +73 -0
  203. keras_hub/src/models/vit/vit_layers.py +391 -0
  204. keras_hub/src/models/vit/vit_presets.py +126 -0
  205. keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
  206. keras_hub/src/models/vit_det/vit_layers.py +3 -3
  207. keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
  208. keras_hub/src/models/whisper/whisper_backbone.py +6 -5
  209. keras_hub/src/models/whisper/whisper_decoder.py +3 -5
  210. keras_hub/src/models/whisper/whisper_presets.py +10 -30
  211. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
  212. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
  213. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
  214. keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
  215. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
  216. keras_hub/src/models/xlnet/relative_attention.py +20 -19
  217. keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
  218. keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
  219. keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
  220. keras_hub/src/samplers/contrastive_sampler.py +2 -3
  221. keras_hub/src/samplers/sampler.py +2 -1
  222. keras_hub/src/tests/test_case.py +41 -6
  223. keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
  224. keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
  225. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
  226. keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
  227. keras_hub/src/tokenizers/tokenizer.py +10 -13
  228. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
  229. keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
  230. keras_hub/src/utils/keras_utils.py +2 -13
  231. keras_hub/src/utils/pipeline_model.py +3 -3
  232. keras_hub/src/utils/preset_utils.py +196 -144
  233. keras_hub/src/utils/tensor_utils.py +4 -4
  234. keras_hub/src/utils/timm/convert_densenet.py +6 -4
  235. keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
  236. keras_hub/src/utils/timm/convert_resnet.py +1 -1
  237. keras_hub/src/utils/timm/convert_vgg.py +85 -0
  238. keras_hub/src/utils/timm/preset_loader.py +14 -9
  239. keras_hub/src/utils/transformers/convert_llama3.py +21 -5
  240. keras_hub/src/utils/transformers/convert_vit.py +150 -0
  241. keras_hub/src/utils/transformers/preset_loader.py +23 -0
  242. keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
  243. keras_hub/src/version_utils.py +1 -1
  244. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
  245. keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
  246. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
  247. keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
  248. keras_hub/src/models/mix_transformer/__init__.py +0 -0
  249. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
  250. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
  251. keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
  252. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,3 @@
1
- import keras
2
-
3
1
  from keras_hub.src.api_export import keras_hub_export
4
2
  from keras_hub.src.models.image_classifier import ImageClassifier
5
3
  from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone
@@ -10,140 +8,5 @@ from keras_hub.src.models.resnet.resnet_image_classifier_preprocessor import (
10
8
 
11
9
  @keras_hub_export("keras_hub.models.ResNetImageClassifier")
12
10
  class ResNetImageClassifier(ImageClassifier):
13
- """ResNet image classifier task model.
14
-
15
- Args:
16
- backbone: A `keras_hub.models.ResNetBackbone` instance.
17
- num_classes: int. The number of classes to predict.
18
- activation: `None`, str or callable. The activation function to use on
19
- the `Dense` layer. Set `activation=None` to return the output
20
- logits. Defaults to `"softmax"`.
21
- head_dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The
22
- dtype to use for the classification head's computations and weights.
23
-
24
- To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)`
25
- where `x` is a tensor and `y` is a integer from `[0, num_classes)`.
26
- All `ImageClassifier` tasks include a `from_preset()` constructor which can
27
- be used to load a pre-trained config and weights.
28
-
29
- Examples:
30
-
31
- Call `predict()` to run inference.
32
- ```python
33
- # Load preset and train
34
- images = np.ones((2, 224, 224, 3), dtype="float32")
35
- classifier = keras_hub.models.ResNetImageClassifier.from_preset(
36
- "resnet_50_imagenet"
37
- )
38
- classifier.predict(images)
39
- ```
40
-
41
- Call `fit()` on a single batch.
42
- ```python
43
- # Load preset and train
44
- images = np.ones((2, 224, 224, 3), dtype="float32")
45
- labels = [0, 3]
46
- classifier = keras_hub.models.ResNetImageClassifier.from_preset(
47
- "resnet_50_imagenet"
48
- )
49
- classifier.fit(x=images, y=labels, batch_size=2)
50
- ```
51
-
52
- Call `fit()` with custom loss, optimizer and backbone.
53
- ```python
54
- classifier = keras_hub.models.ResNetImageClassifier.from_preset(
55
- "resnet_50_imagenet"
56
- )
57
- classifier.compile(
58
- loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
59
- optimizer=keras.optimizers.Adam(5e-5),
60
- )
61
- classifier.backbone.trainable = False
62
- classifier.fit(x=images, y=labels, batch_size=2)
63
- ```
64
-
65
- Custom backbone.
66
- ```python
67
- images = np.ones((2, 224, 224, 3), dtype="float32")
68
- labels = [0, 3]
69
- backbone = keras_hub.models.ResNetBackbone(
70
- stackwise_num_filters=[64, 64, 64],
71
- stackwise_num_blocks=[2, 2, 2],
72
- stackwise_num_strides=[1, 2, 2],
73
- block_type="basic_block",
74
- use_pre_activation=True,
75
- pooling="avg",
76
- )
77
- classifier = keras_hub.models.ResNetImageClassifier(
78
- backbone=backbone,
79
- num_classes=4,
80
- )
81
- classifier.fit(x=images, y=labels, batch_size=2)
82
- ```
83
- """
84
-
85
11
  backbone_cls = ResNetBackbone
86
12
  preprocessor_cls = ResNetImageClassifierPreprocessor
87
-
88
- def __init__(
89
- self,
90
- backbone,
91
- num_classes,
92
- preprocessor=None,
93
- pooling="avg",
94
- activation=None,
95
- head_dtype=None,
96
- **kwargs,
97
- ):
98
- head_dtype = head_dtype or backbone.dtype_policy
99
-
100
- # === Layers ===
101
- self.backbone = backbone
102
- self.preprocessor = preprocessor
103
- if pooling == "avg":
104
- self.pooler = keras.layers.GlobalAveragePooling2D(
105
- data_format=backbone.data_format, dtype=head_dtype
106
- )
107
- elif pooling == "max":
108
- self.pooler = keras.layers.GlobalAveragePooling2D(
109
- data_format=backbone.data_format, dtype=head_dtype
110
- )
111
- else:
112
- raise ValueError(
113
- "Unknown `pooling` type. Polling should be either `'avg'` or "
114
- f"`'max'`. Received: pooling={pooling}."
115
- )
116
- self.output_dense = keras.layers.Dense(
117
- num_classes,
118
- activation=activation,
119
- dtype=head_dtype,
120
- name="predictions",
121
- )
122
-
123
- # === Functional Model ===
124
- inputs = self.backbone.input
125
- x = self.backbone(inputs)
126
- x = self.pooler(x)
127
- outputs = self.output_dense(x)
128
- super().__init__(
129
- inputs=inputs,
130
- outputs=outputs,
131
- **kwargs,
132
- )
133
-
134
- # === Config ===
135
- self.num_classes = num_classes
136
- self.activation = activation
137
- self.pooling = pooling
138
-
139
- def get_config(self):
140
- # Backbone serialized in `super`
141
- config = super().get_config()
142
- config.update(
143
- {
144
- "num_classes": self.num_classes,
145
- "pooling": self.pooling,
146
- "activation": self.activation,
147
- }
148
- )
149
- return config
@@ -1,10 +1,8 @@
1
1
  from keras_hub.src.api_export import keras_hub_export
2
- from keras_hub.src.layers.preprocessing.resizing_image_converter import (
3
- ResizingImageConverter,
4
- )
2
+ from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
5
3
  from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone
6
4
 
7
5
 
8
6
  @keras_hub_export("keras_hub.layers.ResNetImageConverter")
9
- class ResNetImageConverter(ResizingImageConverter):
7
+ class ResNetImageConverter(ImageConverter):
10
8
  backbone_cls = ResNetBackbone
@@ -8,11 +8,9 @@ backbone_presets = {
8
8
  "at a 224x224 resolution."
9
9
  ),
10
10
  "params": 11186112,
11
- "official_name": "ResNet",
12
11
  "path": "resnet",
13
- "model_card": "https://arxiv.org/abs/2110.00476",
14
12
  },
15
- "kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_18_imagenet/3",
13
+ "kaggle_handle": "kaggle://keras/resnetv1/keras/resnet_18_imagenet/3",
16
14
  },
17
15
  "resnet_50_imagenet": {
18
16
  "metadata": {
@@ -21,11 +19,9 @@ backbone_presets = {
21
19
  "at a 224x224 resolution."
22
20
  ),
23
21
  "params": 23561152,
24
- "official_name": "ResNet",
25
22
  "path": "resnet",
26
- "model_card": "https://arxiv.org/abs/2110.00476",
27
23
  },
28
- "kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_50_imagenet/3",
24
+ "kaggle_handle": "kaggle://keras/resnetv1/keras/resnet_50_imagenet/3",
29
25
  },
30
26
  "resnet_101_imagenet": {
31
27
  "metadata": {
@@ -34,11 +30,9 @@ backbone_presets = {
34
30
  "at a 224x224 resolution."
35
31
  ),
36
32
  "params": 42605504,
37
- "official_name": "ResNet",
38
33
  "path": "resnet",
39
- "model_card": "https://arxiv.org/abs/2110.00476",
40
34
  },
41
- "kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_101_imagenet/3",
35
+ "kaggle_handle": "kaggle://keras/resnetv1/keras/resnet_101_imagenet/3",
42
36
  },
43
37
  "resnet_152_imagenet": {
44
38
  "metadata": {
@@ -47,11 +41,9 @@ backbone_presets = {
47
41
  "at a 224x224 resolution."
48
42
  ),
49
43
  "params": 58295232,
50
- "official_name": "ResNet",
51
44
  "path": "resnet",
52
- "model_card": "https://arxiv.org/abs/2110.00476",
53
45
  },
54
- "kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_152_imagenet/3",
46
+ "kaggle_handle": "kaggle://keras/resnetv1/keras/resnet_152_imagenet/3",
55
47
  },
56
48
  "resnet_v2_50_imagenet": {
57
49
  "metadata": {
@@ -60,11 +52,9 @@ backbone_presets = {
60
52
  "dataset at a 224x224 resolution."
61
53
  ),
62
54
  "params": 23561152,
63
- "official_name": "ResNet",
64
55
  "path": "resnet",
65
- "model_card": "https://arxiv.org/abs/2110.00476",
66
56
  },
67
- "kaggle_handle": "kaggle://kerashub/resnetv2/keras/resnet_v2_50_imagenet/3",
57
+ "kaggle_handle": "kaggle://keras/resnetv2/keras/resnet_v2_50_imagenet/3",
68
58
  },
69
59
  "resnet_v2_101_imagenet": {
70
60
  "metadata": {
@@ -73,10 +63,129 @@ backbone_presets = {
73
63
  "dataset at a 224x224 resolution."
74
64
  ),
75
65
  "params": 42605504,
76
- "official_name": "ResNet",
77
66
  "path": "resnet",
78
- "model_card": "https://arxiv.org/abs/2110.00476",
79
67
  },
80
- "kaggle_handle": "kaggle://kerashub/resnetv2/keras/resnet_v2_101_imagenet/3",
68
+ "kaggle_handle": "kaggle://keras/resnetv2/keras/resnet_v2_101_imagenet/3",
69
+ },
70
+ "resnet_vd_18_imagenet": {
71
+ "metadata": {
72
+ "description": (
73
+ "18-layer ResNetVD (ResNet with bag of tricks) model "
74
+ "pre-trained on the ImageNet 1k dataset at a 224x224 "
75
+ "resolution."
76
+ ),
77
+ "params": 11722824,
78
+ "path": "resnet",
79
+ },
80
+ "kaggle_handle": "kaggle://keras/resnet_vd/keras/resnet_vd_18_imagenet/2",
81
+ },
82
+ "resnet_vd_34_imagenet": {
83
+ "metadata": {
84
+ "description": (
85
+ "34-layer ResNetVD (ResNet with bag of tricks) model "
86
+ "pre-trained on the ImageNet 1k dataset at a 224x224 "
87
+ "resolution."
88
+ ),
89
+ "params": 21838408,
90
+ "path": "resnet",
91
+ },
92
+ "kaggle_handle": "kaggle://keras/resnet_vd/keras/resnet_vd_34_imagenet/2",
93
+ },
94
+ "resnet_vd_50_imagenet": {
95
+ "metadata": {
96
+ "description": (
97
+ "50-layer ResNetVD (ResNet with bag of tricks) model "
98
+ "pre-trained on the ImageNet 1k dataset at a 224x224 "
99
+ "resolution."
100
+ ),
101
+ "params": 25629512,
102
+ "path": "resnet",
103
+ },
104
+ "kaggle_handle": "kaggle://keras/resnet_vd/keras/resnet_vd_50_imagenet/2",
105
+ },
106
+ "resnet_vd_50_ssld_imagenet": {
107
+ "metadata": {
108
+ "description": (
109
+ "50-layer ResNetVD (ResNet with bag of tricks) model "
110
+ "pre-trained on the ImageNet 1k dataset at a 224x224 "
111
+ "resolution with knowledge distillation."
112
+ ),
113
+ "params": 25629512,
114
+ "path": "resnet",
115
+ },
116
+ "kaggle_handle": "kaggle://keras/resnet_vd/keras/resnet_vd_50_ssld_imagenet/2",
117
+ },
118
+ "resnet_vd_50_ssld_v2_imagenet": {
119
+ "metadata": {
120
+ "description": (
121
+ "50-layer ResNetVD (ResNet with bag of tricks) model "
122
+ "pre-trained on the ImageNet 1k dataset at a 224x224 "
123
+ "resolution with knowledge distillation and AutoAugment."
124
+ ),
125
+ "params": 25629512,
126
+ "path": "resnet",
127
+ },
128
+ "kaggle_handle": "kaggle://keras/resnet_vd/keras/resnet_vd_50_ssld_v2_imagenet/2",
129
+ },
130
+ "resnet_vd_50_ssld_v2_fix_imagenet": {
131
+ "metadata": {
132
+ "description": (
133
+ "50-layer ResNetVD (ResNet with bag of tricks) model "
134
+ "pre-trained on the ImageNet 1k dataset at a 224x224 "
135
+ "resolution with knowledge distillation, AutoAugment and "
136
+ "additional fine-tuning of the classification head."
137
+ ),
138
+ "params": 25629512,
139
+ "path": "resnet",
140
+ },
141
+ "kaggle_handle": "kaggle://keras/resnet_vd/keras/resnet_vd_50_ssld_v2_fix_imagenet/2",
142
+ },
143
+ "resnet_vd_101_imagenet": {
144
+ "metadata": {
145
+ "description": (
146
+ "101-layer ResNetVD (ResNet with bag of tricks) model "
147
+ "pre-trained on the ImageNet 1k dataset at a 224x224 "
148
+ "resolution."
149
+ ),
150
+ "params": 44673864,
151
+ "path": "resnet",
152
+ },
153
+ "kaggle_handle": "kaggle://keras/resnet_vd/keras/resnet_vd_101_imagenet/2",
154
+ },
155
+ "resnet_vd_101_ssld_imagenet": {
156
+ "metadata": {
157
+ "description": (
158
+ "101-layer ResNetVD (ResNet with bag of tricks) model "
159
+ "pre-trained on the ImageNet 1k dataset at a 224x224 "
160
+ "resolution with knowledge distillation."
161
+ ),
162
+ "params": 44673864,
163
+ "path": "resnet",
164
+ },
165
+ "kaggle_handle": "kaggle://keras/resnet_vd/keras/resnet_vd_101_ssld_imagenet/2",
166
+ },
167
+ "resnet_vd_152_imagenet": {
168
+ "metadata": {
169
+ "description": (
170
+ "152-layer ResNetVD (ResNet with bag of tricks) model "
171
+ "pre-trained on the ImageNet 1k dataset at a 224x224 "
172
+ "resolution."
173
+ ),
174
+ "params": 60363592,
175
+ "path": "resnet",
176
+ },
177
+ "kaggle_handle": "kaggle://keras/resnet_vd/keras/resnet_vd_152_imagenet/2",
178
+ },
179
+ "resnet_vd_200_imagenet": {
180
+ "metadata": {
181
+ "description": (
182
+ "200-layer ResNetVD (ResNet with bag of tricks) model "
183
+ "pre-trained on the ImageNet 1k dataset at a 224x224 "
184
+ "resolution."
185
+ ),
186
+ "params": 74933064,
187
+ "path": "resnet",
188
+ },
189
+ "kaggle_handle": "kaggle://keras/resnet_vd/keras/resnet_vd_200_imagenet/2",
81
190
  },
82
191
  }
@@ -0,0 +1,5 @@
1
+ from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone
2
+ from keras_hub.src.models.retinanet.retinanet_presets import backbone_presets
3
+ from keras_hub.src.utils.preset_utils import register_presets
4
+
5
+ register_presets(backbone_presets, RetinaNetBackbone)
@@ -3,9 +3,13 @@ import math
3
3
  import keras
4
4
  from keras import ops
5
5
 
6
+ from keras_hub.src.api_export import keras_hub_export
7
+
8
+ # TODO: https://github.com/keras-team/keras-hub/issues/1965
6
9
  from keras_hub.src.bounding_box.converters import convert_format
7
10
 
8
11
 
12
+ @keras_hub_export("keras_hub.layers.AnchorGenerator")
9
13
  class AnchorGenerator(keras.layers.Layer):
10
14
  """Generates anchor boxes for object detection tasks.
11
15
 
@@ -81,6 +85,7 @@ class AnchorGenerator(keras.layers.Layer):
81
85
  self.num_scales = num_scales
82
86
  self.aspect_ratios = aspect_ratios
83
87
  self.anchor_size = anchor_size
88
+ self.num_base_anchors = num_scales * len(aspect_ratios)
84
89
  self.built = True
85
90
 
86
91
  def call(self, inputs):
@@ -92,60 +97,61 @@ class AnchorGenerator(keras.layers.Layer):
92
97
 
93
98
  image_shape = tuple(image_shape)
94
99
 
95
- multilevel_boxes = {}
100
+ multilevel_anchors = {}
96
101
  for level in range(self.min_level, self.max_level + 1):
97
- boxes_l = []
98
102
  # Calculate the feature map size for this level
99
103
  feat_size_y = math.ceil(image_shape[0] / 2**level)
100
104
  feat_size_x = math.ceil(image_shape[1] / 2**level)
101
105
 
102
106
  # Calculate the stride (step size) for this level
103
- stride_y = ops.cast(image_shape[0] / feat_size_y, "float32")
104
- stride_x = ops.cast(image_shape[1] / feat_size_x, "float32")
107
+ stride_y = image_shape[0] // feat_size_y
108
+ stride_x = image_shape[1] // feat_size_x
105
109
 
106
110
  # Generate anchor center points
107
111
  # Start from stride/2 to center anchors on pixels
108
- cx = ops.arange(stride_x / 2, image_shape[1], stride_x)
109
- cy = ops.arange(stride_y / 2, image_shape[0], stride_y)
112
+ cx = ops.arange(0, feat_size_x, dtype="float32") * stride_x
113
+ cy = ops.arange(0, feat_size_y, dtype="float32") * stride_y
110
114
 
111
115
  # Create a grid of anchor centers
112
- cx_grid, cy_grid = ops.meshgrid(cx, cy)
113
-
114
- for scale in range(self.num_scales):
115
- for aspect_ratio in self.aspect_ratios:
116
- # Calculate the intermediate scale factor
117
- intermidate_scale = 2 ** (scale / self.num_scales)
118
- # Calculate the base anchor size for this level and scale
119
- base_anchor_size = (
120
- self.anchor_size * 2**level * intermidate_scale
121
- )
122
- # Adjust anchor dimensions based on aspect ratio
123
- aspect_x = aspect_ratio**0.5
124
- aspect_y = aspect_ratio**-0.5
125
- half_anchor_size_x = base_anchor_size * aspect_x / 2.0
126
- half_anchor_size_y = base_anchor_size * aspect_y / 2.0
127
-
128
- # Generate anchor boxes (y1, x1, y2, x2 format)
129
- boxes = ops.stack(
130
- [
131
- cy_grid - half_anchor_size_y,
132
- cx_grid - half_anchor_size_x,
133
- cy_grid + half_anchor_size_y,
134
- cx_grid + half_anchor_size_x,
135
- ],
136
- axis=-1,
137
- )
138
- boxes_l.append(boxes)
139
- # Concat anchors on the same level to tensor shape HxWx(Ax4)
140
- boxes_l = ops.concatenate(boxes_l, axis=-1)
141
- boxes_l = ops.reshape(boxes_l, (-1, 4))
142
- # Convert to user defined
143
- multilevel_boxes[f"P{level}"] = convert_format(
144
- boxes_l,
145
- source="yxyx",
116
+ cy_grid, cx_grid = ops.meshgrid(cy, cx, indexing="ij")
117
+ cy_grid = ops.reshape(cy_grid, (-1,))
118
+ cx_grid = ops.reshape(cx_grid, (-1,))
119
+
120
+ shifts = ops.stack((cx_grid, cy_grid, cx_grid, cy_grid), axis=1)
121
+ sizes = [
122
+ int(
123
+ 2**level * self.anchor_size * 2 ** (scale / self.num_scales)
124
+ )
125
+ for scale in range(self.num_scales)
126
+ ]
127
+
128
+ base_anchors = self.generate_base_anchors(
129
+ sizes=sizes, aspect_ratios=self.aspect_ratios
130
+ )
131
+ shifts = ops.reshape(shifts, (-1, 1, 4))
132
+ base_anchors = ops.reshape(base_anchors, (1, -1, 4))
133
+
134
+ anchors = shifts + base_anchors
135
+ anchors = ops.reshape(anchors, (-1, 4))
136
+ multilevel_anchors[f"P{level}"] = convert_format(
137
+ anchors,
138
+ source="xyxy",
146
139
  target=self.bounding_box_format,
147
140
  )
148
- return multilevel_boxes
141
+ return multilevel_anchors
142
+
143
+ def generate_base_anchors(self, sizes, aspect_ratios):
144
+ sizes = ops.convert_to_tensor(sizes, dtype="float32")
145
+ aspect_ratios = ops.convert_to_tensor(aspect_ratios)
146
+ h_ratios = ops.sqrt(aspect_ratios)
147
+ w_ratios = 1 / h_ratios
148
+
149
+ ws = ops.reshape(w_ratios[:, None] * sizes[None, :], (-1,))
150
+ hs = ops.reshape(h_ratios[:, None] * sizes[None, :], (-1,))
151
+
152
+ base_anchors = ops.stack([-1 * ws, -1 * hs, ws, hs], axis=1) / 2
153
+ base_anchors = ops.round(base_anchors)
154
+ return base_anchors
149
155
 
150
156
  def compute_output_shape(self, input_shape):
151
157
  multilevel_boxes_shape = {}
@@ -156,18 +162,11 @@ class AnchorGenerator(keras.layers.Layer):
156
162
 
157
163
  for i in range(self.min_level, self.max_level + 1):
158
164
  multilevel_boxes_shape[f"P{i}"] = (
159
- (image_height // 2 ** (i))
160
- * (image_width // 2 ** (i))
161
- * self.anchors_per_location,
165
+ int(
166
+ math.ceil(image_height / 2 ** (i))
167
+ * math.ceil(image_width // 2 ** (i))
168
+ * self.num_base_anchors
169
+ ),
162
170
  4,
163
171
  )
164
172
  return multilevel_boxes_shape
165
-
166
- @property
167
- def anchors_per_location(self):
168
- """
169
- The `anchors_per_location` property returns the number of anchors
170
- generated per pixel location, which is equal to
171
- `num_scales * len(aspect_ratios)`.
172
- """
173
- return self.num_scales * len(self.aspect_ratios)