keras-hub-nightly 0.15.0.dev20240911134614__py3-none-any.whl → 0.16.0.dev2024092017__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. keras_hub/__init__.py +0 -6
  2. keras_hub/api/__init__.py +1 -0
  3. keras_hub/api/models/__init__.py +22 -17
  4. keras_hub/{src/models/llama3/llama3_preprocessor.py → api/utils/__init__.py} +7 -8
  5. keras_hub/src/api_export.py +15 -9
  6. keras_hub/src/models/albert/albert_text_classifier.py +6 -1
  7. keras_hub/src/models/bert/bert_text_classifier.py +6 -1
  8. keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +6 -1
  9. keras_hub/src/models/densenet/densenet_backbone.py +1 -1
  10. keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +6 -1
  11. keras_hub/src/models/f_net/f_net_text_classifier.py +6 -1
  12. keras_hub/src/models/gemma/gemma_decoder_block.py +1 -1
  13. keras_hub/src/models/gpt2/gpt2_preprocessor.py +7 -78
  14. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +1 -1
  15. keras_hub/src/models/preprocessor.py +1 -5
  16. keras_hub/src/models/resnet/resnet_backbone.py +3 -16
  17. keras_hub/src/models/resnet/resnet_image_classifier.py +26 -3
  18. keras_hub/src/models/resnet/resnet_presets.py +12 -12
  19. keras_hub/src/models/retinanet/__init__.py +13 -0
  20. keras_hub/src/models/retinanet/anchor_generator.py +175 -0
  21. keras_hub/src/models/retinanet/box_matcher.py +259 -0
  22. keras_hub/src/models/retinanet/non_max_supression.py +578 -0
  23. keras_hub/src/models/roberta/roberta_text_classifier.py +6 -1
  24. keras_hub/src/models/task.py +6 -6
  25. keras_hub/src/models/text_classifier.py +12 -1
  26. keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +6 -1
  27. keras_hub/src/tests/test_case.py +21 -0
  28. keras_hub/src/tokenizers/byte_pair_tokenizer.py +1 -0
  29. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +1 -0
  30. keras_hub/src/tokenizers/word_piece_tokenizer.py +1 -0
  31. keras_hub/src/utils/imagenet/__init__.py +13 -0
  32. keras_hub/src/utils/imagenet/imagenet_utils.py +1067 -0
  33. keras_hub/src/utils/preset_utils.py +24 -33
  34. keras_hub/src/utils/tensor_utils.py +14 -14
  35. keras_hub/src/utils/timm/convert_resnet.py +0 -1
  36. keras_hub/src/utils/timm/preset_loader.py +6 -7
  37. keras_hub/src/version_utils.py +1 -1
  38. keras_hub_nightly-0.16.0.dev2024092017.dist-info/METADATA +202 -0
  39. {keras_hub_nightly-0.15.0.dev20240911134614.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/RECORD +41 -45
  40. {keras_hub_nightly-0.15.0.dev20240911134614.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/WHEEL +1 -1
  41. keras_hub/src/models/bart/bart_preprocessor.py +0 -264
  42. keras_hub/src/models/bloom/bloom_preprocessor.py +0 -178
  43. keras_hub/src/models/electra/electra_preprocessor.py +0 -155
  44. keras_hub/src/models/falcon/falcon_preprocessor.py +0 -180
  45. keras_hub/src/models/gemma/gemma_preprocessor.py +0 -184
  46. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +0 -138
  47. keras_hub/src/models/llama/llama_preprocessor.py +0 -182
  48. keras_hub/src/models/mistral/mistral_preprocessor.py +0 -183
  49. keras_hub/src/models/opt/opt_preprocessor.py +0 -181
  50. keras_hub/src/models/phi3/phi3_preprocessor.py +0 -183
  51. keras_hub_nightly-0.15.0.dev20240911134614.dist-info/METADATA +0 -33
  52. {keras_hub_nightly-0.15.0.dev20240911134614.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/top_level.txt +0 -0
@@ -14,7 +14,7 @@
14
14
  """ResNet preset configurations."""
15
15
 
16
16
  backbone_presets = {
17
- "resnet_18": {
17
+ "resnet_18_imagenet": {
18
18
  "metadata": {
19
19
  "description": (
20
20
  "18-layer ResNet model pre-trained on the ImageNet 1k dataset "
@@ -25,9 +25,9 @@ backbone_presets = {
25
25
  "path": "resnet",
26
26
  "model_card": "https://arxiv.org/abs/2110.00476",
27
27
  },
28
- "kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_18/1",
28
+ "kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_18_imagenet/2",
29
29
  },
30
- "resnet_50": {
30
+ "resnet_50_imagenet": {
31
31
  "metadata": {
32
32
  "description": (
33
33
  "50-layer ResNet model pre-trained on the ImageNet 1k dataset "
@@ -38,9 +38,9 @@ backbone_presets = {
38
38
  "path": "resnet",
39
39
  "model_card": "https://arxiv.org/abs/2110.00476",
40
40
  },
41
- "kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_50/1",
41
+ "kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_50_imagenet/2",
42
42
  },
43
- "resnet_101": {
43
+ "resnet_101_imagenet": {
44
44
  "metadata": {
45
45
  "description": (
46
46
  "101-layer ResNet model pre-trained on the ImageNet 1k dataset "
@@ -51,9 +51,9 @@ backbone_presets = {
51
51
  "path": "resnet",
52
52
  "model_card": "https://arxiv.org/abs/2110.00476",
53
53
  },
54
- "kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_101/1",
54
+ "kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_101_imagenet/2",
55
55
  },
56
- "resnet_152": {
56
+ "resnet_152_imagenet": {
57
57
  "metadata": {
58
58
  "description": (
59
59
  "152-layer ResNet model pre-trained on the ImageNet 1k dataset "
@@ -64,9 +64,9 @@ backbone_presets = {
64
64
  "path": "resnet",
65
65
  "model_card": "https://arxiv.org/abs/2110.00476",
66
66
  },
67
- "kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_152/1",
67
+ "kaggle_handle": "kaggle://kerashub/resnetv1/keras/resnet_152_imagenet/2",
68
68
  },
69
- "resnet_v2_50": {
69
+ "resnet_v2_50_imagenet": {
70
70
  "metadata": {
71
71
  "description": (
72
72
  "50-layer ResNetV2 model pre-trained on the ImageNet 1k "
@@ -77,9 +77,9 @@ backbone_presets = {
77
77
  "path": "resnet",
78
78
  "model_card": "https://arxiv.org/abs/2110.00476",
79
79
  },
80
- "kaggle_handle": "kaggle://kerashub/resnetv2/keras/resnet_v2_50/1",
80
+ "kaggle_handle": "kaggle://kerashub/resnetv2/keras/resnet_v2_50_imagenet/2",
81
81
  },
82
- "resnet_v2_101": {
82
+ "resnet_v2_101_imagenet": {
83
83
  "metadata": {
84
84
  "description": (
85
85
  "101-layer ResNetV2 model pre-trained on the ImageNet 1k "
@@ -90,6 +90,6 @@ backbone_presets = {
90
90
  "path": "resnet",
91
91
  "model_card": "https://arxiv.org/abs/2110.00476",
92
92
  },
93
- "kaggle_handle": "kaggle://kerashub/resnetv2/keras/resnet_v2_101/1",
93
+ "kaggle_handle": "kaggle://kerashub/resnetv2/keras/resnet_v2_101_imagenet/2",
94
94
  },
95
95
  }
@@ -0,0 +1,13 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,175 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+
17
+ import keras
18
+ from keras import ops
19
+
20
+ from keras_hub.src.bounding_box.converters import convert_format
21
+
22
+
23
+ class AnchorGenerator(keras.layers.Layer):
24
+ """Generates anchor boxes for object detection tasks.
25
+
26
+ This layer creates a set of anchor boxes (also known as default boxes or
27
+ priors) for use in object detection models, particularly those utilizing
28
+ Feature Pyramid Networks (FPN). It generates anchors across multiple
29
+ pyramid levels, with various scales and aspect ratios.
30
+
31
+ Feature Pyramid Levels:
32
+ - Levels typically range from 2 to 6 (P2 to P7), corresponding to different
33
+ resolutions of the input image.
34
+ - Each level l has a stride of 2^l pixels relative to the input image.
35
+ - Lower levels (e.g., P2) have higher resolution and are used for
36
+ detecting smaller objects.
37
+ - Higher levels (e.g., P7) have lower resolution and are used
38
+ for larger objects.
39
+
40
+ Args:
41
+ bounding_box_format (str): The format of the bounding boxes
42
+ to be generated. Expected to be a string like 'xyxy', 'xywh', etc.
43
+ min_level (int): Minimum level of the output feature pyramid.
44
+ max_level (int): Maximum level of the output feature pyramid.
45
+ num_scales (int): Number of intermediate scales added on each level.
46
+ For example, num_scales=2 adds one additional intermediate anchor
47
+ scale [2^0, 2^0.5] on each level.
48
+ aspect_ratios (list of float): Aspect ratios of anchors added on
49
+ each level. Each number indicates the ratio of width to height.
50
+ anchor_size (float): Scale of size of the base anchor relative to the
51
+ feature stride 2^level.
52
+
53
+ Call arguments:
54
+ images (Optional[Tensor]): An image tensor with shape `[B, H, W, C]` or
55
+ `[H, W, C]`. If provided, its shape will be used to determine anchor
56
+ sizes.
57
+
58
+ Returns:
59
+ Dict: A dictionary mapping feature levels
60
+ (e.g., 'P3', 'P4', etc.) to anchor boxes. Each entry contains a tensor
61
+ of shape `(H/stride * W/stride * num_anchors_per_location, 4)`,
62
+ where H and W are the height and width of the image, stride is 2^level,
63
+ and num_anchors_per_location is `num_scales * len(aspect_ratios)`.
64
+
65
+ Example:
66
+ ```python
67
+ anchor_generator = AnchorGenerator(
68
+ bounding_box_format='xyxy',
69
+ min_level=3,
70
+ max_level=7,
71
+ num_scales=3,
72
+ aspect_ratios=[0.5, 1.0, 2.0],
73
+ anchor_size=4.0,
74
+ )
75
+ anchors = anchor_generator(images=keas.ops.ones(shape=(2, 640, 480, 3)))
76
+ ```
77
+ """
78
+
79
+ def __init__(
80
+ self,
81
+ bounding_box_format,
82
+ min_level,
83
+ max_level,
84
+ num_scales,
85
+ aspect_ratios,
86
+ anchor_size,
87
+ **kwargs,
88
+ ):
89
+ super().__init__(**kwargs)
90
+ self.bounding_box_format = bounding_box_format
91
+ self.min_level = min_level
92
+ self.max_level = max_level
93
+ self.num_scales = num_scales
94
+ self.aspect_ratios = aspect_ratios
95
+ self.anchor_size = anchor_size
96
+ self.built = True
97
+
98
+ def call(self, images):
99
+ images_shape = ops.shape(images)
100
+ if len(images_shape) == 4:
101
+ image_shape = images_shape[1:-1]
102
+ else:
103
+ image_shape = images_shape[:-1]
104
+
105
+ image_shape = tuple(image_shape)
106
+
107
+ multilevel_boxes = {}
108
+ for level in range(self.min_level, self.max_level + 1):
109
+ boxes_l = []
110
+ # Calculate the feature map size for this level
111
+ feat_size_y = math.ceil(image_shape[0] / 2**level)
112
+ feat_size_x = math.ceil(image_shape[1] / 2**level)
113
+
114
+ # Calculate the stride (step size) for this level
115
+ stride_y = ops.cast(image_shape[0] / feat_size_y, "float32")
116
+ stride_x = ops.cast(image_shape[1] / feat_size_x, "float32")
117
+
118
+ # Generate anchor center points
119
+ # Start from stride/2 to center anchors on pixels
120
+ cx = ops.arange(stride_x / 2, image_shape[1], stride_x)
121
+ cy = ops.arange(stride_y / 2, image_shape[0], stride_y)
122
+
123
+ # Create a grid of anchor centers
124
+ cx_grid, cy_grid = ops.meshgrid(cx, cy)
125
+
126
+ for scale in range(self.num_scales):
127
+ for aspect_ratio in self.aspect_ratios:
128
+ # Calculate the intermediate scale factor
129
+ intermidate_scale = 2 ** (scale / self.num_scales)
130
+ # Calculate the base anchor size for this level and scale
131
+ base_anchor_size = (
132
+ self.anchor_size * 2**level * intermidate_scale
133
+ )
134
+ # Adjust anchor dimensions based on aspect ratio
135
+ aspect_x = aspect_ratio**0.5
136
+ aspect_y = aspect_ratio**-0.5
137
+ half_anchor_size_x = base_anchor_size * aspect_x / 2.0
138
+ half_anchor_size_y = base_anchor_size * aspect_y / 2.0
139
+
140
+ # Generate anchor boxes (y1, x1, y2, x2 format)
141
+ boxes = ops.stack(
142
+ [
143
+ cy_grid - half_anchor_size_y,
144
+ cx_grid - half_anchor_size_x,
145
+ cy_grid + half_anchor_size_y,
146
+ cx_grid + half_anchor_size_x,
147
+ ],
148
+ axis=-1,
149
+ )
150
+ boxes_l.append(boxes)
151
+ # Concat anchors on the same level to tensor shape HxWx(Ax4)
152
+ boxes_l = ops.concatenate(boxes_l, axis=-1)
153
+ boxes_l = ops.reshape(boxes_l, (-1, 4))
154
+ # Convert to user defined
155
+ multilevel_boxes[f"P{level}"] = convert_format(
156
+ boxes_l,
157
+ source="yxyx",
158
+ target=self.bounding_box_format,
159
+ )
160
+ return multilevel_boxes
161
+
162
+ def compute_output_shape(self, input_shape):
163
+ multilevel_boxes_shape = {}
164
+ for level in range(self.min_level, self.max_level + 1):
165
+ multilevel_boxes_shape[f"P{level}"] = (None, None, 4)
166
+ return multilevel_boxes_shape
167
+
168
+ @property
169
+ def anchors_per_location(self):
170
+ """
171
+ The `anchors_per_location` property returns the number of anchors
172
+ generated per pixel location, which is equal to
173
+ `num_scales * len(aspect_ratios)`.
174
+ """
175
+ return self.num_scales * len(self.aspect_ratios)
@@ -0,0 +1,259 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import keras
16
+ from keras import ops
17
+
18
+
19
+ class BoxMatcher(keras.layers.Layer):
20
+ """Box matching logic based on argmax of highest value (e.g., IOU).
21
+
22
+ This class computes matches from a similarity matrix. Each row will be
23
+ matched to at least one column, the matched result can either be positive
24
+ or negative, or simply ignored depending on the setting.
25
+
26
+ The settings include `thresholds` and `match_values`, for example if:
27
+ 1) `thresholds=[negative_threshold, positive_threshold]`, and
28
+ `match_values=[negative_value=0, ignore_value=-1, positive_value=1]`: the
29
+ rows will be assigned to positive_value if its argmax result >=
30
+ positive_threshold; the rows will be assigned to negative_value if its
31
+ argmax result < negative_threshold, and the rows will be assigned to
32
+ ignore_value if its argmax result is between [negative_threshold,
33
+ positive_threshold).
34
+ 2) `thresholds=[negative_threshold, positive_threshold]`, and
35
+ `match_values=[ignore_value=-1, negative_value=0, positive_value=1]`: the
36
+ rows will be assigned to positive_value if its argmax result >=
37
+ positive_threshold; the rows will be assigned to ignore_value if its
38
+ argmax result < negative_threshold, and the rows will be assigned to
39
+ negative_value if its argmax result is between [negative_threshold,
40
+ positive_threshold). This is different from case 1) by swapping first two
41
+ values.
42
+ 3) `thresholds=[positive_threshold]`, and
43
+ `match_values=[negative_values, positive_value]`: the rows will be
44
+ assigned to positive value if its argmax result >= positive_threshold;
45
+ the rows will be assigned to negative_value if its argmax result <
46
+ negative_threshold.
47
+
48
+ Args:
49
+ thresholds: A sorted list of floats to classify the matches into
50
+ different results (e.g. positive or negative or ignored match). The
51
+ list will be prepended with -Inf and and appended with +Inf.
52
+ match_values: A list of integers representing matched results (e.g.
53
+ positive or negative or ignored match). len(`match_values`) must
54
+ equal to len(`thresholds`) + 1.
55
+ force_match_for_each_col: each row will be argmax matched to at
56
+ least one column. This means some columns will be matched to
57
+ multiple rows while some columns will not be matched to any rows.
58
+ Filtering by `thresholds` will make less columns match to positive
59
+ result. Setting this to True guarantees that each column will be
60
+ matched to positive result to at least one row.
61
+
62
+ Raises:
63
+ ValueError: if `thresholds` not sorted or
64
+ len(`match_values`) != len(`thresholds`) + 1
65
+
66
+ Example:
67
+ ```python
68
+ box_matcher = keras_cv.layers.BoxMatcher([0.3, 0.7], [-1, 0, 1])
69
+ iou_metric = keras_cv.bounding_box.compute_iou(anchors, boxes)
70
+ matched_columns, matched_match_values = box_matcher(iou_metric)
71
+ cls_mask = ops.less_equal(matched_match_values, 0)
72
+ ```
73
+
74
+ """
75
+
76
+ def __init__(
77
+ self,
78
+ thresholds,
79
+ match_values,
80
+ force_match_for_each_col=False,
81
+ **kwargs,
82
+ ):
83
+ super().__init__(**kwargs)
84
+ if sorted(thresholds) != thresholds:
85
+ raise ValueError(f"`threshold` must be sorted, got {thresholds}")
86
+ self.match_values = match_values
87
+ if len(match_values) != len(thresholds) + 1:
88
+ raise ValueError(
89
+ f"len(`match_values`) must be len(`thresholds`) + 1, got "
90
+ f"match_values {match_values}, thresholds {thresholds}"
91
+ )
92
+ thresholds.insert(0, -float("inf"))
93
+ thresholds.append(float("inf"))
94
+ self.thresholds = thresholds
95
+ self.force_match_for_each_col = force_match_for_each_col
96
+ self.built = True
97
+
98
+ def call(self, similarity_matrix):
99
+ """Matches each row to a column based on argmax
100
+
101
+ Args:
102
+ similarity_matrix: A float Tensor of shape `[num_rows, num_cols]` or
103
+ `[batch_size, num_rows, num_cols]` representing any similarity
104
+ metric.
105
+
106
+ Returns:
107
+ matched_columns: An integer tensor of shape `[num_rows]` or
108
+ `[batch_size, num_rows]` storing the index of the matched
109
+ column for each row.
110
+ matched_values: An integer tensor of shape [num_rows] or
111
+ `[batch_size, num_rows]` storing the match result
112
+ `(positive match, negative match, ignored match)`.
113
+ """
114
+ squeeze_result = False
115
+ if len(similarity_matrix.shape) == 2:
116
+ squeeze_result = True
117
+ similarity_matrix = ops.expand_dims(similarity_matrix, axis=0)
118
+ static_shape = list(similarity_matrix.shape)
119
+ num_rows = static_shape[1] or ops.shape(similarity_matrix)[1]
120
+ batch_size = static_shape[0] or ops.shape(similarity_matrix)[0]
121
+
122
+ def _match_when_cols_are_empty():
123
+ """Performs matching when the rows of similarity matrix are empty.
124
+ When the rows are empty, all detections are false positives. So we
125
+ return a tensor of -1's to indicate that the rows do not match to
126
+ any columns.
127
+
128
+ Returns:
129
+ matched_columns: An integer tensor of shape [batch_size,
130
+ num_rows] storing the index of the matched column for each
131
+ row.
132
+ matched_values: An integer tensor of shape [batch_size,
133
+ num_rows] storing the match type indicator (e.g. positive or
134
+ negative or ignored match).
135
+ """
136
+ matched_columns = ops.zeros([batch_size, num_rows], dtype="int32")
137
+ matched_values = -ops.ones([batch_size, num_rows], dtype="int32")
138
+ return matched_columns, matched_values
139
+
140
+ def _match_when_cols_are_non_empty():
141
+ """Performs matching when the rows of similarity matrix are
142
+ non-empty.
143
+ Returns:
144
+ matched_columns: An integer tensor of shape [batch_size,
145
+ num_rows] storing the index of the matched column for each
146
+ row.
147
+ matched_values: An integer tensor of shape [batch_size,
148
+ num_rows] storing the match type indicator (e.g. positive or
149
+ negative or ignored match).
150
+ """
151
+ # Jax traces this function even when running eagerly and the
152
+ # columns are non-empty. Therefore, we need to handle the case
153
+ # where the similarity matrix is empty. We do this by padding
154
+ # some -1s to the end. -1s are guaranteed to not affect argmax
155
+ # matching because all values in a similarity matrix are [0,1]
156
+ # and the indexing won't change because these are added at the
157
+ # end.
158
+ padded_similarity_matrix = ops.concatenate(
159
+ [similarity_matrix, -ops.ones((batch_size, num_rows, 1))],
160
+ axis=-1,
161
+ )
162
+
163
+ matched_columns = ops.argmax(
164
+ padded_similarity_matrix,
165
+ axis=-1,
166
+ )
167
+
168
+ # Get logical indices of ignored and unmatched columns as int32
169
+ matched_vals = ops.max(padded_similarity_matrix, axis=-1)
170
+ matched_values = ops.zeros([batch_size, num_rows], "int32")
171
+
172
+ match_dtype = matched_vals.dtype
173
+ for ind, low, high in zip(
174
+ self.match_values, self.thresholds[:-1], self.thresholds[1:]
175
+ ):
176
+ low_threshold = ops.cast(low, match_dtype)
177
+ high_threshold = ops.cast(high, match_dtype)
178
+ mask = ops.logical_and(
179
+ ops.greater_equal(matched_vals, low_threshold),
180
+ ops.less(matched_vals, high_threshold),
181
+ )
182
+ matched_values = self._set_values_using_indicator(
183
+ matched_values, mask, ind
184
+ )
185
+
186
+ if self.force_match_for_each_col:
187
+ # [batch_size, num_cols], for each column (groundtruth_box),
188
+ # find the best matching row (anchor).
189
+ matching_rows = ops.argmax(
190
+ padded_similarity_matrix,
191
+ axis=1,
192
+ )
193
+ # [batch_size, num_cols, num_rows], a transposed 0-1 mapping
194
+ # matrix M, where M[j, i] = 1 means column j is matched to
195
+ # row i.
196
+ column_to_row_match_mapping = ops.one_hot(
197
+ matching_rows, num_rows
198
+ )
199
+ # [batch_size, num_rows], for each row (anchor), find the
200
+ # matched column (groundtruth_box).
201
+ force_matched_columns = ops.argmax(
202
+ column_to_row_match_mapping,
203
+ axis=1,
204
+ )
205
+ # [batch_size, num_rows]
206
+ force_matched_column_mask = ops.cast(
207
+ ops.max(column_to_row_match_mapping, axis=1),
208
+ "bool",
209
+ )
210
+ # [batch_size, num_rows]
211
+ matched_columns = ops.where(
212
+ force_matched_column_mask,
213
+ force_matched_columns,
214
+ matched_columns,
215
+ )
216
+ matched_values = ops.where(
217
+ force_matched_column_mask,
218
+ self.match_values[-1]
219
+ * ops.ones([batch_size, num_rows], dtype="int32"),
220
+ matched_values,
221
+ )
222
+
223
+ return ops.cast(matched_columns, "int32"), matched_values
224
+
225
+ num_boxes = (
226
+ similarity_matrix.shape[-1] or ops.shape(similarity_matrix)[-1]
227
+ )
228
+ matched_columns, matched_values = ops.cond(
229
+ pred=ops.greater(num_boxes, 0),
230
+ true_fn=_match_when_cols_are_non_empty,
231
+ false_fn=_match_when_cols_are_empty,
232
+ )
233
+
234
+ if squeeze_result:
235
+ matched_columns = ops.squeeze(matched_columns, axis=0)
236
+ matched_values = ops.squeeze(matched_values, axis=0)
237
+
238
+ return matched_columns, matched_values
239
+
240
+ def _set_values_using_indicator(self, x, indicator, val):
241
+ """Set the indicated fields of x to val.
242
+
243
+ Args:
244
+ x: tensor.
245
+ indicator: boolean with same shape as x.
246
+ val: scalar with value to set.
247
+ Returns:
248
+ modified tensor.
249
+ """
250
+ indicator = ops.cast(indicator, x.dtype)
251
+ return ops.where(indicator == 0, x, val)
252
+
253
+ def get_config(self):
254
+ config = {
255
+ "thresholds": self.thresholds[1:-1],
256
+ "match_values": self.match_values,
257
+ "force_match_for_each_col": self.force_match_for_each_col,
258
+ }
259
+ return config