keras-hub-nightly 0.23.0.dev202508260411__py3-none-any.whl → 0.23.0.dev202508280418__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. keras_hub/layers/__init__.py +6 -0
  2. keras_hub/models/__init__.py +21 -0
  3. keras_hub/src/layers/modeling/position_embedding.py +21 -6
  4. keras_hub/src/layers/modeling/rotary_embedding.py +16 -6
  5. keras_hub/src/layers/modeling/sine_position_encoding.py +21 -8
  6. keras_hub/src/layers/modeling/token_and_position_embedding.py +2 -1
  7. keras_hub/src/models/backbone.py +10 -15
  8. keras_hub/src/models/d_fine/__init__.py +0 -0
  9. keras_hub/src/models/d_fine/d_fine_attention.py +461 -0
  10. keras_hub/src/models/d_fine/d_fine_backbone.py +891 -0
  11. keras_hub/src/models/d_fine/d_fine_decoder.py +944 -0
  12. keras_hub/src/models/d_fine/d_fine_encoder.py +365 -0
  13. keras_hub/src/models/d_fine/d_fine_hybrid_encoder.py +642 -0
  14. keras_hub/src/models/d_fine/d_fine_image_converter.py +8 -0
  15. keras_hub/src/models/d_fine/d_fine_layers.py +1828 -0
  16. keras_hub/src/models/d_fine/d_fine_loss.py +938 -0
  17. keras_hub/src/models/d_fine/d_fine_object_detector.py +875 -0
  18. keras_hub/src/models/d_fine/d_fine_object_detector_preprocessor.py +14 -0
  19. keras_hub/src/models/d_fine/d_fine_presets.py +2 -0
  20. keras_hub/src/models/d_fine/d_fine_utils.py +827 -0
  21. keras_hub/src/models/hgnetv2/hgnetv2_backbone.py +4 -1
  22. keras_hub/src/models/hgnetv2/hgnetv2_encoder.py +3 -2
  23. keras_hub/src/models/hgnetv2/hgnetv2_layers.py +27 -11
  24. keras_hub/src/models/parseq/__init__.py +0 -0
  25. keras_hub/src/models/parseq/parseq_backbone.py +134 -0
  26. keras_hub/src/models/parseq/parseq_causal_lm.py +466 -0
  27. keras_hub/src/models/parseq/parseq_causal_lm_preprocessor.py +168 -0
  28. keras_hub/src/models/parseq/parseq_decoder.py +418 -0
  29. keras_hub/src/models/parseq/parseq_image_converter.py +8 -0
  30. keras_hub/src/models/parseq/parseq_tokenizer.py +221 -0
  31. keras_hub/src/tests/test_case.py +37 -1
  32. keras_hub/src/utils/preset_utils.py +49 -0
  33. keras_hub/src/utils/tensor_utils.py +23 -1
  34. keras_hub/src/utils/transformers/convert_vit.py +4 -1
  35. keras_hub/src/version.py +1 -1
  36. keras_hub/tokenizers/__init__.py +3 -0
  37. {keras_hub_nightly-0.23.0.dev202508260411.dist-info → keras_hub_nightly-0.23.0.dev202508280418.dist-info}/METADATA +1 -1
  38. {keras_hub_nightly-0.23.0.dev202508260411.dist-info → keras_hub_nightly-0.23.0.dev202508280418.dist-info}/RECORD +40 -20
  39. {keras_hub_nightly-0.23.0.dev202508260411.dist-info → keras_hub_nightly-0.23.0.dev202508280418.dist-info}/WHEEL +0 -0
  40. {keras_hub_nightly-0.23.0.dev202508260411.dist-info → keras_hub_nightly-0.23.0.dev202508280418.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,827 @@
1
+ import keras
2
+
3
+
4
+ def d_fine_kernel_initializer(initializer_range=0.01, name="random_normal"):
5
+ if name == "random_normal":
6
+ return keras.initializers.RandomNormal(
7
+ mean=0.0, stddev=initializer_range
8
+ )
9
+ elif name == "glorot_uniform":
10
+ return keras.initializers.GlorotUniform()
11
+ elif name == "zeros":
12
+ return keras.initializers.Zeros()
13
+
14
+
15
+ def inverse_sigmoid(x, eps=1e-5):
16
+ """Computes the inverse sigmoid (logit) function.
17
+
18
+ This function computes the inverse of the sigmoid function, also known as
19
+ the logit function. It is used in D-FINE to transform bounding box
20
+ coordinates from the `[0, 1]` range back to logits, for example in
21
+ `DFineContrastiveDenoisingGroupGenerator` and `DFineDecoder`.
22
+
23
+ Args:
24
+ x: Tensor, Input tensor with values in `[0, 1]`.
25
+ eps: float, Small epsilon value to prevent numerical instability
26
+ at the boundaries. Default is `1e-5`.
27
+
28
+ Returns:
29
+ Tensor: The inverse sigmoid of the input tensor.
30
+ """
31
+ x = keras.ops.clip(x, 0, 1)
32
+ x1 = keras.ops.maximum(x, eps)
33
+ x2 = keras.ops.maximum(1 - x, eps)
34
+ return keras.ops.log(x1 / x2)
35
+
36
+
37
+ def grid_sample(data, grid, align_corners=False, height=None, width=None):
38
+ """Samples data at specified grid locations using bilinear interpolation.
39
+
40
+ This function performs bilinear interpolation to sample data at arbitrary
41
+ grid locations. It is a core component of the deformable attention
42
+ mechanism, used within `multi_scale_deformable_attention_v2`.
43
+ This is a Keras-native implementation (polyfill) for
44
+ `torch.nn.functional.grid_sample`.
45
+
46
+ Args:
47
+ data: Tensor, Input data tensor of shape `[batch, channels, height,
48
+ width]`.
49
+ grid: Tensor, Grid coordinates of shape `[batch, out_height, out_width,
50
+ 2]`. The last dimension contains `(x, y)` coordinates normalized to
51
+ `[-1, 1]`.
52
+ align_corners: bool, If `True`, align corners for coordinate mapping.
53
+ Default is `False`.
54
+ height: int, optional, Override height for coordinate normalization.
55
+ width: int, optional, Override width for coordinate normalization.
56
+
57
+ Returns:
58
+ Tensor: Sampled data of shape `[batch, channels, out_height,
59
+ out_width]`.
60
+ """
61
+ num_batch, _, data_height, data_width = keras.ops.shape(data)
62
+ _, out_height, out_width, _ = keras.ops.shape(grid)
63
+ dtype = data.dtype
64
+ grid_x_norm = grid[..., 0]
65
+ grid_y_norm = grid[..., 1]
66
+ h_in = height if height is not None else data_height
67
+ w_in = width if width is not None else data_width
68
+ height_f = keras.ops.cast(h_in, dtype=dtype)
69
+ width_f = keras.ops.cast(w_in, dtype=dtype)
70
+ if align_corners:
71
+ x_unnorm = (grid_x_norm + 1) / 2 * (width_f - 1)
72
+ y_unnorm = (grid_y_norm + 1) / 2 * (height_f - 1)
73
+ else:
74
+ x_unnorm = ((grid_x_norm + 1) / 2 * width_f) - 0.5
75
+ y_unnorm = ((grid_y_norm + 1) / 2 * height_f) - 0.5
76
+ x0 = keras.ops.floor(x_unnorm)
77
+ y0 = keras.ops.floor(y_unnorm)
78
+ x1 = x0 + 1
79
+ y1 = y0 + 1
80
+ w_y0_val = y1 - y_unnorm
81
+ w_y1_val = y_unnorm - y0
82
+ w_x0_val = x1 - x_unnorm
83
+ w_x1_val = x_unnorm - x0
84
+ data_permuted = keras.ops.transpose(data, (0, 2, 3, 1))
85
+
86
+ def gather_padded(
87
+ data_p,
88
+ y_coords,
89
+ x_coords,
90
+ actual_data_height,
91
+ actual_data_width,
92
+ override_height=None,
93
+ override_width=None,
94
+ ):
95
+ y_coords_int = keras.ops.cast(y_coords, "int32")
96
+ x_coords_int = keras.ops.cast(x_coords, "int32")
97
+
98
+ y_oob = keras.ops.logical_or(
99
+ y_coords_int < 0, y_coords_int >= actual_data_height
100
+ )
101
+ x_oob = keras.ops.logical_or(
102
+ x_coords_int < 0, x_coords_int >= actual_data_width
103
+ )
104
+ oob_mask = keras.ops.logical_or(y_oob, x_oob)
105
+
106
+ y_coords_clipped = keras.ops.clip(
107
+ y_coords_int, 0, actual_data_height - 1
108
+ )
109
+ x_coords_clipped = keras.ops.clip(
110
+ x_coords_int, 0, actual_data_width - 1
111
+ )
112
+
113
+ width_for_indexing = (
114
+ override_width if override_width is not None else actual_data_width
115
+ )
116
+
117
+ if override_height is not None and override_width is not None:
118
+ data_flat = keras.ops.reshape(
119
+ data_p,
120
+ (
121
+ num_batch,
122
+ override_height * override_width,
123
+ keras.ops.shape(data_p)[-1],
124
+ ),
125
+ )
126
+ else:
127
+ data_flat = keras.ops.reshape(
128
+ data_p, (num_batch, -1, keras.ops.shape(data_p)[-1])
129
+ )
130
+ y_coords_flat = keras.ops.reshape(
131
+ y_coords_clipped, (num_batch, out_height * out_width)
132
+ )
133
+ x_coords_flat = keras.ops.reshape(
134
+ x_coords_clipped, (num_batch, out_height * out_width)
135
+ )
136
+ indices = y_coords_flat * width_for_indexing + x_coords_flat
137
+
138
+ num_elements_per_batch = keras.ops.shape(data_flat)[1]
139
+ batch_offsets = (
140
+ keras.ops.arange(num_batch, dtype=indices.dtype)
141
+ * num_elements_per_batch
142
+ )
143
+ batch_offsets = keras.ops.reshape(batch_offsets, (num_batch, 1))
144
+ absolute_indices = indices + batch_offsets
145
+ data_reshaped_for_gather = keras.ops.reshape(
146
+ data_flat, (-1, keras.ops.shape(data_flat)[-1])
147
+ )
148
+ gathered = keras.ops.take(
149
+ data_reshaped_for_gather, absolute_indices, axis=0
150
+ )
151
+ gathered = keras.ops.reshape(
152
+ gathered, (num_batch, out_height, out_width, -1)
153
+ )
154
+ oob_mask_expanded = keras.ops.expand_dims(oob_mask, axis=-1)
155
+ gathered_values = gathered * keras.ops.cast(
156
+ keras.ops.logical_not(oob_mask_expanded), dtype=gathered.dtype
157
+ )
158
+ return gathered_values
159
+
160
+ batch_indices = keras.ops.arange(0, num_batch, dtype="int32")
161
+ batch_indices = keras.ops.reshape(batch_indices, (num_batch, 1, 1))
162
+ batch_indices = keras.ops.tile(batch_indices, (1, out_height, out_width))
163
+ val_y0_x0 = gather_padded(data_permuted, y0, x0, h_in, w_in, height, width)
164
+ val_y0_x1 = gather_padded(data_permuted, y0, x1, h_in, w_in, height, width)
165
+ val_y1_x0 = gather_padded(data_permuted, y1, x0, h_in, w_in, height, width)
166
+ val_y1_x1 = gather_padded(data_permuted, y1, x1, h_in, w_in, height, width)
167
+ interp_val = (
168
+ val_y0_x0 * keras.ops.expand_dims(w_y0_val * w_x0_val, -1)
169
+ + val_y0_x1 * keras.ops.expand_dims(w_y0_val * w_x1_val, -1)
170
+ + val_y1_x0 * keras.ops.expand_dims(w_y1_val * w_x0_val, -1)
171
+ + val_y1_x1 * keras.ops.expand_dims(w_y1_val * w_x1_val, -1)
172
+ )
173
+
174
+ return keras.ops.transpose(interp_val, (0, 3, 1, 2))
175
+
176
+
177
+ def multi_scale_deformable_attention_v2(
178
+ value,
179
+ dynamic_spatial_shapes,
180
+ sampling_locations,
181
+ attention_weights,
182
+ num_points,
183
+ slice_sizes,
184
+ spatial_shapes,
185
+ num_levels,
186
+ num_queries,
187
+ method="default",
188
+ ):
189
+ """Computes multi-scale deformable attention mechanism.
190
+
191
+ This function implements the core of the multi-scale deformable attention
192
+ mechanism used in `DFineMultiScaleDeformableAttention`. It samples features
193
+ at multiple scales and locations based on learned attention weights and
194
+ sampling locations.
195
+
196
+ Args:
197
+ value: Tensor, Feature values of shape `[batch, seq_len, num_heads,
198
+ hidden_dim]`.
199
+ dynamic_spatial_shapes: Tensor, Spatial shapes for each level.
200
+ sampling_locations: Tensor, Sampling locations of shape
201
+ `[batch, num_queries, num_heads, num_levels, num_points, 2]`.
202
+ attention_weights: Tensor, Attention weights of shape `[batch,
203
+ num_queries, num_heads, total_points]`.
204
+ num_points: list, Number of sampling points for each level.
205
+ slice_sizes: list, Sizes for slicing the value tensor.
206
+ spatial_shapes: list, Spatial shapes for each level.
207
+ num_levels: int, Number of feature levels.
208
+ num_queries: int, Number of queries.
209
+ method: str, Sampling method, either `"default"` or `"discrete"`.
210
+ Default is `"default"`.
211
+
212
+ Returns:
213
+ Tensor: Output features of shape `[batch, num_queries, num_heads *
214
+ hidden_dim]`.
215
+ """
216
+ value_shape = keras.ops.shape(value)
217
+ batch_size = value_shape[0]
218
+ num_heads = value_shape[2]
219
+ hidden_dim = value_shape[3]
220
+ sampling_shape = keras.ops.shape(sampling_locations)
221
+ num_levels_from_shape = sampling_shape[3]
222
+ num_points_from_shape = sampling_shape[4]
223
+ permuted_value = keras.ops.transpose(value, axes=(0, 2, 3, 1))
224
+ seq_len = value_shape[1]
225
+ flattened_value = keras.ops.reshape(
226
+ permuted_value, (-1, hidden_dim, seq_len)
227
+ )
228
+ value_chunk_sizes = keras.ops.array(slice_sizes, dtype="int32")
229
+ cum_sizes = keras.ops.concatenate(
230
+ [
231
+ keras.ops.zeros((1,), dtype="int32"),
232
+ keras.ops.cumsum(value_chunk_sizes),
233
+ ]
234
+ )
235
+ values = []
236
+ for i in range(len(spatial_shapes)):
237
+ start = cum_sizes[i]
238
+ current_slice_size = slice_sizes[i]
239
+ dynamic_slice_start_indices = (0, 0, start)
240
+ dynamic_slice_shape = (
241
+ keras.ops.shape(flattened_value)[0],
242
+ keras.ops.shape(flattened_value)[1],
243
+ current_slice_size,
244
+ )
245
+ sliced_value = keras.ops.slice(
246
+ flattened_value, dynamic_slice_start_indices, dynamic_slice_shape
247
+ )
248
+ values.append(sliced_value)
249
+ if method == "default":
250
+ sampling_grids = 2 * sampling_locations - 1
251
+ elif method == "discrete":
252
+ sampling_grids = sampling_locations
253
+ else:
254
+ sampling_grids = 2 * sampling_locations - 1
255
+ permuted_sampling_grids = keras.ops.transpose(
256
+ sampling_grids, axes=(0, 2, 1, 3, 4)
257
+ )
258
+ flattened_sampling_grids = keras.ops.reshape(
259
+ permuted_sampling_grids,
260
+ (
261
+ batch_size * num_heads,
262
+ num_queries,
263
+ num_levels_from_shape,
264
+ num_points_from_shape,
265
+ ),
266
+ )
267
+ cum_points = keras.ops.concatenate(
268
+ [
269
+ keras.ops.zeros((1,), dtype="int32"),
270
+ keras.ops.cumsum(keras.ops.array(num_points, dtype="int32")),
271
+ ]
272
+ )
273
+ sampling_grids = []
274
+ for i in range(num_levels):
275
+ start = cum_points[i]
276
+ current_level_num_points = num_points[i]
277
+ slice_start_indices = (0, 0, start, 0)
278
+ slice_shape = (
279
+ keras.ops.shape(flattened_sampling_grids)[0],
280
+ keras.ops.shape(flattened_sampling_grids)[1],
281
+ current_level_num_points,
282
+ keras.ops.shape(flattened_sampling_grids)[3],
283
+ )
284
+ sliced_grid = keras.ops.slice(
285
+ flattened_sampling_grids, slice_start_indices, slice_shape
286
+ )
287
+ sampling_grids.append(sliced_grid)
288
+ sampling_values = []
289
+ for level_id in range(num_levels):
290
+ if spatial_shapes is not None and len(spatial_shapes) == num_levels:
291
+ height, width = spatial_shapes[level_id]
292
+ else:
293
+ height = dynamic_spatial_shapes[level_id, 0]
294
+ width = dynamic_spatial_shapes[level_id, 1]
295
+ value_l_ = keras.ops.reshape(
296
+ values[level_id],
297
+ (batch_size * num_heads, hidden_dim, height, width),
298
+ )
299
+ sampling_grid_l_ = sampling_grids[level_id]
300
+ if method == "default":
301
+ sampling_value_l_ = grid_sample(
302
+ data=value_l_,
303
+ grid=sampling_grid_l_,
304
+ align_corners=False,
305
+ height=height,
306
+ width=width,
307
+ )
308
+ elif method == "discrete":
309
+ scale_factors = keras.ops.cast(
310
+ keras.ops.array([width, height]),
311
+ dtype=sampling_grid_l_.dtype,
312
+ )
313
+ sampling_coord_float = sampling_grid_l_ * scale_factors
314
+ sampling_coord_x_int = keras.ops.cast(
315
+ keras.ops.floor(sampling_coord_float[..., 0] + 0.5), "int32"
316
+ )
317
+ sampling_coord_y_int = keras.ops.cast(
318
+ keras.ops.floor(sampling_coord_float[..., 1] + 0.5), "int32"
319
+ )
320
+ clamped_coord_x = keras.ops.clip(sampling_coord_x_int, 0, width - 1)
321
+ clamped_coord_y = keras.ops.clip(
322
+ sampling_coord_y_int, 0, height - 1
323
+ )
324
+ sampling_coord_stacked = keras.ops.stack(
325
+ [clamped_coord_x, clamped_coord_y], axis=-1
326
+ )
327
+ B_prime = batch_size * num_heads
328
+ Q_dim = num_queries
329
+ P_level = num_points[level_id]
330
+ sampling_coord = keras.ops.reshape(
331
+ sampling_coord_stacked, (B_prime, Q_dim * P_level, 2)
332
+ )
333
+ value_l_permuted = keras.ops.transpose(value_l_, (0, 2, 3, 1))
334
+ y_coords_for_gather = sampling_coord[
335
+ ..., 1
336
+ ] # (B_prime, Q_dim * P_level)
337
+ x_coords_for_gather = sampling_coord[
338
+ ..., 0
339
+ ] # (B_prime, Q_dim * P_level)
340
+ indices = y_coords_for_gather * width + x_coords_for_gather
341
+ indices = keras.ops.expand_dims(indices, axis=-1)
342
+ value_l_flat = keras.ops.reshape(
343
+ value_l_permuted, (B_prime, height * width, hidden_dim)
344
+ )
345
+ gathered_values = keras.ops.take_along_axis(
346
+ value_l_flat, indices, axis=1
347
+ )
348
+ permuted_gathered_values = keras.ops.transpose(
349
+ gathered_values, axes=(0, 2, 1)
350
+ )
351
+ sampling_value_l_ = keras.ops.reshape(
352
+ permuted_gathered_values, (B_prime, hidden_dim, Q_dim, P_level)
353
+ )
354
+ else:
355
+ sampling_value_l_ = grid_sample(
356
+ data=value_l_,
357
+ grid=sampling_grid_l_,
358
+ align_corners=False,
359
+ height=height,
360
+ width=width,
361
+ )
362
+ sampling_values.append(sampling_value_l_)
363
+ attention_weights = keras.ops.transpose(
364
+ attention_weights, axes=(0, 2, 1, 3)
365
+ )
366
+ attention_weights = keras.ops.reshape(
367
+ attention_weights,
368
+ (batch_size * num_heads, 1, num_queries, sum(num_points)),
369
+ )
370
+ concatenated_sampling_values = keras.ops.concatenate(
371
+ sampling_values, axis=-1
372
+ )
373
+ weighted_values = concatenated_sampling_values * attention_weights
374
+ summed_values = keras.ops.sum(weighted_values, axis=-1)
375
+ output = keras.ops.reshape(
376
+ summed_values, (batch_size, num_heads * hidden_dim, num_queries)
377
+ )
378
+ return keras.ops.transpose(output, axes=(0, 2, 1))
379
+
380
+
381
+ def weighting_function(max_num_bins, upsampling_factor, reg_scale):
382
+ """Generates weighting values for binning operations.
383
+
384
+ This function creates a set of weighting values used for integral-based
385
+ bounding box regression. It is used in `DFineDecoder` to create a
386
+ projection matrix for converting corner predictions into distances. The
387
+ weights follow an exponential distribution around zero.
388
+
389
+ Args:
390
+ max_num_bins: int, Maximum number of bins to generate.
391
+ upsampling_factor: Tensor, A scaling hyperparameter that controls the
392
+ range of the bins used for integral-based bounding box regression.
393
+ reg_scale: float, Regularization scale factor.
394
+
395
+ Returns:
396
+ Tensor: Weighting values of shape `[max_num_bins]`.
397
+ """
398
+ upper_bound1 = abs(upsampling_factor[0]) * abs(reg_scale)
399
+ upper_bound2 = abs(upsampling_factor[0]) * abs(reg_scale) * 2
400
+ step = (upper_bound1 + 1) ** (2 / (max_num_bins - 2))
401
+ left_values = [
402
+ -((step) ** i) + 1 for i in range(max_num_bins // 2 - 1, 0, -1)
403
+ ]
404
+ right_values = [(step) ** i - 1 for i in range(1, max_num_bins // 2)]
405
+ values = (
406
+ [-upper_bound2]
407
+ + left_values
408
+ + [
409
+ keras.ops.zeros_like(
410
+ keras.ops.expand_dims(upsampling_factor[0], axis=0)
411
+ )
412
+ ]
413
+ + right_values
414
+ + [upper_bound2]
415
+ )
416
+ values = keras.ops.concatenate(values, 0)
417
+ return values
418
+
419
+
420
+ def distance2bbox(points, distance, reg_scale):
421
+ """Converts distance predictions to bounding boxes.
422
+
423
+ This function converts distance predictions from anchor points to
424
+ bounding boxes. It is a key part of the regression head in `DFineDecoder`,
425
+ transforming the output of the integral-based prediction into final
426
+ bounding box coordinates.
427
+
428
+ Args:
429
+ points: Tensor, Anchor points of shape `[..., 4]` where the last
430
+ dimension contains `[x, y, width, height]`.
431
+ distance: Tensor, Distance predictions of shape `[..., 4]` where
432
+ the last dimension contains `[left, top, right, bottom]` distances.
433
+ reg_scale: float, Regularization scale factor.
434
+
435
+ Returns:
436
+ Tensor: Bounding boxes in center format of shape `[..., 4]` where
437
+ the last dimension contains `[center_x, center_y, width, height]`.
438
+ """
439
+ reg_scale = abs(reg_scale)
440
+ top_left_x = points[..., 0] - (0.5 * reg_scale + distance[..., 0]) * (
441
+ points[..., 2] / reg_scale
442
+ )
443
+ top_left_y = points[..., 1] - (0.5 * reg_scale + distance[..., 1]) * (
444
+ points[..., 3] / reg_scale
445
+ )
446
+ bottom_right_x = points[..., 0] + (0.5 * reg_scale + distance[..., 2]) * (
447
+ points[..., 2] / reg_scale
448
+ )
449
+ bottom_right_y = points[..., 1] + (0.5 * reg_scale + distance[..., 3]) * (
450
+ points[..., 3] / reg_scale
451
+ )
452
+ bboxes = keras.ops.stack(
453
+ [top_left_x, top_left_y, bottom_right_x, bottom_right_y], axis=-1
454
+ )
455
+ return keras.utils.bounding_boxes.convert_format(
456
+ bboxes,
457
+ source="xyxy",
458
+ target="center_xywh",
459
+ dtype=points.dtype,
460
+ )
461
+
462
+
463
+ def hungarian_assignment(cost_matrix, num_queries):
464
+ """Solves the linear assignment problem using the Hungarian algorithm.
465
+
466
+ This function provides a JIT-compatible implementation of the Hungarian
467
+ (Munkres) algorithm using pure `keras.ops` operations. It is designed to
468
+ replace Scipy's `optimize.linear_sum_assignment` for backend-agnostic
469
+ end-to-end model compilation. The implementation uses a stateful loop
470
+ with `keras.ops.while_loop`, a state machine pattern with
471
+ `keras.ops.switch`, and tensor-only operations to ensure compatibility
472
+ with static graphs and standard accelerators.
473
+
474
+ Args:
475
+ cost_matrix: Tensor, A 2D tensor of shape `(num_rows, num_cols)`
476
+ representing the cost of each potential assignment. `num_rows`
477
+ typically corresponds to the number of predictions (queries),
478
+ and `num_cols` corresponds to number of ground-truth targets.
479
+ num_queries: int, The fixed number of queries (predictions) from
480
+ the model, used to establish static shapes for JAX compatibility.
481
+
482
+ Returns:
483
+ Tuple: A tuple `(row_ind, col_ind, valid_mask)` containing:
484
+ - row_ind: Tensor with integer indices for the rows (predictions).
485
+ - col_ind: Tensor with integer indices for the assigned columns
486
+ (targets).
487
+ - valid_mask: Boolean tensor where `True` indicates a valid
488
+ assignment that falls within the original (unpadded) cost
489
+ matrix dimensions.
490
+ """
491
+ # Reference: https://github.com/bmc/munkres/blob/master/munkres.py
492
+
493
+ original_num_rows, original_num_cols = keras.ops.shape(cost_matrix)
494
+ # Pad matrix to be square.
495
+ padded_cost_matrix = keras.ops.full(
496
+ (num_queries, num_queries), 1e9, dtype=cost_matrix.dtype
497
+ )
498
+ padded_cost_matrix = keras.ops.slice_update(
499
+ padded_cost_matrix,
500
+ (0, 0),
501
+ cost_matrix,
502
+ )
503
+ # Step 1: Subtract row minima.
504
+ cost = padded_cost_matrix - keras.ops.min(
505
+ padded_cost_matrix, axis=1, keepdims=True
506
+ )
507
+ # Step 2: Subtract column minima.
508
+ cost = cost - keras.ops.min(cost, axis=0, keepdims=True)
509
+
510
+ def body(
511
+ step,
512
+ cost,
513
+ starred_mask,
514
+ row_covered,
515
+ col_covered,
516
+ primed_mask,
517
+ path_start_row,
518
+ path_start_col,
519
+ ):
520
+ zero_mask = keras.ops.abs(cost) < 1e-6
521
+
522
+ def step_2():
523
+ # Initial starring: Star zeros with no starred zero in their row or
524
+ # column.
525
+ s_mask = keras.ops.zeros_like(starred_mask, dtype="bool")
526
+
527
+ def star_zeros(i, s_m):
528
+ def star_zeros_in_row(j, s_m_inner):
529
+ is_zero = zero_mask[i, j]
530
+ # Check if no starred zero in this row.
531
+ no_star_in_row = keras.ops.logical_not(
532
+ keras.ops.any(s_m_inner[i])
533
+ )
534
+ # Check if no starred zero in this column.
535
+ no_star_in_col = keras.ops.logical_not(
536
+ keras.ops.any(s_m_inner[:, j])
537
+ )
538
+
539
+ def can_star():
540
+ return keras.ops.scatter_update(
541
+ s_m_inner,
542
+ [[i, j]],
543
+ [True],
544
+ )
545
+
546
+ def cannot_star():
547
+ return s_m_inner
548
+
549
+ should_star = keras.ops.logical_and(
550
+ keras.ops.logical_and(is_zero, no_star_in_row),
551
+ no_star_in_col,
552
+ )
553
+ return keras.ops.cond(should_star, can_star, cannot_star)
554
+
555
+ return keras.ops.fori_loop(
556
+ 0, num_queries, star_zeros_in_row, s_m
557
+ )
558
+
559
+ s_mask = keras.ops.fori_loop(0, num_queries, star_zeros, s_mask)
560
+ return (
561
+ 3,
562
+ cost,
563
+ s_mask,
564
+ keras.ops.zeros_like(row_covered),
565
+ keras.ops.zeros_like(col_covered),
566
+ keras.ops.zeros_like(primed_mask),
567
+ -1,
568
+ -1,
569
+ )
570
+
571
+ def step_3():
572
+ # Step 3: Cover each column containing a starred zero.
573
+ new_col_covered = keras.ops.any(starred_mask, axis=0)
574
+ num_covered = keras.ops.sum(
575
+ keras.ops.cast(new_col_covered, "int32")
576
+ )
577
+ return keras.ops.cond(
578
+ num_covered >= num_queries,
579
+ lambda: (
580
+ 0,
581
+ cost,
582
+ starred_mask,
583
+ row_covered,
584
+ new_col_covered,
585
+ primed_mask,
586
+ -1,
587
+ -1,
588
+ ), # Done
589
+ lambda: (
590
+ 4,
591
+ cost,
592
+ starred_mask,
593
+ row_covered,
594
+ new_col_covered,
595
+ primed_mask,
596
+ -1,
597
+ -1,
598
+ ), # Continue to step 4
599
+ )
600
+
601
+ def step_4():
602
+ # Step 4: Find a noncovered zero and prime it.
603
+ uncovered_zeros = keras.ops.logical_and(
604
+ keras.ops.logical_and(
605
+ zero_mask,
606
+ keras.ops.logical_not(
607
+ keras.ops.expand_dims(row_covered, 1)
608
+ ),
609
+ ),
610
+ keras.ops.logical_not(keras.ops.expand_dims(col_covered, 0)),
611
+ )
612
+
613
+ def has_uncovered_zero():
614
+ uncovered_zeros_flat = keras.ops.reshape(uncovered_zeros, [-1])
615
+ first_idx = keras.ops.argmax(
616
+ keras.ops.cast(uncovered_zeros_flat, "int32")
617
+ )
618
+ r = first_idx // num_queries
619
+ c = first_idx % num_queries
620
+ p_mask = keras.ops.scatter_update(primed_mask, [[r, c]], [True])
621
+ starred_in_row = starred_mask[r]
622
+
623
+ def has_starred_in_row():
624
+ star_col = keras.ops.argmax(
625
+ keras.ops.cast(starred_in_row, "int32")
626
+ )
627
+ r_cov = keras.ops.scatter_update(row_covered, [[r]], [True])
628
+ c_cov = keras.ops.scatter_update(
629
+ col_covered, [[star_col]], [False]
630
+ )
631
+ return 4, cost, starred_mask, r_cov, c_cov, p_mask, -1, -1
632
+
633
+ def no_starred_in_row():
634
+ return (
635
+ 5,
636
+ cost,
637
+ starred_mask,
638
+ row_covered,
639
+ col_covered,
640
+ p_mask,
641
+ r,
642
+ c,
643
+ )
644
+
645
+ return keras.ops.cond(
646
+ keras.ops.any(starred_in_row),
647
+ has_starred_in_row,
648
+ no_starred_in_row,
649
+ )
650
+
651
+ def no_uncovered_zero():
652
+ return (
653
+ 6,
654
+ cost,
655
+ starred_mask,
656
+ row_covered,
657
+ col_covered,
658
+ primed_mask,
659
+ -1,
660
+ -1,
661
+ )
662
+
663
+ return keras.ops.cond(
664
+ keras.ops.any(uncovered_zeros),
665
+ has_uncovered_zero,
666
+ no_uncovered_zero,
667
+ )
668
+
669
+ def step_5():
670
+ # Step 5: Construct a series of alternating starred and primed
671
+ # zeros.
672
+ path = keras.ops.full((num_queries * 2, 2), -1, dtype="int32")
673
+ path = keras.ops.scatter_update(
674
+ path, [[0]], [[path_start_row, path_start_col]]
675
+ )
676
+
677
+ def build_path(count, path_state):
678
+ def continue_building(cnt, p):
679
+ current_col = p[cnt - 1, 1]
680
+ starred_in_col = starred_mask[:, current_col]
681
+
682
+ def found_star():
683
+ star_row = keras.ops.argmax(
684
+ keras.ops.cast(starred_in_col, "int32")
685
+ )
686
+ p1 = keras.ops.scatter_update(
687
+ p, [[cnt]], [[star_row, current_col]]
688
+ )
689
+ primed_in_star_row = primed_mask[star_row]
690
+ prime_col = keras.ops.argmax(
691
+ keras.ops.cast(primed_in_star_row, "int32")
692
+ )
693
+ p2 = keras.ops.scatter_update(
694
+ p1, [[cnt + 1]], [[star_row, prime_col]]
695
+ )
696
+ return cnt + 2, p2
697
+
698
+ def no_star():
699
+ # Path complete.
700
+ return cnt, p
701
+
702
+ return keras.ops.cond(
703
+ keras.ops.any(starred_in_col), found_star, no_star
704
+ )
705
+
706
+ def should_continue(cnt, p):
707
+ return keras.ops.logical_and(
708
+ cnt < num_queries * 2, p[cnt - 1, 1] >= 0
709
+ )
710
+
711
+ return keras.ops.while_loop(
712
+ should_continue,
713
+ continue_building,
714
+ (count, path_state),
715
+ maximum_iterations=num_queries,
716
+ )
717
+
718
+ path_count, final_path = build_path(1, path)
719
+ s_mask = starred_mask
720
+
721
+ def update_star_mask(i, mask):
722
+ def apply_update():
723
+ row_idx = final_path[i, 0]
724
+ col_idx = final_path[i, 1]
725
+ valid_row = keras.ops.logical_and(
726
+ row_idx >= 0, row_idx < num_queries
727
+ )
728
+ valid_col = keras.ops.logical_and(
729
+ col_idx >= 0, col_idx < num_queries
730
+ )
731
+ valid_indices = keras.ops.logical_and(valid_row, valid_col)
732
+
733
+ def do_update():
734
+ current_value = mask[row_idx, col_idx]
735
+ new_value = keras.ops.logical_not(current_value)
736
+ return keras.ops.scatter_update(
737
+ mask, [[row_idx, col_idx]], [new_value]
738
+ )
739
+
740
+ def skip_update():
741
+ return mask
742
+
743
+ return keras.ops.cond(valid_indices, do_update, skip_update)
744
+
745
+ def skip_iteration():
746
+ return mask
747
+
748
+ should_process = i < path_count
749
+ return keras.ops.cond(
750
+ should_process, apply_update, skip_iteration
751
+ )
752
+
753
+ s_mask = keras.ops.fori_loop(
754
+ 0, num_queries * 2, update_star_mask, s_mask
755
+ )
756
+ return (
757
+ 3,
758
+ cost,
759
+ s_mask,
760
+ keras.ops.zeros_like(row_covered),
761
+ keras.ops.zeros_like(col_covered),
762
+ keras.ops.zeros_like(primed_mask),
763
+ -1,
764
+ -1,
765
+ )
766
+
767
+ def step_6():
768
+ # Step 6: Add/subtract minimum uncovered value.
769
+ uncovered_mask = keras.ops.logical_and(
770
+ keras.ops.logical_not(keras.ops.expand_dims(row_covered, 1)),
771
+ keras.ops.logical_not(keras.ops.expand_dims(col_covered, 0)),
772
+ )
773
+ min_val = keras.ops.min(keras.ops.where(uncovered_mask, cost, 1e9))
774
+ # Add to covered rows.
775
+ row_adjustment = keras.ops.where(
776
+ keras.ops.expand_dims(row_covered, 1), min_val, 0.0
777
+ )
778
+ # Subtract from uncovered columns.
779
+ col_adjustment = keras.ops.where(
780
+ keras.ops.expand_dims(col_covered, 0), 0.0, -min_val
781
+ )
782
+ new_cost = cost + row_adjustment + col_adjustment
783
+ return (
784
+ 4,
785
+ new_cost,
786
+ starred_mask,
787
+ row_covered,
788
+ col_covered,
789
+ primed_mask,
790
+ -1,
791
+ -1,
792
+ )
793
+
794
+ return keras.ops.switch(
795
+ step - 2, [step_2, step_3, step_4, step_5, step_6]
796
+ )
797
+
798
+ # Main algorithm loop.
799
+ init_state = (
800
+ 2, # Start at step 2
801
+ cost,
802
+ keras.ops.zeros(
803
+ (num_queries, num_queries), dtype="bool"
804
+ ), # starred_mask
805
+ keras.ops.zeros((num_queries,), dtype="bool"), # row_covered
806
+ keras.ops.zeros((num_queries,), dtype="bool"), # col_covered
807
+ keras.ops.zeros(
808
+ (num_queries, num_queries), dtype="bool"
809
+ ), # primed_mask
810
+ -1, # path_start_row
811
+ -1, # path_start_col
812
+ )
813
+ final_state = keras.ops.while_loop(
814
+ lambda step, *_: step > 0,
815
+ body,
816
+ init_state,
817
+ maximum_iterations=num_queries * num_queries,
818
+ )
819
+ final_starred_mask = final_state[2]
820
+ row_ind = keras.ops.arange(num_queries, dtype="int32")
821
+ col_ind = keras.ops.argmax(
822
+ keras.ops.cast(final_starred_mask, "int32"), axis=1
823
+ )
824
+ valid_mask = keras.ops.logical_and(
825
+ row_ind < original_num_rows, col_ind < original_num_cols
826
+ )
827
+ return row_ind, col_ind, valid_mask