keras-hub-nightly 0.23.0.dev202508260411__py3-none-any.whl → 0.23.0.dev202508280418__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/layers/__init__.py +6 -0
- keras_hub/models/__init__.py +21 -0
- keras_hub/src/layers/modeling/position_embedding.py +21 -6
- keras_hub/src/layers/modeling/rotary_embedding.py +16 -6
- keras_hub/src/layers/modeling/sine_position_encoding.py +21 -8
- keras_hub/src/layers/modeling/token_and_position_embedding.py +2 -1
- keras_hub/src/models/backbone.py +10 -15
- keras_hub/src/models/d_fine/__init__.py +0 -0
- keras_hub/src/models/d_fine/d_fine_attention.py +461 -0
- keras_hub/src/models/d_fine/d_fine_backbone.py +891 -0
- keras_hub/src/models/d_fine/d_fine_decoder.py +944 -0
- keras_hub/src/models/d_fine/d_fine_encoder.py +365 -0
- keras_hub/src/models/d_fine/d_fine_hybrid_encoder.py +642 -0
- keras_hub/src/models/d_fine/d_fine_image_converter.py +8 -0
- keras_hub/src/models/d_fine/d_fine_layers.py +1828 -0
- keras_hub/src/models/d_fine/d_fine_loss.py +938 -0
- keras_hub/src/models/d_fine/d_fine_object_detector.py +875 -0
- keras_hub/src/models/d_fine/d_fine_object_detector_preprocessor.py +14 -0
- keras_hub/src/models/d_fine/d_fine_presets.py +2 -0
- keras_hub/src/models/d_fine/d_fine_utils.py +827 -0
- keras_hub/src/models/hgnetv2/hgnetv2_backbone.py +4 -1
- keras_hub/src/models/hgnetv2/hgnetv2_encoder.py +3 -2
- keras_hub/src/models/hgnetv2/hgnetv2_layers.py +27 -11
- keras_hub/src/models/parseq/__init__.py +0 -0
- keras_hub/src/models/parseq/parseq_backbone.py +134 -0
- keras_hub/src/models/parseq/parseq_causal_lm.py +466 -0
- keras_hub/src/models/parseq/parseq_causal_lm_preprocessor.py +168 -0
- keras_hub/src/models/parseq/parseq_decoder.py +418 -0
- keras_hub/src/models/parseq/parseq_image_converter.py +8 -0
- keras_hub/src/models/parseq/parseq_tokenizer.py +221 -0
- keras_hub/src/tests/test_case.py +37 -1
- keras_hub/src/utils/preset_utils.py +49 -0
- keras_hub/src/utils/tensor_utils.py +23 -1
- keras_hub/src/utils/transformers/convert_vit.py +4 -1
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +3 -0
- {keras_hub_nightly-0.23.0.dev202508260411.dist-info → keras_hub_nightly-0.23.0.dev202508280418.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.23.0.dev202508260411.dist-info → keras_hub_nightly-0.23.0.dev202508280418.dist-info}/RECORD +40 -20
- {keras_hub_nightly-0.23.0.dev202508260411.dist-info → keras_hub_nightly-0.23.0.dev202508280418.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.23.0.dev202508260411.dist-info → keras_hub_nightly-0.23.0.dev202508280418.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,827 @@
|
|
1
|
+
import keras
|
2
|
+
|
3
|
+
|
4
|
+
def d_fine_kernel_initializer(initializer_range=0.01, name="random_normal"):
|
5
|
+
if name == "random_normal":
|
6
|
+
return keras.initializers.RandomNormal(
|
7
|
+
mean=0.0, stddev=initializer_range
|
8
|
+
)
|
9
|
+
elif name == "glorot_uniform":
|
10
|
+
return keras.initializers.GlorotUniform()
|
11
|
+
elif name == "zeros":
|
12
|
+
return keras.initializers.Zeros()
|
13
|
+
|
14
|
+
|
15
|
+
def inverse_sigmoid(x, eps=1e-5):
|
16
|
+
"""Computes the inverse sigmoid (logit) function.
|
17
|
+
|
18
|
+
This function computes the inverse of the sigmoid function, also known as
|
19
|
+
the logit function. It is used in D-FINE to transform bounding box
|
20
|
+
coordinates from the `[0, 1]` range back to logits, for example in
|
21
|
+
`DFineContrastiveDenoisingGroupGenerator` and `DFineDecoder`.
|
22
|
+
|
23
|
+
Args:
|
24
|
+
x: Tensor, Input tensor with values in `[0, 1]`.
|
25
|
+
eps: float, Small epsilon value to prevent numerical instability
|
26
|
+
at the boundaries. Default is `1e-5`.
|
27
|
+
|
28
|
+
Returns:
|
29
|
+
Tensor: The inverse sigmoid of the input tensor.
|
30
|
+
"""
|
31
|
+
x = keras.ops.clip(x, 0, 1)
|
32
|
+
x1 = keras.ops.maximum(x, eps)
|
33
|
+
x2 = keras.ops.maximum(1 - x, eps)
|
34
|
+
return keras.ops.log(x1 / x2)
|
35
|
+
|
36
|
+
|
37
|
+
def grid_sample(data, grid, align_corners=False, height=None, width=None):
|
38
|
+
"""Samples data at specified grid locations using bilinear interpolation.
|
39
|
+
|
40
|
+
This function performs bilinear interpolation to sample data at arbitrary
|
41
|
+
grid locations. It is a core component of the deformable attention
|
42
|
+
mechanism, used within `multi_scale_deformable_attention_v2`.
|
43
|
+
This is a Keras-native implementation (polyfill) for
|
44
|
+
`torch.nn.functional.grid_sample`.
|
45
|
+
|
46
|
+
Args:
|
47
|
+
data: Tensor, Input data tensor of shape `[batch, channels, height,
|
48
|
+
width]`.
|
49
|
+
grid: Tensor, Grid coordinates of shape `[batch, out_height, out_width,
|
50
|
+
2]`. The last dimension contains `(x, y)` coordinates normalized to
|
51
|
+
`[-1, 1]`.
|
52
|
+
align_corners: bool, If `True`, align corners for coordinate mapping.
|
53
|
+
Default is `False`.
|
54
|
+
height: int, optional, Override height for coordinate normalization.
|
55
|
+
width: int, optional, Override width for coordinate normalization.
|
56
|
+
|
57
|
+
Returns:
|
58
|
+
Tensor: Sampled data of shape `[batch, channels, out_height,
|
59
|
+
out_width]`.
|
60
|
+
"""
|
61
|
+
num_batch, _, data_height, data_width = keras.ops.shape(data)
|
62
|
+
_, out_height, out_width, _ = keras.ops.shape(grid)
|
63
|
+
dtype = data.dtype
|
64
|
+
grid_x_norm = grid[..., 0]
|
65
|
+
grid_y_norm = grid[..., 1]
|
66
|
+
h_in = height if height is not None else data_height
|
67
|
+
w_in = width if width is not None else data_width
|
68
|
+
height_f = keras.ops.cast(h_in, dtype=dtype)
|
69
|
+
width_f = keras.ops.cast(w_in, dtype=dtype)
|
70
|
+
if align_corners:
|
71
|
+
x_unnorm = (grid_x_norm + 1) / 2 * (width_f - 1)
|
72
|
+
y_unnorm = (grid_y_norm + 1) / 2 * (height_f - 1)
|
73
|
+
else:
|
74
|
+
x_unnorm = ((grid_x_norm + 1) / 2 * width_f) - 0.5
|
75
|
+
y_unnorm = ((grid_y_norm + 1) / 2 * height_f) - 0.5
|
76
|
+
x0 = keras.ops.floor(x_unnorm)
|
77
|
+
y0 = keras.ops.floor(y_unnorm)
|
78
|
+
x1 = x0 + 1
|
79
|
+
y1 = y0 + 1
|
80
|
+
w_y0_val = y1 - y_unnorm
|
81
|
+
w_y1_val = y_unnorm - y0
|
82
|
+
w_x0_val = x1 - x_unnorm
|
83
|
+
w_x1_val = x_unnorm - x0
|
84
|
+
data_permuted = keras.ops.transpose(data, (0, 2, 3, 1))
|
85
|
+
|
86
|
+
def gather_padded(
|
87
|
+
data_p,
|
88
|
+
y_coords,
|
89
|
+
x_coords,
|
90
|
+
actual_data_height,
|
91
|
+
actual_data_width,
|
92
|
+
override_height=None,
|
93
|
+
override_width=None,
|
94
|
+
):
|
95
|
+
y_coords_int = keras.ops.cast(y_coords, "int32")
|
96
|
+
x_coords_int = keras.ops.cast(x_coords, "int32")
|
97
|
+
|
98
|
+
y_oob = keras.ops.logical_or(
|
99
|
+
y_coords_int < 0, y_coords_int >= actual_data_height
|
100
|
+
)
|
101
|
+
x_oob = keras.ops.logical_or(
|
102
|
+
x_coords_int < 0, x_coords_int >= actual_data_width
|
103
|
+
)
|
104
|
+
oob_mask = keras.ops.logical_or(y_oob, x_oob)
|
105
|
+
|
106
|
+
y_coords_clipped = keras.ops.clip(
|
107
|
+
y_coords_int, 0, actual_data_height - 1
|
108
|
+
)
|
109
|
+
x_coords_clipped = keras.ops.clip(
|
110
|
+
x_coords_int, 0, actual_data_width - 1
|
111
|
+
)
|
112
|
+
|
113
|
+
width_for_indexing = (
|
114
|
+
override_width if override_width is not None else actual_data_width
|
115
|
+
)
|
116
|
+
|
117
|
+
if override_height is not None and override_width is not None:
|
118
|
+
data_flat = keras.ops.reshape(
|
119
|
+
data_p,
|
120
|
+
(
|
121
|
+
num_batch,
|
122
|
+
override_height * override_width,
|
123
|
+
keras.ops.shape(data_p)[-1],
|
124
|
+
),
|
125
|
+
)
|
126
|
+
else:
|
127
|
+
data_flat = keras.ops.reshape(
|
128
|
+
data_p, (num_batch, -1, keras.ops.shape(data_p)[-1])
|
129
|
+
)
|
130
|
+
y_coords_flat = keras.ops.reshape(
|
131
|
+
y_coords_clipped, (num_batch, out_height * out_width)
|
132
|
+
)
|
133
|
+
x_coords_flat = keras.ops.reshape(
|
134
|
+
x_coords_clipped, (num_batch, out_height * out_width)
|
135
|
+
)
|
136
|
+
indices = y_coords_flat * width_for_indexing + x_coords_flat
|
137
|
+
|
138
|
+
num_elements_per_batch = keras.ops.shape(data_flat)[1]
|
139
|
+
batch_offsets = (
|
140
|
+
keras.ops.arange(num_batch, dtype=indices.dtype)
|
141
|
+
* num_elements_per_batch
|
142
|
+
)
|
143
|
+
batch_offsets = keras.ops.reshape(batch_offsets, (num_batch, 1))
|
144
|
+
absolute_indices = indices + batch_offsets
|
145
|
+
data_reshaped_for_gather = keras.ops.reshape(
|
146
|
+
data_flat, (-1, keras.ops.shape(data_flat)[-1])
|
147
|
+
)
|
148
|
+
gathered = keras.ops.take(
|
149
|
+
data_reshaped_for_gather, absolute_indices, axis=0
|
150
|
+
)
|
151
|
+
gathered = keras.ops.reshape(
|
152
|
+
gathered, (num_batch, out_height, out_width, -1)
|
153
|
+
)
|
154
|
+
oob_mask_expanded = keras.ops.expand_dims(oob_mask, axis=-1)
|
155
|
+
gathered_values = gathered * keras.ops.cast(
|
156
|
+
keras.ops.logical_not(oob_mask_expanded), dtype=gathered.dtype
|
157
|
+
)
|
158
|
+
return gathered_values
|
159
|
+
|
160
|
+
batch_indices = keras.ops.arange(0, num_batch, dtype="int32")
|
161
|
+
batch_indices = keras.ops.reshape(batch_indices, (num_batch, 1, 1))
|
162
|
+
batch_indices = keras.ops.tile(batch_indices, (1, out_height, out_width))
|
163
|
+
val_y0_x0 = gather_padded(data_permuted, y0, x0, h_in, w_in, height, width)
|
164
|
+
val_y0_x1 = gather_padded(data_permuted, y0, x1, h_in, w_in, height, width)
|
165
|
+
val_y1_x0 = gather_padded(data_permuted, y1, x0, h_in, w_in, height, width)
|
166
|
+
val_y1_x1 = gather_padded(data_permuted, y1, x1, h_in, w_in, height, width)
|
167
|
+
interp_val = (
|
168
|
+
val_y0_x0 * keras.ops.expand_dims(w_y0_val * w_x0_val, -1)
|
169
|
+
+ val_y0_x1 * keras.ops.expand_dims(w_y0_val * w_x1_val, -1)
|
170
|
+
+ val_y1_x0 * keras.ops.expand_dims(w_y1_val * w_x0_val, -1)
|
171
|
+
+ val_y1_x1 * keras.ops.expand_dims(w_y1_val * w_x1_val, -1)
|
172
|
+
)
|
173
|
+
|
174
|
+
return keras.ops.transpose(interp_val, (0, 3, 1, 2))
|
175
|
+
|
176
|
+
|
177
|
+
def multi_scale_deformable_attention_v2(
|
178
|
+
value,
|
179
|
+
dynamic_spatial_shapes,
|
180
|
+
sampling_locations,
|
181
|
+
attention_weights,
|
182
|
+
num_points,
|
183
|
+
slice_sizes,
|
184
|
+
spatial_shapes,
|
185
|
+
num_levels,
|
186
|
+
num_queries,
|
187
|
+
method="default",
|
188
|
+
):
|
189
|
+
"""Computes multi-scale deformable attention mechanism.
|
190
|
+
|
191
|
+
This function implements the core of the multi-scale deformable attention
|
192
|
+
mechanism used in `DFineMultiScaleDeformableAttention`. It samples features
|
193
|
+
at multiple scales and locations based on learned attention weights and
|
194
|
+
sampling locations.
|
195
|
+
|
196
|
+
Args:
|
197
|
+
value: Tensor, Feature values of shape `[batch, seq_len, num_heads,
|
198
|
+
hidden_dim]`.
|
199
|
+
dynamic_spatial_shapes: Tensor, Spatial shapes for each level.
|
200
|
+
sampling_locations: Tensor, Sampling locations of shape
|
201
|
+
`[batch, num_queries, num_heads, num_levels, num_points, 2]`.
|
202
|
+
attention_weights: Tensor, Attention weights of shape `[batch,
|
203
|
+
num_queries, num_heads, total_points]`.
|
204
|
+
num_points: list, Number of sampling points for each level.
|
205
|
+
slice_sizes: list, Sizes for slicing the value tensor.
|
206
|
+
spatial_shapes: list, Spatial shapes for each level.
|
207
|
+
num_levels: int, Number of feature levels.
|
208
|
+
num_queries: int, Number of queries.
|
209
|
+
method: str, Sampling method, either `"default"` or `"discrete"`.
|
210
|
+
Default is `"default"`.
|
211
|
+
|
212
|
+
Returns:
|
213
|
+
Tensor: Output features of shape `[batch, num_queries, num_heads *
|
214
|
+
hidden_dim]`.
|
215
|
+
"""
|
216
|
+
value_shape = keras.ops.shape(value)
|
217
|
+
batch_size = value_shape[0]
|
218
|
+
num_heads = value_shape[2]
|
219
|
+
hidden_dim = value_shape[3]
|
220
|
+
sampling_shape = keras.ops.shape(sampling_locations)
|
221
|
+
num_levels_from_shape = sampling_shape[3]
|
222
|
+
num_points_from_shape = sampling_shape[4]
|
223
|
+
permuted_value = keras.ops.transpose(value, axes=(0, 2, 3, 1))
|
224
|
+
seq_len = value_shape[1]
|
225
|
+
flattened_value = keras.ops.reshape(
|
226
|
+
permuted_value, (-1, hidden_dim, seq_len)
|
227
|
+
)
|
228
|
+
value_chunk_sizes = keras.ops.array(slice_sizes, dtype="int32")
|
229
|
+
cum_sizes = keras.ops.concatenate(
|
230
|
+
[
|
231
|
+
keras.ops.zeros((1,), dtype="int32"),
|
232
|
+
keras.ops.cumsum(value_chunk_sizes),
|
233
|
+
]
|
234
|
+
)
|
235
|
+
values = []
|
236
|
+
for i in range(len(spatial_shapes)):
|
237
|
+
start = cum_sizes[i]
|
238
|
+
current_slice_size = slice_sizes[i]
|
239
|
+
dynamic_slice_start_indices = (0, 0, start)
|
240
|
+
dynamic_slice_shape = (
|
241
|
+
keras.ops.shape(flattened_value)[0],
|
242
|
+
keras.ops.shape(flattened_value)[1],
|
243
|
+
current_slice_size,
|
244
|
+
)
|
245
|
+
sliced_value = keras.ops.slice(
|
246
|
+
flattened_value, dynamic_slice_start_indices, dynamic_slice_shape
|
247
|
+
)
|
248
|
+
values.append(sliced_value)
|
249
|
+
if method == "default":
|
250
|
+
sampling_grids = 2 * sampling_locations - 1
|
251
|
+
elif method == "discrete":
|
252
|
+
sampling_grids = sampling_locations
|
253
|
+
else:
|
254
|
+
sampling_grids = 2 * sampling_locations - 1
|
255
|
+
permuted_sampling_grids = keras.ops.transpose(
|
256
|
+
sampling_grids, axes=(0, 2, 1, 3, 4)
|
257
|
+
)
|
258
|
+
flattened_sampling_grids = keras.ops.reshape(
|
259
|
+
permuted_sampling_grids,
|
260
|
+
(
|
261
|
+
batch_size * num_heads,
|
262
|
+
num_queries,
|
263
|
+
num_levels_from_shape,
|
264
|
+
num_points_from_shape,
|
265
|
+
),
|
266
|
+
)
|
267
|
+
cum_points = keras.ops.concatenate(
|
268
|
+
[
|
269
|
+
keras.ops.zeros((1,), dtype="int32"),
|
270
|
+
keras.ops.cumsum(keras.ops.array(num_points, dtype="int32")),
|
271
|
+
]
|
272
|
+
)
|
273
|
+
sampling_grids = []
|
274
|
+
for i in range(num_levels):
|
275
|
+
start = cum_points[i]
|
276
|
+
current_level_num_points = num_points[i]
|
277
|
+
slice_start_indices = (0, 0, start, 0)
|
278
|
+
slice_shape = (
|
279
|
+
keras.ops.shape(flattened_sampling_grids)[0],
|
280
|
+
keras.ops.shape(flattened_sampling_grids)[1],
|
281
|
+
current_level_num_points,
|
282
|
+
keras.ops.shape(flattened_sampling_grids)[3],
|
283
|
+
)
|
284
|
+
sliced_grid = keras.ops.slice(
|
285
|
+
flattened_sampling_grids, slice_start_indices, slice_shape
|
286
|
+
)
|
287
|
+
sampling_grids.append(sliced_grid)
|
288
|
+
sampling_values = []
|
289
|
+
for level_id in range(num_levels):
|
290
|
+
if spatial_shapes is not None and len(spatial_shapes) == num_levels:
|
291
|
+
height, width = spatial_shapes[level_id]
|
292
|
+
else:
|
293
|
+
height = dynamic_spatial_shapes[level_id, 0]
|
294
|
+
width = dynamic_spatial_shapes[level_id, 1]
|
295
|
+
value_l_ = keras.ops.reshape(
|
296
|
+
values[level_id],
|
297
|
+
(batch_size * num_heads, hidden_dim, height, width),
|
298
|
+
)
|
299
|
+
sampling_grid_l_ = sampling_grids[level_id]
|
300
|
+
if method == "default":
|
301
|
+
sampling_value_l_ = grid_sample(
|
302
|
+
data=value_l_,
|
303
|
+
grid=sampling_grid_l_,
|
304
|
+
align_corners=False,
|
305
|
+
height=height,
|
306
|
+
width=width,
|
307
|
+
)
|
308
|
+
elif method == "discrete":
|
309
|
+
scale_factors = keras.ops.cast(
|
310
|
+
keras.ops.array([width, height]),
|
311
|
+
dtype=sampling_grid_l_.dtype,
|
312
|
+
)
|
313
|
+
sampling_coord_float = sampling_grid_l_ * scale_factors
|
314
|
+
sampling_coord_x_int = keras.ops.cast(
|
315
|
+
keras.ops.floor(sampling_coord_float[..., 0] + 0.5), "int32"
|
316
|
+
)
|
317
|
+
sampling_coord_y_int = keras.ops.cast(
|
318
|
+
keras.ops.floor(sampling_coord_float[..., 1] + 0.5), "int32"
|
319
|
+
)
|
320
|
+
clamped_coord_x = keras.ops.clip(sampling_coord_x_int, 0, width - 1)
|
321
|
+
clamped_coord_y = keras.ops.clip(
|
322
|
+
sampling_coord_y_int, 0, height - 1
|
323
|
+
)
|
324
|
+
sampling_coord_stacked = keras.ops.stack(
|
325
|
+
[clamped_coord_x, clamped_coord_y], axis=-1
|
326
|
+
)
|
327
|
+
B_prime = batch_size * num_heads
|
328
|
+
Q_dim = num_queries
|
329
|
+
P_level = num_points[level_id]
|
330
|
+
sampling_coord = keras.ops.reshape(
|
331
|
+
sampling_coord_stacked, (B_prime, Q_dim * P_level, 2)
|
332
|
+
)
|
333
|
+
value_l_permuted = keras.ops.transpose(value_l_, (0, 2, 3, 1))
|
334
|
+
y_coords_for_gather = sampling_coord[
|
335
|
+
..., 1
|
336
|
+
] # (B_prime, Q_dim * P_level)
|
337
|
+
x_coords_for_gather = sampling_coord[
|
338
|
+
..., 0
|
339
|
+
] # (B_prime, Q_dim * P_level)
|
340
|
+
indices = y_coords_for_gather * width + x_coords_for_gather
|
341
|
+
indices = keras.ops.expand_dims(indices, axis=-1)
|
342
|
+
value_l_flat = keras.ops.reshape(
|
343
|
+
value_l_permuted, (B_prime, height * width, hidden_dim)
|
344
|
+
)
|
345
|
+
gathered_values = keras.ops.take_along_axis(
|
346
|
+
value_l_flat, indices, axis=1
|
347
|
+
)
|
348
|
+
permuted_gathered_values = keras.ops.transpose(
|
349
|
+
gathered_values, axes=(0, 2, 1)
|
350
|
+
)
|
351
|
+
sampling_value_l_ = keras.ops.reshape(
|
352
|
+
permuted_gathered_values, (B_prime, hidden_dim, Q_dim, P_level)
|
353
|
+
)
|
354
|
+
else:
|
355
|
+
sampling_value_l_ = grid_sample(
|
356
|
+
data=value_l_,
|
357
|
+
grid=sampling_grid_l_,
|
358
|
+
align_corners=False,
|
359
|
+
height=height,
|
360
|
+
width=width,
|
361
|
+
)
|
362
|
+
sampling_values.append(sampling_value_l_)
|
363
|
+
attention_weights = keras.ops.transpose(
|
364
|
+
attention_weights, axes=(0, 2, 1, 3)
|
365
|
+
)
|
366
|
+
attention_weights = keras.ops.reshape(
|
367
|
+
attention_weights,
|
368
|
+
(batch_size * num_heads, 1, num_queries, sum(num_points)),
|
369
|
+
)
|
370
|
+
concatenated_sampling_values = keras.ops.concatenate(
|
371
|
+
sampling_values, axis=-1
|
372
|
+
)
|
373
|
+
weighted_values = concatenated_sampling_values * attention_weights
|
374
|
+
summed_values = keras.ops.sum(weighted_values, axis=-1)
|
375
|
+
output = keras.ops.reshape(
|
376
|
+
summed_values, (batch_size, num_heads * hidden_dim, num_queries)
|
377
|
+
)
|
378
|
+
return keras.ops.transpose(output, axes=(0, 2, 1))
|
379
|
+
|
380
|
+
|
381
|
+
def weighting_function(max_num_bins, upsampling_factor, reg_scale):
|
382
|
+
"""Generates weighting values for binning operations.
|
383
|
+
|
384
|
+
This function creates a set of weighting values used for integral-based
|
385
|
+
bounding box regression. It is used in `DFineDecoder` to create a
|
386
|
+
projection matrix for converting corner predictions into distances. The
|
387
|
+
weights follow an exponential distribution around zero.
|
388
|
+
|
389
|
+
Args:
|
390
|
+
max_num_bins: int, Maximum number of bins to generate.
|
391
|
+
upsampling_factor: Tensor, A scaling hyperparameter that controls the
|
392
|
+
range of the bins used for integral-based bounding box regression.
|
393
|
+
reg_scale: float, Regularization scale factor.
|
394
|
+
|
395
|
+
Returns:
|
396
|
+
Tensor: Weighting values of shape `[max_num_bins]`.
|
397
|
+
"""
|
398
|
+
upper_bound1 = abs(upsampling_factor[0]) * abs(reg_scale)
|
399
|
+
upper_bound2 = abs(upsampling_factor[0]) * abs(reg_scale) * 2
|
400
|
+
step = (upper_bound1 + 1) ** (2 / (max_num_bins - 2))
|
401
|
+
left_values = [
|
402
|
+
-((step) ** i) + 1 for i in range(max_num_bins // 2 - 1, 0, -1)
|
403
|
+
]
|
404
|
+
right_values = [(step) ** i - 1 for i in range(1, max_num_bins // 2)]
|
405
|
+
values = (
|
406
|
+
[-upper_bound2]
|
407
|
+
+ left_values
|
408
|
+
+ [
|
409
|
+
keras.ops.zeros_like(
|
410
|
+
keras.ops.expand_dims(upsampling_factor[0], axis=0)
|
411
|
+
)
|
412
|
+
]
|
413
|
+
+ right_values
|
414
|
+
+ [upper_bound2]
|
415
|
+
)
|
416
|
+
values = keras.ops.concatenate(values, 0)
|
417
|
+
return values
|
418
|
+
|
419
|
+
|
420
|
+
def distance2bbox(points, distance, reg_scale):
|
421
|
+
"""Converts distance predictions to bounding boxes.
|
422
|
+
|
423
|
+
This function converts distance predictions from anchor points to
|
424
|
+
bounding boxes. It is a key part of the regression head in `DFineDecoder`,
|
425
|
+
transforming the output of the integral-based prediction into final
|
426
|
+
bounding box coordinates.
|
427
|
+
|
428
|
+
Args:
|
429
|
+
points: Tensor, Anchor points of shape `[..., 4]` where the last
|
430
|
+
dimension contains `[x, y, width, height]`.
|
431
|
+
distance: Tensor, Distance predictions of shape `[..., 4]` where
|
432
|
+
the last dimension contains `[left, top, right, bottom]` distances.
|
433
|
+
reg_scale: float, Regularization scale factor.
|
434
|
+
|
435
|
+
Returns:
|
436
|
+
Tensor: Bounding boxes in center format of shape `[..., 4]` where
|
437
|
+
the last dimension contains `[center_x, center_y, width, height]`.
|
438
|
+
"""
|
439
|
+
reg_scale = abs(reg_scale)
|
440
|
+
top_left_x = points[..., 0] - (0.5 * reg_scale + distance[..., 0]) * (
|
441
|
+
points[..., 2] / reg_scale
|
442
|
+
)
|
443
|
+
top_left_y = points[..., 1] - (0.5 * reg_scale + distance[..., 1]) * (
|
444
|
+
points[..., 3] / reg_scale
|
445
|
+
)
|
446
|
+
bottom_right_x = points[..., 0] + (0.5 * reg_scale + distance[..., 2]) * (
|
447
|
+
points[..., 2] / reg_scale
|
448
|
+
)
|
449
|
+
bottom_right_y = points[..., 1] + (0.5 * reg_scale + distance[..., 3]) * (
|
450
|
+
points[..., 3] / reg_scale
|
451
|
+
)
|
452
|
+
bboxes = keras.ops.stack(
|
453
|
+
[top_left_x, top_left_y, bottom_right_x, bottom_right_y], axis=-1
|
454
|
+
)
|
455
|
+
return keras.utils.bounding_boxes.convert_format(
|
456
|
+
bboxes,
|
457
|
+
source="xyxy",
|
458
|
+
target="center_xywh",
|
459
|
+
dtype=points.dtype,
|
460
|
+
)
|
461
|
+
|
462
|
+
|
463
|
+
def hungarian_assignment(cost_matrix, num_queries):
|
464
|
+
"""Solves the linear assignment problem using the Hungarian algorithm.
|
465
|
+
|
466
|
+
This function provides a JIT-compatible implementation of the Hungarian
|
467
|
+
(Munkres) algorithm using pure `keras.ops` operations. It is designed to
|
468
|
+
replace Scipy's `optimize.linear_sum_assignment` for backend-agnostic
|
469
|
+
end-to-end model compilation. The implementation uses a stateful loop
|
470
|
+
with `keras.ops.while_loop`, a state machine pattern with
|
471
|
+
`keras.ops.switch`, and tensor-only operations to ensure compatibility
|
472
|
+
with static graphs and standard accelerators.
|
473
|
+
|
474
|
+
Args:
|
475
|
+
cost_matrix: Tensor, A 2D tensor of shape `(num_rows, num_cols)`
|
476
|
+
representing the cost of each potential assignment. `num_rows`
|
477
|
+
typically corresponds to the number of predictions (queries),
|
478
|
+
and `num_cols` corresponds to number of ground-truth targets.
|
479
|
+
num_queries: int, The fixed number of queries (predictions) from
|
480
|
+
the model, used to establish static shapes for JAX compatibility.
|
481
|
+
|
482
|
+
Returns:
|
483
|
+
Tuple: A tuple `(row_ind, col_ind, valid_mask)` containing:
|
484
|
+
- row_ind: Tensor with integer indices for the rows (predictions).
|
485
|
+
- col_ind: Tensor with integer indices for the assigned columns
|
486
|
+
(targets).
|
487
|
+
- valid_mask: Boolean tensor where `True` indicates a valid
|
488
|
+
assignment that falls within the original (unpadded) cost
|
489
|
+
matrix dimensions.
|
490
|
+
"""
|
491
|
+
# Reference: https://github.com/bmc/munkres/blob/master/munkres.py
|
492
|
+
|
493
|
+
original_num_rows, original_num_cols = keras.ops.shape(cost_matrix)
|
494
|
+
# Pad matrix to be square.
|
495
|
+
padded_cost_matrix = keras.ops.full(
|
496
|
+
(num_queries, num_queries), 1e9, dtype=cost_matrix.dtype
|
497
|
+
)
|
498
|
+
padded_cost_matrix = keras.ops.slice_update(
|
499
|
+
padded_cost_matrix,
|
500
|
+
(0, 0),
|
501
|
+
cost_matrix,
|
502
|
+
)
|
503
|
+
# Step 1: Subtract row minima.
|
504
|
+
cost = padded_cost_matrix - keras.ops.min(
|
505
|
+
padded_cost_matrix, axis=1, keepdims=True
|
506
|
+
)
|
507
|
+
# Step 2: Subtract column minima.
|
508
|
+
cost = cost - keras.ops.min(cost, axis=0, keepdims=True)
|
509
|
+
|
510
|
+
def body(
|
511
|
+
step,
|
512
|
+
cost,
|
513
|
+
starred_mask,
|
514
|
+
row_covered,
|
515
|
+
col_covered,
|
516
|
+
primed_mask,
|
517
|
+
path_start_row,
|
518
|
+
path_start_col,
|
519
|
+
):
|
520
|
+
zero_mask = keras.ops.abs(cost) < 1e-6
|
521
|
+
|
522
|
+
def step_2():
|
523
|
+
# Initial starring: Star zeros with no starred zero in their row or
|
524
|
+
# column.
|
525
|
+
s_mask = keras.ops.zeros_like(starred_mask, dtype="bool")
|
526
|
+
|
527
|
+
def star_zeros(i, s_m):
|
528
|
+
def star_zeros_in_row(j, s_m_inner):
|
529
|
+
is_zero = zero_mask[i, j]
|
530
|
+
# Check if no starred zero in this row.
|
531
|
+
no_star_in_row = keras.ops.logical_not(
|
532
|
+
keras.ops.any(s_m_inner[i])
|
533
|
+
)
|
534
|
+
# Check if no starred zero in this column.
|
535
|
+
no_star_in_col = keras.ops.logical_not(
|
536
|
+
keras.ops.any(s_m_inner[:, j])
|
537
|
+
)
|
538
|
+
|
539
|
+
def can_star():
|
540
|
+
return keras.ops.scatter_update(
|
541
|
+
s_m_inner,
|
542
|
+
[[i, j]],
|
543
|
+
[True],
|
544
|
+
)
|
545
|
+
|
546
|
+
def cannot_star():
|
547
|
+
return s_m_inner
|
548
|
+
|
549
|
+
should_star = keras.ops.logical_and(
|
550
|
+
keras.ops.logical_and(is_zero, no_star_in_row),
|
551
|
+
no_star_in_col,
|
552
|
+
)
|
553
|
+
return keras.ops.cond(should_star, can_star, cannot_star)
|
554
|
+
|
555
|
+
return keras.ops.fori_loop(
|
556
|
+
0, num_queries, star_zeros_in_row, s_m
|
557
|
+
)
|
558
|
+
|
559
|
+
s_mask = keras.ops.fori_loop(0, num_queries, star_zeros, s_mask)
|
560
|
+
return (
|
561
|
+
3,
|
562
|
+
cost,
|
563
|
+
s_mask,
|
564
|
+
keras.ops.zeros_like(row_covered),
|
565
|
+
keras.ops.zeros_like(col_covered),
|
566
|
+
keras.ops.zeros_like(primed_mask),
|
567
|
+
-1,
|
568
|
+
-1,
|
569
|
+
)
|
570
|
+
|
571
|
+
def step_3():
|
572
|
+
# Step 3: Cover each column containing a starred zero.
|
573
|
+
new_col_covered = keras.ops.any(starred_mask, axis=0)
|
574
|
+
num_covered = keras.ops.sum(
|
575
|
+
keras.ops.cast(new_col_covered, "int32")
|
576
|
+
)
|
577
|
+
return keras.ops.cond(
|
578
|
+
num_covered >= num_queries,
|
579
|
+
lambda: (
|
580
|
+
0,
|
581
|
+
cost,
|
582
|
+
starred_mask,
|
583
|
+
row_covered,
|
584
|
+
new_col_covered,
|
585
|
+
primed_mask,
|
586
|
+
-1,
|
587
|
+
-1,
|
588
|
+
), # Done
|
589
|
+
lambda: (
|
590
|
+
4,
|
591
|
+
cost,
|
592
|
+
starred_mask,
|
593
|
+
row_covered,
|
594
|
+
new_col_covered,
|
595
|
+
primed_mask,
|
596
|
+
-1,
|
597
|
+
-1,
|
598
|
+
), # Continue to step 4
|
599
|
+
)
|
600
|
+
|
601
|
+
def step_4():
|
602
|
+
# Step 4: Find a noncovered zero and prime it.
|
603
|
+
uncovered_zeros = keras.ops.logical_and(
|
604
|
+
keras.ops.logical_and(
|
605
|
+
zero_mask,
|
606
|
+
keras.ops.logical_not(
|
607
|
+
keras.ops.expand_dims(row_covered, 1)
|
608
|
+
),
|
609
|
+
),
|
610
|
+
keras.ops.logical_not(keras.ops.expand_dims(col_covered, 0)),
|
611
|
+
)
|
612
|
+
|
613
|
+
def has_uncovered_zero():
|
614
|
+
uncovered_zeros_flat = keras.ops.reshape(uncovered_zeros, [-1])
|
615
|
+
first_idx = keras.ops.argmax(
|
616
|
+
keras.ops.cast(uncovered_zeros_flat, "int32")
|
617
|
+
)
|
618
|
+
r = first_idx // num_queries
|
619
|
+
c = first_idx % num_queries
|
620
|
+
p_mask = keras.ops.scatter_update(primed_mask, [[r, c]], [True])
|
621
|
+
starred_in_row = starred_mask[r]
|
622
|
+
|
623
|
+
def has_starred_in_row():
|
624
|
+
star_col = keras.ops.argmax(
|
625
|
+
keras.ops.cast(starred_in_row, "int32")
|
626
|
+
)
|
627
|
+
r_cov = keras.ops.scatter_update(row_covered, [[r]], [True])
|
628
|
+
c_cov = keras.ops.scatter_update(
|
629
|
+
col_covered, [[star_col]], [False]
|
630
|
+
)
|
631
|
+
return 4, cost, starred_mask, r_cov, c_cov, p_mask, -1, -1
|
632
|
+
|
633
|
+
def no_starred_in_row():
|
634
|
+
return (
|
635
|
+
5,
|
636
|
+
cost,
|
637
|
+
starred_mask,
|
638
|
+
row_covered,
|
639
|
+
col_covered,
|
640
|
+
p_mask,
|
641
|
+
r,
|
642
|
+
c,
|
643
|
+
)
|
644
|
+
|
645
|
+
return keras.ops.cond(
|
646
|
+
keras.ops.any(starred_in_row),
|
647
|
+
has_starred_in_row,
|
648
|
+
no_starred_in_row,
|
649
|
+
)
|
650
|
+
|
651
|
+
def no_uncovered_zero():
|
652
|
+
return (
|
653
|
+
6,
|
654
|
+
cost,
|
655
|
+
starred_mask,
|
656
|
+
row_covered,
|
657
|
+
col_covered,
|
658
|
+
primed_mask,
|
659
|
+
-1,
|
660
|
+
-1,
|
661
|
+
)
|
662
|
+
|
663
|
+
return keras.ops.cond(
|
664
|
+
keras.ops.any(uncovered_zeros),
|
665
|
+
has_uncovered_zero,
|
666
|
+
no_uncovered_zero,
|
667
|
+
)
|
668
|
+
|
669
|
+
def step_5():
|
670
|
+
# Step 5: Construct a series of alternating starred and primed
|
671
|
+
# zeros.
|
672
|
+
path = keras.ops.full((num_queries * 2, 2), -1, dtype="int32")
|
673
|
+
path = keras.ops.scatter_update(
|
674
|
+
path, [[0]], [[path_start_row, path_start_col]]
|
675
|
+
)
|
676
|
+
|
677
|
+
def build_path(count, path_state):
|
678
|
+
def continue_building(cnt, p):
|
679
|
+
current_col = p[cnt - 1, 1]
|
680
|
+
starred_in_col = starred_mask[:, current_col]
|
681
|
+
|
682
|
+
def found_star():
|
683
|
+
star_row = keras.ops.argmax(
|
684
|
+
keras.ops.cast(starred_in_col, "int32")
|
685
|
+
)
|
686
|
+
p1 = keras.ops.scatter_update(
|
687
|
+
p, [[cnt]], [[star_row, current_col]]
|
688
|
+
)
|
689
|
+
primed_in_star_row = primed_mask[star_row]
|
690
|
+
prime_col = keras.ops.argmax(
|
691
|
+
keras.ops.cast(primed_in_star_row, "int32")
|
692
|
+
)
|
693
|
+
p2 = keras.ops.scatter_update(
|
694
|
+
p1, [[cnt + 1]], [[star_row, prime_col]]
|
695
|
+
)
|
696
|
+
return cnt + 2, p2
|
697
|
+
|
698
|
+
def no_star():
|
699
|
+
# Path complete.
|
700
|
+
return cnt, p
|
701
|
+
|
702
|
+
return keras.ops.cond(
|
703
|
+
keras.ops.any(starred_in_col), found_star, no_star
|
704
|
+
)
|
705
|
+
|
706
|
+
def should_continue(cnt, p):
|
707
|
+
return keras.ops.logical_and(
|
708
|
+
cnt < num_queries * 2, p[cnt - 1, 1] >= 0
|
709
|
+
)
|
710
|
+
|
711
|
+
return keras.ops.while_loop(
|
712
|
+
should_continue,
|
713
|
+
continue_building,
|
714
|
+
(count, path_state),
|
715
|
+
maximum_iterations=num_queries,
|
716
|
+
)
|
717
|
+
|
718
|
+
path_count, final_path = build_path(1, path)
|
719
|
+
s_mask = starred_mask
|
720
|
+
|
721
|
+
def update_star_mask(i, mask):
|
722
|
+
def apply_update():
|
723
|
+
row_idx = final_path[i, 0]
|
724
|
+
col_idx = final_path[i, 1]
|
725
|
+
valid_row = keras.ops.logical_and(
|
726
|
+
row_idx >= 0, row_idx < num_queries
|
727
|
+
)
|
728
|
+
valid_col = keras.ops.logical_and(
|
729
|
+
col_idx >= 0, col_idx < num_queries
|
730
|
+
)
|
731
|
+
valid_indices = keras.ops.logical_and(valid_row, valid_col)
|
732
|
+
|
733
|
+
def do_update():
|
734
|
+
current_value = mask[row_idx, col_idx]
|
735
|
+
new_value = keras.ops.logical_not(current_value)
|
736
|
+
return keras.ops.scatter_update(
|
737
|
+
mask, [[row_idx, col_idx]], [new_value]
|
738
|
+
)
|
739
|
+
|
740
|
+
def skip_update():
|
741
|
+
return mask
|
742
|
+
|
743
|
+
return keras.ops.cond(valid_indices, do_update, skip_update)
|
744
|
+
|
745
|
+
def skip_iteration():
|
746
|
+
return mask
|
747
|
+
|
748
|
+
should_process = i < path_count
|
749
|
+
return keras.ops.cond(
|
750
|
+
should_process, apply_update, skip_iteration
|
751
|
+
)
|
752
|
+
|
753
|
+
s_mask = keras.ops.fori_loop(
|
754
|
+
0, num_queries * 2, update_star_mask, s_mask
|
755
|
+
)
|
756
|
+
return (
|
757
|
+
3,
|
758
|
+
cost,
|
759
|
+
s_mask,
|
760
|
+
keras.ops.zeros_like(row_covered),
|
761
|
+
keras.ops.zeros_like(col_covered),
|
762
|
+
keras.ops.zeros_like(primed_mask),
|
763
|
+
-1,
|
764
|
+
-1,
|
765
|
+
)
|
766
|
+
|
767
|
+
def step_6():
|
768
|
+
# Step 6: Add/subtract minimum uncovered value.
|
769
|
+
uncovered_mask = keras.ops.logical_and(
|
770
|
+
keras.ops.logical_not(keras.ops.expand_dims(row_covered, 1)),
|
771
|
+
keras.ops.logical_not(keras.ops.expand_dims(col_covered, 0)),
|
772
|
+
)
|
773
|
+
min_val = keras.ops.min(keras.ops.where(uncovered_mask, cost, 1e9))
|
774
|
+
# Add to covered rows.
|
775
|
+
row_adjustment = keras.ops.where(
|
776
|
+
keras.ops.expand_dims(row_covered, 1), min_val, 0.0
|
777
|
+
)
|
778
|
+
# Subtract from uncovered columns.
|
779
|
+
col_adjustment = keras.ops.where(
|
780
|
+
keras.ops.expand_dims(col_covered, 0), 0.0, -min_val
|
781
|
+
)
|
782
|
+
new_cost = cost + row_adjustment + col_adjustment
|
783
|
+
return (
|
784
|
+
4,
|
785
|
+
new_cost,
|
786
|
+
starred_mask,
|
787
|
+
row_covered,
|
788
|
+
col_covered,
|
789
|
+
primed_mask,
|
790
|
+
-1,
|
791
|
+
-1,
|
792
|
+
)
|
793
|
+
|
794
|
+
return keras.ops.switch(
|
795
|
+
step - 2, [step_2, step_3, step_4, step_5, step_6]
|
796
|
+
)
|
797
|
+
|
798
|
+
# Main algorithm loop.
|
799
|
+
init_state = (
|
800
|
+
2, # Start at step 2
|
801
|
+
cost,
|
802
|
+
keras.ops.zeros(
|
803
|
+
(num_queries, num_queries), dtype="bool"
|
804
|
+
), # starred_mask
|
805
|
+
keras.ops.zeros((num_queries,), dtype="bool"), # row_covered
|
806
|
+
keras.ops.zeros((num_queries,), dtype="bool"), # col_covered
|
807
|
+
keras.ops.zeros(
|
808
|
+
(num_queries, num_queries), dtype="bool"
|
809
|
+
), # primed_mask
|
810
|
+
-1, # path_start_row
|
811
|
+
-1, # path_start_col
|
812
|
+
)
|
813
|
+
final_state = keras.ops.while_loop(
|
814
|
+
lambda step, *_: step > 0,
|
815
|
+
body,
|
816
|
+
init_state,
|
817
|
+
maximum_iterations=num_queries * num_queries,
|
818
|
+
)
|
819
|
+
final_starred_mask = final_state[2]
|
820
|
+
row_ind = keras.ops.arange(num_queries, dtype="int32")
|
821
|
+
col_ind = keras.ops.argmax(
|
822
|
+
keras.ops.cast(final_starred_mask, "int32"), axis=1
|
823
|
+
)
|
824
|
+
valid_mask = keras.ops.logical_and(
|
825
|
+
row_ind < original_num_rows, col_ind < original_num_cols
|
826
|
+
)
|
827
|
+
return row_ind, col_ind, valid_mask
|