keras-hub-nightly 0.15.0.dev20240911134614__py3-none-any.whl → 0.16.0.dev2024092017__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. keras_hub/__init__.py +0 -6
  2. keras_hub/api/__init__.py +1 -0
  3. keras_hub/api/models/__init__.py +22 -17
  4. keras_hub/{src/models/llama3/llama3_preprocessor.py → api/utils/__init__.py} +7 -8
  5. keras_hub/src/api_export.py +15 -9
  6. keras_hub/src/models/albert/albert_text_classifier.py +6 -1
  7. keras_hub/src/models/bert/bert_text_classifier.py +6 -1
  8. keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +6 -1
  9. keras_hub/src/models/densenet/densenet_backbone.py +1 -1
  10. keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +6 -1
  11. keras_hub/src/models/f_net/f_net_text_classifier.py +6 -1
  12. keras_hub/src/models/gemma/gemma_decoder_block.py +1 -1
  13. keras_hub/src/models/gpt2/gpt2_preprocessor.py +7 -78
  14. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +1 -1
  15. keras_hub/src/models/preprocessor.py +1 -5
  16. keras_hub/src/models/resnet/resnet_backbone.py +3 -16
  17. keras_hub/src/models/resnet/resnet_image_classifier.py +26 -3
  18. keras_hub/src/models/resnet/resnet_presets.py +12 -12
  19. keras_hub/src/models/retinanet/__init__.py +13 -0
  20. keras_hub/src/models/retinanet/anchor_generator.py +175 -0
  21. keras_hub/src/models/retinanet/box_matcher.py +259 -0
  22. keras_hub/src/models/retinanet/non_max_supression.py +578 -0
  23. keras_hub/src/models/roberta/roberta_text_classifier.py +6 -1
  24. keras_hub/src/models/task.py +6 -6
  25. keras_hub/src/models/text_classifier.py +12 -1
  26. keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +6 -1
  27. keras_hub/src/tests/test_case.py +21 -0
  28. keras_hub/src/tokenizers/byte_pair_tokenizer.py +1 -0
  29. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +1 -0
  30. keras_hub/src/tokenizers/word_piece_tokenizer.py +1 -0
  31. keras_hub/src/utils/imagenet/__init__.py +13 -0
  32. keras_hub/src/utils/imagenet/imagenet_utils.py +1067 -0
  33. keras_hub/src/utils/preset_utils.py +24 -33
  34. keras_hub/src/utils/tensor_utils.py +14 -14
  35. keras_hub/src/utils/timm/convert_resnet.py +0 -1
  36. keras_hub/src/utils/timm/preset_loader.py +6 -7
  37. keras_hub/src/version_utils.py +1 -1
  38. keras_hub_nightly-0.16.0.dev2024092017.dist-info/METADATA +202 -0
  39. {keras_hub_nightly-0.15.0.dev20240911134614.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/RECORD +41 -45
  40. {keras_hub_nightly-0.15.0.dev20240911134614.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/WHEEL +1 -1
  41. keras_hub/src/models/bart/bart_preprocessor.py +0 -264
  42. keras_hub/src/models/bloom/bloom_preprocessor.py +0 -178
  43. keras_hub/src/models/electra/electra_preprocessor.py +0 -155
  44. keras_hub/src/models/falcon/falcon_preprocessor.py +0 -180
  45. keras_hub/src/models/gemma/gemma_preprocessor.py +0 -184
  46. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +0 -138
  47. keras_hub/src/models/llama/llama_preprocessor.py +0 -182
  48. keras_hub/src/models/mistral/mistral_preprocessor.py +0 -183
  49. keras_hub/src/models/opt/opt_preprocessor.py +0 -181
  50. keras_hub/src/models/phi3/phi3_preprocessor.py +0 -183
  51. keras_hub_nightly-0.15.0.dev20240911134614.dist-info/METADATA +0 -33
  52. {keras_hub_nightly-0.15.0.dev20240911134614.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,578 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+
17
+ import keras
18
+ from keras import ops
19
+
20
+ from keras_hub.src.bounding_box import converters
21
+ from keras_hub.src.bounding_box import utils
22
+ from keras_hub.src.bounding_box import validate_format
23
+
24
+ EPSILON = 1e-8
25
+
26
+
27
+ class NonMaxSuppression(keras.layers.Layer):
28
+ """A Keras layer that decodes predictions of an object detection model.
29
+
30
+ Args:
31
+ bounding_box_format: The format of bounding boxes of input dataset.
32
+ Refer
33
+ TODO: link keras core bounding box docs
34
+ for more details on supported bounding box formats.
35
+ from_logits: boolean, True means input score is logits, False means
36
+ confidence.
37
+ iou_threshold: a float value in the range [0, 1] representing the
38
+ minimum IoU threshold for two boxes to be considered
39
+ same for suppression. Defaults to 0.5.
40
+ confidence_threshold: a float value in the range [0, 1]. All boxes with
41
+ confidence below this value will be discarded, defaults to 0.5.
42
+ max_detections: the maximum detections to consider after nms is applied.
43
+ A large number may trigger significant memory overhead,
44
+ defaults to 100.
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ bounding_box_format,
50
+ from_logits,
51
+ iou_threshold=0.5,
52
+ confidence_threshold=0.5,
53
+ max_detections=100,
54
+ **kwargs,
55
+ ):
56
+ super().__init__(**kwargs)
57
+ self.bounding_box_format = bounding_box_format
58
+ self.from_logits = from_logits
59
+ self.iou_threshold = iou_threshold
60
+ self.confidence_threshold = confidence_threshold
61
+ self.max_detections = max_detections
62
+ self.built = True
63
+
64
+ def call(
65
+ self, box_prediction, class_prediction, images=None, image_shape=None
66
+ ):
67
+ """Accepts images and raw scores, returning bounding box predictions.
68
+
69
+ Args:
70
+ box_prediction: Dense Tensor of shape [batch, boxes, 4] in the
71
+ `bounding_box_format` specified in the constructor.
72
+ class_prediction: Dense Tensor of shape [batch, boxes, num_classes].
73
+ """
74
+ target_format = "yxyx"
75
+ if utils.is_relative(self.bounding_box_format):
76
+ target_format = utils.as_relative(target_format)
77
+
78
+ box_prediction = converters.convert_format(
79
+ box_prediction,
80
+ source=self.bounding_box_format,
81
+ target=target_format,
82
+ images=images,
83
+ image_shape=image_shape,
84
+ )
85
+ if self.from_logits:
86
+ class_prediction = ops.sigmoid(class_prediction)
87
+
88
+ confidence_prediction = ops.max(class_prediction, axis=-1)
89
+
90
+ idx, valid_det = non_max_suppression(
91
+ box_prediction,
92
+ confidence_prediction,
93
+ max_output_size=self.max_detections,
94
+ iou_threshold=self.iou_threshold,
95
+ score_threshold=self.confidence_threshold,
96
+ )
97
+
98
+ box_prediction = ops.take_along_axis(
99
+ box_prediction, ops.expand_dims(idx, axis=-1), axis=1
100
+ )
101
+ box_prediction = ops.reshape(
102
+ box_prediction, (-1, self.max_detections, 4)
103
+ )
104
+ confidence_prediction = ops.take_along_axis(
105
+ confidence_prediction, idx, axis=1
106
+ )
107
+ class_prediction = ops.take_along_axis(
108
+ class_prediction, ops.expand_dims(idx, axis=-1), axis=1
109
+ )
110
+
111
+ box_prediction = converters.convert_format(
112
+ box_prediction,
113
+ source=target_format,
114
+ target=self.bounding_box_format,
115
+ images=images,
116
+ image_shape=image_shape,
117
+ )
118
+ bounding_boxes = {
119
+ "boxes": box_prediction,
120
+ "confidence": confidence_prediction,
121
+ "classes": ops.argmax(class_prediction, axis=-1),
122
+ "num_detections": valid_det,
123
+ }
124
+
125
+ # this is required to comply with bounding box format.
126
+ return mask_invalid_detections(bounding_boxes)
127
+
128
+ def get_config(self):
129
+ config = super().get_config()
130
+ config.update(
131
+ {
132
+ "bounding_box_format": self.bounding_box_format,
133
+ "from_logits": self.from_logits,
134
+ "iou_threshold": self.iou_threshold,
135
+ "confidence_threshold": self.confidence_threshold,
136
+ "max_detections": self.max_detections,
137
+ }
138
+ )
139
+ return config
140
+
141
+
142
+ def non_max_suppression(
143
+ boxes,
144
+ scores,
145
+ max_output_size,
146
+ iou_threshold=0.5,
147
+ score_threshold=0.0,
148
+ tile_size=512,
149
+ ):
150
+ """Non-maximum suppression.
151
+
152
+ Ported from https://github.com/tensorflow/tensorflow/blob/v2.12.0/tensorflow/python/ops/image_ops_impl.py#L5368-L5458
153
+
154
+ Args:
155
+ boxes: a tensor of rank 2 or higher with a shape of
156
+ `[..., num_boxes, 4]`. Dimensions except the last two are batch
157
+ dimensions. The last dimension represents box coordinates in
158
+ yxyx format.
159
+ scores: a tensor of rank 1 or higher with a shape of `[..., num_boxes]`.
160
+ max_output_size: a scalar integer tensor representing the maximum
161
+ number of boxes to be selected by non max suppression.
162
+ iou_threshold: a float representing the threshold for
163
+ deciding whether boxes overlap too much with respect
164
+ to IoU (intersection over union).
165
+ score_threshold: a float representing the threshold for box scores.
166
+ Boxes with a score that is not larger than this threshold
167
+ will be suppressed.
168
+ tile_size: an integer representing the number of boxes in a tile, i.e.,
169
+ the maximum number of boxes per image that can be used to suppress
170
+ other boxes in parallel; larger tile_size means larger parallelism
171
+ and potentially more redundant work.
172
+
173
+ Returns:
174
+ idx: a tensor with a shape of `[..., num_boxes]` representing the
175
+ indices selected by non-max suppression. The leading dimensions
176
+ are the batch dimensions of the input boxes. All numbers are within
177
+ `[0, num_boxes)`. For each image (i.e., `idx[i]`), only the first
178
+ `num_valid[i]` indices (i.e., `idx[i][:num_valid[i]]`) are valid.
179
+ num_valid: a tensor of rank 0 or higher with a shape of [...]
180
+ representing the number of valid indices in idx. Its dimensions
181
+ are the batch dimensions of the input boxes.
182
+ """
183
+
184
+ def _sort_scores_and_boxes(scores, boxes):
185
+ """Sort boxes based their score from highest to lowest.
186
+
187
+ Args:
188
+ scores: a tensor with a shape of `[batch_size, num_boxes]`
189
+ representing the scores of boxes.
190
+ boxes: a tensor with a shape of `[batch_size, num_boxes, 4]`
191
+ representing the boxes.
192
+
193
+ Returns:
194
+ sorted_scores: a tensor with a shape of
195
+ `[batch_size, num_boxes]` representing the sorted scores.
196
+ sorted_boxes: a tensor representing the sorted boxes.
197
+ sorted_scores_indices: a tensor with a shape of
198
+ `[batch_size, num_boxes]` representing the index of the scores
199
+ in a sorted descending order.
200
+ """
201
+ sorted_scores_indices = ops.flip(
202
+ ops.cast(ops.argsort(scores, axis=1), "int32"), axis=1
203
+ )
204
+ sorted_scores = ops.take_along_axis(
205
+ scores,
206
+ sorted_scores_indices,
207
+ axis=1,
208
+ )
209
+ sorted_boxes = ops.take_along_axis(
210
+ boxes,
211
+ ops.expand_dims(sorted_scores_indices, axis=-1),
212
+ axis=1,
213
+ )
214
+ return sorted_scores, sorted_boxes, sorted_scores_indices
215
+
216
+ batch_dims = ops.shape(boxes)[:-2]
217
+ num_boxes = boxes.shape[-2]
218
+ boxes = ops.reshape(boxes, [-1, num_boxes, 4])
219
+ scores = ops.reshape(scores, [-1, num_boxes])
220
+ batch_size = boxes.shape[0]
221
+ if score_threshold != float("-inf"):
222
+ score_mask = ops.cast(scores > score_threshold, scores.dtype)
223
+ scores *= score_mask
224
+ box_mask = ops.expand_dims(ops.cast(score_mask, boxes.dtype), 2)
225
+ boxes *= box_mask
226
+
227
+ scores, boxes, sorted_indices = _sort_scores_and_boxes(scores, boxes)
228
+
229
+ pad = (
230
+ math.ceil(max(num_boxes, max_output_size) / tile_size) * tile_size
231
+ - num_boxes
232
+ )
233
+ boxes = ops.pad(ops.cast(boxes, "float32"), [[0, 0], [0, pad], [0, 0]])
234
+ scores = ops.pad(ops.cast(scores, "float32"), [[0, 0], [0, pad]])
235
+ num_boxes_after_padding = num_boxes + pad
236
+ num_iterations = num_boxes_after_padding // tile_size
237
+
238
+ def _loop_cond(unused_boxes, unused_threshold, output_size, idx):
239
+ return ops.logical_and(
240
+ ops.min(output_size) < ops.cast(max_output_size, "int32"),
241
+ ops.cast(idx, "int32") < num_iterations,
242
+ )
243
+
244
+ def suppression_loop_body(boxes, iou_threshold, output_size, idx):
245
+ return _suppression_loop_body(
246
+ boxes, iou_threshold, output_size, idx, tile_size
247
+ )
248
+
249
+ selected_boxes, _, output_size, _ = ops.while_loop(
250
+ _loop_cond,
251
+ suppression_loop_body,
252
+ [
253
+ boxes,
254
+ iou_threshold,
255
+ ops.zeros([batch_size], "int32"),
256
+ ops.array(0),
257
+ ],
258
+ )
259
+ num_valid = ops.minimum(output_size, max_output_size)
260
+ idx = num_boxes_after_padding - ops.cast(
261
+ ops.top_k(
262
+ ops.cast(ops.any(selected_boxes > 0, [2]), "int32")
263
+ * ops.cast(
264
+ ops.expand_dims(ops.arange(num_boxes_after_padding, 0, -1), 0),
265
+ "int32",
266
+ ),
267
+ max_output_size,
268
+ )[0],
269
+ "int32",
270
+ )
271
+ idx = ops.minimum(idx, num_boxes - 1)
272
+
273
+ index_offsets = ops.cast(ops.arange(batch_size) * num_boxes, "int32")
274
+ take_along_axis_idx = ops.reshape(
275
+ idx + ops.expand_dims(index_offsets, 1), [-1]
276
+ )
277
+
278
+ if keras.backend.backend() != "tensorflow":
279
+ idx = ops.take_along_axis(
280
+ ops.reshape(sorted_indices, [-1]), take_along_axis_idx
281
+ )
282
+ else:
283
+ import tensorflow as tf
284
+
285
+ idx = tf.gather(ops.reshape(sorted_indices, [-1]), take_along_axis_idx)
286
+ idx = ops.reshape(idx, [batch_size, -1])
287
+
288
+ invalid_index = ops.zeros([batch_size, max_output_size], dtype="int32")
289
+ idx_index = ops.cast(
290
+ ops.expand_dims(ops.arange(max_output_size), 0), "int32"
291
+ )
292
+ num_valid_expanded = ops.expand_dims(num_valid, 1)
293
+ idx = ops.where(idx_index < num_valid_expanded, idx, invalid_index)
294
+
295
+ num_valid = ops.reshape(num_valid, batch_dims)
296
+ return idx, num_valid
297
+
298
+
299
+ def _bbox_overlap(boxes_a, boxes_b):
300
+ """Calculates the overlap (iou - intersection over union) between boxes_a
301
+ and boxes_b.
302
+
303
+ Args:
304
+ boxes_a: a tensor with a shape of `[batch_size, N, 4]`.
305
+ `N` is the number of boxes per image. The last dimension is the
306
+ pixel coordinates in `[ymin, xmin, ymax, xmax]` form.
307
+ boxes_b: a tensor with a shape of `[batch_size, M, 4]`. M is the number of
308
+ boxes. The last dimension is the pixel coordinates in
309
+ `[ymin, xmin, ymax, xmax]` form.
310
+
311
+ Returns:
312
+ intersection_over_union: a tensor with as a shape of
313
+ `[batch_size, N, M]`, representing the ratio of intersection area
314
+ over union area (IoU) between two boxes
315
+ """
316
+ if len(boxes_a.shape) == 4:
317
+ boxes_a = ops.squeeze(boxes_a, axis=0)
318
+ a_y_min, a_x_min, a_y_max, a_x_max = ops.split(boxes_a, 4, axis=2)
319
+ b_y_min, b_x_min, b_y_max, b_x_max = ops.split(boxes_b, 4, axis=2)
320
+
321
+ # Calculates the intersection area.
322
+ i_xmin = ops.maximum(a_x_min, ops.transpose(b_x_min, [0, 2, 1]))
323
+ i_xmax = ops.minimum(a_x_max, ops.transpose(b_x_max, [0, 2, 1]))
324
+ i_ymin = ops.maximum(a_y_min, ops.transpose(b_y_min, [0, 2, 1]))
325
+ i_ymax = ops.minimum(a_y_max, ops.transpose(b_y_max, [0, 2, 1]))
326
+ i_area = ops.maximum((i_xmax - i_xmin), 0) * ops.maximum(
327
+ (i_ymax - i_ymin), 0
328
+ )
329
+
330
+ # Calculates the union area.
331
+ a_area = (a_y_max - a_y_min) * (a_x_max - a_x_min)
332
+ b_area = (b_y_max - b_y_min) * (b_x_max - b_x_min)
333
+
334
+ # Adds a small epsilon to avoid divide-by-zero.
335
+ u_area = a_area + ops.transpose(b_area, [0, 2, 1]) - i_area + EPSILON
336
+
337
+ intersection_over_union = i_area / u_area
338
+
339
+ return intersection_over_union
340
+
341
+
342
+ def _self_suppression(iou, _, iou_sum, iou_threshold):
343
+ """Suppress boxes in the same tile.
344
+
345
+ Compute boxes that cannot be suppressed by others (i.e.,
346
+ can_suppress_others), and then use them to suppress boxes in the same tile.
347
+
348
+ Args:
349
+ iou: a tensor of shape `[batch_size, num_boxes_with_padding]`
350
+ representing intersection over union.
351
+ iou_sum: a scalar tensor.
352
+ iou_threshold: a scalar tensor.
353
+
354
+ Returns:
355
+ iou_suppressed: a tensor of shape
356
+ `[batch_size, num_boxes_with_padding]`.
357
+ iou_diff: a scalar tensor representing whether any box is supressed in
358
+ this step.
359
+ iou_sum_new: a scalar tensor of shape `[batch_size]` that represents
360
+ the iou sum after suppression.
361
+ iou_threshold: a scalar tensor.
362
+ """
363
+ batch_size = ops.shape(iou)[0]
364
+ can_suppress_others = ops.cast(
365
+ ops.reshape(ops.max(iou, 1) < iou_threshold, [batch_size, -1, 1]),
366
+ iou.dtype,
367
+ )
368
+ iou_after_suppression = (
369
+ ops.reshape(
370
+ ops.cast(
371
+ ops.max(can_suppress_others * iou, 1) < iou_threshold, iou.dtype
372
+ ),
373
+ [batch_size, -1, 1],
374
+ )
375
+ * iou
376
+ )
377
+ iou_sum_new = ops.sum(iou_after_suppression, [1, 2])
378
+ return [
379
+ iou_after_suppression,
380
+ ops.any(iou_sum - iou_sum_new > iou_threshold),
381
+ iou_sum_new,
382
+ iou_threshold,
383
+ ]
384
+
385
+
386
+ def _cross_suppression(boxes, box_slice, iou_threshold, inner_idx, tile_size):
387
+ """Suppress boxes between different tiles.
388
+
389
+ Args:
390
+ boxes: a tensor of shape `[batch_size, num_boxes_with_padding, 4]`
391
+ box_slice: a tensor of shape `[batch_size, tile_size, 4]`
392
+ iou_threshold: a scalar tensor
393
+ inner_idx: a scalar tensor representing the tile index of the tile
394
+ that is used to supress box_slice
395
+ tile_size: an integer representing the number of boxes in a tile
396
+
397
+ Returns:
398
+ boxes: unchanged boxes as input
399
+ box_slice_after_suppression: box_slice after suppression
400
+ iou_threshold: unchanged
401
+ """
402
+ slice_index = ops.expand_dims(
403
+ ops.expand_dims(
404
+ ops.cast(
405
+ ops.linspace(
406
+ inner_idx * tile_size,
407
+ (inner_idx + 1) * tile_size - 1,
408
+ tile_size,
409
+ ),
410
+ "int32",
411
+ ),
412
+ axis=0,
413
+ ),
414
+ axis=-1,
415
+ )
416
+ new_slice = ops.expand_dims(
417
+ ops.take_along_axis(boxes, slice_index, axis=1), 0
418
+ )
419
+ iou = _bbox_overlap(new_slice, box_slice)
420
+ box_slice_after_suppression = (
421
+ ops.expand_dims(
422
+ ops.cast(ops.all(iou < iou_threshold, [1]), box_slice.dtype), 2
423
+ )
424
+ * box_slice
425
+ )
426
+ return boxes, box_slice_after_suppression, iou_threshold, inner_idx + 1
427
+
428
+
429
+ def _suppression_loop_body(boxes, iou_threshold, output_size, idx, tile_size):
430
+ """Process boxes in the range [idx*tile_size, (idx+1)*tile_size).
431
+
432
+ Args:
433
+ boxes: a tensor with a shape of [batch_size, anchors, 4].
434
+ iou_threshold: a float representing the threshold for deciding whether
435
+ boxes overlap too much with respect to IOU.
436
+ output_size: an int32 tensor of size [batch_size]. Representing the
437
+ number of selected boxes for each batch.
438
+ idx: an integer scalar representing induction variable.
439
+ tile_size: an integer representing the number of boxes in a tile
440
+
441
+ Returns:
442
+ boxes: updated boxes.
443
+ iou_threshold: pass down iou_threshold to the next iteration.
444
+ output_size: the updated output_size.
445
+ idx: the updated induction variable.
446
+ """
447
+ num_tiles = boxes.shape[1] // tile_size
448
+ batch_size = boxes.shape[0]
449
+
450
+ def cross_suppression_func(boxes, box_slice, iou_threshold, inner_idx):
451
+ return _cross_suppression(
452
+ boxes, box_slice, iou_threshold, inner_idx, tile_size
453
+ )
454
+
455
+ # Iterates over tiles that can possibly suppress the current tile.
456
+ slice_index = ops.expand_dims(
457
+ ops.expand_dims(
458
+ ops.cast(
459
+ ops.linspace(
460
+ idx * tile_size, (idx + 1) * tile_size - 1, tile_size
461
+ ),
462
+ "int32",
463
+ ),
464
+ axis=0,
465
+ ),
466
+ axis=-1,
467
+ )
468
+ box_slice = ops.take_along_axis(boxes, slice_index, axis=1)
469
+ _, box_slice, _, _ = ops.while_loop(
470
+ lambda _boxes, _box_slice, _threshold, inner_idx: inner_idx < idx,
471
+ cross_suppression_func,
472
+ [boxes, box_slice, iou_threshold, ops.array(0)],
473
+ )
474
+
475
+ # Iterates over the current tile to compute self-suppression.
476
+ iou = _bbox_overlap(box_slice, box_slice)
477
+ mask = ops.expand_dims(
478
+ ops.reshape(ops.arange(tile_size), [1, -1])
479
+ > ops.reshape(ops.arange(tile_size), [-1, 1]),
480
+ 0,
481
+ )
482
+ iou *= ops.cast(ops.logical_and(mask, iou >= iou_threshold), iou.dtype)
483
+ suppressed_iou, _, _, _ = ops.while_loop(
484
+ lambda _iou, loop_condition, _iou_sum, _: loop_condition,
485
+ _self_suppression,
486
+ [iou, ops.array(True), ops.sum(iou, [1, 2]), iou_threshold],
487
+ )
488
+ suppressed_box = ops.sum(suppressed_iou, 1) > 0
489
+ box_slice *= ops.expand_dims(
490
+ 1.0 - ops.cast(suppressed_box, box_slice.dtype), 2
491
+ )
492
+
493
+ # Uses box_slice to update the input boxes.
494
+ mask = ops.reshape(
495
+ ops.cast(ops.equal(ops.arange(num_tiles), idx), boxes.dtype),
496
+ [1, -1, 1, 1],
497
+ )
498
+ boxes = ops.tile(
499
+ ops.expand_dims(box_slice, 1), [1, num_tiles, 1, 1]
500
+ ) * mask + ops.reshape(boxes, [batch_size, num_tiles, tile_size, 4]) * (
501
+ 1 - mask
502
+ )
503
+ boxes = ops.reshape(boxes, [batch_size, -1, 4])
504
+
505
+ # Updates output_size.
506
+ output_size += ops.cast(ops.sum(ops.any(box_slice > 0, [2]), [1]), "int32")
507
+ return boxes, iou_threshold, output_size, idx + 1
508
+
509
+
510
+ def mask_invalid_detections(bounding_boxes):
511
+ """masks out invalid detections with -1s.
512
+
513
+ This utility is mainly used on the output of non-max suppression operations.
514
+ The output of non-max-suppression contains all the detections, even invalid
515
+ ones. Users are expected to use `num_detections` to determine how many boxes
516
+ are in each image.
517
+
518
+ In contrast, KerasHub expects all bounding boxes to be padded with -1s.
519
+ This function uses the value of `num_detections` to mask out
520
+ invalid boxes with -1s.
521
+
522
+ Args:
523
+ bounding_boxes: a dictionary complying with Keras bounding box format.
524
+ In addition to the normal required keys, these boxes are also
525
+ expected to have a `num_detections` key.
526
+
527
+ Returns:
528
+ bounding boxes with proper masking of the boxes according to
529
+ `num_detections`. This allows proper interop with non-max supression.
530
+ Returned boxes match the specification fed to the function, so if the
531
+ bounding box tensor uses `tf.RaggedTensor` to represent boxes the
532
+ returned value will also return `tf.RaggedTensor` representations.
533
+ """
534
+ # ensure we are complying with Keras bounding box format.
535
+ info = validate_format.validate_format(bounding_boxes)
536
+ if info["ragged"]:
537
+ raise ValueError(
538
+ "`bounding_box.mask_invalid_detections()` requires inputs to be "
539
+ "Dense tensors. Please call "
540
+ "`bounding_box.to_dense(bounding_boxes)` before passing your boxes "
541
+ "to `bounding_box.mask_invalid_detections()`."
542
+ )
543
+ if "num_detections" not in bounding_boxes:
544
+ raise ValueError(
545
+ "`bounding_boxes` must have key 'num_detections' "
546
+ "to be used with `bounding_box.mask_invalid_detections()`."
547
+ )
548
+
549
+ boxes = bounding_boxes.get("boxes")
550
+ classes = bounding_boxes.get("classes")
551
+ confidence = bounding_boxes.get("confidence", None)
552
+ num_detections = bounding_boxes.get("num_detections")
553
+
554
+ # Create a mask to select only the first N boxes from each batch
555
+ mask = ops.cast(
556
+ ops.expand_dims(ops.arange(boxes.shape[1]), axis=0),
557
+ num_detections.dtype,
558
+ )
559
+ mask = mask < num_detections[:, None]
560
+
561
+ classes = ops.where(mask, classes, -ops.ones_like(classes))
562
+
563
+ if confidence is not None:
564
+ confidence = ops.where(mask, confidence, -ops.ones_like(confidence))
565
+
566
+ # reuse mask for boxes
567
+ mask = ops.expand_dims(mask, axis=-1)
568
+ mask = ops.repeat(mask, repeats=boxes.shape[-1], axis=-1)
569
+ boxes = ops.where(mask, boxes, -ops.ones_like(boxes))
570
+
571
+ result = bounding_boxes.copy()
572
+
573
+ result["boxes"] = boxes
574
+ result["classes"] = classes
575
+ if confidence is not None:
576
+ result["confidence"] = confidence
577
+
578
+ return result
@@ -26,7 +26,12 @@ from keras_hub.src.models.roberta.roberta_text_classifier_preprocessor import (
26
26
  from keras_hub.src.models.text_classifier import TextClassifier
27
27
 
28
28
 
29
- @keras_hub_export("keras_hub.models.RobertaTextClassifier")
29
+ @keras_hub_export(
30
+ [
31
+ "keras_hub.models.RobertaTextClassifier",
32
+ "keras_hub.models.RobertaClassifier",
33
+ ]
34
+ )
30
35
  class RobertaTextClassifier(TextClassifier):
31
36
  """An end-to-end RoBERTa model for classification tasks.
32
37
 
@@ -139,7 +139,6 @@ class Task(PipelineModel):
139
139
  cls,
140
140
  preset,
141
141
  load_weights=True,
142
- load_task_extras=False,
143
142
  **kwargs,
144
143
  ):
145
144
  """Instantiate a `keras_hub.models.Task` from a model preset.
@@ -168,10 +167,6 @@ class Task(PipelineModel):
168
167
  load_weights: bool. If `True`, saved weights will be loaded into
169
168
  the model architecture. If `False`, all weights will be
170
169
  randomly initialized.
171
- load_task_extras: bool. If `True`, load the saved task configuration
172
- from a `task.json` and any task specific weights from
173
- `task.weights`. You might use this to load a classification
174
- head for a model that has been saved with it.
175
170
 
176
171
  Examples:
177
172
  ```python
@@ -199,7 +194,12 @@ class Task(PipelineModel):
199
194
  # Detect the correct subclass if we need to.
200
195
  if cls.backbone_cls != backbone_cls:
201
196
  cls = find_subclass(preset, cls, backbone_cls)
202
- return loader.load_task(cls, load_weights, load_task_extras, **kwargs)
197
+ # Specifically for classifiers, we never load task weights if
198
+ # num_classes is supplied. We handle this in the task base class because
199
+ # it is the same logic for classifiers regardless of modality (text,
200
+ # images, audio).
201
+ load_task_weights = "num_classes" not in kwargs
202
+ return loader.load_task(cls, load_weights, load_task_weights, **kwargs)
203
203
 
204
204
  def load_task_weights(self, filepath):
205
205
  """Load only the tasks specific weights not in the backbone."""
@@ -17,7 +17,12 @@ from keras_hub.src.api_export import keras_hub_export
17
17
  from keras_hub.src.models.task import Task
18
18
 
19
19
 
20
- @keras_hub_export("keras_hub.models.TextClassifier")
20
+ @keras_hub_export(
21
+ [
22
+ "keras_hub.models.TextClassifier",
23
+ "keras_hub.models.Classifier",
24
+ ]
25
+ )
21
26
  class TextClassifier(Task):
22
27
  """Base class for all classification tasks.
23
28
 
@@ -32,6 +37,12 @@ class TextClassifier(Task):
32
37
  All `TextClassifier` tasks include a `from_preset()` constructor which can be
33
38
  used to load a pre-trained config and weights.
34
39
 
40
+ Some, but not all, classification presets include classification head
41
+ weights in a `task.weights.h5` file. For these presets, you can omit passing
42
+ `num_classes` to restore the saved classification head. For all presets, if
43
+ `num_classes` is passed as a kwarg to `from_preset()`, the classification
44
+ head will be randomly initialized.
45
+
35
46
  Example:
36
47
  ```python
37
48
  # Load a BERT classifier with pre-trained weights.
@@ -28,7 +28,12 @@ from keras_hub.src.models.xlm_roberta.xlm_roberta_text_classifier_preprocessor i
28
28
  )
29
29
 
30
30
 
31
- @keras_hub_export("keras_hub.models.XLMRobertaTextClassifier")
31
+ @keras_hub_export(
32
+ [
33
+ "keras_hub.models.XLMRobertaTextClassifier",
34
+ "keras_hub.models.XLMRobertaClassifier",
35
+ ]
36
+ )
32
37
  class XLMRobertaTextClassifier(TextClassifier):
33
38
  """An end-to-end XLM-RoBERTa model for classification tasks.
34
39