keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl → 0.16.0.dev2024092017__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +0 -6
- keras_hub/api/__init__.py +2 -0
- keras_hub/api/bounding_box/__init__.py +36 -0
- keras_hub/api/layers/__init__.py +14 -0
- keras_hub/api/models/__init__.py +97 -48
- keras_hub/api/tokenizers/__init__.py +30 -0
- keras_hub/api/utils/__init__.py +22 -0
- keras_hub/src/api_export.py +15 -9
- keras_hub/src/bounding_box/__init__.py +13 -0
- keras_hub/src/bounding_box/converters.py +529 -0
- keras_hub/src/bounding_box/formats.py +162 -0
- keras_hub/src/bounding_box/iou.py +263 -0
- keras_hub/src/bounding_box/to_dense.py +95 -0
- keras_hub/src/bounding_box/to_ragged.py +99 -0
- keras_hub/src/bounding_box/utils.py +194 -0
- keras_hub/src/bounding_box/validate_format.py +99 -0
- keras_hub/src/layers/preprocessing/audio_converter.py +121 -0
- keras_hub/src/layers/preprocessing/image_converter.py +130 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +2 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +9 -8
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +2 -29
- keras_hub/src/layers/preprocessing/random_deletion.py +33 -31
- keras_hub/src/layers/preprocessing/random_swap.py +33 -31
- keras_hub/src/layers/preprocessing/resizing_image_converter.py +101 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +3 -2
- keras_hub/src/models/albert/__init__.py +1 -2
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +6 -86
- keras_hub/src/models/albert/{albert_classifier.py → albert_text_classifier.py} +34 -10
- keras_hub/src/models/albert/{albert_preprocessor.py → albert_text_classifier_preprocessor.py} +14 -70
- keras_hub/src/models/albert/albert_tokenizer.py +17 -36
- keras_hub/src/models/backbone.py +12 -34
- keras_hub/src/models/bart/__init__.py +1 -2
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +21 -148
- keras_hub/src/models/bart/bart_tokenizer.py +12 -39
- keras_hub/src/models/bert/__init__.py +1 -5
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +6 -87
- keras_hub/src/models/bert/bert_presets.py +1 -4
- keras_hub/src/models/bert/{bert_classifier.py → bert_text_classifier.py} +19 -12
- keras_hub/src/models/bert/{bert_preprocessor.py → bert_text_classifier_preprocessor.py} +14 -70
- keras_hub/src/models/bert/bert_tokenizer.py +17 -35
- keras_hub/src/models/bloom/__init__.py +1 -2
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +6 -91
- keras_hub/src/models/bloom/bloom_tokenizer.py +12 -41
- keras_hub/src/models/causal_lm.py +10 -29
- keras_hub/src/models/causal_lm_preprocessor.py +195 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +54 -15
- keras_hub/src/models/deberta_v3/__init__.py +1 -4
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +14 -77
- keras_hub/src/models/deberta_v3/{deberta_v3_classifier.py → deberta_v3_text_classifier.py} +16 -11
- keras_hub/src/models/deberta_v3/{deberta_v3_preprocessor.py → deberta_v3_text_classifier_preprocessor.py} +23 -64
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +30 -25
- keras_hub/src/models/densenet/densenet_backbone.py +46 -22
- keras_hub/src/models/distil_bert/__init__.py +1 -4
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +14 -76
- keras_hub/src/models/distil_bert/{distil_bert_classifier.py → distil_bert_text_classifier.py} +17 -12
- keras_hub/src/models/distil_bert/{distil_bert_preprocessor.py → distil_bert_text_classifier_preprocessor.py} +23 -63
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +19 -35
- keras_hub/src/models/efficientnet/__init__.py +13 -0
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +569 -0
- keras_hub/src/models/efficientnet/fusedmbconv.py +229 -0
- keras_hub/src/models/efficientnet/mbconv.py +238 -0
- keras_hub/src/models/electra/__init__.py +1 -2
- keras_hub/src/models/electra/electra_tokenizer.py +17 -32
- keras_hub/src/models/f_net/__init__.py +1 -2
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +12 -78
- keras_hub/src/models/f_net/{f_net_classifier.py → f_net_text_classifier.py} +17 -10
- keras_hub/src/models/f_net/{f_net_preprocessor.py → f_net_text_classifier_preprocessor.py} +19 -63
- keras_hub/src/models/f_net/f_net_tokenizer.py +17 -35
- keras_hub/src/models/falcon/__init__.py +1 -2
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/falcon/falcon_tokenizer.py +12 -35
- keras_hub/src/models/gemma/__init__.py +1 -2
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +6 -90
- keras_hub/src/models/gemma/gemma_decoder_block.py +1 -1
- keras_hub/src/models/gemma/gemma_tokenizer.py +12 -23
- keras_hub/src/models/gpt2/__init__.py +1 -2
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +12 -90
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +12 -34
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +6 -91
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +12 -34
- keras_hub/src/models/image_classifier.py +0 -5
- keras_hub/src/models/image_classifier_preprocessor.py +83 -0
- keras_hub/src/models/llama/__init__.py +1 -2
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +6 -85
- keras_hub/src/models/llama/llama_tokenizer.py +12 -25
- keras_hub/src/models/llama3/__init__.py +1 -2
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/llama3/llama3_tokenizer.py +12 -33
- keras_hub/src/models/masked_lm.py +0 -2
- keras_hub/src/models/masked_lm_preprocessor.py +156 -0
- keras_hub/src/models/mistral/__init__.py +1 -2
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +6 -91
- keras_hub/src/models/mistral/mistral_tokenizer.py +12 -23
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +2 -2
- keras_hub/src/models/mobilenet/__init__.py +13 -0
- keras_hub/src/models/mobilenet/mobilenet_backbone.py +530 -0
- keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +114 -0
- keras_hub/src/models/opt/__init__.py +1 -2
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +6 -93
- keras_hub/src/models/opt/opt_tokenizer.py +12 -41
- keras_hub/src/models/pali_gemma/__init__.py +1 -4
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +28 -28
- keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +25 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +5 -5
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +11 -3
- keras_hub/src/models/phi3/__init__.py +1 -2
- keras_hub/src/models/phi3/phi3_causal_lm.py +3 -9
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/phi3/phi3_tokenizer.py +12 -36
- keras_hub/src/models/preprocessor.py +72 -83
- keras_hub/src/models/resnet/__init__.py +6 -0
- keras_hub/src/models/resnet/resnet_backbone.py +390 -42
- keras_hub/src/models/resnet/resnet_image_classifier.py +33 -6
- keras_hub/src/models/resnet/resnet_image_classifier_preprocessor.py +28 -0
- keras_hub/src/models/{llama3/llama3_preprocessor.py → resnet/resnet_image_converter.py} +7 -5
- keras_hub/src/models/resnet/resnet_presets.py +95 -0
- keras_hub/src/models/retinanet/__init__.py +13 -0
- keras_hub/src/models/retinanet/anchor_generator.py +175 -0
- keras_hub/src/models/retinanet/box_matcher.py +259 -0
- keras_hub/src/models/retinanet/non_max_supression.py +578 -0
- keras_hub/src/models/roberta/__init__.py +1 -2
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +22 -74
- keras_hub/src/models/roberta/{roberta_classifier.py → roberta_text_classifier.py} +16 -11
- keras_hub/src/models/roberta/{roberta_preprocessor.py → roberta_text_classifier_preprocessor.py} +21 -53
- keras_hub/src/models/roberta/roberta_tokenizer.py +13 -52
- keras_hub/src/models/seq_2_seq_lm_preprocessor.py +269 -0
- keras_hub/src/models/stable_diffusion_v3/__init__.py +13 -0
- keras_hub/src/models/stable_diffusion_v3/clip_encoder_block.py +103 -0
- keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +93 -0
- keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py +149 -0
- keras_hub/src/models/stable_diffusion_v3/clip_tokenizer.py +167 -0
- keras_hub/src/models/stable_diffusion_v3/mmdit.py +427 -0
- keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +317 -0
- keras_hub/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py +74 -0
- keras_hub/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py +155 -0
- keras_hub/src/models/stable_diffusion_v3/vae_attention.py +126 -0
- keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +186 -0
- keras_hub/src/models/t5/__init__.py +1 -2
- keras_hub/src/models/t5/t5_tokenizer.py +13 -23
- keras_hub/src/models/task.py +71 -116
- keras_hub/src/models/{classifier.py → text_classifier.py} +19 -13
- keras_hub/src/models/text_classifier_preprocessor.py +138 -0
- keras_hub/src/models/whisper/__init__.py +1 -2
- keras_hub/src/models/whisper/{whisper_audio_feature_extractor.py → whisper_audio_converter.py} +20 -18
- keras_hub/src/models/whisper/whisper_backbone.py +0 -3
- keras_hub/src/models/whisper/whisper_presets.py +10 -10
- keras_hub/src/models/whisper/whisper_tokenizer.py +20 -16
- keras_hub/src/models/xlm_roberta/__init__.py +1 -4
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +26 -72
- keras_hub/src/models/xlm_roberta/{xlm_roberta_classifier.py → xlm_roberta_text_classifier.py} +16 -11
- keras_hub/src/models/xlm_roberta/{xlm_roberta_preprocessor.py → xlm_roberta_text_classifier_preprocessor.py} +26 -53
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +25 -10
- keras_hub/src/tests/test_case.py +46 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +30 -17
- keras_hub/src/tokenizers/byte_tokenizer.py +14 -15
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +20 -7
- keras_hub/src/tokenizers/tokenizer.py +67 -32
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +14 -15
- keras_hub/src/tokenizers/word_piece_tokenizer.py +34 -47
- keras_hub/src/utils/imagenet/__init__.py +13 -0
- keras_hub/src/utils/imagenet/imagenet_utils.py +1067 -0
- keras_hub/src/utils/keras_utils.py +0 -50
- keras_hub/src/utils/preset_utils.py +230 -68
- keras_hub/src/utils/tensor_utils.py +187 -69
- keras_hub/src/utils/timm/convert_resnet.py +19 -16
- keras_hub/src/utils/timm/preset_loader.py +66 -0
- keras_hub/src/utils/transformers/convert_albert.py +193 -0
- keras_hub/src/utils/transformers/convert_bart.py +373 -0
- keras_hub/src/utils/transformers/convert_bert.py +7 -17
- keras_hub/src/utils/transformers/convert_distilbert.py +10 -20
- keras_hub/src/utils/transformers/convert_gemma.py +5 -19
- keras_hub/src/utils/transformers/convert_gpt2.py +5 -18
- keras_hub/src/utils/transformers/convert_llama3.py +7 -18
- keras_hub/src/utils/transformers/convert_mistral.py +129 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +7 -29
- keras_hub/src/utils/transformers/preset_loader.py +77 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +2 -2
- keras_hub/src/version_utils.py +1 -1
- keras_hub_nightly-0.16.0.dev2024092017.dist-info/METADATA +202 -0
- keras_hub_nightly-0.16.0.dev2024092017.dist-info/RECORD +334 -0
- {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/WHEEL +1 -1
- keras_hub/src/models/bart/bart_preprocessor.py +0 -276
- keras_hub/src/models/bloom/bloom_preprocessor.py +0 -185
- keras_hub/src/models/electra/electra_preprocessor.py +0 -154
- keras_hub/src/models/falcon/falcon_preprocessor.py +0 -187
- keras_hub/src/models/gemma/gemma_preprocessor.py +0 -191
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +0 -145
- keras_hub/src/models/llama/llama_preprocessor.py +0 -189
- keras_hub/src/models/mistral/mistral_preprocessor.py +0 -190
- keras_hub/src/models/opt/opt_preprocessor.py +0 -188
- keras_hub/src/models/phi3/phi3_preprocessor.py +0 -190
- keras_hub/src/models/whisper/whisper_preprocessor.py +0 -326
- keras_hub/src/utils/timm/convert.py +0 -37
- keras_hub/src/utils/transformers/convert.py +0 -101
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +0 -34
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +0 -297
- {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,263 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Contains functions to compute ious of bounding boxes."""
|
15
|
+
import math
|
16
|
+
|
17
|
+
import keras
|
18
|
+
from keras import ops
|
19
|
+
|
20
|
+
from keras_hub.src.api_export import keras_hub_export
|
21
|
+
from keras_hub.src.bounding_box.converters import convert_format
|
22
|
+
from keras_hub.src.bounding_box.utils import as_relative
|
23
|
+
from keras_hub.src.bounding_box.utils import is_relative
|
24
|
+
|
25
|
+
|
26
|
+
def _compute_area(box):
|
27
|
+
"""Computes area for bounding boxes
|
28
|
+
|
29
|
+
Args:
|
30
|
+
box: [N, 4] or [batch_size, N, 4] float Tensor, either batched
|
31
|
+
or unbatched boxes.
|
32
|
+
Returns:
|
33
|
+
a float Tensor of [N] or [batch_size, N]
|
34
|
+
"""
|
35
|
+
y_min, x_min, y_max, x_max = ops.split(box[..., :4], 4, axis=-1)
|
36
|
+
return ops.squeeze((y_max - y_min) * (x_max - x_min), axis=-1)
|
37
|
+
|
38
|
+
|
39
|
+
def _compute_intersection(boxes1, boxes2):
|
40
|
+
"""Computes intersection area between two sets of boxes.
|
41
|
+
|
42
|
+
Args:
|
43
|
+
boxes1: [N, 4] or [batch_size, N, 4] float Tensor boxes.
|
44
|
+
boxes2: [M, 4] or [batch_size, M, 4] float Tensor boxes.
|
45
|
+
Returns:
|
46
|
+
a [N, M] or [batch_size, N, M] float Tensor.
|
47
|
+
"""
|
48
|
+
y_min1, x_min1, y_max1, x_max1 = ops.split(boxes1[..., :4], 4, axis=-1)
|
49
|
+
y_min2, x_min2, y_max2, x_max2 = ops.split(boxes2[..., :4], 4, axis=-1)
|
50
|
+
boxes2_rank = len(boxes2.shape)
|
51
|
+
perm = [1, 0] if boxes2_rank == 2 else [0, 2, 1]
|
52
|
+
# [N, M] or [batch_size, N, M]
|
53
|
+
intersect_ymax = ops.minimum(y_max1, ops.transpose(y_max2, perm))
|
54
|
+
intersect_ymin = ops.maximum(y_min1, ops.transpose(y_min2, perm))
|
55
|
+
intersect_xmax = ops.minimum(x_max1, ops.transpose(x_max2, perm))
|
56
|
+
intersect_xmin = ops.maximum(x_min1, ops.transpose(x_min2, perm))
|
57
|
+
|
58
|
+
intersect_height = intersect_ymax - intersect_ymin
|
59
|
+
intersect_width = intersect_xmax - intersect_xmin
|
60
|
+
zeros_t = ops.cast(0, intersect_height.dtype)
|
61
|
+
intersect_height = ops.maximum(zeros_t, intersect_height)
|
62
|
+
intersect_width = ops.maximum(zeros_t, intersect_width)
|
63
|
+
|
64
|
+
return intersect_height * intersect_width
|
65
|
+
|
66
|
+
|
67
|
+
@keras_hub_export("keras_hub.bounding_box.compute_iou")
|
68
|
+
def compute_iou(
|
69
|
+
boxes1,
|
70
|
+
boxes2,
|
71
|
+
bounding_box_format,
|
72
|
+
use_masking=False,
|
73
|
+
mask_val=-1,
|
74
|
+
images=None,
|
75
|
+
image_shape=None,
|
76
|
+
):
|
77
|
+
"""Computes a lookup table vector containing the ious for a given set boxes.
|
78
|
+
|
79
|
+
The lookup vector is to be indexed by [`boxes1_index`,`boxes2_index`] if
|
80
|
+
boxes are unbatched and by [`batch`, `boxes1_index`,`boxes2_index`] if the
|
81
|
+
boxes are batched.
|
82
|
+
|
83
|
+
The users can pass `boxes1` and `boxes2` to be different ranks. For example:
|
84
|
+
1) `boxes1`: [batch_size, M, 4], `boxes2`: [batch_size, N, 4] -> return
|
85
|
+
[batch_size, M, N].
|
86
|
+
2) `boxes1`: [batch_size, M, 4], `boxes2`: [N, 4] -> return
|
87
|
+
[batch_size, M, N]
|
88
|
+
3) `boxes1`: [M, 4], `boxes2`: [batch_size, N, 4] -> return
|
89
|
+
[batch_size, M, N]
|
90
|
+
4) `boxes1`: [M, 4], `boxes2`: [N, 4] -> return [M, N]
|
91
|
+
|
92
|
+
Args:
|
93
|
+
boxes1: a list of bounding boxes in 'corners' format. Can be batched or
|
94
|
+
unbatched.
|
95
|
+
boxes2: a list of bounding boxes in 'corners' format. Can be batched or
|
96
|
+
unbatched.
|
97
|
+
bounding_box_format: a case-insensitive string which is one of `"xyxy"`,
|
98
|
+
`"rel_xyxy"`, `"xyWH"`, `"center_xyWH"`, `"yxyx"`, `"rel_yxyx"`.
|
99
|
+
For detailed information on the supported format, see the
|
100
|
+
[KerasCV bounding box documentation](https://keras.io/api/keras_cv/bounding_box/formats/).
|
101
|
+
use_masking: whether masking will be applied. This will mask all `boxes1`
|
102
|
+
or `boxes2` that have values less than 0 in all its 4 dimensions.
|
103
|
+
Default to `False`.
|
104
|
+
mask_val: int to mask those returned IOUs if the masking is True, defaults
|
105
|
+
to -1.
|
106
|
+
|
107
|
+
Returns:
|
108
|
+
iou_lookup_table: a vector containing the pairwise ious of boxes1 and
|
109
|
+
boxes2.
|
110
|
+
""" # noqa: E501
|
111
|
+
|
112
|
+
boxes1_rank = len(boxes1.shape)
|
113
|
+
boxes2_rank = len(boxes2.shape)
|
114
|
+
|
115
|
+
if boxes1_rank not in [2, 3]:
|
116
|
+
raise ValueError(
|
117
|
+
"compute_iou() expects boxes1 to be batched, or to be unbatched. "
|
118
|
+
f"Received len(boxes1.shape)={boxes1_rank}, "
|
119
|
+
f"len(boxes2.shape)={boxes2_rank}. Expected either "
|
120
|
+
"len(boxes1.shape)=2 AND or len(boxes1.shape)=3."
|
121
|
+
)
|
122
|
+
if boxes2_rank not in [2, 3]:
|
123
|
+
raise ValueError(
|
124
|
+
"compute_iou() expects boxes2 to be batched, or to be unbatched. "
|
125
|
+
f"Received len(boxes1.shape)={boxes1_rank}, "
|
126
|
+
f"len(boxes2.shape)={boxes2_rank}. Expected either "
|
127
|
+
"len(boxes2.shape)=2 AND or len(boxes2.shape)=3."
|
128
|
+
)
|
129
|
+
|
130
|
+
target_format = "yxyx"
|
131
|
+
if is_relative(bounding_box_format):
|
132
|
+
target_format = as_relative(target_format)
|
133
|
+
|
134
|
+
boxes1 = convert_format(
|
135
|
+
boxes1,
|
136
|
+
source=bounding_box_format,
|
137
|
+
target=target_format,
|
138
|
+
images=images,
|
139
|
+
image_shape=image_shape,
|
140
|
+
)
|
141
|
+
|
142
|
+
boxes2 = convert_format(
|
143
|
+
boxes2,
|
144
|
+
source=bounding_box_format,
|
145
|
+
target=target_format,
|
146
|
+
images=images,
|
147
|
+
image_shape=image_shape,
|
148
|
+
)
|
149
|
+
|
150
|
+
intersect_area = _compute_intersection(boxes1, boxes2)
|
151
|
+
boxes1_area = _compute_area(boxes1)
|
152
|
+
boxes2_area = _compute_area(boxes2)
|
153
|
+
boxes2_area_rank = len(boxes2_area.shape)
|
154
|
+
boxes2_axis = 1 if (boxes2_area_rank == 2) else 0
|
155
|
+
boxes1_area = ops.expand_dims(boxes1_area, axis=-1)
|
156
|
+
boxes2_area = ops.expand_dims(boxes2_area, axis=boxes2_axis)
|
157
|
+
union_area = boxes1_area + boxes2_area - intersect_area
|
158
|
+
res = ops.divide(intersect_area, union_area + keras.backend.epsilon())
|
159
|
+
|
160
|
+
if boxes1_rank == 2:
|
161
|
+
perm = [1, 0]
|
162
|
+
else:
|
163
|
+
perm = [0, 2, 1]
|
164
|
+
|
165
|
+
if not use_masking:
|
166
|
+
return res
|
167
|
+
|
168
|
+
mask_val_t = ops.cast(mask_val, res.dtype) * ops.ones_like(res)
|
169
|
+
boxes1_mask = ops.less(ops.max(boxes1, axis=-1, keepdims=True), 0.0)
|
170
|
+
boxes2_mask = ops.less(ops.max(boxes2, axis=-1, keepdims=True), 0.0)
|
171
|
+
background_mask = ops.logical_or(
|
172
|
+
boxes1_mask, ops.transpose(boxes2_mask, perm)
|
173
|
+
)
|
174
|
+
iou_lookup_table = ops.where(background_mask, mask_val_t, res)
|
175
|
+
return iou_lookup_table
|
176
|
+
|
177
|
+
|
178
|
+
@keras_hub_export("keras_hub.bounding_box.compute_ciou")
|
179
|
+
def compute_ciou(boxes1, boxes2, bounding_box_format):
|
180
|
+
"""
|
181
|
+
Computes the Complete IoU (CIoU) between two bounding boxes or between
|
182
|
+
two batches of bounding boxes.
|
183
|
+
|
184
|
+
CIoU loss is an extension of GIoU loss, which further improves the IoU
|
185
|
+
optimization for object detection. CIoU loss not only penalizes the
|
186
|
+
bounding box coordinates but also considers the aspect ratio and center
|
187
|
+
distance of the boxes. The length of the last dimension should be 4 to
|
188
|
+
represent the bounding boxes.
|
189
|
+
|
190
|
+
Args:
|
191
|
+
box1 (tensor): tensor representing the first bounding box with
|
192
|
+
shape (..., 4).
|
193
|
+
box2 (tensor): tensor representing the second bounding box with
|
194
|
+
shape (..., 4).
|
195
|
+
bounding_box_format: a case-insensitive string (for example, "xyxy").
|
196
|
+
Each bounding box is defined by these 4 values. For detailed
|
197
|
+
information on the supported formats, see the [KerasCV bounding box
|
198
|
+
documentation](https://keras.io/api/keras_cv/bounding_box/formats/).
|
199
|
+
|
200
|
+
Returns:
|
201
|
+
tensor: The CIoU distance between the two bounding boxes.
|
202
|
+
"""
|
203
|
+
target_format = "xyxy"
|
204
|
+
if is_relative(bounding_box_format):
|
205
|
+
target_format = as_relative(target_format)
|
206
|
+
|
207
|
+
boxes1 = convert_format(
|
208
|
+
boxes1, source=bounding_box_format, target=target_format
|
209
|
+
)
|
210
|
+
|
211
|
+
boxes2 = convert_format(
|
212
|
+
boxes2, source=bounding_box_format, target=target_format
|
213
|
+
)
|
214
|
+
|
215
|
+
x_min1, y_min1, x_max1, y_max1 = ops.split(boxes1[..., :4], 4, axis=-1)
|
216
|
+
x_min2, y_min2, x_max2, y_max2 = ops.split(boxes2[..., :4], 4, axis=-1)
|
217
|
+
|
218
|
+
width_1 = x_max1 - x_min1
|
219
|
+
height_1 = y_max1 - y_min1 + keras.backend.epsilon()
|
220
|
+
width_2 = x_max2 - x_min2
|
221
|
+
height_2 = y_max2 - y_min2 + keras.backend.epsilon()
|
222
|
+
|
223
|
+
intersection_area = ops.maximum(
|
224
|
+
ops.minimum(x_max1, x_max2) - ops.maximum(x_min1, x_min2), 0
|
225
|
+
) * ops.maximum(
|
226
|
+
ops.minimum(y_max1, y_max2) - ops.maximum(y_min1, y_min2), 0
|
227
|
+
)
|
228
|
+
union_area = (
|
229
|
+
width_1 * height_1
|
230
|
+
+ width_2 * height_2
|
231
|
+
- intersection_area
|
232
|
+
+ keras.backend.epsilon()
|
233
|
+
)
|
234
|
+
iou = ops.squeeze(
|
235
|
+
ops.divide(intersection_area, union_area + keras.backend.epsilon()),
|
236
|
+
axis=-1,
|
237
|
+
)
|
238
|
+
|
239
|
+
convex_width = ops.maximum(x_max1, x_max2) - ops.minimum(x_min1, x_min2)
|
240
|
+
convex_height = ops.maximum(y_max1, y_max2) - ops.minimum(y_min1, y_min2)
|
241
|
+
convex_diagonal_squared = ops.squeeze(
|
242
|
+
convex_width**2 + convex_height**2 + keras.backend.epsilon(),
|
243
|
+
axis=-1,
|
244
|
+
)
|
245
|
+
centers_distance_squared = ops.squeeze(
|
246
|
+
((x_min1 + x_max1) / 2 - (x_min2 + x_max2) / 2) ** 2
|
247
|
+
+ ((y_min1 + y_max1) / 2 - (y_min2 + y_max2) / 2) ** 2,
|
248
|
+
axis=-1,
|
249
|
+
)
|
250
|
+
|
251
|
+
v = ops.squeeze(
|
252
|
+
ops.power(
|
253
|
+
(4 / math.pi**2)
|
254
|
+
* (ops.arctan(width_2 / height_2) - ops.arctan(width_1 / height_1)),
|
255
|
+
2,
|
256
|
+
),
|
257
|
+
axis=-1,
|
258
|
+
)
|
259
|
+
alpha = v / (v - iou + (1 + keras.backend.epsilon()))
|
260
|
+
|
261
|
+
return iou - (
|
262
|
+
centers_distance_squared / convex_diagonal_squared + v * alpha
|
263
|
+
)
|
@@ -0,0 +1,95 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import keras_hub.src.bounding_box.validate_format as validate_format
|
16
|
+
from keras_hub.src.api_export import keras_hub_export
|
17
|
+
|
18
|
+
try:
|
19
|
+
import tensorflow as tf
|
20
|
+
except ImportError:
|
21
|
+
tf = None
|
22
|
+
|
23
|
+
|
24
|
+
def _box_shape(batched, boxes_shape, max_boxes):
|
25
|
+
# ensure we dont drop the final axis in RaggedTensor mode
|
26
|
+
if max_boxes is None:
|
27
|
+
shape = list(boxes_shape)
|
28
|
+
shape[-1] = 4
|
29
|
+
return shape
|
30
|
+
if batched:
|
31
|
+
return [None, max_boxes, 4]
|
32
|
+
return [max_boxes, 4]
|
33
|
+
|
34
|
+
|
35
|
+
def _classes_shape(batched, classes_shape, max_boxes):
|
36
|
+
if max_boxes is None:
|
37
|
+
return None
|
38
|
+
if batched:
|
39
|
+
return [None, max_boxes] + classes_shape[2:]
|
40
|
+
return [max_boxes] + classes_shape[2:]
|
41
|
+
|
42
|
+
|
43
|
+
@keras_hub_export("keras_hub.bounding_box.to_dense")
|
44
|
+
def to_dense(bounding_boxes, max_boxes=None, default_value=-1):
|
45
|
+
"""to_dense converts bounding boxes to Dense tensors
|
46
|
+
|
47
|
+
Args:
|
48
|
+
bounding_boxes: bounding boxes in KerasCV dictionary format.
|
49
|
+
max_boxes: the maximum number of boxes, used to pad tensors to a given
|
50
|
+
shape. This can be used to make object detection pipelines TPU
|
51
|
+
compatible.
|
52
|
+
default_value: the default value to pad bounding boxes with. defaults
|
53
|
+
to -1.
|
54
|
+
"""
|
55
|
+
info = validate_format.validate_format(bounding_boxes)
|
56
|
+
|
57
|
+
# guards against errors in metrics regarding modification of inputs.
|
58
|
+
# also guards against unexpected behavior when modifying downstream
|
59
|
+
bounding_boxes = bounding_boxes.copy()
|
60
|
+
|
61
|
+
# Already running in masked mode
|
62
|
+
if not info["ragged"]:
|
63
|
+
# even if already ragged, still copy the dictionary for API consistency
|
64
|
+
return bounding_boxes
|
65
|
+
|
66
|
+
if isinstance(bounding_boxes["classes"], tf.RaggedTensor):
|
67
|
+
bounding_boxes["classes"] = bounding_boxes["classes"].to_tensor(
|
68
|
+
default_value=default_value,
|
69
|
+
shape=_classes_shape(
|
70
|
+
info["is_batched"], bounding_boxes["classes"].shape, max_boxes
|
71
|
+
),
|
72
|
+
)
|
73
|
+
|
74
|
+
if isinstance(bounding_boxes["boxes"], tf.RaggedTensor):
|
75
|
+
bounding_boxes["boxes"] = bounding_boxes["boxes"].to_tensor(
|
76
|
+
default_value=default_value,
|
77
|
+
shape=_box_shape(
|
78
|
+
info["is_batched"], bounding_boxes["boxes"].shape, max_boxes
|
79
|
+
),
|
80
|
+
)
|
81
|
+
|
82
|
+
if "confidence" in bounding_boxes:
|
83
|
+
if isinstance(bounding_boxes["confidence"], tf.RaggedTensor):
|
84
|
+
bounding_boxes["confidence"] = bounding_boxes[
|
85
|
+
"confidence"
|
86
|
+
].to_tensor(
|
87
|
+
default_value=default_value,
|
88
|
+
shape=_classes_shape(
|
89
|
+
info["is_batched"],
|
90
|
+
bounding_boxes["confidence"].shape,
|
91
|
+
max_boxes,
|
92
|
+
),
|
93
|
+
)
|
94
|
+
|
95
|
+
return bounding_boxes
|
@@ -0,0 +1,99 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import keras
|
15
|
+
|
16
|
+
import keras_hub.src.bounding_box.validate_format as validate_format
|
17
|
+
from keras_hub.src.api_export import keras_hub_export
|
18
|
+
|
19
|
+
try:
|
20
|
+
import tensorflow as tf
|
21
|
+
except ImportError:
|
22
|
+
tf = None
|
23
|
+
|
24
|
+
|
25
|
+
@keras_hub_export("keras_hub.bounding_box.to_ragged")
|
26
|
+
def to_ragged(bounding_boxes, sentinel=-1, dtype="float32"):
|
27
|
+
"""converts a Dense padded bounding box `tf.Tensor` to a `tf.RaggedTensor`.
|
28
|
+
|
29
|
+
Bounding boxes are ragged tensors in most use cases. Converting them to a
|
30
|
+
dense tensor makes it easier to work with Tensorflow ecosystem.
|
31
|
+
This function can be used to filter out the masked out bounding boxes by
|
32
|
+
checking for padded sentinel value of the class_id axis of the
|
33
|
+
bounding_boxes.
|
34
|
+
|
35
|
+
Example:
|
36
|
+
```python
|
37
|
+
bounding_boxes = {
|
38
|
+
"boxes": tf.constant([[2, 3, 4, 5], [0, 1, 2, 3]]),
|
39
|
+
"classes": tf.constant([[-1, 1]]),
|
40
|
+
}
|
41
|
+
bounding_boxes = bounding_box.to_ragged(bounding_boxes)
|
42
|
+
print(bounding_boxes)
|
43
|
+
# {
|
44
|
+
# "boxes": [[0, 1, 2, 3]],
|
45
|
+
# "classes": [[1]]
|
46
|
+
# }
|
47
|
+
```
|
48
|
+
|
49
|
+
Args:
|
50
|
+
bounding_boxes: a Tensor of bounding boxes. May be batched, or
|
51
|
+
unbatched.
|
52
|
+
sentinel: The value indicating that a bounding box does not exist at the
|
53
|
+
current index, and the corresponding box is padding, defaults to -1.
|
54
|
+
dtype: the data type to use for the underlying Tensors.
|
55
|
+
Returns:
|
56
|
+
dictionary of `tf.RaggedTensor` or 'tf.Tensor' containing the filtered
|
57
|
+
bounding boxes.
|
58
|
+
"""
|
59
|
+
if keras.config.backend() != "tensorflow":
|
60
|
+
raise NotImplementedError(
|
61
|
+
"`bounding_box.to_ragged` was called using a backend which does "
|
62
|
+
"not support ragged tensors. "
|
63
|
+
f"Current backend: {keras.backend.backend()}."
|
64
|
+
)
|
65
|
+
|
66
|
+
info = validate_format.validate_format(bounding_boxes)
|
67
|
+
|
68
|
+
if info["ragged"]:
|
69
|
+
return bounding_boxes
|
70
|
+
|
71
|
+
boxes = bounding_boxes.get("boxes")
|
72
|
+
classes = bounding_boxes.get("classes")
|
73
|
+
confidence = bounding_boxes.get("confidence", None)
|
74
|
+
|
75
|
+
mask = classes != sentinel
|
76
|
+
|
77
|
+
boxes = tf.ragged.boolean_mask(boxes, mask)
|
78
|
+
classes = tf.ragged.boolean_mask(classes, mask)
|
79
|
+
if confidence is not None:
|
80
|
+
confidence = tf.ragged.boolean_mask(confidence, mask)
|
81
|
+
|
82
|
+
if isinstance(boxes, tf.Tensor):
|
83
|
+
boxes = tf.RaggedTensor.from_tensor(boxes)
|
84
|
+
|
85
|
+
if isinstance(classes, tf.Tensor) and len(classes.shape) > 1:
|
86
|
+
classes = tf.RaggedTensor.from_tensor(classes)
|
87
|
+
|
88
|
+
if confidence is not None:
|
89
|
+
if isinstance(confidence, tf.Tensor) and len(confidence.shape) > 1:
|
90
|
+
confidence = tf.RaggedTensor.from_tensor(confidence)
|
91
|
+
|
92
|
+
result = bounding_boxes.copy()
|
93
|
+
result["boxes"] = tf.cast(boxes, dtype)
|
94
|
+
result["classes"] = tf.cast(classes, dtype)
|
95
|
+
|
96
|
+
if confidence is not None:
|
97
|
+
result["confidence"] = tf.cast(confidence, dtype)
|
98
|
+
|
99
|
+
return result
|
@@ -0,0 +1,194 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""Utility functions for working with bounding boxes."""
|
15
|
+
|
16
|
+
from keras import ops
|
17
|
+
|
18
|
+
from keras_hub.src.api_export import keras_hub_export
|
19
|
+
from keras_hub.src.bounding_box import converters
|
20
|
+
from keras_hub.src.bounding_box.formats import XYWH
|
21
|
+
|
22
|
+
|
23
|
+
@keras_hub_export("keras_hub.bounding_box.is_relative")
|
24
|
+
def is_relative(bounding_box_format):
|
25
|
+
"""A util to check if a bounding box format uses relative coordinates"""
|
26
|
+
if bounding_box_format.lower() not in converters.TO_XYXY_CONVERTERS:
|
27
|
+
raise ValueError(
|
28
|
+
"`is_relative()` received an unsupported format for the argument "
|
29
|
+
f"`bounding_box_format`. `bounding_box_format` should be one of "
|
30
|
+
f"{converters.TO_XYXY_CONVERTERS.keys()}. "
|
31
|
+
f"Got bounding_box_format={bounding_box_format}"
|
32
|
+
)
|
33
|
+
|
34
|
+
return bounding_box_format.startswith("rel")
|
35
|
+
|
36
|
+
|
37
|
+
@keras_hub_export("keras_hub.bounding_box.as_relative")
|
38
|
+
def as_relative(bounding_box_format):
|
39
|
+
"""A util to get the relative equivalent of a provided bounding box format.
|
40
|
+
|
41
|
+
If the specified format is already a relative format,
|
42
|
+
it will be returned unchanged.
|
43
|
+
"""
|
44
|
+
|
45
|
+
if not is_relative(bounding_box_format):
|
46
|
+
return "rel_" + bounding_box_format
|
47
|
+
|
48
|
+
return bounding_box_format
|
49
|
+
|
50
|
+
|
51
|
+
def _relative_area(boxes, bounding_box_format):
|
52
|
+
boxes = converters.convert_format(
|
53
|
+
boxes,
|
54
|
+
source=bounding_box_format,
|
55
|
+
target="rel_xywh",
|
56
|
+
)
|
57
|
+
widths = boxes[..., XYWH.WIDTH]
|
58
|
+
heights = boxes[..., XYWH.HEIGHT]
|
59
|
+
# handle corner case where shear performs a full inversion.
|
60
|
+
return ops.where(
|
61
|
+
ops.logical_and(widths > 0, heights > 0), widths * heights, 0.0
|
62
|
+
)
|
63
|
+
|
64
|
+
|
65
|
+
@keras_hub_export("keras_hub.bounding_box.clip_to_image")
|
66
|
+
def clip_to_image(
|
67
|
+
bounding_boxes, bounding_box_format, images=None, image_shape=None
|
68
|
+
):
|
69
|
+
"""clips bounding boxes to image boundaries.
|
70
|
+
|
71
|
+
`clip_to_image()` clips bounding boxes that have coordinates out of bounds
|
72
|
+
of an image down to the boundaries of the image. This is done by converting
|
73
|
+
the bounding box to relative formats, then clipping them to the `[0, 1]`
|
74
|
+
range. Additionally, bounding boxes that end up with a zero area have their
|
75
|
+
class ID set to -1, indicating that there is no object present in them.
|
76
|
+
|
77
|
+
Args:
|
78
|
+
bounding_boxes: bounding box tensor to clip.
|
79
|
+
bounding_box_format: the KerasCV bounding box format the bounding boxes
|
80
|
+
are in.
|
81
|
+
images: list of images to clip the bounding boxes to.
|
82
|
+
image_shape: the shape of the images to clip the bounding boxes to.
|
83
|
+
"""
|
84
|
+
boxes, classes = bounding_boxes["boxes"], bounding_boxes["classes"]
|
85
|
+
|
86
|
+
boxes = converters.convert_format(
|
87
|
+
boxes,
|
88
|
+
source=bounding_box_format,
|
89
|
+
target="rel_xyxy",
|
90
|
+
images=images,
|
91
|
+
image_shape=image_shape,
|
92
|
+
)
|
93
|
+
boxes, classes, images, squeeze = _format_inputs(boxes, classes, images)
|
94
|
+
x1, y1, x2, y2 = ops.split(boxes, 4, axis=-1)
|
95
|
+
clipped_bounding_boxes = ops.concatenate(
|
96
|
+
[
|
97
|
+
ops.clip(x1, 0, 1),
|
98
|
+
ops.clip(y1, 0, 1),
|
99
|
+
ops.clip(x2, 0, 1),
|
100
|
+
ops.clip(y2, 0, 1),
|
101
|
+
],
|
102
|
+
axis=-1,
|
103
|
+
)
|
104
|
+
areas = _relative_area(
|
105
|
+
clipped_bounding_boxes, bounding_box_format="rel_xyxy"
|
106
|
+
)
|
107
|
+
clipped_bounding_boxes = converters.convert_format(
|
108
|
+
clipped_bounding_boxes,
|
109
|
+
source="rel_xyxy",
|
110
|
+
target=bounding_box_format,
|
111
|
+
images=images,
|
112
|
+
image_shape=image_shape,
|
113
|
+
)
|
114
|
+
clipped_bounding_boxes = ops.where(
|
115
|
+
ops.expand_dims(areas > 0.0, axis=-1), clipped_bounding_boxes, -1.0
|
116
|
+
)
|
117
|
+
classes = ops.where(areas > 0.0, classes, -1)
|
118
|
+
nan_indices = ops.any(ops.isnan(clipped_bounding_boxes), axis=-1)
|
119
|
+
classes = ops.where(nan_indices, -1, classes)
|
120
|
+
|
121
|
+
# TODO update dict and return
|
122
|
+
clipped_bounding_boxes, classes = _format_outputs(
|
123
|
+
clipped_bounding_boxes, classes, squeeze
|
124
|
+
)
|
125
|
+
|
126
|
+
bounding_boxes.update({"boxes": clipped_bounding_boxes, "classes": classes})
|
127
|
+
|
128
|
+
return bounding_boxes
|
129
|
+
|
130
|
+
|
131
|
+
@keras_hub_export("keras_hub.bounding_box.clip_boxes")
|
132
|
+
def clip_boxes(boxes, image_shape):
|
133
|
+
"""Clip boxes to the boundaries of the image shape"""
|
134
|
+
if boxes.shape[-1] != 4:
|
135
|
+
raise ValueError(
|
136
|
+
"boxes.shape[-1] is {:d}, but must be 4.".format(boxes.shape[-1])
|
137
|
+
)
|
138
|
+
|
139
|
+
if isinstance(image_shape, list) or isinstance(image_shape, tuple):
|
140
|
+
height, width, _ = image_shape
|
141
|
+
max_length = ops.stack([height, width, height, width], axis=-1)
|
142
|
+
else:
|
143
|
+
image_shape = ops.cast(image_shape, dtype=boxes.dtype)
|
144
|
+
height = image_shape[0]
|
145
|
+
width = image_shape[1]
|
146
|
+
max_length = ops.stack([height, width, height, width], axis=-1)
|
147
|
+
|
148
|
+
clipped_boxes = ops.maximum(ops.minimum(boxes, max_length), 0.0)
|
149
|
+
return clipped_boxes
|
150
|
+
|
151
|
+
|
152
|
+
def _format_inputs(boxes, classes, images):
|
153
|
+
boxes_rank = len(boxes.shape)
|
154
|
+
if boxes_rank > 3:
|
155
|
+
raise ValueError(
|
156
|
+
"Expected len(boxes.shape)=2, or len(boxes.shape)=3, got "
|
157
|
+
f"len(boxes.shape)={boxes_rank}"
|
158
|
+
)
|
159
|
+
boxes_includes_batch = boxes_rank == 3
|
160
|
+
# Determine if images needs an expand_dims() call
|
161
|
+
if images is not None:
|
162
|
+
images_rank = len(images.shape)
|
163
|
+
if images_rank > 4:
|
164
|
+
raise ValueError(
|
165
|
+
"Expected len(images.shape)=2, or len(images.shape)=3, got "
|
166
|
+
f"len(images.shape)={images_rank}"
|
167
|
+
)
|
168
|
+
images_include_batch = images_rank == 4
|
169
|
+
if boxes_includes_batch != images_include_batch:
|
170
|
+
raise ValueError(
|
171
|
+
"clip_to_image() expects both boxes and images to be batched, "
|
172
|
+
"or both boxes and images to be unbatched. Received "
|
173
|
+
f"len(boxes.shape)={boxes_rank}, "
|
174
|
+
f"len(images.shape)={images_rank}. Expected either "
|
175
|
+
"len(boxes.shape)=2 AND len(images.shape)=3, or "
|
176
|
+
"len(boxes.shape)=3 AND len(images.shape)=4."
|
177
|
+
)
|
178
|
+
if not images_include_batch:
|
179
|
+
images = ops.expand_dims(images, axis=0)
|
180
|
+
|
181
|
+
if not boxes_includes_batch:
|
182
|
+
return (
|
183
|
+
ops.expand_dims(boxes, axis=0),
|
184
|
+
ops.expand_dims(classes, axis=0),
|
185
|
+
images,
|
186
|
+
True,
|
187
|
+
)
|
188
|
+
return boxes, classes, images, False
|
189
|
+
|
190
|
+
|
191
|
+
def _format_outputs(boxes, classes, squeeze):
|
192
|
+
if squeeze:
|
193
|
+
return ops.squeeze(boxes, axis=0), ops.squeeze(classes, axis=0)
|
194
|
+
return boxes, classes
|