keras-hub-nightly 0.16.1.dev202410200345__py3-none-any.whl → 0.19.0.dev202412070351__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. keras_hub/api/layers/__init__.py +12 -0
  2. keras_hub/api/models/__init__.py +32 -0
  3. keras_hub/src/bounding_box/__init__.py +2 -0
  4. keras_hub/src/bounding_box/converters.py +102 -12
  5. keras_hub/src/layers/modeling/rms_normalization.py +34 -0
  6. keras_hub/src/layers/modeling/transformer_encoder.py +27 -7
  7. keras_hub/src/layers/preprocessing/image_converter.py +5 -0
  8. keras_hub/src/models/albert/albert_presets.py +0 -8
  9. keras_hub/src/models/bart/bart_presets.py +0 -6
  10. keras_hub/src/models/bert/bert_presets.py +0 -20
  11. keras_hub/src/models/bloom/bloom_presets.py +0 -16
  12. keras_hub/src/models/clip/__init__.py +5 -0
  13. keras_hub/src/models/clip/clip_backbone.py +286 -0
  14. keras_hub/src/models/clip/clip_encoder_block.py +19 -4
  15. keras_hub/src/models/clip/clip_image_converter.py +8 -0
  16. keras_hub/src/models/clip/clip_presets.py +93 -0
  17. keras_hub/src/models/clip/clip_text_encoder.py +4 -1
  18. keras_hub/src/models/clip/clip_tokenizer.py +18 -3
  19. keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
  20. keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
  21. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +0 -10
  22. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +0 -2
  23. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +5 -3
  24. keras_hub/src/models/densenet/densenet_backbone.py +1 -1
  25. keras_hub/src/models/densenet/densenet_presets.py +0 -6
  26. keras_hub/src/models/distil_bert/distil_bert_presets.py +0 -6
  27. keras_hub/src/models/efficientnet/__init__.py +9 -0
  28. keras_hub/src/models/efficientnet/cba.py +141 -0
  29. keras_hub/src/models/efficientnet/efficientnet_backbone.py +139 -56
  30. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
  31. keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
  32. keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
  33. keras_hub/src/models/efficientnet/efficientnet_presets.py +192 -0
  34. keras_hub/src/models/efficientnet/fusedmbconv.py +81 -36
  35. keras_hub/src/models/efficientnet/mbconv.py +52 -21
  36. keras_hub/src/models/electra/electra_presets.py +0 -12
  37. keras_hub/src/models/f_net/f_net_presets.py +0 -4
  38. keras_hub/src/models/falcon/falcon_presets.py +0 -2
  39. keras_hub/src/models/flux/__init__.py +5 -0
  40. keras_hub/src/models/flux/flux_layers.py +494 -0
  41. keras_hub/src/models/flux/flux_maths.py +218 -0
  42. keras_hub/src/models/flux/flux_model.py +231 -0
  43. keras_hub/src/models/flux/flux_presets.py +14 -0
  44. keras_hub/src/models/flux/flux_text_to_image.py +142 -0
  45. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
  46. keras_hub/src/models/gemma/gemma_presets.py +0 -40
  47. keras_hub/src/models/gpt2/gpt2_presets.py +0 -9
  48. keras_hub/src/models/image_object_detector.py +87 -0
  49. keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
  50. keras_hub/src/models/image_to_image.py +16 -10
  51. keras_hub/src/models/inpaint.py +20 -13
  52. keras_hub/src/models/llama/llama_backbone.py +1 -1
  53. keras_hub/src/models/llama/llama_presets.py +5 -15
  54. keras_hub/src/models/llama3/llama3_presets.py +0 -8
  55. keras_hub/src/models/mistral/mistral_presets.py +0 -6
  56. keras_hub/src/models/mit/mit_backbone.py +41 -27
  57. keras_hub/src/models/mit/mit_layers.py +9 -7
  58. keras_hub/src/models/mit/mit_presets.py +12 -24
  59. keras_hub/src/models/opt/opt_presets.py +0 -8
  60. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +61 -11
  61. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
  62. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +166 -10
  63. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +12 -11
  64. keras_hub/src/models/phi3/phi3_presets.py +0 -4
  65. keras_hub/src/models/resnet/resnet_presets.py +10 -42
  66. keras_hub/src/models/retinanet/__init__.py +5 -0
  67. keras_hub/src/models/retinanet/anchor_generator.py +52 -53
  68. keras_hub/src/models/retinanet/feature_pyramid.py +99 -36
  69. keras_hub/src/models/retinanet/non_max_supression.py +1 -0
  70. keras_hub/src/models/retinanet/prediction_head.py +192 -0
  71. keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
  72. keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
  73. keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
  74. keras_hub/src/models/retinanet/retinanet_object_detector.py +382 -0
  75. keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
  76. keras_hub/src/models/retinanet/retinanet_presets.py +15 -0
  77. keras_hub/src/models/roberta/roberta_presets.py +0 -4
  78. keras_hub/src/models/sam/sam_backbone.py +0 -1
  79. keras_hub/src/models/sam/sam_image_segmenter.py +9 -10
  80. keras_hub/src/models/sam/sam_presets.py +0 -6
  81. keras_hub/src/models/segformer/__init__.py +8 -0
  82. keras_hub/src/models/segformer/segformer_backbone.py +163 -0
  83. keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
  84. keras_hub/src/models/segformer/segformer_image_segmenter.py +171 -0
  85. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
  86. keras_hub/src/models/segformer/segformer_presets.py +124 -0
  87. keras_hub/src/models/stable_diffusion_3/mmdit.py +41 -0
  88. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +38 -21
  89. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +3 -3
  90. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +3 -3
  91. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +28 -4
  92. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +1 -1
  93. keras_hub/src/models/t5/t5_backbone.py +5 -4
  94. keras_hub/src/models/t5/t5_presets.py +41 -13
  95. keras_hub/src/models/text_to_image.py +13 -5
  96. keras_hub/src/models/vgg/vgg_backbone.py +1 -1
  97. keras_hub/src/models/vgg/vgg_presets.py +0 -8
  98. keras_hub/src/models/whisper/whisper_audio_converter.py +1 -1
  99. keras_hub/src/models/whisper/whisper_presets.py +0 -20
  100. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +0 -4
  101. keras_hub/src/tests/test_case.py +25 -0
  102. keras_hub/src/utils/preset_utils.py +17 -4
  103. keras_hub/src/utils/timm/convert_efficientnet.py +449 -0
  104. keras_hub/src/utils/timm/preset_loader.py +3 -0
  105. keras_hub/src/version_utils.py +1 -1
  106. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/METADATA +15 -26
  107. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/RECORD +109 -76
  108. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/WHEEL +1 -1
  109. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,12 @@
1
+ import math
2
+
1
3
  import keras
2
4
  from keras import ops
3
5
 
4
- from keras_hub.src.bounding_box.converters import _encode_box_to_deltas
6
+ # TODO: https://github.com/keras-team/keras-hub/issues/1965
7
+ from keras_hub.src.bounding_box.converters import convert_format
8
+ from keras_hub.src.bounding_box.converters import encode_box_to_deltas
5
9
  from keras_hub.src.bounding_box.iou import compute_iou
6
- from keras_hub.src.models.retinanet.anchor_generator import AnchorGenerator
7
10
  from keras_hub.src.models.retinanet.box_matcher import BoxMatcher
8
11
  from keras_hub.src.utils import tensor_utils
9
12
 
@@ -24,17 +27,10 @@ class RetinaNetLabelEncoder(keras.layers.Layer):
24
27
  consistency during training, regardless of the input format.
25
28
 
26
29
  Args:
27
- bounding_box_format: str. The format of bounding boxes of input dataset.
28
- Refer TODO: Add link to Keras Core Docs.
29
- min_level: int. Minimum level of the output feature pyramid.
30
- max_level: int. Maximum level of the output feature pyramid.
31
- num_scales: int. Number of intermediate scales added on each level.
32
- For example, num_scales=2 adds one additional intermediate anchor
33
- scale [2^0, 2^0.5] on each level.
34
- aspect_ratios: List[float]. Aspect ratios of anchors added on
35
- each level. Each number indicates the ratio of width to height.
36
- anchor_size: float. Scale of size of the base anchor relative to the
37
- feature stride 2^level.
30
+ anchor_generator: A `keras_hub.layers.AnchorGenerator`.
31
+ bounding_box_format: str. Ground truth format of bounding boxes.
32
+ encoding_format: str. The desired target encoding format for the boxes.
33
+ TODO: https://github.com/keras-team/keras-hub/issues/1907
38
34
  positive_threshold: float. the threshold to set an anchor to positive
39
35
  match to gt box. Values above it are positive matches.
40
36
  Defaults to `0.5`
@@ -43,7 +39,7 @@ class RetinaNetLabelEncoder(keras.layers.Layer):
43
39
  Defaults to `0.4`
44
40
  box_variance: List[float]. The scaling factors used to scale the
45
41
  bounding box targets.
46
- Defaults to `[0.1, 0.1, 0.2, 0.2]`.
42
+ Defaults to `[1.0, 1.0, 1.0, 1.0]`.
47
43
  background_class: int. The class ID used for the background class,
48
44
  Defaults to `-1`.
49
45
  ignore_class: int. The class ID used for the ignore class,
@@ -63,15 +59,12 @@ class RetinaNetLabelEncoder(keras.layers.Layer):
63
59
 
64
60
  def __init__(
65
61
  self,
62
+ anchor_generator,
66
63
  bounding_box_format,
67
- min_level,
68
- max_level,
69
- num_scales,
70
- aspect_ratios,
71
- anchor_size,
64
+ encoding_format="center_yxhw",
72
65
  positive_threshold=0.5,
73
66
  negative_threshold=0.4,
74
- box_variance=[0.1, 0.1, 0.2, 0.2],
67
+ box_variance=[1.0, 1.0, 1.0, 1.0],
75
68
  background_class=-1.0,
76
69
  ignore_class=-2.0,
77
70
  box_matcher_match_values=[-1, -2, 1],
@@ -79,27 +72,15 @@ class RetinaNetLabelEncoder(keras.layers.Layer):
79
72
  **kwargs,
80
73
  ):
81
74
  super().__init__(**kwargs)
75
+ self.anchor_generator = anchor_generator
82
76
  self.bounding_box_format = bounding_box_format
83
- self.min_level = min_level
84
- self.max_level = max_level
85
- self.num_scales = num_scales
86
- self.aspect_ratios = aspect_ratios
87
- self.anchor_size = anchor_size
77
+ self.encoding_format = encoding_format
88
78
  self.positive_threshold = positive_threshold
89
79
  self.box_variance = box_variance
90
80
  self.negative_threshold = negative_threshold
91
81
  self.background_class = background_class
92
82
  self.ignore_class = ignore_class
93
83
 
94
- self.anchor_generator = AnchorGenerator(
95
- bounding_box_format=bounding_box_format,
96
- min_level=min_level,
97
- max_level=max_level,
98
- num_scales=num_scales,
99
- aspect_ratios=aspect_ratios,
100
- anchor_size=anchor_size,
101
- )
102
-
103
84
  self.box_matcher = BoxMatcher(
104
85
  thresholds=[negative_threshold, positive_threshold],
105
86
  match_values=box_matcher_match_values,
@@ -116,7 +97,7 @@ class RetinaNetLabelEncoder(keras.layers.Layer):
116
97
  images: A Tensor. The input images argument should be
117
98
  of shape `[B, H, W, C]` or `[B, C, H, W]`.
118
99
  gt_boxes: A Tensor with shape of `[B, num_boxes, 4]`.
119
- gt_labels: A Tensor with shape of `[B, num_boxes, num_classes]`
100
+ gt_classes: A Tensor with shape of `[B, num_boxes, num_classes]`
120
101
 
121
102
  Returns:
122
103
  box_targets: A Tensor of shape `[batch_size, num_anchors, 4]`
@@ -171,10 +152,15 @@ class RetinaNetLabelEncoder(keras.layers.Layer):
171
152
  image_shape: Tuple indicating the image shape `[H, W, C]`.
172
153
 
173
154
  Returns:
174
- Encoded boudning boxes in the format of `center_yxwh` and
155
+ Encoded bounding boxes in the format of `center_yxwh` and
175
156
  corresponding labels for each encoded bounding box.
176
157
  """
177
-
158
+ anchor_boxes = convert_format(
159
+ anchor_boxes,
160
+ source=self.anchor_generator.bounding_box_format,
161
+ target=self.bounding_box_format,
162
+ image_shape=image_shape,
163
+ )
178
164
  iou_matrix = compute_iou(
179
165
  anchor_boxes,
180
166
  gt_boxes,
@@ -193,11 +179,12 @@ class RetinaNetLabelEncoder(keras.layers.Layer):
193
179
  matched_gt_boxes, (-1, ops.shape(matched_gt_boxes)[1], 4)
194
180
  )
195
181
 
196
- box_target = _encode_box_to_deltas(
182
+ box_targets = encode_box_to_deltas(
197
183
  anchors=anchor_boxes,
198
184
  boxes=matched_gt_boxes,
199
185
  anchor_format=self.bounding_box_format,
200
186
  box_format=self.bounding_box_format,
187
+ encoding_format=self.encoding_format,
201
188
  variance=self.box_variance,
202
189
  image_shape=image_shape,
203
190
  )
@@ -205,16 +192,16 @@ class RetinaNetLabelEncoder(keras.layers.Layer):
205
192
  matched_gt_cls_ids = tensor_utils.target_gather(
206
193
  gt_classes, matched_gt_idx
207
194
  )
208
- cls_target = ops.where(
195
+ class_targets = ops.where(
209
196
  ops.not_equal(positive_mask, 1.0),
210
197
  self.background_class,
211
198
  matched_gt_cls_ids,
212
199
  )
213
- cls_target = ops.where(
214
- ops.equal(ignore_mask, 1.0), self.ignore_class, cls_target
200
+ class_targets = ops.where(
201
+ ops.equal(ignore_mask, 1.0), self.ignore_class, class_targets
215
202
  )
216
203
  label = ops.concatenate(
217
- [box_target, ops.cast(cls_target, box_target.dtype)], axis=-1
204
+ [box_targets, ops.cast(class_targets, box_targets.dtype)], axis=-1
218
205
  )
219
206
 
220
207
  # In the case that a box in the corner of an image matches with an all
@@ -234,12 +221,11 @@ class RetinaNetLabelEncoder(keras.layers.Layer):
234
221
  config = super().get_config()
235
222
  config.update(
236
223
  {
224
+ "anchor_generator": keras.layers.serialize(
225
+ self.anchor_generator
226
+ ),
237
227
  "bounding_box_format": self.bounding_box_format,
238
- "min_level": self.min_level,
239
- "max_level": self.max_level,
240
- "num_scales": self.num_scales,
241
- "aspect_ratios": self.aspect_ratios,
242
- "anchor_size": self.anchor_size,
228
+ "encoding_format": self.encoding_format,
243
229
  "positive_threshold": self.positive_threshold,
244
230
  "box_variance": self.box_variance,
245
231
  "negative_threshold": self.negative_threshold,
@@ -249,6 +235,18 @@ class RetinaNetLabelEncoder(keras.layers.Layer):
249
235
  )
250
236
  return config
251
237
 
238
+ @classmethod
239
+ def from_config(cls, config):
240
+ config.update(
241
+ {
242
+ "anchor_generator": keras.layers.deserialize(
243
+ config["anchor_generator"]
244
+ ),
245
+ }
246
+ )
247
+
248
+ return super().from_config(config)
249
+
252
250
  def compute_output_shape(
253
251
  self, images_shape, gt_boxes_shape, gt_classes_shape
254
252
  ):
@@ -258,10 +256,10 @@ class RetinaNetLabelEncoder(keras.layers.Layer):
258
256
 
259
257
  total_num_anchors = 0
260
258
  for i in range(min_level, max_level + 1):
261
- total_num_anchors += (
262
- (image_H // 2 ** (i))
263
- * (image_W // 2 ** (i))
264
- * self.anchor_generator.anchors_per_location
259
+ total_num_anchors += int(
260
+ math.ceil(image_H / 2 ** (i))
261
+ * math.ceil(image_W / 2 ** (i))
262
+ * self.anchor_generator.num_base_anchors
265
263
  )
266
264
 
267
265
  return (batch_size, total_num_anchors, 4), (
@@ -0,0 +1,382 @@
1
+ import keras
2
+ from keras import ops
3
+
4
+ from keras_hub.src.api_export import keras_hub_export
5
+
6
+ # TODO: https://github.com/keras-team/keras-hub/issues/1965
7
+ from keras_hub.src.bounding_box.converters import convert_format
8
+ from keras_hub.src.bounding_box.converters import decode_deltas_to_boxes
9
+ from keras_hub.src.models.image_object_detector import ImageObjectDetector
10
+ from keras_hub.src.models.retinanet.anchor_generator import AnchorGenerator
11
+ from keras_hub.src.models.retinanet.non_max_supression import NonMaxSuppression
12
+ from keras_hub.src.models.retinanet.prediction_head import PredictionHead
13
+ from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone
14
+ from keras_hub.src.models.retinanet.retinanet_label_encoder import (
15
+ RetinaNetLabelEncoder,
16
+ )
17
+ from keras_hub.src.models.retinanet.retinanet_object_detector_preprocessor import (
18
+ RetinaNetObjectDetectorPreprocessor,
19
+ )
20
+
21
+
22
+ @keras_hub_export("keras_hub.models.RetinaNetObjectDetector")
23
+ class RetinaNetObjectDetector(ImageObjectDetector):
24
+ """RetinaNet object detector model.
25
+
26
+ This class implements the RetinaNet object detection architecture.
27
+ It consists of a feature extractor backbone, a feature pyramid network(FPN),
28
+ and two prediction heads (for classification and bounding box regression).
29
+
30
+ Args:
31
+ backbone: `keras.Model`. A `keras.models.RetinaNetBackbone` class,
32
+ defining the backbone network architecture. Provides feature maps
33
+ for detection.
34
+ anchor_generator: A `keras_hub.layers.AnchorGenerator` instance.
35
+ Generates anchor boxes at different scales and aspect ratios
36
+ across the image. If None, a default `AnchorGenerator` is
37
+ created with the following parameters:
38
+ - `bounding_box_format`: Same as the model's
39
+ `bounding_box_format`.
40
+ - `min_level`: The backbone's `min_level`.
41
+ - `max_level`: The backbone's `max_level`.
42
+ - `num_scales`: 3.
43
+ - `aspect_ratios`: [0.5, 1.0, 2.0].
44
+ - `anchor_size`: 4.0.
45
+ You can create a custom `AnchorGenerator` by instantiating the
46
+ `keras_hub.layers.AnchorGenerator` class and passing the desired
47
+ arguments.
48
+ num_classes: int. The number of object classes to be detected.
49
+ bounding_box_format: str. Dataset bounding box format (e.g., "xyxy",
50
+ "yxyx"). The supported formats are
51
+ refer TODO: https://github.com/keras-team/keras-hub/issues/1907.
52
+ Defaults to `yxyx`.
53
+ label_encoder: Optional. A `RetinaNetLabelEncoder` instance. Encodes
54
+ ground truth boxes and classes into training targets. It matches
55
+ ground truth boxes to anchors based on IoU and encodes box
56
+ coordinates as offsets. If `None`, a default encoder is created.
57
+ See the
58
+ `keras_hub.src.models.retinanet.retinanet_label_encoder.RetinaNetLabelEncoder`
59
+ class for details. If None, a default encoder is created with
60
+ standard parameters.
61
+ - `anchor_generator`: Same as the model's.
62
+ - `bounding_box_format`: Same as the model's
63
+ `bounding_box_format`.
64
+ - `positive_threshold`: 0.5
65
+ - `negative_threshold`: 0.4
66
+ - `encoding_format`: "center_xywh"
67
+ - `box_variance`: [1.0, 1.0, 1.0, 1.0]
68
+ - `background_class`: -1
69
+ - `ignore_class`: -2
70
+ use_prediction_head_norm: bool. Whether to use Group Normalization after
71
+ the convolution layers in the prediction heads. Defaults to `False`.
72
+ classification_head_prior_probability: float. Prior probability for the
73
+ classification head (used for focal loss). Defaults to 0.01.
74
+ pre_logits_num_conv_layers: int. The number of convolutional layers in
75
+ the head before the logits layer. These convolutional layers are
76
+ applied before the final linear layer (logits) that produces the
77
+ output predictions (bounding box regressions, classification scores).
78
+ preprocessor: Optional. An instance of
79
+ `RetinaNetObjectDetectorPreprocessor`or a custom preprocessor.
80
+ Handles image preprocessing before feeding into the backbone.
81
+ activation: Optional. The activation function to be used in the
82
+ classification head. If None, sigmoid is used.
83
+ dtype: Optional. The data type for the prediction heads. Defaults to the
84
+ backbone's dtype policy.
85
+ prediction_decoder: Optional. A `keras.layers.Layer` instance
86
+ responsible for transforming RetinaNet predictions
87
+ (box regressions and classifications) into final bounding boxes and
88
+ classes with confidence scores. Defaults to a `NonMaxSuppression`
89
+ instance.
90
+ """
91
+
92
+ backbone_cls = RetinaNetBackbone
93
+ preprocessor_cls = RetinaNetObjectDetectorPreprocessor
94
+
95
+ def __init__(
96
+ self,
97
+ backbone,
98
+ num_classes,
99
+ bounding_box_format="yxyx",
100
+ anchor_generator=None,
101
+ label_encoder=None,
102
+ use_prediction_head_norm=False,
103
+ classification_head_prior_probability=0.01,
104
+ pre_logits_num_conv_layers=4,
105
+ preprocessor=None,
106
+ activation=None,
107
+ dtype=None,
108
+ prediction_decoder=None,
109
+ **kwargs,
110
+ ):
111
+ # === Layers ===
112
+ image_input = keras.layers.Input(backbone.image_shape, name="images")
113
+ head_dtype = dtype or backbone.dtype_policy
114
+
115
+ anchor_generator = anchor_generator or AnchorGenerator(
116
+ bounding_box_format,
117
+ min_level=backbone.min_level,
118
+ max_level=backbone.max_level,
119
+ num_scales=3,
120
+ aspect_ratios=[0.5, 1.0, 2.0],
121
+ anchor_size=4,
122
+ )
123
+ # As weights are ported from torch they use encoded format
124
+ # as "center_xywh"
125
+ label_encoder = label_encoder or RetinaNetLabelEncoder(
126
+ anchor_generator,
127
+ bounding_box_format=bounding_box_format,
128
+ encoding_format="center_xywh",
129
+ )
130
+
131
+ box_head = PredictionHead(
132
+ output_filters=anchor_generator.num_base_anchors * 4,
133
+ num_conv_layers=pre_logits_num_conv_layers,
134
+ num_filters=256,
135
+ use_group_norm=use_prediction_head_norm,
136
+ use_prior_probability=True,
137
+ prior_probability=classification_head_prior_probability,
138
+ dtype=head_dtype,
139
+ name="box_head",
140
+ )
141
+ classification_head = PredictionHead(
142
+ output_filters=anchor_generator.num_base_anchors * num_classes,
143
+ num_conv_layers=pre_logits_num_conv_layers,
144
+ num_filters=256,
145
+ use_group_norm=use_prediction_head_norm,
146
+ dtype=head_dtype,
147
+ name="classification_head",
148
+ )
149
+
150
+ # === Functional Model ===
151
+ feature_map = backbone(image_input)
152
+
153
+ class_predictions = []
154
+ box_predictions = []
155
+
156
+ # Iterate through the feature pyramid levels (e.g., P3, P4, P5, P6, P7).
157
+ for level in feature_map:
158
+ box_predictions.append(
159
+ keras.layers.Reshape((-1, 4), name=f"box_pred_{level}")(
160
+ box_head(feature_map[level])
161
+ )
162
+ )
163
+ class_predictions.append(
164
+ keras.layers.Reshape(
165
+ (-1, num_classes), name=f"cls_pred_{level}"
166
+ )(classification_head(feature_map[level]))
167
+ )
168
+
169
+ # Concatenate predictions from all FPN levels.
170
+ class_predictions = keras.layers.Concatenate(axis=1, name="cls_logits")(
171
+ class_predictions
172
+ )
173
+ # box_pred is always in "center_xywh" delta-encoded no matter what
174
+ # format you pass in.
175
+ box_predictions = keras.layers.Concatenate(
176
+ axis=1, name="bbox_regression"
177
+ )(box_predictions)
178
+
179
+ outputs = {
180
+ "bbox_regression": box_predictions,
181
+ "cls_logits": class_predictions,
182
+ }
183
+
184
+ super().__init__(
185
+ inputs=image_input,
186
+ outputs=outputs,
187
+ **kwargs,
188
+ )
189
+
190
+ # === Config ===
191
+ self.bounding_box_format = bounding_box_format
192
+ self.use_prediction_head_norm = use_prediction_head_norm
193
+ self.num_classes = num_classes
194
+ self.backbone = backbone
195
+ self.preprocessor = preprocessor
196
+ self.activation = activation
197
+ self.pre_logits_num_conv_layers = pre_logits_num_conv_layers
198
+ self.box_head = box_head
199
+ self.classification_head = classification_head
200
+ self.anchor_generator = anchor_generator
201
+ self.label_encoder = label_encoder
202
+ self._prediction_decoder = prediction_decoder or NonMaxSuppression(
203
+ from_logits=(activation != keras.activations.sigmoid),
204
+ bounding_box_format=bounding_box_format,
205
+ )
206
+
207
+ def compute_loss(self, x, y, y_pred, sample_weight, **kwargs):
208
+ y_for_label_encoder = convert_format(
209
+ y,
210
+ source=self.bounding_box_format,
211
+ target=self.label_encoder.bounding_box_format,
212
+ images=x,
213
+ )
214
+
215
+ boxes, classes = self.label_encoder(
216
+ images=x,
217
+ gt_boxes=y_for_label_encoder["boxes"],
218
+ gt_classes=y_for_label_encoder["classes"],
219
+ )
220
+
221
+ box_pred = y_pred["bbox_regression"]
222
+ cls_pred = y_pred["cls_logits"]
223
+
224
+ if boxes.shape[-1] != 4:
225
+ raise ValueError(
226
+ "boxes should have shape (None, None, 4). Got "
227
+ f"boxes.shape={tuple(boxes.shape)}"
228
+ )
229
+
230
+ if box_pred.shape[-1] != 4:
231
+ raise ValueError(
232
+ "box_pred should have shape (None, None, 4). Got "
233
+ f"box_pred.shape={tuple(box_pred.shape)}. Does your model's "
234
+ "`num_classes` parameter match your losses `num_classes` "
235
+ "parameter?"
236
+ )
237
+ if cls_pred.shape[-1] != self.num_classes:
238
+ raise ValueError(
239
+ "cls_pred should have shape (None, None, 4). Got "
240
+ f"cls_pred.shape={tuple(cls_pred.shape)}. Does your model's "
241
+ "`num_classes` parameter match your losses `num_classes` "
242
+ "parameter?"
243
+ )
244
+
245
+ cls_labels = ops.one_hot(
246
+ ops.cast(classes, "int32"), self.num_classes, dtype="float32"
247
+ )
248
+ positive_mask = ops.cast(ops.greater(classes, -1.0), dtype="float32")
249
+ normalizer = ops.sum(positive_mask)
250
+ cls_weights = ops.cast(ops.not_equal(classes, -2.0), dtype="float32")
251
+ cls_weights /= normalizer
252
+ box_weights = positive_mask / normalizer
253
+
254
+ y_true = {
255
+ "bbox_regression": boxes,
256
+ "cls_logits": cls_labels,
257
+ }
258
+ sample_weights = {
259
+ "bbox_regression": box_weights,
260
+ "cls_logits": cls_weights,
261
+ }
262
+ zero_weight = {
263
+ "bbox_regression": ops.zeros_like(box_weights),
264
+ "cls_logits": ops.zeros_like(cls_weights),
265
+ }
266
+
267
+ sample_weight = ops.cond(
268
+ normalizer == 0,
269
+ lambda: zero_weight,
270
+ lambda: sample_weights,
271
+ )
272
+ return super().compute_loss(
273
+ x=x, y=y_true, y_pred=y_pred, sample_weight=sample_weight, **kwargs
274
+ )
275
+
276
+ def predict_step(self, *args):
277
+ outputs = super().predict_step(*args)
278
+ if isinstance(outputs, tuple):
279
+ return self.decode_predictions(outputs[0], args[-1]), outputs[1]
280
+ return self.decode_predictions(outputs, *args)
281
+
282
+ @property
283
+ def prediction_decoder(self):
284
+ return self._prediction_decoder
285
+
286
+ @prediction_decoder.setter
287
+ def prediction_decoder(self, prediction_decoder):
288
+ if prediction_decoder.bounding_box_format != self.bounding_box_format:
289
+ raise ValueError(
290
+ "Expected `prediction_decoder` and `RetinaNet` to "
291
+ "use the same `bounding_box_format`, but got "
292
+ "`prediction_decoder.bounding_box_format="
293
+ f"{prediction_decoder.bounding_box_format}`, and "
294
+ "`self.bounding_box_format="
295
+ f"{self.bounding_box_format}`."
296
+ )
297
+ self._prediction_decoder = prediction_decoder
298
+ self.make_predict_function(force=True)
299
+ self.make_train_function(force=True)
300
+ self.make_test_function(force=True)
301
+
302
+ def decode_predictions(self, predictions, data):
303
+ box_pred = predictions["bbox_regression"]
304
+ cls_pred = predictions["cls_logits"]
305
+ # box_pred is on "center_yxhw" format, convert to target format.
306
+ if isinstance(data, list) or isinstance(data, tuple):
307
+ images, _ = data
308
+ else:
309
+ images = data
310
+ image_shape = ops.shape(images)[1:]
311
+ anchor_boxes = self.anchor_generator(images)
312
+ anchor_boxes = ops.concatenate(list(anchor_boxes.values()), axis=0)
313
+ box_pred = decode_deltas_to_boxes(
314
+ anchors=anchor_boxes,
315
+ boxes_delta=box_pred,
316
+ encoded_format="center_xywh",
317
+ anchor_format=self.anchor_generator.bounding_box_format,
318
+ box_format=self.bounding_box_format,
319
+ image_shape=image_shape,
320
+ )
321
+ # box_pred is now in "self.bounding_box_format" format
322
+ box_pred = convert_format(
323
+ box_pred,
324
+ source=self.bounding_box_format,
325
+ target=self.prediction_decoder.bounding_box_format,
326
+ image_shape=image_shape,
327
+ )
328
+ y_pred = self.prediction_decoder(
329
+ box_pred, cls_pred, image_shape=image_shape
330
+ )
331
+ y_pred["boxes"] = convert_format(
332
+ y_pred["boxes"],
333
+ source=self.prediction_decoder.bounding_box_format,
334
+ target=self.bounding_box_format,
335
+ image_shape=image_shape,
336
+ )
337
+ return y_pred
338
+
339
+ def get_config(self):
340
+ config = super().get_config()
341
+ config.update(
342
+ {
343
+ "num_classes": self.num_classes,
344
+ "use_prediction_head_norm": self.use_prediction_head_norm,
345
+ "pre_logits_num_conv_layers": self.pre_logits_num_conv_layers,
346
+ "bounding_box_format": self.bounding_box_format,
347
+ "anchor_generator": keras.layers.serialize(
348
+ self.anchor_generator
349
+ ),
350
+ "label_encoder": keras.layers.serialize(self.label_encoder),
351
+ "prediction_decoder": keras.layers.serialize(
352
+ self._prediction_decoder
353
+ ),
354
+ }
355
+ )
356
+
357
+ return config
358
+
359
+ @classmethod
360
+ def from_config(cls, config):
361
+ if "label_encoder" in config and isinstance(
362
+ config["label_encoder"], dict
363
+ ):
364
+ config["label_encoder"] = keras.layers.deserialize(
365
+ config["label_encoder"]
366
+ )
367
+
368
+ if "anchor_generator" in config and isinstance(
369
+ config["anchor_generator"], dict
370
+ ):
371
+ config["anchor_generator"] = keras.layers.deserialize(
372
+ config["anchor_generator"]
373
+ )
374
+
375
+ if "prediction_decoder" in config and isinstance(
376
+ config["prediction_decoder"], dict
377
+ ):
378
+ config["prediction_decoder"] = keras.layers.deserialize(
379
+ config["prediction_decoder"]
380
+ )
381
+
382
+ return super().from_config(config)
@@ -0,0 +1,14 @@
1
+ from keras_hub.src.api_export import keras_hub_export
2
+ from keras_hub.src.models.image_object_detector_preprocessor import (
3
+ ImageObjectDetectorPreprocessor,
4
+ )
5
+ from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone
6
+ from keras_hub.src.models.retinanet.retinanet_image_converter import (
7
+ RetinaNetImageConverter,
8
+ )
9
+
10
+
11
+ @keras_hub_export("keras_hub.models.RetinaNetObjectDetectorPreprocessor")
12
+ class RetinaNetObjectDetectorPreprocessor(ImageObjectDetectorPreprocessor):
13
+ backbone_cls = RetinaNetBackbone
14
+ image_converter_cls = RetinaNetImageConverter
@@ -0,0 +1,15 @@
1
+ """RetinaNet model preset configurations."""
2
+
3
+ # Metadata for loading pretrained model weights.
4
+ backbone_presets = {
5
+ "retinanet_resnet50_fpn_coco": {
6
+ "metadata": {
7
+ "description": (
8
+ "RetinaNet model with ResNet50 backbone fine-tuned on COCO in 800x800 resolution."
9
+ ),
10
+ "params": 34121239,
11
+ "path": "retinanet",
12
+ },
13
+ "kaggle_handle": "kaggle://keras/retinanet/keras/retinanet_resnet50_fpn_coco/1",
14
+ }
15
+ }
@@ -8,9 +8,7 @@ backbone_presets = {
8
8
  "Trained on English Wikipedia, BooksCorpus, CommonCraw, and OpenWebText."
9
9
  ),
10
10
  "params": 124052736,
11
- "official_name": "RoBERTa",
12
11
  "path": "roberta",
13
- "model_card": "https://github.com/facebookresearch/fairseq/blob/main/examples/roberta/README.md",
14
12
  },
15
13
  "kaggle_handle": "kaggle://keras/roberta/keras/roberta_base_en/2",
16
14
  },
@@ -21,9 +19,7 @@ backbone_presets = {
21
19
  "Trained on English Wikipedia, BooksCorpus, CommonCraw, and OpenWebText."
22
20
  ),
23
21
  "params": 354307072,
24
- "official_name": "RoBERTa",
25
22
  "path": "roberta",
26
- "model_card": "https://github.com/facebookresearch/fairseq/blob/main/examples/roberta/README.md",
27
23
  },
28
24
  "kaggle_handle": "kaggle://keras/roberta/keras/roberta_large_en/2",
29
25
  },
@@ -68,7 +68,6 @@ class SAMBackbone(Backbone):
68
68
  image_encoder=image_encoder,
69
69
  prompt_encoder=prompt_encoder,
70
70
  mask_decoder=mask_decoder,
71
- image_shape=(image_size, image_size, 3),
72
71
  )
73
72
  backbone(input_data)
74
73
  ```