keras-hub-nightly 0.15.0.dev20240911134614__py3-none-any.whl → 0.16.0.dev2024092017__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +0 -6
- keras_hub/api/__init__.py +1 -0
- keras_hub/api/models/__init__.py +22 -17
- keras_hub/{src/models/llama3/llama3_preprocessor.py → api/utils/__init__.py} +7 -8
- keras_hub/src/api_export.py +15 -9
- keras_hub/src/models/albert/albert_text_classifier.py +6 -1
- keras_hub/src/models/bert/bert_text_classifier.py +6 -1
- keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +6 -1
- keras_hub/src/models/densenet/densenet_backbone.py +1 -1
- keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +6 -1
- keras_hub/src/models/f_net/f_net_text_classifier.py +6 -1
- keras_hub/src/models/gemma/gemma_decoder_block.py +1 -1
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +7 -78
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +1 -1
- keras_hub/src/models/preprocessor.py +1 -5
- keras_hub/src/models/resnet/resnet_backbone.py +3 -16
- keras_hub/src/models/resnet/resnet_image_classifier.py +26 -3
- keras_hub/src/models/resnet/resnet_presets.py +12 -12
- keras_hub/src/models/retinanet/__init__.py +13 -0
- keras_hub/src/models/retinanet/anchor_generator.py +175 -0
- keras_hub/src/models/retinanet/box_matcher.py +259 -0
- keras_hub/src/models/retinanet/non_max_supression.py +578 -0
- keras_hub/src/models/roberta/roberta_text_classifier.py +6 -1
- keras_hub/src/models/task.py +6 -6
- keras_hub/src/models/text_classifier.py +12 -1
- keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +6 -1
- keras_hub/src/tests/test_case.py +21 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +1 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +1 -0
- keras_hub/src/tokenizers/word_piece_tokenizer.py +1 -0
- keras_hub/src/utils/imagenet/__init__.py +13 -0
- keras_hub/src/utils/imagenet/imagenet_utils.py +1067 -0
- keras_hub/src/utils/preset_utils.py +24 -33
- keras_hub/src/utils/tensor_utils.py +14 -14
- keras_hub/src/utils/timm/convert_resnet.py +0 -1
- keras_hub/src/utils/timm/preset_loader.py +6 -7
- keras_hub/src/version_utils.py +1 -1
- keras_hub_nightly-0.16.0.dev2024092017.dist-info/METADATA +202 -0
- {keras_hub_nightly-0.15.0.dev20240911134614.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/RECORD +41 -45
- {keras_hub_nightly-0.15.0.dev20240911134614.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/WHEEL +1 -1
- keras_hub/src/models/bart/bart_preprocessor.py +0 -264
- keras_hub/src/models/bloom/bloom_preprocessor.py +0 -178
- keras_hub/src/models/electra/electra_preprocessor.py +0 -155
- keras_hub/src/models/falcon/falcon_preprocessor.py +0 -180
- keras_hub/src/models/gemma/gemma_preprocessor.py +0 -184
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +0 -138
- keras_hub/src/models/llama/llama_preprocessor.py +0 -182
- keras_hub/src/models/mistral/mistral_preprocessor.py +0 -183
- keras_hub/src/models/opt/opt_preprocessor.py +0 -181
- keras_hub/src/models/phi3/phi3_preprocessor.py +0 -183
- keras_hub_nightly-0.15.0.dev20240911134614.dist-info/METADATA +0 -33
- {keras_hub_nightly-0.15.0.dev20240911134614.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,578 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import math
|
16
|
+
|
17
|
+
import keras
|
18
|
+
from keras import ops
|
19
|
+
|
20
|
+
from keras_hub.src.bounding_box import converters
|
21
|
+
from keras_hub.src.bounding_box import utils
|
22
|
+
from keras_hub.src.bounding_box import validate_format
|
23
|
+
|
24
|
+
EPSILON = 1e-8
|
25
|
+
|
26
|
+
|
27
|
+
class NonMaxSuppression(keras.layers.Layer):
|
28
|
+
"""A Keras layer that decodes predictions of an object detection model.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
bounding_box_format: The format of bounding boxes of input dataset.
|
32
|
+
Refer
|
33
|
+
TODO: link keras core bounding box docs
|
34
|
+
for more details on supported bounding box formats.
|
35
|
+
from_logits: boolean, True means input score is logits, False means
|
36
|
+
confidence.
|
37
|
+
iou_threshold: a float value in the range [0, 1] representing the
|
38
|
+
minimum IoU threshold for two boxes to be considered
|
39
|
+
same for suppression. Defaults to 0.5.
|
40
|
+
confidence_threshold: a float value in the range [0, 1]. All boxes with
|
41
|
+
confidence below this value will be discarded, defaults to 0.5.
|
42
|
+
max_detections: the maximum detections to consider after nms is applied.
|
43
|
+
A large number may trigger significant memory overhead,
|
44
|
+
defaults to 100.
|
45
|
+
"""
|
46
|
+
|
47
|
+
def __init__(
|
48
|
+
self,
|
49
|
+
bounding_box_format,
|
50
|
+
from_logits,
|
51
|
+
iou_threshold=0.5,
|
52
|
+
confidence_threshold=0.5,
|
53
|
+
max_detections=100,
|
54
|
+
**kwargs,
|
55
|
+
):
|
56
|
+
super().__init__(**kwargs)
|
57
|
+
self.bounding_box_format = bounding_box_format
|
58
|
+
self.from_logits = from_logits
|
59
|
+
self.iou_threshold = iou_threshold
|
60
|
+
self.confidence_threshold = confidence_threshold
|
61
|
+
self.max_detections = max_detections
|
62
|
+
self.built = True
|
63
|
+
|
64
|
+
def call(
|
65
|
+
self, box_prediction, class_prediction, images=None, image_shape=None
|
66
|
+
):
|
67
|
+
"""Accepts images and raw scores, returning bounding box predictions.
|
68
|
+
|
69
|
+
Args:
|
70
|
+
box_prediction: Dense Tensor of shape [batch, boxes, 4] in the
|
71
|
+
`bounding_box_format` specified in the constructor.
|
72
|
+
class_prediction: Dense Tensor of shape [batch, boxes, num_classes].
|
73
|
+
"""
|
74
|
+
target_format = "yxyx"
|
75
|
+
if utils.is_relative(self.bounding_box_format):
|
76
|
+
target_format = utils.as_relative(target_format)
|
77
|
+
|
78
|
+
box_prediction = converters.convert_format(
|
79
|
+
box_prediction,
|
80
|
+
source=self.bounding_box_format,
|
81
|
+
target=target_format,
|
82
|
+
images=images,
|
83
|
+
image_shape=image_shape,
|
84
|
+
)
|
85
|
+
if self.from_logits:
|
86
|
+
class_prediction = ops.sigmoid(class_prediction)
|
87
|
+
|
88
|
+
confidence_prediction = ops.max(class_prediction, axis=-1)
|
89
|
+
|
90
|
+
idx, valid_det = non_max_suppression(
|
91
|
+
box_prediction,
|
92
|
+
confidence_prediction,
|
93
|
+
max_output_size=self.max_detections,
|
94
|
+
iou_threshold=self.iou_threshold,
|
95
|
+
score_threshold=self.confidence_threshold,
|
96
|
+
)
|
97
|
+
|
98
|
+
box_prediction = ops.take_along_axis(
|
99
|
+
box_prediction, ops.expand_dims(idx, axis=-1), axis=1
|
100
|
+
)
|
101
|
+
box_prediction = ops.reshape(
|
102
|
+
box_prediction, (-1, self.max_detections, 4)
|
103
|
+
)
|
104
|
+
confidence_prediction = ops.take_along_axis(
|
105
|
+
confidence_prediction, idx, axis=1
|
106
|
+
)
|
107
|
+
class_prediction = ops.take_along_axis(
|
108
|
+
class_prediction, ops.expand_dims(idx, axis=-1), axis=1
|
109
|
+
)
|
110
|
+
|
111
|
+
box_prediction = converters.convert_format(
|
112
|
+
box_prediction,
|
113
|
+
source=target_format,
|
114
|
+
target=self.bounding_box_format,
|
115
|
+
images=images,
|
116
|
+
image_shape=image_shape,
|
117
|
+
)
|
118
|
+
bounding_boxes = {
|
119
|
+
"boxes": box_prediction,
|
120
|
+
"confidence": confidence_prediction,
|
121
|
+
"classes": ops.argmax(class_prediction, axis=-1),
|
122
|
+
"num_detections": valid_det,
|
123
|
+
}
|
124
|
+
|
125
|
+
# this is required to comply with bounding box format.
|
126
|
+
return mask_invalid_detections(bounding_boxes)
|
127
|
+
|
128
|
+
def get_config(self):
|
129
|
+
config = super().get_config()
|
130
|
+
config.update(
|
131
|
+
{
|
132
|
+
"bounding_box_format": self.bounding_box_format,
|
133
|
+
"from_logits": self.from_logits,
|
134
|
+
"iou_threshold": self.iou_threshold,
|
135
|
+
"confidence_threshold": self.confidence_threshold,
|
136
|
+
"max_detections": self.max_detections,
|
137
|
+
}
|
138
|
+
)
|
139
|
+
return config
|
140
|
+
|
141
|
+
|
142
|
+
def non_max_suppression(
|
143
|
+
boxes,
|
144
|
+
scores,
|
145
|
+
max_output_size,
|
146
|
+
iou_threshold=0.5,
|
147
|
+
score_threshold=0.0,
|
148
|
+
tile_size=512,
|
149
|
+
):
|
150
|
+
"""Non-maximum suppression.
|
151
|
+
|
152
|
+
Ported from https://github.com/tensorflow/tensorflow/blob/v2.12.0/tensorflow/python/ops/image_ops_impl.py#L5368-L5458
|
153
|
+
|
154
|
+
Args:
|
155
|
+
boxes: a tensor of rank 2 or higher with a shape of
|
156
|
+
`[..., num_boxes, 4]`. Dimensions except the last two are batch
|
157
|
+
dimensions. The last dimension represents box coordinates in
|
158
|
+
yxyx format.
|
159
|
+
scores: a tensor of rank 1 or higher with a shape of `[..., num_boxes]`.
|
160
|
+
max_output_size: a scalar integer tensor representing the maximum
|
161
|
+
number of boxes to be selected by non max suppression.
|
162
|
+
iou_threshold: a float representing the threshold for
|
163
|
+
deciding whether boxes overlap too much with respect
|
164
|
+
to IoU (intersection over union).
|
165
|
+
score_threshold: a float representing the threshold for box scores.
|
166
|
+
Boxes with a score that is not larger than this threshold
|
167
|
+
will be suppressed.
|
168
|
+
tile_size: an integer representing the number of boxes in a tile, i.e.,
|
169
|
+
the maximum number of boxes per image that can be used to suppress
|
170
|
+
other boxes in parallel; larger tile_size means larger parallelism
|
171
|
+
and potentially more redundant work.
|
172
|
+
|
173
|
+
Returns:
|
174
|
+
idx: a tensor with a shape of `[..., num_boxes]` representing the
|
175
|
+
indices selected by non-max suppression. The leading dimensions
|
176
|
+
are the batch dimensions of the input boxes. All numbers are within
|
177
|
+
`[0, num_boxes)`. For each image (i.e., `idx[i]`), only the first
|
178
|
+
`num_valid[i]` indices (i.e., `idx[i][:num_valid[i]]`) are valid.
|
179
|
+
num_valid: a tensor of rank 0 or higher with a shape of [...]
|
180
|
+
representing the number of valid indices in idx. Its dimensions
|
181
|
+
are the batch dimensions of the input boxes.
|
182
|
+
"""
|
183
|
+
|
184
|
+
def _sort_scores_and_boxes(scores, boxes):
|
185
|
+
"""Sort boxes based their score from highest to lowest.
|
186
|
+
|
187
|
+
Args:
|
188
|
+
scores: a tensor with a shape of `[batch_size, num_boxes]`
|
189
|
+
representing the scores of boxes.
|
190
|
+
boxes: a tensor with a shape of `[batch_size, num_boxes, 4]`
|
191
|
+
representing the boxes.
|
192
|
+
|
193
|
+
Returns:
|
194
|
+
sorted_scores: a tensor with a shape of
|
195
|
+
`[batch_size, num_boxes]` representing the sorted scores.
|
196
|
+
sorted_boxes: a tensor representing the sorted boxes.
|
197
|
+
sorted_scores_indices: a tensor with a shape of
|
198
|
+
`[batch_size, num_boxes]` representing the index of the scores
|
199
|
+
in a sorted descending order.
|
200
|
+
"""
|
201
|
+
sorted_scores_indices = ops.flip(
|
202
|
+
ops.cast(ops.argsort(scores, axis=1), "int32"), axis=1
|
203
|
+
)
|
204
|
+
sorted_scores = ops.take_along_axis(
|
205
|
+
scores,
|
206
|
+
sorted_scores_indices,
|
207
|
+
axis=1,
|
208
|
+
)
|
209
|
+
sorted_boxes = ops.take_along_axis(
|
210
|
+
boxes,
|
211
|
+
ops.expand_dims(sorted_scores_indices, axis=-1),
|
212
|
+
axis=1,
|
213
|
+
)
|
214
|
+
return sorted_scores, sorted_boxes, sorted_scores_indices
|
215
|
+
|
216
|
+
batch_dims = ops.shape(boxes)[:-2]
|
217
|
+
num_boxes = boxes.shape[-2]
|
218
|
+
boxes = ops.reshape(boxes, [-1, num_boxes, 4])
|
219
|
+
scores = ops.reshape(scores, [-1, num_boxes])
|
220
|
+
batch_size = boxes.shape[0]
|
221
|
+
if score_threshold != float("-inf"):
|
222
|
+
score_mask = ops.cast(scores > score_threshold, scores.dtype)
|
223
|
+
scores *= score_mask
|
224
|
+
box_mask = ops.expand_dims(ops.cast(score_mask, boxes.dtype), 2)
|
225
|
+
boxes *= box_mask
|
226
|
+
|
227
|
+
scores, boxes, sorted_indices = _sort_scores_and_boxes(scores, boxes)
|
228
|
+
|
229
|
+
pad = (
|
230
|
+
math.ceil(max(num_boxes, max_output_size) / tile_size) * tile_size
|
231
|
+
- num_boxes
|
232
|
+
)
|
233
|
+
boxes = ops.pad(ops.cast(boxes, "float32"), [[0, 0], [0, pad], [0, 0]])
|
234
|
+
scores = ops.pad(ops.cast(scores, "float32"), [[0, 0], [0, pad]])
|
235
|
+
num_boxes_after_padding = num_boxes + pad
|
236
|
+
num_iterations = num_boxes_after_padding // tile_size
|
237
|
+
|
238
|
+
def _loop_cond(unused_boxes, unused_threshold, output_size, idx):
|
239
|
+
return ops.logical_and(
|
240
|
+
ops.min(output_size) < ops.cast(max_output_size, "int32"),
|
241
|
+
ops.cast(idx, "int32") < num_iterations,
|
242
|
+
)
|
243
|
+
|
244
|
+
def suppression_loop_body(boxes, iou_threshold, output_size, idx):
|
245
|
+
return _suppression_loop_body(
|
246
|
+
boxes, iou_threshold, output_size, idx, tile_size
|
247
|
+
)
|
248
|
+
|
249
|
+
selected_boxes, _, output_size, _ = ops.while_loop(
|
250
|
+
_loop_cond,
|
251
|
+
suppression_loop_body,
|
252
|
+
[
|
253
|
+
boxes,
|
254
|
+
iou_threshold,
|
255
|
+
ops.zeros([batch_size], "int32"),
|
256
|
+
ops.array(0),
|
257
|
+
],
|
258
|
+
)
|
259
|
+
num_valid = ops.minimum(output_size, max_output_size)
|
260
|
+
idx = num_boxes_after_padding - ops.cast(
|
261
|
+
ops.top_k(
|
262
|
+
ops.cast(ops.any(selected_boxes > 0, [2]), "int32")
|
263
|
+
* ops.cast(
|
264
|
+
ops.expand_dims(ops.arange(num_boxes_after_padding, 0, -1), 0),
|
265
|
+
"int32",
|
266
|
+
),
|
267
|
+
max_output_size,
|
268
|
+
)[0],
|
269
|
+
"int32",
|
270
|
+
)
|
271
|
+
idx = ops.minimum(idx, num_boxes - 1)
|
272
|
+
|
273
|
+
index_offsets = ops.cast(ops.arange(batch_size) * num_boxes, "int32")
|
274
|
+
take_along_axis_idx = ops.reshape(
|
275
|
+
idx + ops.expand_dims(index_offsets, 1), [-1]
|
276
|
+
)
|
277
|
+
|
278
|
+
if keras.backend.backend() != "tensorflow":
|
279
|
+
idx = ops.take_along_axis(
|
280
|
+
ops.reshape(sorted_indices, [-1]), take_along_axis_idx
|
281
|
+
)
|
282
|
+
else:
|
283
|
+
import tensorflow as tf
|
284
|
+
|
285
|
+
idx = tf.gather(ops.reshape(sorted_indices, [-1]), take_along_axis_idx)
|
286
|
+
idx = ops.reshape(idx, [batch_size, -1])
|
287
|
+
|
288
|
+
invalid_index = ops.zeros([batch_size, max_output_size], dtype="int32")
|
289
|
+
idx_index = ops.cast(
|
290
|
+
ops.expand_dims(ops.arange(max_output_size), 0), "int32"
|
291
|
+
)
|
292
|
+
num_valid_expanded = ops.expand_dims(num_valid, 1)
|
293
|
+
idx = ops.where(idx_index < num_valid_expanded, idx, invalid_index)
|
294
|
+
|
295
|
+
num_valid = ops.reshape(num_valid, batch_dims)
|
296
|
+
return idx, num_valid
|
297
|
+
|
298
|
+
|
299
|
+
def _bbox_overlap(boxes_a, boxes_b):
|
300
|
+
"""Calculates the overlap (iou - intersection over union) between boxes_a
|
301
|
+
and boxes_b.
|
302
|
+
|
303
|
+
Args:
|
304
|
+
boxes_a: a tensor with a shape of `[batch_size, N, 4]`.
|
305
|
+
`N` is the number of boxes per image. The last dimension is the
|
306
|
+
pixel coordinates in `[ymin, xmin, ymax, xmax]` form.
|
307
|
+
boxes_b: a tensor with a shape of `[batch_size, M, 4]`. M is the number of
|
308
|
+
boxes. The last dimension is the pixel coordinates in
|
309
|
+
`[ymin, xmin, ymax, xmax]` form.
|
310
|
+
|
311
|
+
Returns:
|
312
|
+
intersection_over_union: a tensor with as a shape of
|
313
|
+
`[batch_size, N, M]`, representing the ratio of intersection area
|
314
|
+
over union area (IoU) between two boxes
|
315
|
+
"""
|
316
|
+
if len(boxes_a.shape) == 4:
|
317
|
+
boxes_a = ops.squeeze(boxes_a, axis=0)
|
318
|
+
a_y_min, a_x_min, a_y_max, a_x_max = ops.split(boxes_a, 4, axis=2)
|
319
|
+
b_y_min, b_x_min, b_y_max, b_x_max = ops.split(boxes_b, 4, axis=2)
|
320
|
+
|
321
|
+
# Calculates the intersection area.
|
322
|
+
i_xmin = ops.maximum(a_x_min, ops.transpose(b_x_min, [0, 2, 1]))
|
323
|
+
i_xmax = ops.minimum(a_x_max, ops.transpose(b_x_max, [0, 2, 1]))
|
324
|
+
i_ymin = ops.maximum(a_y_min, ops.transpose(b_y_min, [0, 2, 1]))
|
325
|
+
i_ymax = ops.minimum(a_y_max, ops.transpose(b_y_max, [0, 2, 1]))
|
326
|
+
i_area = ops.maximum((i_xmax - i_xmin), 0) * ops.maximum(
|
327
|
+
(i_ymax - i_ymin), 0
|
328
|
+
)
|
329
|
+
|
330
|
+
# Calculates the union area.
|
331
|
+
a_area = (a_y_max - a_y_min) * (a_x_max - a_x_min)
|
332
|
+
b_area = (b_y_max - b_y_min) * (b_x_max - b_x_min)
|
333
|
+
|
334
|
+
# Adds a small epsilon to avoid divide-by-zero.
|
335
|
+
u_area = a_area + ops.transpose(b_area, [0, 2, 1]) - i_area + EPSILON
|
336
|
+
|
337
|
+
intersection_over_union = i_area / u_area
|
338
|
+
|
339
|
+
return intersection_over_union
|
340
|
+
|
341
|
+
|
342
|
+
def _self_suppression(iou, _, iou_sum, iou_threshold):
|
343
|
+
"""Suppress boxes in the same tile.
|
344
|
+
|
345
|
+
Compute boxes that cannot be suppressed by others (i.e.,
|
346
|
+
can_suppress_others), and then use them to suppress boxes in the same tile.
|
347
|
+
|
348
|
+
Args:
|
349
|
+
iou: a tensor of shape `[batch_size, num_boxes_with_padding]`
|
350
|
+
representing intersection over union.
|
351
|
+
iou_sum: a scalar tensor.
|
352
|
+
iou_threshold: a scalar tensor.
|
353
|
+
|
354
|
+
Returns:
|
355
|
+
iou_suppressed: a tensor of shape
|
356
|
+
`[batch_size, num_boxes_with_padding]`.
|
357
|
+
iou_diff: a scalar tensor representing whether any box is supressed in
|
358
|
+
this step.
|
359
|
+
iou_sum_new: a scalar tensor of shape `[batch_size]` that represents
|
360
|
+
the iou sum after suppression.
|
361
|
+
iou_threshold: a scalar tensor.
|
362
|
+
"""
|
363
|
+
batch_size = ops.shape(iou)[0]
|
364
|
+
can_suppress_others = ops.cast(
|
365
|
+
ops.reshape(ops.max(iou, 1) < iou_threshold, [batch_size, -1, 1]),
|
366
|
+
iou.dtype,
|
367
|
+
)
|
368
|
+
iou_after_suppression = (
|
369
|
+
ops.reshape(
|
370
|
+
ops.cast(
|
371
|
+
ops.max(can_suppress_others * iou, 1) < iou_threshold, iou.dtype
|
372
|
+
),
|
373
|
+
[batch_size, -1, 1],
|
374
|
+
)
|
375
|
+
* iou
|
376
|
+
)
|
377
|
+
iou_sum_new = ops.sum(iou_after_suppression, [1, 2])
|
378
|
+
return [
|
379
|
+
iou_after_suppression,
|
380
|
+
ops.any(iou_sum - iou_sum_new > iou_threshold),
|
381
|
+
iou_sum_new,
|
382
|
+
iou_threshold,
|
383
|
+
]
|
384
|
+
|
385
|
+
|
386
|
+
def _cross_suppression(boxes, box_slice, iou_threshold, inner_idx, tile_size):
|
387
|
+
"""Suppress boxes between different tiles.
|
388
|
+
|
389
|
+
Args:
|
390
|
+
boxes: a tensor of shape `[batch_size, num_boxes_with_padding, 4]`
|
391
|
+
box_slice: a tensor of shape `[batch_size, tile_size, 4]`
|
392
|
+
iou_threshold: a scalar tensor
|
393
|
+
inner_idx: a scalar tensor representing the tile index of the tile
|
394
|
+
that is used to supress box_slice
|
395
|
+
tile_size: an integer representing the number of boxes in a tile
|
396
|
+
|
397
|
+
Returns:
|
398
|
+
boxes: unchanged boxes as input
|
399
|
+
box_slice_after_suppression: box_slice after suppression
|
400
|
+
iou_threshold: unchanged
|
401
|
+
"""
|
402
|
+
slice_index = ops.expand_dims(
|
403
|
+
ops.expand_dims(
|
404
|
+
ops.cast(
|
405
|
+
ops.linspace(
|
406
|
+
inner_idx * tile_size,
|
407
|
+
(inner_idx + 1) * tile_size - 1,
|
408
|
+
tile_size,
|
409
|
+
),
|
410
|
+
"int32",
|
411
|
+
),
|
412
|
+
axis=0,
|
413
|
+
),
|
414
|
+
axis=-1,
|
415
|
+
)
|
416
|
+
new_slice = ops.expand_dims(
|
417
|
+
ops.take_along_axis(boxes, slice_index, axis=1), 0
|
418
|
+
)
|
419
|
+
iou = _bbox_overlap(new_slice, box_slice)
|
420
|
+
box_slice_after_suppression = (
|
421
|
+
ops.expand_dims(
|
422
|
+
ops.cast(ops.all(iou < iou_threshold, [1]), box_slice.dtype), 2
|
423
|
+
)
|
424
|
+
* box_slice
|
425
|
+
)
|
426
|
+
return boxes, box_slice_after_suppression, iou_threshold, inner_idx + 1
|
427
|
+
|
428
|
+
|
429
|
+
def _suppression_loop_body(boxes, iou_threshold, output_size, idx, tile_size):
|
430
|
+
"""Process boxes in the range [idx*tile_size, (idx+1)*tile_size).
|
431
|
+
|
432
|
+
Args:
|
433
|
+
boxes: a tensor with a shape of [batch_size, anchors, 4].
|
434
|
+
iou_threshold: a float representing the threshold for deciding whether
|
435
|
+
boxes overlap too much with respect to IOU.
|
436
|
+
output_size: an int32 tensor of size [batch_size]. Representing the
|
437
|
+
number of selected boxes for each batch.
|
438
|
+
idx: an integer scalar representing induction variable.
|
439
|
+
tile_size: an integer representing the number of boxes in a tile
|
440
|
+
|
441
|
+
Returns:
|
442
|
+
boxes: updated boxes.
|
443
|
+
iou_threshold: pass down iou_threshold to the next iteration.
|
444
|
+
output_size: the updated output_size.
|
445
|
+
idx: the updated induction variable.
|
446
|
+
"""
|
447
|
+
num_tiles = boxes.shape[1] // tile_size
|
448
|
+
batch_size = boxes.shape[0]
|
449
|
+
|
450
|
+
def cross_suppression_func(boxes, box_slice, iou_threshold, inner_idx):
|
451
|
+
return _cross_suppression(
|
452
|
+
boxes, box_slice, iou_threshold, inner_idx, tile_size
|
453
|
+
)
|
454
|
+
|
455
|
+
# Iterates over tiles that can possibly suppress the current tile.
|
456
|
+
slice_index = ops.expand_dims(
|
457
|
+
ops.expand_dims(
|
458
|
+
ops.cast(
|
459
|
+
ops.linspace(
|
460
|
+
idx * tile_size, (idx + 1) * tile_size - 1, tile_size
|
461
|
+
),
|
462
|
+
"int32",
|
463
|
+
),
|
464
|
+
axis=0,
|
465
|
+
),
|
466
|
+
axis=-1,
|
467
|
+
)
|
468
|
+
box_slice = ops.take_along_axis(boxes, slice_index, axis=1)
|
469
|
+
_, box_slice, _, _ = ops.while_loop(
|
470
|
+
lambda _boxes, _box_slice, _threshold, inner_idx: inner_idx < idx,
|
471
|
+
cross_suppression_func,
|
472
|
+
[boxes, box_slice, iou_threshold, ops.array(0)],
|
473
|
+
)
|
474
|
+
|
475
|
+
# Iterates over the current tile to compute self-suppression.
|
476
|
+
iou = _bbox_overlap(box_slice, box_slice)
|
477
|
+
mask = ops.expand_dims(
|
478
|
+
ops.reshape(ops.arange(tile_size), [1, -1])
|
479
|
+
> ops.reshape(ops.arange(tile_size), [-1, 1]),
|
480
|
+
0,
|
481
|
+
)
|
482
|
+
iou *= ops.cast(ops.logical_and(mask, iou >= iou_threshold), iou.dtype)
|
483
|
+
suppressed_iou, _, _, _ = ops.while_loop(
|
484
|
+
lambda _iou, loop_condition, _iou_sum, _: loop_condition,
|
485
|
+
_self_suppression,
|
486
|
+
[iou, ops.array(True), ops.sum(iou, [1, 2]), iou_threshold],
|
487
|
+
)
|
488
|
+
suppressed_box = ops.sum(suppressed_iou, 1) > 0
|
489
|
+
box_slice *= ops.expand_dims(
|
490
|
+
1.0 - ops.cast(suppressed_box, box_slice.dtype), 2
|
491
|
+
)
|
492
|
+
|
493
|
+
# Uses box_slice to update the input boxes.
|
494
|
+
mask = ops.reshape(
|
495
|
+
ops.cast(ops.equal(ops.arange(num_tiles), idx), boxes.dtype),
|
496
|
+
[1, -1, 1, 1],
|
497
|
+
)
|
498
|
+
boxes = ops.tile(
|
499
|
+
ops.expand_dims(box_slice, 1), [1, num_tiles, 1, 1]
|
500
|
+
) * mask + ops.reshape(boxes, [batch_size, num_tiles, tile_size, 4]) * (
|
501
|
+
1 - mask
|
502
|
+
)
|
503
|
+
boxes = ops.reshape(boxes, [batch_size, -1, 4])
|
504
|
+
|
505
|
+
# Updates output_size.
|
506
|
+
output_size += ops.cast(ops.sum(ops.any(box_slice > 0, [2]), [1]), "int32")
|
507
|
+
return boxes, iou_threshold, output_size, idx + 1
|
508
|
+
|
509
|
+
|
510
|
+
def mask_invalid_detections(bounding_boxes):
|
511
|
+
"""masks out invalid detections with -1s.
|
512
|
+
|
513
|
+
This utility is mainly used on the output of non-max suppression operations.
|
514
|
+
The output of non-max-suppression contains all the detections, even invalid
|
515
|
+
ones. Users are expected to use `num_detections` to determine how many boxes
|
516
|
+
are in each image.
|
517
|
+
|
518
|
+
In contrast, KerasHub expects all bounding boxes to be padded with -1s.
|
519
|
+
This function uses the value of `num_detections` to mask out
|
520
|
+
invalid boxes with -1s.
|
521
|
+
|
522
|
+
Args:
|
523
|
+
bounding_boxes: a dictionary complying with Keras bounding box format.
|
524
|
+
In addition to the normal required keys, these boxes are also
|
525
|
+
expected to have a `num_detections` key.
|
526
|
+
|
527
|
+
Returns:
|
528
|
+
bounding boxes with proper masking of the boxes according to
|
529
|
+
`num_detections`. This allows proper interop with non-max supression.
|
530
|
+
Returned boxes match the specification fed to the function, so if the
|
531
|
+
bounding box tensor uses `tf.RaggedTensor` to represent boxes the
|
532
|
+
returned value will also return `tf.RaggedTensor` representations.
|
533
|
+
"""
|
534
|
+
# ensure we are complying with Keras bounding box format.
|
535
|
+
info = validate_format.validate_format(bounding_boxes)
|
536
|
+
if info["ragged"]:
|
537
|
+
raise ValueError(
|
538
|
+
"`bounding_box.mask_invalid_detections()` requires inputs to be "
|
539
|
+
"Dense tensors. Please call "
|
540
|
+
"`bounding_box.to_dense(bounding_boxes)` before passing your boxes "
|
541
|
+
"to `bounding_box.mask_invalid_detections()`."
|
542
|
+
)
|
543
|
+
if "num_detections" not in bounding_boxes:
|
544
|
+
raise ValueError(
|
545
|
+
"`bounding_boxes` must have key 'num_detections' "
|
546
|
+
"to be used with `bounding_box.mask_invalid_detections()`."
|
547
|
+
)
|
548
|
+
|
549
|
+
boxes = bounding_boxes.get("boxes")
|
550
|
+
classes = bounding_boxes.get("classes")
|
551
|
+
confidence = bounding_boxes.get("confidence", None)
|
552
|
+
num_detections = bounding_boxes.get("num_detections")
|
553
|
+
|
554
|
+
# Create a mask to select only the first N boxes from each batch
|
555
|
+
mask = ops.cast(
|
556
|
+
ops.expand_dims(ops.arange(boxes.shape[1]), axis=0),
|
557
|
+
num_detections.dtype,
|
558
|
+
)
|
559
|
+
mask = mask < num_detections[:, None]
|
560
|
+
|
561
|
+
classes = ops.where(mask, classes, -ops.ones_like(classes))
|
562
|
+
|
563
|
+
if confidence is not None:
|
564
|
+
confidence = ops.where(mask, confidence, -ops.ones_like(confidence))
|
565
|
+
|
566
|
+
# reuse mask for boxes
|
567
|
+
mask = ops.expand_dims(mask, axis=-1)
|
568
|
+
mask = ops.repeat(mask, repeats=boxes.shape[-1], axis=-1)
|
569
|
+
boxes = ops.where(mask, boxes, -ops.ones_like(boxes))
|
570
|
+
|
571
|
+
result = bounding_boxes.copy()
|
572
|
+
|
573
|
+
result["boxes"] = boxes
|
574
|
+
result["classes"] = classes
|
575
|
+
if confidence is not None:
|
576
|
+
result["confidence"] = confidence
|
577
|
+
|
578
|
+
return result
|
@@ -26,7 +26,12 @@ from keras_hub.src.models.roberta.roberta_text_classifier_preprocessor import (
|
|
26
26
|
from keras_hub.src.models.text_classifier import TextClassifier
|
27
27
|
|
28
28
|
|
29
|
-
@keras_hub_export(
|
29
|
+
@keras_hub_export(
|
30
|
+
[
|
31
|
+
"keras_hub.models.RobertaTextClassifier",
|
32
|
+
"keras_hub.models.RobertaClassifier",
|
33
|
+
]
|
34
|
+
)
|
30
35
|
class RobertaTextClassifier(TextClassifier):
|
31
36
|
"""An end-to-end RoBERTa model for classification tasks.
|
32
37
|
|
keras_hub/src/models/task.py
CHANGED
@@ -139,7 +139,6 @@ class Task(PipelineModel):
|
|
139
139
|
cls,
|
140
140
|
preset,
|
141
141
|
load_weights=True,
|
142
|
-
load_task_extras=False,
|
143
142
|
**kwargs,
|
144
143
|
):
|
145
144
|
"""Instantiate a `keras_hub.models.Task` from a model preset.
|
@@ -168,10 +167,6 @@ class Task(PipelineModel):
|
|
168
167
|
load_weights: bool. If `True`, saved weights will be loaded into
|
169
168
|
the model architecture. If `False`, all weights will be
|
170
169
|
randomly initialized.
|
171
|
-
load_task_extras: bool. If `True`, load the saved task configuration
|
172
|
-
from a `task.json` and any task specific weights from
|
173
|
-
`task.weights`. You might use this to load a classification
|
174
|
-
head for a model that has been saved with it.
|
175
170
|
|
176
171
|
Examples:
|
177
172
|
```python
|
@@ -199,7 +194,12 @@ class Task(PipelineModel):
|
|
199
194
|
# Detect the correct subclass if we need to.
|
200
195
|
if cls.backbone_cls != backbone_cls:
|
201
196
|
cls = find_subclass(preset, cls, backbone_cls)
|
202
|
-
|
197
|
+
# Specifically for classifiers, we never load task weights if
|
198
|
+
# num_classes is supplied. We handle this in the task base class because
|
199
|
+
# it is the same logic for classifiers regardless of modality (text,
|
200
|
+
# images, audio).
|
201
|
+
load_task_weights = "num_classes" not in kwargs
|
202
|
+
return loader.load_task(cls, load_weights, load_task_weights, **kwargs)
|
203
203
|
|
204
204
|
def load_task_weights(self, filepath):
|
205
205
|
"""Load only the tasks specific weights not in the backbone."""
|
@@ -17,7 +17,12 @@ from keras_hub.src.api_export import keras_hub_export
|
|
17
17
|
from keras_hub.src.models.task import Task
|
18
18
|
|
19
19
|
|
20
|
-
@keras_hub_export(
|
20
|
+
@keras_hub_export(
|
21
|
+
[
|
22
|
+
"keras_hub.models.TextClassifier",
|
23
|
+
"keras_hub.models.Classifier",
|
24
|
+
]
|
25
|
+
)
|
21
26
|
class TextClassifier(Task):
|
22
27
|
"""Base class for all classification tasks.
|
23
28
|
|
@@ -32,6 +37,12 @@ class TextClassifier(Task):
|
|
32
37
|
All `TextClassifier` tasks include a `from_preset()` constructor which can be
|
33
38
|
used to load a pre-trained config and weights.
|
34
39
|
|
40
|
+
Some, but not all, classification presets include classification head
|
41
|
+
weights in a `task.weights.h5` file. For these presets, you can omit passing
|
42
|
+
`num_classes` to restore the saved classification head. For all presets, if
|
43
|
+
`num_classes` is passed as a kwarg to `from_preset()`, the classification
|
44
|
+
head will be randomly initialized.
|
45
|
+
|
35
46
|
Example:
|
36
47
|
```python
|
37
48
|
# Load a BERT classifier with pre-trained weights.
|
@@ -28,7 +28,12 @@ from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier_preprocessor i
|
|
28
28
|
)
|
29
29
|
|
30
30
|
|
31
|
-
@keras_hub_export(
|
31
|
+
@keras_hub_export(
|
32
|
+
[
|
33
|
+
"keras_hub.models.XLMRobertaTextClassifier",
|
34
|
+
"keras_hub.models.XLMRobertaClassifier",
|
35
|
+
]
|
36
|
+
)
|
32
37
|
class XLMRobertaTextClassifier(TextClassifier):
|
33
38
|
"""An end-to-end XLM-RoBERTa model for classification tasks.
|
34
39
|
|