keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl → 0.16.0.dev2024092017__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (198) hide show
  1. keras_hub/__init__.py +0 -6
  2. keras_hub/api/__init__.py +2 -0
  3. keras_hub/api/bounding_box/__init__.py +36 -0
  4. keras_hub/api/layers/__init__.py +14 -0
  5. keras_hub/api/models/__init__.py +97 -48
  6. keras_hub/api/tokenizers/__init__.py +30 -0
  7. keras_hub/api/utils/__init__.py +22 -0
  8. keras_hub/src/api_export.py +15 -9
  9. keras_hub/src/bounding_box/__init__.py +13 -0
  10. keras_hub/src/bounding_box/converters.py +529 -0
  11. keras_hub/src/bounding_box/formats.py +162 -0
  12. keras_hub/src/bounding_box/iou.py +263 -0
  13. keras_hub/src/bounding_box/to_dense.py +95 -0
  14. keras_hub/src/bounding_box/to_ragged.py +99 -0
  15. keras_hub/src/bounding_box/utils.py +194 -0
  16. keras_hub/src/bounding_box/validate_format.py +99 -0
  17. keras_hub/src/layers/preprocessing/audio_converter.py +121 -0
  18. keras_hub/src/layers/preprocessing/image_converter.py +130 -0
  19. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +2 -0
  20. keras_hub/src/layers/preprocessing/multi_segment_packer.py +9 -8
  21. keras_hub/src/layers/preprocessing/preprocessing_layer.py +2 -29
  22. keras_hub/src/layers/preprocessing/random_deletion.py +33 -31
  23. keras_hub/src/layers/preprocessing/random_swap.py +33 -31
  24. keras_hub/src/layers/preprocessing/resizing_image_converter.py +101 -0
  25. keras_hub/src/layers/preprocessing/start_end_packer.py +3 -2
  26. keras_hub/src/models/albert/__init__.py +1 -2
  27. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +6 -86
  28. keras_hub/src/models/albert/{albert_classifier.py → albert_text_classifier.py} +34 -10
  29. keras_hub/src/models/albert/{albert_preprocessor.py → albert_text_classifier_preprocessor.py} +14 -70
  30. keras_hub/src/models/albert/albert_tokenizer.py +17 -36
  31. keras_hub/src/models/backbone.py +12 -34
  32. keras_hub/src/models/bart/__init__.py +1 -2
  33. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +21 -148
  34. keras_hub/src/models/bart/bart_tokenizer.py +12 -39
  35. keras_hub/src/models/bert/__init__.py +1 -5
  36. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +6 -87
  37. keras_hub/src/models/bert/bert_presets.py +1 -4
  38. keras_hub/src/models/bert/{bert_classifier.py → bert_text_classifier.py} +19 -12
  39. keras_hub/src/models/bert/{bert_preprocessor.py → bert_text_classifier_preprocessor.py} +14 -70
  40. keras_hub/src/models/bert/bert_tokenizer.py +17 -35
  41. keras_hub/src/models/bloom/__init__.py +1 -2
  42. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +6 -91
  43. keras_hub/src/models/bloom/bloom_tokenizer.py +12 -41
  44. keras_hub/src/models/causal_lm.py +10 -29
  45. keras_hub/src/models/causal_lm_preprocessor.py +195 -0
  46. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +54 -15
  47. keras_hub/src/models/deberta_v3/__init__.py +1 -4
  48. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +14 -77
  49. keras_hub/src/models/deberta_v3/{deberta_v3_classifier.py → deberta_v3_text_classifier.py} +16 -11
  50. keras_hub/src/models/deberta_v3/{deberta_v3_preprocessor.py → deberta_v3_text_classifier_preprocessor.py} +23 -64
  51. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +30 -25
  52. keras_hub/src/models/densenet/densenet_backbone.py +46 -22
  53. keras_hub/src/models/distil_bert/__init__.py +1 -4
  54. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +14 -76
  55. keras_hub/src/models/distil_bert/{distil_bert_classifier.py → distil_bert_text_classifier.py} +17 -12
  56. keras_hub/src/models/distil_bert/{distil_bert_preprocessor.py → distil_bert_text_classifier_preprocessor.py} +23 -63
  57. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +19 -35
  58. keras_hub/src/models/efficientnet/__init__.py +13 -0
  59. keras_hub/src/models/efficientnet/efficientnet_backbone.py +569 -0
  60. keras_hub/src/models/efficientnet/fusedmbconv.py +229 -0
  61. keras_hub/src/models/efficientnet/mbconv.py +238 -0
  62. keras_hub/src/models/electra/__init__.py +1 -2
  63. keras_hub/src/models/electra/electra_tokenizer.py +17 -32
  64. keras_hub/src/models/f_net/__init__.py +1 -2
  65. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +12 -78
  66. keras_hub/src/models/f_net/{f_net_classifier.py → f_net_text_classifier.py} +17 -10
  67. keras_hub/src/models/f_net/{f_net_preprocessor.py → f_net_text_classifier_preprocessor.py} +19 -63
  68. keras_hub/src/models/f_net/f_net_tokenizer.py +17 -35
  69. keras_hub/src/models/falcon/__init__.py +1 -2
  70. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +6 -89
  71. keras_hub/src/models/falcon/falcon_tokenizer.py +12 -35
  72. keras_hub/src/models/gemma/__init__.py +1 -2
  73. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +6 -90
  74. keras_hub/src/models/gemma/gemma_decoder_block.py +1 -1
  75. keras_hub/src/models/gemma/gemma_tokenizer.py +12 -23
  76. keras_hub/src/models/gpt2/__init__.py +1 -2
  77. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +6 -89
  78. keras_hub/src/models/gpt2/gpt2_preprocessor.py +12 -90
  79. keras_hub/src/models/gpt2/gpt2_tokenizer.py +12 -34
  80. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +6 -91
  81. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +12 -34
  82. keras_hub/src/models/image_classifier.py +0 -5
  83. keras_hub/src/models/image_classifier_preprocessor.py +83 -0
  84. keras_hub/src/models/llama/__init__.py +1 -2
  85. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +6 -85
  86. keras_hub/src/models/llama/llama_tokenizer.py +12 -25
  87. keras_hub/src/models/llama3/__init__.py +1 -2
  88. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +6 -89
  89. keras_hub/src/models/llama3/llama3_tokenizer.py +12 -33
  90. keras_hub/src/models/masked_lm.py +0 -2
  91. keras_hub/src/models/masked_lm_preprocessor.py +156 -0
  92. keras_hub/src/models/mistral/__init__.py +1 -2
  93. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +6 -91
  94. keras_hub/src/models/mistral/mistral_tokenizer.py +12 -23
  95. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +2 -2
  96. keras_hub/src/models/mobilenet/__init__.py +13 -0
  97. keras_hub/src/models/mobilenet/mobilenet_backbone.py +530 -0
  98. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +114 -0
  99. keras_hub/src/models/opt/__init__.py +1 -2
  100. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +6 -93
  101. keras_hub/src/models/opt/opt_tokenizer.py +12 -41
  102. keras_hub/src/models/pali_gemma/__init__.py +1 -4
  103. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +28 -28
  104. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +25 -0
  105. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +5 -5
  106. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +11 -3
  107. keras_hub/src/models/phi3/__init__.py +1 -2
  108. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -9
  109. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +6 -89
  110. keras_hub/src/models/phi3/phi3_tokenizer.py +12 -36
  111. keras_hub/src/models/preprocessor.py +72 -83
  112. keras_hub/src/models/resnet/__init__.py +6 -0
  113. keras_hub/src/models/resnet/resnet_backbone.py +390 -42
  114. keras_hub/src/models/resnet/resnet_image_classifier.py +33 -6
  115. keras_hub/src/models/resnet/resnet_image_classifier_preprocessor.py +28 -0
  116. keras_hub/src/models/{llama3/llama3_preprocessor.py → resnet/resnet_image_converter.py} +7 -5
  117. keras_hub/src/models/resnet/resnet_presets.py +95 -0
  118. keras_hub/src/models/retinanet/__init__.py +13 -0
  119. keras_hub/src/models/retinanet/anchor_generator.py +175 -0
  120. keras_hub/src/models/retinanet/box_matcher.py +259 -0
  121. keras_hub/src/models/retinanet/non_max_supression.py +578 -0
  122. keras_hub/src/models/roberta/__init__.py +1 -2
  123. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +22 -74
  124. keras_hub/src/models/roberta/{roberta_classifier.py → roberta_text_classifier.py} +16 -11
  125. keras_hub/src/models/roberta/{roberta_preprocessor.py → roberta_text_classifier_preprocessor.py} +21 -53
  126. keras_hub/src/models/roberta/roberta_tokenizer.py +13 -52
  127. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +269 -0
  128. keras_hub/src/models/stable_diffusion_v3/__init__.py +13 -0
  129. keras_hub/src/models/stable_diffusion_v3/clip_encoder_block.py +103 -0
  130. keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +93 -0
  131. keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py +149 -0
  132. keras_hub/src/models/stable_diffusion_v3/clip_tokenizer.py +167 -0
  133. keras_hub/src/models/stable_diffusion_v3/mmdit.py +427 -0
  134. keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +317 -0
  135. keras_hub/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py +74 -0
  136. keras_hub/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py +155 -0
  137. keras_hub/src/models/stable_diffusion_v3/vae_attention.py +126 -0
  138. keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +186 -0
  139. keras_hub/src/models/t5/__init__.py +1 -2
  140. keras_hub/src/models/t5/t5_tokenizer.py +13 -23
  141. keras_hub/src/models/task.py +71 -116
  142. keras_hub/src/models/{classifier.py → text_classifier.py} +19 -13
  143. keras_hub/src/models/text_classifier_preprocessor.py +138 -0
  144. keras_hub/src/models/whisper/__init__.py +1 -2
  145. keras_hub/src/models/whisper/{whisper_audio_feature_extractor.py → whisper_audio_converter.py} +20 -18
  146. keras_hub/src/models/whisper/whisper_backbone.py +0 -3
  147. keras_hub/src/models/whisper/whisper_presets.py +10 -10
  148. keras_hub/src/models/whisper/whisper_tokenizer.py +20 -16
  149. keras_hub/src/models/xlm_roberta/__init__.py +1 -4
  150. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +26 -72
  151. keras_hub/src/models/xlm_roberta/{xlm_roberta_classifier.py → xlm_roberta_text_classifier.py} +16 -11
  152. keras_hub/src/models/xlm_roberta/{xlm_roberta_preprocessor.py → xlm_roberta_text_classifier_preprocessor.py} +26 -53
  153. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +25 -10
  154. keras_hub/src/tests/test_case.py +46 -0
  155. keras_hub/src/tokenizers/byte_pair_tokenizer.py +30 -17
  156. keras_hub/src/tokenizers/byte_tokenizer.py +14 -15
  157. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +20 -7
  158. keras_hub/src/tokenizers/tokenizer.py +67 -32
  159. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +14 -15
  160. keras_hub/src/tokenizers/word_piece_tokenizer.py +34 -47
  161. keras_hub/src/utils/imagenet/__init__.py +13 -0
  162. keras_hub/src/utils/imagenet/imagenet_utils.py +1067 -0
  163. keras_hub/src/utils/keras_utils.py +0 -50
  164. keras_hub/src/utils/preset_utils.py +230 -68
  165. keras_hub/src/utils/tensor_utils.py +187 -69
  166. keras_hub/src/utils/timm/convert_resnet.py +19 -16
  167. keras_hub/src/utils/timm/preset_loader.py +66 -0
  168. keras_hub/src/utils/transformers/convert_albert.py +193 -0
  169. keras_hub/src/utils/transformers/convert_bart.py +373 -0
  170. keras_hub/src/utils/transformers/convert_bert.py +7 -17
  171. keras_hub/src/utils/transformers/convert_distilbert.py +10 -20
  172. keras_hub/src/utils/transformers/convert_gemma.py +5 -19
  173. keras_hub/src/utils/transformers/convert_gpt2.py +5 -18
  174. keras_hub/src/utils/transformers/convert_llama3.py +7 -18
  175. keras_hub/src/utils/transformers/convert_mistral.py +129 -0
  176. keras_hub/src/utils/transformers/convert_pali_gemma.py +7 -29
  177. keras_hub/src/utils/transformers/preset_loader.py +77 -0
  178. keras_hub/src/utils/transformers/safetensor_utils.py +2 -2
  179. keras_hub/src/version_utils.py +1 -1
  180. keras_hub_nightly-0.16.0.dev2024092017.dist-info/METADATA +202 -0
  181. keras_hub_nightly-0.16.0.dev2024092017.dist-info/RECORD +334 -0
  182. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/WHEEL +1 -1
  183. keras_hub/src/models/bart/bart_preprocessor.py +0 -276
  184. keras_hub/src/models/bloom/bloom_preprocessor.py +0 -185
  185. keras_hub/src/models/electra/electra_preprocessor.py +0 -154
  186. keras_hub/src/models/falcon/falcon_preprocessor.py +0 -187
  187. keras_hub/src/models/gemma/gemma_preprocessor.py +0 -191
  188. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +0 -145
  189. keras_hub/src/models/llama/llama_preprocessor.py +0 -189
  190. keras_hub/src/models/mistral/mistral_preprocessor.py +0 -190
  191. keras_hub/src/models/opt/opt_preprocessor.py +0 -188
  192. keras_hub/src/models/phi3/phi3_preprocessor.py +0 -190
  193. keras_hub/src/models/whisper/whisper_preprocessor.py +0 -326
  194. keras_hub/src/utils/timm/convert.py +0 -37
  195. keras_hub/src/utils/transformers/convert.py +0 -101
  196. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +0 -34
  197. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +0 -297
  198. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,578 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+
17
+ import keras
18
+ from keras import ops
19
+
20
+ from keras_hub.src.bounding_box import converters
21
+ from keras_hub.src.bounding_box import utils
22
+ from keras_hub.src.bounding_box import validate_format
23
+
24
+ EPSILON = 1e-8
25
+
26
+
27
+ class NonMaxSuppression(keras.layers.Layer):
28
+ """A Keras layer that decodes predictions of an object detection model.
29
+
30
+ Args:
31
+ bounding_box_format: The format of bounding boxes of input dataset.
32
+ Refer
33
+ TODO: link keras core bounding box docs
34
+ for more details on supported bounding box formats.
35
+ from_logits: boolean, True means input score is logits, False means
36
+ confidence.
37
+ iou_threshold: a float value in the range [0, 1] representing the
38
+ minimum IoU threshold for two boxes to be considered
39
+ same for suppression. Defaults to 0.5.
40
+ confidence_threshold: a float value in the range [0, 1]. All boxes with
41
+ confidence below this value will be discarded, defaults to 0.5.
42
+ max_detections: the maximum detections to consider after nms is applied.
43
+ A large number may trigger significant memory overhead,
44
+ defaults to 100.
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ bounding_box_format,
50
+ from_logits,
51
+ iou_threshold=0.5,
52
+ confidence_threshold=0.5,
53
+ max_detections=100,
54
+ **kwargs,
55
+ ):
56
+ super().__init__(**kwargs)
57
+ self.bounding_box_format = bounding_box_format
58
+ self.from_logits = from_logits
59
+ self.iou_threshold = iou_threshold
60
+ self.confidence_threshold = confidence_threshold
61
+ self.max_detections = max_detections
62
+ self.built = True
63
+
64
+ def call(
65
+ self, box_prediction, class_prediction, images=None, image_shape=None
66
+ ):
67
+ """Accepts images and raw scores, returning bounding box predictions.
68
+
69
+ Args:
70
+ box_prediction: Dense Tensor of shape [batch, boxes, 4] in the
71
+ `bounding_box_format` specified in the constructor.
72
+ class_prediction: Dense Tensor of shape [batch, boxes, num_classes].
73
+ """
74
+ target_format = "yxyx"
75
+ if utils.is_relative(self.bounding_box_format):
76
+ target_format = utils.as_relative(target_format)
77
+
78
+ box_prediction = converters.convert_format(
79
+ box_prediction,
80
+ source=self.bounding_box_format,
81
+ target=target_format,
82
+ images=images,
83
+ image_shape=image_shape,
84
+ )
85
+ if self.from_logits:
86
+ class_prediction = ops.sigmoid(class_prediction)
87
+
88
+ confidence_prediction = ops.max(class_prediction, axis=-1)
89
+
90
+ idx, valid_det = non_max_suppression(
91
+ box_prediction,
92
+ confidence_prediction,
93
+ max_output_size=self.max_detections,
94
+ iou_threshold=self.iou_threshold,
95
+ score_threshold=self.confidence_threshold,
96
+ )
97
+
98
+ box_prediction = ops.take_along_axis(
99
+ box_prediction, ops.expand_dims(idx, axis=-1), axis=1
100
+ )
101
+ box_prediction = ops.reshape(
102
+ box_prediction, (-1, self.max_detections, 4)
103
+ )
104
+ confidence_prediction = ops.take_along_axis(
105
+ confidence_prediction, idx, axis=1
106
+ )
107
+ class_prediction = ops.take_along_axis(
108
+ class_prediction, ops.expand_dims(idx, axis=-1), axis=1
109
+ )
110
+
111
+ box_prediction = converters.convert_format(
112
+ box_prediction,
113
+ source=target_format,
114
+ target=self.bounding_box_format,
115
+ images=images,
116
+ image_shape=image_shape,
117
+ )
118
+ bounding_boxes = {
119
+ "boxes": box_prediction,
120
+ "confidence": confidence_prediction,
121
+ "classes": ops.argmax(class_prediction, axis=-1),
122
+ "num_detections": valid_det,
123
+ }
124
+
125
+ # this is required to comply with bounding box format.
126
+ return mask_invalid_detections(bounding_boxes)
127
+
128
+ def get_config(self):
129
+ config = super().get_config()
130
+ config.update(
131
+ {
132
+ "bounding_box_format": self.bounding_box_format,
133
+ "from_logits": self.from_logits,
134
+ "iou_threshold": self.iou_threshold,
135
+ "confidence_threshold": self.confidence_threshold,
136
+ "max_detections": self.max_detections,
137
+ }
138
+ )
139
+ return config
140
+
141
+
142
+ def non_max_suppression(
143
+ boxes,
144
+ scores,
145
+ max_output_size,
146
+ iou_threshold=0.5,
147
+ score_threshold=0.0,
148
+ tile_size=512,
149
+ ):
150
+ """Non-maximum suppression.
151
+
152
+ Ported from https://github.com/tensorflow/tensorflow/blob/v2.12.0/tensorflow/python/ops/image_ops_impl.py#L5368-L5458
153
+
154
+ Args:
155
+ boxes: a tensor of rank 2 or higher with a shape of
156
+ `[..., num_boxes, 4]`. Dimensions except the last two are batch
157
+ dimensions. The last dimension represents box coordinates in
158
+ yxyx format.
159
+ scores: a tensor of rank 1 or higher with a shape of `[..., num_boxes]`.
160
+ max_output_size: a scalar integer tensor representing the maximum
161
+ number of boxes to be selected by non max suppression.
162
+ iou_threshold: a float representing the threshold for
163
+ deciding whether boxes overlap too much with respect
164
+ to IoU (intersection over union).
165
+ score_threshold: a float representing the threshold for box scores.
166
+ Boxes with a score that is not larger than this threshold
167
+ will be suppressed.
168
+ tile_size: an integer representing the number of boxes in a tile, i.e.,
169
+ the maximum number of boxes per image that can be used to suppress
170
+ other boxes in parallel; larger tile_size means larger parallelism
171
+ and potentially more redundant work.
172
+
173
+ Returns:
174
+ idx: a tensor with a shape of `[..., num_boxes]` representing the
175
+ indices selected by non-max suppression. The leading dimensions
176
+ are the batch dimensions of the input boxes. All numbers are within
177
+ `[0, num_boxes)`. For each image (i.e., `idx[i]`), only the first
178
+ `num_valid[i]` indices (i.e., `idx[i][:num_valid[i]]`) are valid.
179
+ num_valid: a tensor of rank 0 or higher with a shape of [...]
180
+ representing the number of valid indices in idx. Its dimensions
181
+ are the batch dimensions of the input boxes.
182
+ """
183
+
184
+ def _sort_scores_and_boxes(scores, boxes):
185
+ """Sort boxes based their score from highest to lowest.
186
+
187
+ Args:
188
+ scores: a tensor with a shape of `[batch_size, num_boxes]`
189
+ representing the scores of boxes.
190
+ boxes: a tensor with a shape of `[batch_size, num_boxes, 4]`
191
+ representing the boxes.
192
+
193
+ Returns:
194
+ sorted_scores: a tensor with a shape of
195
+ `[batch_size, num_boxes]` representing the sorted scores.
196
+ sorted_boxes: a tensor representing the sorted boxes.
197
+ sorted_scores_indices: a tensor with a shape of
198
+ `[batch_size, num_boxes]` representing the index of the scores
199
+ in a sorted descending order.
200
+ """
201
+ sorted_scores_indices = ops.flip(
202
+ ops.cast(ops.argsort(scores, axis=1), "int32"), axis=1
203
+ )
204
+ sorted_scores = ops.take_along_axis(
205
+ scores,
206
+ sorted_scores_indices,
207
+ axis=1,
208
+ )
209
+ sorted_boxes = ops.take_along_axis(
210
+ boxes,
211
+ ops.expand_dims(sorted_scores_indices, axis=-1),
212
+ axis=1,
213
+ )
214
+ return sorted_scores, sorted_boxes, sorted_scores_indices
215
+
216
+ batch_dims = ops.shape(boxes)[:-2]
217
+ num_boxes = boxes.shape[-2]
218
+ boxes = ops.reshape(boxes, [-1, num_boxes, 4])
219
+ scores = ops.reshape(scores, [-1, num_boxes])
220
+ batch_size = boxes.shape[0]
221
+ if score_threshold != float("-inf"):
222
+ score_mask = ops.cast(scores > score_threshold, scores.dtype)
223
+ scores *= score_mask
224
+ box_mask = ops.expand_dims(ops.cast(score_mask, boxes.dtype), 2)
225
+ boxes *= box_mask
226
+
227
+ scores, boxes, sorted_indices = _sort_scores_and_boxes(scores, boxes)
228
+
229
+ pad = (
230
+ math.ceil(max(num_boxes, max_output_size) / tile_size) * tile_size
231
+ - num_boxes
232
+ )
233
+ boxes = ops.pad(ops.cast(boxes, "float32"), [[0, 0], [0, pad], [0, 0]])
234
+ scores = ops.pad(ops.cast(scores, "float32"), [[0, 0], [0, pad]])
235
+ num_boxes_after_padding = num_boxes + pad
236
+ num_iterations = num_boxes_after_padding // tile_size
237
+
238
+ def _loop_cond(unused_boxes, unused_threshold, output_size, idx):
239
+ return ops.logical_and(
240
+ ops.min(output_size) < ops.cast(max_output_size, "int32"),
241
+ ops.cast(idx, "int32") < num_iterations,
242
+ )
243
+
244
+ def suppression_loop_body(boxes, iou_threshold, output_size, idx):
245
+ return _suppression_loop_body(
246
+ boxes, iou_threshold, output_size, idx, tile_size
247
+ )
248
+
249
+ selected_boxes, _, output_size, _ = ops.while_loop(
250
+ _loop_cond,
251
+ suppression_loop_body,
252
+ [
253
+ boxes,
254
+ iou_threshold,
255
+ ops.zeros([batch_size], "int32"),
256
+ ops.array(0),
257
+ ],
258
+ )
259
+ num_valid = ops.minimum(output_size, max_output_size)
260
+ idx = num_boxes_after_padding - ops.cast(
261
+ ops.top_k(
262
+ ops.cast(ops.any(selected_boxes > 0, [2]), "int32")
263
+ * ops.cast(
264
+ ops.expand_dims(ops.arange(num_boxes_after_padding, 0, -1), 0),
265
+ "int32",
266
+ ),
267
+ max_output_size,
268
+ )[0],
269
+ "int32",
270
+ )
271
+ idx = ops.minimum(idx, num_boxes - 1)
272
+
273
+ index_offsets = ops.cast(ops.arange(batch_size) * num_boxes, "int32")
274
+ take_along_axis_idx = ops.reshape(
275
+ idx + ops.expand_dims(index_offsets, 1), [-1]
276
+ )
277
+
278
+ if keras.backend.backend() != "tensorflow":
279
+ idx = ops.take_along_axis(
280
+ ops.reshape(sorted_indices, [-1]), take_along_axis_idx
281
+ )
282
+ else:
283
+ import tensorflow as tf
284
+
285
+ idx = tf.gather(ops.reshape(sorted_indices, [-1]), take_along_axis_idx)
286
+ idx = ops.reshape(idx, [batch_size, -1])
287
+
288
+ invalid_index = ops.zeros([batch_size, max_output_size], dtype="int32")
289
+ idx_index = ops.cast(
290
+ ops.expand_dims(ops.arange(max_output_size), 0), "int32"
291
+ )
292
+ num_valid_expanded = ops.expand_dims(num_valid, 1)
293
+ idx = ops.where(idx_index < num_valid_expanded, idx, invalid_index)
294
+
295
+ num_valid = ops.reshape(num_valid, batch_dims)
296
+ return idx, num_valid
297
+
298
+
299
+ def _bbox_overlap(boxes_a, boxes_b):
300
+ """Calculates the overlap (iou - intersection over union) between boxes_a
301
+ and boxes_b.
302
+
303
+ Args:
304
+ boxes_a: a tensor with a shape of `[batch_size, N, 4]`.
305
+ `N` is the number of boxes per image. The last dimension is the
306
+ pixel coordinates in `[ymin, xmin, ymax, xmax]` form.
307
+ boxes_b: a tensor with a shape of `[batch_size, M, 4]`. M is the number of
308
+ boxes. The last dimension is the pixel coordinates in
309
+ `[ymin, xmin, ymax, xmax]` form.
310
+
311
+ Returns:
312
+ intersection_over_union: a tensor with as a shape of
313
+ `[batch_size, N, M]`, representing the ratio of intersection area
314
+ over union area (IoU) between two boxes
315
+ """
316
+ if len(boxes_a.shape) == 4:
317
+ boxes_a = ops.squeeze(boxes_a, axis=0)
318
+ a_y_min, a_x_min, a_y_max, a_x_max = ops.split(boxes_a, 4, axis=2)
319
+ b_y_min, b_x_min, b_y_max, b_x_max = ops.split(boxes_b, 4, axis=2)
320
+
321
+ # Calculates the intersection area.
322
+ i_xmin = ops.maximum(a_x_min, ops.transpose(b_x_min, [0, 2, 1]))
323
+ i_xmax = ops.minimum(a_x_max, ops.transpose(b_x_max, [0, 2, 1]))
324
+ i_ymin = ops.maximum(a_y_min, ops.transpose(b_y_min, [0, 2, 1]))
325
+ i_ymax = ops.minimum(a_y_max, ops.transpose(b_y_max, [0, 2, 1]))
326
+ i_area = ops.maximum((i_xmax - i_xmin), 0) * ops.maximum(
327
+ (i_ymax - i_ymin), 0
328
+ )
329
+
330
+ # Calculates the union area.
331
+ a_area = (a_y_max - a_y_min) * (a_x_max - a_x_min)
332
+ b_area = (b_y_max - b_y_min) * (b_x_max - b_x_min)
333
+
334
+ # Adds a small epsilon to avoid divide-by-zero.
335
+ u_area = a_area + ops.transpose(b_area, [0, 2, 1]) - i_area + EPSILON
336
+
337
+ intersection_over_union = i_area / u_area
338
+
339
+ return intersection_over_union
340
+
341
+
342
+ def _self_suppression(iou, _, iou_sum, iou_threshold):
343
+ """Suppress boxes in the same tile.
344
+
345
+ Compute boxes that cannot be suppressed by others (i.e.,
346
+ can_suppress_others), and then use them to suppress boxes in the same tile.
347
+
348
+ Args:
349
+ iou: a tensor of shape `[batch_size, num_boxes_with_padding]`
350
+ representing intersection over union.
351
+ iou_sum: a scalar tensor.
352
+ iou_threshold: a scalar tensor.
353
+
354
+ Returns:
355
+ iou_suppressed: a tensor of shape
356
+ `[batch_size, num_boxes_with_padding]`.
357
+ iou_diff: a scalar tensor representing whether any box is supressed in
358
+ this step.
359
+ iou_sum_new: a scalar tensor of shape `[batch_size]` that represents
360
+ the iou sum after suppression.
361
+ iou_threshold: a scalar tensor.
362
+ """
363
+ batch_size = ops.shape(iou)[0]
364
+ can_suppress_others = ops.cast(
365
+ ops.reshape(ops.max(iou, 1) < iou_threshold, [batch_size, -1, 1]),
366
+ iou.dtype,
367
+ )
368
+ iou_after_suppression = (
369
+ ops.reshape(
370
+ ops.cast(
371
+ ops.max(can_suppress_others * iou, 1) < iou_threshold, iou.dtype
372
+ ),
373
+ [batch_size, -1, 1],
374
+ )
375
+ * iou
376
+ )
377
+ iou_sum_new = ops.sum(iou_after_suppression, [1, 2])
378
+ return [
379
+ iou_after_suppression,
380
+ ops.any(iou_sum - iou_sum_new > iou_threshold),
381
+ iou_sum_new,
382
+ iou_threshold,
383
+ ]
384
+
385
+
386
+ def _cross_suppression(boxes, box_slice, iou_threshold, inner_idx, tile_size):
387
+ """Suppress boxes between different tiles.
388
+
389
+ Args:
390
+ boxes: a tensor of shape `[batch_size, num_boxes_with_padding, 4]`
391
+ box_slice: a tensor of shape `[batch_size, tile_size, 4]`
392
+ iou_threshold: a scalar tensor
393
+ inner_idx: a scalar tensor representing the tile index of the tile
394
+ that is used to supress box_slice
395
+ tile_size: an integer representing the number of boxes in a tile
396
+
397
+ Returns:
398
+ boxes: unchanged boxes as input
399
+ box_slice_after_suppression: box_slice after suppression
400
+ iou_threshold: unchanged
401
+ """
402
+ slice_index = ops.expand_dims(
403
+ ops.expand_dims(
404
+ ops.cast(
405
+ ops.linspace(
406
+ inner_idx * tile_size,
407
+ (inner_idx + 1) * tile_size - 1,
408
+ tile_size,
409
+ ),
410
+ "int32",
411
+ ),
412
+ axis=0,
413
+ ),
414
+ axis=-1,
415
+ )
416
+ new_slice = ops.expand_dims(
417
+ ops.take_along_axis(boxes, slice_index, axis=1), 0
418
+ )
419
+ iou = _bbox_overlap(new_slice, box_slice)
420
+ box_slice_after_suppression = (
421
+ ops.expand_dims(
422
+ ops.cast(ops.all(iou < iou_threshold, [1]), box_slice.dtype), 2
423
+ )
424
+ * box_slice
425
+ )
426
+ return boxes, box_slice_after_suppression, iou_threshold, inner_idx + 1
427
+
428
+
429
+ def _suppression_loop_body(boxes, iou_threshold, output_size, idx, tile_size):
430
+ """Process boxes in the range [idx*tile_size, (idx+1)*tile_size).
431
+
432
+ Args:
433
+ boxes: a tensor with a shape of [batch_size, anchors, 4].
434
+ iou_threshold: a float representing the threshold for deciding whether
435
+ boxes overlap too much with respect to IOU.
436
+ output_size: an int32 tensor of size [batch_size]. Representing the
437
+ number of selected boxes for each batch.
438
+ idx: an integer scalar representing induction variable.
439
+ tile_size: an integer representing the number of boxes in a tile
440
+
441
+ Returns:
442
+ boxes: updated boxes.
443
+ iou_threshold: pass down iou_threshold to the next iteration.
444
+ output_size: the updated output_size.
445
+ idx: the updated induction variable.
446
+ """
447
+ num_tiles = boxes.shape[1] // tile_size
448
+ batch_size = boxes.shape[0]
449
+
450
+ def cross_suppression_func(boxes, box_slice, iou_threshold, inner_idx):
451
+ return _cross_suppression(
452
+ boxes, box_slice, iou_threshold, inner_idx, tile_size
453
+ )
454
+
455
+ # Iterates over tiles that can possibly suppress the current tile.
456
+ slice_index = ops.expand_dims(
457
+ ops.expand_dims(
458
+ ops.cast(
459
+ ops.linspace(
460
+ idx * tile_size, (idx + 1) * tile_size - 1, tile_size
461
+ ),
462
+ "int32",
463
+ ),
464
+ axis=0,
465
+ ),
466
+ axis=-1,
467
+ )
468
+ box_slice = ops.take_along_axis(boxes, slice_index, axis=1)
469
+ _, box_slice, _, _ = ops.while_loop(
470
+ lambda _boxes, _box_slice, _threshold, inner_idx: inner_idx < idx,
471
+ cross_suppression_func,
472
+ [boxes, box_slice, iou_threshold, ops.array(0)],
473
+ )
474
+
475
+ # Iterates over the current tile to compute self-suppression.
476
+ iou = _bbox_overlap(box_slice, box_slice)
477
+ mask = ops.expand_dims(
478
+ ops.reshape(ops.arange(tile_size), [1, -1])
479
+ > ops.reshape(ops.arange(tile_size), [-1, 1]),
480
+ 0,
481
+ )
482
+ iou *= ops.cast(ops.logical_and(mask, iou >= iou_threshold), iou.dtype)
483
+ suppressed_iou, _, _, _ = ops.while_loop(
484
+ lambda _iou, loop_condition, _iou_sum, _: loop_condition,
485
+ _self_suppression,
486
+ [iou, ops.array(True), ops.sum(iou, [1, 2]), iou_threshold],
487
+ )
488
+ suppressed_box = ops.sum(suppressed_iou, 1) > 0
489
+ box_slice *= ops.expand_dims(
490
+ 1.0 - ops.cast(suppressed_box, box_slice.dtype), 2
491
+ )
492
+
493
+ # Uses box_slice to update the input boxes.
494
+ mask = ops.reshape(
495
+ ops.cast(ops.equal(ops.arange(num_tiles), idx), boxes.dtype),
496
+ [1, -1, 1, 1],
497
+ )
498
+ boxes = ops.tile(
499
+ ops.expand_dims(box_slice, 1), [1, num_tiles, 1, 1]
500
+ ) * mask + ops.reshape(boxes, [batch_size, num_tiles, tile_size, 4]) * (
501
+ 1 - mask
502
+ )
503
+ boxes = ops.reshape(boxes, [batch_size, -1, 4])
504
+
505
+ # Updates output_size.
506
+ output_size += ops.cast(ops.sum(ops.any(box_slice > 0, [2]), [1]), "int32")
507
+ return boxes, iou_threshold, output_size, idx + 1
508
+
509
+
510
+ def mask_invalid_detections(bounding_boxes):
511
+ """masks out invalid detections with -1s.
512
+
513
+ This utility is mainly used on the output of non-max suppression operations.
514
+ The output of non-max-suppression contains all the detections, even invalid
515
+ ones. Users are expected to use `num_detections` to determine how many boxes
516
+ are in each image.
517
+
518
+ In contrast, KerasHub expects all bounding boxes to be padded with -1s.
519
+ This function uses the value of `num_detections` to mask out
520
+ invalid boxes with -1s.
521
+
522
+ Args:
523
+ bounding_boxes: a dictionary complying with Keras bounding box format.
524
+ In addition to the normal required keys, these boxes are also
525
+ expected to have a `num_detections` key.
526
+
527
+ Returns:
528
+ bounding boxes with proper masking of the boxes according to
529
+ `num_detections`. This allows proper interop with non-max supression.
530
+ Returned boxes match the specification fed to the function, so if the
531
+ bounding box tensor uses `tf.RaggedTensor` to represent boxes the
532
+ returned value will also return `tf.RaggedTensor` representations.
533
+ """
534
+ # ensure we are complying with Keras bounding box format.
535
+ info = validate_format.validate_format(bounding_boxes)
536
+ if info["ragged"]:
537
+ raise ValueError(
538
+ "`bounding_box.mask_invalid_detections()` requires inputs to be "
539
+ "Dense tensors. Please call "
540
+ "`bounding_box.to_dense(bounding_boxes)` before passing your boxes "
541
+ "to `bounding_box.mask_invalid_detections()`."
542
+ )
543
+ if "num_detections" not in bounding_boxes:
544
+ raise ValueError(
545
+ "`bounding_boxes` must have key 'num_detections' "
546
+ "to be used with `bounding_box.mask_invalid_detections()`."
547
+ )
548
+
549
+ boxes = bounding_boxes.get("boxes")
550
+ classes = bounding_boxes.get("classes")
551
+ confidence = bounding_boxes.get("confidence", None)
552
+ num_detections = bounding_boxes.get("num_detections")
553
+
554
+ # Create a mask to select only the first N boxes from each batch
555
+ mask = ops.cast(
556
+ ops.expand_dims(ops.arange(boxes.shape[1]), axis=0),
557
+ num_detections.dtype,
558
+ )
559
+ mask = mask < num_detections[:, None]
560
+
561
+ classes = ops.where(mask, classes, -ops.ones_like(classes))
562
+
563
+ if confidence is not None:
564
+ confidence = ops.where(mask, confidence, -ops.ones_like(confidence))
565
+
566
+ # reuse mask for boxes
567
+ mask = ops.expand_dims(mask, axis=-1)
568
+ mask = ops.repeat(mask, repeats=boxes.shape[-1], axis=-1)
569
+ boxes = ops.where(mask, boxes, -ops.ones_like(boxes))
570
+
571
+ result = bounding_boxes.copy()
572
+
573
+ result["boxes"] = boxes
574
+ result["classes"] = classes
575
+ if confidence is not None:
576
+ result["confidence"] = confidence
577
+
578
+ return result
@@ -14,7 +14,6 @@
14
14
 
15
15
  from keras_hub.src.models.roberta.roberta_backbone import RobertaBackbone
16
16
  from keras_hub.src.models.roberta.roberta_presets import backbone_presets
17
- from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer
18
17
  from keras_hub.src.utils.preset_utils import register_presets
19
18
 
20
- register_presets(backbone_presets, (RobertaBackbone, RobertaTokenizer))
19
+ register_presets(backbone_presets, RobertaBackbone)