keras-hub-nightly 0.23.0.dev202508260411__py3-none-any.whl → 0.23.0.dev202508280418__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. keras_hub/layers/__init__.py +6 -0
  2. keras_hub/models/__init__.py +21 -0
  3. keras_hub/src/layers/modeling/position_embedding.py +21 -6
  4. keras_hub/src/layers/modeling/rotary_embedding.py +16 -6
  5. keras_hub/src/layers/modeling/sine_position_encoding.py +21 -8
  6. keras_hub/src/layers/modeling/token_and_position_embedding.py +2 -1
  7. keras_hub/src/models/backbone.py +10 -15
  8. keras_hub/src/models/d_fine/__init__.py +0 -0
  9. keras_hub/src/models/d_fine/d_fine_attention.py +461 -0
  10. keras_hub/src/models/d_fine/d_fine_backbone.py +891 -0
  11. keras_hub/src/models/d_fine/d_fine_decoder.py +944 -0
  12. keras_hub/src/models/d_fine/d_fine_encoder.py +365 -0
  13. keras_hub/src/models/d_fine/d_fine_hybrid_encoder.py +642 -0
  14. keras_hub/src/models/d_fine/d_fine_image_converter.py +8 -0
  15. keras_hub/src/models/d_fine/d_fine_layers.py +1828 -0
  16. keras_hub/src/models/d_fine/d_fine_loss.py +938 -0
  17. keras_hub/src/models/d_fine/d_fine_object_detector.py +875 -0
  18. keras_hub/src/models/d_fine/d_fine_object_detector_preprocessor.py +14 -0
  19. keras_hub/src/models/d_fine/d_fine_presets.py +2 -0
  20. keras_hub/src/models/d_fine/d_fine_utils.py +827 -0
  21. keras_hub/src/models/hgnetv2/hgnetv2_backbone.py +4 -1
  22. keras_hub/src/models/hgnetv2/hgnetv2_encoder.py +3 -2
  23. keras_hub/src/models/hgnetv2/hgnetv2_layers.py +27 -11
  24. keras_hub/src/models/parseq/__init__.py +0 -0
  25. keras_hub/src/models/parseq/parseq_backbone.py +134 -0
  26. keras_hub/src/models/parseq/parseq_causal_lm.py +466 -0
  27. keras_hub/src/models/parseq/parseq_causal_lm_preprocessor.py +168 -0
  28. keras_hub/src/models/parseq/parseq_decoder.py +418 -0
  29. keras_hub/src/models/parseq/parseq_image_converter.py +8 -0
  30. keras_hub/src/models/parseq/parseq_tokenizer.py +221 -0
  31. keras_hub/src/tests/test_case.py +37 -1
  32. keras_hub/src/utils/preset_utils.py +49 -0
  33. keras_hub/src/utils/tensor_utils.py +23 -1
  34. keras_hub/src/utils/transformers/convert_vit.py +4 -1
  35. keras_hub/src/version.py +1 -1
  36. keras_hub/tokenizers/__init__.py +3 -0
  37. {keras_hub_nightly-0.23.0.dev202508260411.dist-info → keras_hub_nightly-0.23.0.dev202508280418.dist-info}/METADATA +1 -1
  38. {keras_hub_nightly-0.23.0.dev202508260411.dist-info → keras_hub_nightly-0.23.0.dev202508280418.dist-info}/RECORD +40 -20
  39. {keras_hub_nightly-0.23.0.dev202508260411.dist-info → keras_hub_nightly-0.23.0.dev202508280418.dist-info}/WHEEL +0 -0
  40. {keras_hub_nightly-0.23.0.dev202508260411.dist-info → keras_hub_nightly-0.23.0.dev202508280418.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,944 @@
1
+ import math
2
+
3
+ import keras
4
+ import numpy as np
5
+
6
+ from keras_hub.src.models.d_fine.d_fine_attention import DFineMultiheadAttention
7
+ from keras_hub.src.models.d_fine.d_fine_attention import (
8
+ DFineMultiscaleDeformableAttention,
9
+ )
10
+ from keras_hub.src.models.d_fine.d_fine_layers import DFineGate
11
+ from keras_hub.src.models.d_fine.d_fine_layers import DFineIntegral
12
+ from keras_hub.src.models.d_fine.d_fine_layers import DFineLQE
13
+ from keras_hub.src.models.d_fine.d_fine_layers import DFineMLP
14
+ from keras_hub.src.models.d_fine.d_fine_layers import DFineMLPPredictionHead
15
+ from keras_hub.src.models.d_fine.d_fine_utils import d_fine_kernel_initializer
16
+ from keras_hub.src.models.d_fine.d_fine_utils import distance2bbox
17
+ from keras_hub.src.models.d_fine.d_fine_utils import inverse_sigmoid
18
+ from keras_hub.src.models.d_fine.d_fine_utils import weighting_function
19
+ from keras_hub.src.utils.keras_utils import clone_initializer
20
+
21
+
22
+ class DFineDecoderLayer(keras.layers.Layer):
23
+ """Single decoder layer for D-FINE models.
24
+
25
+ This layer is the fundamental building block of the `DFineDecoder`. It
26
+ refines a set of object queries by first allowing them to interact with
27
+ each other via self-attention (`DFineMultiheadAttention`), and then
28
+ attending to the image features from the encoder via cross-attention
29
+ (`DFineMultiscaleDeformableAttention`). A feed-forward network with a
30
+ gating mechanism (`DFineGate`) further processes the queries.
31
+
32
+ Args:
33
+ hidden_dim: int, Hidden dimension size for all attention and
34
+ feed-forward layers.
35
+ decoder_attention_heads: int, Number of attention heads for both
36
+ self-attention and cross-attention mechanisms.
37
+ attention_dropout: float, Dropout probability for attention weights.
38
+ decoder_activation_function: str, Activation function name for the
39
+ feed-forward network (e.g., `"relu"`, `"gelu"`, etc).
40
+ dropout: float, General dropout probability applied to layer outputs.
41
+ activation_dropout: float, Dropout probability applied after activation
42
+ in the feed-forward network.
43
+ layer_norm_eps: float, Epsilon value for layer normalization to prevent
44
+ division by zero.
45
+ decoder_ffn_dim: int, Hidden dimension size for the feed-forward
46
+ network.
47
+ num_feature_levels: int, Number of feature pyramid levels to attend to.
48
+ decoder_offset_scale: float, Scaling factor for deformable attention
49
+ offsets.
50
+ decoder_method: str, Method used for deformable attention computation.
51
+ decoder_n_points: int or list, Number of sampling points per feature
52
+ level.
53
+ If int, same number for all levels. If list, specific count per
54
+ level.
55
+ spatial_shapes: list, List of spatial dimensions `(height, width)`
56
+ for each feature level.
57
+ num_queries: int, Number of object queries processed by the decoder.
58
+ kernel_initializer: str or Initializer, optional, Initializer for
59
+ the kernel weights. Defaults to `"glorot_uniform"`.
60
+ bias_initializer: str or Initializer, optional, Initializer for
61
+ the bias weights. Defaults to `"zeros"`.
62
+ **kwargs: Additional keyword arguments passed to the parent class.
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ hidden_dim,
68
+ decoder_attention_heads,
69
+ attention_dropout,
70
+ decoder_activation_function,
71
+ dropout,
72
+ activation_dropout,
73
+ layer_norm_eps,
74
+ decoder_ffn_dim,
75
+ num_feature_levels,
76
+ decoder_offset_scale,
77
+ decoder_method,
78
+ decoder_n_points,
79
+ spatial_shapes,
80
+ num_queries,
81
+ kernel_initializer="glorot_uniform",
82
+ bias_initializer="zeros",
83
+ dtype=None,
84
+ **kwargs,
85
+ ):
86
+ super().__init__(dtype=dtype, **kwargs)
87
+ self.hidden_dim = hidden_dim
88
+ self.num_queries = num_queries
89
+ self.decoder_attention_heads = decoder_attention_heads
90
+ self.attention_dropout_rate = attention_dropout
91
+ self.decoder_activation_function = decoder_activation_function
92
+ self.layer_norm_eps = layer_norm_eps
93
+ self.decoder_ffn_dim = decoder_ffn_dim
94
+ self.num_feature_levels = num_feature_levels
95
+ self.decoder_offset_scale = decoder_offset_scale
96
+ self.decoder_method = decoder_method
97
+ self.decoder_n_points = decoder_n_points
98
+ self.spatial_shapes = spatial_shapes
99
+ self.kernel_initializer = keras.initializers.get(kernel_initializer)
100
+ self.bias_initializer = keras.initializers.get(bias_initializer)
101
+
102
+ self.self_attn = DFineMultiheadAttention(
103
+ embedding_dim=self.hidden_dim,
104
+ num_heads=self.decoder_attention_heads,
105
+ dropout=self.attention_dropout_rate,
106
+ kernel_initializer=clone_initializer(self.kernel_initializer),
107
+ bias_initializer=clone_initializer(self.bias_initializer),
108
+ dtype=self.dtype_policy,
109
+ name="self_attn",
110
+ )
111
+ self.dropout_layer = keras.layers.Dropout(
112
+ rate=dropout, name="dropout_layer", dtype=self.dtype_policy
113
+ )
114
+ self.activation_dropout_layer = keras.layers.Dropout(
115
+ rate=activation_dropout,
116
+ name="activation_dropout_layer",
117
+ dtype=self.dtype_policy,
118
+ )
119
+ self.activation_fn = keras.layers.Activation(
120
+ self.decoder_activation_function,
121
+ name="activation_fn",
122
+ dtype=self.dtype_policy,
123
+ )
124
+ self.self_attn_layer_norm = keras.layers.LayerNormalization(
125
+ epsilon=self.layer_norm_eps,
126
+ name="self_attn_layer_norm",
127
+ dtype=self.dtype_policy,
128
+ )
129
+ self.encoder_attn = DFineMultiscaleDeformableAttention(
130
+ hidden_dim=self.hidden_dim,
131
+ decoder_attention_heads=self.decoder_attention_heads,
132
+ num_feature_levels=self.num_feature_levels,
133
+ decoder_offset_scale=self.decoder_offset_scale,
134
+ dtype=self.dtype_policy,
135
+ decoder_method=self.decoder_method,
136
+ decoder_n_points=self.decoder_n_points,
137
+ spatial_shapes=self.spatial_shapes,
138
+ num_queries=self.num_queries,
139
+ name="encoder_attn",
140
+ )
141
+ self.fc1 = keras.layers.Dense(
142
+ self.decoder_ffn_dim,
143
+ name="fc1",
144
+ dtype=self.dtype_policy,
145
+ kernel_initializer=clone_initializer(self.kernel_initializer),
146
+ bias_initializer=clone_initializer(self.bias_initializer),
147
+ )
148
+ self.fc2 = keras.layers.Dense(
149
+ self.hidden_dim,
150
+ name="fc2",
151
+ dtype=self.dtype_policy,
152
+ kernel_initializer=clone_initializer(self.kernel_initializer),
153
+ bias_initializer=clone_initializer(self.bias_initializer),
154
+ )
155
+ self.final_layer_norm = keras.layers.LayerNormalization(
156
+ epsilon=self.layer_norm_eps,
157
+ name="final_layer_norm",
158
+ dtype=self.dtype_policy,
159
+ )
160
+ self.gateway = DFineGate(
161
+ self.hidden_dim, name="gateway", dtype=self.dtype_policy
162
+ )
163
+
164
+ def build(self, input_shape):
165
+ batch_size = input_shape[0]
166
+ num_queries = input_shape[1]
167
+ hidden_dim = self.hidden_dim
168
+ attention_input_shape = (batch_size, num_queries, hidden_dim)
169
+ self.self_attn.build(attention_input_shape)
170
+ self.encoder_attn.build(attention_input_shape)
171
+ self.fc1.build(attention_input_shape)
172
+ self.fc2.build((batch_size, num_queries, self.decoder_ffn_dim))
173
+ self.gateway.build(attention_input_shape)
174
+ self.self_attn_layer_norm.build(attention_input_shape)
175
+ self.final_layer_norm.build(attention_input_shape)
176
+ super().build(input_shape)
177
+
178
+ def call(
179
+ self,
180
+ hidden_states,
181
+ position_embeddings=None,
182
+ reference_points=None,
183
+ spatial_shapes=None,
184
+ encoder_hidden_states=None,
185
+ attention_mask=None,
186
+ output_attentions=False,
187
+ training=None,
188
+ ):
189
+ self_attn_output, self_attn_weights = self.self_attn(
190
+ hidden_states=hidden_states,
191
+ position_embeddings=position_embeddings,
192
+ attention_mask=attention_mask,
193
+ output_attentions=output_attentions,
194
+ training=training,
195
+ )
196
+ hidden_states_2 = self_attn_output
197
+ hidden_states_2 = self.dropout_layer(hidden_states_2, training=training)
198
+ hidden_states = hidden_states + hidden_states_2
199
+ hidden_states = self.self_attn_layer_norm(
200
+ hidden_states, training=training
201
+ )
202
+ residual = hidden_states
203
+ query_for_cross_attn = residual
204
+ if position_embeddings is not None:
205
+ query_for_cross_attn = query_for_cross_attn + position_embeddings
206
+ encoder_attn_output_tensor, cross_attn_weights_tensor = (
207
+ self.encoder_attn(
208
+ hidden_states=query_for_cross_attn,
209
+ encoder_hidden_states=encoder_hidden_states,
210
+ reference_points=reference_points,
211
+ spatial_shapes=spatial_shapes,
212
+ training=training,
213
+ )
214
+ )
215
+ hidden_states_2 = encoder_attn_output_tensor
216
+ current_cross_attn_weights = (
217
+ cross_attn_weights_tensor if output_attentions else None
218
+ )
219
+ hidden_states_2 = self.dropout_layer(hidden_states_2, training=training)
220
+ hidden_states = self.gateway(
221
+ residual, hidden_states_2, training=training
222
+ )
223
+ hidden_states_ffn = self.fc1(hidden_states)
224
+ hidden_states_2 = self.activation_fn(
225
+ hidden_states_ffn, training=training
226
+ )
227
+ hidden_states_2 = self.activation_dropout_layer(
228
+ hidden_states_2, training=training
229
+ )
230
+ hidden_states_2 = self.fc2(hidden_states_2)
231
+ hidden_states_2 = self.dropout_layer(hidden_states_2, training=training)
232
+ hidden_states = hidden_states + hidden_states_2
233
+ dtype_name = keras.backend.standardize_dtype(self.compute_dtype)
234
+ if dtype_name == "float16":
235
+ clamp_value = np.finfo(np.float16).max - 1000.0
236
+ else: # float32, bfloat16
237
+ clamp_value = np.finfo(np.float32).max - 1000.0
238
+ hidden_states_clamped = keras.ops.clip(
239
+ hidden_states, x_min=-clamp_value, x_max=clamp_value
240
+ )
241
+ hidden_states = self.final_layer_norm(
242
+ hidden_states_clamped, training=training
243
+ )
244
+ return hidden_states, self_attn_weights, current_cross_attn_weights
245
+
246
+ def compute_output_spec(
247
+ self,
248
+ hidden_states,
249
+ position_embeddings=None,
250
+ reference_points=None,
251
+ spatial_shapes=None,
252
+ encoder_hidden_states=None,
253
+ attention_mask=None,
254
+ output_attentions=False,
255
+ training=None,
256
+ ):
257
+ hidden_states_output_spec = keras.KerasTensor(
258
+ shape=hidden_states.shape, dtype=self.compute_dtype
259
+ )
260
+ self_attn_output_spec = self.self_attn.compute_output_spec(
261
+ hidden_states=hidden_states,
262
+ position_embeddings=position_embeddings,
263
+ attention_mask=attention_mask,
264
+ output_attentions=True,
265
+ )
266
+ _, self_attn_weights_spec = self_attn_output_spec
267
+ _, cross_attn_weights_spec = self.encoder_attn.compute_output_spec(
268
+ hidden_states=hidden_states,
269
+ encoder_hidden_states=encoder_hidden_states,
270
+ reference_points=reference_points,
271
+ spatial_shapes=spatial_shapes,
272
+ )
273
+ if not output_attentions:
274
+ self_attn_weights_spec = None
275
+ cross_attn_weights_spec = None
276
+ return (
277
+ hidden_states_output_spec,
278
+ self_attn_weights_spec,
279
+ cross_attn_weights_spec,
280
+ )
281
+
282
+ def get_config(self):
283
+ config = super().get_config()
284
+ config.update(
285
+ {
286
+ "hidden_dim": self.hidden_dim,
287
+ "decoder_attention_heads": self.decoder_attention_heads,
288
+ "attention_dropout": self.attention_dropout_rate,
289
+ "decoder_activation_function": self.decoder_activation_function,
290
+ "dropout": self.dropout_layer.rate,
291
+ "activation_dropout": self.activation_dropout_layer.rate,
292
+ "layer_norm_eps": self.layer_norm_eps,
293
+ "decoder_ffn_dim": self.decoder_ffn_dim,
294
+ "num_feature_levels": self.num_feature_levels,
295
+ "decoder_offset_scale": self.decoder_offset_scale,
296
+ "decoder_method": self.decoder_method,
297
+ "decoder_n_points": self.decoder_n_points,
298
+ "spatial_shapes": self.spatial_shapes,
299
+ "num_queries": self.num_queries,
300
+ "kernel_initializer": keras.initializers.serialize(
301
+ self.kernel_initializer
302
+ ),
303
+ "bias_initializer": keras.initializers.serialize(
304
+ self.bias_initializer
305
+ ),
306
+ }
307
+ )
308
+ return config
309
+
310
+
311
+ class DFineDecoder(keras.layers.Layer):
312
+ """Complete decoder module for D-FINE object detection models.
313
+
314
+ This class implements the full D-FINE decoder, which is responsible for
315
+ transforming a set of object queries into final bounding box and class
316
+ predictions. It consists of a stack of `DFineDecoderLayer` instances that
317
+ iteratively refine the queries. At each layer, prediction heads
318
+ (`class_embed`, `bbox_embed`) generate intermediate outputs, which are used
319
+ for auxiliary loss calculation during training. The final layer's output
320
+ represents the model's predictions.
321
+
322
+ Args:
323
+ eval_idx: int, Index of decoder layer used for evaluation. Negative
324
+ values count from the end (e.g., -1 for last layer).
325
+ num_decoder_layers: int, Number of decoder layers in the stack.
326
+ dropout: float, General dropout probability applied throughout the
327
+ decoder.
328
+ hidden_dim: int, Hidden dimension size for all components.
329
+ reg_scale: float, Scaling factor for regression loss and coordinate
330
+ prediction.
331
+ max_num_bins: int, Maximum number of bins for integral-based coordinate
332
+ prediction.
333
+ upsampling_factor: float, Upsampling factor used in coordinate
334
+ prediction weighting.
335
+ decoder_attention_heads: int, Number of attention heads in each decoder
336
+ layer.
337
+ attention_dropout: float, Dropout probability for attention mechanisms.
338
+ decoder_activation_function: str, Activation function for feed-forward
339
+ networks.
340
+ activation_dropout: float, Dropout probability after activation
341
+ functions.
342
+ layer_norm_eps: float, Epsilon for layer normalization stability.
343
+ decoder_ffn_dim: int, Hidden dimension for feed-forward networks.
344
+ num_feature_levels: int, Number of feature pyramid levels.
345
+ decoder_offset_scale: float, Scaling factor for deformable attention
346
+ offsets.
347
+ decoder_method: str, Method for deformable attention computation,
348
+ either `"default"` or `"discrete"`.
349
+ decoder_n_points: int or list, Number of sampling points per feature
350
+ level.
351
+ top_prob_values: int, Number of top probability values used in LQE.
352
+ lqe_hidden_dim: int, Hidden dimension for LQE networks.
353
+ num_lqe_layers: int, Number of layers in LQE networks.
354
+ num_labels: int, Number of object classes for classification.
355
+ spatial_shapes: list, Spatial dimensions for each feature level.
356
+ layer_scale: float, Scaling factor for layer-wise feature dimensions.
357
+ num_queries: int, Number of object queries processed by the decoder.
358
+ initializer_bias_prior_prob: float, optional, Prior probability for
359
+ the bias of the classification head. Used to initialize the bias
360
+ of the `class_embed` layers. Defaults to `None`.
361
+ **kwargs: Additional keyword arguments passed to the parent class.
362
+ """
363
+
364
+ def __init__(
365
+ self,
366
+ eval_idx,
367
+ num_decoder_layers,
368
+ dropout,
369
+ hidden_dim,
370
+ reg_scale,
371
+ max_num_bins,
372
+ upsampling_factor,
373
+ decoder_attention_heads,
374
+ attention_dropout,
375
+ decoder_activation_function,
376
+ activation_dropout,
377
+ layer_norm_eps,
378
+ decoder_ffn_dim,
379
+ num_feature_levels,
380
+ decoder_offset_scale,
381
+ decoder_method,
382
+ decoder_n_points,
383
+ top_prob_values,
384
+ lqe_hidden_dim,
385
+ num_lqe_layers,
386
+ num_labels,
387
+ spatial_shapes,
388
+ layer_scale,
389
+ num_queries,
390
+ initializer_bias_prior_prob=None,
391
+ dtype=None,
392
+ **kwargs,
393
+ ):
394
+ super().__init__(dtype=dtype, **kwargs)
395
+ self.eval_idx = (
396
+ eval_idx if eval_idx >= 0 else num_decoder_layers + eval_idx
397
+ )
398
+ self.dropout_rate = dropout
399
+ self.num_queries = num_queries
400
+ self.hidden_dim = hidden_dim
401
+ self.num_decoder_layers = num_decoder_layers
402
+ self.reg_scale_val = reg_scale
403
+ self.max_num_bins = max_num_bins
404
+ self.upsampling_factor = upsampling_factor
405
+ self.decoder_attention_heads = decoder_attention_heads
406
+ self.attention_dropout_rate = attention_dropout
407
+ self.decoder_activation_function = decoder_activation_function
408
+ self.activation_dropout_rate = activation_dropout
409
+ self.layer_norm_eps = layer_norm_eps
410
+ self.decoder_ffn_dim = decoder_ffn_dim
411
+ self.num_feature_levels = num_feature_levels
412
+ self.decoder_offset_scale = decoder_offset_scale
413
+ self.decoder_method = decoder_method
414
+ self.decoder_n_points = decoder_n_points
415
+ self.top_prob_values = top_prob_values
416
+ self.lqe_hidden_dim = lqe_hidden_dim
417
+ self.num_lqe_layers = num_lqe_layers
418
+ self.num_labels = num_labels
419
+ self.spatial_shapes = spatial_shapes
420
+ self.layer_scale = layer_scale
421
+ self.initializer_bias_prior_prob = initializer_bias_prior_prob
422
+ self.initializer = d_fine_kernel_initializer()
423
+ self.decoder_layers = []
424
+ for i in range(self.num_decoder_layers):
425
+ self.decoder_layers.append(
426
+ DFineDecoderLayer(
427
+ self.hidden_dim,
428
+ self.decoder_attention_heads,
429
+ self.attention_dropout_rate,
430
+ self.decoder_activation_function,
431
+ self.dropout_rate,
432
+ self.activation_dropout_rate,
433
+ self.layer_norm_eps,
434
+ self.decoder_ffn_dim,
435
+ self.num_feature_levels,
436
+ self.decoder_offset_scale,
437
+ self.decoder_method,
438
+ self.decoder_n_points,
439
+ self.spatial_shapes,
440
+ num_queries=self.num_queries,
441
+ kernel_initializer=clone_initializer(self.initializer),
442
+ bias_initializer="zeros",
443
+ dtype=self.dtype_policy,
444
+ name=f"decoder_layer_{i}",
445
+ )
446
+ )
447
+
448
+ self.query_pos_head = DFineMLPPredictionHead(
449
+ input_dim=4,
450
+ hidden_dim=(2 * self.hidden_dim),
451
+ output_dim=self.hidden_dim,
452
+ num_layers=2,
453
+ dtype=self.dtype_policy,
454
+ kernel_initializer=clone_initializer(self.initializer),
455
+ bias_initializer="zeros",
456
+ name="query_pos_head",
457
+ )
458
+
459
+ num_pred = self.num_decoder_layers
460
+ scaled_dim = round(self.hidden_dim * self.layer_scale)
461
+ if initializer_bias_prior_prob is None:
462
+ prior_prob = 1 / (self.num_labels + 1)
463
+ else:
464
+ prior_prob = initializer_bias_prior_prob
465
+ class_embed_bias = float(-math.log((1 - prior_prob) / prior_prob))
466
+ self.class_embed = [
467
+ keras.layers.Dense(
468
+ self.num_labels,
469
+ name=f"class_embed_{i}",
470
+ dtype=self.dtype_policy,
471
+ kernel_initializer="glorot_uniform",
472
+ bias_initializer=keras.initializers.Constant(class_embed_bias),
473
+ )
474
+ for i in range(num_pred)
475
+ ]
476
+ self.bbox_embed = [
477
+ DFineMLPPredictionHead(
478
+ input_dim=self.hidden_dim,
479
+ hidden_dim=self.hidden_dim,
480
+ output_dim=4 * (self.max_num_bins + 1),
481
+ num_layers=3,
482
+ name=f"bbox_embed_{i}",
483
+ dtype=self.dtype_policy,
484
+ kernel_initializer=clone_initializer(self.initializer),
485
+ bias_initializer="zeros",
486
+ last_layer_initializer="zeros",
487
+ )
488
+ for i in range(self.eval_idx + 1)
489
+ ] + [
490
+ DFineMLPPredictionHead(
491
+ input_dim=scaled_dim,
492
+ hidden_dim=scaled_dim,
493
+ output_dim=4 * (self.max_num_bins + 1),
494
+ num_layers=3,
495
+ name=f"bbox_embed_{i + self.eval_idx + 1}",
496
+ dtype=self.dtype_policy,
497
+ kernel_initializer=clone_initializer(self.initializer),
498
+ bias_initializer="zeros",
499
+ last_layer_initializer="zeros",
500
+ )
501
+ for i in range(self.num_decoder_layers - self.eval_idx - 1)
502
+ ]
503
+ self.pre_bbox_head = DFineMLP(
504
+ input_dim=self.hidden_dim,
505
+ hidden_dim=self.hidden_dim,
506
+ output_dim=4,
507
+ num_layers=3,
508
+ activation_function="relu",
509
+ dtype=self.dtype_policy,
510
+ kernel_initializer=clone_initializer(self.initializer),
511
+ bias_initializer="zeros",
512
+ name="pre_bbox_head",
513
+ )
514
+
515
+ self.integral = DFineIntegral(
516
+ max_num_bins=self.max_num_bins,
517
+ name="integral",
518
+ dtype=self.dtype_policy,
519
+ )
520
+
521
+ self.num_head = self.decoder_attention_heads
522
+
523
+ self.lqe_layers = []
524
+ for i in range(self.num_decoder_layers):
525
+ self.lqe_layers.append(
526
+ DFineLQE(
527
+ top_prob_values=self.top_prob_values,
528
+ max_num_bins=self.max_num_bins,
529
+ lqe_hidden_dim=self.lqe_hidden_dim,
530
+ num_lqe_layers=self.num_lqe_layers,
531
+ dtype=self.dtype_policy,
532
+ name=f"lqe_layer_{i}",
533
+ )
534
+ )
535
+
536
+ def build(self, input_shape):
537
+ if isinstance(input_shape, dict):
538
+ if "inputs_embeds" not in input_shape:
539
+ raise ValueError(
540
+ "DFineDecoder.build() received a dict input_shape "
541
+ "missing 'inputs_embeds' key. Please ensure 'inputs_embeds'"
542
+ " is passed correctly."
543
+ )
544
+ inputs_embeds_shape = input_shape["inputs_embeds"]
545
+ elif (
546
+ isinstance(input_shape, (list, tuple))
547
+ and len(input_shape) > 0
548
+ and isinstance(input_shape[0], (list, tuple))
549
+ ):
550
+ inputs_embeds_shape = input_shape[0]
551
+ else:
552
+ inputs_embeds_shape = input_shape
553
+ if not isinstance(inputs_embeds_shape, tuple):
554
+ raise TypeError(
555
+ f"Internal error: inputs_embeds_shape was expected to be a "
556
+ f"tuple, but got {type(inputs_embeds_shape)} with value "
557
+ f"{inputs_embeds_shape}. Original input_shape: {input_shape}"
558
+ )
559
+
560
+ batch_size_ph = (
561
+ inputs_embeds_shape[0]
562
+ if inputs_embeds_shape
563
+ and len(inputs_embeds_shape) > 0
564
+ and inputs_embeds_shape[0] is not None
565
+ else None
566
+ )
567
+ num_queries_ph = (
568
+ inputs_embeds_shape[1]
569
+ if inputs_embeds_shape
570
+ and len(inputs_embeds_shape) > 1
571
+ and inputs_embeds_shape[1] is not None
572
+ else None
573
+ )
574
+ current_decoder_layer_input_shape = inputs_embeds_shape
575
+ for decoder_layer_instance in self.decoder_layers:
576
+ decoder_layer_instance.build(current_decoder_layer_input_shape)
577
+ qph_input_shape = (batch_size_ph, num_queries_ph, 4)
578
+ self.query_pos_head.build(qph_input_shape)
579
+ pre_bbox_head_input_shape = (
580
+ batch_size_ph,
581
+ num_queries_ph,
582
+ self.hidden_dim,
583
+ )
584
+ self.pre_bbox_head.build(pre_bbox_head_input_shape)
585
+ lqe_scores_shape = (batch_size_ph, num_queries_ph, 1)
586
+ lqe_pred_corners_dim = 4 * (self.max_num_bins + 1)
587
+ lqe_pred_corners_shape = (
588
+ batch_size_ph,
589
+ num_queries_ph,
590
+ lqe_pred_corners_dim,
591
+ )
592
+ lqe_build_input_shape_tuple = (lqe_scores_shape, lqe_pred_corners_shape)
593
+ for lqe_layer in self.lqe_layers:
594
+ lqe_layer.build(lqe_build_input_shape_tuple)
595
+ self.reg_scale = self.add_weight(
596
+ name="reg_scale",
597
+ shape=(1,),
598
+ initializer=keras.initializers.Constant(self.reg_scale_val),
599
+ trainable=False,
600
+ )
601
+ self.upsampling_factor = self.add_weight(
602
+ name="upsampling_factor",
603
+ shape=(1,),
604
+ initializer=keras.initializers.Constant(self.upsampling_factor),
605
+ trainable=False,
606
+ )
607
+ input_shape_for_class_embed = (
608
+ batch_size_ph,
609
+ num_queries_ph,
610
+ self.hidden_dim,
611
+ )
612
+ for class_embed_layer in self.class_embed:
613
+ class_embed_layer.build(input_shape_for_class_embed)
614
+ input_shape_for_bbox_embed = (
615
+ batch_size_ph,
616
+ num_queries_ph,
617
+ self.hidden_dim,
618
+ )
619
+ for bbox_embed_layer in self.bbox_embed:
620
+ bbox_embed_layer.build(input_shape_for_bbox_embed)
621
+ super().build(input_shape)
622
+
623
+ def compute_output_spec(
624
+ self,
625
+ inputs_embeds,
626
+ encoder_hidden_states,
627
+ reference_points,
628
+ spatial_shapes,
629
+ attention_mask=None,
630
+ output_hidden_states=None,
631
+ output_attentions=None,
632
+ training=None,
633
+ ):
634
+ output_attentions = (
635
+ False if output_attentions is None else output_attentions
636
+ )
637
+ output_hidden_states = (
638
+ False if output_hidden_states is None else output_hidden_states
639
+ )
640
+ batch_size = inputs_embeds.shape[0]
641
+ num_queries = inputs_embeds.shape[1]
642
+ hidden_dim = inputs_embeds.shape[2]
643
+ last_hidden_state_spec = keras.KerasTensor(
644
+ shape=(batch_size, num_queries, hidden_dim),
645
+ dtype=self.compute_dtype,
646
+ )
647
+ intermediate_hidden_states_spec = None
648
+ if output_hidden_states:
649
+ intermediate_hidden_states_spec = keras.KerasTensor(
650
+ shape=(
651
+ batch_size,
652
+ self.num_decoder_layers,
653
+ num_queries,
654
+ hidden_dim,
655
+ ),
656
+ dtype=self.compute_dtype,
657
+ )
658
+ num_layers_with_logits = self.num_decoder_layers + 1
659
+ intermediate_logits_spec = keras.KerasTensor(
660
+ shape=(
661
+ batch_size,
662
+ num_layers_with_logits,
663
+ num_queries,
664
+ self.num_labels,
665
+ ),
666
+ dtype=self.compute_dtype,
667
+ )
668
+ intermediate_reference_points_spec = keras.KerasTensor(
669
+ shape=(batch_size, num_layers_with_logits, num_queries, 4),
670
+ dtype=self.compute_dtype,
671
+ )
672
+ intermediate_predicted_corners_spec = keras.KerasTensor(
673
+ shape=(
674
+ batch_size,
675
+ num_layers_with_logits,
676
+ num_queries,
677
+ 4 * (self.max_num_bins + 1),
678
+ ),
679
+ dtype=self.compute_dtype,
680
+ )
681
+ initial_reference_points_spec = keras.KerasTensor(
682
+ shape=(batch_size, num_layers_with_logits, num_queries, 4),
683
+ dtype=self.compute_dtype,
684
+ )
685
+ all_hidden_states_spec = None
686
+ all_self_attns_spec = None
687
+ all_cross_attentions_spec = None
688
+ if output_hidden_states:
689
+ all_hidden_states_spec = tuple(
690
+ [last_hidden_state_spec] * (self.num_decoder_layers + 1)
691
+ )
692
+ if output_attentions:
693
+ (
694
+ _,
695
+ self_attn_spec,
696
+ cross_attn_spec,
697
+ ) = self.decoder_layers[0].compute_output_spec(
698
+ hidden_states=inputs_embeds,
699
+ encoder_hidden_states=encoder_hidden_states,
700
+ output_attentions=True,
701
+ )
702
+ all_self_attns_spec = tuple(
703
+ [self_attn_spec] * self.num_decoder_layers
704
+ )
705
+ if encoder_hidden_states is not None:
706
+ all_cross_attentions_spec = tuple(
707
+ [cross_attn_spec] * self.num_decoder_layers
708
+ )
709
+ outputs_tuple = [
710
+ last_hidden_state_spec,
711
+ intermediate_hidden_states_spec,
712
+ intermediate_logits_spec,
713
+ intermediate_reference_points_spec,
714
+ intermediate_predicted_corners_spec,
715
+ initial_reference_points_spec,
716
+ all_hidden_states_spec,
717
+ all_self_attns_spec,
718
+ all_cross_attentions_spec,
719
+ ]
720
+ return tuple(v for v in outputs_tuple if v is not None)
721
+
722
+ def call(
723
+ self,
724
+ inputs_embeds,
725
+ encoder_hidden_states,
726
+ reference_points,
727
+ spatial_shapes,
728
+ attention_mask=None,
729
+ output_hidden_states=None,
730
+ output_attentions=None,
731
+ training=None,
732
+ ):
733
+ output_attentions = (
734
+ False if output_attentions is None else output_attentions
735
+ )
736
+ output_hidden_states = (
737
+ False if output_hidden_states is None else output_hidden_states
738
+ )
739
+
740
+ hidden_states = inputs_embeds
741
+
742
+ all_hidden_states = [] if output_hidden_states else None
743
+ all_self_attns = [] if output_attentions else None
744
+ all_cross_attentions = (
745
+ []
746
+ if (output_attentions and encoder_hidden_states is not None)
747
+ else None
748
+ )
749
+
750
+ intermediate_hidden_states = []
751
+ intermediate_reference_points = []
752
+ intermediate_logits = []
753
+ intermediate_predicted_corners = []
754
+ initial_reference_points = []
755
+
756
+ output_detach = (
757
+ keras.ops.zeros_like(hidden_states)
758
+ if hidden_states is not None
759
+ else 0
760
+ )
761
+ pred_corners_undetach = 0
762
+
763
+ project_flat = weighting_function(
764
+ self.max_num_bins, self.upsampling_factor, self.reg_scale
765
+ )
766
+ project = keras.ops.expand_dims(project_flat, axis=0)
767
+
768
+ ref_points_detach = keras.ops.sigmoid(reference_points)
769
+
770
+ for i, decoder_layer_instance in enumerate(self.decoder_layers):
771
+ ref_points_input = keras.ops.expand_dims(ref_points_detach, axis=2)
772
+ query_pos_embed = self.query_pos_head(
773
+ ref_points_detach, training=training
774
+ )
775
+ query_pos_embed = keras.ops.clip(query_pos_embed, -10.0, 10.0)
776
+
777
+ if output_hidden_states:
778
+ all_hidden_states.append(hidden_states)
779
+
780
+ output_tuple = decoder_layer_instance(
781
+ hidden_states=hidden_states,
782
+ position_embeddings=query_pos_embed,
783
+ reference_points=ref_points_input,
784
+ spatial_shapes=spatial_shapes,
785
+ encoder_hidden_states=encoder_hidden_states,
786
+ attention_mask=attention_mask,
787
+ output_attentions=output_attentions,
788
+ training=training,
789
+ )
790
+ hidden_states = output_tuple[0]
791
+ self_attn_weights_from_layer = output_tuple[1]
792
+ cross_attn_weights_from_layer = output_tuple[2]
793
+
794
+ if i == 0:
795
+ pre_bbox_head_output = self.pre_bbox_head(
796
+ hidden_states, training=training
797
+ )
798
+ new_reference_points = keras.ops.sigmoid(
799
+ pre_bbox_head_output + inverse_sigmoid(ref_points_detach)
800
+ )
801
+ ref_points_initial = keras.ops.stop_gradient(
802
+ new_reference_points
803
+ )
804
+
805
+ if self.bbox_embed is not None:
806
+ bbox_embed_input = hidden_states + output_detach
807
+ pred_corners = (
808
+ self.bbox_embed[i](bbox_embed_input, training=training)
809
+ + pred_corners_undetach
810
+ )
811
+ integral_output = self.integral(
812
+ pred_corners, project, training=training
813
+ )
814
+ inter_ref_bbox = distance2bbox(
815
+ ref_points_initial, integral_output, self.reg_scale
816
+ )
817
+ pred_corners_undetach = pred_corners
818
+ ref_points_detach = keras.ops.stop_gradient(inter_ref_bbox)
819
+
820
+ output_detach = keras.ops.stop_gradient(hidden_states)
821
+
822
+ intermediate_hidden_states.append(hidden_states)
823
+
824
+ if self.class_embed is not None and self.bbox_embed is not None:
825
+ class_scores = self.class_embed[i](hidden_states)
826
+ refined_scores = self.lqe_layers[i](
827
+ class_scores, pred_corners, training=training
828
+ )
829
+ if i == 0:
830
+ # NOTE: For first layer, output both, pre-LQE and post-LQE
831
+ # predictions, to provide an initial estimate. In the orig.
832
+ # implementation, the `torch.stack()` op would've thrown
833
+ # an error due to mismatched lengths.
834
+ intermediate_logits.append(class_scores)
835
+ intermediate_reference_points.append(new_reference_points)
836
+ initial_reference_points.append(ref_points_initial)
837
+ intermediate_predicted_corners.append(pred_corners)
838
+ intermediate_logits.append(refined_scores)
839
+ intermediate_reference_points.append(inter_ref_bbox)
840
+ initial_reference_points.append(ref_points_initial)
841
+ intermediate_predicted_corners.append(pred_corners)
842
+
843
+ if output_attentions:
844
+ if self_attn_weights_from_layer is not None:
845
+ all_self_attns.append(self_attn_weights_from_layer)
846
+ if (
847
+ encoder_hidden_states is not None
848
+ and cross_attn_weights_from_layer is not None
849
+ ):
850
+ all_cross_attentions.append(cross_attn_weights_from_layer)
851
+
852
+ intermediate_stacked = (
853
+ keras.ops.stack(intermediate_hidden_states, axis=1)
854
+ if intermediate_hidden_states
855
+ else None
856
+ )
857
+
858
+ if self.class_embed is not None and self.bbox_embed is not None:
859
+ intermediate_logits_stacked = (
860
+ keras.ops.stack(intermediate_logits, axis=1)
861
+ if intermediate_logits
862
+ else None
863
+ )
864
+ intermediate_predicted_corners_stacked = (
865
+ keras.ops.stack(intermediate_predicted_corners, axis=1)
866
+ if intermediate_predicted_corners
867
+ else None
868
+ )
869
+ initial_reference_points_stacked = (
870
+ keras.ops.stack(initial_reference_points, axis=1)
871
+ if initial_reference_points
872
+ else None
873
+ )
874
+ intermediate_reference_points_stacked = (
875
+ keras.ops.stack(intermediate_reference_points, axis=1)
876
+ if intermediate_reference_points
877
+ else None
878
+ )
879
+ else:
880
+ intermediate_logits_stacked = None
881
+ intermediate_predicted_corners_stacked = None
882
+ initial_reference_points_stacked = None
883
+ intermediate_reference_points_stacked = None
884
+
885
+ if output_hidden_states:
886
+ all_hidden_states.append(hidden_states)
887
+
888
+ all_hidden_states_tuple = (
889
+ tuple(all_hidden_states) if output_hidden_states else None
890
+ )
891
+ all_self_attns_tuple = (
892
+ tuple(all_self_attns) if output_attentions else None
893
+ )
894
+ all_cross_attentions_tuple = (
895
+ tuple(all_cross_attentions)
896
+ if (output_attentions and encoder_hidden_states is not None)
897
+ else None
898
+ )
899
+
900
+ outputs_tuple = [
901
+ hidden_states,
902
+ intermediate_stacked,
903
+ intermediate_logits_stacked,
904
+ intermediate_reference_points_stacked,
905
+ intermediate_predicted_corners_stacked,
906
+ initial_reference_points_stacked,
907
+ all_hidden_states_tuple,
908
+ all_self_attns_tuple,
909
+ all_cross_attentions_tuple,
910
+ ]
911
+ return tuple(v for v in outputs_tuple if v is not None)
912
+
913
+ def get_config(self):
914
+ config = super().get_config()
915
+ config.update(
916
+ {
917
+ "eval_idx": self.eval_idx,
918
+ "num_decoder_layers": self.num_decoder_layers,
919
+ "dropout": self.dropout_rate,
920
+ "hidden_dim": self.hidden_dim,
921
+ "reg_scale": self.reg_scale_val,
922
+ "max_num_bins": self.max_num_bins,
923
+ "upsampling_factor": self.upsampling_factor,
924
+ "decoder_attention_heads": self.decoder_attention_heads,
925
+ "attention_dropout": self.attention_dropout_rate,
926
+ "decoder_activation_function": self.decoder_activation_function,
927
+ "activation_dropout": self.activation_dropout_rate,
928
+ "layer_norm_eps": self.layer_norm_eps,
929
+ "decoder_ffn_dim": self.decoder_ffn_dim,
930
+ "num_feature_levels": self.num_feature_levels,
931
+ "decoder_offset_scale": self.decoder_offset_scale,
932
+ "decoder_method": self.decoder_method,
933
+ "decoder_n_points": self.decoder_n_points,
934
+ "top_prob_values": self.top_prob_values,
935
+ "lqe_hidden_dim": self.lqe_hidden_dim,
936
+ "num_lqe_layers": self.num_lqe_layers,
937
+ "num_labels": self.num_labels,
938
+ "spatial_shapes": self.spatial_shapes,
939
+ "layer_scale": self.layer_scale,
940
+ "num_queries": self.num_queries,
941
+ "initializer_bias_prior_prob": self.initializer_bias_prior_prob,
942
+ }
943
+ )
944
+ return config