keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl → 0.16.0.dev2024092017__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +0 -6
- keras_hub/api/__init__.py +2 -0
- keras_hub/api/bounding_box/__init__.py +36 -0
- keras_hub/api/layers/__init__.py +14 -0
- keras_hub/api/models/__init__.py +97 -48
- keras_hub/api/tokenizers/__init__.py +30 -0
- keras_hub/api/utils/__init__.py +22 -0
- keras_hub/src/api_export.py +15 -9
- keras_hub/src/bounding_box/__init__.py +13 -0
- keras_hub/src/bounding_box/converters.py +529 -0
- keras_hub/src/bounding_box/formats.py +162 -0
- keras_hub/src/bounding_box/iou.py +263 -0
- keras_hub/src/bounding_box/to_dense.py +95 -0
- keras_hub/src/bounding_box/to_ragged.py +99 -0
- keras_hub/src/bounding_box/utils.py +194 -0
- keras_hub/src/bounding_box/validate_format.py +99 -0
- keras_hub/src/layers/preprocessing/audio_converter.py +121 -0
- keras_hub/src/layers/preprocessing/image_converter.py +130 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +2 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +9 -8
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +2 -29
- keras_hub/src/layers/preprocessing/random_deletion.py +33 -31
- keras_hub/src/layers/preprocessing/random_swap.py +33 -31
- keras_hub/src/layers/preprocessing/resizing_image_converter.py +101 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +3 -2
- keras_hub/src/models/albert/__init__.py +1 -2
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +6 -86
- keras_hub/src/models/albert/{albert_classifier.py → albert_text_classifier.py} +34 -10
- keras_hub/src/models/albert/{albert_preprocessor.py → albert_text_classifier_preprocessor.py} +14 -70
- keras_hub/src/models/albert/albert_tokenizer.py +17 -36
- keras_hub/src/models/backbone.py +12 -34
- keras_hub/src/models/bart/__init__.py +1 -2
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +21 -148
- keras_hub/src/models/bart/bart_tokenizer.py +12 -39
- keras_hub/src/models/bert/__init__.py +1 -5
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +6 -87
- keras_hub/src/models/bert/bert_presets.py +1 -4
- keras_hub/src/models/bert/{bert_classifier.py → bert_text_classifier.py} +19 -12
- keras_hub/src/models/bert/{bert_preprocessor.py → bert_text_classifier_preprocessor.py} +14 -70
- keras_hub/src/models/bert/bert_tokenizer.py +17 -35
- keras_hub/src/models/bloom/__init__.py +1 -2
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +6 -91
- keras_hub/src/models/bloom/bloom_tokenizer.py +12 -41
- keras_hub/src/models/causal_lm.py +10 -29
- keras_hub/src/models/causal_lm_preprocessor.py +195 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +54 -15
- keras_hub/src/models/deberta_v3/__init__.py +1 -4
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +14 -77
- keras_hub/src/models/deberta_v3/{deberta_v3_classifier.py → deberta_v3_text_classifier.py} +16 -11
- keras_hub/src/models/deberta_v3/{deberta_v3_preprocessor.py → deberta_v3_text_classifier_preprocessor.py} +23 -64
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +30 -25
- keras_hub/src/models/densenet/densenet_backbone.py +46 -22
- keras_hub/src/models/distil_bert/__init__.py +1 -4
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +14 -76
- keras_hub/src/models/distil_bert/{distil_bert_classifier.py → distil_bert_text_classifier.py} +17 -12
- keras_hub/src/models/distil_bert/{distil_bert_preprocessor.py → distil_bert_text_classifier_preprocessor.py} +23 -63
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +19 -35
- keras_hub/src/models/efficientnet/__init__.py +13 -0
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +569 -0
- keras_hub/src/models/efficientnet/fusedmbconv.py +229 -0
- keras_hub/src/models/efficientnet/mbconv.py +238 -0
- keras_hub/src/models/electra/__init__.py +1 -2
- keras_hub/src/models/electra/electra_tokenizer.py +17 -32
- keras_hub/src/models/f_net/__init__.py +1 -2
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +12 -78
- keras_hub/src/models/f_net/{f_net_classifier.py → f_net_text_classifier.py} +17 -10
- keras_hub/src/models/f_net/{f_net_preprocessor.py → f_net_text_classifier_preprocessor.py} +19 -63
- keras_hub/src/models/f_net/f_net_tokenizer.py +17 -35
- keras_hub/src/models/falcon/__init__.py +1 -2
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/falcon/falcon_tokenizer.py +12 -35
- keras_hub/src/models/gemma/__init__.py +1 -2
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +6 -90
- keras_hub/src/models/gemma/gemma_decoder_block.py +1 -1
- keras_hub/src/models/gemma/gemma_tokenizer.py +12 -23
- keras_hub/src/models/gpt2/__init__.py +1 -2
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +12 -90
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +12 -34
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +6 -91
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +12 -34
- keras_hub/src/models/image_classifier.py +0 -5
- keras_hub/src/models/image_classifier_preprocessor.py +83 -0
- keras_hub/src/models/llama/__init__.py +1 -2
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +6 -85
- keras_hub/src/models/llama/llama_tokenizer.py +12 -25
- keras_hub/src/models/llama3/__init__.py +1 -2
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/llama3/llama3_tokenizer.py +12 -33
- keras_hub/src/models/masked_lm.py +0 -2
- keras_hub/src/models/masked_lm_preprocessor.py +156 -0
- keras_hub/src/models/mistral/__init__.py +1 -2
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +6 -91
- keras_hub/src/models/mistral/mistral_tokenizer.py +12 -23
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +2 -2
- keras_hub/src/models/mobilenet/__init__.py +13 -0
- keras_hub/src/models/mobilenet/mobilenet_backbone.py +530 -0
- keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +114 -0
- keras_hub/src/models/opt/__init__.py +1 -2
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +6 -93
- keras_hub/src/models/opt/opt_tokenizer.py +12 -41
- keras_hub/src/models/pali_gemma/__init__.py +1 -4
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +28 -28
- keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +25 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +5 -5
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +11 -3
- keras_hub/src/models/phi3/__init__.py +1 -2
- keras_hub/src/models/phi3/phi3_causal_lm.py +3 -9
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +6 -89
- keras_hub/src/models/phi3/phi3_tokenizer.py +12 -36
- keras_hub/src/models/preprocessor.py +72 -83
- keras_hub/src/models/resnet/__init__.py +6 -0
- keras_hub/src/models/resnet/resnet_backbone.py +390 -42
- keras_hub/src/models/resnet/resnet_image_classifier.py +33 -6
- keras_hub/src/models/resnet/resnet_image_classifier_preprocessor.py +28 -0
- keras_hub/src/models/{llama3/llama3_preprocessor.py → resnet/resnet_image_converter.py} +7 -5
- keras_hub/src/models/resnet/resnet_presets.py +95 -0
- keras_hub/src/models/retinanet/__init__.py +13 -0
- keras_hub/src/models/retinanet/anchor_generator.py +175 -0
- keras_hub/src/models/retinanet/box_matcher.py +259 -0
- keras_hub/src/models/retinanet/non_max_supression.py +578 -0
- keras_hub/src/models/roberta/__init__.py +1 -2
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +22 -74
- keras_hub/src/models/roberta/{roberta_classifier.py → roberta_text_classifier.py} +16 -11
- keras_hub/src/models/roberta/{roberta_preprocessor.py → roberta_text_classifier_preprocessor.py} +21 -53
- keras_hub/src/models/roberta/roberta_tokenizer.py +13 -52
- keras_hub/src/models/seq_2_seq_lm_preprocessor.py +269 -0
- keras_hub/src/models/stable_diffusion_v3/__init__.py +13 -0
- keras_hub/src/models/stable_diffusion_v3/clip_encoder_block.py +103 -0
- keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +93 -0
- keras_hub/src/models/stable_diffusion_v3/clip_text_encoder.py +149 -0
- keras_hub/src/models/stable_diffusion_v3/clip_tokenizer.py +167 -0
- keras_hub/src/models/stable_diffusion_v3/mmdit.py +427 -0
- keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +317 -0
- keras_hub/src/models/stable_diffusion_v3/t5_xxl_preprocessor.py +74 -0
- keras_hub/src/models/stable_diffusion_v3/t5_xxl_text_encoder.py +155 -0
- keras_hub/src/models/stable_diffusion_v3/vae_attention.py +126 -0
- keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +186 -0
- keras_hub/src/models/t5/__init__.py +1 -2
- keras_hub/src/models/t5/t5_tokenizer.py +13 -23
- keras_hub/src/models/task.py +71 -116
- keras_hub/src/models/{classifier.py → text_classifier.py} +19 -13
- keras_hub/src/models/text_classifier_preprocessor.py +138 -0
- keras_hub/src/models/whisper/__init__.py +1 -2
- keras_hub/src/models/whisper/{whisper_audio_feature_extractor.py → whisper_audio_converter.py} +20 -18
- keras_hub/src/models/whisper/whisper_backbone.py +0 -3
- keras_hub/src/models/whisper/whisper_presets.py +10 -10
- keras_hub/src/models/whisper/whisper_tokenizer.py +20 -16
- keras_hub/src/models/xlm_roberta/__init__.py +1 -4
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +26 -72
- keras_hub/src/models/xlm_roberta/{xlm_roberta_classifier.py → xlm_roberta_text_classifier.py} +16 -11
- keras_hub/src/models/xlm_roberta/{xlm_roberta_preprocessor.py → xlm_roberta_text_classifier_preprocessor.py} +26 -53
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +25 -10
- keras_hub/src/tests/test_case.py +46 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +30 -17
- keras_hub/src/tokenizers/byte_tokenizer.py +14 -15
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +20 -7
- keras_hub/src/tokenizers/tokenizer.py +67 -32
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +14 -15
- keras_hub/src/tokenizers/word_piece_tokenizer.py +34 -47
- keras_hub/src/utils/imagenet/__init__.py +13 -0
- keras_hub/src/utils/imagenet/imagenet_utils.py +1067 -0
- keras_hub/src/utils/keras_utils.py +0 -50
- keras_hub/src/utils/preset_utils.py +230 -68
- keras_hub/src/utils/tensor_utils.py +187 -69
- keras_hub/src/utils/timm/convert_resnet.py +19 -16
- keras_hub/src/utils/timm/preset_loader.py +66 -0
- keras_hub/src/utils/transformers/convert_albert.py +193 -0
- keras_hub/src/utils/transformers/convert_bart.py +373 -0
- keras_hub/src/utils/transformers/convert_bert.py +7 -17
- keras_hub/src/utils/transformers/convert_distilbert.py +10 -20
- keras_hub/src/utils/transformers/convert_gemma.py +5 -19
- keras_hub/src/utils/transformers/convert_gpt2.py +5 -18
- keras_hub/src/utils/transformers/convert_llama3.py +7 -18
- keras_hub/src/utils/transformers/convert_mistral.py +129 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +7 -29
- keras_hub/src/utils/transformers/preset_loader.py +77 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +2 -2
- keras_hub/src/version_utils.py +1 -1
- keras_hub_nightly-0.16.0.dev2024092017.dist-info/METADATA +202 -0
- keras_hub_nightly-0.16.0.dev2024092017.dist-info/RECORD +334 -0
- {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/WHEEL +1 -1
- keras_hub/src/models/bart/bart_preprocessor.py +0 -276
- keras_hub/src/models/bloom/bloom_preprocessor.py +0 -185
- keras_hub/src/models/electra/electra_preprocessor.py +0 -154
- keras_hub/src/models/falcon/falcon_preprocessor.py +0 -187
- keras_hub/src/models/gemma/gemma_preprocessor.py +0 -191
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +0 -145
- keras_hub/src/models/llama/llama_preprocessor.py +0 -189
- keras_hub/src/models/mistral/mistral_preprocessor.py +0 -190
- keras_hub/src/models/opt/opt_preprocessor.py +0 -188
- keras_hub/src/models/phi3/phi3_preprocessor.py +0 -190
- keras_hub/src/models/whisper/whisper_preprocessor.py +0 -326
- keras_hub/src/utils/timm/convert.py +0 -37
- keras_hub/src/utils/transformers/convert.py +0 -101
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +0 -34
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +0 -297
- {keras_hub_nightly-0.15.0.dev20240823171555.dist-info → keras_hub_nightly-0.16.0.dev2024092017.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,259 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import keras
|
16
|
+
from keras import ops
|
17
|
+
|
18
|
+
|
19
|
+
class BoxMatcher(keras.layers.Layer):
|
20
|
+
"""Box matching logic based on argmax of highest value (e.g., IOU).
|
21
|
+
|
22
|
+
This class computes matches from a similarity matrix. Each row will be
|
23
|
+
matched to at least one column, the matched result can either be positive
|
24
|
+
or negative, or simply ignored depending on the setting.
|
25
|
+
|
26
|
+
The settings include `thresholds` and `match_values`, for example if:
|
27
|
+
1) `thresholds=[negative_threshold, positive_threshold]`, and
|
28
|
+
`match_values=[negative_value=0, ignore_value=-1, positive_value=1]`: the
|
29
|
+
rows will be assigned to positive_value if its argmax result >=
|
30
|
+
positive_threshold; the rows will be assigned to negative_value if its
|
31
|
+
argmax result < negative_threshold, and the rows will be assigned to
|
32
|
+
ignore_value if its argmax result is between [negative_threshold,
|
33
|
+
positive_threshold).
|
34
|
+
2) `thresholds=[negative_threshold, positive_threshold]`, and
|
35
|
+
`match_values=[ignore_value=-1, negative_value=0, positive_value=1]`: the
|
36
|
+
rows will be assigned to positive_value if its argmax result >=
|
37
|
+
positive_threshold; the rows will be assigned to ignore_value if its
|
38
|
+
argmax result < negative_threshold, and the rows will be assigned to
|
39
|
+
negative_value if its argmax result is between [negative_threshold,
|
40
|
+
positive_threshold). This is different from case 1) by swapping first two
|
41
|
+
values.
|
42
|
+
3) `thresholds=[positive_threshold]`, and
|
43
|
+
`match_values=[negative_values, positive_value]`: the rows will be
|
44
|
+
assigned to positive value if its argmax result >= positive_threshold;
|
45
|
+
the rows will be assigned to negative_value if its argmax result <
|
46
|
+
negative_threshold.
|
47
|
+
|
48
|
+
Args:
|
49
|
+
thresholds: A sorted list of floats to classify the matches into
|
50
|
+
different results (e.g. positive or negative or ignored match). The
|
51
|
+
list will be prepended with -Inf and and appended with +Inf.
|
52
|
+
match_values: A list of integers representing matched results (e.g.
|
53
|
+
positive or negative or ignored match). len(`match_values`) must
|
54
|
+
equal to len(`thresholds`) + 1.
|
55
|
+
force_match_for_each_col: each row will be argmax matched to at
|
56
|
+
least one column. This means some columns will be matched to
|
57
|
+
multiple rows while some columns will not be matched to any rows.
|
58
|
+
Filtering by `thresholds` will make less columns match to positive
|
59
|
+
result. Setting this to True guarantees that each column will be
|
60
|
+
matched to positive result to at least one row.
|
61
|
+
|
62
|
+
Raises:
|
63
|
+
ValueError: if `thresholds` not sorted or
|
64
|
+
len(`match_values`) != len(`thresholds`) + 1
|
65
|
+
|
66
|
+
Example:
|
67
|
+
```python
|
68
|
+
box_matcher = keras_cv.layers.BoxMatcher([0.3, 0.7], [-1, 0, 1])
|
69
|
+
iou_metric = keras_cv.bounding_box.compute_iou(anchors, boxes)
|
70
|
+
matched_columns, matched_match_values = box_matcher(iou_metric)
|
71
|
+
cls_mask = ops.less_equal(matched_match_values, 0)
|
72
|
+
```
|
73
|
+
|
74
|
+
"""
|
75
|
+
|
76
|
+
def __init__(
|
77
|
+
self,
|
78
|
+
thresholds,
|
79
|
+
match_values,
|
80
|
+
force_match_for_each_col=False,
|
81
|
+
**kwargs,
|
82
|
+
):
|
83
|
+
super().__init__(**kwargs)
|
84
|
+
if sorted(thresholds) != thresholds:
|
85
|
+
raise ValueError(f"`threshold` must be sorted, got {thresholds}")
|
86
|
+
self.match_values = match_values
|
87
|
+
if len(match_values) != len(thresholds) + 1:
|
88
|
+
raise ValueError(
|
89
|
+
f"len(`match_values`) must be len(`thresholds`) + 1, got "
|
90
|
+
f"match_values {match_values}, thresholds {thresholds}"
|
91
|
+
)
|
92
|
+
thresholds.insert(0, -float("inf"))
|
93
|
+
thresholds.append(float("inf"))
|
94
|
+
self.thresholds = thresholds
|
95
|
+
self.force_match_for_each_col = force_match_for_each_col
|
96
|
+
self.built = True
|
97
|
+
|
98
|
+
def call(self, similarity_matrix):
|
99
|
+
"""Matches each row to a column based on argmax
|
100
|
+
|
101
|
+
Args:
|
102
|
+
similarity_matrix: A float Tensor of shape `[num_rows, num_cols]` or
|
103
|
+
`[batch_size, num_rows, num_cols]` representing any similarity
|
104
|
+
metric.
|
105
|
+
|
106
|
+
Returns:
|
107
|
+
matched_columns: An integer tensor of shape `[num_rows]` or
|
108
|
+
`[batch_size, num_rows]` storing the index of the matched
|
109
|
+
column for each row.
|
110
|
+
matched_values: An integer tensor of shape [num_rows] or
|
111
|
+
`[batch_size, num_rows]` storing the match result
|
112
|
+
`(positive match, negative match, ignored match)`.
|
113
|
+
"""
|
114
|
+
squeeze_result = False
|
115
|
+
if len(similarity_matrix.shape) == 2:
|
116
|
+
squeeze_result = True
|
117
|
+
similarity_matrix = ops.expand_dims(similarity_matrix, axis=0)
|
118
|
+
static_shape = list(similarity_matrix.shape)
|
119
|
+
num_rows = static_shape[1] or ops.shape(similarity_matrix)[1]
|
120
|
+
batch_size = static_shape[0] or ops.shape(similarity_matrix)[0]
|
121
|
+
|
122
|
+
def _match_when_cols_are_empty():
|
123
|
+
"""Performs matching when the rows of similarity matrix are empty.
|
124
|
+
When the rows are empty, all detections are false positives. So we
|
125
|
+
return a tensor of -1's to indicate that the rows do not match to
|
126
|
+
any columns.
|
127
|
+
|
128
|
+
Returns:
|
129
|
+
matched_columns: An integer tensor of shape [batch_size,
|
130
|
+
num_rows] storing the index of the matched column for each
|
131
|
+
row.
|
132
|
+
matched_values: An integer tensor of shape [batch_size,
|
133
|
+
num_rows] storing the match type indicator (e.g. positive or
|
134
|
+
negative or ignored match).
|
135
|
+
"""
|
136
|
+
matched_columns = ops.zeros([batch_size, num_rows], dtype="int32")
|
137
|
+
matched_values = -ops.ones([batch_size, num_rows], dtype="int32")
|
138
|
+
return matched_columns, matched_values
|
139
|
+
|
140
|
+
def _match_when_cols_are_non_empty():
|
141
|
+
"""Performs matching when the rows of similarity matrix are
|
142
|
+
non-empty.
|
143
|
+
Returns:
|
144
|
+
matched_columns: An integer tensor of shape [batch_size,
|
145
|
+
num_rows] storing the index of the matched column for each
|
146
|
+
row.
|
147
|
+
matched_values: An integer tensor of shape [batch_size,
|
148
|
+
num_rows] storing the match type indicator (e.g. positive or
|
149
|
+
negative or ignored match).
|
150
|
+
"""
|
151
|
+
# Jax traces this function even when running eagerly and the
|
152
|
+
# columns are non-empty. Therefore, we need to handle the case
|
153
|
+
# where the similarity matrix is empty. We do this by padding
|
154
|
+
# some -1s to the end. -1s are guaranteed to not affect argmax
|
155
|
+
# matching because all values in a similarity matrix are [0,1]
|
156
|
+
# and the indexing won't change because these are added at the
|
157
|
+
# end.
|
158
|
+
padded_similarity_matrix = ops.concatenate(
|
159
|
+
[similarity_matrix, -ops.ones((batch_size, num_rows, 1))],
|
160
|
+
axis=-1,
|
161
|
+
)
|
162
|
+
|
163
|
+
matched_columns = ops.argmax(
|
164
|
+
padded_similarity_matrix,
|
165
|
+
axis=-1,
|
166
|
+
)
|
167
|
+
|
168
|
+
# Get logical indices of ignored and unmatched columns as int32
|
169
|
+
matched_vals = ops.max(padded_similarity_matrix, axis=-1)
|
170
|
+
matched_values = ops.zeros([batch_size, num_rows], "int32")
|
171
|
+
|
172
|
+
match_dtype = matched_vals.dtype
|
173
|
+
for ind, low, high in zip(
|
174
|
+
self.match_values, self.thresholds[:-1], self.thresholds[1:]
|
175
|
+
):
|
176
|
+
low_threshold = ops.cast(low, match_dtype)
|
177
|
+
high_threshold = ops.cast(high, match_dtype)
|
178
|
+
mask = ops.logical_and(
|
179
|
+
ops.greater_equal(matched_vals, low_threshold),
|
180
|
+
ops.less(matched_vals, high_threshold),
|
181
|
+
)
|
182
|
+
matched_values = self._set_values_using_indicator(
|
183
|
+
matched_values, mask, ind
|
184
|
+
)
|
185
|
+
|
186
|
+
if self.force_match_for_each_col:
|
187
|
+
# [batch_size, num_cols], for each column (groundtruth_box),
|
188
|
+
# find the best matching row (anchor).
|
189
|
+
matching_rows = ops.argmax(
|
190
|
+
padded_similarity_matrix,
|
191
|
+
axis=1,
|
192
|
+
)
|
193
|
+
# [batch_size, num_cols, num_rows], a transposed 0-1 mapping
|
194
|
+
# matrix M, where M[j, i] = 1 means column j is matched to
|
195
|
+
# row i.
|
196
|
+
column_to_row_match_mapping = ops.one_hot(
|
197
|
+
matching_rows, num_rows
|
198
|
+
)
|
199
|
+
# [batch_size, num_rows], for each row (anchor), find the
|
200
|
+
# matched column (groundtruth_box).
|
201
|
+
force_matched_columns = ops.argmax(
|
202
|
+
column_to_row_match_mapping,
|
203
|
+
axis=1,
|
204
|
+
)
|
205
|
+
# [batch_size, num_rows]
|
206
|
+
force_matched_column_mask = ops.cast(
|
207
|
+
ops.max(column_to_row_match_mapping, axis=1),
|
208
|
+
"bool",
|
209
|
+
)
|
210
|
+
# [batch_size, num_rows]
|
211
|
+
matched_columns = ops.where(
|
212
|
+
force_matched_column_mask,
|
213
|
+
force_matched_columns,
|
214
|
+
matched_columns,
|
215
|
+
)
|
216
|
+
matched_values = ops.where(
|
217
|
+
force_matched_column_mask,
|
218
|
+
self.match_values[-1]
|
219
|
+
* ops.ones([batch_size, num_rows], dtype="int32"),
|
220
|
+
matched_values,
|
221
|
+
)
|
222
|
+
|
223
|
+
return ops.cast(matched_columns, "int32"), matched_values
|
224
|
+
|
225
|
+
num_boxes = (
|
226
|
+
similarity_matrix.shape[-1] or ops.shape(similarity_matrix)[-1]
|
227
|
+
)
|
228
|
+
matched_columns, matched_values = ops.cond(
|
229
|
+
pred=ops.greater(num_boxes, 0),
|
230
|
+
true_fn=_match_when_cols_are_non_empty,
|
231
|
+
false_fn=_match_when_cols_are_empty,
|
232
|
+
)
|
233
|
+
|
234
|
+
if squeeze_result:
|
235
|
+
matched_columns = ops.squeeze(matched_columns, axis=0)
|
236
|
+
matched_values = ops.squeeze(matched_values, axis=0)
|
237
|
+
|
238
|
+
return matched_columns, matched_values
|
239
|
+
|
240
|
+
def _set_values_using_indicator(self, x, indicator, val):
|
241
|
+
"""Set the indicated fields of x to val.
|
242
|
+
|
243
|
+
Args:
|
244
|
+
x: tensor.
|
245
|
+
indicator: boolean with same shape as x.
|
246
|
+
val: scalar with value to set.
|
247
|
+
Returns:
|
248
|
+
modified tensor.
|
249
|
+
"""
|
250
|
+
indicator = ops.cast(indicator, x.dtype)
|
251
|
+
return ops.where(indicator == 0, x, val)
|
252
|
+
|
253
|
+
def get_config(self):
|
254
|
+
config = {
|
255
|
+
"thresholds": self.thresholds[1:-1],
|
256
|
+
"match_values": self.match_values,
|
257
|
+
"force_match_for_each_col": self.force_match_for_each_col,
|
258
|
+
}
|
259
|
+
return config
|