keras-hub-nightly 0.23.0.dev202508260411__py3-none-any.whl → 0.23.0.dev202508280418__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. keras_hub/layers/__init__.py +6 -0
  2. keras_hub/models/__init__.py +21 -0
  3. keras_hub/src/layers/modeling/position_embedding.py +21 -6
  4. keras_hub/src/layers/modeling/rotary_embedding.py +16 -6
  5. keras_hub/src/layers/modeling/sine_position_encoding.py +21 -8
  6. keras_hub/src/layers/modeling/token_and_position_embedding.py +2 -1
  7. keras_hub/src/models/backbone.py +10 -15
  8. keras_hub/src/models/d_fine/__init__.py +0 -0
  9. keras_hub/src/models/d_fine/d_fine_attention.py +461 -0
  10. keras_hub/src/models/d_fine/d_fine_backbone.py +891 -0
  11. keras_hub/src/models/d_fine/d_fine_decoder.py +944 -0
  12. keras_hub/src/models/d_fine/d_fine_encoder.py +365 -0
  13. keras_hub/src/models/d_fine/d_fine_hybrid_encoder.py +642 -0
  14. keras_hub/src/models/d_fine/d_fine_image_converter.py +8 -0
  15. keras_hub/src/models/d_fine/d_fine_layers.py +1828 -0
  16. keras_hub/src/models/d_fine/d_fine_loss.py +938 -0
  17. keras_hub/src/models/d_fine/d_fine_object_detector.py +875 -0
  18. keras_hub/src/models/d_fine/d_fine_object_detector_preprocessor.py +14 -0
  19. keras_hub/src/models/d_fine/d_fine_presets.py +2 -0
  20. keras_hub/src/models/d_fine/d_fine_utils.py +827 -0
  21. keras_hub/src/models/hgnetv2/hgnetv2_backbone.py +4 -1
  22. keras_hub/src/models/hgnetv2/hgnetv2_encoder.py +3 -2
  23. keras_hub/src/models/hgnetv2/hgnetv2_layers.py +27 -11
  24. keras_hub/src/models/parseq/__init__.py +0 -0
  25. keras_hub/src/models/parseq/parseq_backbone.py +134 -0
  26. keras_hub/src/models/parseq/parseq_causal_lm.py +466 -0
  27. keras_hub/src/models/parseq/parseq_causal_lm_preprocessor.py +168 -0
  28. keras_hub/src/models/parseq/parseq_decoder.py +418 -0
  29. keras_hub/src/models/parseq/parseq_image_converter.py +8 -0
  30. keras_hub/src/models/parseq/parseq_tokenizer.py +221 -0
  31. keras_hub/src/tests/test_case.py +37 -1
  32. keras_hub/src/utils/preset_utils.py +49 -0
  33. keras_hub/src/utils/tensor_utils.py +23 -1
  34. keras_hub/src/utils/transformers/convert_vit.py +4 -1
  35. keras_hub/src/version.py +1 -1
  36. keras_hub/tokenizers/__init__.py +3 -0
  37. {keras_hub_nightly-0.23.0.dev202508260411.dist-info → keras_hub_nightly-0.23.0.dev202508280418.dist-info}/METADATA +1 -1
  38. {keras_hub_nightly-0.23.0.dev202508260411.dist-info → keras_hub_nightly-0.23.0.dev202508280418.dist-info}/RECORD +40 -20
  39. {keras_hub_nightly-0.23.0.dev202508260411.dist-info → keras_hub_nightly-0.23.0.dev202508280418.dist-info}/WHEEL +0 -0
  40. {keras_hub_nightly-0.23.0.dev202508260411.dist-info → keras_hub_nightly-0.23.0.dev202508280418.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,938 @@
1
+ import keras
2
+
3
+ from keras_hub.src.models.d_fine.d_fine_utils import hungarian_assignment
4
+ from keras_hub.src.models.d_fine.d_fine_utils import weighting_function
5
+
6
+
7
+ def gather_along_first_two_dims(tensor, batch_idx, src_idx):
8
+ batch_size, num_queries, *feature_dims = keras.ops.shape(tensor)
9
+ batch_size = keras.ops.cast(batch_size, dtype=batch_idx.dtype)
10
+ num_queries = keras.ops.cast(num_queries, dtype=batch_idx.dtype)
11
+ linear_idx = batch_idx * num_queries + src_idx
12
+ flat_tensor = keras.ops.reshape(tensor, (-1, *feature_dims))
13
+ gathered = keras.ops.take(flat_tensor, linear_idx, axis=0)
14
+ return gathered
15
+
16
+
17
+ def hungarian_matcher(
18
+ outputs,
19
+ targets,
20
+ num_targets_per_image,
21
+ use_focal_loss,
22
+ matcher_alpha,
23
+ matcher_gamma,
24
+ matcher_bbox_cost,
25
+ matcher_class_cost,
26
+ matcher_ciou_cost,
27
+ backbone,
28
+ ):
29
+ """Performs bipartite matching between predictions and ground truths.
30
+
31
+ This method implements the Hungarian matching algorithm to find the
32
+ optimal one-to-one assignment between the model's predictions (queries)
33
+ and the ground truth objects. The cost matrix for the assignment is a
34
+ weighted sum of three components:
35
+ 1. **Class Cost:** The cost of classifying a query into the wrong
36
+ class.
37
+ 2. **Bounding Box Cost:** The L1 distance between the predicted and
38
+ ground truth bounding boxes.
39
+ 3. **CIoU Cost:** The Complete Intersection over Union (CIoU) loss.
40
+
41
+ Args:
42
+ outputs: dict, A dictionary containing predicted `"logits"` and
43
+ `"pred_boxes"`.
44
+ targets: list of dict, A list of dictionaries, each containing
45
+ the ground truth `"labels"` and `"boxes"`.
46
+ num_targets_per_image: A tensor of shape `(batch_size,)` indicating
47
+ the number of ground truth objects in each image.
48
+
49
+ Returns:
50
+ tuple: A tuple of three tensors `(row_indices, col_indices,
51
+ valid_masks)`. `row_indices` and `col_indices` contain the indices
52
+ of matched predictions and ground truths, while `valid_masks`
53
+ indicates which matches are valid.
54
+ """
55
+ batch_size = keras.ops.shape(outputs["logits"])[0]
56
+ num_queries = keras.ops.shape(outputs["logits"])[1]
57
+ out_logits = outputs["logits"]
58
+ out_bbox = outputs["pred_boxes"]
59
+ target_ids_all = keras.ops.cast(targets[0]["labels"], dtype="int32")
60
+ target_bbox_all = targets[0]["boxes"]
61
+ target_offsets = keras.ops.concatenate(
62
+ [
63
+ keras.ops.zeros((1,), dtype="int32"),
64
+ keras.ops.cumsum(num_targets_per_image),
65
+ ]
66
+ )
67
+ max_matches = num_queries
68
+ row_indices_init = keras.ops.zeros((batch_size, max_matches), dtype="int32")
69
+ col_indices_init = keras.ops.zeros((batch_size, max_matches), dtype="int32")
70
+ valid_masks_init = keras.ops.zeros((batch_size, max_matches), dtype="bool")
71
+
72
+ def loop_body(i, loop_vars):
73
+ row_indices, col_indices, valid_masks = loop_vars
74
+ out_logits_i = out_logits[i]
75
+ out_bbox_i = out_bbox[i]
76
+ start = target_offsets[i]
77
+ end = target_offsets[i + 1]
78
+ num_targets_i = end - start
79
+ k = keras.ops.arange(0, num_queries)
80
+ is_valid_target_mask = k < num_targets_i
81
+ target_indices = start + k
82
+ safe_target_indices = keras.ops.minimum(
83
+ target_indices, keras.ops.shape(target_ids_all)[0] - 1
84
+ )
85
+ target_ids_i = keras.ops.take(
86
+ target_ids_all, safe_target_indices, axis=0
87
+ )
88
+ target_bbox_i = keras.ops.take(
89
+ target_bbox_all, safe_target_indices, axis=0
90
+ )
91
+
92
+ def compute_cost_matrix():
93
+ if use_focal_loss:
94
+ out_prob_i = keras.ops.sigmoid(out_logits_i)
95
+ safe_ids_for_take = keras.ops.maximum(target_ids_i, 0)
96
+ prob_for_target_classes = keras.ops.take(
97
+ out_prob_i, safe_ids_for_take, axis=1
98
+ )
99
+ p = prob_for_target_classes
100
+ pos_cost = (
101
+ matcher_alpha
102
+ * keras.ops.power(1 - p, matcher_gamma)
103
+ * (-keras.ops.log(p + 1e-8))
104
+ )
105
+ neg_cost = (
106
+ (1 - matcher_alpha)
107
+ * keras.ops.power(p, matcher_gamma)
108
+ * (-keras.ops.log(1 - p + 1e-8))
109
+ )
110
+ class_cost_i = pos_cost - neg_cost
111
+ else:
112
+ out_prob_softmax_i = keras.ops.softmax(out_logits_i, axis=-1)
113
+ safe_ids_for_take = keras.ops.maximum(target_ids_i, 0)
114
+ prob_for_target_classes = keras.ops.take(
115
+ out_prob_softmax_i, safe_ids_for_take, axis=1
116
+ )
117
+ class_cost_i = -prob_for_target_classes
118
+
119
+ bbox_cost_i = keras.ops.sum(
120
+ keras.ops.abs(
121
+ keras.ops.expand_dims(out_bbox_i, 1)
122
+ - keras.ops.expand_dims(target_bbox_i, 0)
123
+ ),
124
+ axis=2,
125
+ )
126
+ out_bbox_corners_i = keras.utils.bounding_boxes.convert_format(
127
+ out_bbox_i,
128
+ source="center_xywh",
129
+ target="xyxy",
130
+ )
131
+ target_bbox_corners_i = keras.utils.bounding_boxes.convert_format(
132
+ target_bbox_i,
133
+ source="center_xywh",
134
+ target="xyxy",
135
+ )
136
+ ciou_cost_i = -keras.utils.bounding_boxes.compute_ciou(
137
+ keras.ops.expand_dims(out_bbox_corners_i, 1),
138
+ keras.ops.expand_dims(target_bbox_corners_i, 0),
139
+ bounding_box_format="xyxy",
140
+ )
141
+
142
+ cost_matrix_i = (
143
+ matcher_bbox_cost * bbox_cost_i
144
+ + matcher_class_cost * class_cost_i
145
+ + matcher_ciou_cost * ciou_cost_i
146
+ )
147
+ cost_matrix_i = keras.ops.where(
148
+ keras.ops.expand_dims(is_valid_target_mask, 0),
149
+ cost_matrix_i,
150
+ 1e9,
151
+ )
152
+ return cost_matrix_i
153
+
154
+ def perform_assignment():
155
+ cost_matrix_i = compute_cost_matrix()
156
+ row_idx, col_idx, valid_mask = hungarian_assignment(
157
+ cost_matrix_i, backbone.num_queries
158
+ )
159
+ valid_mask = keras.ops.logical_and(
160
+ valid_mask, col_idx < num_targets_i
161
+ )
162
+ return row_idx, col_idx, valid_mask
163
+
164
+ def skip_assignment():
165
+ return (
166
+ keras.ops.zeros((num_queries,), dtype="int32"),
167
+ keras.ops.zeros((num_queries,), dtype="int32"),
168
+ keras.ops.zeros((num_queries,), dtype="bool"),
169
+ )
170
+
171
+ row_idx, col_idx, valid_mask = keras.ops.cond(
172
+ keras.ops.greater(num_targets_i, 0),
173
+ perform_assignment,
174
+ skip_assignment,
175
+ )
176
+ row_indices = keras.ops.scatter_update(
177
+ row_indices, [[i]], keras.ops.expand_dims(row_idx, axis=0)
178
+ )
179
+ col_indices = keras.ops.scatter_update(
180
+ col_indices, [[i]], keras.ops.expand_dims(col_idx, axis=0)
181
+ )
182
+ valid_masks = keras.ops.scatter_update(
183
+ valid_masks, [[i]], keras.ops.expand_dims(valid_mask, axis=0)
184
+ )
185
+ return row_indices, col_indices, valid_masks
186
+
187
+ row_indices, col_indices, valid_masks = keras.ops.fori_loop(
188
+ 0,
189
+ batch_size,
190
+ loop_body,
191
+ (row_indices_init, col_indices_init, valid_masks_init),
192
+ )
193
+ return (row_indices, col_indices, valid_masks)
194
+
195
+
196
+ def compute_vfl_loss(
197
+ outputs,
198
+ targets,
199
+ indices,
200
+ num_boxes,
201
+ num_classes,
202
+ matcher_alpha,
203
+ matcher_gamma,
204
+ ):
205
+ """Computes the Varifocal Loss (VFL) for classification.
206
+
207
+ VFL is an asymmetric focal loss variant designed for dense object
208
+ detection. It treats the Intersection over Union (IoU) between a
209
+ predicted box and its matched ground truth box as the target score for
210
+ positive examples while down-weighting the loss for negative examples.
211
+ This helps the model focus on high-quality localizations.
212
+
213
+ Args:
214
+ outputs: dict, A dictionary containing the model's predictions,
215
+ including `"logits"` and `"pred_boxes"`.
216
+ targets: list of dict, A list of dictionaries containing ground
217
+ truth `"labels"` and `"boxes"`.
218
+ indices: tuple, `(row_ind, col_ind, valid_mask)` from the
219
+ Hungarian matcher, indicating the assignments between
220
+ predictions and targets.
221
+ num_boxes: int, The total number of ground truth boxes in the batch,
222
+ used for normalization.
223
+
224
+ Returns:
225
+ Dictionary: The computed VFL loss.
226
+ """
227
+ _, col_indices, valid_masks = indices
228
+ batch_idx, src_idx = _get_source_permutation_idx(indices)
229
+ src_boxes = gather_along_first_two_dims(
230
+ outputs["pred_boxes"], batch_idx, src_idx
231
+ )
232
+ flat_col_indices = keras.ops.reshape(col_indices, (-1,))
233
+ flat_valid_masks = keras.ops.reshape(valid_masks, (-1,))
234
+ src_logits = outputs["logits"]
235
+ target_classes_init = keras.ops.full(
236
+ shape=keras.ops.shape(src_logits)[:2],
237
+ fill_value=num_classes,
238
+ dtype="int32",
239
+ )
240
+ target_score_original = keras.ops.zeros_like(
241
+ target_classes_init, dtype=src_logits.dtype
242
+ )
243
+ update_indices = keras.ops.stack([batch_idx, src_idx], axis=-1)
244
+
245
+ def process_targets():
246
+ target_labels_tensor = keras.ops.stack(
247
+ [t["labels"] for t in targets], axis=0
248
+ )
249
+ target_boxes_tensor = keras.ops.stack(
250
+ [t["boxes"] for t in targets], axis=0
251
+ )
252
+ if keras.ops.ndim(target_labels_tensor) == 3:
253
+ target_labels_tensor = keras.ops.squeeze(
254
+ target_labels_tensor, axis=1
255
+ )
256
+ if keras.ops.ndim(target_boxes_tensor) == 4:
257
+ target_boxes_tensor = keras.ops.squeeze(target_boxes_tensor, axis=1)
258
+ flat_target_labels = keras.ops.reshape(target_labels_tensor, (-1,))
259
+ flat_target_boxes = keras.ops.reshape(target_boxes_tensor, (-1, 4))
260
+ num_targets = keras.ops.shape(flat_target_labels)[0]
261
+ num_targets = keras.ops.cast(num_targets, dtype=flat_col_indices.dtype)
262
+ safe_flat_col_indices = keras.ops.where(
263
+ (flat_col_indices >= 0) & (flat_col_indices < num_targets),
264
+ flat_col_indices,
265
+ 0,
266
+ )
267
+ target_classes_flat = keras.ops.take(
268
+ flat_target_labels, safe_flat_col_indices, axis=0
269
+ )
270
+ target_boxes_flat = keras.ops.take(
271
+ flat_target_boxes, safe_flat_col_indices, axis=0
272
+ )
273
+ target_classes_flat = keras.ops.where(
274
+ flat_valid_masks, target_classes_flat, num_classes
275
+ )
276
+ target_boxes_flat = keras.ops.where(
277
+ keras.ops.expand_dims(flat_valid_masks, axis=-1),
278
+ target_boxes_flat,
279
+ 0.0,
280
+ )
281
+ src_boxes_corners = keras.utils.bounding_boxes.convert_format(
282
+ keras.ops.stop_gradient(src_boxes),
283
+ source="center_xywh",
284
+ target="xyxy",
285
+ )
286
+ target_boxes_corners = keras.utils.bounding_boxes.convert_format(
287
+ target_boxes_flat,
288
+ source="center_xywh",
289
+ target="xyxy",
290
+ )
291
+ ious_matrix = keras.utils.bounding_boxes.compute_iou(
292
+ src_boxes_corners,
293
+ target_boxes_corners,
294
+ bounding_box_format="xyxy",
295
+ )
296
+ ious = keras.ops.diagonal(ious_matrix)
297
+ ious = ious * keras.ops.cast(flat_valid_masks, dtype=ious.dtype)
298
+ target_classes_flat = keras.ops.cast(target_classes_flat, dtype="int32")
299
+ ious = keras.ops.cast(ious, dtype=src_logits.dtype)
300
+ target_classes_updated = keras.ops.scatter_update(
301
+ target_classes_init, update_indices, target_classes_flat
302
+ )
303
+ target_score_updated = keras.ops.scatter_update(
304
+ target_score_original, update_indices, ious
305
+ )
306
+ return target_classes_updated, target_score_updated
307
+
308
+ target_classes, target_score_original = process_targets()
309
+ target_one_hot = keras.ops.one_hot(
310
+ target_classes, num_classes=num_classes + 1
311
+ )[..., :-1]
312
+ target_score = (
313
+ keras.ops.expand_dims(target_score_original, axis=-1) * target_one_hot
314
+ )
315
+ pred_score_sigmoid = keras.ops.sigmoid(keras.ops.stop_gradient(src_logits))
316
+ weight = (
317
+ matcher_alpha
318
+ * keras.ops.power(pred_score_sigmoid, matcher_gamma)
319
+ * (1 - target_one_hot)
320
+ + target_score
321
+ )
322
+ loss_vfl = keras.ops.binary_crossentropy(
323
+ target_score, src_logits, from_logits=True
324
+ )
325
+ loss_vfl = loss_vfl * weight
326
+ loss_vfl = (
327
+ keras.ops.sum(keras.ops.mean(loss_vfl, axis=1))
328
+ * keras.ops.cast(keras.ops.shape(src_logits)[1], dtype=loss_vfl.dtype)
329
+ / num_boxes
330
+ )
331
+ return {"loss_vfl": loss_vfl}
332
+
333
+
334
+ def compute_box_losses(outputs, targets, indices, num_boxes):
335
+ """Computes the bounding box regression losses.
336
+
337
+ This function calculates two losses for the bounding boxes that were
338
+ successfully matched to ground truth objects by the Hungarian matcher:
339
+ 1. **L1 Loss (`loss_bbox`):** A regression loss that measures the
340
+ absolute difference between the predicted and ground truth box
341
+ coordinates.
342
+ 2. **Complete IoU Loss (`loss_ciou`):** A scale-invariant loss that
343
+ accounts for the shape and orientation of the boxes, providing a
344
+ better gradient signal than the standard IoU, especially for
345
+ non-overlapping boxes.
346
+
347
+ Args:
348
+ outputs: dict, A dictionary containing predicted `"pred_boxes"`.
349
+ targets: list of dict, A list of dictionaries containing ground
350
+ truth `"boxes"`.
351
+ indices: tuple, The assignments from the Hungarian matcher.
352
+ num_boxes: int, The total number of ground truth boxes for
353
+ normalization.
354
+
355
+ Returns:
356
+ Dictionary: A dictionary containing the L1 and CIoU losses.
357
+ """
358
+ _, col_indices, valid_masks = indices
359
+ batch_idx, src_idx = _get_source_permutation_idx(indices)
360
+ src_boxes = gather_along_first_two_dims(
361
+ outputs["pred_boxes"], batch_idx, src_idx
362
+ )
363
+ target_boxes_all = targets[0]["boxes"]
364
+ if keras.ops.ndim(target_boxes_all) == 3:
365
+ target_boxes_all = keras.ops.squeeze(target_boxes_all, axis=0)
366
+ col_indices_flat = keras.ops.reshape(col_indices, [-1])
367
+ valid_masks_flat = keras.ops.reshape(valid_masks, [-1])
368
+ max_box_idx = keras.ops.maximum(keras.ops.shape(target_boxes_all)[0] - 1, 0)
369
+ max_box_idx = keras.ops.cast(max_box_idx, dtype=col_indices_flat.dtype)
370
+ safe_col_indices = keras.ops.clip(col_indices_flat, 0, max_box_idx)
371
+ target_boxes = keras.ops.take(target_boxes_all, safe_col_indices, axis=0)
372
+ valid_masks_expanded = keras.ops.expand_dims(valid_masks_flat, axis=-1)
373
+ valid_masks_expanded = keras.ops.cast(
374
+ valid_masks_expanded, target_boxes.dtype
375
+ )
376
+ target_boxes = target_boxes * valid_masks_expanded
377
+ l1_loss = keras.ops.sum(
378
+ keras.ops.abs(src_boxes - target_boxes)
379
+ * keras.ops.cast(valid_masks_expanded, src_boxes.dtype)
380
+ )
381
+ src_boxes_xyxy = keras.utils.bounding_boxes.convert_format(
382
+ src_boxes,
383
+ source="center_xywh",
384
+ target="xyxy",
385
+ )
386
+ target_boxes_xyxy = keras.utils.bounding_boxes.convert_format(
387
+ target_boxes,
388
+ source="center_xywh",
389
+ target="xyxy",
390
+ )
391
+ ciou = keras.utils.bounding_boxes.compute_ciou(
392
+ src_boxes_xyxy,
393
+ target_boxes_xyxy,
394
+ bounding_box_format="xyxy",
395
+ )
396
+ ciou_loss = keras.ops.sum(
397
+ (1.0 - ciou) * keras.ops.cast(valid_masks_flat, src_boxes.dtype)
398
+ )
399
+ return {
400
+ "loss_bbox": l1_loss / num_boxes,
401
+ "loss_ciou": ciou_loss / num_boxes,
402
+ }
403
+
404
+
405
+ def compute_local_losses(
406
+ outputs,
407
+ targets,
408
+ indices,
409
+ num_boxes,
410
+ backbone,
411
+ ddf_temperature,
412
+ compute_ddf=None,
413
+ ):
414
+ """Computes local refinement losses (FGL and DDF).
415
+
416
+ This function calculates two advanced losses for fine-grained box
417
+ and feature refinement:
418
+ 1. **Focal Grid Loss (`loss_fgl`):** This loss operates on the
419
+ integral-based representation of the bounding box corners. It is a
420
+ focal loss applied to the distribution over discrete bins,
421
+ encouraging the model to produce sharp, unimodal distributions
422
+ around the true corner locations.
423
+ 2. **Distribution-guided Denoising Focal Loss (`loss_ddf`):** This is
424
+ a knowledge distillation loss used for auxiliary decoder layers. It
425
+ minimizes the KL-divergence between the corner prediction
426
+ distribution of an intermediate layer (student) and that of the
427
+ final decoder layer (teacher). This guides the intermediate layers
428
+ to learn features that are consistent with the final, most refined
429
+ predictions.
430
+
431
+ Args:
432
+ outputs: dict, A dictionary of model predictions, including
433
+ `"pred_corners"`, `"ref_points"`, and potentially teacher
434
+ predictions like `"teacher_corners"` and `"teacher_logits"`.
435
+ targets: list of dict, A list of dictionaries with ground truth
436
+ `"boxes"`.
437
+ indices: tuple of Tensors, The assignments from the Hungarian
438
+ matcher.
439
+ num_boxes: scalar Tensor, The total number of ground truth boxes for
440
+ normalization.
441
+ compute_ddf: bool, Indicates whether to compute the DDF loss.
442
+
443
+ Returns:
444
+ Dictionary: A dictionary containing the computed FGL and DDF losses.
445
+ """
446
+ losses = {}
447
+ if (
448
+ "pred_corners" not in outputs
449
+ or outputs["pred_corners"] is None
450
+ or "ref_points" not in outputs
451
+ or outputs["ref_points"] is None
452
+ ):
453
+ losses["loss_fgl"] = keras.ops.convert_to_tensor(
454
+ 0.0, dtype=keras.backend.floatx()
455
+ )
456
+ losses["loss_ddf"] = keras.ops.convert_to_tensor(
457
+ 0.0, dtype=keras.backend.floatx()
458
+ )
459
+ return losses
460
+
461
+ if compute_ddf is None:
462
+ compute_ddf = (
463
+ "teacher_corners" in outputs
464
+ and outputs["teacher_corners"] is not None
465
+ and "teacher_logits" in outputs
466
+ )
467
+
468
+ _, col_indices, valid_masks = indices
469
+ batch_idx, src_idx = _get_source_permutation_idx(indices)
470
+ col_indices_flat = keras.ops.reshape(col_indices, [-1])
471
+ valid_masks_flat = keras.ops.reshape(valid_masks, [-1])
472
+ target_boxes_all = targets[0]["boxes"]
473
+ if keras.ops.ndim(target_boxes_all) == 3:
474
+ target_boxes_all = keras.ops.squeeze(target_boxes_all, axis=0)
475
+ max_box_idx = keras.ops.maximum(keras.ops.shape(target_boxes_all)[0] - 1, 0)
476
+ max_box_idx = keras.ops.cast(max_box_idx, dtype=col_indices_flat.dtype)
477
+ safe_col_indices = keras.ops.clip(col_indices_flat, 0, max_box_idx)
478
+ target_boxes_matched_center = keras.ops.take(
479
+ target_boxes_all, safe_col_indices, axis=0
480
+ )
481
+ valid_masks_expanded = keras.ops.expand_dims(valid_masks_flat, axis=-1)
482
+ valid_masks_expanded = keras.ops.cast(
483
+ valid_masks_expanded, target_boxes_matched_center.dtype
484
+ )
485
+ target_boxes_matched_center = (
486
+ target_boxes_matched_center * valid_masks_expanded
487
+ )
488
+
489
+ pred_corners_matched_flat = gather_along_first_two_dims(
490
+ outputs["pred_corners"], batch_idx, src_idx
491
+ )
492
+ pred_corners_matched = keras.ops.reshape(
493
+ pred_corners_matched_flat,
494
+ (-1, backbone.decoder.max_num_bins + 1),
495
+ )
496
+ ref_points_matched = gather_along_first_two_dims(
497
+ outputs["ref_points"], batch_idx, src_idx
498
+ )
499
+ ref_points_matched = keras.ops.stop_gradient(ref_points_matched)
500
+ target_boxes_corners_matched = keras.utils.bounding_boxes.convert_format(
501
+ target_boxes_matched_center,
502
+ source="center_xywh",
503
+ target="xyxy",
504
+ )
505
+ reg_scale_tensor = backbone.decoder.reg_scale
506
+ up_tensor = backbone.decoder.upsampling_factor
507
+ target_corners_dist, weight_right, weight_left = bbox2distance(
508
+ ref_points_matched,
509
+ target_boxes_corners_matched,
510
+ backbone.decoder.max_num_bins,
511
+ reg_scale_tensor,
512
+ up_tensor,
513
+ )
514
+ pred_boxes_matched_center = gather_along_first_two_dims(
515
+ outputs["pred_boxes"], batch_idx, src_idx
516
+ )
517
+ pred_boxes_corners_matched = keras.utils.bounding_boxes.convert_format(
518
+ pred_boxes_matched_center,
519
+ source="center_xywh",
520
+ target="xyxy",
521
+ )
522
+ ious_pairwise = keras.utils.bounding_boxes.compute_iou(
523
+ pred_boxes_corners_matched,
524
+ target_boxes_corners_matched,
525
+ bounding_box_format="xyxy",
526
+ )
527
+ ious = keras.ops.diagonal(ious_pairwise)
528
+ ious = ious * keras.ops.cast(valid_masks_flat, dtype=ious.dtype)
529
+ weight_targets_fgl = keras.ops.reshape(
530
+ keras.ops.tile(keras.ops.expand_dims(ious, 1), [1, 4]),
531
+ [-1],
532
+ )
533
+ weight_targets_fgl = keras.ops.stop_gradient(weight_targets_fgl)
534
+ losses["loss_fgl"] = unimodal_distribution_focal_loss(
535
+ pred_corners_matched,
536
+ target_corners_dist,
537
+ weight_right,
538
+ weight_left,
539
+ weight=weight_targets_fgl,
540
+ avg_factor=num_boxes,
541
+ )
542
+
543
+ def ddf_true_fn():
544
+ pred_corners_all = keras.ops.reshape(
545
+ outputs["pred_corners"],
546
+ (-1, backbone.decoder.max_num_bins + 1),
547
+ )
548
+ target_corners_all = keras.ops.reshape(
549
+ keras.ops.stop_gradient(outputs["teacher_corners"]),
550
+ (-1, backbone.decoder.max_num_bins + 1),
551
+ )
552
+
553
+ def compute_ddf_loss_fn():
554
+ weight_targets_local = keras.ops.max(
555
+ keras.ops.sigmoid(outputs["teacher_logits"]), axis=-1
556
+ )
557
+ num_queries = keras.ops.cast(
558
+ keras.ops.shape(weight_targets_local)[1],
559
+ dtype=batch_idx.dtype,
560
+ )
561
+ flat_update_indices = batch_idx * num_queries + src_idx
562
+ flat_update_indices = keras.ops.expand_dims(
563
+ flat_update_indices, axis=-1
564
+ )
565
+ mask = keras.ops.zeros_like(weight_targets_local, dtype="bool")
566
+ mask_flat = keras.ops.scatter_update(
567
+ keras.ops.reshape(mask, (-1,)),
568
+ flat_update_indices,
569
+ keras.ops.ones_like(batch_idx, dtype="bool"),
570
+ )
571
+ mask = keras.ops.reshape(
572
+ mask_flat, keras.ops.shape(weight_targets_local)
573
+ )
574
+ weight_targets_local_flat = keras.ops.reshape(
575
+ weight_targets_local, (-1,)
576
+ )
577
+ weight_targets_local_matched_flat = keras.ops.scatter_update(
578
+ weight_targets_local_flat,
579
+ flat_update_indices,
580
+ ious,
581
+ )
582
+ weight_targets_local = keras.ops.reshape(
583
+ weight_targets_local_matched_flat,
584
+ keras.ops.shape(weight_targets_local),
585
+ )
586
+ weight_targets_local_expanded = keras.ops.reshape(
587
+ keras.ops.tile(
588
+ keras.ops.expand_dims(weight_targets_local, axis=-1),
589
+ [1, 1, 4],
590
+ ),
591
+ [-1],
592
+ )
593
+ weight_targets_local_expanded = keras.ops.stop_gradient(
594
+ weight_targets_local_expanded
595
+ )
596
+ # NOTE: Original impl hardcodes `ddf_temperature` to 5.0 for
597
+ # DDFL.
598
+ # KerasHub lets users configure it if needed.
599
+ # Ref: https://github.com/huggingface/transformers/blob/b374c3d12e8a42014b7911d1bddf598aeada1154/src/transformers/loss/loss_d_fine.py#L238
600
+ pred_softmax = keras.ops.softmax(
601
+ pred_corners_all / ddf_temperature, axis=-1
602
+ )
603
+ target_softmax = keras.ops.softmax(
604
+ target_corners_all / ddf_temperature, axis=-1
605
+ )
606
+ kl_div = keras.ops.sum(
607
+ target_softmax
608
+ * (
609
+ keras.ops.log(target_softmax + 1e-8)
610
+ - keras.ops.log(pred_softmax + 1e-8)
611
+ ),
612
+ axis=-1,
613
+ )
614
+ loss_match_local = (
615
+ weight_targets_local_expanded * (ddf_temperature**2) * kl_div
616
+ )
617
+ mask_expanded = keras.ops.expand_dims(mask, axis=-1)
618
+ mask_expanded = keras.ops.tile(mask_expanded, [1, 1, 4])
619
+ mask_flat = keras.ops.reshape(mask_expanded, (-1,))
620
+ loss_match_local1 = keras.ops.cond(
621
+ keras.ops.any(mask_flat),
622
+ lambda: keras.ops.sum(
623
+ loss_match_local
624
+ * keras.ops.cast(mask_flat, loss_match_local.dtype)
625
+ )
626
+ / keras.ops.sum(
627
+ keras.ops.cast(mask_flat, loss_match_local.dtype)
628
+ ),
629
+ lambda: keras.ops.convert_to_tensor(
630
+ 0.0, dtype=loss_match_local.dtype
631
+ ),
632
+ )
633
+ neg_mask_flat = keras.ops.logical_not(mask_flat)
634
+ loss_match_local2 = keras.ops.cond(
635
+ keras.ops.any(neg_mask_flat),
636
+ lambda: keras.ops.sum(
637
+ loss_match_local
638
+ * keras.ops.cast(neg_mask_flat, loss_match_local.dtype)
639
+ )
640
+ / keras.ops.sum(
641
+ keras.ops.cast(neg_mask_flat, loss_match_local.dtype)
642
+ ),
643
+ lambda: keras.ops.convert_to_tensor(
644
+ 0.0, dtype=loss_match_local.dtype
645
+ ),
646
+ )
647
+ batch_scale = 1.0 / keras.ops.cast(
648
+ keras.ops.shape(outputs["pred_boxes"])[0],
649
+ dtype="float32",
650
+ )
651
+ num_pos = keras.ops.sqrt(
652
+ keras.ops.sum(keras.ops.cast(mask, dtype="float32"))
653
+ * batch_scale
654
+ )
655
+ num_neg = keras.ops.sqrt(
656
+ keras.ops.sum(keras.ops.cast(~mask, dtype="float32"))
657
+ * batch_scale
658
+ )
659
+ return (
660
+ loss_match_local1 * num_pos + loss_match_local2 * num_neg
661
+ ) / (num_pos + num_neg + 1e-8)
662
+
663
+ all_equal = keras.ops.all(
664
+ keras.ops.equal(pred_corners_all, target_corners_all)
665
+ )
666
+ return keras.ops.cond(
667
+ all_equal,
668
+ lambda: keras.ops.sum(pred_corners_all) * 0.0,
669
+ compute_ddf_loss_fn,
670
+ )
671
+
672
+ def ddf_false_fn():
673
+ return keras.ops.convert_to_tensor(0.0, dtype=keras.backend.floatx())
674
+
675
+ losses["loss_ddf"] = keras.ops.cond(compute_ddf, ddf_true_fn, ddf_false_fn)
676
+ return losses
677
+
678
+
679
+ def _translate_gt_valid_case(
680
+ gt_flat, valid_idx_mask, function_values, max_num_bins, mask
681
+ ):
682
+ closest_left_indices = (
683
+ keras.ops.sum(keras.ops.cast(mask, "int32"), axis=1) - 1
684
+ )
685
+ indices_float = keras.ops.cast(closest_left_indices, dtype=gt_flat.dtype)
686
+ weight_right = keras.ops.zeros_like(indices_float)
687
+ weight_left = keras.ops.zeros_like(indices_float)
688
+ valid_indices_int = keras.ops.arange(keras.ops.shape(valid_idx_mask)[0])
689
+ valid_indices_int = keras.ops.where(valid_idx_mask, valid_indices_int, -1)
690
+ valid_indices_int = keras.ops.where(
691
+ valid_indices_int >= 0, valid_indices_int, 0
692
+ )
693
+ valid_indices_long = keras.ops.cast(
694
+ keras.ops.where(
695
+ valid_idx_mask,
696
+ keras.ops.take(indices_float, valid_indices_int, axis=0),
697
+ 0.0,
698
+ ),
699
+ "int32",
700
+ )
701
+ gt_valid = keras.ops.where(
702
+ valid_idx_mask,
703
+ keras.ops.take(gt_flat, valid_indices_int, axis=0),
704
+ 0.0,
705
+ )
706
+ left_values = keras.ops.take(function_values, valid_indices_long, axis=0)
707
+ right_values = keras.ops.take(
708
+ function_values,
709
+ keras.ops.clip(
710
+ valid_indices_long + 1,
711
+ 0,
712
+ keras.ops.shape(function_values)[0] - 1,
713
+ ),
714
+ axis=0,
715
+ )
716
+ left_diffs = keras.ops.abs(gt_valid - left_values)
717
+ right_diffs = keras.ops.abs(right_values - gt_valid)
718
+ wr_valid = left_diffs / (left_diffs + right_diffs + 1e-8)
719
+ wl_valid = 1.0 - wr_valid
720
+ weight_right = keras.ops.where(
721
+ keras.ops.expand_dims(valid_idx_mask, axis=-1),
722
+ keras.ops.expand_dims(wr_valid, axis=-1),
723
+ keras.ops.expand_dims(weight_right, axis=-1),
724
+ )
725
+ weight_right = keras.ops.squeeze(weight_right, axis=-1)
726
+ weight_left = keras.ops.where(
727
+ keras.ops.expand_dims(valid_idx_mask, axis=-1),
728
+ keras.ops.expand_dims(wl_valid, axis=-1),
729
+ keras.ops.expand_dims(weight_left, axis=-1),
730
+ )
731
+ weight_left = keras.ops.squeeze(weight_left, axis=-1)
732
+ indices_float = keras.ops.where(
733
+ indices_float < 0,
734
+ keras.ops.zeros_like(indices_float),
735
+ indices_float,
736
+ )
737
+ weight_right = keras.ops.where(
738
+ indices_float < 0, keras.ops.zeros_like(weight_right), weight_right
739
+ )
740
+ weight_left = keras.ops.where(
741
+ indices_float < 0, keras.ops.ones_like(weight_left), weight_left
742
+ )
743
+ indices_float = keras.ops.where(
744
+ indices_float >= max_num_bins,
745
+ keras.ops.cast(max_num_bins - 0.1, dtype=indices_float.dtype),
746
+ indices_float,
747
+ )
748
+ weight_right = keras.ops.where(
749
+ indices_float >= max_num_bins,
750
+ keras.ops.ones_like(weight_right),
751
+ weight_right,
752
+ )
753
+ weight_left = keras.ops.where(
754
+ indices_float >= max_num_bins,
755
+ keras.ops.zeros_like(weight_left),
756
+ weight_left,
757
+ )
758
+ return indices_float, weight_right, weight_left
759
+
760
+
761
+ def translate_gt(gt, max_num_bins, reg_scale, up):
762
+ gt_flat = keras.ops.reshape(gt, [-1])
763
+ function_values = weighting_function(max_num_bins, up, reg_scale)
764
+ diffs = keras.ops.expand_dims(
765
+ function_values, axis=0
766
+ ) - keras.ops.expand_dims(gt_flat, axis=1)
767
+ mask = diffs <= 0
768
+ closest_left_indices = (
769
+ keras.ops.sum(keras.ops.cast(mask, "int32"), axis=1) - 1
770
+ )
771
+ indices_float = keras.ops.cast(closest_left_indices, dtype=gt_flat.dtype)
772
+ weight_right = keras.ops.zeros_like(indices_float)
773
+ weight_left = keras.ops.zeros_like(indices_float)
774
+ valid_idx_mask = (indices_float >= 0) & (indices_float < max_num_bins)
775
+ return keras.ops.cond(
776
+ keras.ops.any(valid_idx_mask),
777
+ lambda: _translate_gt_valid_case(
778
+ gt_flat, valid_idx_mask, function_values, max_num_bins, mask
779
+ ),
780
+ lambda: (
781
+ keras.ops.zeros_like(indices_float),
782
+ keras.ops.zeros_like(weight_right),
783
+ keras.ops.ones_like(weight_left),
784
+ ),
785
+ )
786
+
787
+
788
+ def _compute_bbox2distance(points, bbox, max_num_bins, reg_scale, up, eps=0.1):
789
+ reg_scale_abs = keras.ops.abs(reg_scale)
790
+ left = (points[..., 0] - bbox[..., 0]) / (
791
+ points[..., 2] / reg_scale_abs + 1e-16
792
+ ) - 0.5 * reg_scale_abs
793
+ top = (points[..., 1] - bbox[..., 1]) / (
794
+ points[..., 3] / reg_scale_abs + 1e-16
795
+ ) - 0.5 * reg_scale_abs
796
+ right = (bbox[..., 2] - points[..., 0]) / (
797
+ points[..., 2] / reg_scale_abs + 1e-16
798
+ ) - 0.5 * reg_scale_abs
799
+ bottom = (bbox[..., 3] - points[..., 1]) / (
800
+ points[..., 3] / reg_scale_abs + 1e-16
801
+ ) - 0.5 * reg_scale_abs
802
+ four_lens = keras.ops.stack([left, top, right, bottom], axis=-1)
803
+ up_tensor = (
804
+ keras.ops.convert_to_tensor(up)
805
+ if not isinstance(up, (keras.KerasTensor))
806
+ else up
807
+ )
808
+ four_lens_translated, weight_right, weight_left = translate_gt(
809
+ four_lens, max_num_bins, reg_scale_abs, up_tensor
810
+ )
811
+ four_lens_translated = keras.ops.clip(
812
+ four_lens_translated, 0, max_num_bins - eps
813
+ )
814
+ return (
815
+ keras.ops.stop_gradient(four_lens_translated),
816
+ keras.ops.stop_gradient(weight_right),
817
+ keras.ops.stop_gradient(weight_left),
818
+ )
819
+
820
+
821
+ def bbox2distance(points, bbox, max_num_bins, reg_scale, up, eps=0.1):
822
+ expected_flat_size = keras.ops.shape(points)[0] * 4
823
+ return keras.ops.cond(
824
+ keras.ops.equal(keras.ops.shape(points)[0], 0),
825
+ lambda: (
826
+ keras.ops.zeros(
827
+ (expected_flat_size,), dtype=keras.backend.floatx()
828
+ ),
829
+ keras.ops.zeros(
830
+ (expected_flat_size,), dtype=keras.backend.floatx()
831
+ ),
832
+ keras.ops.zeros(
833
+ (expected_flat_size,), dtype=keras.backend.floatx()
834
+ ),
835
+ ),
836
+ lambda: _compute_bbox2distance(
837
+ points, bbox, max_num_bins, reg_scale, up, eps
838
+ ),
839
+ )
840
+
841
+
842
+ def unimodal_distribution_focal_loss(
843
+ pred,
844
+ label,
845
+ weight_right,
846
+ weight_left,
847
+ weight=None,
848
+ reduction="sum",
849
+ avg_factor=None,
850
+ ):
851
+ label_flat = keras.ops.reshape(label, [-1])
852
+ weight_right_flat = keras.ops.reshape(weight_right, [-1])
853
+ weight_left_flat = keras.ops.reshape(weight_left, [-1])
854
+ dis_left = keras.ops.cast(label_flat, "int32")
855
+ dis_right = dis_left + 1
856
+ loss_left = (
857
+ keras.ops.sparse_categorical_crossentropy(
858
+ dis_left, pred, from_logits=True
859
+ )
860
+ * weight_left_flat
861
+ )
862
+ loss_right = (
863
+ keras.ops.sparse_categorical_crossentropy(
864
+ dis_right, pred, from_logits=True
865
+ )
866
+ * weight_right_flat
867
+ )
868
+ loss = loss_left + loss_right
869
+ if weight is not None:
870
+ loss = loss * keras.ops.cast(weight, dtype=loss.dtype)
871
+ if avg_factor is not None:
872
+ loss = keras.ops.sum(loss) / avg_factor
873
+ elif reduction == "mean":
874
+ loss = keras.ops.mean(loss)
875
+ elif reduction == "sum":
876
+ loss = keras.ops.sum(loss)
877
+ return loss
878
+
879
+
880
+ def _get_source_permutation_idx(indices):
881
+ """Gathers the batch and source indices for matched predictions.
882
+
883
+ This method is a JAX-compatible adaptation of the author's approach,
884
+ which creates dynamically sized tensors by concatenating indices from a
885
+ list, which is not traceable by a JIT compiler.
886
+
887
+ To ensure JAX compatibility, this implementation uses a masking
888
+ strategy. It returns fixed-size tensors where invalid positions are
889
+ padded with `0`. The downstream loss functions then use the
890
+ `valid_masks` tensor to ignore these padded entries during loss
891
+ computation.
892
+ """
893
+ row_indices, _, valid_masks = indices
894
+ batch_size = keras.ops.shape(row_indices)[0]
895
+ max_matches = keras.ops.shape(row_indices)[1]
896
+ batch_indices = keras.ops.arange(batch_size, dtype="int32")
897
+ batch_indices = keras.ops.expand_dims(batch_indices, axis=1)
898
+ batch_indices = keras.ops.tile(batch_indices, [1, max_matches])
899
+ batch_indices_flat = keras.ops.reshape(batch_indices, (-1,))
900
+ row_indices_flat = keras.ops.reshape(row_indices, (-1,))
901
+ valid_masks_flat = keras.ops.reshape(valid_masks, (-1,))
902
+ batch_idx = keras.ops.where(
903
+ valid_masks_flat,
904
+ keras.ops.cast(batch_indices_flat, "int64"),
905
+ 0,
906
+ )
907
+ src_idx = keras.ops.where(
908
+ valid_masks_flat,
909
+ keras.ops.cast(row_indices_flat, dtype="int64"),
910
+ 0,
911
+ )
912
+ return batch_idx, src_idx
913
+
914
+
915
+ def get_cdn_matched_indices(dn_meta):
916
+ """Generates matched indices for contrastive denoising (CDN) training.
917
+
918
+ This method is a JAX-compatible adaptation of the author's approach,
919
+ which iterates through the batch to build a list of dynamically sized
920
+ index tensors, which is not traceable by a JIT compiler.
921
+
922
+ To ensure JAX compatibility, this implementation operates on the entire
923
+ batch as a single tensor operation. It uses the pre-padded
924
+ `dn_positive_idx` tensor (where -1 indicates padding) to generate
925
+ fixed-size `row_indices`, `col_indices`, and a `valid_masks` tensor.
926
+ """
927
+ dn_positive_idx = dn_meta["dn_positive_idx"]
928
+ batch_size = keras.ops.shape(dn_positive_idx)[0]
929
+ num_denoising_queries = keras.ops.shape(dn_positive_idx)[1]
930
+ row_indices = keras.ops.tile(
931
+ keras.ops.expand_dims(
932
+ keras.ops.arange(num_denoising_queries, dtype="int64"), 0
933
+ ),
934
+ [batch_size, 1],
935
+ )
936
+ col_indices = dn_positive_idx
937
+ valid_masks = keras.ops.not_equal(col_indices, -1)
938
+ return (row_indices, col_indices, valid_masks)