keras-hub-nightly 0.23.0.dev202508260411__py3-none-any.whl → 0.23.0.dev202508280418__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/layers/__init__.py +6 -0
- keras_hub/models/__init__.py +21 -0
- keras_hub/src/layers/modeling/position_embedding.py +21 -6
- keras_hub/src/layers/modeling/rotary_embedding.py +16 -6
- keras_hub/src/layers/modeling/sine_position_encoding.py +21 -8
- keras_hub/src/layers/modeling/token_and_position_embedding.py +2 -1
- keras_hub/src/models/backbone.py +10 -15
- keras_hub/src/models/d_fine/__init__.py +0 -0
- keras_hub/src/models/d_fine/d_fine_attention.py +461 -0
- keras_hub/src/models/d_fine/d_fine_backbone.py +891 -0
- keras_hub/src/models/d_fine/d_fine_decoder.py +944 -0
- keras_hub/src/models/d_fine/d_fine_encoder.py +365 -0
- keras_hub/src/models/d_fine/d_fine_hybrid_encoder.py +642 -0
- keras_hub/src/models/d_fine/d_fine_image_converter.py +8 -0
- keras_hub/src/models/d_fine/d_fine_layers.py +1828 -0
- keras_hub/src/models/d_fine/d_fine_loss.py +938 -0
- keras_hub/src/models/d_fine/d_fine_object_detector.py +875 -0
- keras_hub/src/models/d_fine/d_fine_object_detector_preprocessor.py +14 -0
- keras_hub/src/models/d_fine/d_fine_presets.py +2 -0
- keras_hub/src/models/d_fine/d_fine_utils.py +827 -0
- keras_hub/src/models/hgnetv2/hgnetv2_backbone.py +4 -1
- keras_hub/src/models/hgnetv2/hgnetv2_encoder.py +3 -2
- keras_hub/src/models/hgnetv2/hgnetv2_layers.py +27 -11
- keras_hub/src/models/parseq/__init__.py +0 -0
- keras_hub/src/models/parseq/parseq_backbone.py +134 -0
- keras_hub/src/models/parseq/parseq_causal_lm.py +466 -0
- keras_hub/src/models/parseq/parseq_causal_lm_preprocessor.py +168 -0
- keras_hub/src/models/parseq/parseq_decoder.py +418 -0
- keras_hub/src/models/parseq/parseq_image_converter.py +8 -0
- keras_hub/src/models/parseq/parseq_tokenizer.py +221 -0
- keras_hub/src/tests/test_case.py +37 -1
- keras_hub/src/utils/preset_utils.py +49 -0
- keras_hub/src/utils/tensor_utils.py +23 -1
- keras_hub/src/utils/transformers/convert_vit.py +4 -1
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +3 -0
- {keras_hub_nightly-0.23.0.dev202508260411.dist-info → keras_hub_nightly-0.23.0.dev202508280418.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.23.0.dev202508260411.dist-info → keras_hub_nightly-0.23.0.dev202508280418.dist-info}/RECORD +40 -20
- {keras_hub_nightly-0.23.0.dev202508260411.dist-info → keras_hub_nightly-0.23.0.dev202508280418.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.23.0.dev202508260411.dist-info → keras_hub_nightly-0.23.0.dev202508280418.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,938 @@
|
|
1
|
+
import keras
|
2
|
+
|
3
|
+
from keras_hub.src.models.d_fine.d_fine_utils import hungarian_assignment
|
4
|
+
from keras_hub.src.models.d_fine.d_fine_utils import weighting_function
|
5
|
+
|
6
|
+
|
7
|
+
def gather_along_first_two_dims(tensor, batch_idx, src_idx):
|
8
|
+
batch_size, num_queries, *feature_dims = keras.ops.shape(tensor)
|
9
|
+
batch_size = keras.ops.cast(batch_size, dtype=batch_idx.dtype)
|
10
|
+
num_queries = keras.ops.cast(num_queries, dtype=batch_idx.dtype)
|
11
|
+
linear_idx = batch_idx * num_queries + src_idx
|
12
|
+
flat_tensor = keras.ops.reshape(tensor, (-1, *feature_dims))
|
13
|
+
gathered = keras.ops.take(flat_tensor, linear_idx, axis=0)
|
14
|
+
return gathered
|
15
|
+
|
16
|
+
|
17
|
+
def hungarian_matcher(
|
18
|
+
outputs,
|
19
|
+
targets,
|
20
|
+
num_targets_per_image,
|
21
|
+
use_focal_loss,
|
22
|
+
matcher_alpha,
|
23
|
+
matcher_gamma,
|
24
|
+
matcher_bbox_cost,
|
25
|
+
matcher_class_cost,
|
26
|
+
matcher_ciou_cost,
|
27
|
+
backbone,
|
28
|
+
):
|
29
|
+
"""Performs bipartite matching between predictions and ground truths.
|
30
|
+
|
31
|
+
This method implements the Hungarian matching algorithm to find the
|
32
|
+
optimal one-to-one assignment between the model's predictions (queries)
|
33
|
+
and the ground truth objects. The cost matrix for the assignment is a
|
34
|
+
weighted sum of three components:
|
35
|
+
1. **Class Cost:** The cost of classifying a query into the wrong
|
36
|
+
class.
|
37
|
+
2. **Bounding Box Cost:** The L1 distance between the predicted and
|
38
|
+
ground truth bounding boxes.
|
39
|
+
3. **CIoU Cost:** The Complete Intersection over Union (CIoU) loss.
|
40
|
+
|
41
|
+
Args:
|
42
|
+
outputs: dict, A dictionary containing predicted `"logits"` and
|
43
|
+
`"pred_boxes"`.
|
44
|
+
targets: list of dict, A list of dictionaries, each containing
|
45
|
+
the ground truth `"labels"` and `"boxes"`.
|
46
|
+
num_targets_per_image: A tensor of shape `(batch_size,)` indicating
|
47
|
+
the number of ground truth objects in each image.
|
48
|
+
|
49
|
+
Returns:
|
50
|
+
tuple: A tuple of three tensors `(row_indices, col_indices,
|
51
|
+
valid_masks)`. `row_indices` and `col_indices` contain the indices
|
52
|
+
of matched predictions and ground truths, while `valid_masks`
|
53
|
+
indicates which matches are valid.
|
54
|
+
"""
|
55
|
+
batch_size = keras.ops.shape(outputs["logits"])[0]
|
56
|
+
num_queries = keras.ops.shape(outputs["logits"])[1]
|
57
|
+
out_logits = outputs["logits"]
|
58
|
+
out_bbox = outputs["pred_boxes"]
|
59
|
+
target_ids_all = keras.ops.cast(targets[0]["labels"], dtype="int32")
|
60
|
+
target_bbox_all = targets[0]["boxes"]
|
61
|
+
target_offsets = keras.ops.concatenate(
|
62
|
+
[
|
63
|
+
keras.ops.zeros((1,), dtype="int32"),
|
64
|
+
keras.ops.cumsum(num_targets_per_image),
|
65
|
+
]
|
66
|
+
)
|
67
|
+
max_matches = num_queries
|
68
|
+
row_indices_init = keras.ops.zeros((batch_size, max_matches), dtype="int32")
|
69
|
+
col_indices_init = keras.ops.zeros((batch_size, max_matches), dtype="int32")
|
70
|
+
valid_masks_init = keras.ops.zeros((batch_size, max_matches), dtype="bool")
|
71
|
+
|
72
|
+
def loop_body(i, loop_vars):
|
73
|
+
row_indices, col_indices, valid_masks = loop_vars
|
74
|
+
out_logits_i = out_logits[i]
|
75
|
+
out_bbox_i = out_bbox[i]
|
76
|
+
start = target_offsets[i]
|
77
|
+
end = target_offsets[i + 1]
|
78
|
+
num_targets_i = end - start
|
79
|
+
k = keras.ops.arange(0, num_queries)
|
80
|
+
is_valid_target_mask = k < num_targets_i
|
81
|
+
target_indices = start + k
|
82
|
+
safe_target_indices = keras.ops.minimum(
|
83
|
+
target_indices, keras.ops.shape(target_ids_all)[0] - 1
|
84
|
+
)
|
85
|
+
target_ids_i = keras.ops.take(
|
86
|
+
target_ids_all, safe_target_indices, axis=0
|
87
|
+
)
|
88
|
+
target_bbox_i = keras.ops.take(
|
89
|
+
target_bbox_all, safe_target_indices, axis=0
|
90
|
+
)
|
91
|
+
|
92
|
+
def compute_cost_matrix():
|
93
|
+
if use_focal_loss:
|
94
|
+
out_prob_i = keras.ops.sigmoid(out_logits_i)
|
95
|
+
safe_ids_for_take = keras.ops.maximum(target_ids_i, 0)
|
96
|
+
prob_for_target_classes = keras.ops.take(
|
97
|
+
out_prob_i, safe_ids_for_take, axis=1
|
98
|
+
)
|
99
|
+
p = prob_for_target_classes
|
100
|
+
pos_cost = (
|
101
|
+
matcher_alpha
|
102
|
+
* keras.ops.power(1 - p, matcher_gamma)
|
103
|
+
* (-keras.ops.log(p + 1e-8))
|
104
|
+
)
|
105
|
+
neg_cost = (
|
106
|
+
(1 - matcher_alpha)
|
107
|
+
* keras.ops.power(p, matcher_gamma)
|
108
|
+
* (-keras.ops.log(1 - p + 1e-8))
|
109
|
+
)
|
110
|
+
class_cost_i = pos_cost - neg_cost
|
111
|
+
else:
|
112
|
+
out_prob_softmax_i = keras.ops.softmax(out_logits_i, axis=-1)
|
113
|
+
safe_ids_for_take = keras.ops.maximum(target_ids_i, 0)
|
114
|
+
prob_for_target_classes = keras.ops.take(
|
115
|
+
out_prob_softmax_i, safe_ids_for_take, axis=1
|
116
|
+
)
|
117
|
+
class_cost_i = -prob_for_target_classes
|
118
|
+
|
119
|
+
bbox_cost_i = keras.ops.sum(
|
120
|
+
keras.ops.abs(
|
121
|
+
keras.ops.expand_dims(out_bbox_i, 1)
|
122
|
+
- keras.ops.expand_dims(target_bbox_i, 0)
|
123
|
+
),
|
124
|
+
axis=2,
|
125
|
+
)
|
126
|
+
out_bbox_corners_i = keras.utils.bounding_boxes.convert_format(
|
127
|
+
out_bbox_i,
|
128
|
+
source="center_xywh",
|
129
|
+
target="xyxy",
|
130
|
+
)
|
131
|
+
target_bbox_corners_i = keras.utils.bounding_boxes.convert_format(
|
132
|
+
target_bbox_i,
|
133
|
+
source="center_xywh",
|
134
|
+
target="xyxy",
|
135
|
+
)
|
136
|
+
ciou_cost_i = -keras.utils.bounding_boxes.compute_ciou(
|
137
|
+
keras.ops.expand_dims(out_bbox_corners_i, 1),
|
138
|
+
keras.ops.expand_dims(target_bbox_corners_i, 0),
|
139
|
+
bounding_box_format="xyxy",
|
140
|
+
)
|
141
|
+
|
142
|
+
cost_matrix_i = (
|
143
|
+
matcher_bbox_cost * bbox_cost_i
|
144
|
+
+ matcher_class_cost * class_cost_i
|
145
|
+
+ matcher_ciou_cost * ciou_cost_i
|
146
|
+
)
|
147
|
+
cost_matrix_i = keras.ops.where(
|
148
|
+
keras.ops.expand_dims(is_valid_target_mask, 0),
|
149
|
+
cost_matrix_i,
|
150
|
+
1e9,
|
151
|
+
)
|
152
|
+
return cost_matrix_i
|
153
|
+
|
154
|
+
def perform_assignment():
|
155
|
+
cost_matrix_i = compute_cost_matrix()
|
156
|
+
row_idx, col_idx, valid_mask = hungarian_assignment(
|
157
|
+
cost_matrix_i, backbone.num_queries
|
158
|
+
)
|
159
|
+
valid_mask = keras.ops.logical_and(
|
160
|
+
valid_mask, col_idx < num_targets_i
|
161
|
+
)
|
162
|
+
return row_idx, col_idx, valid_mask
|
163
|
+
|
164
|
+
def skip_assignment():
|
165
|
+
return (
|
166
|
+
keras.ops.zeros((num_queries,), dtype="int32"),
|
167
|
+
keras.ops.zeros((num_queries,), dtype="int32"),
|
168
|
+
keras.ops.zeros((num_queries,), dtype="bool"),
|
169
|
+
)
|
170
|
+
|
171
|
+
row_idx, col_idx, valid_mask = keras.ops.cond(
|
172
|
+
keras.ops.greater(num_targets_i, 0),
|
173
|
+
perform_assignment,
|
174
|
+
skip_assignment,
|
175
|
+
)
|
176
|
+
row_indices = keras.ops.scatter_update(
|
177
|
+
row_indices, [[i]], keras.ops.expand_dims(row_idx, axis=0)
|
178
|
+
)
|
179
|
+
col_indices = keras.ops.scatter_update(
|
180
|
+
col_indices, [[i]], keras.ops.expand_dims(col_idx, axis=0)
|
181
|
+
)
|
182
|
+
valid_masks = keras.ops.scatter_update(
|
183
|
+
valid_masks, [[i]], keras.ops.expand_dims(valid_mask, axis=0)
|
184
|
+
)
|
185
|
+
return row_indices, col_indices, valid_masks
|
186
|
+
|
187
|
+
row_indices, col_indices, valid_masks = keras.ops.fori_loop(
|
188
|
+
0,
|
189
|
+
batch_size,
|
190
|
+
loop_body,
|
191
|
+
(row_indices_init, col_indices_init, valid_masks_init),
|
192
|
+
)
|
193
|
+
return (row_indices, col_indices, valid_masks)
|
194
|
+
|
195
|
+
|
196
|
+
def compute_vfl_loss(
|
197
|
+
outputs,
|
198
|
+
targets,
|
199
|
+
indices,
|
200
|
+
num_boxes,
|
201
|
+
num_classes,
|
202
|
+
matcher_alpha,
|
203
|
+
matcher_gamma,
|
204
|
+
):
|
205
|
+
"""Computes the Varifocal Loss (VFL) for classification.
|
206
|
+
|
207
|
+
VFL is an asymmetric focal loss variant designed for dense object
|
208
|
+
detection. It treats the Intersection over Union (IoU) between a
|
209
|
+
predicted box and its matched ground truth box as the target score for
|
210
|
+
positive examples while down-weighting the loss for negative examples.
|
211
|
+
This helps the model focus on high-quality localizations.
|
212
|
+
|
213
|
+
Args:
|
214
|
+
outputs: dict, A dictionary containing the model's predictions,
|
215
|
+
including `"logits"` and `"pred_boxes"`.
|
216
|
+
targets: list of dict, A list of dictionaries containing ground
|
217
|
+
truth `"labels"` and `"boxes"`.
|
218
|
+
indices: tuple, `(row_ind, col_ind, valid_mask)` from the
|
219
|
+
Hungarian matcher, indicating the assignments between
|
220
|
+
predictions and targets.
|
221
|
+
num_boxes: int, The total number of ground truth boxes in the batch,
|
222
|
+
used for normalization.
|
223
|
+
|
224
|
+
Returns:
|
225
|
+
Dictionary: The computed VFL loss.
|
226
|
+
"""
|
227
|
+
_, col_indices, valid_masks = indices
|
228
|
+
batch_idx, src_idx = _get_source_permutation_idx(indices)
|
229
|
+
src_boxes = gather_along_first_two_dims(
|
230
|
+
outputs["pred_boxes"], batch_idx, src_idx
|
231
|
+
)
|
232
|
+
flat_col_indices = keras.ops.reshape(col_indices, (-1,))
|
233
|
+
flat_valid_masks = keras.ops.reshape(valid_masks, (-1,))
|
234
|
+
src_logits = outputs["logits"]
|
235
|
+
target_classes_init = keras.ops.full(
|
236
|
+
shape=keras.ops.shape(src_logits)[:2],
|
237
|
+
fill_value=num_classes,
|
238
|
+
dtype="int32",
|
239
|
+
)
|
240
|
+
target_score_original = keras.ops.zeros_like(
|
241
|
+
target_classes_init, dtype=src_logits.dtype
|
242
|
+
)
|
243
|
+
update_indices = keras.ops.stack([batch_idx, src_idx], axis=-1)
|
244
|
+
|
245
|
+
def process_targets():
|
246
|
+
target_labels_tensor = keras.ops.stack(
|
247
|
+
[t["labels"] for t in targets], axis=0
|
248
|
+
)
|
249
|
+
target_boxes_tensor = keras.ops.stack(
|
250
|
+
[t["boxes"] for t in targets], axis=0
|
251
|
+
)
|
252
|
+
if keras.ops.ndim(target_labels_tensor) == 3:
|
253
|
+
target_labels_tensor = keras.ops.squeeze(
|
254
|
+
target_labels_tensor, axis=1
|
255
|
+
)
|
256
|
+
if keras.ops.ndim(target_boxes_tensor) == 4:
|
257
|
+
target_boxes_tensor = keras.ops.squeeze(target_boxes_tensor, axis=1)
|
258
|
+
flat_target_labels = keras.ops.reshape(target_labels_tensor, (-1,))
|
259
|
+
flat_target_boxes = keras.ops.reshape(target_boxes_tensor, (-1, 4))
|
260
|
+
num_targets = keras.ops.shape(flat_target_labels)[0]
|
261
|
+
num_targets = keras.ops.cast(num_targets, dtype=flat_col_indices.dtype)
|
262
|
+
safe_flat_col_indices = keras.ops.where(
|
263
|
+
(flat_col_indices >= 0) & (flat_col_indices < num_targets),
|
264
|
+
flat_col_indices,
|
265
|
+
0,
|
266
|
+
)
|
267
|
+
target_classes_flat = keras.ops.take(
|
268
|
+
flat_target_labels, safe_flat_col_indices, axis=0
|
269
|
+
)
|
270
|
+
target_boxes_flat = keras.ops.take(
|
271
|
+
flat_target_boxes, safe_flat_col_indices, axis=0
|
272
|
+
)
|
273
|
+
target_classes_flat = keras.ops.where(
|
274
|
+
flat_valid_masks, target_classes_flat, num_classes
|
275
|
+
)
|
276
|
+
target_boxes_flat = keras.ops.where(
|
277
|
+
keras.ops.expand_dims(flat_valid_masks, axis=-1),
|
278
|
+
target_boxes_flat,
|
279
|
+
0.0,
|
280
|
+
)
|
281
|
+
src_boxes_corners = keras.utils.bounding_boxes.convert_format(
|
282
|
+
keras.ops.stop_gradient(src_boxes),
|
283
|
+
source="center_xywh",
|
284
|
+
target="xyxy",
|
285
|
+
)
|
286
|
+
target_boxes_corners = keras.utils.bounding_boxes.convert_format(
|
287
|
+
target_boxes_flat,
|
288
|
+
source="center_xywh",
|
289
|
+
target="xyxy",
|
290
|
+
)
|
291
|
+
ious_matrix = keras.utils.bounding_boxes.compute_iou(
|
292
|
+
src_boxes_corners,
|
293
|
+
target_boxes_corners,
|
294
|
+
bounding_box_format="xyxy",
|
295
|
+
)
|
296
|
+
ious = keras.ops.diagonal(ious_matrix)
|
297
|
+
ious = ious * keras.ops.cast(flat_valid_masks, dtype=ious.dtype)
|
298
|
+
target_classes_flat = keras.ops.cast(target_classes_flat, dtype="int32")
|
299
|
+
ious = keras.ops.cast(ious, dtype=src_logits.dtype)
|
300
|
+
target_classes_updated = keras.ops.scatter_update(
|
301
|
+
target_classes_init, update_indices, target_classes_flat
|
302
|
+
)
|
303
|
+
target_score_updated = keras.ops.scatter_update(
|
304
|
+
target_score_original, update_indices, ious
|
305
|
+
)
|
306
|
+
return target_classes_updated, target_score_updated
|
307
|
+
|
308
|
+
target_classes, target_score_original = process_targets()
|
309
|
+
target_one_hot = keras.ops.one_hot(
|
310
|
+
target_classes, num_classes=num_classes + 1
|
311
|
+
)[..., :-1]
|
312
|
+
target_score = (
|
313
|
+
keras.ops.expand_dims(target_score_original, axis=-1) * target_one_hot
|
314
|
+
)
|
315
|
+
pred_score_sigmoid = keras.ops.sigmoid(keras.ops.stop_gradient(src_logits))
|
316
|
+
weight = (
|
317
|
+
matcher_alpha
|
318
|
+
* keras.ops.power(pred_score_sigmoid, matcher_gamma)
|
319
|
+
* (1 - target_one_hot)
|
320
|
+
+ target_score
|
321
|
+
)
|
322
|
+
loss_vfl = keras.ops.binary_crossentropy(
|
323
|
+
target_score, src_logits, from_logits=True
|
324
|
+
)
|
325
|
+
loss_vfl = loss_vfl * weight
|
326
|
+
loss_vfl = (
|
327
|
+
keras.ops.sum(keras.ops.mean(loss_vfl, axis=1))
|
328
|
+
* keras.ops.cast(keras.ops.shape(src_logits)[1], dtype=loss_vfl.dtype)
|
329
|
+
/ num_boxes
|
330
|
+
)
|
331
|
+
return {"loss_vfl": loss_vfl}
|
332
|
+
|
333
|
+
|
334
|
+
def compute_box_losses(outputs, targets, indices, num_boxes):
|
335
|
+
"""Computes the bounding box regression losses.
|
336
|
+
|
337
|
+
This function calculates two losses for the bounding boxes that were
|
338
|
+
successfully matched to ground truth objects by the Hungarian matcher:
|
339
|
+
1. **L1 Loss (`loss_bbox`):** A regression loss that measures the
|
340
|
+
absolute difference between the predicted and ground truth box
|
341
|
+
coordinates.
|
342
|
+
2. **Complete IoU Loss (`loss_ciou`):** A scale-invariant loss that
|
343
|
+
accounts for the shape and orientation of the boxes, providing a
|
344
|
+
better gradient signal than the standard IoU, especially for
|
345
|
+
non-overlapping boxes.
|
346
|
+
|
347
|
+
Args:
|
348
|
+
outputs: dict, A dictionary containing predicted `"pred_boxes"`.
|
349
|
+
targets: list of dict, A list of dictionaries containing ground
|
350
|
+
truth `"boxes"`.
|
351
|
+
indices: tuple, The assignments from the Hungarian matcher.
|
352
|
+
num_boxes: int, The total number of ground truth boxes for
|
353
|
+
normalization.
|
354
|
+
|
355
|
+
Returns:
|
356
|
+
Dictionary: A dictionary containing the L1 and CIoU losses.
|
357
|
+
"""
|
358
|
+
_, col_indices, valid_masks = indices
|
359
|
+
batch_idx, src_idx = _get_source_permutation_idx(indices)
|
360
|
+
src_boxes = gather_along_first_two_dims(
|
361
|
+
outputs["pred_boxes"], batch_idx, src_idx
|
362
|
+
)
|
363
|
+
target_boxes_all = targets[0]["boxes"]
|
364
|
+
if keras.ops.ndim(target_boxes_all) == 3:
|
365
|
+
target_boxes_all = keras.ops.squeeze(target_boxes_all, axis=0)
|
366
|
+
col_indices_flat = keras.ops.reshape(col_indices, [-1])
|
367
|
+
valid_masks_flat = keras.ops.reshape(valid_masks, [-1])
|
368
|
+
max_box_idx = keras.ops.maximum(keras.ops.shape(target_boxes_all)[0] - 1, 0)
|
369
|
+
max_box_idx = keras.ops.cast(max_box_idx, dtype=col_indices_flat.dtype)
|
370
|
+
safe_col_indices = keras.ops.clip(col_indices_flat, 0, max_box_idx)
|
371
|
+
target_boxes = keras.ops.take(target_boxes_all, safe_col_indices, axis=0)
|
372
|
+
valid_masks_expanded = keras.ops.expand_dims(valid_masks_flat, axis=-1)
|
373
|
+
valid_masks_expanded = keras.ops.cast(
|
374
|
+
valid_masks_expanded, target_boxes.dtype
|
375
|
+
)
|
376
|
+
target_boxes = target_boxes * valid_masks_expanded
|
377
|
+
l1_loss = keras.ops.sum(
|
378
|
+
keras.ops.abs(src_boxes - target_boxes)
|
379
|
+
* keras.ops.cast(valid_masks_expanded, src_boxes.dtype)
|
380
|
+
)
|
381
|
+
src_boxes_xyxy = keras.utils.bounding_boxes.convert_format(
|
382
|
+
src_boxes,
|
383
|
+
source="center_xywh",
|
384
|
+
target="xyxy",
|
385
|
+
)
|
386
|
+
target_boxes_xyxy = keras.utils.bounding_boxes.convert_format(
|
387
|
+
target_boxes,
|
388
|
+
source="center_xywh",
|
389
|
+
target="xyxy",
|
390
|
+
)
|
391
|
+
ciou = keras.utils.bounding_boxes.compute_ciou(
|
392
|
+
src_boxes_xyxy,
|
393
|
+
target_boxes_xyxy,
|
394
|
+
bounding_box_format="xyxy",
|
395
|
+
)
|
396
|
+
ciou_loss = keras.ops.sum(
|
397
|
+
(1.0 - ciou) * keras.ops.cast(valid_masks_flat, src_boxes.dtype)
|
398
|
+
)
|
399
|
+
return {
|
400
|
+
"loss_bbox": l1_loss / num_boxes,
|
401
|
+
"loss_ciou": ciou_loss / num_boxes,
|
402
|
+
}
|
403
|
+
|
404
|
+
|
405
|
+
def compute_local_losses(
|
406
|
+
outputs,
|
407
|
+
targets,
|
408
|
+
indices,
|
409
|
+
num_boxes,
|
410
|
+
backbone,
|
411
|
+
ddf_temperature,
|
412
|
+
compute_ddf=None,
|
413
|
+
):
|
414
|
+
"""Computes local refinement losses (FGL and DDF).
|
415
|
+
|
416
|
+
This function calculates two advanced losses for fine-grained box
|
417
|
+
and feature refinement:
|
418
|
+
1. **Focal Grid Loss (`loss_fgl`):** This loss operates on the
|
419
|
+
integral-based representation of the bounding box corners. It is a
|
420
|
+
focal loss applied to the distribution over discrete bins,
|
421
|
+
encouraging the model to produce sharp, unimodal distributions
|
422
|
+
around the true corner locations.
|
423
|
+
2. **Distribution-guided Denoising Focal Loss (`loss_ddf`):** This is
|
424
|
+
a knowledge distillation loss used for auxiliary decoder layers. It
|
425
|
+
minimizes the KL-divergence between the corner prediction
|
426
|
+
distribution of an intermediate layer (student) and that of the
|
427
|
+
final decoder layer (teacher). This guides the intermediate layers
|
428
|
+
to learn features that are consistent with the final, most refined
|
429
|
+
predictions.
|
430
|
+
|
431
|
+
Args:
|
432
|
+
outputs: dict, A dictionary of model predictions, including
|
433
|
+
`"pred_corners"`, `"ref_points"`, and potentially teacher
|
434
|
+
predictions like `"teacher_corners"` and `"teacher_logits"`.
|
435
|
+
targets: list of dict, A list of dictionaries with ground truth
|
436
|
+
`"boxes"`.
|
437
|
+
indices: tuple of Tensors, The assignments from the Hungarian
|
438
|
+
matcher.
|
439
|
+
num_boxes: scalar Tensor, The total number of ground truth boxes for
|
440
|
+
normalization.
|
441
|
+
compute_ddf: bool, Indicates whether to compute the DDF loss.
|
442
|
+
|
443
|
+
Returns:
|
444
|
+
Dictionary: A dictionary containing the computed FGL and DDF losses.
|
445
|
+
"""
|
446
|
+
losses = {}
|
447
|
+
if (
|
448
|
+
"pred_corners" not in outputs
|
449
|
+
or outputs["pred_corners"] is None
|
450
|
+
or "ref_points" not in outputs
|
451
|
+
or outputs["ref_points"] is None
|
452
|
+
):
|
453
|
+
losses["loss_fgl"] = keras.ops.convert_to_tensor(
|
454
|
+
0.0, dtype=keras.backend.floatx()
|
455
|
+
)
|
456
|
+
losses["loss_ddf"] = keras.ops.convert_to_tensor(
|
457
|
+
0.0, dtype=keras.backend.floatx()
|
458
|
+
)
|
459
|
+
return losses
|
460
|
+
|
461
|
+
if compute_ddf is None:
|
462
|
+
compute_ddf = (
|
463
|
+
"teacher_corners" in outputs
|
464
|
+
and outputs["teacher_corners"] is not None
|
465
|
+
and "teacher_logits" in outputs
|
466
|
+
)
|
467
|
+
|
468
|
+
_, col_indices, valid_masks = indices
|
469
|
+
batch_idx, src_idx = _get_source_permutation_idx(indices)
|
470
|
+
col_indices_flat = keras.ops.reshape(col_indices, [-1])
|
471
|
+
valid_masks_flat = keras.ops.reshape(valid_masks, [-1])
|
472
|
+
target_boxes_all = targets[0]["boxes"]
|
473
|
+
if keras.ops.ndim(target_boxes_all) == 3:
|
474
|
+
target_boxes_all = keras.ops.squeeze(target_boxes_all, axis=0)
|
475
|
+
max_box_idx = keras.ops.maximum(keras.ops.shape(target_boxes_all)[0] - 1, 0)
|
476
|
+
max_box_idx = keras.ops.cast(max_box_idx, dtype=col_indices_flat.dtype)
|
477
|
+
safe_col_indices = keras.ops.clip(col_indices_flat, 0, max_box_idx)
|
478
|
+
target_boxes_matched_center = keras.ops.take(
|
479
|
+
target_boxes_all, safe_col_indices, axis=0
|
480
|
+
)
|
481
|
+
valid_masks_expanded = keras.ops.expand_dims(valid_masks_flat, axis=-1)
|
482
|
+
valid_masks_expanded = keras.ops.cast(
|
483
|
+
valid_masks_expanded, target_boxes_matched_center.dtype
|
484
|
+
)
|
485
|
+
target_boxes_matched_center = (
|
486
|
+
target_boxes_matched_center * valid_masks_expanded
|
487
|
+
)
|
488
|
+
|
489
|
+
pred_corners_matched_flat = gather_along_first_two_dims(
|
490
|
+
outputs["pred_corners"], batch_idx, src_idx
|
491
|
+
)
|
492
|
+
pred_corners_matched = keras.ops.reshape(
|
493
|
+
pred_corners_matched_flat,
|
494
|
+
(-1, backbone.decoder.max_num_bins + 1),
|
495
|
+
)
|
496
|
+
ref_points_matched = gather_along_first_two_dims(
|
497
|
+
outputs["ref_points"], batch_idx, src_idx
|
498
|
+
)
|
499
|
+
ref_points_matched = keras.ops.stop_gradient(ref_points_matched)
|
500
|
+
target_boxes_corners_matched = keras.utils.bounding_boxes.convert_format(
|
501
|
+
target_boxes_matched_center,
|
502
|
+
source="center_xywh",
|
503
|
+
target="xyxy",
|
504
|
+
)
|
505
|
+
reg_scale_tensor = backbone.decoder.reg_scale
|
506
|
+
up_tensor = backbone.decoder.upsampling_factor
|
507
|
+
target_corners_dist, weight_right, weight_left = bbox2distance(
|
508
|
+
ref_points_matched,
|
509
|
+
target_boxes_corners_matched,
|
510
|
+
backbone.decoder.max_num_bins,
|
511
|
+
reg_scale_tensor,
|
512
|
+
up_tensor,
|
513
|
+
)
|
514
|
+
pred_boxes_matched_center = gather_along_first_two_dims(
|
515
|
+
outputs["pred_boxes"], batch_idx, src_idx
|
516
|
+
)
|
517
|
+
pred_boxes_corners_matched = keras.utils.bounding_boxes.convert_format(
|
518
|
+
pred_boxes_matched_center,
|
519
|
+
source="center_xywh",
|
520
|
+
target="xyxy",
|
521
|
+
)
|
522
|
+
ious_pairwise = keras.utils.bounding_boxes.compute_iou(
|
523
|
+
pred_boxes_corners_matched,
|
524
|
+
target_boxes_corners_matched,
|
525
|
+
bounding_box_format="xyxy",
|
526
|
+
)
|
527
|
+
ious = keras.ops.diagonal(ious_pairwise)
|
528
|
+
ious = ious * keras.ops.cast(valid_masks_flat, dtype=ious.dtype)
|
529
|
+
weight_targets_fgl = keras.ops.reshape(
|
530
|
+
keras.ops.tile(keras.ops.expand_dims(ious, 1), [1, 4]),
|
531
|
+
[-1],
|
532
|
+
)
|
533
|
+
weight_targets_fgl = keras.ops.stop_gradient(weight_targets_fgl)
|
534
|
+
losses["loss_fgl"] = unimodal_distribution_focal_loss(
|
535
|
+
pred_corners_matched,
|
536
|
+
target_corners_dist,
|
537
|
+
weight_right,
|
538
|
+
weight_left,
|
539
|
+
weight=weight_targets_fgl,
|
540
|
+
avg_factor=num_boxes,
|
541
|
+
)
|
542
|
+
|
543
|
+
def ddf_true_fn():
|
544
|
+
pred_corners_all = keras.ops.reshape(
|
545
|
+
outputs["pred_corners"],
|
546
|
+
(-1, backbone.decoder.max_num_bins + 1),
|
547
|
+
)
|
548
|
+
target_corners_all = keras.ops.reshape(
|
549
|
+
keras.ops.stop_gradient(outputs["teacher_corners"]),
|
550
|
+
(-1, backbone.decoder.max_num_bins + 1),
|
551
|
+
)
|
552
|
+
|
553
|
+
def compute_ddf_loss_fn():
|
554
|
+
weight_targets_local = keras.ops.max(
|
555
|
+
keras.ops.sigmoid(outputs["teacher_logits"]), axis=-1
|
556
|
+
)
|
557
|
+
num_queries = keras.ops.cast(
|
558
|
+
keras.ops.shape(weight_targets_local)[1],
|
559
|
+
dtype=batch_idx.dtype,
|
560
|
+
)
|
561
|
+
flat_update_indices = batch_idx * num_queries + src_idx
|
562
|
+
flat_update_indices = keras.ops.expand_dims(
|
563
|
+
flat_update_indices, axis=-1
|
564
|
+
)
|
565
|
+
mask = keras.ops.zeros_like(weight_targets_local, dtype="bool")
|
566
|
+
mask_flat = keras.ops.scatter_update(
|
567
|
+
keras.ops.reshape(mask, (-1,)),
|
568
|
+
flat_update_indices,
|
569
|
+
keras.ops.ones_like(batch_idx, dtype="bool"),
|
570
|
+
)
|
571
|
+
mask = keras.ops.reshape(
|
572
|
+
mask_flat, keras.ops.shape(weight_targets_local)
|
573
|
+
)
|
574
|
+
weight_targets_local_flat = keras.ops.reshape(
|
575
|
+
weight_targets_local, (-1,)
|
576
|
+
)
|
577
|
+
weight_targets_local_matched_flat = keras.ops.scatter_update(
|
578
|
+
weight_targets_local_flat,
|
579
|
+
flat_update_indices,
|
580
|
+
ious,
|
581
|
+
)
|
582
|
+
weight_targets_local = keras.ops.reshape(
|
583
|
+
weight_targets_local_matched_flat,
|
584
|
+
keras.ops.shape(weight_targets_local),
|
585
|
+
)
|
586
|
+
weight_targets_local_expanded = keras.ops.reshape(
|
587
|
+
keras.ops.tile(
|
588
|
+
keras.ops.expand_dims(weight_targets_local, axis=-1),
|
589
|
+
[1, 1, 4],
|
590
|
+
),
|
591
|
+
[-1],
|
592
|
+
)
|
593
|
+
weight_targets_local_expanded = keras.ops.stop_gradient(
|
594
|
+
weight_targets_local_expanded
|
595
|
+
)
|
596
|
+
# NOTE: Original impl hardcodes `ddf_temperature` to 5.0 for
|
597
|
+
# DDFL.
|
598
|
+
# KerasHub lets users configure it if needed.
|
599
|
+
# Ref: https://github.com/huggingface/transformers/blob/b374c3d12e8a42014b7911d1bddf598aeada1154/src/transformers/loss/loss_d_fine.py#L238
|
600
|
+
pred_softmax = keras.ops.softmax(
|
601
|
+
pred_corners_all / ddf_temperature, axis=-1
|
602
|
+
)
|
603
|
+
target_softmax = keras.ops.softmax(
|
604
|
+
target_corners_all / ddf_temperature, axis=-1
|
605
|
+
)
|
606
|
+
kl_div = keras.ops.sum(
|
607
|
+
target_softmax
|
608
|
+
* (
|
609
|
+
keras.ops.log(target_softmax + 1e-8)
|
610
|
+
- keras.ops.log(pred_softmax + 1e-8)
|
611
|
+
),
|
612
|
+
axis=-1,
|
613
|
+
)
|
614
|
+
loss_match_local = (
|
615
|
+
weight_targets_local_expanded * (ddf_temperature**2) * kl_div
|
616
|
+
)
|
617
|
+
mask_expanded = keras.ops.expand_dims(mask, axis=-1)
|
618
|
+
mask_expanded = keras.ops.tile(mask_expanded, [1, 1, 4])
|
619
|
+
mask_flat = keras.ops.reshape(mask_expanded, (-1,))
|
620
|
+
loss_match_local1 = keras.ops.cond(
|
621
|
+
keras.ops.any(mask_flat),
|
622
|
+
lambda: keras.ops.sum(
|
623
|
+
loss_match_local
|
624
|
+
* keras.ops.cast(mask_flat, loss_match_local.dtype)
|
625
|
+
)
|
626
|
+
/ keras.ops.sum(
|
627
|
+
keras.ops.cast(mask_flat, loss_match_local.dtype)
|
628
|
+
),
|
629
|
+
lambda: keras.ops.convert_to_tensor(
|
630
|
+
0.0, dtype=loss_match_local.dtype
|
631
|
+
),
|
632
|
+
)
|
633
|
+
neg_mask_flat = keras.ops.logical_not(mask_flat)
|
634
|
+
loss_match_local2 = keras.ops.cond(
|
635
|
+
keras.ops.any(neg_mask_flat),
|
636
|
+
lambda: keras.ops.sum(
|
637
|
+
loss_match_local
|
638
|
+
* keras.ops.cast(neg_mask_flat, loss_match_local.dtype)
|
639
|
+
)
|
640
|
+
/ keras.ops.sum(
|
641
|
+
keras.ops.cast(neg_mask_flat, loss_match_local.dtype)
|
642
|
+
),
|
643
|
+
lambda: keras.ops.convert_to_tensor(
|
644
|
+
0.0, dtype=loss_match_local.dtype
|
645
|
+
),
|
646
|
+
)
|
647
|
+
batch_scale = 1.0 / keras.ops.cast(
|
648
|
+
keras.ops.shape(outputs["pred_boxes"])[0],
|
649
|
+
dtype="float32",
|
650
|
+
)
|
651
|
+
num_pos = keras.ops.sqrt(
|
652
|
+
keras.ops.sum(keras.ops.cast(mask, dtype="float32"))
|
653
|
+
* batch_scale
|
654
|
+
)
|
655
|
+
num_neg = keras.ops.sqrt(
|
656
|
+
keras.ops.sum(keras.ops.cast(~mask, dtype="float32"))
|
657
|
+
* batch_scale
|
658
|
+
)
|
659
|
+
return (
|
660
|
+
loss_match_local1 * num_pos + loss_match_local2 * num_neg
|
661
|
+
) / (num_pos + num_neg + 1e-8)
|
662
|
+
|
663
|
+
all_equal = keras.ops.all(
|
664
|
+
keras.ops.equal(pred_corners_all, target_corners_all)
|
665
|
+
)
|
666
|
+
return keras.ops.cond(
|
667
|
+
all_equal,
|
668
|
+
lambda: keras.ops.sum(pred_corners_all) * 0.0,
|
669
|
+
compute_ddf_loss_fn,
|
670
|
+
)
|
671
|
+
|
672
|
+
def ddf_false_fn():
|
673
|
+
return keras.ops.convert_to_tensor(0.0, dtype=keras.backend.floatx())
|
674
|
+
|
675
|
+
losses["loss_ddf"] = keras.ops.cond(compute_ddf, ddf_true_fn, ddf_false_fn)
|
676
|
+
return losses
|
677
|
+
|
678
|
+
|
679
|
+
def _translate_gt_valid_case(
|
680
|
+
gt_flat, valid_idx_mask, function_values, max_num_bins, mask
|
681
|
+
):
|
682
|
+
closest_left_indices = (
|
683
|
+
keras.ops.sum(keras.ops.cast(mask, "int32"), axis=1) - 1
|
684
|
+
)
|
685
|
+
indices_float = keras.ops.cast(closest_left_indices, dtype=gt_flat.dtype)
|
686
|
+
weight_right = keras.ops.zeros_like(indices_float)
|
687
|
+
weight_left = keras.ops.zeros_like(indices_float)
|
688
|
+
valid_indices_int = keras.ops.arange(keras.ops.shape(valid_idx_mask)[0])
|
689
|
+
valid_indices_int = keras.ops.where(valid_idx_mask, valid_indices_int, -1)
|
690
|
+
valid_indices_int = keras.ops.where(
|
691
|
+
valid_indices_int >= 0, valid_indices_int, 0
|
692
|
+
)
|
693
|
+
valid_indices_long = keras.ops.cast(
|
694
|
+
keras.ops.where(
|
695
|
+
valid_idx_mask,
|
696
|
+
keras.ops.take(indices_float, valid_indices_int, axis=0),
|
697
|
+
0.0,
|
698
|
+
),
|
699
|
+
"int32",
|
700
|
+
)
|
701
|
+
gt_valid = keras.ops.where(
|
702
|
+
valid_idx_mask,
|
703
|
+
keras.ops.take(gt_flat, valid_indices_int, axis=0),
|
704
|
+
0.0,
|
705
|
+
)
|
706
|
+
left_values = keras.ops.take(function_values, valid_indices_long, axis=0)
|
707
|
+
right_values = keras.ops.take(
|
708
|
+
function_values,
|
709
|
+
keras.ops.clip(
|
710
|
+
valid_indices_long + 1,
|
711
|
+
0,
|
712
|
+
keras.ops.shape(function_values)[0] - 1,
|
713
|
+
),
|
714
|
+
axis=0,
|
715
|
+
)
|
716
|
+
left_diffs = keras.ops.abs(gt_valid - left_values)
|
717
|
+
right_diffs = keras.ops.abs(right_values - gt_valid)
|
718
|
+
wr_valid = left_diffs / (left_diffs + right_diffs + 1e-8)
|
719
|
+
wl_valid = 1.0 - wr_valid
|
720
|
+
weight_right = keras.ops.where(
|
721
|
+
keras.ops.expand_dims(valid_idx_mask, axis=-1),
|
722
|
+
keras.ops.expand_dims(wr_valid, axis=-1),
|
723
|
+
keras.ops.expand_dims(weight_right, axis=-1),
|
724
|
+
)
|
725
|
+
weight_right = keras.ops.squeeze(weight_right, axis=-1)
|
726
|
+
weight_left = keras.ops.where(
|
727
|
+
keras.ops.expand_dims(valid_idx_mask, axis=-1),
|
728
|
+
keras.ops.expand_dims(wl_valid, axis=-1),
|
729
|
+
keras.ops.expand_dims(weight_left, axis=-1),
|
730
|
+
)
|
731
|
+
weight_left = keras.ops.squeeze(weight_left, axis=-1)
|
732
|
+
indices_float = keras.ops.where(
|
733
|
+
indices_float < 0,
|
734
|
+
keras.ops.zeros_like(indices_float),
|
735
|
+
indices_float,
|
736
|
+
)
|
737
|
+
weight_right = keras.ops.where(
|
738
|
+
indices_float < 0, keras.ops.zeros_like(weight_right), weight_right
|
739
|
+
)
|
740
|
+
weight_left = keras.ops.where(
|
741
|
+
indices_float < 0, keras.ops.ones_like(weight_left), weight_left
|
742
|
+
)
|
743
|
+
indices_float = keras.ops.where(
|
744
|
+
indices_float >= max_num_bins,
|
745
|
+
keras.ops.cast(max_num_bins - 0.1, dtype=indices_float.dtype),
|
746
|
+
indices_float,
|
747
|
+
)
|
748
|
+
weight_right = keras.ops.where(
|
749
|
+
indices_float >= max_num_bins,
|
750
|
+
keras.ops.ones_like(weight_right),
|
751
|
+
weight_right,
|
752
|
+
)
|
753
|
+
weight_left = keras.ops.where(
|
754
|
+
indices_float >= max_num_bins,
|
755
|
+
keras.ops.zeros_like(weight_left),
|
756
|
+
weight_left,
|
757
|
+
)
|
758
|
+
return indices_float, weight_right, weight_left
|
759
|
+
|
760
|
+
|
761
|
+
def translate_gt(gt, max_num_bins, reg_scale, up):
|
762
|
+
gt_flat = keras.ops.reshape(gt, [-1])
|
763
|
+
function_values = weighting_function(max_num_bins, up, reg_scale)
|
764
|
+
diffs = keras.ops.expand_dims(
|
765
|
+
function_values, axis=0
|
766
|
+
) - keras.ops.expand_dims(gt_flat, axis=1)
|
767
|
+
mask = diffs <= 0
|
768
|
+
closest_left_indices = (
|
769
|
+
keras.ops.sum(keras.ops.cast(mask, "int32"), axis=1) - 1
|
770
|
+
)
|
771
|
+
indices_float = keras.ops.cast(closest_left_indices, dtype=gt_flat.dtype)
|
772
|
+
weight_right = keras.ops.zeros_like(indices_float)
|
773
|
+
weight_left = keras.ops.zeros_like(indices_float)
|
774
|
+
valid_idx_mask = (indices_float >= 0) & (indices_float < max_num_bins)
|
775
|
+
return keras.ops.cond(
|
776
|
+
keras.ops.any(valid_idx_mask),
|
777
|
+
lambda: _translate_gt_valid_case(
|
778
|
+
gt_flat, valid_idx_mask, function_values, max_num_bins, mask
|
779
|
+
),
|
780
|
+
lambda: (
|
781
|
+
keras.ops.zeros_like(indices_float),
|
782
|
+
keras.ops.zeros_like(weight_right),
|
783
|
+
keras.ops.ones_like(weight_left),
|
784
|
+
),
|
785
|
+
)
|
786
|
+
|
787
|
+
|
788
|
+
def _compute_bbox2distance(points, bbox, max_num_bins, reg_scale, up, eps=0.1):
|
789
|
+
reg_scale_abs = keras.ops.abs(reg_scale)
|
790
|
+
left = (points[..., 0] - bbox[..., 0]) / (
|
791
|
+
points[..., 2] / reg_scale_abs + 1e-16
|
792
|
+
) - 0.5 * reg_scale_abs
|
793
|
+
top = (points[..., 1] - bbox[..., 1]) / (
|
794
|
+
points[..., 3] / reg_scale_abs + 1e-16
|
795
|
+
) - 0.5 * reg_scale_abs
|
796
|
+
right = (bbox[..., 2] - points[..., 0]) / (
|
797
|
+
points[..., 2] / reg_scale_abs + 1e-16
|
798
|
+
) - 0.5 * reg_scale_abs
|
799
|
+
bottom = (bbox[..., 3] - points[..., 1]) / (
|
800
|
+
points[..., 3] / reg_scale_abs + 1e-16
|
801
|
+
) - 0.5 * reg_scale_abs
|
802
|
+
four_lens = keras.ops.stack([left, top, right, bottom], axis=-1)
|
803
|
+
up_tensor = (
|
804
|
+
keras.ops.convert_to_tensor(up)
|
805
|
+
if not isinstance(up, (keras.KerasTensor))
|
806
|
+
else up
|
807
|
+
)
|
808
|
+
four_lens_translated, weight_right, weight_left = translate_gt(
|
809
|
+
four_lens, max_num_bins, reg_scale_abs, up_tensor
|
810
|
+
)
|
811
|
+
four_lens_translated = keras.ops.clip(
|
812
|
+
four_lens_translated, 0, max_num_bins - eps
|
813
|
+
)
|
814
|
+
return (
|
815
|
+
keras.ops.stop_gradient(four_lens_translated),
|
816
|
+
keras.ops.stop_gradient(weight_right),
|
817
|
+
keras.ops.stop_gradient(weight_left),
|
818
|
+
)
|
819
|
+
|
820
|
+
|
821
|
+
def bbox2distance(points, bbox, max_num_bins, reg_scale, up, eps=0.1):
|
822
|
+
expected_flat_size = keras.ops.shape(points)[0] * 4
|
823
|
+
return keras.ops.cond(
|
824
|
+
keras.ops.equal(keras.ops.shape(points)[0], 0),
|
825
|
+
lambda: (
|
826
|
+
keras.ops.zeros(
|
827
|
+
(expected_flat_size,), dtype=keras.backend.floatx()
|
828
|
+
),
|
829
|
+
keras.ops.zeros(
|
830
|
+
(expected_flat_size,), dtype=keras.backend.floatx()
|
831
|
+
),
|
832
|
+
keras.ops.zeros(
|
833
|
+
(expected_flat_size,), dtype=keras.backend.floatx()
|
834
|
+
),
|
835
|
+
),
|
836
|
+
lambda: _compute_bbox2distance(
|
837
|
+
points, bbox, max_num_bins, reg_scale, up, eps
|
838
|
+
),
|
839
|
+
)
|
840
|
+
|
841
|
+
|
842
|
+
def unimodal_distribution_focal_loss(
|
843
|
+
pred,
|
844
|
+
label,
|
845
|
+
weight_right,
|
846
|
+
weight_left,
|
847
|
+
weight=None,
|
848
|
+
reduction="sum",
|
849
|
+
avg_factor=None,
|
850
|
+
):
|
851
|
+
label_flat = keras.ops.reshape(label, [-1])
|
852
|
+
weight_right_flat = keras.ops.reshape(weight_right, [-1])
|
853
|
+
weight_left_flat = keras.ops.reshape(weight_left, [-1])
|
854
|
+
dis_left = keras.ops.cast(label_flat, "int32")
|
855
|
+
dis_right = dis_left + 1
|
856
|
+
loss_left = (
|
857
|
+
keras.ops.sparse_categorical_crossentropy(
|
858
|
+
dis_left, pred, from_logits=True
|
859
|
+
)
|
860
|
+
* weight_left_flat
|
861
|
+
)
|
862
|
+
loss_right = (
|
863
|
+
keras.ops.sparse_categorical_crossentropy(
|
864
|
+
dis_right, pred, from_logits=True
|
865
|
+
)
|
866
|
+
* weight_right_flat
|
867
|
+
)
|
868
|
+
loss = loss_left + loss_right
|
869
|
+
if weight is not None:
|
870
|
+
loss = loss * keras.ops.cast(weight, dtype=loss.dtype)
|
871
|
+
if avg_factor is not None:
|
872
|
+
loss = keras.ops.sum(loss) / avg_factor
|
873
|
+
elif reduction == "mean":
|
874
|
+
loss = keras.ops.mean(loss)
|
875
|
+
elif reduction == "sum":
|
876
|
+
loss = keras.ops.sum(loss)
|
877
|
+
return loss
|
878
|
+
|
879
|
+
|
880
|
+
def _get_source_permutation_idx(indices):
|
881
|
+
"""Gathers the batch and source indices for matched predictions.
|
882
|
+
|
883
|
+
This method is a JAX-compatible adaptation of the author's approach,
|
884
|
+
which creates dynamically sized tensors by concatenating indices from a
|
885
|
+
list, which is not traceable by a JIT compiler.
|
886
|
+
|
887
|
+
To ensure JAX compatibility, this implementation uses a masking
|
888
|
+
strategy. It returns fixed-size tensors where invalid positions are
|
889
|
+
padded with `0`. The downstream loss functions then use the
|
890
|
+
`valid_masks` tensor to ignore these padded entries during loss
|
891
|
+
computation.
|
892
|
+
"""
|
893
|
+
row_indices, _, valid_masks = indices
|
894
|
+
batch_size = keras.ops.shape(row_indices)[0]
|
895
|
+
max_matches = keras.ops.shape(row_indices)[1]
|
896
|
+
batch_indices = keras.ops.arange(batch_size, dtype="int32")
|
897
|
+
batch_indices = keras.ops.expand_dims(batch_indices, axis=1)
|
898
|
+
batch_indices = keras.ops.tile(batch_indices, [1, max_matches])
|
899
|
+
batch_indices_flat = keras.ops.reshape(batch_indices, (-1,))
|
900
|
+
row_indices_flat = keras.ops.reshape(row_indices, (-1,))
|
901
|
+
valid_masks_flat = keras.ops.reshape(valid_masks, (-1,))
|
902
|
+
batch_idx = keras.ops.where(
|
903
|
+
valid_masks_flat,
|
904
|
+
keras.ops.cast(batch_indices_flat, "int64"),
|
905
|
+
0,
|
906
|
+
)
|
907
|
+
src_idx = keras.ops.where(
|
908
|
+
valid_masks_flat,
|
909
|
+
keras.ops.cast(row_indices_flat, dtype="int64"),
|
910
|
+
0,
|
911
|
+
)
|
912
|
+
return batch_idx, src_idx
|
913
|
+
|
914
|
+
|
915
|
+
def get_cdn_matched_indices(dn_meta):
|
916
|
+
"""Generates matched indices for contrastive denoising (CDN) training.
|
917
|
+
|
918
|
+
This method is a JAX-compatible adaptation of the author's approach,
|
919
|
+
which iterates through the batch to build a list of dynamically sized
|
920
|
+
index tensors, which is not traceable by a JIT compiler.
|
921
|
+
|
922
|
+
To ensure JAX compatibility, this implementation operates on the entire
|
923
|
+
batch as a single tensor operation. It uses the pre-padded
|
924
|
+
`dn_positive_idx` tensor (where -1 indicates padding) to generate
|
925
|
+
fixed-size `row_indices`, `col_indices`, and a `valid_masks` tensor.
|
926
|
+
"""
|
927
|
+
dn_positive_idx = dn_meta["dn_positive_idx"]
|
928
|
+
batch_size = keras.ops.shape(dn_positive_idx)[0]
|
929
|
+
num_denoising_queries = keras.ops.shape(dn_positive_idx)[1]
|
930
|
+
row_indices = keras.ops.tile(
|
931
|
+
keras.ops.expand_dims(
|
932
|
+
keras.ops.arange(num_denoising_queries, dtype="int64"), 0
|
933
|
+
),
|
934
|
+
[batch_size, 1],
|
935
|
+
)
|
936
|
+
col_indices = dn_positive_idx
|
937
|
+
valid_masks = keras.ops.not_equal(col_indices, -1)
|
938
|
+
return (row_indices, col_indices, valid_masks)
|