keras-hub-nightly 0.15.0.dev20240911134614__py3-none-any.whl → 0.16.0.dev2024092017__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +0 -6
- keras_hub/api/__init__.py +1 -0
- keras_hub/api/models/__init__.py +22 -17
- keras_hub/{src/models/llama3/llama3_preprocessor.py → api/utils/__init__.py} +7 -8
- keras_hub/src/api_export.py +15 -9
- keras_hub/src/models/albert/albert_text_classifier.py +6 -1
- keras_hub/src/models/bert/bert_text_classifier.py +6 -1
- keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +6 -1
- keras_hub/src/models/densenet/densenet_backbone.py +1 -1
- keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +6 -1
- keras_hub/src/models/f_net/f_net_text_classifier.py +6 -1
- keras_hub/src/models/gemma/gemma_decoder_block.py +1 -1
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +7 -78
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +1 -1
- keras_hub/src/models/preprocessor.py +1 -5
- keras_hub/src/models/resnet/resnet_backbone.py +3 -16
- keras_hub/src/models/resnet/resnet_image_classifier.py +26 -3
- keras_hub/src/models/resnet/resnet_presets.py +12 -12
- keras_hub/src/models/retinanet/__init__.py +13 -0
- keras_hub/src/models/retinanet/anchor_generator.py +175 -0
- keras_hub/src/models/retinanet/box_matcher.py +259 -0
- keras_hub/src/models/retinanet/non_max_supression.py +578 -0
- keras_hub/src/models/roberta/roberta_text_classifier.py +6 -1
- keras_hub/src/models/task.py +6 -6
- keras_hub/src/models/text_classifier.py +12 -1
- keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +6 -1
- keras_hub/src/tests/test_case.py +21 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +1 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +1 -0
- keras_hub/src/tokenizers/word_piece_tokenizer.py +1 -0
- keras_hub/src/utils/imagenet/__init__.py +13 -0
- keras_hub/src/utils/imagenet/imagenet_utils.py +1067 -0
- keras_hub/src/utils/preset_utils.py +24 -33
- keras_hub/src/utils/tensor_utils.py +14 -14
- keras_hub/src/utils/timm/convert_resnet.py +0 -1
- keras_hub/src/utils/timm/preset_loader.py +6 -7
- keras_hub/src/version_utils.py +1 -1
- keras_hub_nightly-0.16.0.dev2024092017.dist-info/METADATA +202 -0
- {keras_hub_nightly-0.15.0.dev20240911134614.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/RECORD +41 -45
- {keras_hub_nightly-0.15.0.dev20240911134614.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/WHEEL +1 -1
- keras_hub/src/models/bart/bart_preprocessor.py +0 -264
- keras_hub/src/models/bloom/bloom_preprocessor.py +0 -178
- keras_hub/src/models/electra/electra_preprocessor.py +0 -155
- keras_hub/src/models/falcon/falcon_preprocessor.py +0 -180
- keras_hub/src/models/gemma/gemma_preprocessor.py +0 -184
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +0 -138
- keras_hub/src/models/llama/llama_preprocessor.py +0 -182
- keras_hub/src/models/mistral/mistral_preprocessor.py +0 -183
- keras_hub/src/models/opt/opt_preprocessor.py +0 -181
- keras_hub/src/models/phi3/phi3_preprocessor.py +0 -183
- keras_hub_nightly-0.15.0.dev20240911134614.dist-info/METADATA +0 -33
- {keras_hub_nightly-0.15.0.dev20240911134614.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/top_level.txt +0 -0
@@ -14,7 +14,7 @@
|
|
14
14
|
"""ResNet preset configurations."""
|
15
15
|
|
16
16
|
backbone_presets = {
|
17
|
-
"
|
17
|
+
"resnet_18_imagenet": {
|
18
18
|
"metadata": {
|
19
19
|
"description": (
|
20
20
|
"18-layer ResNet model pre-trained on the ImageNet 1k dataset "
|
@@ -25,9 +25,9 @@ backbone_presets = {
|
|
25
25
|
"path": "resnet",
|
26
26
|
"model_card": "https://arxiv.org/abs/2110.00476",
|
27
27
|
},
|
28
|
-
"kaggle_handle": "kaggle://kerashub/resnetv1/keras/
|
28
|
+
"kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_18_imagenet/2",
|
29
29
|
},
|
30
|
-
"
|
30
|
+
"resnet_50_imagenet": {
|
31
31
|
"metadata": {
|
32
32
|
"description": (
|
33
33
|
"50-layer ResNet model pre-trained on the ImageNet 1k dataset "
|
@@ -38,9 +38,9 @@ backbone_presets = {
|
|
38
38
|
"path": "resnet",
|
39
39
|
"model_card": "https://arxiv.org/abs/2110.00476",
|
40
40
|
},
|
41
|
-
"kaggle_handle": "kaggle://kerashub/resnetv1/keras/
|
41
|
+
"kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_50_imagenet/2",
|
42
42
|
},
|
43
|
-
"
|
43
|
+
"resnet_101_imagenet": {
|
44
44
|
"metadata": {
|
45
45
|
"description": (
|
46
46
|
"101-layer ResNet model pre-trained on the ImageNet 1k dataset "
|
@@ -51,9 +51,9 @@ backbone_presets = {
|
|
51
51
|
"path": "resnet",
|
52
52
|
"model_card": "https://arxiv.org/abs/2110.00476",
|
53
53
|
},
|
54
|
-
"kaggle_handle": "kaggle://kerashub/resnetv1/keras/
|
54
|
+
"kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_101_imagenet/2",
|
55
55
|
},
|
56
|
-
"
|
56
|
+
"resnet_152_imagenet": {
|
57
57
|
"metadata": {
|
58
58
|
"description": (
|
59
59
|
"152-layer ResNet model pre-trained on the ImageNet 1k dataset "
|
@@ -64,9 +64,9 @@ backbone_presets = {
|
|
64
64
|
"path": "resnet",
|
65
65
|
"model_card": "https://arxiv.org/abs/2110.00476",
|
66
66
|
},
|
67
|
-
"kaggle_handle": "kaggle://kerashub/resnetv1/keras/
|
67
|
+
"kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_152_imagenet/2",
|
68
68
|
},
|
69
|
-
"
|
69
|
+
"resnet_v2_50_imagenet": {
|
70
70
|
"metadata": {
|
71
71
|
"description": (
|
72
72
|
"50-layer ResNetV2 model pre-trained on the ImageNet 1k "
|
@@ -77,9 +77,9 @@ backbone_presets = {
|
|
77
77
|
"path": "resnet",
|
78
78
|
"model_card": "https://arxiv.org/abs/2110.00476",
|
79
79
|
},
|
80
|
-
"kaggle_handle": "kaggle://kerashub/resnetv2/keras/
|
80
|
+
"kaggle_handle": "kaggle://kerashub/resnetv2/keras/resnet_v2_50_imagenet/2",
|
81
81
|
},
|
82
|
-
"
|
82
|
+
"resnet_v2_101_imagenet": {
|
83
83
|
"metadata": {
|
84
84
|
"description": (
|
85
85
|
"101-layer ResNetV2 model pre-trained on the ImageNet 1k "
|
@@ -90,6 +90,6 @@ backbone_presets = {
|
|
90
90
|
"path": "resnet",
|
91
91
|
"model_card": "https://arxiv.org/abs/2110.00476",
|
92
92
|
},
|
93
|
-
"kaggle_handle": "kaggle://kerashub/resnetv2/keras/
|
93
|
+
"kaggle_handle": "kaggle://kerashub/resnetv2/keras/resnet_v2_101_imagenet/2",
|
94
94
|
},
|
95
95
|
}
|
@@ -0,0 +1,13 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
@@ -0,0 +1,175 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import math
|
16
|
+
|
17
|
+
import keras
|
18
|
+
from keras import ops
|
19
|
+
|
20
|
+
from keras_hub.src.bounding_box.converters import convert_format
|
21
|
+
|
22
|
+
|
23
|
+
class AnchorGenerator(keras.layers.Layer):
|
24
|
+
"""Generates anchor boxes for object detection tasks.
|
25
|
+
|
26
|
+
This layer creates a set of anchor boxes (also known as default boxes or
|
27
|
+
priors) for use in object detection models, particularly those utilizing
|
28
|
+
Feature Pyramid Networks (FPN). It generates anchors across multiple
|
29
|
+
pyramid levels, with various scales and aspect ratios.
|
30
|
+
|
31
|
+
Feature Pyramid Levels:
|
32
|
+
- Levels typically range from 2 to 6 (P2 to P7), corresponding to different
|
33
|
+
resolutions of the input image.
|
34
|
+
- Each level l has a stride of 2^l pixels relative to the input image.
|
35
|
+
- Lower levels (e.g., P2) have higher resolution and are used for
|
36
|
+
detecting smaller objects.
|
37
|
+
- Higher levels (e.g., P7) have lower resolution and are used
|
38
|
+
for larger objects.
|
39
|
+
|
40
|
+
Args:
|
41
|
+
bounding_box_format (str): The format of the bounding boxes
|
42
|
+
to be generated. Expected to be a string like 'xyxy', 'xywh', etc.
|
43
|
+
min_level (int): Minimum level of the output feature pyramid.
|
44
|
+
max_level (int): Maximum level of the output feature pyramid.
|
45
|
+
num_scales (int): Number of intermediate scales added on each level.
|
46
|
+
For example, num_scales=2 adds one additional intermediate anchor
|
47
|
+
scale [2^0, 2^0.5] on each level.
|
48
|
+
aspect_ratios (list of float): Aspect ratios of anchors added on
|
49
|
+
each level. Each number indicates the ratio of width to height.
|
50
|
+
anchor_size (float): Scale of size of the base anchor relative to the
|
51
|
+
feature stride 2^level.
|
52
|
+
|
53
|
+
Call arguments:
|
54
|
+
images (Optional[Tensor]): An image tensor with shape `[B, H, W, C]` or
|
55
|
+
`[H, W, C]`. If provided, its shape will be used to determine anchor
|
56
|
+
sizes.
|
57
|
+
|
58
|
+
Returns:
|
59
|
+
Dict: A dictionary mapping feature levels
|
60
|
+
(e.g., 'P3', 'P4', etc.) to anchor boxes. Each entry contains a tensor
|
61
|
+
of shape `(H/stride * W/stride * num_anchors_per_location, 4)`,
|
62
|
+
where H and W are the height and width of the image, stride is 2^level,
|
63
|
+
and num_anchors_per_location is `num_scales * len(aspect_ratios)`.
|
64
|
+
|
65
|
+
Example:
|
66
|
+
```python
|
67
|
+
anchor_generator = AnchorGenerator(
|
68
|
+
bounding_box_format='xyxy',
|
69
|
+
min_level=3,
|
70
|
+
max_level=7,
|
71
|
+
num_scales=3,
|
72
|
+
aspect_ratios=[0.5, 1.0, 2.0],
|
73
|
+
anchor_size=4.0,
|
74
|
+
)
|
75
|
+
anchors = anchor_generator(images=keas.ops.ones(shape=(2, 640, 480, 3)))
|
76
|
+
```
|
77
|
+
"""
|
78
|
+
|
79
|
+
def __init__(
|
80
|
+
self,
|
81
|
+
bounding_box_format,
|
82
|
+
min_level,
|
83
|
+
max_level,
|
84
|
+
num_scales,
|
85
|
+
aspect_ratios,
|
86
|
+
anchor_size,
|
87
|
+
**kwargs,
|
88
|
+
):
|
89
|
+
super().__init__(**kwargs)
|
90
|
+
self.bounding_box_format = bounding_box_format
|
91
|
+
self.min_level = min_level
|
92
|
+
self.max_level = max_level
|
93
|
+
self.num_scales = num_scales
|
94
|
+
self.aspect_ratios = aspect_ratios
|
95
|
+
self.anchor_size = anchor_size
|
96
|
+
self.built = True
|
97
|
+
|
98
|
+
def call(self, images):
|
99
|
+
images_shape = ops.shape(images)
|
100
|
+
if len(images_shape) == 4:
|
101
|
+
image_shape = images_shape[1:-1]
|
102
|
+
else:
|
103
|
+
image_shape = images_shape[:-1]
|
104
|
+
|
105
|
+
image_shape = tuple(image_shape)
|
106
|
+
|
107
|
+
multilevel_boxes = {}
|
108
|
+
for level in range(self.min_level, self.max_level + 1):
|
109
|
+
boxes_l = []
|
110
|
+
# Calculate the feature map size for this level
|
111
|
+
feat_size_y = math.ceil(image_shape[0] / 2**level)
|
112
|
+
feat_size_x = math.ceil(image_shape[1] / 2**level)
|
113
|
+
|
114
|
+
# Calculate the stride (step size) for this level
|
115
|
+
stride_y = ops.cast(image_shape[0] / feat_size_y, "float32")
|
116
|
+
stride_x = ops.cast(image_shape[1] / feat_size_x, "float32")
|
117
|
+
|
118
|
+
# Generate anchor center points
|
119
|
+
# Start from stride/2 to center anchors on pixels
|
120
|
+
cx = ops.arange(stride_x / 2, image_shape[1], stride_x)
|
121
|
+
cy = ops.arange(stride_y / 2, image_shape[0], stride_y)
|
122
|
+
|
123
|
+
# Create a grid of anchor centers
|
124
|
+
cx_grid, cy_grid = ops.meshgrid(cx, cy)
|
125
|
+
|
126
|
+
for scale in range(self.num_scales):
|
127
|
+
for aspect_ratio in self.aspect_ratios:
|
128
|
+
# Calculate the intermediate scale factor
|
129
|
+
intermidate_scale = 2 ** (scale / self.num_scales)
|
130
|
+
# Calculate the base anchor size for this level and scale
|
131
|
+
base_anchor_size = (
|
132
|
+
self.anchor_size * 2**level * intermidate_scale
|
133
|
+
)
|
134
|
+
# Adjust anchor dimensions based on aspect ratio
|
135
|
+
aspect_x = aspect_ratio**0.5
|
136
|
+
aspect_y = aspect_ratio**-0.5
|
137
|
+
half_anchor_size_x = base_anchor_size * aspect_x / 2.0
|
138
|
+
half_anchor_size_y = base_anchor_size * aspect_y / 2.0
|
139
|
+
|
140
|
+
# Generate anchor boxes (y1, x1, y2, x2 format)
|
141
|
+
boxes = ops.stack(
|
142
|
+
[
|
143
|
+
cy_grid - half_anchor_size_y,
|
144
|
+
cx_grid - half_anchor_size_x,
|
145
|
+
cy_grid + half_anchor_size_y,
|
146
|
+
cx_grid + half_anchor_size_x,
|
147
|
+
],
|
148
|
+
axis=-1,
|
149
|
+
)
|
150
|
+
boxes_l.append(boxes)
|
151
|
+
# Concat anchors on the same level to tensor shape HxWx(Ax4)
|
152
|
+
boxes_l = ops.concatenate(boxes_l, axis=-1)
|
153
|
+
boxes_l = ops.reshape(boxes_l, (-1, 4))
|
154
|
+
# Convert to user defined
|
155
|
+
multilevel_boxes[f"P{level}"] = convert_format(
|
156
|
+
boxes_l,
|
157
|
+
source="yxyx",
|
158
|
+
target=self.bounding_box_format,
|
159
|
+
)
|
160
|
+
return multilevel_boxes
|
161
|
+
|
162
|
+
def compute_output_shape(self, input_shape):
|
163
|
+
multilevel_boxes_shape = {}
|
164
|
+
for level in range(self.min_level, self.max_level + 1):
|
165
|
+
multilevel_boxes_shape[f"P{level}"] = (None, None, 4)
|
166
|
+
return multilevel_boxes_shape
|
167
|
+
|
168
|
+
@property
|
169
|
+
def anchors_per_location(self):
|
170
|
+
"""
|
171
|
+
The `anchors_per_location` property returns the number of anchors
|
172
|
+
generated per pixel location, which is equal to
|
173
|
+
`num_scales * len(aspect_ratios)`.
|
174
|
+
"""
|
175
|
+
return self.num_scales * len(self.aspect_ratios)
|
@@ -0,0 +1,259 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import keras
|
16
|
+
from keras import ops
|
17
|
+
|
18
|
+
|
19
|
+
class BoxMatcher(keras.layers.Layer):
|
20
|
+
"""Box matching logic based on argmax of highest value (e.g., IOU).
|
21
|
+
|
22
|
+
This class computes matches from a similarity matrix. Each row will be
|
23
|
+
matched to at least one column, the matched result can either be positive
|
24
|
+
or negative, or simply ignored depending on the setting.
|
25
|
+
|
26
|
+
The settings include `thresholds` and `match_values`, for example if:
|
27
|
+
1) `thresholds=[negative_threshold, positive_threshold]`, and
|
28
|
+
`match_values=[negative_value=0, ignore_value=-1, positive_value=1]`: the
|
29
|
+
rows will be assigned to positive_value if its argmax result >=
|
30
|
+
positive_threshold; the rows will be assigned to negative_value if its
|
31
|
+
argmax result < negative_threshold, and the rows will be assigned to
|
32
|
+
ignore_value if its argmax result is between [negative_threshold,
|
33
|
+
positive_threshold).
|
34
|
+
2) `thresholds=[negative_threshold, positive_threshold]`, and
|
35
|
+
`match_values=[ignore_value=-1, negative_value=0, positive_value=1]`: the
|
36
|
+
rows will be assigned to positive_value if its argmax result >=
|
37
|
+
positive_threshold; the rows will be assigned to ignore_value if its
|
38
|
+
argmax result < negative_threshold, and the rows will be assigned to
|
39
|
+
negative_value if its argmax result is between [negative_threshold,
|
40
|
+
positive_threshold). This is different from case 1) by swapping first two
|
41
|
+
values.
|
42
|
+
3) `thresholds=[positive_threshold]`, and
|
43
|
+
`match_values=[negative_values, positive_value]`: the rows will be
|
44
|
+
assigned to positive value if its argmax result >= positive_threshold;
|
45
|
+
the rows will be assigned to negative_value if its argmax result <
|
46
|
+
negative_threshold.
|
47
|
+
|
48
|
+
Args:
|
49
|
+
thresholds: A sorted list of floats to classify the matches into
|
50
|
+
different results (e.g. positive or negative or ignored match). The
|
51
|
+
list will be prepended with -Inf and and appended with +Inf.
|
52
|
+
match_values: A list of integers representing matched results (e.g.
|
53
|
+
positive or negative or ignored match). len(`match_values`) must
|
54
|
+
equal to len(`thresholds`) + 1.
|
55
|
+
force_match_for_each_col: each row will be argmax matched to at
|
56
|
+
least one column. This means some columns will be matched to
|
57
|
+
multiple rows while some columns will not be matched to any rows.
|
58
|
+
Filtering by `thresholds` will make less columns match to positive
|
59
|
+
result. Setting this to True guarantees that each column will be
|
60
|
+
matched to positive result to at least one row.
|
61
|
+
|
62
|
+
Raises:
|
63
|
+
ValueError: if `thresholds` not sorted or
|
64
|
+
len(`match_values`) != len(`thresholds`) + 1
|
65
|
+
|
66
|
+
Example:
|
67
|
+
```python
|
68
|
+
box_matcher = keras_cv.layers.BoxMatcher([0.3, 0.7], [-1, 0, 1])
|
69
|
+
iou_metric = keras_cv.bounding_box.compute_iou(anchors, boxes)
|
70
|
+
matched_columns, matched_match_values = box_matcher(iou_metric)
|
71
|
+
cls_mask = ops.less_equal(matched_match_values, 0)
|
72
|
+
```
|
73
|
+
|
74
|
+
"""
|
75
|
+
|
76
|
+
def __init__(
|
77
|
+
self,
|
78
|
+
thresholds,
|
79
|
+
match_values,
|
80
|
+
force_match_for_each_col=False,
|
81
|
+
**kwargs,
|
82
|
+
):
|
83
|
+
super().__init__(**kwargs)
|
84
|
+
if sorted(thresholds) != thresholds:
|
85
|
+
raise ValueError(f"`threshold` must be sorted, got {thresholds}")
|
86
|
+
self.match_values = match_values
|
87
|
+
if len(match_values) != len(thresholds) + 1:
|
88
|
+
raise ValueError(
|
89
|
+
f"len(`match_values`) must be len(`thresholds`) + 1, got "
|
90
|
+
f"match_values {match_values}, thresholds {thresholds}"
|
91
|
+
)
|
92
|
+
thresholds.insert(0, -float("inf"))
|
93
|
+
thresholds.append(float("inf"))
|
94
|
+
self.thresholds = thresholds
|
95
|
+
self.force_match_for_each_col = force_match_for_each_col
|
96
|
+
self.built = True
|
97
|
+
|
98
|
+
def call(self, similarity_matrix):
|
99
|
+
"""Matches each row to a column based on argmax
|
100
|
+
|
101
|
+
Args:
|
102
|
+
similarity_matrix: A float Tensor of shape `[num_rows, num_cols]` or
|
103
|
+
`[batch_size, num_rows, num_cols]` representing any similarity
|
104
|
+
metric.
|
105
|
+
|
106
|
+
Returns:
|
107
|
+
matched_columns: An integer tensor of shape `[num_rows]` or
|
108
|
+
`[batch_size, num_rows]` storing the index of the matched
|
109
|
+
column for each row.
|
110
|
+
matched_values: An integer tensor of shape [num_rows] or
|
111
|
+
`[batch_size, num_rows]` storing the match result
|
112
|
+
`(positive match, negative match, ignored match)`.
|
113
|
+
"""
|
114
|
+
squeeze_result = False
|
115
|
+
if len(similarity_matrix.shape) == 2:
|
116
|
+
squeeze_result = True
|
117
|
+
similarity_matrix = ops.expand_dims(similarity_matrix, axis=0)
|
118
|
+
static_shape = list(similarity_matrix.shape)
|
119
|
+
num_rows = static_shape[1] or ops.shape(similarity_matrix)[1]
|
120
|
+
batch_size = static_shape[0] or ops.shape(similarity_matrix)[0]
|
121
|
+
|
122
|
+
def _match_when_cols_are_empty():
|
123
|
+
"""Performs matching when the rows of similarity matrix are empty.
|
124
|
+
When the rows are empty, all detections are false positives. So we
|
125
|
+
return a tensor of -1's to indicate that the rows do not match to
|
126
|
+
any columns.
|
127
|
+
|
128
|
+
Returns:
|
129
|
+
matched_columns: An integer tensor of shape [batch_size,
|
130
|
+
num_rows] storing the index of the matched column for each
|
131
|
+
row.
|
132
|
+
matched_values: An integer tensor of shape [batch_size,
|
133
|
+
num_rows] storing the match type indicator (e.g. positive or
|
134
|
+
negative or ignored match).
|
135
|
+
"""
|
136
|
+
matched_columns = ops.zeros([batch_size, num_rows], dtype="int32")
|
137
|
+
matched_values = -ops.ones([batch_size, num_rows], dtype="int32")
|
138
|
+
return matched_columns, matched_values
|
139
|
+
|
140
|
+
def _match_when_cols_are_non_empty():
|
141
|
+
"""Performs matching when the rows of similarity matrix are
|
142
|
+
non-empty.
|
143
|
+
Returns:
|
144
|
+
matched_columns: An integer tensor of shape [batch_size,
|
145
|
+
num_rows] storing the index of the matched column for each
|
146
|
+
row.
|
147
|
+
matched_values: An integer tensor of shape [batch_size,
|
148
|
+
num_rows] storing the match type indicator (e.g. positive or
|
149
|
+
negative or ignored match).
|
150
|
+
"""
|
151
|
+
# Jax traces this function even when running eagerly and the
|
152
|
+
# columns are non-empty. Therefore, we need to handle the case
|
153
|
+
# where the similarity matrix is empty. We do this by padding
|
154
|
+
# some -1s to the end. -1s are guaranteed to not affect argmax
|
155
|
+
# matching because all values in a similarity matrix are [0,1]
|
156
|
+
# and the indexing won't change because these are added at the
|
157
|
+
# end.
|
158
|
+
padded_similarity_matrix = ops.concatenate(
|
159
|
+
[similarity_matrix, -ops.ones((batch_size, num_rows, 1))],
|
160
|
+
axis=-1,
|
161
|
+
)
|
162
|
+
|
163
|
+
matched_columns = ops.argmax(
|
164
|
+
padded_similarity_matrix,
|
165
|
+
axis=-1,
|
166
|
+
)
|
167
|
+
|
168
|
+
# Get logical indices of ignored and unmatched columns as int32
|
169
|
+
matched_vals = ops.max(padded_similarity_matrix, axis=-1)
|
170
|
+
matched_values = ops.zeros([batch_size, num_rows], "int32")
|
171
|
+
|
172
|
+
match_dtype = matched_vals.dtype
|
173
|
+
for ind, low, high in zip(
|
174
|
+
self.match_values, self.thresholds[:-1], self.thresholds[1:]
|
175
|
+
):
|
176
|
+
low_threshold = ops.cast(low, match_dtype)
|
177
|
+
high_threshold = ops.cast(high, match_dtype)
|
178
|
+
mask = ops.logical_and(
|
179
|
+
ops.greater_equal(matched_vals, low_threshold),
|
180
|
+
ops.less(matched_vals, high_threshold),
|
181
|
+
)
|
182
|
+
matched_values = self._set_values_using_indicator(
|
183
|
+
matched_values, mask, ind
|
184
|
+
)
|
185
|
+
|
186
|
+
if self.force_match_for_each_col:
|
187
|
+
# [batch_size, num_cols], for each column (groundtruth_box),
|
188
|
+
# find the best matching row (anchor).
|
189
|
+
matching_rows = ops.argmax(
|
190
|
+
padded_similarity_matrix,
|
191
|
+
axis=1,
|
192
|
+
)
|
193
|
+
# [batch_size, num_cols, num_rows], a transposed 0-1 mapping
|
194
|
+
# matrix M, where M[j, i] = 1 means column j is matched to
|
195
|
+
# row i.
|
196
|
+
column_to_row_match_mapping = ops.one_hot(
|
197
|
+
matching_rows, num_rows
|
198
|
+
)
|
199
|
+
# [batch_size, num_rows], for each row (anchor), find the
|
200
|
+
# matched column (groundtruth_box).
|
201
|
+
force_matched_columns = ops.argmax(
|
202
|
+
column_to_row_match_mapping,
|
203
|
+
axis=1,
|
204
|
+
)
|
205
|
+
# [batch_size, num_rows]
|
206
|
+
force_matched_column_mask = ops.cast(
|
207
|
+
ops.max(column_to_row_match_mapping, axis=1),
|
208
|
+
"bool",
|
209
|
+
)
|
210
|
+
# [batch_size, num_rows]
|
211
|
+
matched_columns = ops.where(
|
212
|
+
force_matched_column_mask,
|
213
|
+
force_matched_columns,
|
214
|
+
matched_columns,
|
215
|
+
)
|
216
|
+
matched_values = ops.where(
|
217
|
+
force_matched_column_mask,
|
218
|
+
self.match_values[-1]
|
219
|
+
* ops.ones([batch_size, num_rows], dtype="int32"),
|
220
|
+
matched_values,
|
221
|
+
)
|
222
|
+
|
223
|
+
return ops.cast(matched_columns, "int32"), matched_values
|
224
|
+
|
225
|
+
num_boxes = (
|
226
|
+
similarity_matrix.shape[-1] or ops.shape(similarity_matrix)[-1]
|
227
|
+
)
|
228
|
+
matched_columns, matched_values = ops.cond(
|
229
|
+
pred=ops.greater(num_boxes, 0),
|
230
|
+
true_fn=_match_when_cols_are_non_empty,
|
231
|
+
false_fn=_match_when_cols_are_empty,
|
232
|
+
)
|
233
|
+
|
234
|
+
if squeeze_result:
|
235
|
+
matched_columns = ops.squeeze(matched_columns, axis=0)
|
236
|
+
matched_values = ops.squeeze(matched_values, axis=0)
|
237
|
+
|
238
|
+
return matched_columns, matched_values
|
239
|
+
|
240
|
+
def _set_values_using_indicator(self, x, indicator, val):
|
241
|
+
"""Set the indicated fields of x to val.
|
242
|
+
|
243
|
+
Args:
|
244
|
+
x: tensor.
|
245
|
+
indicator: boolean with same shape as x.
|
246
|
+
val: scalar with value to set.
|
247
|
+
Returns:
|
248
|
+
modified tensor.
|
249
|
+
"""
|
250
|
+
indicator = ops.cast(indicator, x.dtype)
|
251
|
+
return ops.where(indicator == 0, x, val)
|
252
|
+
|
253
|
+
def get_config(self):
|
254
|
+
config = {
|
255
|
+
"thresholds": self.thresholds[1:-1],
|
256
|
+
"match_values": self.match_values,
|
257
|
+
"force_match_for_each_col": self.force_match_for_each_col,
|
258
|
+
}
|
259
|
+
return config
|