keras-hub-nightly 0.16.1.dev202410200345__py3-none-any.whl → 0.19.0.dev202412070351__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. keras_hub/api/layers/__init__.py +12 -0
  2. keras_hub/api/models/__init__.py +32 -0
  3. keras_hub/src/bounding_box/__init__.py +2 -0
  4. keras_hub/src/bounding_box/converters.py +102 -12
  5. keras_hub/src/layers/modeling/rms_normalization.py +34 -0
  6. keras_hub/src/layers/modeling/transformer_encoder.py +27 -7
  7. keras_hub/src/layers/preprocessing/image_converter.py +5 -0
  8. keras_hub/src/models/albert/albert_presets.py +0 -8
  9. keras_hub/src/models/bart/bart_presets.py +0 -6
  10. keras_hub/src/models/bert/bert_presets.py +0 -20
  11. keras_hub/src/models/bloom/bloom_presets.py +0 -16
  12. keras_hub/src/models/clip/__init__.py +5 -0
  13. keras_hub/src/models/clip/clip_backbone.py +286 -0
  14. keras_hub/src/models/clip/clip_encoder_block.py +19 -4
  15. keras_hub/src/models/clip/clip_image_converter.py +8 -0
  16. keras_hub/src/models/clip/clip_presets.py +93 -0
  17. keras_hub/src/models/clip/clip_text_encoder.py +4 -1
  18. keras_hub/src/models/clip/clip_tokenizer.py +18 -3
  19. keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
  20. keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
  21. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +0 -10
  22. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +0 -2
  23. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +5 -3
  24. keras_hub/src/models/densenet/densenet_backbone.py +1 -1
  25. keras_hub/src/models/densenet/densenet_presets.py +0 -6
  26. keras_hub/src/models/distil_bert/distil_bert_presets.py +0 -6
  27. keras_hub/src/models/efficientnet/__init__.py +9 -0
  28. keras_hub/src/models/efficientnet/cba.py +141 -0
  29. keras_hub/src/models/efficientnet/efficientnet_backbone.py +139 -56
  30. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
  31. keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
  32. keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
  33. keras_hub/src/models/efficientnet/efficientnet_presets.py +192 -0
  34. keras_hub/src/models/efficientnet/fusedmbconv.py +81 -36
  35. keras_hub/src/models/efficientnet/mbconv.py +52 -21
  36. keras_hub/src/models/electra/electra_presets.py +0 -12
  37. keras_hub/src/models/f_net/f_net_presets.py +0 -4
  38. keras_hub/src/models/falcon/falcon_presets.py +0 -2
  39. keras_hub/src/models/flux/__init__.py +5 -0
  40. keras_hub/src/models/flux/flux_layers.py +494 -0
  41. keras_hub/src/models/flux/flux_maths.py +218 -0
  42. keras_hub/src/models/flux/flux_model.py +231 -0
  43. keras_hub/src/models/flux/flux_presets.py +14 -0
  44. keras_hub/src/models/flux/flux_text_to_image.py +142 -0
  45. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
  46. keras_hub/src/models/gemma/gemma_presets.py +0 -40
  47. keras_hub/src/models/gpt2/gpt2_presets.py +0 -9
  48. keras_hub/src/models/image_object_detector.py +87 -0
  49. keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
  50. keras_hub/src/models/image_to_image.py +16 -10
  51. keras_hub/src/models/inpaint.py +20 -13
  52. keras_hub/src/models/llama/llama_backbone.py +1 -1
  53. keras_hub/src/models/llama/llama_presets.py +5 -15
  54. keras_hub/src/models/llama3/llama3_presets.py +0 -8
  55. keras_hub/src/models/mistral/mistral_presets.py +0 -6
  56. keras_hub/src/models/mit/mit_backbone.py +41 -27
  57. keras_hub/src/models/mit/mit_layers.py +9 -7
  58. keras_hub/src/models/mit/mit_presets.py +12 -24
  59. keras_hub/src/models/opt/opt_presets.py +0 -8
  60. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +61 -11
  61. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
  62. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +166 -10
  63. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +12 -11
  64. keras_hub/src/models/phi3/phi3_presets.py +0 -4
  65. keras_hub/src/models/resnet/resnet_presets.py +10 -42
  66. keras_hub/src/models/retinanet/__init__.py +5 -0
  67. keras_hub/src/models/retinanet/anchor_generator.py +52 -53
  68. keras_hub/src/models/retinanet/feature_pyramid.py +99 -36
  69. keras_hub/src/models/retinanet/non_max_supression.py +1 -0
  70. keras_hub/src/models/retinanet/prediction_head.py +192 -0
  71. keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
  72. keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
  73. keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
  74. keras_hub/src/models/retinanet/retinanet_object_detector.py +382 -0
  75. keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
  76. keras_hub/src/models/retinanet/retinanet_presets.py +15 -0
  77. keras_hub/src/models/roberta/roberta_presets.py +0 -4
  78. keras_hub/src/models/sam/sam_backbone.py +0 -1
  79. keras_hub/src/models/sam/sam_image_segmenter.py +9 -10
  80. keras_hub/src/models/sam/sam_presets.py +0 -6
  81. keras_hub/src/models/segformer/__init__.py +8 -0
  82. keras_hub/src/models/segformer/segformer_backbone.py +163 -0
  83. keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
  84. keras_hub/src/models/segformer/segformer_image_segmenter.py +171 -0
  85. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
  86. keras_hub/src/models/segformer/segformer_presets.py +124 -0
  87. keras_hub/src/models/stable_diffusion_3/mmdit.py +41 -0
  88. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +38 -21
  89. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +3 -3
  90. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +3 -3
  91. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +28 -4
  92. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +1 -1
  93. keras_hub/src/models/t5/t5_backbone.py +5 -4
  94. keras_hub/src/models/t5/t5_presets.py +41 -13
  95. keras_hub/src/models/text_to_image.py +13 -5
  96. keras_hub/src/models/vgg/vgg_backbone.py +1 -1
  97. keras_hub/src/models/vgg/vgg_presets.py +0 -8
  98. keras_hub/src/models/whisper/whisper_audio_converter.py +1 -1
  99. keras_hub/src/models/whisper/whisper_presets.py +0 -20
  100. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +0 -4
  101. keras_hub/src/tests/test_case.py +25 -0
  102. keras_hub/src/utils/preset_utils.py +17 -4
  103. keras_hub/src/utils/timm/convert_efficientnet.py +449 -0
  104. keras_hub/src/utils/timm/preset_loader.py +3 -0
  105. keras_hub/src/version_utils.py +1 -1
  106. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/METADATA +15 -26
  107. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/RECORD +109 -76
  108. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/WHEEL +1 -1
  109. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/top_level.txt +0 -0
@@ -8,9 +8,7 @@ backbone_presets = {
8
8
  "image size 224, mix fine tuned, text sequence " "length is 256"
9
9
  ),
10
10
  "params": 2923335408,
11
- "official_name": "PaliGemma",
12
11
  "path": "pali_gemma",
13
- "model_card": "https://www.kaggle.com/models/google/paligemma",
14
12
  },
15
13
  "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_mix_224/3",
16
14
  },
@@ -20,9 +18,7 @@ backbone_presets = {
20
18
  "image size 448, mix fine tuned, text sequence length is 512"
21
19
  ),
22
20
  "params": 2924220144,
23
- "official_name": "PaliGemma",
24
21
  "path": "pali_gemma",
25
- "model_card": "https://www.kaggle.com/models/google/paligemma",
26
22
  },
27
23
  "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_mix_448/3",
28
24
  },
@@ -32,9 +28,7 @@ backbone_presets = {
32
28
  "image size 224, pre trained, text sequence length is 128"
33
29
  ),
34
30
  "params": 2923335408,
35
- "official_name": "PaliGemma",
36
31
  "path": "pali_gemma",
37
- "model_card": "https://www.kaggle.com/models/google/paligemma",
38
32
  },
39
33
  "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_224/3",
40
34
  },
@@ -44,9 +38,7 @@ backbone_presets = {
44
38
  "image size 448, pre trained, text sequence length is 512"
45
39
  ),
46
40
  "params": 2924220144,
47
- "official_name": "PaliGemma",
48
41
  "path": "pali_gemma",
49
- "model_card": "https://www.kaggle.com/models/google/paligemma",
50
42
  },
51
43
  "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_448/3",
52
44
  },
@@ -56,10 +48,174 @@ backbone_presets = {
56
48
  "image size 896, pre trained, text sequence length " "is 512"
57
49
  ),
58
50
  "params": 2927759088,
59
- "official_name": "PaliGemma",
60
51
  "path": "pali_gemma",
61
- "model_card": "https://www.kaggle.com/models/google/paligemma",
62
52
  },
63
53
  "kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_896/3",
64
54
  },
55
+ # PaliGemma2
56
+ "pali_gemma2_3b_ft_docci_448": {
57
+ "metadata": {
58
+ "description": (
59
+ "3 billion parameter, image size 448, 27-layer for "
60
+ "SigLIP-So400m vision encoder and 26-layer Gemma2 2B lanuage "
61
+ "model. This model has been fine-tuned on the DOCCI dataset "
62
+ "for improved descriptions with fine-grained details."
63
+ ),
64
+ "params": 3032979696,
65
+ "official_name": "PaliGemma2",
66
+ "path": "pali_gemma2",
67
+ "model_card": "https://www.kaggle.com/models/google/paligemma-2",
68
+ },
69
+ "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_3b_ft_docci_448/1",
70
+ },
71
+ "pali_gemma2_10b_ft_docci_448": {
72
+ "metadata": {
73
+ "description": (
74
+ "10 billion parameter, 27-layer for SigLIP-So400m vision "
75
+ "encoder and 42-layer Gemma2 9B lanuage model. This model has "
76
+ "been fine-tuned on the DOCCI dataset for improved "
77
+ "descriptions with fine-grained details."
78
+ ),
79
+ "params": 9663294192,
80
+ "official_name": "PaliGemma2",
81
+ "path": "pali_gemma2",
82
+ "model_card": "https://www.kaggle.com/models/google/paligemma-2",
83
+ },
84
+ "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_10b_ft_docci_448/1",
85
+ },
86
+ "pali_gemma2_pt_3b_224": {
87
+ "metadata": {
88
+ "description": (
89
+ "3 billion parameter, image size 224, 27-layer for "
90
+ "SigLIP-So400m vision encoder and 26-layer Gemma2 2B lanuage "
91
+ "model. This model has been pre-trained on a mixture of "
92
+ "datasets."
93
+ ),
94
+ "params": 3032094960,
95
+ "official_name": "PaliGemma2",
96
+ "path": "pali_gemma2",
97
+ "model_card": "https://www.kaggle.com/models/google/paligemma-2",
98
+ },
99
+ "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_3b_224/1",
100
+ },
101
+ "pali_gemma2_pt_3b_448": {
102
+ "metadata": {
103
+ "description": (
104
+ "3 billion parameter, image size 448, 27-layer for "
105
+ "SigLIP-So400m vision encoder and 26-layer Gemma2 2B lanuage "
106
+ "model. This model has been pre-trained on a mixture of "
107
+ "datasets."
108
+ ),
109
+ "params": 3032979696,
110
+ "official_name": "PaliGemma2",
111
+ "path": "pali_gemma2",
112
+ "model_card": "https://www.kaggle.com/models/google/paligemma-2",
113
+ },
114
+ "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_3b_448/1",
115
+ },
116
+ "pali_gemma2_pt_3b_896": {
117
+ "metadata": {
118
+ "description": (
119
+ "3 billion parameter, image size 896, 27-layer for "
120
+ "SigLIP-So400m vision encoder and 26-layer Gemma2 2B lanuage "
121
+ "model. This model has been pre-trained on a mixture of "
122
+ "datasets."
123
+ ),
124
+ "params": 3036518640,
125
+ "official_name": "PaliGemma2",
126
+ "path": "pali_gemma2",
127
+ "model_card": "https://www.kaggle.com/models/google/paligemma-2",
128
+ },
129
+ "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_3b_896/1",
130
+ },
131
+ "pali_gemma2_pt_10b_224": {
132
+ "metadata": {
133
+ "description": (
134
+ "10 billion parameter, image size 224, 27-layer for "
135
+ "SigLIP-So400m vision encoder and 42-layer Gemma2 9B lanuage "
136
+ "model. This model has been pre-trained on a mixture of "
137
+ "datasets."
138
+ ),
139
+ "params": 9662409456,
140
+ "official_name": "PaliGemma2",
141
+ "path": "pali_gemma2",
142
+ "model_card": "https://www.kaggle.com/models/google/paligemma-2",
143
+ },
144
+ "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_10b_224/1",
145
+ },
146
+ "pali_gemma2_pt_10b_448": {
147
+ "metadata": {
148
+ "description": (
149
+ "10 billion parameter, image size 448, 27-layer for "
150
+ "SigLIP-So400m vision encoder and 42-layer Gemma2 9B lanuage "
151
+ "model. This model has been pre-trained on a mixture of "
152
+ "datasets."
153
+ ),
154
+ "params": 9663294192,
155
+ "official_name": "PaliGemma2",
156
+ "path": "pali_gemma2",
157
+ "model_card": "https://www.kaggle.com/models/google/paligemma-2",
158
+ },
159
+ "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_10b_448/1",
160
+ },
161
+ "pali_gemma2_pt_10b_896": {
162
+ "metadata": {
163
+ "description": (
164
+ "10 billion parameter, image size 896, 27-layer for "
165
+ "SigLIP-So400m vision encoder and 42-layer Gemma2 9B lanuage "
166
+ "model. This model has been pre-trained on a mixture of "
167
+ "datasets."
168
+ ),
169
+ "params": 9666833136,
170
+ "official_name": "PaliGemma2",
171
+ "path": "pali_gemma2",
172
+ "model_card": "https://www.kaggle.com/models/google/paligemma-2",
173
+ },
174
+ "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_10b_896/1",
175
+ },
176
+ "pali_gemma2_pt_28b_224": {
177
+ "metadata": {
178
+ "description": (
179
+ "28 billion parameter, image size 224, 27-layer for "
180
+ "SigLIP-So400m vision encoder and 46-layer Gemma2 27B lanuage "
181
+ "model. This model has been pre-trained on a mixture of "
182
+ "datasets."
183
+ ),
184
+ "params": 9662409456,
185
+ "official_name": "PaliGemma2",
186
+ "path": "pali_gemma2",
187
+ "model_card": "https://www.kaggle.com/models/google/paligemma-2",
188
+ },
189
+ "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_28b_224/1",
190
+ },
191
+ "pali_gemma2_pt_28b_448": {
192
+ "metadata": {
193
+ "description": (
194
+ "28 billion parameter, image size 448, 27-layer for "
195
+ "SigLIP-So400m vision encoder and 46-layer Gemma2 27B lanuage "
196
+ "model. This model has been pre-trained on a mixture of "
197
+ "datasets."
198
+ ),
199
+ "params": 9663294192,
200
+ "official_name": "PaliGemma2",
201
+ "path": "pali_gemma2",
202
+ "model_card": "https://www.kaggle.com/models/google/paligemma-2",
203
+ },
204
+ "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_28b_448/1",
205
+ },
206
+ "pali_gemma2_pt_28b_896": {
207
+ "metadata": {
208
+ "description": (
209
+ "28 billion parameter, image size 896, 27-layer for "
210
+ "SigLIP-So400m vision encoder and 46-layer Gemma2 27B lanuage "
211
+ "model. This model has been pre-trained on a mixture of "
212
+ "datasets."
213
+ ),
214
+ "params": 9666833136,
215
+ "official_name": "PaliGemma2",
216
+ "path": "pali_gemma2",
217
+ "model_card": "https://www.kaggle.com/models/google/paligemma-2",
218
+ },
219
+ "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_28b_896/1",
220
+ },
65
221
  }
@@ -12,7 +12,7 @@ class PaliGemmaVitEmbeddings(keras.layers.Layer):
12
12
  dtype=None,
13
13
  **kwargs,
14
14
  ):
15
- super().__init__(**kwargs)
15
+ super().__init__(dtype=dtype, **kwargs)
16
16
  self.hidden_dim = hidden_dim
17
17
  self.image_size = image_size
18
18
  self.patch_size = patch_size
@@ -72,7 +72,7 @@ class PaliGemmaVitAttention(keras.layers.Layer):
72
72
  dtype=None,
73
73
  **kwargs,
74
74
  ):
75
- super().__init__(**kwargs)
75
+ super().__init__(dtype=dtype, **kwargs)
76
76
 
77
77
  self.hidden_dim = hidden_dim
78
78
  self.num_heads = num_heads
@@ -282,7 +282,7 @@ class PaliGemmaVitEncoder(keras.layers.Layer):
282
282
  dtype=None,
283
283
  **kwargs,
284
284
  ):
285
- super().__init__(**kwargs)
285
+ super().__init__(dtype=dtype, **kwargs)
286
286
  self.hidden_dim = hidden_dim
287
287
  self.num_layers = num_layers
288
288
  self.num_heads = num_heads
@@ -311,25 +311,26 @@ class PaliGemmaVitEncoder(keras.layers.Layer):
311
311
  for i in range(self.num_layers)
312
312
  ]
313
313
 
314
- def build(self, input_shape):
315
- self.vision_embeddings.build(input_shape)
314
+ def build(self, inputs_shape):
315
+ self.vision_embeddings.build(inputs_shape)
316
316
  for block in self.resblocks:
317
317
  block.build([None, None, self.hidden_dim])
318
318
  self.encoder_layer_norm.build([None, None, self.hidden_dim])
319
319
  self.built = True
320
320
 
321
- def call(
322
- self,
323
- x,
324
- mask=None,
325
- ):
326
- x = self.vision_embeddings(x)
321
+ def call(self, inputs, mask=None):
322
+ x = self.vision_embeddings(inputs)
327
323
  for block in self.resblocks:
328
324
  x = block(x, mask=mask)
329
325
  x = self.encoder_layer_norm(x)
330
326
  return x
331
327
 
332
328
  def compute_output_shape(self, inputs_shape):
329
+ if inputs_shape is None:
330
+ # Fix the compatibility issue with Keras 3.1 where
331
+ # `compute_output_spec` fails to propagate `inputs_shape`
332
+ # correctly, causing it to be `None`.
333
+ inputs_shape = [None, None, None]
333
334
  return [inputs_shape[0], inputs_shape[1], self.hidden_dim]
334
335
 
335
336
  def get_config(self):
@@ -12,9 +12,7 @@ backbone_presets = {
12
12
  "reasoning-dense properties."
13
13
  ),
14
14
  "params": 3821079552,
15
- "official_name": "Phi-3",
16
15
  "path": "phi3",
17
- "model_card": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
18
16
  },
19
17
  "kaggle_handle": "kaggle://keras/phi3/keras/phi3_mini_4k_instruct_en",
20
18
  },
@@ -28,9 +26,7 @@ backbone_presets = {
28
26
  "reasoning-dense properties."
29
27
  ),
30
28
  "params": 3821079552,
31
- "official_name": "Phi-3",
32
29
  "path": "phi3",
33
- "model_card": "https://huggingface.co/microsoft/Phi-3-mini-128k-instruct",
34
30
  },
35
31
  "kaggle_handle": "kaggle://keras/phi3/keras/phi3_mini_128k_instruct_en",
36
32
  },
@@ -8,9 +8,7 @@ backbone_presets = {
8
8
  "at a 224x224 resolution."
9
9
  ),
10
10
  "params": 11186112,
11
- "official_name": "ResNet",
12
11
  "path": "resnet",
13
- "model_card": "https://arxiv.org/abs/2110.00476",
14
12
  },
15
13
  "kaggle_handle": "kaggle://keras/resnetv1/keras/resnet_18_imagenet/2",
16
14
  },
@@ -21,9 +19,7 @@ backbone_presets = {
21
19
  "at a 224x224 resolution."
22
20
  ),
23
21
  "params": 23561152,
24
- "official_name": "ResNet",
25
22
  "path": "resnet",
26
- "model_card": "https://arxiv.org/abs/2110.00476",
27
23
  },
28
24
  "kaggle_handle": "kaggle://keras/resnetv1/keras/resnet_50_imagenet/2",
29
25
  },
@@ -34,9 +30,7 @@ backbone_presets = {
34
30
  "at a 224x224 resolution."
35
31
  ),
36
32
  "params": 42605504,
37
- "official_name": "ResNet",
38
33
  "path": "resnet",
39
- "model_card": "https://arxiv.org/abs/2110.00476",
40
34
  },
41
35
  "kaggle_handle": "kaggle://keras/resnetv1/keras/resnet_101_imagenet/2",
42
36
  },
@@ -47,9 +41,7 @@ backbone_presets = {
47
41
  "at a 224x224 resolution."
48
42
  ),
49
43
  "params": 58295232,
50
- "official_name": "ResNet",
51
44
  "path": "resnet",
52
- "model_card": "https://arxiv.org/abs/2110.00476",
53
45
  },
54
46
  "kaggle_handle": "kaggle://keras/resnetv1/keras/resnet_152_imagenet/2",
55
47
  },
@@ -60,9 +52,7 @@ backbone_presets = {
60
52
  "dataset at a 224x224 resolution."
61
53
  ),
62
54
  "params": 23561152,
63
- "official_name": "ResNet",
64
55
  "path": "resnet",
65
- "model_card": "https://arxiv.org/abs/2110.00476",
66
56
  },
67
57
  "kaggle_handle": "kaggle://keras/resnetv2/keras/resnet_v2_50_imagenet/2",
68
58
  },
@@ -73,9 +63,7 @@ backbone_presets = {
73
63
  "dataset at a 224x224 resolution."
74
64
  ),
75
65
  "params": 42605504,
76
- "official_name": "ResNet",
77
66
  "path": "resnet",
78
- "model_card": "https://arxiv.org/abs/2110.00476",
79
67
  },
80
68
  "kaggle_handle": "kaggle://keras/resnetv2/keras/resnet_v2_101_imagenet/2",
81
69
  },
@@ -87,11 +75,9 @@ backbone_presets = {
87
75
  "resolution."
88
76
  ),
89
77
  "params": 11722824,
90
- "official_name": "ResNet",
91
78
  "path": "resnet",
92
- "model_card": "https://arxiv.org/abs/1812.01187",
93
79
  },
94
- "kaggle_handle": "kaggle://kerashub/resnetvd/keras/resnet_vd_18_imagenet",
80
+ "kaggle_handle": "kaggle://keras/resnet_vd/keras/resnet_vd_18_imagenet",
95
81
  },
96
82
  "resnet_vd_34_imagenet": {
97
83
  "metadata": {
@@ -101,11 +87,9 @@ backbone_presets = {
101
87
  "resolution."
102
88
  ),
103
89
  "params": 21838408,
104
- "official_name": "ResNet",
105
90
  "path": "resnet",
106
- "model_card": "https://arxiv.org/abs/1812.01187",
107
91
  },
108
- "kaggle_handle": "kaggle://kerashub/resnetvd/keras/resnet_vd_34_imagenet",
92
+ "kaggle_handle": "kaggle://keras/resnet_vd/keras/resnet_vd_34_imagenet",
109
93
  },
110
94
  "resnet_vd_50_imagenet": {
111
95
  "metadata": {
@@ -115,11 +99,9 @@ backbone_presets = {
115
99
  "resolution."
116
100
  ),
117
101
  "params": 25629512,
118
- "official_name": "ResNet",
119
102
  "path": "resnet",
120
- "model_card": "https://arxiv.org/abs/1812.01187",
121
103
  },
122
- "kaggle_handle": "kaggle://kerashub/resnetvd/keras/resnet_vd_50_imagenet",
104
+ "kaggle_handle": "kaggle://keras/resnet_vd/keras/resnet_vd_50_imagenet",
123
105
  },
124
106
  "resnet_vd_50_ssld_imagenet": {
125
107
  "metadata": {
@@ -129,11 +111,9 @@ backbone_presets = {
129
111
  "resolution with knowledge distillation."
130
112
  ),
131
113
  "params": 25629512,
132
- "official_name": "ResNet",
133
114
  "path": "resnet",
134
- "model_card": "https://arxiv.org/abs/1812.01187",
135
115
  },
136
- "kaggle_handle": "kaggle://kerashub/resnetvd/keras/resnet_vd_50_ssld_imagenet",
116
+ "kaggle_handle": "kaggle://keras/resnet_vd/keras/resnet_vd_50_ssld_imagenet",
137
117
  },
138
118
  "resnet_vd_50_ssld_v2_imagenet": {
139
119
  "metadata": {
@@ -143,11 +123,9 @@ backbone_presets = {
143
123
  "resolution with knowledge distillation and AutoAugment."
144
124
  ),
145
125
  "params": 25629512,
146
- "official_name": "ResNet",
147
126
  "path": "resnet",
148
- "model_card": "https://arxiv.org/abs/1812.01187",
149
127
  },
150
- "kaggle_handle": "kaggle://kerashub/resnetvd/keras/resnet_vd_50_ssld_v2_imagenet",
128
+ "kaggle_handle": "kaggle://keras/resnet_vd/keras/resnet_vd_50_ssld_v2_imagenet",
151
129
  },
152
130
  "resnet_vd_50_ssld_v2_fix_imagenet": {
153
131
  "metadata": {
@@ -158,11 +136,9 @@ backbone_presets = {
158
136
  "additional fine-tuning of the classification head."
159
137
  ),
160
138
  "params": 25629512,
161
- "official_name": "ResNet",
162
139
  "path": "resnet",
163
- "model_card": "https://arxiv.org/abs/1812.01187",
164
140
  },
165
- "kaggle_handle": "kaggle://kerashub/resnetvd/keras/resnet_vd_50_ssld_v2_fix_imagenet",
141
+ "kaggle_handle": "kaggle://keras/resnet_vd/keras/resnet_vd_50_ssld_v2_fix_imagenet",
166
142
  },
167
143
  "resnet_vd_101_imagenet": {
168
144
  "metadata": {
@@ -172,11 +148,9 @@ backbone_presets = {
172
148
  "resolution."
173
149
  ),
174
150
  "params": 44673864,
175
- "official_name": "ResNet",
176
151
  "path": "resnet",
177
- "model_card": "https://arxiv.org/abs/1812.01187",
178
152
  },
179
- "kaggle_handle": "kaggle://kerashub/resnetvd/keras/resnet_vd_101_imagenet",
153
+ "kaggle_handle": "kaggle://keras/resnet_vd/keras/resnet_vd_101_imagenet",
180
154
  },
181
155
  "resnet_vd_101_ssld_imagenet": {
182
156
  "metadata": {
@@ -186,11 +160,9 @@ backbone_presets = {
186
160
  "resolution with knowledge distillation."
187
161
  ),
188
162
  "params": 44673864,
189
- "official_name": "ResNet",
190
163
  "path": "resnet",
191
- "model_card": "https://arxiv.org/abs/1812.01187",
192
164
  },
193
- "kaggle_handle": "kaggle://kerashub/resnetvd/keras/resnet_vd_101_ssld_imagenet",
165
+ "kaggle_handle": "kaggle://keras/resnet_vd/keras/resnet_vd_101_ssld_imagenet",
194
166
  },
195
167
  "resnet_vd_152_imagenet": {
196
168
  "metadata": {
@@ -200,11 +172,9 @@ backbone_presets = {
200
172
  "resolution."
201
173
  ),
202
174
  "params": 60363592,
203
- "official_name": "ResNet",
204
175
  "path": "resnet",
205
- "model_card": "https://arxiv.org/abs/1812.01187",
206
176
  },
207
- "kaggle_handle": "kaggle://kerashub/resnetvd/keras/resnet_vd_152_imagenet",
177
+ "kaggle_handle": "kaggle://keras/resnet_vd/keras/resnet_vd_152_imagenet",
208
178
  },
209
179
  "resnet_vd_200_imagenet": {
210
180
  "metadata": {
@@ -214,10 +184,8 @@ backbone_presets = {
214
184
  "resolution."
215
185
  ),
216
186
  "params": 74933064,
217
- "official_name": "ResNet",
218
187
  "path": "resnet",
219
- "model_card": "https://arxiv.org/abs/1812.01187",
220
188
  },
221
- "kaggle_handle": "kaggle://kerashub/resnetvd/keras/resnet_vd_200_imagenet",
189
+ "kaggle_handle": "kaggle://keras/resnet_vd/keras/resnet_vd_200_imagenet",
222
190
  },
223
191
  }
@@ -0,0 +1,5 @@
1
+ from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone
2
+ from keras_hub.src.models.retinanet.retinanet_presets import backbone_presets
3
+ from keras_hub.src.utils.preset_utils import register_presets
4
+
5
+ register_presets(backbone_presets, RetinaNetBackbone)
@@ -3,9 +3,13 @@ import math
3
3
  import keras
4
4
  from keras import ops
5
5
 
6
+ from keras_hub.src.api_export import keras_hub_export
7
+
8
+ # TODO: https://github.com/keras-team/keras-hub/issues/1965
6
9
  from keras_hub.src.bounding_box.converters import convert_format
7
10
 
8
11
 
12
+ @keras_hub_export("keras_hub.layers.AnchorGenerator")
9
13
  class AnchorGenerator(keras.layers.Layer):
10
14
  """Generates anchor boxes for object detection tasks.
11
15
 
@@ -81,6 +85,7 @@ class AnchorGenerator(keras.layers.Layer):
81
85
  self.num_scales = num_scales
82
86
  self.aspect_ratios = aspect_ratios
83
87
  self.anchor_size = anchor_size
88
+ self.num_base_anchors = num_scales * len(aspect_ratios)
84
89
  self.built = True
85
90
 
86
91
  def call(self, inputs):
@@ -92,60 +97,61 @@ class AnchorGenerator(keras.layers.Layer):
92
97
 
93
98
  image_shape = tuple(image_shape)
94
99
 
95
- multilevel_boxes = {}
100
+ multilevel_anchors = {}
96
101
  for level in range(self.min_level, self.max_level + 1):
97
- boxes_l = []
98
102
  # Calculate the feature map size for this level
99
103
  feat_size_y = math.ceil(image_shape[0] / 2**level)
100
104
  feat_size_x = math.ceil(image_shape[1] / 2**level)
101
105
 
102
106
  # Calculate the stride (step size) for this level
103
- stride_y = ops.cast(image_shape[0] / feat_size_y, "float32")
104
- stride_x = ops.cast(image_shape[1] / feat_size_x, "float32")
107
+ stride_y = image_shape[0] // feat_size_y
108
+ stride_x = image_shape[1] // feat_size_x
105
109
 
106
110
  # Generate anchor center points
107
111
  # Start from stride/2 to center anchors on pixels
108
- cx = ops.arange(stride_x / 2, image_shape[1], stride_x)
109
- cy = ops.arange(stride_y / 2, image_shape[0], stride_y)
112
+ cx = ops.arange(0, feat_size_x, dtype="float32") * stride_x
113
+ cy = ops.arange(0, feat_size_y, dtype="float32") * stride_y
110
114
 
111
115
  # Create a grid of anchor centers
112
- cx_grid, cy_grid = ops.meshgrid(cx, cy)
113
-
114
- for scale in range(self.num_scales):
115
- for aspect_ratio in self.aspect_ratios:
116
- # Calculate the intermediate scale factor
117
- intermidate_scale = 2 ** (scale / self.num_scales)
118
- # Calculate the base anchor size for this level and scale
119
- base_anchor_size = (
120
- self.anchor_size * 2**level * intermidate_scale
121
- )
122
- # Adjust anchor dimensions based on aspect ratio
123
- aspect_x = aspect_ratio**0.5
124
- aspect_y = aspect_ratio**-0.5
125
- half_anchor_size_x = base_anchor_size * aspect_x / 2.0
126
- half_anchor_size_y = base_anchor_size * aspect_y / 2.0
127
-
128
- # Generate anchor boxes (y1, x1, y2, x2 format)
129
- boxes = ops.stack(
130
- [
131
- cy_grid - half_anchor_size_y,
132
- cx_grid - half_anchor_size_x,
133
- cy_grid + half_anchor_size_y,
134
- cx_grid + half_anchor_size_x,
135
- ],
136
- axis=-1,
137
- )
138
- boxes_l.append(boxes)
139
- # Concat anchors on the same level to tensor shape HxWx(Ax4)
140
- boxes_l = ops.concatenate(boxes_l, axis=-1)
141
- boxes_l = ops.reshape(boxes_l, (-1, 4))
142
- # Convert to user defined
143
- multilevel_boxes[f"P{level}"] = convert_format(
144
- boxes_l,
145
- source="yxyx",
116
+ cy_grid, cx_grid = ops.meshgrid(cy, cx, indexing="ij")
117
+ cy_grid = ops.reshape(cy_grid, (-1,))
118
+ cx_grid = ops.reshape(cx_grid, (-1,))
119
+
120
+ shifts = ops.stack((cx_grid, cy_grid, cx_grid, cy_grid), axis=1)
121
+ sizes = [
122
+ int(
123
+ 2**level * self.anchor_size * 2 ** (scale / self.num_scales)
124
+ )
125
+ for scale in range(self.num_scales)
126
+ ]
127
+
128
+ base_anchors = self.generate_base_anchors(
129
+ sizes=sizes, aspect_ratios=self.aspect_ratios
130
+ )
131
+ shifts = ops.reshape(shifts, (-1, 1, 4))
132
+ base_anchors = ops.reshape(base_anchors, (1, -1, 4))
133
+
134
+ anchors = shifts + base_anchors
135
+ anchors = ops.reshape(anchors, (-1, 4))
136
+ multilevel_anchors[f"P{level}"] = convert_format(
137
+ anchors,
138
+ source="xyxy",
146
139
  target=self.bounding_box_format,
147
140
  )
148
- return multilevel_boxes
141
+ return multilevel_anchors
142
+
143
+ def generate_base_anchors(self, sizes, aspect_ratios):
144
+ sizes = ops.convert_to_tensor(sizes, dtype="float32")
145
+ aspect_ratios = ops.convert_to_tensor(aspect_ratios)
146
+ h_ratios = ops.sqrt(aspect_ratios)
147
+ w_ratios = 1 / h_ratios
148
+
149
+ ws = ops.reshape(w_ratios[:, None] * sizes[None, :], (-1,))
150
+ hs = ops.reshape(h_ratios[:, None] * sizes[None, :], (-1,))
151
+
152
+ base_anchors = ops.stack([-1 * ws, -1 * hs, ws, hs], axis=1) / 2
153
+ base_anchors = ops.round(base_anchors)
154
+ return base_anchors
149
155
 
150
156
  def compute_output_shape(self, input_shape):
151
157
  multilevel_boxes_shape = {}
@@ -156,18 +162,11 @@ class AnchorGenerator(keras.layers.Layer):
156
162
 
157
163
  for i in range(self.min_level, self.max_level + 1):
158
164
  multilevel_boxes_shape[f"P{i}"] = (
159
- (image_height // 2 ** (i))
160
- * (image_width // 2 ** (i))
161
- * self.anchors_per_location,
165
+ int(
166
+ math.ceil(image_height / 2 ** (i))
167
+ * math.ceil(image_width // 2 ** (i))
168
+ * self.num_base_anchors
169
+ ),
162
170
  4,
163
171
  )
164
172
  return multilevel_boxes_shape
165
-
166
- @property
167
- def anchors_per_location(self):
168
- """
169
- The `anchors_per_location` property returns the number of anchors
170
- generated per pixel location, which is equal to
171
- `num_scales * len(aspect_ratios)`.
172
- """
173
- return self.num_scales * len(self.aspect_ratios)