keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl → 0.16.0.dev2024092017__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (198) hide show
  1. keras_hub/__init__.py +0 -6
  2. keras_hub/api/__init__.py +2 -0
  3. keras_hub/api/bounding_box/__init__.py +36 -0
  4. keras_hub/api/layers/__init__.py +14 -0
  5. keras_hub/api/models/__init__.py +97 -48
  6. keras_hub/api/tokenizers/__init__.py +30 -0
  7. keras_hub/api/utils/__init__.py +22 -0
  8. keras_hub/src/api_export.py +15 -9
  9. keras_hub/src/bounding_box/__init__.py +13 -0
  10. keras_hub/src/bounding_box/converters.py +529 -0
  11. keras_hub/src/bounding_box/formats.py +162 -0
  12. keras_hub/src/bounding_box/iou.py +263 -0
  13. keras_hub/src/bounding_box/to_dense.py +95 -0
  14. keras_hub/src/bounding_box/to_ragged.py +99 -0
  15. keras_hub/src/bounding_box/utils.py +194 -0
  16. keras_hub/src/bounding_box/validate_format.py +99 -0
  17. keras_hub/src/layers/preprocessing/audio_converter.py +121 -0
  18. keras_hub/src/layers/preprocessing/image_converter.py +130 -0
  19. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +2 -0
  20. keras_hub/src/layers/preprocessing/multi_segment_packer.py +9 -8
  21. keras_hub/src/layers/preprocessing/preprocessing_layer.py +2 -29
  22. keras_hub/src/layers/preprocessing/random_deletion.py +33 -31
  23. keras_hub/src/layers/preprocessing/random_swap.py +33 -31
  24. keras_hub/src/layers/preprocessing/resizing_image_converter.py +101 -0
  25. keras_hub/src/layers/preprocessing/start_end_packer.py +3 -2
  26. keras_hub/src/models/albert/__init__.py +1 -2
  27. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +6 -86
  28. keras_hub/src/models/albert/{albert_classifier.py → albert_text_classifier.py} +34 -10
  29. keras_hub/src/models/albert/{albert_preprocessor.py → albert_text_classifier_preprocessor.py} +14 -70
  30. keras_hub/src/models/albert/albert_tokenizer.py +17 -36
  31. keras_hub/src/models/backbone.py +12 -34
  32. keras_hub/src/models/bart/__init__.py +1 -2
  33. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +21 -148
  34. keras_hub/src/models/bart/bart_tokenizer.py +12 -39
  35. keras_hub/src/models/bert/__init__.py +1 -5
  36. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +6 -87
  37. keras_hub/src/models/bert/bert_presets.py +1 -4
  38. keras_hub/src/models/bert/{bert_classifier.py → bert_text_classifier.py} +19 -12
  39. keras_hub/src/models/bert/{bert_preprocessor.py → bert_text_classifier_preprocessor.py} +14 -70
  40. keras_hub/src/models/bert/bert_tokenizer.py +17 -35
  41. keras_hub/src/models/bloom/__init__.py +1 -2
  42. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +6 -91
  43. keras_hub/src/models/bloom/bloom_tokenizer.py +12 -41
  44. keras_hub/src/models/causal_lm.py +10 -29
  45. keras_hub/src/models/causal_lm_preprocessor.py +195 -0
  46. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +54 -15
  47. keras_hub/src/models/deberta_v3/__init__.py +1 -4
  48. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +14 -77
  49. keras_hub/src/models/deberta_v3/{deberta_v3_classifier.py → deberta_v3_text_classifier.py} +16 -11
  50. keras_hub/src/models/deberta_v3/{deberta_v3_preprocessor.py → deberta_v3_text_classifier_preprocessor.py} +23 -64
  51. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +30 -25
  52. keras_hub/src/models/densenet/densenet_backbone.py +46 -22
  53. keras_hub/src/models/distil_bert/__init__.py +1 -4
  54. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +14 -76
  55. keras_hub/src/models/distil_bert/{distil_bert_classifier.py → distil_bert_text_classifier.py} +17 -12
  56. keras_hub/src/models/distil_bert/{distil_bert_preprocessor.py → distil_bert_text_classifier_preprocessor.py} +23 -63
  57. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +19 -35
  58. keras_hub/src/models/efficientnet/__init__.py +13 -0
  59. keras_hub/src/models/efficientnet/efficientnet_backbone.py +569 -0
  60. keras_hub/src/models/efficientnet/fusedmbconv.py +229 -0
  61. keras_hub/src/models/efficientnet/mbconv.py +238 -0
  62. keras_hub/src/models/electra/__init__.py +1 -2
  63. keras_hub/src/models/electra/electra_tokenizer.py +17 -32
  64. keras_hub/src/models/f_net/__init__.py +1 -2
  65. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +12 -78
  66. keras_hub/src/models/f_net/{f_net_classifier.py → f_net_text_classifier.py} +17 -10
  67. keras_hub/src/models/f_net/{f_net_preprocessor.py → f_net_text_classifier_preprocessor.py} +19 -63
  68. keras_hub/src/models/f_net/f_net_tokenizer.py +17 -35
  69. keras_hub/src/models/falcon/__init__.py +1 -2
  70. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +6 -89
  71. keras_hub/src/models/falcon/falcon_tokenizer.py +12 -35
  72. keras_hub/src/models/gemma/__init__.py +1 -2
  73. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +6 -90
  74. keras_hub/src/models/gemma/gemma_decoder_block.py +1 -1
  75. keras_hub/src/models/gemma/gemma_tokenizer.py +12 -23
  76. keras_hub/src/models/gpt2/__init__.py +1 -2
  77. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +6 -89
  78. keras_hub/src/models/gpt2/gpt2_preprocessor.py +12 -90
  79. keras_hub/src/models/gpt2/gpt2_tokenizer.py +12 -34
  80. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +6 -91
  81. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +12 -34
  82. keras_hub/src/models/image_classifier.py +0 -5
  83. keras_hub/src/models/image_classifier_preprocessor.py +83 -0
  84. keras_hub/src/models/llama/__init__.py +1 -2
  85. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +6 -85
  86. keras_hub/src/models/llama/llama_tokenizer.py +12 -25
  87. keras_hub/src/models/llama3/__init__.py +1 -2
  88. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +6 -89
  89. keras_hub/src/models/llama3/llama3_tokenizer.py +12 -33
  90. keras_hub/src/models/masked_lm.py +0 -2
  91. keras_hub/src/models/masked_lm_preprocessor.py +156 -0
  92. keras_hub/src/models/mistral/__init__.py +1 -2
  93. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +6 -91
  94. keras_hub/src/models/mistral/mistral_tokenizer.py +12 -23
  95. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +2 -2
  96. keras_hub/src/models/mobilenet/__init__.py +13 -0
  97. keras_hub/src/models/mobilenet/mobilenet_backbone.py +530 -0
  98. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +114 -0
  99. keras_hub/src/models/opt/__init__.py +1 -2
  100. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +6 -93
  101. keras_hub/src/models/opt/opt_tokenizer.py +12 -41
  102. keras_hub/src/models/pali_gemma/__init__.py +1 -4
  103. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +28 -28
  104. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +25 -0
  105. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +5 -5
  106. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +11 -3
  107. keras_hub/src/models/phi3/__init__.py +1 -2
  108. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -9
  109. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +6 -89
  110. keras_hub/src/models/phi3/phi3_tokenizer.py +12 -36
  111. keras_hub/src/models/preprocessor.py +72 -83
  112. keras_hub/src/models/resnet/__init__.py +6 -0
  113. keras_hub/src/models/resnet/resnet_backbone.py +390 -42
  114. keras_hub/src/models/resnet/resnet_image_classifier.py +33 -6
  115. keras_hub/src/models/resnet/resnet_image_classifier_preprocessor.py +28 -0
  116. keras_hub/src/models/{llama3/llama3_preprocessor.py → resnet/resnet_image_converter.py} +7 -5
  117. keras_hub/src/models/resnet/resnet_presets.py +95 -0
  118. keras_hub/src/models/retinanet/__init__.py +13 -0
  119. keras_hub/src/models/retinanet/anchor_generator.py +175 -0
  120. keras_hub/src/models/retinanet/box_matcher.py +259 -0
  121. keras_hub/src/models/retinanet/non_max_supression.py +578 -0
  122. keras_hub/src/models/roberta/__init__.py +1 -2
  123. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +22 -74
  124. keras_hub/src/models/roberta/{roberta_classifier.py → roberta_text_classifier.py} +16 -11
  125. keras_hub/src/models/roberta/{roberta_preprocessor.py → roberta_text_classifier_preprocessor.py} +21 -53
  126. keras_hub/src/models/roberta/roberta_tokenizer.py +13 -52
  127. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +269 -0
  128. keras_hub/src/models/stable_diffusion_v3/__init__.py +13 -0
  129. keras_hub/src/models/stable_diffusion_v3/clip_encoder_block.py +103 -0
  130. keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +93 -0
  131. keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py +149 -0
  132. keras_hub/src/models/stable_diffusion_v3/clip_tokenizer.py +167 -0
  133. keras_hub/src/models/stable_diffusion_v3/mmdit.py +427 -0
  134. keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +317 -0
  135. keras_hub/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py +74 -0
  136. keras_hub/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py +155 -0
  137. keras_hub/src/models/stable_diffusion_v3/vae_attention.py +126 -0
  138. keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +186 -0
  139. keras_hub/src/models/t5/__init__.py +1 -2
  140. keras_hub/src/models/t5/t5_tokenizer.py +13 -23
  141. keras_hub/src/models/task.py +71 -116
  142. keras_hub/src/models/{classifier.py → text_classifier.py} +19 -13
  143. keras_hub/src/models/text_classifier_preprocessor.py +138 -0
  144. keras_hub/src/models/whisper/__init__.py +1 -2
  145. keras_hub/src/models/whisper/{whisper_audio_feature_extractor.py → whisper_audio_converter.py} +20 -18
  146. keras_hub/src/models/whisper/whisper_backbone.py +0 -3
  147. keras_hub/src/models/whisper/whisper_presets.py +10 -10
  148. keras_hub/src/models/whisper/whisper_tokenizer.py +20 -16
  149. keras_hub/src/models/xlm_roberta/__init__.py +1 -4
  150. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +26 -72
  151. keras_hub/src/models/xlm_roberta/{xlm_roberta_classifier.py → xlm_roberta_text_classifier.py} +16 -11
  152. keras_hub/src/models/xlm_roberta/{xlm_roberta_preprocessor.py → xlm_roberta_text_classifier_preprocessor.py} +26 -53
  153. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +25 -10
  154. keras_hub/src/tests/test_case.py +46 -0
  155. keras_hub/src/tokenizers/byte_pair_tokenizer.py +30 -17
  156. keras_hub/src/tokenizers/byte_tokenizer.py +14 -15
  157. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +20 -7
  158. keras_hub/src/tokenizers/tokenizer.py +67 -32
  159. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +14 -15
  160. keras_hub/src/tokenizers/word_piece_tokenizer.py +34 -47
  161. keras_hub/src/utils/imagenet/__init__.py +13 -0
  162. keras_hub/src/utils/imagenet/imagenet_utils.py +1067 -0
  163. keras_hub/src/utils/keras_utils.py +0 -50
  164. keras_hub/src/utils/preset_utils.py +230 -68
  165. keras_hub/src/utils/tensor_utils.py +187 -69
  166. keras_hub/src/utils/timm/convert_resnet.py +19 -16
  167. keras_hub/src/utils/timm/preset_loader.py +66 -0
  168. keras_hub/src/utils/transformers/convert_albert.py +193 -0
  169. keras_hub/src/utils/transformers/convert_bart.py +373 -0
  170. keras_hub/src/utils/transformers/convert_bert.py +7 -17
  171. keras_hub/src/utils/transformers/convert_distilbert.py +10 -20
  172. keras_hub/src/utils/transformers/convert_gemma.py +5 -19
  173. keras_hub/src/utils/transformers/convert_gpt2.py +5 -18
  174. keras_hub/src/utils/transformers/convert_llama3.py +7 -18
  175. keras_hub/src/utils/transformers/convert_mistral.py +129 -0
  176. keras_hub/src/utils/transformers/convert_pali_gemma.py +7 -29
  177. keras_hub/src/utils/transformers/preset_loader.py +77 -0
  178. keras_hub/src/utils/transformers/safetensor_utils.py +2 -2
  179. keras_hub/src/version_utils.py +1 -1
  180. keras_hub_nightly-0.16.0.dev2024092017.dist-info/METADATA +202 -0
  181. keras_hub_nightly-0.16.0.dev2024092017.dist-info/RECORD +334 -0
  182. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/WHEEL +1 -1
  183. keras_hub/src/models/bart/bart_preprocessor.py +0 -276
  184. keras_hub/src/models/bloom/bloom_preprocessor.py +0 -185
  185. keras_hub/src/models/electra/electra_preprocessor.py +0 -154
  186. keras_hub/src/models/falcon/falcon_preprocessor.py +0 -187
  187. keras_hub/src/models/gemma/gemma_preprocessor.py +0 -191
  188. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +0 -145
  189. keras_hub/src/models/llama/llama_preprocessor.py +0 -189
  190. keras_hub/src/models/mistral/mistral_preprocessor.py +0 -190
  191. keras_hub/src/models/opt/opt_preprocessor.py +0 -188
  192. keras_hub/src/models/phi3/phi3_preprocessor.py +0 -190
  193. keras_hub/src/models/whisper/whisper_preprocessor.py +0 -326
  194. keras_hub/src/utils/timm/convert.py +0 -37
  195. keras_hub/src/utils/transformers/convert.py +0 -101
  196. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +0 -34
  197. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +0 -297
  198. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,529 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Converter functions for working with bounding box formats."""
15
+
16
+ import keras
17
+ from keras import ops
18
+
19
+ from keras_hub.src.api_export import keras_hub_export
20
+
21
+ try:
22
+ import tensorflow as tf
23
+ except ImportError:
24
+ tf = None
25
+
26
+
27
+ # Internal exception to propagate the fact images was not passed to a converter
28
+ # that needs it.
29
+ class RequiresImagesException(Exception):
30
+ pass
31
+
32
+
33
+ ALL_AXES = 4
34
+
35
+
36
+ def _encode_box_to_deltas(
37
+ anchors,
38
+ boxes,
39
+ anchor_format: str,
40
+ box_format: str,
41
+ variance=None,
42
+ image_shape=None,
43
+ ):
44
+ """Converts bounding_boxes from `center_yxhw` to delta format."""
45
+ if variance is not None:
46
+ variance = ops.convert_to_tensor(variance, "float32")
47
+ var_len = variance.shape[-1]
48
+
49
+ if var_len != 4:
50
+ raise ValueError(f"`variance` must be length 4, got {variance}")
51
+ encoded_anchors = convert_format(
52
+ anchors,
53
+ source=anchor_format,
54
+ target="center_yxhw",
55
+ image_shape=image_shape,
56
+ )
57
+ boxes = convert_format(
58
+ boxes, source=box_format, target="center_yxhw", image_shape=image_shape
59
+ )
60
+ anchor_dimensions = ops.maximum(
61
+ encoded_anchors[..., 2:], keras.backend.epsilon()
62
+ )
63
+ box_dimensions = ops.maximum(boxes[..., 2:], keras.backend.epsilon())
64
+ # anchors be unbatched, boxes can either be batched or unbatched.
65
+ boxes_delta = ops.concatenate(
66
+ [
67
+ (boxes[..., :2] - encoded_anchors[..., :2]) / anchor_dimensions,
68
+ ops.log(box_dimensions / anchor_dimensions),
69
+ ],
70
+ axis=-1,
71
+ )
72
+ if variance is not None:
73
+ boxes_delta /= variance
74
+ return boxes_delta
75
+
76
+
77
+ def _decode_deltas_to_boxes(
78
+ anchors,
79
+ boxes_delta,
80
+ anchor_format: str,
81
+ box_format: str,
82
+ variance=None,
83
+ image_shape=None,
84
+ ):
85
+ """Converts bounding_boxes from delta format to `center_yxhw`."""
86
+ if variance is not None:
87
+ variance = ops.convert_to_tensor(variance, "float32")
88
+ var_len = variance.shape[-1]
89
+
90
+ if var_len != 4:
91
+ raise ValueError(f"`variance` must be length 4, got {variance}")
92
+
93
+ def decode_single_level(anchor, box_delta):
94
+ encoded_anchor = convert_format(
95
+ anchor,
96
+ source=anchor_format,
97
+ target="center_yxhw",
98
+ image_shape=image_shape,
99
+ )
100
+ if variance is not None:
101
+ box_delta = box_delta * variance
102
+ # anchors be unbatched, boxes can either be batched or unbatched.
103
+ box = ops.concatenate(
104
+ [
105
+ box_delta[..., :2] * encoded_anchor[..., 2:]
106
+ + encoded_anchor[..., :2],
107
+ ops.exp(box_delta[..., 2:]) * encoded_anchor[..., 2:],
108
+ ],
109
+ axis=-1,
110
+ )
111
+ box = convert_format(
112
+ box,
113
+ source="center_yxhw",
114
+ target=box_format,
115
+ image_shape=image_shape,
116
+ )
117
+ return box
118
+
119
+ if isinstance(anchors, dict) and isinstance(boxes_delta, dict):
120
+ boxes = {}
121
+ for lvl, anchor in anchors.items():
122
+ boxes[lvl] = decode_single_level(anchor, boxes_delta[lvl])
123
+ return boxes
124
+ else:
125
+ return decode_single_level(anchors, boxes_delta)
126
+
127
+
128
+ def _center_yxhw_to_xyxy(boxes, images=None, image_shape=None):
129
+ y, x, height, width = ops.split(boxes, ALL_AXES, axis=-1)
130
+ return ops.concatenate(
131
+ [x - width / 2.0, y - height / 2.0, x + width / 2.0, y + height / 2.0],
132
+ axis=-1,
133
+ )
134
+
135
+
136
+ def _center_xywh_to_xyxy(boxes, images=None, image_shape=None):
137
+ x, y, width, height = ops.split(boxes, ALL_AXES, axis=-1)
138
+ return ops.concatenate(
139
+ [x - width / 2.0, y - height / 2.0, x + width / 2.0, y + height / 2.0],
140
+ axis=-1,
141
+ )
142
+
143
+
144
+ def _xywh_to_xyxy(boxes, images=None, image_shape=None):
145
+ x, y, width, height = ops.split(boxes, ALL_AXES, axis=-1)
146
+ return ops.concatenate([x, y, x + width, y + height], axis=-1)
147
+
148
+
149
+ def _xyxy_to_center_yxhw(boxes, images=None, image_shape=None):
150
+ left, top, right, bottom = ops.split(boxes, ALL_AXES, axis=-1)
151
+ return ops.concatenate(
152
+ [
153
+ (top + bottom) / 2.0,
154
+ (left + right) / 2.0,
155
+ bottom - top,
156
+ right - left,
157
+ ],
158
+ axis=-1,
159
+ )
160
+
161
+
162
+ def _rel_xywh_to_xyxy(boxes, images=None, image_shape=None):
163
+ image_height, image_width = _image_shape(images, image_shape, boxes)
164
+ x, y, width, height = ops.split(boxes, ALL_AXES, axis=-1)
165
+ return ops.concatenate(
166
+ [
167
+ image_width * x,
168
+ image_height * y,
169
+ image_width * (x + width),
170
+ image_height * (y + height),
171
+ ],
172
+ axis=-1,
173
+ )
174
+
175
+
176
+ def _xyxy_no_op(boxes, images=None, image_shape=None):
177
+ return boxes
178
+
179
+
180
+ def _xyxy_to_xywh(boxes, images=None, image_shape=None):
181
+ left, top, right, bottom = ops.split(boxes, ALL_AXES, axis=-1)
182
+ return ops.concatenate(
183
+ [left, top, right - left, bottom - top],
184
+ axis=-1,
185
+ )
186
+
187
+
188
+ def _xyxy_to_rel_xywh(boxes, images=None, image_shape=None):
189
+ image_height, image_width = _image_shape(images, image_shape, boxes)
190
+ left, top, right, bottom = ops.split(boxes, ALL_AXES, axis=-1)
191
+ left, right = (
192
+ left / image_width,
193
+ right / image_width,
194
+ )
195
+ top, bottom = top / image_height, bottom / image_height
196
+ return ops.concatenate(
197
+ [left, top, right - left, bottom - top],
198
+ axis=-1,
199
+ )
200
+
201
+
202
+ def _xyxy_to_center_xywh(boxes, images=None, image_shape=None):
203
+ left, top, right, bottom = ops.split(boxes, ALL_AXES, axis=-1)
204
+ return ops.concatenate(
205
+ [
206
+ (left + right) / 2.0,
207
+ (top + bottom) / 2.0,
208
+ right - left,
209
+ bottom - top,
210
+ ],
211
+ axis=-1,
212
+ )
213
+
214
+
215
+ def _rel_xyxy_to_xyxy(boxes, images=None, image_shape=None):
216
+ image_height, image_width = _image_shape(images, image_shape, boxes)
217
+ left, top, right, bottom = ops.split(
218
+ boxes,
219
+ ALL_AXES,
220
+ axis=-1,
221
+ )
222
+ left, right = left * image_width, right * image_width
223
+ top, bottom = top * image_height, bottom * image_height
224
+ return ops.concatenate(
225
+ [left, top, right, bottom],
226
+ axis=-1,
227
+ )
228
+
229
+
230
+ def _xyxy_to_rel_xyxy(boxes, images=None, image_shape=None):
231
+ image_height, image_width = _image_shape(images, image_shape, boxes)
232
+ left, top, right, bottom = ops.split(
233
+ boxes,
234
+ ALL_AXES,
235
+ axis=-1,
236
+ )
237
+ left, right = left / image_width, right / image_width
238
+ top, bottom = top / image_height, bottom / image_height
239
+ return ops.concatenate(
240
+ [left, top, right, bottom],
241
+ axis=-1,
242
+ )
243
+
244
+
245
+ def _yxyx_to_xyxy(boxes, images=None, image_shape=None):
246
+ y1, x1, y2, x2 = ops.split(boxes, ALL_AXES, axis=-1)
247
+ return ops.concatenate([x1, y1, x2, y2], axis=-1)
248
+
249
+
250
+ def _rel_yxyx_to_xyxy(boxes, images=None, image_shape=None):
251
+ image_height, image_width = _image_shape(images, image_shape, boxes)
252
+ top, left, bottom, right = ops.split(
253
+ boxes,
254
+ ALL_AXES,
255
+ axis=-1,
256
+ )
257
+ left, right = left * image_width, right * image_width
258
+ top, bottom = top * image_height, bottom * image_height
259
+ return ops.concatenate(
260
+ [left, top, right, bottom],
261
+ axis=-1,
262
+ )
263
+
264
+
265
+ def _xyxy_to_yxyx(boxes, images=None, image_shape=None):
266
+ x1, y1, x2, y2 = ops.split(boxes, ALL_AXES, axis=-1)
267
+ return ops.concatenate([y1, x1, y2, x2], axis=-1)
268
+
269
+
270
+ def _xyxy_to_rel_yxyx(boxes, images=None, image_shape=None):
271
+ image_height, image_width = _image_shape(images, image_shape, boxes)
272
+ left, top, right, bottom = ops.split(boxes, ALL_AXES, axis=-1)
273
+ left, right = left / image_width, right / image_width
274
+ top, bottom = top / image_height, bottom / image_height
275
+ return ops.concatenate(
276
+ [top, left, bottom, right],
277
+ axis=-1,
278
+ )
279
+
280
+
281
+ TO_XYXY_CONVERTERS = {
282
+ "xywh": _xywh_to_xyxy,
283
+ "center_xywh": _center_xywh_to_xyxy,
284
+ "center_yxhw": _center_yxhw_to_xyxy,
285
+ "rel_xywh": _rel_xywh_to_xyxy,
286
+ "xyxy": _xyxy_no_op,
287
+ "rel_xyxy": _rel_xyxy_to_xyxy,
288
+ "yxyx": _yxyx_to_xyxy,
289
+ "rel_yxyx": _rel_yxyx_to_xyxy,
290
+ }
291
+
292
+ FROM_XYXY_CONVERTERS = {
293
+ "xywh": _xyxy_to_xywh,
294
+ "center_xywh": _xyxy_to_center_xywh,
295
+ "center_yxhw": _xyxy_to_center_yxhw,
296
+ "rel_xywh": _xyxy_to_rel_xywh,
297
+ "xyxy": _xyxy_no_op,
298
+ "rel_xyxy": _xyxy_to_rel_xyxy,
299
+ "yxyx": _xyxy_to_yxyx,
300
+ "rel_yxyx": _xyxy_to_rel_yxyx,
301
+ }
302
+
303
+
304
+ @keras_hub_export("keras_hub.bounding_box.convert_format")
305
+ def convert_format(
306
+ boxes, source, target, images=None, image_shape=None, dtype="float32"
307
+ ):
308
+ f"""Converts bounding_boxes from one format to another.
309
+
310
+ Supported formats are:
311
+ - `"xyxy"`, also known as `corners` format. In this format the first four
312
+ axes represent `[left, top, right, bottom]` in that order.
313
+ - `"rel_xyxy"`. In this format, the axes are the same as `"xyxy"` but the x
314
+ coordinates are normalized using the image width, and the y axes the
315
+ image height. All values in `rel_xyxy` are in the range `(0, 1)`.
316
+ - `"xywh"`. In this format the first four axes represent
317
+ `[left, top, width, height]`.
318
+ - `"rel_xywh". In this format the first four axes represent
319
+ [left, top, width, height], just like `"xywh"`. Unlike `"xywh"`, the
320
+ values are in the range (0, 1) instead of absolute pixel values.
321
+ - `"center_xyWH"`. In this format the first two coordinates represent the x
322
+ and y coordinates of the center of the bounding box, while the last two
323
+ represent the width and height of the bounding box.
324
+ - `"center_yxHW"`. In this format the first two coordinates represent the y
325
+ and x coordinates of the center of the bounding box, while the last two
326
+ represent the height and width of the bounding box.
327
+ - `"yxyx"`. In this format the first four axes represent
328
+ [top, left, bottom, right] in that order.
329
+ - `"rel_yxyx"`. In this format, the axes are the same as `"yxyx"` but the x
330
+ coordinates are normalized using the image width, and the y axes the
331
+ image height. All values in `rel_yxyx` are in the range (0, 1).
332
+ Formats are case insensitive. It is recommended that you capitalize width
333
+ and height to maximize the visual difference between `"xyWH"` and `"xyxy"`.
334
+
335
+ Relative formats, abbreviated `rel`, make use of the shapes of the `images`
336
+ passed. In these formats, the coordinates, widths, and heights are all
337
+ specified as percentages of the host image. `images` may be a ragged
338
+ Tensor. Note that using a ragged Tensor for images may cause a substantial
339
+ performance loss, as each image will need to be processed separately due to
340
+ the mismatching image shapes.
341
+
342
+ Example:
343
+
344
+ ```python
345
+ boxes = load_coco_dataset()
346
+ boxes_in_xywh = keras_hub.bounding_box.convert_format(
347
+ boxes,
348
+ source='xyxy',
349
+ target='xyWH'
350
+ )
351
+ ```
352
+
353
+ Args:
354
+ boxes: tensor representing bounding boxes in the format specified in
355
+ the `source` parameter. `boxes` can optionally have extra
356
+ dimensions stacked on the final axis to store metadata. boxes
357
+ should be a 3D tensor, with the shape `[batch_size, num_boxes, 4]`.
358
+ Alternatively, boxes can be a dictionary with key 'boxes' containing
359
+ a tensor matching the aforementioned spec.
360
+ source:One of {" ".join([f'"{f}"' for f in TO_XYXY_CONVERTERS.keys()])}.
361
+ Used to specify the original format of the `boxes` parameter.
362
+ target:One of {" ".join([f'"{f}"' for f in TO_XYXY_CONVERTERS.keys()])}.
363
+ Used to specify the destination format of the `boxes` parameter.
364
+ images: (Optional) a batch of images aligned with `boxes` on the first
365
+ axis. Should be at least 3 dimensions, with the first 3 dimensions
366
+ representing: `[batch_size, height, width]`. Used in some
367
+ converters to compute relative pixel values of the bounding box
368
+ dimensions. Required when transforming from a rel format to a
369
+ non-rel format.
370
+ dtype: the data type to use when transforming the boxes, defaults to
371
+ `"float32"`.
372
+ """
373
+ if isinstance(boxes, dict):
374
+ converted_boxes = boxes.copy()
375
+ converted_boxes["boxes"] = convert_format(
376
+ boxes["boxes"],
377
+ source=source,
378
+ target=target,
379
+ images=images,
380
+ image_shape=image_shape,
381
+ dtype=dtype,
382
+ )
383
+ return converted_boxes
384
+
385
+ if boxes.shape[-1] is not None and boxes.shape[-1] != 4:
386
+ raise ValueError(
387
+ "Expected `boxes` to be a Tensor with a final dimension of "
388
+ f"`4`. Instead, got `boxes.shape={boxes.shape}`."
389
+ )
390
+ if images is not None and image_shape is not None:
391
+ raise ValueError(
392
+ "convert_format() expects either `images` or `image_shape`, but "
393
+ f"not both. Received images={images} image_shape={image_shape}"
394
+ )
395
+
396
+ _validate_image_shape(image_shape)
397
+
398
+ source = source.lower()
399
+ target = target.lower()
400
+ if source not in TO_XYXY_CONVERTERS:
401
+ raise ValueError(
402
+ "`convert_format()` received an unsupported format for the "
403
+ "argument `source`. `source` should be one of "
404
+ f"{TO_XYXY_CONVERTERS.keys()}. Got source={source}"
405
+ )
406
+ if target not in FROM_XYXY_CONVERTERS:
407
+ raise ValueError(
408
+ "`convert_format()` received an unsupported format for the "
409
+ "argument `target`. `target` should be one of "
410
+ f"{FROM_XYXY_CONVERTERS.keys()}. Got target={target}"
411
+ )
412
+
413
+ boxes = ops.cast(boxes, dtype)
414
+ if source == target:
415
+ return boxes
416
+
417
+ # rel->rel conversions should not require images
418
+ if source.startswith("rel") and target.startswith("rel"):
419
+ source = source.replace("rel_", "", 1)
420
+ target = target.replace("rel_", "", 1)
421
+
422
+ boxes, images, squeeze = _format_inputs(boxes, images)
423
+ to_xyxy_fn = TO_XYXY_CONVERTERS[source]
424
+ from_xyxy_fn = FROM_XYXY_CONVERTERS[target]
425
+
426
+ try:
427
+ in_xyxy = to_xyxy_fn(boxes, images=images, image_shape=image_shape)
428
+ result = from_xyxy_fn(in_xyxy, images=images, image_shape=image_shape)
429
+ except RequiresImagesException:
430
+ raise ValueError(
431
+ "convert_format() must receive `images` or `image_shape` when "
432
+ "transforming between relative and absolute formats."
433
+ f"convert_format() received source=`{format}`, target=`{format}, "
434
+ f"but images={images} and image_shape={image_shape}."
435
+ )
436
+
437
+ return _format_outputs(result, squeeze)
438
+
439
+
440
+ def _format_inputs(boxes, images):
441
+ boxes_rank = len(boxes.shape)
442
+ if boxes_rank > 3:
443
+ raise ValueError(
444
+ "Expected len(boxes.shape)=2, or len(boxes.shape)=3, got "
445
+ f"len(boxes.shape)={boxes_rank}"
446
+ )
447
+ boxes_includes_batch = boxes_rank == 3
448
+ # Determine if images needs an expand_dims() call
449
+ if images is not None:
450
+ images_rank = len(images.shape)
451
+ if images_rank > 4:
452
+ raise ValueError(
453
+ "Expected len(images.shape)=2, or len(images.shape)=3, got "
454
+ f"len(images.shape)={images_rank}"
455
+ )
456
+ images_include_batch = images_rank == 4
457
+ if boxes_includes_batch != images_include_batch:
458
+ raise ValueError(
459
+ "convert_format() expects both boxes and images to be batched, "
460
+ "or both boxes and images to be unbatched. Received "
461
+ f"len(boxes.shape)={boxes_rank}, "
462
+ f"len(images.shape)={images_rank}. Expected either "
463
+ "len(boxes.shape)=2 AND len(images.shape)=3, or "
464
+ "len(boxes.shape)=3 AND len(images.shape)=4."
465
+ )
466
+ if not images_include_batch:
467
+ images = ops.expand_dims(images, axis=0)
468
+
469
+ if not boxes_includes_batch:
470
+ return ops.expand_dims(boxes, axis=0), images, True
471
+ return boxes, images, False
472
+
473
+
474
+ def _validate_image_shape(image_shape):
475
+ # Escape early if image_shape is None and skip validation.
476
+ if image_shape is None:
477
+ return
478
+ # tuple/list
479
+ if isinstance(image_shape, (tuple, list)):
480
+ if len(image_shape) != 3:
481
+ raise ValueError(
482
+ "image_shape should be of length 3, but got "
483
+ f"image_shape={image_shape}"
484
+ )
485
+ return
486
+
487
+ # tensor
488
+ if ops.is_tensor(image_shape):
489
+ if len(image_shape.shape) > 1:
490
+ raise ValueError(
491
+ "image_shape.shape should be (3), but got "
492
+ f"image_shape.shape={image_shape.shape}"
493
+ )
494
+ if image_shape.shape[0] != 3:
495
+ raise ValueError(
496
+ "image_shape.shape should be (3), but got "
497
+ f"image_shape.shape={image_shape.shape}"
498
+ )
499
+ return
500
+
501
+ # Warn about failure cases
502
+ raise ValueError(
503
+ "Expected image_shape to be either a tuple, list, Tensor. "
504
+ f"Received image_shape={image_shape}"
505
+ )
506
+
507
+
508
+ def _format_outputs(boxes, squeeze):
509
+ if squeeze:
510
+ return ops.squeeze(boxes, axis=0)
511
+ return boxes
512
+
513
+
514
+ def _image_shape(images, image_shape, boxes):
515
+ if images is None and image_shape is None:
516
+ raise RequiresImagesException()
517
+
518
+ if image_shape is None:
519
+ if not isinstance(images, tf.RaggedTensor):
520
+ image_shape = ops.shape(images)
521
+ height, width = image_shape[1], image_shape[2]
522
+ else:
523
+ height = ops.reshape(images.row_lengths(), (-1, 1))
524
+ width = ops.reshape(ops.max(images.row_lengths(axis=2), 1), (-1, 1))
525
+ height = ops.expand_dims(height, axis=-1)
526
+ width = ops.expand_dims(width, axis=-1)
527
+ else:
528
+ height, width = image_shape[0], image_shape[1]
529
+ return ops.cast(height, boxes.dtype), ops.cast(width, boxes.dtype)
@@ -0,0 +1,162 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ formats.py contains axis information for each supported format.
16
+ """
17
+
18
+ from keras_hub.src.api_export import keras_hub_export
19
+
20
+
21
+ @keras_hub_export("keras_hub.bounding_box.XYXY")
22
+ class XYXY:
23
+ """XYXY contains axis indices for the XYXY format.
24
+
25
+ All values in the XYXY format should be absolute pixel values.
26
+
27
+ The XYXY format consists of the following required indices:
28
+
29
+ - LEFT: left of the bounding box
30
+ - TOP: top of the bounding box
31
+ - RIGHT: right of the bounding box
32
+ - BOTTOM: bottom of the bounding box
33
+ """
34
+
35
+ LEFT = 0
36
+ TOP = 1
37
+ RIGHT = 2
38
+ BOTTOM = 3
39
+
40
+
41
+ @keras_hub_export("keras_hub.bounding_box.REL_XYXY")
42
+ class REL_XYXY:
43
+ """REL_XYXY contains axis indices for the REL_XYXY format.
44
+
45
+ REL_XYXY is like XYXY, but each value is relative to the width and height of
46
+ the origin image. Values are percentages of the origin images' width and
47
+ height respectively.
48
+
49
+ The REL_XYXY format consists of the following required indices:
50
+
51
+ - LEFT: left of the bounding box
52
+ - TOP: top of the bounding box
53
+ - RIGHT: right of the bounding box
54
+ - BOTTOM: bottom of the bounding box
55
+ """
56
+
57
+ LEFT = 0
58
+ TOP = 1
59
+ RIGHT = 2
60
+ BOTTOM = 3
61
+
62
+
63
+ @keras_hub_export("keras_hub.bounding_box.CENTER_XYWH")
64
+ class CENTER_XYWH:
65
+ """CENTER_XYWH contains axis indices for the CENTER_XYWH format.
66
+
67
+ All values in the CENTER_XYWH format should be absolute pixel values.
68
+
69
+ The CENTER_XYWH format consists of the following required indices:
70
+
71
+ - X: X coordinate of the center of the bounding box
72
+ - Y: Y coordinate of the center of the bounding box
73
+ - WIDTH: width of the bounding box
74
+ - HEIGHT: height of the bounding box
75
+ """
76
+
77
+ X = 0
78
+ Y = 1
79
+ WIDTH = 2
80
+ HEIGHT = 3
81
+
82
+
83
+ @keras_hub_export("keras_hub.bounding_box.XYWH")
84
+ class XYWH:
85
+ """XYWH contains axis indices for the XYWH format.
86
+
87
+ All values in the XYWH format should be absolute pixel values.
88
+
89
+ The XYWH format consists of the following required indices:
90
+
91
+ - X: X coordinate of the left of the bounding box
92
+ - Y: Y coordinate of the top of the bounding box
93
+ - WIDTH: width of the bounding box
94
+ - HEIGHT: height of the bounding box
95
+ """
96
+
97
+ X = 0
98
+ Y = 1
99
+ WIDTH = 2
100
+ HEIGHT = 3
101
+
102
+
103
+ @keras_hub_export("keras_hub.bounding_box.REL_XYWH")
104
+ class REL_XYWH:
105
+ """REL_XYWH contains axis indices for the XYWH format.
106
+
107
+ REL_XYXY is like XYWH, but each value is relative to the width and height of
108
+ the origin image. Values are percentages of the origin images' width and
109
+ height respectively.
110
+
111
+ - X: X coordinate of the left of the bounding box
112
+ - Y: Y coordinate of the top of the bounding box
113
+ - WIDTH: width of the bounding box
114
+ - HEIGHT: height of the bounding box
115
+ """
116
+
117
+ X = 0
118
+ Y = 1
119
+ WIDTH = 2
120
+ HEIGHT = 3
121
+
122
+
123
+ @keras_hub_export("keras_hub.bounding_box.YXYX")
124
+ class YXYX:
125
+ """YXYX contains axis indices for the YXYX format.
126
+
127
+ All values in the YXYX format should be absolute pixel values.
128
+
129
+ The YXYX format consists of the following required indices:
130
+
131
+ - TOP: top of the bounding box
132
+ - LEFT: left of the bounding box
133
+ - BOTTOM: bottom of the bounding box
134
+ - RIGHT: right of the bounding box
135
+ """
136
+
137
+ TOP = 0
138
+ LEFT = 1
139
+ BOTTOM = 2
140
+ RIGHT = 3
141
+
142
+
143
+ @keras_hub_export("keras_hub.bounding_box.REL_YXYX")
144
+ class REL_YXYX:
145
+ """REL_YXYX contains axis indices for the REL_YXYX format.
146
+
147
+ REL_YXYX is like YXYX, but each value is relative to the width and height of
148
+ the origin image. Values are percentages of the origin images' width and
149
+ height respectively.
150
+
151
+ The REL_YXYX format consists of the following required indices:
152
+
153
+ - TOP: top of the bounding box
154
+ - LEFT: left of the bounding box
155
+ - BOTTOM: bottom of the bounding box
156
+ - RIGHT: right of the bounding box
157
+ """
158
+
159
+ TOP = 0
160
+ LEFT = 1
161
+ BOTTOM = 2
162
+ RIGHT = 3