keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl → 0.16.0.dev2024092017__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (198) hide show
  1. keras_hub/__init__.py +0 -6
  2. keras_hub/api/__init__.py +2 -0
  3. keras_hub/api/bounding_box/__init__.py +36 -0
  4. keras_hub/api/layers/__init__.py +14 -0
  5. keras_hub/api/models/__init__.py +97 -48
  6. keras_hub/api/tokenizers/__init__.py +30 -0
  7. keras_hub/api/utils/__init__.py +22 -0
  8. keras_hub/src/api_export.py +15 -9
  9. keras_hub/src/bounding_box/__init__.py +13 -0
  10. keras_hub/src/bounding_box/converters.py +529 -0
  11. keras_hub/src/bounding_box/formats.py +162 -0
  12. keras_hub/src/bounding_box/iou.py +263 -0
  13. keras_hub/src/bounding_box/to_dense.py +95 -0
  14. keras_hub/src/bounding_box/to_ragged.py +99 -0
  15. keras_hub/src/bounding_box/utils.py +194 -0
  16. keras_hub/src/bounding_box/validate_format.py +99 -0
  17. keras_hub/src/layers/preprocessing/audio_converter.py +121 -0
  18. keras_hub/src/layers/preprocessing/image_converter.py +130 -0
  19. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +2 -0
  20. keras_hub/src/layers/preprocessing/multi_segment_packer.py +9 -8
  21. keras_hub/src/layers/preprocessing/preprocessing_layer.py +2 -29
  22. keras_hub/src/layers/preprocessing/random_deletion.py +33 -31
  23. keras_hub/src/layers/preprocessing/random_swap.py +33 -31
  24. keras_hub/src/layers/preprocessing/resizing_image_converter.py +101 -0
  25. keras_hub/src/layers/preprocessing/start_end_packer.py +3 -2
  26. keras_hub/src/models/albert/__init__.py +1 -2
  27. keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +6 -86
  28. keras_hub/src/models/albert/{albert_classifier.py → albert_text_classifier.py} +34 -10
  29. keras_hub/src/models/albert/{albert_preprocessor.py → albert_text_classifier_preprocessor.py} +14 -70
  30. keras_hub/src/models/albert/albert_tokenizer.py +17 -36
  31. keras_hub/src/models/backbone.py +12 -34
  32. keras_hub/src/models/bart/__init__.py +1 -2
  33. keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +21 -148
  34. keras_hub/src/models/bart/bart_tokenizer.py +12 -39
  35. keras_hub/src/models/bert/__init__.py +1 -5
  36. keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +6 -87
  37. keras_hub/src/models/bert/bert_presets.py +1 -4
  38. keras_hub/src/models/bert/{bert_classifier.py → bert_text_classifier.py} +19 -12
  39. keras_hub/src/models/bert/{bert_preprocessor.py → bert_text_classifier_preprocessor.py} +14 -70
  40. keras_hub/src/models/bert/bert_tokenizer.py +17 -35
  41. keras_hub/src/models/bloom/__init__.py +1 -2
  42. keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +6 -91
  43. keras_hub/src/models/bloom/bloom_tokenizer.py +12 -41
  44. keras_hub/src/models/causal_lm.py +10 -29
  45. keras_hub/src/models/causal_lm_preprocessor.py +195 -0
  46. keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +54 -15
  47. keras_hub/src/models/deberta_v3/__init__.py +1 -4
  48. keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +14 -77
  49. keras_hub/src/models/deberta_v3/{deberta_v3_classifier.py → deberta_v3_text_classifier.py} +16 -11
  50. keras_hub/src/models/deberta_v3/{deberta_v3_preprocessor.py → deberta_v3_text_classifier_preprocessor.py} +23 -64
  51. keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +30 -25
  52. keras_hub/src/models/densenet/densenet_backbone.py +46 -22
  53. keras_hub/src/models/distil_bert/__init__.py +1 -4
  54. keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +14 -76
  55. keras_hub/src/models/distil_bert/{distil_bert_classifier.py → distil_bert_text_classifier.py} +17 -12
  56. keras_hub/src/models/distil_bert/{distil_bert_preprocessor.py → distil_bert_text_classifier_preprocessor.py} +23 -63
  57. keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +19 -35
  58. keras_hub/src/models/efficientnet/__init__.py +13 -0
  59. keras_hub/src/models/efficientnet/efficientnet_backbone.py +569 -0
  60. keras_hub/src/models/efficientnet/fusedmbconv.py +229 -0
  61. keras_hub/src/models/efficientnet/mbconv.py +238 -0
  62. keras_hub/src/models/electra/__init__.py +1 -2
  63. keras_hub/src/models/electra/electra_tokenizer.py +17 -32
  64. keras_hub/src/models/f_net/__init__.py +1 -2
  65. keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +12 -78
  66. keras_hub/src/models/f_net/{f_net_classifier.py → f_net_text_classifier.py} +17 -10
  67. keras_hub/src/models/f_net/{f_net_preprocessor.py → f_net_text_classifier_preprocessor.py} +19 -63
  68. keras_hub/src/models/f_net/f_net_tokenizer.py +17 -35
  69. keras_hub/src/models/falcon/__init__.py +1 -2
  70. keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +6 -89
  71. keras_hub/src/models/falcon/falcon_tokenizer.py +12 -35
  72. keras_hub/src/models/gemma/__init__.py +1 -2
  73. keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +6 -90
  74. keras_hub/src/models/gemma/gemma_decoder_block.py +1 -1
  75. keras_hub/src/models/gemma/gemma_tokenizer.py +12 -23
  76. keras_hub/src/models/gpt2/__init__.py +1 -2
  77. keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +6 -89
  78. keras_hub/src/models/gpt2/gpt2_preprocessor.py +12 -90
  79. keras_hub/src/models/gpt2/gpt2_tokenizer.py +12 -34
  80. keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +6 -91
  81. keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +12 -34
  82. keras_hub/src/models/image_classifier.py +0 -5
  83. keras_hub/src/models/image_classifier_preprocessor.py +83 -0
  84. keras_hub/src/models/llama/__init__.py +1 -2
  85. keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +6 -85
  86. keras_hub/src/models/llama/llama_tokenizer.py +12 -25
  87. keras_hub/src/models/llama3/__init__.py +1 -2
  88. keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +6 -89
  89. keras_hub/src/models/llama3/llama3_tokenizer.py +12 -33
  90. keras_hub/src/models/masked_lm.py +0 -2
  91. keras_hub/src/models/masked_lm_preprocessor.py +156 -0
  92. keras_hub/src/models/mistral/__init__.py +1 -2
  93. keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +6 -91
  94. keras_hub/src/models/mistral/mistral_tokenizer.py +12 -23
  95. keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +2 -2
  96. keras_hub/src/models/mobilenet/__init__.py +13 -0
  97. keras_hub/src/models/mobilenet/mobilenet_backbone.py +530 -0
  98. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +114 -0
  99. keras_hub/src/models/opt/__init__.py +1 -2
  100. keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +6 -93
  101. keras_hub/src/models/opt/opt_tokenizer.py +12 -41
  102. keras_hub/src/models/pali_gemma/__init__.py +1 -4
  103. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +28 -28
  104. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +25 -0
  105. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +5 -5
  106. keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +11 -3
  107. keras_hub/src/models/phi3/__init__.py +1 -2
  108. keras_hub/src/models/phi3/phi3_causal_lm.py +3 -9
  109. keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +6 -89
  110. keras_hub/src/models/phi3/phi3_tokenizer.py +12 -36
  111. keras_hub/src/models/preprocessor.py +72 -83
  112. keras_hub/src/models/resnet/__init__.py +6 -0
  113. keras_hub/src/models/resnet/resnet_backbone.py +390 -42
  114. keras_hub/src/models/resnet/resnet_image_classifier.py +33 -6
  115. keras_hub/src/models/resnet/resnet_image_classifier_preprocessor.py +28 -0
  116. keras_hub/src/models/{llama3/llama3_preprocessor.py → resnet/resnet_image_converter.py} +7 -5
  117. keras_hub/src/models/resnet/resnet_presets.py +95 -0
  118. keras_hub/src/models/retinanet/__init__.py +13 -0
  119. keras_hub/src/models/retinanet/anchor_generator.py +175 -0
  120. keras_hub/src/models/retinanet/box_matcher.py +259 -0
  121. keras_hub/src/models/retinanet/non_max_supression.py +578 -0
  122. keras_hub/src/models/roberta/__init__.py +1 -2
  123. keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +22 -74
  124. keras_hub/src/models/roberta/{roberta_classifier.py → roberta_text_classifier.py} +16 -11
  125. keras_hub/src/models/roberta/{roberta_preprocessor.py → roberta_text_classifier_preprocessor.py} +21 -53
  126. keras_hub/src/models/roberta/roberta_tokenizer.py +13 -52
  127. keras_hub/src/models/seq_2_seq_lm_preprocessor.py +269 -0
  128. keras_hub/src/models/stable_diffusion_v3/__init__.py +13 -0
  129. keras_hub/src/models/stable_diffusion_v3/clip_encoder_block.py +103 -0
  130. keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +93 -0
  131. keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py +149 -0
  132. keras_hub/src/models/stable_diffusion_v3/clip_tokenizer.py +167 -0
  133. keras_hub/src/models/stable_diffusion_v3/mmdit.py +427 -0
  134. keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +317 -0
  135. keras_hub/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py +74 -0
  136. keras_hub/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py +155 -0
  137. keras_hub/src/models/stable_diffusion_v3/vae_attention.py +126 -0
  138. keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +186 -0
  139. keras_hub/src/models/t5/__init__.py +1 -2
  140. keras_hub/src/models/t5/t5_tokenizer.py +13 -23
  141. keras_hub/src/models/task.py +71 -116
  142. keras_hub/src/models/{classifier.py → text_classifier.py} +19 -13
  143. keras_hub/src/models/text_classifier_preprocessor.py +138 -0
  144. keras_hub/src/models/whisper/__init__.py +1 -2
  145. keras_hub/src/models/whisper/{whisper_audio_feature_extractor.py → whisper_audio_converter.py} +20 -18
  146. keras_hub/src/models/whisper/whisper_backbone.py +0 -3
  147. keras_hub/src/models/whisper/whisper_presets.py +10 -10
  148. keras_hub/src/models/whisper/whisper_tokenizer.py +20 -16
  149. keras_hub/src/models/xlm_roberta/__init__.py +1 -4
  150. keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +26 -72
  151. keras_hub/src/models/xlm_roberta/{xlm_roberta_classifier.py → xlm_roberta_text_classifier.py} +16 -11
  152. keras_hub/src/models/xlm_roberta/{xlm_roberta_preprocessor.py → xlm_roberta_text_classifier_preprocessor.py} +26 -53
  153. keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +25 -10
  154. keras_hub/src/tests/test_case.py +46 -0
  155. keras_hub/src/tokenizers/byte_pair_tokenizer.py +30 -17
  156. keras_hub/src/tokenizers/byte_tokenizer.py +14 -15
  157. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +20 -7
  158. keras_hub/src/tokenizers/tokenizer.py +67 -32
  159. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +14 -15
  160. keras_hub/src/tokenizers/word_piece_tokenizer.py +34 -47
  161. keras_hub/src/utils/imagenet/__init__.py +13 -0
  162. keras_hub/src/utils/imagenet/imagenet_utils.py +1067 -0
  163. keras_hub/src/utils/keras_utils.py +0 -50
  164. keras_hub/src/utils/preset_utils.py +230 -68
  165. keras_hub/src/utils/tensor_utils.py +187 -69
  166. keras_hub/src/utils/timm/convert_resnet.py +19 -16
  167. keras_hub/src/utils/timm/preset_loader.py +66 -0
  168. keras_hub/src/utils/transformers/convert_albert.py +193 -0
  169. keras_hub/src/utils/transformers/convert_bart.py +373 -0
  170. keras_hub/src/utils/transformers/convert_bert.py +7 -17
  171. keras_hub/src/utils/transformers/convert_distilbert.py +10 -20
  172. keras_hub/src/utils/transformers/convert_gemma.py +5 -19
  173. keras_hub/src/utils/transformers/convert_gpt2.py +5 -18
  174. keras_hub/src/utils/transformers/convert_llama3.py +7 -18
  175. keras_hub/src/utils/transformers/convert_mistral.py +129 -0
  176. keras_hub/src/utils/transformers/convert_pali_gemma.py +7 -29
  177. keras_hub/src/utils/transformers/preset_loader.py +77 -0
  178. keras_hub/src/utils/transformers/safetensor_utils.py +2 -2
  179. keras_hub/src/version_utils.py +1 -1
  180. keras_hub_nightly-0.16.0.dev2024092017.dist-info/METADATA +202 -0
  181. keras_hub_nightly-0.16.0.dev2024092017.dist-info/RECORD +334 -0
  182. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/WHEEL +1 -1
  183. keras_hub/src/models/bart/bart_preprocessor.py +0 -276
  184. keras_hub/src/models/bloom/bloom_preprocessor.py +0 -185
  185. keras_hub/src/models/electra/electra_preprocessor.py +0 -154
  186. keras_hub/src/models/falcon/falcon_preprocessor.py +0 -187
  187. keras_hub/src/models/gemma/gemma_preprocessor.py +0 -191
  188. keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +0 -145
  189. keras_hub/src/models/llama/llama_preprocessor.py +0 -189
  190. keras_hub/src/models/mistral/mistral_preprocessor.py +0 -190
  191. keras_hub/src/models/opt/opt_preprocessor.py +0 -188
  192. keras_hub/src/models/phi3/phi3_preprocessor.py +0 -190
  193. keras_hub/src/models/whisper/whisper_preprocessor.py +0 -326
  194. keras_hub/src/utils/timm/convert.py +0 -37
  195. keras_hub/src/utils/transformers/convert.py +0 -101
  196. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +0 -34
  197. keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +0 -297
  198. {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,259 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import keras
16
+ from keras import ops
17
+
18
+
19
+ class BoxMatcher(keras.layers.Layer):
20
+ """Box matching logic based on argmax of highest value (e.g., IOU).
21
+
22
+ This class computes matches from a similarity matrix. Each row will be
23
+ matched to at least one column, the matched result can either be positive
24
+ or negative, or simply ignored depending on the setting.
25
+
26
+ The settings include `thresholds` and `match_values`, for example if:
27
+ 1) `thresholds=[negative_threshold, positive_threshold]`, and
28
+ `match_values=[negative_value=0, ignore_value=-1, positive_value=1]`: the
29
+ rows will be assigned to positive_value if its argmax result >=
30
+ positive_threshold; the rows will be assigned to negative_value if its
31
+ argmax result < negative_threshold, and the rows will be assigned to
32
+ ignore_value if its argmax result is between [negative_threshold,
33
+ positive_threshold).
34
+ 2) `thresholds=[negative_threshold, positive_threshold]`, and
35
+ `match_values=[ignore_value=-1, negative_value=0, positive_value=1]`: the
36
+ rows will be assigned to positive_value if its argmax result >=
37
+ positive_threshold; the rows will be assigned to ignore_value if its
38
+ argmax result < negative_threshold, and the rows will be assigned to
39
+ negative_value if its argmax result is between [negative_threshold,
40
+ positive_threshold). This is different from case 1) by swapping first two
41
+ values.
42
+ 3) `thresholds=[positive_threshold]`, and
43
+ `match_values=[negative_values, positive_value]`: the rows will be
44
+ assigned to positive value if its argmax result >= positive_threshold;
45
+ the rows will be assigned to negative_value if its argmax result <
46
+ negative_threshold.
47
+
48
+ Args:
49
+ thresholds: A sorted list of floats to classify the matches into
50
+ different results (e.g. positive or negative or ignored match). The
51
+ list will be prepended with -Inf and and appended with +Inf.
52
+ match_values: A list of integers representing matched results (e.g.
53
+ positive or negative or ignored match). len(`match_values`) must
54
+ equal to len(`thresholds`) + 1.
55
+ force_match_for_each_col: each row will be argmax matched to at
56
+ least one column. This means some columns will be matched to
57
+ multiple rows while some columns will not be matched to any rows.
58
+ Filtering by `thresholds` will make less columns match to positive
59
+ result. Setting this to True guarantees that each column will be
60
+ matched to positive result to at least one row.
61
+
62
+ Raises:
63
+ ValueError: if `thresholds` not sorted or
64
+ len(`match_values`) != len(`thresholds`) + 1
65
+
66
+ Example:
67
+ ```python
68
+ box_matcher = keras_cv.layers.BoxMatcher([0.3, 0.7], [-1, 0, 1])
69
+ iou_metric = keras_cv.bounding_box.compute_iou(anchors, boxes)
70
+ matched_columns, matched_match_values = box_matcher(iou_metric)
71
+ cls_mask = ops.less_equal(matched_match_values, 0)
72
+ ```
73
+
74
+ """
75
+
76
+ def __init__(
77
+ self,
78
+ thresholds,
79
+ match_values,
80
+ force_match_for_each_col=False,
81
+ **kwargs,
82
+ ):
83
+ super().__init__(**kwargs)
84
+ if sorted(thresholds) != thresholds:
85
+ raise ValueError(f"`threshold` must be sorted, got {thresholds}")
86
+ self.match_values = match_values
87
+ if len(match_values) != len(thresholds) + 1:
88
+ raise ValueError(
89
+ f"len(`match_values`) must be len(`thresholds`) + 1, got "
90
+ f"match_values {match_values}, thresholds {thresholds}"
91
+ )
92
+ thresholds.insert(0, -float("inf"))
93
+ thresholds.append(float("inf"))
94
+ self.thresholds = thresholds
95
+ self.force_match_for_each_col = force_match_for_each_col
96
+ self.built = True
97
+
98
+ def call(self, similarity_matrix):
99
+ """Matches each row to a column based on argmax
100
+
101
+ Args:
102
+ similarity_matrix: A float Tensor of shape `[num_rows, num_cols]` or
103
+ `[batch_size, num_rows, num_cols]` representing any similarity
104
+ metric.
105
+
106
+ Returns:
107
+ matched_columns: An integer tensor of shape `[num_rows]` or
108
+ `[batch_size, num_rows]` storing the index of the matched
109
+ column for each row.
110
+ matched_values: An integer tensor of shape [num_rows] or
111
+ `[batch_size, num_rows]` storing the match result
112
+ `(positive match, negative match, ignored match)`.
113
+ """
114
+ squeeze_result = False
115
+ if len(similarity_matrix.shape) == 2:
116
+ squeeze_result = True
117
+ similarity_matrix = ops.expand_dims(similarity_matrix, axis=0)
118
+ static_shape = list(similarity_matrix.shape)
119
+ num_rows = static_shape[1] or ops.shape(similarity_matrix)[1]
120
+ batch_size = static_shape[0] or ops.shape(similarity_matrix)[0]
121
+
122
+ def _match_when_cols_are_empty():
123
+ """Performs matching when the rows of similarity matrix are empty.
124
+ When the rows are empty, all detections are false positives. So we
125
+ return a tensor of -1's to indicate that the rows do not match to
126
+ any columns.
127
+
128
+ Returns:
129
+ matched_columns: An integer tensor of shape [batch_size,
130
+ num_rows] storing the index of the matched column for each
131
+ row.
132
+ matched_values: An integer tensor of shape [batch_size,
133
+ num_rows] storing the match type indicator (e.g. positive or
134
+ negative or ignored match).
135
+ """
136
+ matched_columns = ops.zeros([batch_size, num_rows], dtype="int32")
137
+ matched_values = -ops.ones([batch_size, num_rows], dtype="int32")
138
+ return matched_columns, matched_values
139
+
140
+ def _match_when_cols_are_non_empty():
141
+ """Performs matching when the rows of similarity matrix are
142
+ non-empty.
143
+ Returns:
144
+ matched_columns: An integer tensor of shape [batch_size,
145
+ num_rows] storing the index of the matched column for each
146
+ row.
147
+ matched_values: An integer tensor of shape [batch_size,
148
+ num_rows] storing the match type indicator (e.g. positive or
149
+ negative or ignored match).
150
+ """
151
+ # Jax traces this function even when running eagerly and the
152
+ # columns are non-empty. Therefore, we need to handle the case
153
+ # where the similarity matrix is empty. We do this by padding
154
+ # some -1s to the end. -1s are guaranteed to not affect argmax
155
+ # matching because all values in a similarity matrix are [0,1]
156
+ # and the indexing won't change because these are added at the
157
+ # end.
158
+ padded_similarity_matrix = ops.concatenate(
159
+ [similarity_matrix, -ops.ones((batch_size, num_rows, 1))],
160
+ axis=-1,
161
+ )
162
+
163
+ matched_columns = ops.argmax(
164
+ padded_similarity_matrix,
165
+ axis=-1,
166
+ )
167
+
168
+ # Get logical indices of ignored and unmatched columns as int32
169
+ matched_vals = ops.max(padded_similarity_matrix, axis=-1)
170
+ matched_values = ops.zeros([batch_size, num_rows], "int32")
171
+
172
+ match_dtype = matched_vals.dtype
173
+ for ind, low, high in zip(
174
+ self.match_values, self.thresholds[:-1], self.thresholds[1:]
175
+ ):
176
+ low_threshold = ops.cast(low, match_dtype)
177
+ high_threshold = ops.cast(high, match_dtype)
178
+ mask = ops.logical_and(
179
+ ops.greater_equal(matched_vals, low_threshold),
180
+ ops.less(matched_vals, high_threshold),
181
+ )
182
+ matched_values = self._set_values_using_indicator(
183
+ matched_values, mask, ind
184
+ )
185
+
186
+ if self.force_match_for_each_col:
187
+ # [batch_size, num_cols], for each column (groundtruth_box),
188
+ # find the best matching row (anchor).
189
+ matching_rows = ops.argmax(
190
+ padded_similarity_matrix,
191
+ axis=1,
192
+ )
193
+ # [batch_size, num_cols, num_rows], a transposed 0-1 mapping
194
+ # matrix M, where M[j, i] = 1 means column j is matched to
195
+ # row i.
196
+ column_to_row_match_mapping = ops.one_hot(
197
+ matching_rows, num_rows
198
+ )
199
+ # [batch_size, num_rows], for each row (anchor), find the
200
+ # matched column (groundtruth_box).
201
+ force_matched_columns = ops.argmax(
202
+ column_to_row_match_mapping,
203
+ axis=1,
204
+ )
205
+ # [batch_size, num_rows]
206
+ force_matched_column_mask = ops.cast(
207
+ ops.max(column_to_row_match_mapping, axis=1),
208
+ "bool",
209
+ )
210
+ # [batch_size, num_rows]
211
+ matched_columns = ops.where(
212
+ force_matched_column_mask,
213
+ force_matched_columns,
214
+ matched_columns,
215
+ )
216
+ matched_values = ops.where(
217
+ force_matched_column_mask,
218
+ self.match_values[-1]
219
+ * ops.ones([batch_size, num_rows], dtype="int32"),
220
+ matched_values,
221
+ )
222
+
223
+ return ops.cast(matched_columns, "int32"), matched_values
224
+
225
+ num_boxes = (
226
+ similarity_matrix.shape[-1] or ops.shape(similarity_matrix)[-1]
227
+ )
228
+ matched_columns, matched_values = ops.cond(
229
+ pred=ops.greater(num_boxes, 0),
230
+ true_fn=_match_when_cols_are_non_empty,
231
+ false_fn=_match_when_cols_are_empty,
232
+ )
233
+
234
+ if squeeze_result:
235
+ matched_columns = ops.squeeze(matched_columns, axis=0)
236
+ matched_values = ops.squeeze(matched_values, axis=0)
237
+
238
+ return matched_columns, matched_values
239
+
240
+ def _set_values_using_indicator(self, x, indicator, val):
241
+ """Set the indicated fields of x to val.
242
+
243
+ Args:
244
+ x: tensor.
245
+ indicator: boolean with same shape as x.
246
+ val: scalar with value to set.
247
+ Returns:
248
+ modified tensor.
249
+ """
250
+ indicator = ops.cast(indicator, x.dtype)
251
+ return ops.where(indicator == 0, x, val)
252
+
253
+ def get_config(self):
254
+ config = {
255
+ "thresholds": self.thresholds[1:-1],
256
+ "match_values": self.match_values,
257
+ "force_match_for_each_col": self.force_match_for_each_col,
258
+ }
259
+ return config