keras-hub-nightly 0.16.1.dev202410200345__py3-none-any.whl → 0.19.0.dev202412070351__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +12 -0
- keras_hub/api/models/__init__.py +32 -0
- keras_hub/src/bounding_box/__init__.py +2 -0
- keras_hub/src/bounding_box/converters.py +102 -12
- keras_hub/src/layers/modeling/rms_normalization.py +34 -0
- keras_hub/src/layers/modeling/transformer_encoder.py +27 -7
- keras_hub/src/layers/preprocessing/image_converter.py +5 -0
- keras_hub/src/models/albert/albert_presets.py +0 -8
- keras_hub/src/models/bart/bart_presets.py +0 -6
- keras_hub/src/models/bert/bert_presets.py +0 -20
- keras_hub/src/models/bloom/bloom_presets.py +0 -16
- keras_hub/src/models/clip/__init__.py +5 -0
- keras_hub/src/models/clip/clip_backbone.py +286 -0
- keras_hub/src/models/clip/clip_encoder_block.py +19 -4
- keras_hub/src/models/clip/clip_image_converter.py +8 -0
- keras_hub/src/models/clip/clip_presets.py +93 -0
- keras_hub/src/models/clip/clip_text_encoder.py +4 -1
- keras_hub/src/models/clip/clip_tokenizer.py +18 -3
- keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
- keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +0 -10
- keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +0 -2
- keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +5 -3
- keras_hub/src/models/densenet/densenet_backbone.py +1 -1
- keras_hub/src/models/densenet/densenet_presets.py +0 -6
- keras_hub/src/models/distil_bert/distil_bert_presets.py +0 -6
- keras_hub/src/models/efficientnet/__init__.py +9 -0
- keras_hub/src/models/efficientnet/cba.py +141 -0
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +139 -56
- keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
- keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
- keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
- keras_hub/src/models/efficientnet/efficientnet_presets.py +192 -0
- keras_hub/src/models/efficientnet/fusedmbconv.py +81 -36
- keras_hub/src/models/efficientnet/mbconv.py +52 -21
- keras_hub/src/models/electra/electra_presets.py +0 -12
- keras_hub/src/models/f_net/f_net_presets.py +0 -4
- keras_hub/src/models/falcon/falcon_presets.py +0 -2
- keras_hub/src/models/flux/__init__.py +5 -0
- keras_hub/src/models/flux/flux_layers.py +494 -0
- keras_hub/src/models/flux/flux_maths.py +218 -0
- keras_hub/src/models/flux/flux_model.py +231 -0
- keras_hub/src/models/flux/flux_presets.py +14 -0
- keras_hub/src/models/flux/flux_text_to_image.py +142 -0
- keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
- keras_hub/src/models/gemma/gemma_presets.py +0 -40
- keras_hub/src/models/gpt2/gpt2_presets.py +0 -9
- keras_hub/src/models/image_object_detector.py +87 -0
- keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
- keras_hub/src/models/image_to_image.py +16 -10
- keras_hub/src/models/inpaint.py +20 -13
- keras_hub/src/models/llama/llama_backbone.py +1 -1
- keras_hub/src/models/llama/llama_presets.py +5 -15
- keras_hub/src/models/llama3/llama3_presets.py +0 -8
- keras_hub/src/models/mistral/mistral_presets.py +0 -6
- keras_hub/src/models/mit/mit_backbone.py +41 -27
- keras_hub/src/models/mit/mit_layers.py +9 -7
- keras_hub/src/models/mit/mit_presets.py +12 -24
- keras_hub/src/models/opt/opt_presets.py +0 -8
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +61 -11
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +166 -10
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +12 -11
- keras_hub/src/models/phi3/phi3_presets.py +0 -4
- keras_hub/src/models/resnet/resnet_presets.py +10 -42
- keras_hub/src/models/retinanet/__init__.py +5 -0
- keras_hub/src/models/retinanet/anchor_generator.py +52 -53
- keras_hub/src/models/retinanet/feature_pyramid.py +99 -36
- keras_hub/src/models/retinanet/non_max_supression.py +1 -0
- keras_hub/src/models/retinanet/prediction_head.py +192 -0
- keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
- keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
- keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
- keras_hub/src/models/retinanet/retinanet_object_detector.py +382 -0
- keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
- keras_hub/src/models/retinanet/retinanet_presets.py +15 -0
- keras_hub/src/models/roberta/roberta_presets.py +0 -4
- keras_hub/src/models/sam/sam_backbone.py +0 -1
- keras_hub/src/models/sam/sam_image_segmenter.py +9 -10
- keras_hub/src/models/sam/sam_presets.py +0 -6
- keras_hub/src/models/segformer/__init__.py +8 -0
- keras_hub/src/models/segformer/segformer_backbone.py +163 -0
- keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
- keras_hub/src/models/segformer/segformer_image_segmenter.py +171 -0
- keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
- keras_hub/src/models/segformer/segformer_presets.py +124 -0
- keras_hub/src/models/stable_diffusion_3/mmdit.py +41 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +38 -21
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +3 -3
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +3 -3
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +28 -4
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +1 -1
- keras_hub/src/models/t5/t5_backbone.py +5 -4
- keras_hub/src/models/t5/t5_presets.py +41 -13
- keras_hub/src/models/text_to_image.py +13 -5
- keras_hub/src/models/vgg/vgg_backbone.py +1 -1
- keras_hub/src/models/vgg/vgg_presets.py +0 -8
- keras_hub/src/models/whisper/whisper_audio_converter.py +1 -1
- keras_hub/src/models/whisper/whisper_presets.py +0 -20
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +0 -4
- keras_hub/src/tests/test_case.py +25 -0
- keras_hub/src/utils/preset_utils.py +17 -4
- keras_hub/src/utils/timm/convert_efficientnet.py +449 -0
- keras_hub/src/utils/timm/preset_loader.py +3 -0
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/METADATA +15 -26
- {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/RECORD +109 -76
- {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/WHEEL +1 -1
- {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,12 @@
|
|
1
|
+
import math
|
2
|
+
|
1
3
|
import keras
|
2
4
|
from keras import ops
|
3
5
|
|
4
|
-
|
6
|
+
# TODO: https://github.com/keras-team/keras-hub/issues/1965
|
7
|
+
from keras_hub.src.bounding_box.converters import convert_format
|
8
|
+
from keras_hub.src.bounding_box.converters import encode_box_to_deltas
|
5
9
|
from keras_hub.src.bounding_box.iou import compute_iou
|
6
|
-
from keras_hub.src.models.retinanet.anchor_generator import AnchorGenerator
|
7
10
|
from keras_hub.src.models.retinanet.box_matcher import BoxMatcher
|
8
11
|
from keras_hub.src.utils import tensor_utils
|
9
12
|
|
@@ -24,17 +27,10 @@ class RetinaNetLabelEncoder(keras.layers.Layer):
|
|
24
27
|
consistency during training, regardless of the input format.
|
25
28
|
|
26
29
|
Args:
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
num_scales: int. Number of intermediate scales added on each level.
|
32
|
-
For example, num_scales=2 adds one additional intermediate anchor
|
33
|
-
scale [2^0, 2^0.5] on each level.
|
34
|
-
aspect_ratios: List[float]. Aspect ratios of anchors added on
|
35
|
-
each level. Each number indicates the ratio of width to height.
|
36
|
-
anchor_size: float. Scale of size of the base anchor relative to the
|
37
|
-
feature stride 2^level.
|
30
|
+
anchor_generator: A `keras_hub.layers.AnchorGenerator`.
|
31
|
+
bounding_box_format: str. Ground truth format of bounding boxes.
|
32
|
+
encoding_format: str. The desired target encoding format for the boxes.
|
33
|
+
TODO: https://github.com/keras-team/keras-hub/issues/1907
|
38
34
|
positive_threshold: float. the threshold to set an anchor to positive
|
39
35
|
match to gt box. Values above it are positive matches.
|
40
36
|
Defaults to `0.5`
|
@@ -43,7 +39,7 @@ class RetinaNetLabelEncoder(keras.layers.Layer):
|
|
43
39
|
Defaults to `0.4`
|
44
40
|
box_variance: List[float]. The scaling factors used to scale the
|
45
41
|
bounding box targets.
|
46
|
-
Defaults to `[0
|
42
|
+
Defaults to `[1.0, 1.0, 1.0, 1.0]`.
|
47
43
|
background_class: int. The class ID used for the background class,
|
48
44
|
Defaults to `-1`.
|
49
45
|
ignore_class: int. The class ID used for the ignore class,
|
@@ -63,15 +59,12 @@ class RetinaNetLabelEncoder(keras.layers.Layer):
|
|
63
59
|
|
64
60
|
def __init__(
|
65
61
|
self,
|
62
|
+
anchor_generator,
|
66
63
|
bounding_box_format,
|
67
|
-
|
68
|
-
max_level,
|
69
|
-
num_scales,
|
70
|
-
aspect_ratios,
|
71
|
-
anchor_size,
|
64
|
+
encoding_format="center_yxhw",
|
72
65
|
positive_threshold=0.5,
|
73
66
|
negative_threshold=0.4,
|
74
|
-
box_variance=[0
|
67
|
+
box_variance=[1.0, 1.0, 1.0, 1.0],
|
75
68
|
background_class=-1.0,
|
76
69
|
ignore_class=-2.0,
|
77
70
|
box_matcher_match_values=[-1, -2, 1],
|
@@ -79,27 +72,15 @@ class RetinaNetLabelEncoder(keras.layers.Layer):
|
|
79
72
|
**kwargs,
|
80
73
|
):
|
81
74
|
super().__init__(**kwargs)
|
75
|
+
self.anchor_generator = anchor_generator
|
82
76
|
self.bounding_box_format = bounding_box_format
|
83
|
-
self.
|
84
|
-
self.max_level = max_level
|
85
|
-
self.num_scales = num_scales
|
86
|
-
self.aspect_ratios = aspect_ratios
|
87
|
-
self.anchor_size = anchor_size
|
77
|
+
self.encoding_format = encoding_format
|
88
78
|
self.positive_threshold = positive_threshold
|
89
79
|
self.box_variance = box_variance
|
90
80
|
self.negative_threshold = negative_threshold
|
91
81
|
self.background_class = background_class
|
92
82
|
self.ignore_class = ignore_class
|
93
83
|
|
94
|
-
self.anchor_generator = AnchorGenerator(
|
95
|
-
bounding_box_format=bounding_box_format,
|
96
|
-
min_level=min_level,
|
97
|
-
max_level=max_level,
|
98
|
-
num_scales=num_scales,
|
99
|
-
aspect_ratios=aspect_ratios,
|
100
|
-
anchor_size=anchor_size,
|
101
|
-
)
|
102
|
-
|
103
84
|
self.box_matcher = BoxMatcher(
|
104
85
|
thresholds=[negative_threshold, positive_threshold],
|
105
86
|
match_values=box_matcher_match_values,
|
@@ -116,7 +97,7 @@ class RetinaNetLabelEncoder(keras.layers.Layer):
|
|
116
97
|
images: A Tensor. The input images argument should be
|
117
98
|
of shape `[B, H, W, C]` or `[B, C, H, W]`.
|
118
99
|
gt_boxes: A Tensor with shape of `[B, num_boxes, 4]`.
|
119
|
-
|
100
|
+
gt_classes: A Tensor with shape of `[B, num_boxes, num_classes]`
|
120
101
|
|
121
102
|
Returns:
|
122
103
|
box_targets: A Tensor of shape `[batch_size, num_anchors, 4]`
|
@@ -171,10 +152,15 @@ class RetinaNetLabelEncoder(keras.layers.Layer):
|
|
171
152
|
image_shape: Tuple indicating the image shape `[H, W, C]`.
|
172
153
|
|
173
154
|
Returns:
|
174
|
-
Encoded
|
155
|
+
Encoded bounding boxes in the format of `center_yxwh` and
|
175
156
|
corresponding labels for each encoded bounding box.
|
176
157
|
"""
|
177
|
-
|
158
|
+
anchor_boxes = convert_format(
|
159
|
+
anchor_boxes,
|
160
|
+
source=self.anchor_generator.bounding_box_format,
|
161
|
+
target=self.bounding_box_format,
|
162
|
+
image_shape=image_shape,
|
163
|
+
)
|
178
164
|
iou_matrix = compute_iou(
|
179
165
|
anchor_boxes,
|
180
166
|
gt_boxes,
|
@@ -193,11 +179,12 @@ class RetinaNetLabelEncoder(keras.layers.Layer):
|
|
193
179
|
matched_gt_boxes, (-1, ops.shape(matched_gt_boxes)[1], 4)
|
194
180
|
)
|
195
181
|
|
196
|
-
|
182
|
+
box_targets = encode_box_to_deltas(
|
197
183
|
anchors=anchor_boxes,
|
198
184
|
boxes=matched_gt_boxes,
|
199
185
|
anchor_format=self.bounding_box_format,
|
200
186
|
box_format=self.bounding_box_format,
|
187
|
+
encoding_format=self.encoding_format,
|
201
188
|
variance=self.box_variance,
|
202
189
|
image_shape=image_shape,
|
203
190
|
)
|
@@ -205,16 +192,16 @@ class RetinaNetLabelEncoder(keras.layers.Layer):
|
|
205
192
|
matched_gt_cls_ids = tensor_utils.target_gather(
|
206
193
|
gt_classes, matched_gt_idx
|
207
194
|
)
|
208
|
-
|
195
|
+
class_targets = ops.where(
|
209
196
|
ops.not_equal(positive_mask, 1.0),
|
210
197
|
self.background_class,
|
211
198
|
matched_gt_cls_ids,
|
212
199
|
)
|
213
|
-
|
214
|
-
ops.equal(ignore_mask, 1.0), self.ignore_class,
|
200
|
+
class_targets = ops.where(
|
201
|
+
ops.equal(ignore_mask, 1.0), self.ignore_class, class_targets
|
215
202
|
)
|
216
203
|
label = ops.concatenate(
|
217
|
-
[
|
204
|
+
[box_targets, ops.cast(class_targets, box_targets.dtype)], axis=-1
|
218
205
|
)
|
219
206
|
|
220
207
|
# In the case that a box in the corner of an image matches with an all
|
@@ -234,12 +221,11 @@ class RetinaNetLabelEncoder(keras.layers.Layer):
|
|
234
221
|
config = super().get_config()
|
235
222
|
config.update(
|
236
223
|
{
|
224
|
+
"anchor_generator": keras.layers.serialize(
|
225
|
+
self.anchor_generator
|
226
|
+
),
|
237
227
|
"bounding_box_format": self.bounding_box_format,
|
238
|
-
"
|
239
|
-
"max_level": self.max_level,
|
240
|
-
"num_scales": self.num_scales,
|
241
|
-
"aspect_ratios": self.aspect_ratios,
|
242
|
-
"anchor_size": self.anchor_size,
|
228
|
+
"encoding_format": self.encoding_format,
|
243
229
|
"positive_threshold": self.positive_threshold,
|
244
230
|
"box_variance": self.box_variance,
|
245
231
|
"negative_threshold": self.negative_threshold,
|
@@ -249,6 +235,18 @@ class RetinaNetLabelEncoder(keras.layers.Layer):
|
|
249
235
|
)
|
250
236
|
return config
|
251
237
|
|
238
|
+
@classmethod
|
239
|
+
def from_config(cls, config):
|
240
|
+
config.update(
|
241
|
+
{
|
242
|
+
"anchor_generator": keras.layers.deserialize(
|
243
|
+
config["anchor_generator"]
|
244
|
+
),
|
245
|
+
}
|
246
|
+
)
|
247
|
+
|
248
|
+
return super().from_config(config)
|
249
|
+
|
252
250
|
def compute_output_shape(
|
253
251
|
self, images_shape, gt_boxes_shape, gt_classes_shape
|
254
252
|
):
|
@@ -258,10 +256,10 @@ class RetinaNetLabelEncoder(keras.layers.Layer):
|
|
258
256
|
|
259
257
|
total_num_anchors = 0
|
260
258
|
for i in range(min_level, max_level + 1):
|
261
|
-
total_num_anchors += (
|
262
|
-
(image_H
|
263
|
-
* (image_W
|
264
|
-
* self.anchor_generator.
|
259
|
+
total_num_anchors += int(
|
260
|
+
math.ceil(image_H / 2 ** (i))
|
261
|
+
* math.ceil(image_W / 2 ** (i))
|
262
|
+
* self.anchor_generator.num_base_anchors
|
265
263
|
)
|
266
264
|
|
267
265
|
return (batch_size, total_num_anchors, 4), (
|
@@ -0,0 +1,382 @@
|
|
1
|
+
import keras
|
2
|
+
from keras import ops
|
3
|
+
|
4
|
+
from keras_hub.src.api_export import keras_hub_export
|
5
|
+
|
6
|
+
# TODO: https://github.com/keras-team/keras-hub/issues/1965
|
7
|
+
from keras_hub.src.bounding_box.converters import convert_format
|
8
|
+
from keras_hub.src.bounding_box.converters import decode_deltas_to_boxes
|
9
|
+
from keras_hub.src.models.image_object_detector import ImageObjectDetector
|
10
|
+
from keras_hub.src.models.retinanet.anchor_generator import AnchorGenerator
|
11
|
+
from keras_hub.src.models.retinanet.non_max_supression import NonMaxSuppression
|
12
|
+
from keras_hub.src.models.retinanet.prediction_head import PredictionHead
|
13
|
+
from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone
|
14
|
+
from keras_hub.src.models.retinanet.retinanet_label_encoder import (
|
15
|
+
RetinaNetLabelEncoder,
|
16
|
+
)
|
17
|
+
from keras_hub.src.models.retinanet.retinanet_object_detector_preprocessor import (
|
18
|
+
RetinaNetObjectDetectorPreprocessor,
|
19
|
+
)
|
20
|
+
|
21
|
+
|
22
|
+
@keras_hub_export("keras_hub.models.RetinaNetObjectDetector")
|
23
|
+
class RetinaNetObjectDetector(ImageObjectDetector):
|
24
|
+
"""RetinaNet object detector model.
|
25
|
+
|
26
|
+
This class implements the RetinaNet object detection architecture.
|
27
|
+
It consists of a feature extractor backbone, a feature pyramid network(FPN),
|
28
|
+
and two prediction heads (for classification and bounding box regression).
|
29
|
+
|
30
|
+
Args:
|
31
|
+
backbone: `keras.Model`. A `keras.models.RetinaNetBackbone` class,
|
32
|
+
defining the backbone network architecture. Provides feature maps
|
33
|
+
for detection.
|
34
|
+
anchor_generator: A `keras_hub.layers.AnchorGenerator` instance.
|
35
|
+
Generates anchor boxes at different scales and aspect ratios
|
36
|
+
across the image. If None, a default `AnchorGenerator` is
|
37
|
+
created with the following parameters:
|
38
|
+
- `bounding_box_format`: Same as the model's
|
39
|
+
`bounding_box_format`.
|
40
|
+
- `min_level`: The backbone's `min_level`.
|
41
|
+
- `max_level`: The backbone's `max_level`.
|
42
|
+
- `num_scales`: 3.
|
43
|
+
- `aspect_ratios`: [0.5, 1.0, 2.0].
|
44
|
+
- `anchor_size`: 4.0.
|
45
|
+
You can create a custom `AnchorGenerator` by instantiating the
|
46
|
+
`keras_hub.layers.AnchorGenerator` class and passing the desired
|
47
|
+
arguments.
|
48
|
+
num_classes: int. The number of object classes to be detected.
|
49
|
+
bounding_box_format: str. Dataset bounding box format (e.g., "xyxy",
|
50
|
+
"yxyx"). The supported formats are
|
51
|
+
refer TODO: https://github.com/keras-team/keras-hub/issues/1907.
|
52
|
+
Defaults to `yxyx`.
|
53
|
+
label_encoder: Optional. A `RetinaNetLabelEncoder` instance. Encodes
|
54
|
+
ground truth boxes and classes into training targets. It matches
|
55
|
+
ground truth boxes to anchors based on IoU and encodes box
|
56
|
+
coordinates as offsets. If `None`, a default encoder is created.
|
57
|
+
See the
|
58
|
+
`keras_hub.src.models.retinanet.retinanet_label_encoder.RetinaNetLabelEncoder`
|
59
|
+
class for details. If None, a default encoder is created with
|
60
|
+
standard parameters.
|
61
|
+
- `anchor_generator`: Same as the model's.
|
62
|
+
- `bounding_box_format`: Same as the model's
|
63
|
+
`bounding_box_format`.
|
64
|
+
- `positive_threshold`: 0.5
|
65
|
+
- `negative_threshold`: 0.4
|
66
|
+
- `encoding_format`: "center_xywh"
|
67
|
+
- `box_variance`: [1.0, 1.0, 1.0, 1.0]
|
68
|
+
- `background_class`: -1
|
69
|
+
- `ignore_class`: -2
|
70
|
+
use_prediction_head_norm: bool. Whether to use Group Normalization after
|
71
|
+
the convolution layers in the prediction heads. Defaults to `False`.
|
72
|
+
classification_head_prior_probability: float. Prior probability for the
|
73
|
+
classification head (used for focal loss). Defaults to 0.01.
|
74
|
+
pre_logits_num_conv_layers: int. The number of convolutional layers in
|
75
|
+
the head before the logits layer. These convolutional layers are
|
76
|
+
applied before the final linear layer (logits) that produces the
|
77
|
+
output predictions (bounding box regressions, classification scores).
|
78
|
+
preprocessor: Optional. An instance of
|
79
|
+
`RetinaNetObjectDetectorPreprocessor`or a custom preprocessor.
|
80
|
+
Handles image preprocessing before feeding into the backbone.
|
81
|
+
activation: Optional. The activation function to be used in the
|
82
|
+
classification head. If None, sigmoid is used.
|
83
|
+
dtype: Optional. The data type for the prediction heads. Defaults to the
|
84
|
+
backbone's dtype policy.
|
85
|
+
prediction_decoder: Optional. A `keras.layers.Layer` instance
|
86
|
+
responsible for transforming RetinaNet predictions
|
87
|
+
(box regressions and classifications) into final bounding boxes and
|
88
|
+
classes with confidence scores. Defaults to a `NonMaxSuppression`
|
89
|
+
instance.
|
90
|
+
"""
|
91
|
+
|
92
|
+
backbone_cls = RetinaNetBackbone
|
93
|
+
preprocessor_cls = RetinaNetObjectDetectorPreprocessor
|
94
|
+
|
95
|
+
def __init__(
|
96
|
+
self,
|
97
|
+
backbone,
|
98
|
+
num_classes,
|
99
|
+
bounding_box_format="yxyx",
|
100
|
+
anchor_generator=None,
|
101
|
+
label_encoder=None,
|
102
|
+
use_prediction_head_norm=False,
|
103
|
+
classification_head_prior_probability=0.01,
|
104
|
+
pre_logits_num_conv_layers=4,
|
105
|
+
preprocessor=None,
|
106
|
+
activation=None,
|
107
|
+
dtype=None,
|
108
|
+
prediction_decoder=None,
|
109
|
+
**kwargs,
|
110
|
+
):
|
111
|
+
# === Layers ===
|
112
|
+
image_input = keras.layers.Input(backbone.image_shape, name="images")
|
113
|
+
head_dtype = dtype or backbone.dtype_policy
|
114
|
+
|
115
|
+
anchor_generator = anchor_generator or AnchorGenerator(
|
116
|
+
bounding_box_format,
|
117
|
+
min_level=backbone.min_level,
|
118
|
+
max_level=backbone.max_level,
|
119
|
+
num_scales=3,
|
120
|
+
aspect_ratios=[0.5, 1.0, 2.0],
|
121
|
+
anchor_size=4,
|
122
|
+
)
|
123
|
+
# As weights are ported from torch they use encoded format
|
124
|
+
# as "center_xywh"
|
125
|
+
label_encoder = label_encoder or RetinaNetLabelEncoder(
|
126
|
+
anchor_generator,
|
127
|
+
bounding_box_format=bounding_box_format,
|
128
|
+
encoding_format="center_xywh",
|
129
|
+
)
|
130
|
+
|
131
|
+
box_head = PredictionHead(
|
132
|
+
output_filters=anchor_generator.num_base_anchors * 4,
|
133
|
+
num_conv_layers=pre_logits_num_conv_layers,
|
134
|
+
num_filters=256,
|
135
|
+
use_group_norm=use_prediction_head_norm,
|
136
|
+
use_prior_probability=True,
|
137
|
+
prior_probability=classification_head_prior_probability,
|
138
|
+
dtype=head_dtype,
|
139
|
+
name="box_head",
|
140
|
+
)
|
141
|
+
classification_head = PredictionHead(
|
142
|
+
output_filters=anchor_generator.num_base_anchors * num_classes,
|
143
|
+
num_conv_layers=pre_logits_num_conv_layers,
|
144
|
+
num_filters=256,
|
145
|
+
use_group_norm=use_prediction_head_norm,
|
146
|
+
dtype=head_dtype,
|
147
|
+
name="classification_head",
|
148
|
+
)
|
149
|
+
|
150
|
+
# === Functional Model ===
|
151
|
+
feature_map = backbone(image_input)
|
152
|
+
|
153
|
+
class_predictions = []
|
154
|
+
box_predictions = []
|
155
|
+
|
156
|
+
# Iterate through the feature pyramid levels (e.g., P3, P4, P5, P6, P7).
|
157
|
+
for level in feature_map:
|
158
|
+
box_predictions.append(
|
159
|
+
keras.layers.Reshape((-1, 4), name=f"box_pred_{level}")(
|
160
|
+
box_head(feature_map[level])
|
161
|
+
)
|
162
|
+
)
|
163
|
+
class_predictions.append(
|
164
|
+
keras.layers.Reshape(
|
165
|
+
(-1, num_classes), name=f"cls_pred_{level}"
|
166
|
+
)(classification_head(feature_map[level]))
|
167
|
+
)
|
168
|
+
|
169
|
+
# Concatenate predictions from all FPN levels.
|
170
|
+
class_predictions = keras.layers.Concatenate(axis=1, name="cls_logits")(
|
171
|
+
class_predictions
|
172
|
+
)
|
173
|
+
# box_pred is always in "center_xywh" delta-encoded no matter what
|
174
|
+
# format you pass in.
|
175
|
+
box_predictions = keras.layers.Concatenate(
|
176
|
+
axis=1, name="bbox_regression"
|
177
|
+
)(box_predictions)
|
178
|
+
|
179
|
+
outputs = {
|
180
|
+
"bbox_regression": box_predictions,
|
181
|
+
"cls_logits": class_predictions,
|
182
|
+
}
|
183
|
+
|
184
|
+
super().__init__(
|
185
|
+
inputs=image_input,
|
186
|
+
outputs=outputs,
|
187
|
+
**kwargs,
|
188
|
+
)
|
189
|
+
|
190
|
+
# === Config ===
|
191
|
+
self.bounding_box_format = bounding_box_format
|
192
|
+
self.use_prediction_head_norm = use_prediction_head_norm
|
193
|
+
self.num_classes = num_classes
|
194
|
+
self.backbone = backbone
|
195
|
+
self.preprocessor = preprocessor
|
196
|
+
self.activation = activation
|
197
|
+
self.pre_logits_num_conv_layers = pre_logits_num_conv_layers
|
198
|
+
self.box_head = box_head
|
199
|
+
self.classification_head = classification_head
|
200
|
+
self.anchor_generator = anchor_generator
|
201
|
+
self.label_encoder = label_encoder
|
202
|
+
self._prediction_decoder = prediction_decoder or NonMaxSuppression(
|
203
|
+
from_logits=(activation != keras.activations.sigmoid),
|
204
|
+
bounding_box_format=bounding_box_format,
|
205
|
+
)
|
206
|
+
|
207
|
+
def compute_loss(self, x, y, y_pred, sample_weight, **kwargs):
|
208
|
+
y_for_label_encoder = convert_format(
|
209
|
+
y,
|
210
|
+
source=self.bounding_box_format,
|
211
|
+
target=self.label_encoder.bounding_box_format,
|
212
|
+
images=x,
|
213
|
+
)
|
214
|
+
|
215
|
+
boxes, classes = self.label_encoder(
|
216
|
+
images=x,
|
217
|
+
gt_boxes=y_for_label_encoder["boxes"],
|
218
|
+
gt_classes=y_for_label_encoder["classes"],
|
219
|
+
)
|
220
|
+
|
221
|
+
box_pred = y_pred["bbox_regression"]
|
222
|
+
cls_pred = y_pred["cls_logits"]
|
223
|
+
|
224
|
+
if boxes.shape[-1] != 4:
|
225
|
+
raise ValueError(
|
226
|
+
"boxes should have shape (None, None, 4). Got "
|
227
|
+
f"boxes.shape={tuple(boxes.shape)}"
|
228
|
+
)
|
229
|
+
|
230
|
+
if box_pred.shape[-1] != 4:
|
231
|
+
raise ValueError(
|
232
|
+
"box_pred should have shape (None, None, 4). Got "
|
233
|
+
f"box_pred.shape={tuple(box_pred.shape)}. Does your model's "
|
234
|
+
"`num_classes` parameter match your losses `num_classes` "
|
235
|
+
"parameter?"
|
236
|
+
)
|
237
|
+
if cls_pred.shape[-1] != self.num_classes:
|
238
|
+
raise ValueError(
|
239
|
+
"cls_pred should have shape (None, None, 4). Got "
|
240
|
+
f"cls_pred.shape={tuple(cls_pred.shape)}. Does your model's "
|
241
|
+
"`num_classes` parameter match your losses `num_classes` "
|
242
|
+
"parameter?"
|
243
|
+
)
|
244
|
+
|
245
|
+
cls_labels = ops.one_hot(
|
246
|
+
ops.cast(classes, "int32"), self.num_classes, dtype="float32"
|
247
|
+
)
|
248
|
+
positive_mask = ops.cast(ops.greater(classes, -1.0), dtype="float32")
|
249
|
+
normalizer = ops.sum(positive_mask)
|
250
|
+
cls_weights = ops.cast(ops.not_equal(classes, -2.0), dtype="float32")
|
251
|
+
cls_weights /= normalizer
|
252
|
+
box_weights = positive_mask / normalizer
|
253
|
+
|
254
|
+
y_true = {
|
255
|
+
"bbox_regression": boxes,
|
256
|
+
"cls_logits": cls_labels,
|
257
|
+
}
|
258
|
+
sample_weights = {
|
259
|
+
"bbox_regression": box_weights,
|
260
|
+
"cls_logits": cls_weights,
|
261
|
+
}
|
262
|
+
zero_weight = {
|
263
|
+
"bbox_regression": ops.zeros_like(box_weights),
|
264
|
+
"cls_logits": ops.zeros_like(cls_weights),
|
265
|
+
}
|
266
|
+
|
267
|
+
sample_weight = ops.cond(
|
268
|
+
normalizer == 0,
|
269
|
+
lambda: zero_weight,
|
270
|
+
lambda: sample_weights,
|
271
|
+
)
|
272
|
+
return super().compute_loss(
|
273
|
+
x=x, y=y_true, y_pred=y_pred, sample_weight=sample_weight, **kwargs
|
274
|
+
)
|
275
|
+
|
276
|
+
def predict_step(self, *args):
|
277
|
+
outputs = super().predict_step(*args)
|
278
|
+
if isinstance(outputs, tuple):
|
279
|
+
return self.decode_predictions(outputs[0], args[-1]), outputs[1]
|
280
|
+
return self.decode_predictions(outputs, *args)
|
281
|
+
|
282
|
+
@property
|
283
|
+
def prediction_decoder(self):
|
284
|
+
return self._prediction_decoder
|
285
|
+
|
286
|
+
@prediction_decoder.setter
|
287
|
+
def prediction_decoder(self, prediction_decoder):
|
288
|
+
if prediction_decoder.bounding_box_format != self.bounding_box_format:
|
289
|
+
raise ValueError(
|
290
|
+
"Expected `prediction_decoder` and `RetinaNet` to "
|
291
|
+
"use the same `bounding_box_format`, but got "
|
292
|
+
"`prediction_decoder.bounding_box_format="
|
293
|
+
f"{prediction_decoder.bounding_box_format}`, and "
|
294
|
+
"`self.bounding_box_format="
|
295
|
+
f"{self.bounding_box_format}`."
|
296
|
+
)
|
297
|
+
self._prediction_decoder = prediction_decoder
|
298
|
+
self.make_predict_function(force=True)
|
299
|
+
self.make_train_function(force=True)
|
300
|
+
self.make_test_function(force=True)
|
301
|
+
|
302
|
+
def decode_predictions(self, predictions, data):
|
303
|
+
box_pred = predictions["bbox_regression"]
|
304
|
+
cls_pred = predictions["cls_logits"]
|
305
|
+
# box_pred is on "center_yxhw" format, convert to target format.
|
306
|
+
if isinstance(data, list) or isinstance(data, tuple):
|
307
|
+
images, _ = data
|
308
|
+
else:
|
309
|
+
images = data
|
310
|
+
image_shape = ops.shape(images)[1:]
|
311
|
+
anchor_boxes = self.anchor_generator(images)
|
312
|
+
anchor_boxes = ops.concatenate(list(anchor_boxes.values()), axis=0)
|
313
|
+
box_pred = decode_deltas_to_boxes(
|
314
|
+
anchors=anchor_boxes,
|
315
|
+
boxes_delta=box_pred,
|
316
|
+
encoded_format="center_xywh",
|
317
|
+
anchor_format=self.anchor_generator.bounding_box_format,
|
318
|
+
box_format=self.bounding_box_format,
|
319
|
+
image_shape=image_shape,
|
320
|
+
)
|
321
|
+
# box_pred is now in "self.bounding_box_format" format
|
322
|
+
box_pred = convert_format(
|
323
|
+
box_pred,
|
324
|
+
source=self.bounding_box_format,
|
325
|
+
target=self.prediction_decoder.bounding_box_format,
|
326
|
+
image_shape=image_shape,
|
327
|
+
)
|
328
|
+
y_pred = self.prediction_decoder(
|
329
|
+
box_pred, cls_pred, image_shape=image_shape
|
330
|
+
)
|
331
|
+
y_pred["boxes"] = convert_format(
|
332
|
+
y_pred["boxes"],
|
333
|
+
source=self.prediction_decoder.bounding_box_format,
|
334
|
+
target=self.bounding_box_format,
|
335
|
+
image_shape=image_shape,
|
336
|
+
)
|
337
|
+
return y_pred
|
338
|
+
|
339
|
+
def get_config(self):
|
340
|
+
config = super().get_config()
|
341
|
+
config.update(
|
342
|
+
{
|
343
|
+
"num_classes": self.num_classes,
|
344
|
+
"use_prediction_head_norm": self.use_prediction_head_norm,
|
345
|
+
"pre_logits_num_conv_layers": self.pre_logits_num_conv_layers,
|
346
|
+
"bounding_box_format": self.bounding_box_format,
|
347
|
+
"anchor_generator": keras.layers.serialize(
|
348
|
+
self.anchor_generator
|
349
|
+
),
|
350
|
+
"label_encoder": keras.layers.serialize(self.label_encoder),
|
351
|
+
"prediction_decoder": keras.layers.serialize(
|
352
|
+
self._prediction_decoder
|
353
|
+
),
|
354
|
+
}
|
355
|
+
)
|
356
|
+
|
357
|
+
return config
|
358
|
+
|
359
|
+
@classmethod
|
360
|
+
def from_config(cls, config):
|
361
|
+
if "label_encoder" in config and isinstance(
|
362
|
+
config["label_encoder"], dict
|
363
|
+
):
|
364
|
+
config["label_encoder"] = keras.layers.deserialize(
|
365
|
+
config["label_encoder"]
|
366
|
+
)
|
367
|
+
|
368
|
+
if "anchor_generator" in config and isinstance(
|
369
|
+
config["anchor_generator"], dict
|
370
|
+
):
|
371
|
+
config["anchor_generator"] = keras.layers.deserialize(
|
372
|
+
config["anchor_generator"]
|
373
|
+
)
|
374
|
+
|
375
|
+
if "prediction_decoder" in config and isinstance(
|
376
|
+
config["prediction_decoder"], dict
|
377
|
+
):
|
378
|
+
config["prediction_decoder"] = keras.layers.deserialize(
|
379
|
+
config["prediction_decoder"]
|
380
|
+
)
|
381
|
+
|
382
|
+
return super().from_config(config)
|
@@ -0,0 +1,14 @@
|
|
1
|
+
from keras_hub.src.api_export import keras_hub_export
|
2
|
+
from keras_hub.src.models.image_object_detector_preprocessor import (
|
3
|
+
ImageObjectDetectorPreprocessor,
|
4
|
+
)
|
5
|
+
from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone
|
6
|
+
from keras_hub.src.models.retinanet.retinanet_image_converter import (
|
7
|
+
RetinaNetImageConverter,
|
8
|
+
)
|
9
|
+
|
10
|
+
|
11
|
+
@keras_hub_export("keras_hub.models.RetinaNetObjectDetectorPreprocessor")
|
12
|
+
class RetinaNetObjectDetectorPreprocessor(ImageObjectDetectorPreprocessor):
|
13
|
+
backbone_cls = RetinaNetBackbone
|
14
|
+
image_converter_cls = RetinaNetImageConverter
|
@@ -0,0 +1,15 @@
|
|
1
|
+
"""RetinaNet model preset configurations."""
|
2
|
+
|
3
|
+
# Metadata for loading pretrained model weights.
|
4
|
+
backbone_presets = {
|
5
|
+
"retinanet_resnet50_fpn_coco": {
|
6
|
+
"metadata": {
|
7
|
+
"description": (
|
8
|
+
"RetinaNet model with ResNet50 backbone fine-tuned on COCO in 800x800 resolution."
|
9
|
+
),
|
10
|
+
"params": 34121239,
|
11
|
+
"path": "retinanet",
|
12
|
+
},
|
13
|
+
"kaggle_handle": "kaggle://keras/retinanet/keras/retinanet_resnet50_fpn_coco/1",
|
14
|
+
}
|
15
|
+
}
|
@@ -8,9 +8,7 @@ backbone_presets = {
|
|
8
8
|
"Trained on English Wikipedia, BooksCorpus, CommonCraw, and OpenWebText."
|
9
9
|
),
|
10
10
|
"params": 124052736,
|
11
|
-
"official_name": "RoBERTa",
|
12
11
|
"path": "roberta",
|
13
|
-
"model_card": "https://github.com/facebookresearch/fairseq/blob/main/examples/roberta/README.md",
|
14
12
|
},
|
15
13
|
"kaggle_handle": "kaggle://keras/roberta/keras/roberta_base_en/2",
|
16
14
|
},
|
@@ -21,9 +19,7 @@ backbone_presets = {
|
|
21
19
|
"Trained on English Wikipedia, BooksCorpus, CommonCraw, and OpenWebText."
|
22
20
|
),
|
23
21
|
"params": 354307072,
|
24
|
-
"official_name": "RoBERTa",
|
25
22
|
"path": "roberta",
|
26
|
-
"model_card": "https://github.com/facebookresearch/fairseq/blob/main/examples/roberta/README.md",
|
27
23
|
},
|
28
24
|
"kaggle_handle": "kaggle://keras/roberta/keras/roberta_large_en/2",
|
29
25
|
},
|