keras-hub-nightly 0.23.0.dev202508260411__py3-none-any.whl → 0.23.0.dev202508280418__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. keras_hub/layers/__init__.py +6 -0
  2. keras_hub/models/__init__.py +21 -0
  3. keras_hub/src/layers/modeling/position_embedding.py +21 -6
  4. keras_hub/src/layers/modeling/rotary_embedding.py +16 -6
  5. keras_hub/src/layers/modeling/sine_position_encoding.py +21 -8
  6. keras_hub/src/layers/modeling/token_and_position_embedding.py +2 -1
  7. keras_hub/src/models/backbone.py +10 -15
  8. keras_hub/src/models/d_fine/__init__.py +0 -0
  9. keras_hub/src/models/d_fine/d_fine_attention.py +461 -0
  10. keras_hub/src/models/d_fine/d_fine_backbone.py +891 -0
  11. keras_hub/src/models/d_fine/d_fine_decoder.py +944 -0
  12. keras_hub/src/models/d_fine/d_fine_encoder.py +365 -0
  13. keras_hub/src/models/d_fine/d_fine_hybrid_encoder.py +642 -0
  14. keras_hub/src/models/d_fine/d_fine_image_converter.py +8 -0
  15. keras_hub/src/models/d_fine/d_fine_layers.py +1828 -0
  16. keras_hub/src/models/d_fine/d_fine_loss.py +938 -0
  17. keras_hub/src/models/d_fine/d_fine_object_detector.py +875 -0
  18. keras_hub/src/models/d_fine/d_fine_object_detector_preprocessor.py +14 -0
  19. keras_hub/src/models/d_fine/d_fine_presets.py +2 -0
  20. keras_hub/src/models/d_fine/d_fine_utils.py +827 -0
  21. keras_hub/src/models/hgnetv2/hgnetv2_backbone.py +4 -1
  22. keras_hub/src/models/hgnetv2/hgnetv2_encoder.py +3 -2
  23. keras_hub/src/models/hgnetv2/hgnetv2_layers.py +27 -11
  24. keras_hub/src/models/parseq/__init__.py +0 -0
  25. keras_hub/src/models/parseq/parseq_backbone.py +134 -0
  26. keras_hub/src/models/parseq/parseq_causal_lm.py +466 -0
  27. keras_hub/src/models/parseq/parseq_causal_lm_preprocessor.py +168 -0
  28. keras_hub/src/models/parseq/parseq_decoder.py +418 -0
  29. keras_hub/src/models/parseq/parseq_image_converter.py +8 -0
  30. keras_hub/src/models/parseq/parseq_tokenizer.py +221 -0
  31. keras_hub/src/tests/test_case.py +37 -1
  32. keras_hub/src/utils/preset_utils.py +49 -0
  33. keras_hub/src/utils/tensor_utils.py +23 -1
  34. keras_hub/src/utils/transformers/convert_vit.py +4 -1
  35. keras_hub/src/version.py +1 -1
  36. keras_hub/tokenizers/__init__.py +3 -0
  37. {keras_hub_nightly-0.23.0.dev202508260411.dist-info → keras_hub_nightly-0.23.0.dev202508280418.dist-info}/METADATA +1 -1
  38. {keras_hub_nightly-0.23.0.dev202508260411.dist-info → keras_hub_nightly-0.23.0.dev202508280418.dist-info}/RECORD +40 -20
  39. {keras_hub_nightly-0.23.0.dev202508260411.dist-info → keras_hub_nightly-0.23.0.dev202508280418.dist-info}/WHEEL +0 -0
  40. {keras_hub_nightly-0.23.0.dev202508260411.dist-info → keras_hub_nightly-0.23.0.dev202508280418.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1828 @@
1
+ import keras
2
+ import numpy as np
3
+
4
+ from keras_hub.src.models.d_fine.d_fine_utils import inverse_sigmoid
5
+
6
+
7
+ class DFineGate(keras.layers.Layer):
8
+ """Gating layer for combining two input tensors using learnable gates.
9
+
10
+ This layer is used within the `DFineDecoderLayer` to merge the output of
11
+ the self-attention mechanism (residual) with the output of the
12
+ cross-attention mechanism (`hidden_states`). It computes a weighted sum of
13
+ the two inputs, where the weights are learned gates. The result is
14
+ normalized using layer normalization.
15
+
16
+ Args:
17
+ hidden_dim: int, The hidden dimension size for the gate computation.
18
+ **kwargs: Additional keyword arguments passed to the parent class.
19
+ """
20
+
21
+ def __init__(self, hidden_dim, dtype=None, **kwargs):
22
+ super().__init__(dtype=dtype, **kwargs)
23
+ self.hidden_dim = hidden_dim
24
+ self.norm = keras.layers.LayerNormalization(
25
+ epsilon=1e-5, name="norm", dtype=self.dtype_policy
26
+ )
27
+ self.gate = keras.layers.Dense(
28
+ 2 * self.hidden_dim,
29
+ name="gate",
30
+ dtype=self.dtype_policy,
31
+ kernel_initializer="zeros",
32
+ bias_initializer="zeros",
33
+ )
34
+
35
+ def build(self, input_shape):
36
+ batch_dim, seq_len_dim = None, None
37
+ if input_shape and len(input_shape) == 3:
38
+ batch_dim = input_shape[0]
39
+ seq_len_dim = input_shape[1]
40
+ gate_build_shape = (batch_dim, seq_len_dim, 2 * self.hidden_dim)
41
+ self.gate.build(gate_build_shape)
42
+ norm_build_shape = (batch_dim, seq_len_dim, self.hidden_dim)
43
+ self.norm.build(norm_build_shape)
44
+ super().build(input_shape)
45
+
46
+ def call(self, second_residual, hidden_states, training=None):
47
+ gate_input = keras.ops.concatenate(
48
+ [second_residual, hidden_states], axis=-1
49
+ )
50
+ gates_linear_output = self.gate(gate_input)
51
+ gates = keras.ops.sigmoid(gates_linear_output)
52
+ gate_chunks = keras.ops.split(gates, 2, axis=-1)
53
+ gate1 = gate_chunks[0]
54
+ gate2 = gate_chunks[1]
55
+ gated_sum = gate1 * second_residual + gate2 * hidden_states
56
+ hidden_states = self.norm(gated_sum, training=training)
57
+ return hidden_states
58
+
59
+ def get_config(self):
60
+ config = super().get_config()
61
+ config.update({"hidden_dim": self.hidden_dim})
62
+ return config
63
+
64
+
65
+ class DFineMLP(keras.layers.Layer):
66
+ """Multi-layer perceptron (MLP) layer.
67
+
68
+ This layer implements a standard MLP. It is used in several places within
69
+ the D-FINE model, such as the `reg_conf` head inside `DFineLQE` for
70
+ predicting quality scores and the `pre_bbox_head` in `DFineDecoder` for
71
+ initial bounding box predictions.
72
+
73
+ Args:
74
+ input_dim: int, The input dimension.
75
+ hidden_dim: int, The hidden dimension for intermediate layers.
76
+ output_dim: int, The output dimension.
77
+ num_layers: int, The number of layers in the MLP.
78
+ activation_function: str, The activation function to use between layers.
79
+ kernel_initializer: str or Initializer, optional, Initializer for
80
+ the kernel weights. Defaults to `"glorot_uniform"`.
81
+ bias_initializer: str or Initializer, optional, Initializer for
82
+ the bias weights. Defaults to `"zeros"`.
83
+ last_layer_initializer: str or Initializer, optional, Special
84
+ initializer for the final layer's weights and biases. If `None`,
85
+ uses `kernel_initializer` and `bias_initializer`. Defaults to
86
+ `None`.
87
+ **kwargs: Additional keyword arguments passed to the parent class.
88
+ """
89
+
90
+ def __init__(
91
+ self,
92
+ input_dim,
93
+ hidden_dim,
94
+ output_dim,
95
+ num_layers,
96
+ activation_function="relu",
97
+ kernel_initializer="glorot_uniform",
98
+ bias_initializer="zeros",
99
+ last_layer_initializer=None,
100
+ dtype=None,
101
+ **kwargs,
102
+ ):
103
+ super().__init__(dtype=dtype, **kwargs)
104
+ self.num_layers = num_layers
105
+ self.input_dim = input_dim
106
+ self.hidden_dim = hidden_dim
107
+ self.output_dim = output_dim
108
+ self.activation_function = activation_function
109
+ self.kernel_initializer = keras.initializers.get(kernel_initializer)
110
+ self.bias_initializer = keras.initializers.get(bias_initializer)
111
+ # NOTE: In the original code, this is done by searching the modules for
112
+ # specific last layers, instead, we find the last layer in each of the
113
+ # specific modules with `num_layers - 1`.
114
+ self.last_layer_initializer = keras.initializers.get(
115
+ last_layer_initializer
116
+ )
117
+ h = [hidden_dim] * (num_layers - 1)
118
+ input_dims = [input_dim] + h
119
+ output_dims = h + [output_dim]
120
+ self.dense_layers = []
121
+ for i, (_, out_dim) in enumerate(zip(input_dims, output_dims)):
122
+ # NOTE: Req. for handling the case of initializing the final layers'
123
+ # weights and biases to zero when required (for ex: `bbox_embed` or
124
+ # `reg_conf`).
125
+ is_last_layer = i == num_layers - 1
126
+ current_kernel_init = self.kernel_initializer
127
+ current_bias_init = self.bias_initializer
128
+ if is_last_layer and self.last_layer_initializer is not None:
129
+ current_kernel_init = self.last_layer_initializer
130
+ current_bias_init = self.last_layer_initializer
131
+ self.dense_layers.append(
132
+ keras.layers.Dense(
133
+ units=out_dim,
134
+ name=f"mlp_dense_layer_{i}",
135
+ dtype=self.dtype_policy,
136
+ kernel_initializer=current_kernel_init,
137
+ bias_initializer=current_bias_init,
138
+ )
139
+ )
140
+ self.activation_layer = keras.layers.Activation(
141
+ activation_function,
142
+ name="mlp_activation_layer",
143
+ dtype=self.dtype_policy,
144
+ )
145
+
146
+ def build(self, input_shape):
147
+ if self.dense_layers:
148
+ current_build_shape = input_shape
149
+ for i, dense_layer in enumerate(self.dense_layers):
150
+ dense_layer.build(current_build_shape)
151
+ current_build_shape = dense_layer.compute_output_shape(
152
+ current_build_shape
153
+ )
154
+ super().build(input_shape)
155
+
156
+ def call(self, stat_features, training=None):
157
+ x = stat_features
158
+ for i in range(self.num_layers):
159
+ dense_layer = self.dense_layers[i]
160
+ x = dense_layer(x)
161
+ if i < self.num_layers - 1:
162
+ x = self.activation_layer(x)
163
+ return x
164
+
165
+ def compute_output_spec(self, stat_features_spec):
166
+ output_shape = list(stat_features_spec.shape)
167
+ output_shape[-1] = self.output_dim
168
+ return keras.KerasTensor(
169
+ shape=tuple(output_shape), dtype=self.compute_dtype
170
+ )
171
+
172
+ def get_config(self):
173
+ config = super().get_config()
174
+ config.update(
175
+ {
176
+ "input_dim": self.input_dim,
177
+ "hidden_dim": self.hidden_dim,
178
+ "output_dim": self.output_dim,
179
+ "num_layers": self.num_layers,
180
+ "activation_function": self.activation_function,
181
+ "kernel_initializer": keras.initializers.serialize(
182
+ self.kernel_initializer
183
+ ),
184
+ "bias_initializer": keras.initializers.serialize(
185
+ self.bias_initializer
186
+ ),
187
+ "last_layer_initializer": keras.initializers.serialize(
188
+ self.last_layer_initializer
189
+ ),
190
+ }
191
+ )
192
+ return config
193
+
194
+
195
+ class DFineSourceFlattener(keras.layers.Layer):
196
+ """Layer to flatten and concatenate a list of source tensors.
197
+
198
+ This layer is used in `DFineBackbone` to process feature maps from the
199
+ `DFineHybridEncoder`. It takes a list of multi-scale feature maps,
200
+ flattens each along its spatial dimensions, and concatenates them
201
+ along the sequence dimension.
202
+
203
+ Args:
204
+ channel_axis: int, optional, The channel axis. Defaults to `None`.
205
+ data_format: str, optional, The data format. Defaults to `None`.
206
+ **kwargs: Additional keyword arguments passed to the parent class.
207
+ """
208
+
209
+ def __init__(
210
+ self, channel_axis=None, data_format=None, dtype=None, **kwargs
211
+ ):
212
+ super().__init__(dtype=dtype, **kwargs)
213
+ self.channel_axis = channel_axis
214
+ self.data_format = data_format
215
+
216
+ def call(self, sources, training=None):
217
+ source_flatten = []
218
+ for i, source_item in enumerate(sources):
219
+ if self.data_format == "channels_first":
220
+ source_item = keras.ops.transpose(source_item, [0, 2, 3, 1])
221
+ batch_size = keras.ops.shape(source_item)[0]
222
+ channels = keras.ops.shape(source_item)[-1]
223
+ source_reshaped = keras.ops.reshape(
224
+ source_item, (batch_size, -1, channels)
225
+ )
226
+ source_flatten.append(source_reshaped)
227
+ source_flatten_concatenated = keras.ops.concatenate(
228
+ source_flatten, axis=1
229
+ )
230
+ return source_flatten_concatenated
231
+
232
+ def compute_output_shape(self, sources_shape):
233
+ if not sources_shape or not isinstance(sources_shape, list):
234
+ return tuple()
235
+ if not all(isinstance(s, tuple) and len(s) == 4 for s in sources_shape):
236
+ return tuple()
237
+ batch_size = sources_shape[0][0]
238
+ if self.data_format == "channels_first":
239
+ channels = sources_shape[0][1]
240
+ else:
241
+ channels = sources_shape[0][-1]
242
+ calculated_spatial_elements = []
243
+ for s_shape in sources_shape:
244
+ if self.data_format == "channels_first":
245
+ h, w = s_shape[2], s_shape[3]
246
+ else:
247
+ h, w = s_shape[1], s_shape[2]
248
+ if h is None or w is None:
249
+ calculated_spatial_elements.append(None)
250
+ else:
251
+ calculated_spatial_elements.append(h * w)
252
+ if any(elem is None for elem in calculated_spatial_elements):
253
+ total_spatial_elements = None
254
+ else:
255
+ total_spatial_elements = sum(calculated_spatial_elements)
256
+ return (batch_size, total_spatial_elements, channels)
257
+
258
+ def get_config(self):
259
+ config = super().get_config()
260
+ config.update(
261
+ {
262
+ "channel_axis": self.channel_axis,
263
+ "data_format": self.data_format,
264
+ }
265
+ )
266
+ return config
267
+
268
+
269
+ class DFineContrastiveDenoisingGroupGenerator(keras.layers.Layer):
270
+ """Layer to generate denoising groups for contrastive learning.
271
+
272
+ This layer, used in `DFineBackbone`, implements the core logic for
273
+ contrastive denoising, a key training strategy in D-FINE. It takes ground
274
+ truth `targets`, adds controlled noise to labels and boxes, and generates
275
+ the necessary attention masks, queries, and reference points for the
276
+ decoder. Due to functional model constraints, noise is generated once at
277
+ model initialization.
278
+
279
+ Args:
280
+ num_labels: int, The number of object classes.
281
+ num_denoising: int, The number of denoising queries.
282
+ label_noise_ratio: float, The ratio of label noise to apply.
283
+ box_noise_scale: float, The scale of box noise to apply.
284
+ seed: int, optional, The random seed for noise generation.
285
+ **kwargs: Additional keyword arguments passed to the parent class.
286
+ """
287
+
288
+ def __init__(
289
+ self,
290
+ num_labels,
291
+ num_denoising,
292
+ label_noise_ratio,
293
+ box_noise_scale,
294
+ seed=None,
295
+ dtype=None,
296
+ **kwargs,
297
+ ):
298
+ super().__init__(dtype=dtype, **kwargs)
299
+ self.num_labels = num_labels
300
+ self.num_denoising = num_denoising
301
+ self.label_noise_ratio = label_noise_ratio
302
+ self.box_noise_scale = box_noise_scale
303
+ self.seed = seed
304
+ self.seed_generator = keras.random.SeedGenerator(seed)
305
+
306
+ def build(self, input_shape):
307
+ super().build(input_shape)
308
+
309
+ def call(self, targets, num_queries):
310
+ if self.num_denoising <= 0:
311
+ return None, None, None, None
312
+ num_ground_truths = [len(t["labels"]) for t in targets]
313
+ max_gt_num = 0
314
+ if num_ground_truths:
315
+ max_gt_num = max(num_ground_truths)
316
+ if max_gt_num == 0:
317
+ return None, None, None, None
318
+ num_groups_denoising_queries = self.num_denoising // max_gt_num
319
+ num_groups_denoising_queries = (
320
+ 1
321
+ if num_groups_denoising_queries == 0
322
+ else num_groups_denoising_queries
323
+ )
324
+ batch_size = len(num_ground_truths)
325
+ input_query_class = []
326
+ input_query_bbox = []
327
+ pad_gt_mask = []
328
+ for i in range(batch_size):
329
+ num_gt = num_ground_truths[i]
330
+ if num_gt > 0:
331
+ labels = targets[i]["labels"]
332
+ boxes = targets[i]["boxes"]
333
+ padded_class_labels = keras.ops.pad(
334
+ labels,
335
+ [[0, max_gt_num - num_gt]],
336
+ constant_values=self.num_labels,
337
+ )
338
+ padded_boxes = keras.ops.pad(
339
+ keras.ops.cast(boxes, dtype=self.compute_dtype),
340
+ [[0, max_gt_num - num_gt], [0, 0]],
341
+ constant_values=0.0,
342
+ )
343
+ mask = keras.ops.concatenate(
344
+ [
345
+ keras.ops.ones([num_gt], dtype="bool"),
346
+ keras.ops.zeros([max_gt_num - num_gt], dtype="bool"),
347
+ ]
348
+ )
349
+ else:
350
+ padded_class_labels = keras.ops.full(
351
+ [max_gt_num], self.num_labels, dtype="int32"
352
+ )
353
+ padded_boxes = keras.ops.zeros(
354
+ [max_gt_num, 4], dtype=self.compute_dtype
355
+ )
356
+ mask = keras.ops.zeros([max_gt_num], dtype="bool")
357
+ input_query_class.append(padded_class_labels)
358
+ input_query_bbox.append(padded_boxes)
359
+ pad_gt_mask.append(mask)
360
+ input_query_class = keras.ops.stack(input_query_class, axis=0)
361
+ input_query_bbox = keras.ops.stack(input_query_bbox, axis=0)
362
+ pad_gt_mask = keras.ops.stack(pad_gt_mask, axis=0)
363
+ input_query_class = keras.ops.tile(
364
+ input_query_class, [1, 2 * num_groups_denoising_queries]
365
+ )
366
+ input_query_bbox = keras.ops.tile(
367
+ input_query_bbox, [1, 2 * num_groups_denoising_queries, 1]
368
+ )
369
+ pad_gt_mask = keras.ops.tile(
370
+ pad_gt_mask, [1, 2 * num_groups_denoising_queries]
371
+ )
372
+ negative_gt_mask = keras.ops.zeros(
373
+ [batch_size, max_gt_num * 2, 1], dtype=self.compute_dtype
374
+ )
375
+ updates_neg = keras.ops.ones(
376
+ [batch_size, max_gt_num, 1], dtype=negative_gt_mask.dtype
377
+ )
378
+ negative_gt_mask = keras.ops.slice_update(
379
+ negative_gt_mask, [0, max_gt_num, 0], updates_neg
380
+ )
381
+ negative_gt_mask = keras.ops.tile(
382
+ negative_gt_mask, [1, num_groups_denoising_queries, 1]
383
+ )
384
+ positive_gt_mask_float = 1.0 - negative_gt_mask
385
+ squeezed_positive_gt_mask = keras.ops.squeeze(
386
+ positive_gt_mask_float, axis=-1
387
+ )
388
+ positive_gt_mask = squeezed_positive_gt_mask * keras.ops.cast(
389
+ pad_gt_mask, dtype=squeezed_positive_gt_mask.dtype
390
+ )
391
+ denoise_positive_idx = []
392
+ for i in range(batch_size):
393
+ mask_i = positive_gt_mask[i]
394
+ idx = keras.ops.nonzero(mask_i)[0]
395
+ denoise_positive_idx.append(idx)
396
+ if self.label_noise_ratio > 0:
397
+ noise_mask = keras.random.uniform(
398
+ keras.ops.shape(input_query_class),
399
+ dtype=self.compute_dtype,
400
+ seed=self.seed_generator,
401
+ ) < (self.label_noise_ratio * 0.5)
402
+ max_len = 0
403
+ for idx in denoise_positive_idx:
404
+ current_len = keras.ops.shape(idx)[0]
405
+ if current_len > max_len:
406
+ max_len = current_len
407
+ padded_indices = []
408
+ for idx in denoise_positive_idx:
409
+ current_len = keras.ops.shape(idx)[0]
410
+ pad_len = max_len - current_len
411
+ padded = keras.ops.pad(idx, [[0, pad_len]], constant_values=-1)
412
+ padded_indices.append(padded)
413
+ dn_positive_idx = (
414
+ keras.ops.stack(padded_indices, axis=0) if padded_indices else None
415
+ )
416
+ if self.label_noise_ratio > 0:
417
+ noise_mask = keras.ops.cast(noise_mask, "bool")
418
+ new_label = keras.random.randint(
419
+ keras.ops.shape(input_query_class),
420
+ 0,
421
+ self.num_labels,
422
+ seed=self.seed_generator,
423
+ dtype="int32",
424
+ )
425
+ input_query_class = keras.ops.where(
426
+ noise_mask & pad_gt_mask,
427
+ new_label,
428
+ input_query_class,
429
+ )
430
+ if self.box_noise_scale > 0:
431
+ known_bbox = keras.utils.bounding_boxes.convert_format(
432
+ input_query_bbox,
433
+ source="center_xywh",
434
+ target="xyxy",
435
+ dtype=self.compute_dtype,
436
+ )
437
+ width_height = input_query_bbox[..., 2:]
438
+ diff = (
439
+ keras.ops.tile(width_height, [1, 1, 2])
440
+ * 0.5
441
+ * self.box_noise_scale
442
+ )
443
+ rand_int_sign = keras.random.randint(
444
+ keras.ops.shape(input_query_bbox),
445
+ 0,
446
+ 2,
447
+ seed=self.seed_generator,
448
+ )
449
+ rand_sign = (
450
+ keras.ops.cast(rand_int_sign, dtype=diff.dtype) * 2.0 - 1.0
451
+ )
452
+ rand_part = keras.random.uniform(
453
+ keras.ops.shape(input_query_bbox),
454
+ seed=self.seed_generator,
455
+ dtype=self.compute_dtype,
456
+ )
457
+ rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (
458
+ 1 - negative_gt_mask
459
+ )
460
+ rand_part = rand_part * rand_sign
461
+ known_bbox = known_bbox + rand_part * diff
462
+ known_bbox = keras.ops.clip(known_bbox, 0.0, 1.0)
463
+ input_query_bbox = keras.utils.bounding_boxes.convert_format(
464
+ known_bbox,
465
+ source="xyxy",
466
+ target="center_xywh",
467
+ dtype=self.compute_dtype,
468
+ )
469
+ input_query_bbox = inverse_sigmoid(input_query_bbox)
470
+ num_denoising_total = max_gt_num * 2 * num_groups_denoising_queries
471
+ target_size = num_denoising_total + num_queries
472
+ attn_mask = keras.ops.zeros(
473
+ [target_size, target_size], dtype=self.compute_dtype
474
+ )
475
+ updates_attn1 = keras.ops.ones(
476
+ [
477
+ target_size - num_denoising_total,
478
+ num_denoising_total,
479
+ ],
480
+ dtype=attn_mask.dtype,
481
+ )
482
+ attn_mask = keras.ops.slice_update(
483
+ attn_mask, [num_denoising_total, 0], updates_attn1
484
+ )
485
+ for i in range(num_groups_denoising_queries):
486
+ start = max_gt_num * 2 * i
487
+ end = max_gt_num * 2 * (i + 1)
488
+ updates_attn2 = keras.ops.ones(
489
+ [end - start, start], dtype=attn_mask.dtype
490
+ )
491
+ attn_mask = keras.ops.slice_update(
492
+ attn_mask, [start, 0], updates_attn2
493
+ )
494
+ updates_attn3 = keras.ops.ones(
495
+ [end - start, num_denoising_total - end],
496
+ dtype=attn_mask.dtype,
497
+ )
498
+ attn_mask = keras.ops.slice_update(
499
+ attn_mask, [start, end], updates_attn3
500
+ )
501
+ if dn_positive_idx is not None:
502
+ denoising_meta_values = {
503
+ "dn_positive_idx": dn_positive_idx,
504
+ "dn_num_group": keras.ops.convert_to_tensor(
505
+ num_groups_denoising_queries, dtype="int32"
506
+ ),
507
+ "dn_num_split": keras.ops.convert_to_tensor(
508
+ [num_denoising_total, num_queries], dtype="int32"
509
+ ),
510
+ }
511
+ return (
512
+ input_query_class,
513
+ input_query_bbox,
514
+ attn_mask,
515
+ denoising_meta_values,
516
+ )
517
+
518
+ def get_config(self):
519
+ config = super().get_config()
520
+ config.update(
521
+ {
522
+ "num_labels": self.num_labels,
523
+ "num_denoising": self.num_denoising,
524
+ "label_noise_ratio": self.label_noise_ratio,
525
+ "box_noise_scale": self.box_noise_scale,
526
+ "seed": self.seed,
527
+ }
528
+ )
529
+ return config
530
+
531
+
532
+ class DFineAnchorGenerator(keras.layers.Layer):
533
+ """Layer to generate anchor boxes for object detection.
534
+
535
+ This layer is used in `DFineBackbone` to generate anchor proposals. These
536
+ anchors are combined with the output of the encoder's bounding box head
537
+ (`enc_bbox_head`) to create initial reference points for the decoder's
538
+ queries.
539
+
540
+ Args:
541
+ anchor_image_size: tuple, The size of the input image `(height, width)`.
542
+ feat_strides: list, The strides of the feature maps.
543
+ data_format: str, The data format of the image channels. Can be either
544
+ `"channels_first"` or `"channels_last"`. If `None` is specified,
545
+ it will use the `image_data_format` value found in your Keras
546
+ config file at `~/.keras/keras.json`. Defaults to `None`.
547
+ dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
548
+ to use for the model's computations and weights. Defaults to `None`.
549
+ **kwargs: Additional keyword arguments passed to the parent class.
550
+ """
551
+
552
+ def __init__(
553
+ self,
554
+ anchor_image_size,
555
+ feat_strides,
556
+ data_format=None,
557
+ dtype=None,
558
+ **kwargs,
559
+ ):
560
+ super().__init__(dtype=dtype, **kwargs)
561
+ self.anchor_image_size = anchor_image_size
562
+ self.feat_strides = feat_strides
563
+ self.data_format = data_format
564
+
565
+ def call(self, sources_for_shape_derivation=None, grid_size=0.05):
566
+ spatial_shapes = None
567
+ if sources_for_shape_derivation is not None:
568
+ if self.data_format == "channels_first":
569
+ spatial_shapes = [
570
+ (keras.ops.shape(s)[2], keras.ops.shape(s)[3])
571
+ for s in sources_for_shape_derivation
572
+ ]
573
+ else:
574
+ spatial_shapes = [
575
+ (keras.ops.shape(s)[1], keras.ops.shape(s)[2])
576
+ for s in sources_for_shape_derivation
577
+ ]
578
+
579
+ if spatial_shapes is None:
580
+ spatial_shapes = [
581
+ (
582
+ keras.ops.cast(self.anchor_image_size[0] / s, "int32"),
583
+ keras.ops.cast(self.anchor_image_size[1] / s, "int32"),
584
+ )
585
+ for s in self.feat_strides
586
+ ]
587
+
588
+ anchors = []
589
+ for level, (height, width) in enumerate(spatial_shapes):
590
+ grid_y, grid_x = keras.ops.meshgrid(
591
+ keras.ops.arange(height, dtype=self.compute_dtype),
592
+ keras.ops.arange(width, dtype=self.compute_dtype),
593
+ indexing="ij",
594
+ )
595
+ grid_xy = keras.ops.stack([grid_x, grid_y], axis=-1)
596
+ grid_xy = keras.ops.expand_dims(grid_xy, axis=0) + 0.5
597
+ grid_xy = grid_xy / keras.ops.array(
598
+ [width, height], dtype=self.compute_dtype
599
+ )
600
+ wh = keras.ops.ones_like(grid_xy) * grid_size * (2.0**level)
601
+ level_anchors = keras.ops.concatenate([grid_xy, wh], axis=-1)
602
+ level_anchors = keras.ops.reshape(
603
+ level_anchors, (-1, height * width, 4)
604
+ )
605
+ anchors.append(level_anchors)
606
+
607
+ eps = 1e-2
608
+ anchors = keras.ops.concatenate(anchors, axis=1)
609
+ valid_mask = keras.ops.all(
610
+ (anchors > eps) & (anchors < 1 - eps), axis=-1, keepdims=True
611
+ )
612
+ anchors_transformed = keras.ops.log(anchors / (1 - anchors))
613
+ dtype_name = keras.backend.standardize_dtype(self.compute_dtype)
614
+ if dtype_name == "float16":
615
+ finfo_dtype = np.float16
616
+ else:
617
+ finfo_dtype = np.float32
618
+ max_float = keras.ops.array(
619
+ np.finfo(finfo_dtype).max, dtype=self.compute_dtype
620
+ )
621
+ anchors = keras.ops.where(valid_mask, anchors_transformed, max_float)
622
+
623
+ return anchors, valid_mask
624
+
625
+ def compute_output_shape(
626
+ self, sources_for_shape_derivation_shape=None, grid_size_shape=None
627
+ ):
628
+ num_total_anchors_dim = None
629
+
630
+ if sources_for_shape_derivation_shape is None:
631
+ num_total_anchors_calc = 0
632
+ for s_stride in self.feat_strides:
633
+ h = self.anchor_image_size[0] // s_stride
634
+ w = self.anchor_image_size[1] // s_stride
635
+ num_total_anchors_calc += h * w
636
+ num_total_anchors_dim = num_total_anchors_calc
637
+ else:
638
+ calculated_spatial_elements = []
639
+ for s_shape in sources_for_shape_derivation_shape:
640
+ if self.data_format == "channels_first":
641
+ h, w = s_shape[2], s_shape[3]
642
+ else:
643
+ h, w = s_shape[1], s_shape[2]
644
+ if h is None or w is None:
645
+ calculated_spatial_elements.append(None)
646
+ else:
647
+ calculated_spatial_elements.append(h * w)
648
+ if any(elem is None for elem in calculated_spatial_elements):
649
+ num_total_anchors_dim = None
650
+ else:
651
+ num_total_anchors_dim = sum(calculated_spatial_elements)
652
+
653
+ anchors_shape = (1, num_total_anchors_dim, 4)
654
+ valid_mask_shape = (1, num_total_anchors_dim, 1)
655
+ return anchors_shape, valid_mask_shape
656
+
657
+ def get_config(self):
658
+ config = super().get_config()
659
+ config.update(
660
+ {
661
+ "anchor_image_size": self.anchor_image_size,
662
+ "feat_strides": self.feat_strides,
663
+ "data_format": self.data_format,
664
+ }
665
+ )
666
+ return config
667
+
668
+
669
+ class DFineSpatialShapesExtractor(keras.layers.Layer):
670
+ """Layer to extract spatial shapes from input tensors.
671
+
672
+ This layer is used in `DFineBackbone` to extract the spatial dimensions
673
+ (height, width) from the multi-scale feature maps. The resulting shape
674
+ tensor is passed to the `DFineDecoder` for use in deformable attention.
675
+
676
+ Args:
677
+ data_format: str, optional, The data format of the input tensors.
678
+ **kwargs: Additional keyword arguments passed to the parent class.
679
+ """
680
+
681
+ def __init__(self, data_format=None, dtype=None, **kwargs):
682
+ super().__init__(dtype=dtype, **kwargs)
683
+ self.data_format = data_format
684
+
685
+ def call(self, sources):
686
+ if self.data_format == "channels_first":
687
+ spatial_shapes = [
688
+ (keras.ops.shape(s)[2], keras.ops.shape(s)[3]) for s in sources
689
+ ]
690
+ else:
691
+ spatial_shapes = [
692
+ (keras.ops.shape(s)[1], keras.ops.shape(s)[2]) for s in sources
693
+ ]
694
+ spatial_shapes_tensor = keras.ops.array(spatial_shapes, dtype="int32")
695
+ return spatial_shapes_tensor
696
+
697
+ def compute_output_shape(self, input_shape):
698
+ if not isinstance(input_shape, list):
699
+ raise ValueError("Expected a list of shape tuples")
700
+ num_sources = len(input_shape)
701
+ return (num_sources, 2)
702
+
703
+ def get_config(self):
704
+ config = super().get_config()
705
+ config.update({"data_format": self.data_format})
706
+ return config
707
+
708
+
709
+ class DFineInitialQueryAndReferenceGenerator(keras.layers.Layer):
710
+ """Layer to generate initial queries and reference points for the decoder.
711
+
712
+ This layer is a crucial component in `DFineBackbone` that bridges the
713
+ encoder and decoder. It selects the top-k predictions from the encoder's
714
+ output heads and uses them to generate the initial `target` (queries) and
715
+ `reference_points` that are fed into the `DFineDecoder`.
716
+
717
+ Args:
718
+ num_queries: int, The number of queries to generate.
719
+ hidden_dim: int, The hidden dimension of the model.
720
+ learn_initial_query: bool, Whether to learn the initial query
721
+ embeddings.
722
+ **kwargs: Additional keyword arguments passed to the parent class.
723
+ """
724
+
725
+ def __init__(
726
+ self,
727
+ num_queries,
728
+ hidden_dim,
729
+ learn_initial_query,
730
+ dtype=None,
731
+ **kwargs,
732
+ ):
733
+ super().__init__(dtype=dtype, **kwargs)
734
+ self.num_queries = num_queries
735
+ self.hidden_dim = hidden_dim
736
+ self.learn_initial_query = learn_initial_query
737
+ if self.learn_initial_query:
738
+ self.query_indices_base = keras.ops.expand_dims(
739
+ keras.ops.arange(self.num_queries, dtype="int32"), axis=0
740
+ )
741
+ self.weight_embedding = keras.layers.Embedding(
742
+ input_dim=num_queries,
743
+ output_dim=hidden_dim,
744
+ name="weight_embedding",
745
+ dtype=self.dtype_policy,
746
+ embeddings_initializer="glorot_uniform",
747
+ )
748
+ else:
749
+ self.weight_embedding = None
750
+
751
+ def call(
752
+ self,
753
+ inputs,
754
+ denoising_bbox_unact=None,
755
+ denoising_class=None,
756
+ training=None,
757
+ ):
758
+ (
759
+ enc_outputs_class,
760
+ enc_outputs_coord_logits_plus_anchors,
761
+ output_memory,
762
+ sources_last_element,
763
+ ) = inputs
764
+ enc_outputs_class_max = keras.ops.max(enc_outputs_class, axis=-1)
765
+ topk_ind = keras.ops.top_k(
766
+ enc_outputs_class_max, k=self.num_queries, sorted=True
767
+ )[1]
768
+
769
+ def gather_batch(elems):
770
+ data, indices = elems
771
+ return keras.ops.take(data, indices, axis=0)
772
+
773
+ reference_points_unact = keras.ops.map(
774
+ gather_batch, (enc_outputs_coord_logits_plus_anchors, topk_ind)
775
+ )
776
+ enc_topk_logits = keras.ops.map(
777
+ gather_batch, (enc_outputs_class, topk_ind)
778
+ )
779
+ enc_topk_bboxes = keras.ops.sigmoid(reference_points_unact)
780
+
781
+ if denoising_bbox_unact is not None:
782
+ current_batch_size = keras.ops.shape(reference_points_unact)[0]
783
+ denoising_bbox_unact = denoising_bbox_unact[:current_batch_size]
784
+ if denoising_class is not None:
785
+ denoising_class = denoising_class[:current_batch_size]
786
+ reference_points_unact = keras.ops.concatenate(
787
+ [denoising_bbox_unact, reference_points_unact], axis=1
788
+ )
789
+ if self.learn_initial_query:
790
+ query_indices = self.query_indices_base
791
+ target_embedding_val = self.weight_embedding(
792
+ query_indices, training=training
793
+ )
794
+ batch_size = keras.ops.shape(sources_last_element)[0]
795
+ target = keras.ops.tile(target_embedding_val, [batch_size, 1, 1])
796
+ else:
797
+ target = keras.ops.map(gather_batch, (output_memory, topk_ind))
798
+ target = keras.ops.stop_gradient(target)
799
+
800
+ if denoising_class is not None:
801
+ target = keras.ops.concatenate([denoising_class, target], axis=1)
802
+ init_reference_points = keras.ops.stop_gradient(reference_points_unact)
803
+ return init_reference_points, target, enc_topk_logits, enc_topk_bboxes
804
+
805
+ def get_config(self):
806
+ config = super().get_config()
807
+ config.update(
808
+ {
809
+ "num_queries": self.num_queries,
810
+ "hidden_dim": self.hidden_dim,
811
+ "learn_initial_query": self.learn_initial_query,
812
+ }
813
+ )
814
+ return config
815
+
816
+ def compute_output_spec(
817
+ self,
818
+ inputs,
819
+ denoising_bbox_unact=None,
820
+ denoising_class=None,
821
+ training=None,
822
+ ):
823
+ (
824
+ enc_outputs_class_spec,
825
+ _,
826
+ output_memory_spec,
827
+ _,
828
+ ) = inputs
829
+ batch_size = enc_outputs_class_spec.shape[0]
830
+ d_model_dim = output_memory_spec.shape[-1]
831
+ num_labels_dim = enc_outputs_class_spec.shape[-1]
832
+ num_queries_for_ref_points = self.num_queries
833
+ if denoising_bbox_unact is not None:
834
+ if len(denoising_bbox_unact.shape) > 1:
835
+ if denoising_bbox_unact.shape[1] is not None:
836
+ num_queries_for_ref_points = (
837
+ denoising_bbox_unact.shape[1] + self.num_queries
838
+ )
839
+ else:
840
+ num_queries_for_ref_points = None
841
+ num_queries_for_target = self.num_queries
842
+ if denoising_class is not None:
843
+ if len(denoising_class.shape) > 1:
844
+ if denoising_class.shape[1] is not None:
845
+ num_queries_for_target = (
846
+ denoising_class.shape[1] + self.num_queries
847
+ )
848
+ else:
849
+ num_queries_for_target = None
850
+ init_reference_points_spec = keras.KerasTensor(
851
+ shape=(batch_size, num_queries_for_ref_points, 4),
852
+ dtype=self.compute_dtype,
853
+ )
854
+ target_spec = keras.KerasTensor(
855
+ shape=(batch_size, num_queries_for_target, d_model_dim),
856
+ dtype=self.compute_dtype,
857
+ )
858
+ enc_topk_logits_spec = keras.KerasTensor(
859
+ shape=(batch_size, self.num_queries, num_labels_dim),
860
+ dtype=self.compute_dtype,
861
+ )
862
+ enc_topk_bboxes_spec = keras.KerasTensor(
863
+ shape=(batch_size, self.num_queries, 4), dtype=self.compute_dtype
864
+ )
865
+
866
+ return (
867
+ init_reference_points_spec,
868
+ target_spec,
869
+ enc_topk_logits_spec,
870
+ enc_topk_bboxes_spec,
871
+ )
872
+
873
+
874
+ class DFineIntegral(keras.layers.Layer):
875
+ """Layer to compute integrated values from predicted corner probabilities.
876
+
877
+ This layer implements the integral regression technique for bounding box
878
+ prediction. It is used in `DFineDecoder` to transform the predicted
879
+ distribution over bins (from `bbox_embed`) into continuous distance values,
880
+ which are then used to calculate the final box coordinates.
881
+
882
+ Args:
883
+ max_num_bins: int, The maximum number of bins for the predictions.
884
+ **kwargs: Additional keyword arguments passed to the parent class.
885
+ """
886
+
887
+ def __init__(self, max_num_bins, dtype=None, **kwargs):
888
+ super().__init__(dtype=dtype, **kwargs)
889
+ self.max_num_bins = max_num_bins
890
+
891
+ def build(self, input_shape):
892
+ super().build(input_shape)
893
+
894
+ def call(self, pred_corners, project, training=None):
895
+ original_shape = keras.ops.shape(pred_corners)
896
+ batch_size = original_shape[0]
897
+ num_queries = original_shape[1]
898
+ reshaped_pred_corners = keras.ops.reshape(
899
+ pred_corners, (-1, self.max_num_bins + 1)
900
+ )
901
+ softmax_output = keras.ops.softmax(reshaped_pred_corners, axis=1)
902
+ linear_output = keras.ops.matmul(
903
+ softmax_output, keras.ops.transpose(project)
904
+ )
905
+ squeezed_output = keras.ops.squeeze(linear_output, axis=-1)
906
+ output_grouped_by_4 = keras.ops.reshape(squeezed_output, (-1, 4))
907
+ final_output = keras.ops.reshape(
908
+ output_grouped_by_4, (batch_size, num_queries, -1)
909
+ )
910
+ return final_output
911
+
912
+ def get_config(self):
913
+ config = super().get_config()
914
+ config.update(
915
+ {
916
+ "max_num_bins": self.max_num_bins,
917
+ }
918
+ )
919
+ return config
920
+
921
+
922
+ class DFineLQE(keras.layers.Layer):
923
+ """Layer to compute quality scores for predictions.
924
+
925
+ This layer, used within `DFineDecoder`, implements the Localization Quality
926
+ Estimation (LQE) head. It computes a quality score from the distribution of
927
+ predicted bounding box corners and adds this score to the classification
928
+ logits, enhancing prediction confidence.
929
+
930
+ Args:
931
+ top_prob_values: int, The number of top probabilities to consider.
932
+ max_num_bins: int, The maximum number of bins for the predictions.
933
+ lqe_hidden_dim: int, The hidden dimension for the MLP.
934
+ num_lqe_layers: int, The number of layers in the MLP.
935
+ **kwargs: Additional keyword arguments passed to the parent class.
936
+ """
937
+
938
+ def __init__(
939
+ self,
940
+ top_prob_values,
941
+ max_num_bins,
942
+ lqe_hidden_dim,
943
+ num_lqe_layers,
944
+ dtype=None,
945
+ **kwargs,
946
+ ):
947
+ super().__init__(dtype=dtype, **kwargs)
948
+ self.top_prob_values = top_prob_values
949
+ self.max_num_bins = max_num_bins
950
+ self.reg_conf = DFineMLP(
951
+ input_dim=4 * (self.top_prob_values + 1),
952
+ hidden_dim=lqe_hidden_dim,
953
+ output_dim=1,
954
+ num_layers=num_lqe_layers,
955
+ dtype=self.dtype_policy,
956
+ last_layer_initializer="zeros",
957
+ name="reg_conf",
958
+ )
959
+
960
+ def build(self, input_shape):
961
+ reg_conf_input_shape = (
962
+ input_shape[0][0],
963
+ input_shape[0][1],
964
+ 4 * (self.top_prob_values + 1),
965
+ )
966
+ self.reg_conf.build(reg_conf_input_shape)
967
+ super().build(input_shape)
968
+
969
+ def call(self, scores, pred_corners, training=None):
970
+ original_shape = keras.ops.shape(pred_corners)
971
+ batch_size = original_shape[0]
972
+ length = original_shape[1]
973
+ reshaped_pred_corners = keras.ops.reshape(
974
+ pred_corners, (batch_size, length, 4, self.max_num_bins + 1)
975
+ )
976
+ prob = keras.ops.softmax(reshaped_pred_corners, axis=-1)
977
+ prob_topk, _ = keras.ops.top_k(
978
+ prob, k=self.top_prob_values, sorted=True
979
+ )
980
+ stat = keras.ops.concatenate(
981
+ [prob_topk, keras.ops.mean(prob_topk, axis=-1, keepdims=True)],
982
+ axis=-1,
983
+ )
984
+ reshaped_stat = keras.ops.reshape(stat, (batch_size, length, -1))
985
+ quality_score = self.reg_conf(reshaped_stat, training=training)
986
+ return scores + quality_score
987
+
988
+ def get_config(self):
989
+ config = super().get_config()
990
+ config.update(
991
+ {
992
+ "top_prob_values": self.top_prob_values,
993
+ "max_num_bins": self.max_num_bins,
994
+ "lqe_hidden_dim": self.reg_conf.hidden_dim,
995
+ "num_lqe_layers": self.reg_conf.num_layers,
996
+ }
997
+ )
998
+ return config
999
+
1000
+
1001
+ class DFineConvNormLayer(keras.layers.Layer):
1002
+ """Convolutional layer with normalization and optional activation.
1003
+
1004
+ This is a fundamental building block used in the CNN parts of D-FINE. It
1005
+ combines a `Conv2D` layer with `BatchNormalization` and an optional
1006
+ activation. It is used extensively in layers like `DFineRepVggBlock`,
1007
+ `DFineCSPRepLayer`, and within the `DFineHybridEncoder`.
1008
+
1009
+ Args:
1010
+ filters: int, The number of output channels.
1011
+ kernel_size: int, The size of the convolutional kernel.
1012
+ batch_norm_eps: float, The epsilon value for batch normalization.
1013
+ stride: int, The stride of the convolution.
1014
+ groups: int, The number of groups for grouped convolution.
1015
+ padding: int or None, The padding to apply.
1016
+ activation_function: str or None, The activation function to use.
1017
+ kernel_initializer: str or Initializer, optional, Initializer for
1018
+ the kernel weights. Defaults to `"glorot_uniform"`.
1019
+ bias_initializer: str or Initializer, optional, Initializer for
1020
+ the bias weights. Defaults to `"zeros"`.
1021
+ channel_axis: int, optional, The channel axis. Defaults to `None`.
1022
+ **kwargs: Additional keyword arguments passed to the parent class.
1023
+ """
1024
+
1025
+ def __init__(
1026
+ self,
1027
+ filters,
1028
+ kernel_size,
1029
+ batch_norm_eps,
1030
+ stride,
1031
+ groups,
1032
+ padding,
1033
+ activation_function,
1034
+ kernel_initializer="glorot_uniform",
1035
+ bias_initializer="zeros",
1036
+ channel_axis=None,
1037
+ dtype=None,
1038
+ **kwargs,
1039
+ ):
1040
+ super().__init__(dtype=dtype, **kwargs)
1041
+ self.filters = filters
1042
+ self.kernel_size = kernel_size
1043
+ self.batch_norm_eps = batch_norm_eps
1044
+ self.stride = stride
1045
+ self.groups = groups
1046
+ self.padding_arg = padding
1047
+ self.activation_function = activation_function
1048
+ self.kernel_initializer = keras.initializers.get(kernel_initializer)
1049
+ self.bias_initializer = keras.initializers.get(bias_initializer)
1050
+ self.channel_axis = channel_axis
1051
+ if self.padding_arg is None:
1052
+ keras_conv_padding_mode = "same"
1053
+ self.explicit_padding_layer = None
1054
+ else:
1055
+ keras_conv_padding_mode = "valid"
1056
+ self.explicit_padding_layer = keras.layers.ZeroPadding2D(
1057
+ padding=self.padding_arg,
1058
+ name=f"{self.name}_explicit_padding",
1059
+ dtype=self.dtype_policy,
1060
+ )
1061
+
1062
+ self.convolution = keras.layers.Conv2D(
1063
+ filters=self.filters,
1064
+ kernel_size=self.kernel_size,
1065
+ strides=self.stride,
1066
+ padding=keras_conv_padding_mode,
1067
+ groups=self.groups,
1068
+ use_bias=False,
1069
+ dtype=self.dtype_policy,
1070
+ kernel_initializer=self.kernel_initializer,
1071
+ bias_initializer=self.bias_initializer,
1072
+ name=f"{self.name}_convolution",
1073
+ )
1074
+ self.normalization = keras.layers.BatchNormalization(
1075
+ epsilon=self.batch_norm_eps,
1076
+ name=f"{self.name}_normalization",
1077
+ axis=self.channel_axis,
1078
+ dtype=self.dtype_policy,
1079
+ )
1080
+ self.activation_layer = (
1081
+ keras.layers.Activation(
1082
+ self.activation_function,
1083
+ name=f"{self.name}_activation",
1084
+ dtype=self.dtype_policy,
1085
+ )
1086
+ if self.activation_function
1087
+ else keras.layers.Identity(
1088
+ name=f"{self.name}_identity_activation", dtype=self.dtype_policy
1089
+ )
1090
+ )
1091
+
1092
+ def build(self, input_shape):
1093
+ if self.explicit_padding_layer:
1094
+ self.explicit_padding_layer.build(input_shape)
1095
+ shape = self.explicit_padding_layer.compute_output_shape(
1096
+ input_shape
1097
+ )
1098
+ else:
1099
+ shape = input_shape
1100
+ self.convolution.build(shape)
1101
+ conv_output_shape = self.convolution.compute_output_shape(shape)
1102
+ self.normalization.build(conv_output_shape)
1103
+ self.activation_layer.build(conv_output_shape)
1104
+ super().build(input_shape)
1105
+
1106
+ def call(self, hidden_state, training=None):
1107
+ if self.explicit_padding_layer:
1108
+ hidden_state = self.explicit_padding_layer(hidden_state)
1109
+ hidden_state = self.convolution(hidden_state)
1110
+ hidden_state = self.normalization(hidden_state, training=training)
1111
+ hidden_state = self.activation_layer(hidden_state)
1112
+ return hidden_state
1113
+
1114
+ def compute_output_shape(self, input_shape):
1115
+ shape = input_shape
1116
+ if self.explicit_padding_layer:
1117
+ shape = self.explicit_padding_layer.compute_output_shape(shape)
1118
+ return self.convolution.compute_output_shape(shape)
1119
+
1120
+ def get_config(self):
1121
+ config = super().get_config()
1122
+ config.update(
1123
+ {
1124
+ "filters": self.filters,
1125
+ "kernel_size": self.kernel_size,
1126
+ "batch_norm_eps": self.batch_norm_eps,
1127
+ "stride": self.stride,
1128
+ "groups": self.groups,
1129
+ "padding": self.padding_arg,
1130
+ "activation_function": self.activation_function,
1131
+ "kernel_initializer": keras.initializers.serialize(
1132
+ self.kernel_initializer
1133
+ ),
1134
+ "bias_initializer": keras.initializers.serialize(
1135
+ self.bias_initializer
1136
+ ),
1137
+ "channel_axis": self.channel_axis,
1138
+ }
1139
+ )
1140
+ return config
1141
+
1142
+
1143
+ class DFineRepVggBlock(keras.layers.Layer):
1144
+ """RepVGG-style block with two parallel convolutional paths.
1145
+
1146
+ This layer implements a block inspired by the RepVGG architecture, featuring
1147
+ two parallel convolutional paths (3x3 and 1x1) that are summed. It serves
1148
+ as the core bottleneck block within the `DFineCSPRepLayer`.
1149
+
1150
+ Args:
1151
+ activation_function: str, The activation function to use.
1152
+ filters: int, The number of output channels.
1153
+ batch_norm_eps: float, The epsilon value for batch normalization.
1154
+ kernel_initializer: str or Initializer, optional, Initializer for
1155
+ the kernel weights. Defaults to `"glorot_uniform"`.
1156
+ bias_initializer: str or Initializer, optional, Initializer for
1157
+ the bias weights. Defaults to `"zeros"`.
1158
+ channel_axis: int, optional, The channel axis. Defaults to `None`.
1159
+ **kwargs: Additional keyword arguments passed to the parent class.
1160
+ """
1161
+
1162
+ def __init__(
1163
+ self,
1164
+ activation_function,
1165
+ filters,
1166
+ batch_norm_eps=1e-5,
1167
+ kernel_initializer="glorot_uniform",
1168
+ bias_initializer="zeros",
1169
+ channel_axis=None,
1170
+ dtype=None,
1171
+ **kwargs,
1172
+ ):
1173
+ super().__init__(dtype=dtype, **kwargs)
1174
+ self.activation_function = activation_function
1175
+ self.filters = filters
1176
+ self.batch_norm_eps = batch_norm_eps
1177
+ self.kernel_initializer = keras.initializers.get(kernel_initializer)
1178
+ self.bias_initializer = keras.initializers.get(bias_initializer)
1179
+ self.channel_axis = channel_axis
1180
+ self.conv1_layer = DFineConvNormLayer(
1181
+ filters=self.filters,
1182
+ kernel_size=3,
1183
+ batch_norm_eps=self.batch_norm_eps,
1184
+ stride=1,
1185
+ groups=1,
1186
+ padding=1,
1187
+ activation_function=None,
1188
+ dtype=self.dtype_policy,
1189
+ kernel_initializer=self.kernel_initializer,
1190
+ bias_initializer=self.bias_initializer,
1191
+ channel_axis=self.channel_axis,
1192
+ name="conv1",
1193
+ )
1194
+ self.conv2_layer = DFineConvNormLayer(
1195
+ filters=self.filters,
1196
+ kernel_size=1,
1197
+ batch_norm_eps=self.batch_norm_eps,
1198
+ stride=1,
1199
+ groups=1,
1200
+ padding=0,
1201
+ activation_function=None,
1202
+ dtype=self.dtype_policy,
1203
+ kernel_initializer=self.kernel_initializer,
1204
+ bias_initializer=self.bias_initializer,
1205
+ channel_axis=self.channel_axis,
1206
+ name="conv2",
1207
+ )
1208
+ self.activation_layer = (
1209
+ keras.layers.Activation(
1210
+ self.activation_function,
1211
+ name="block_activation",
1212
+ dtype=self.dtype_policy,
1213
+ )
1214
+ if self.activation_function
1215
+ else keras.layers.Identity(
1216
+ name="identity_activation", dtype=self.dtype_policy
1217
+ )
1218
+ )
1219
+
1220
+ def build(self, input_shape):
1221
+ self.conv1_layer.build(input_shape)
1222
+ self.conv2_layer.build(input_shape)
1223
+ self.activation_layer.build(input_shape)
1224
+ super().build(input_shape)
1225
+
1226
+ def call(self, x, training=None):
1227
+ y1 = self.conv1_layer(x, training=training)
1228
+ y2 = self.conv2_layer(x, training=training)
1229
+ y = y1 + y2
1230
+ return self.activation_layer(y)
1231
+
1232
+ def compute_output_shape(self, input_shape):
1233
+ return self.conv1_layer.compute_output_shape(input_shape)
1234
+
1235
+ def get_config(self):
1236
+ config = super().get_config()
1237
+ config.update(
1238
+ {
1239
+ "activation_function": self.activation_function,
1240
+ "filters": self.filters,
1241
+ "batch_norm_eps": self.batch_norm_eps,
1242
+ "kernel_initializer": keras.initializers.serialize(
1243
+ self.kernel_initializer
1244
+ ),
1245
+ "bias_initializer": keras.initializers.serialize(
1246
+ self.bias_initializer
1247
+ ),
1248
+ "channel_axis": self.channel_axis,
1249
+ }
1250
+ )
1251
+ return config
1252
+
1253
+
1254
+ class DFineCSPRepLayer(keras.layers.Layer):
1255
+ """CSP (Cross Stage Partial) layer with repeated bottleneck blocks.
1256
+
1257
+ This layer implements a Cross Stage Partial (CSP) block using
1258
+ `DFineRepVggBlock` as its bottleneck. It is a key component of the
1259
+ `DFineFeatureAggregationBlock` block, which forms the FPN/PAN structure in
1260
+ the `DFineHybridEncoder`.
1261
+
1262
+ Args:
1263
+ activation_function: str, The activation function to use.
1264
+ batch_norm_eps: float, The epsilon value for batch normalization.
1265
+ filters: int, The number of output channels.
1266
+ num_blocks: int, The number of bottleneck blocks.
1267
+ expansion: float, The expansion factor for hidden channels. Defaults to
1268
+ `1.0`.
1269
+ kernel_initializer: str or Initializer, optional, Initializer for
1270
+ the kernel weights. Defaults to `"glorot_uniform"`.
1271
+ bias_initializer: str or Initializer, optional, Initializer for
1272
+ the bias weights. Defaults to `"zeros"`.
1273
+ channel_axis: int, optional, The channel axis. Defaults to `None`.
1274
+ **kwargs: Additional keyword arguments passed to the parent class.
1275
+ """
1276
+
1277
+ def __init__(
1278
+ self,
1279
+ activation_function,
1280
+ batch_norm_eps,
1281
+ filters,
1282
+ num_blocks,
1283
+ expansion=1.0,
1284
+ kernel_initializer="glorot_uniform",
1285
+ bias_initializer="zeros",
1286
+ channel_axis=None,
1287
+ dtype=None,
1288
+ **kwargs,
1289
+ ):
1290
+ super().__init__(dtype=dtype, **kwargs)
1291
+ self.activation_function = activation_function
1292
+ self.batch_norm_eps = batch_norm_eps
1293
+ self.filters = filters
1294
+ self.num_blocks = num_blocks
1295
+ self.expansion = expansion
1296
+ self.kernel_initializer = keras.initializers.get(kernel_initializer)
1297
+ self.bias_initializer = keras.initializers.get(bias_initializer)
1298
+ self.channel_axis = channel_axis
1299
+ hidden_channels = int(self.filters * self.expansion)
1300
+ self.conv1 = DFineConvNormLayer(
1301
+ filters=hidden_channels,
1302
+ kernel_size=1,
1303
+ batch_norm_eps=self.batch_norm_eps,
1304
+ stride=1,
1305
+ groups=1,
1306
+ padding=0,
1307
+ activation_function=self.activation_function,
1308
+ dtype=self.dtype_policy,
1309
+ kernel_initializer=self.kernel_initializer,
1310
+ bias_initializer=self.bias_initializer,
1311
+ channel_axis=self.channel_axis,
1312
+ name="conv1",
1313
+ )
1314
+ self.conv2 = DFineConvNormLayer(
1315
+ filters=hidden_channels,
1316
+ kernel_size=1,
1317
+ batch_norm_eps=self.batch_norm_eps,
1318
+ stride=1,
1319
+ groups=1,
1320
+ padding=0,
1321
+ activation_function=self.activation_function,
1322
+ dtype=self.dtype_policy,
1323
+ kernel_initializer=self.kernel_initializer,
1324
+ bias_initializer=self.bias_initializer,
1325
+ channel_axis=self.channel_axis,
1326
+ name="conv2",
1327
+ )
1328
+ self.bottleneck_layers = [
1329
+ DFineRepVggBlock(
1330
+ activation_function=self.activation_function,
1331
+ filters=hidden_channels,
1332
+ batch_norm_eps=self.batch_norm_eps,
1333
+ dtype=self.dtype_policy,
1334
+ kernel_initializer=self.kernel_initializer,
1335
+ bias_initializer=self.bias_initializer,
1336
+ channel_axis=self.channel_axis,
1337
+ name=f"bottleneck_{i}",
1338
+ )
1339
+ for i in range(self.num_blocks)
1340
+ ]
1341
+ if hidden_channels != self.filters:
1342
+ self.conv3 = DFineConvNormLayer(
1343
+ filters=self.filters,
1344
+ kernel_size=1,
1345
+ batch_norm_eps=self.batch_norm_eps,
1346
+ stride=1,
1347
+ groups=1,
1348
+ padding=0,
1349
+ activation_function=self.activation_function,
1350
+ dtype=self.dtype_policy,
1351
+ kernel_initializer=self.kernel_initializer,
1352
+ bias_initializer=self.bias_initializer,
1353
+ channel_axis=self.channel_axis,
1354
+ name="conv3",
1355
+ )
1356
+ else:
1357
+ self.conv3 = keras.layers.Identity(
1358
+ name="conv3_identity", dtype=self.dtype_policy
1359
+ )
1360
+
1361
+ def build(self, input_shape):
1362
+ self.conv1.build(input_shape)
1363
+ self.conv2.build(input_shape)
1364
+ bottleneck_input_shape = self.conv1.compute_output_shape(input_shape)
1365
+ for bottleneck_layer in self.bottleneck_layers:
1366
+ bottleneck_layer.build(bottleneck_input_shape)
1367
+ self.conv3.build(bottleneck_input_shape)
1368
+ super().build(input_shape)
1369
+
1370
+ def call(self, hidden_state, training=None):
1371
+ hidden_state_1 = self.conv1(hidden_state, training=training)
1372
+ for bottleneck_layer in self.bottleneck_layers:
1373
+ hidden_state_1 = bottleneck_layer(hidden_state_1, training=training)
1374
+ hidden_state_2 = self.conv2(hidden_state, training=training)
1375
+ summed_hidden_states = hidden_state_1 + hidden_state_2
1376
+ if isinstance(self.conv3, keras.layers.Identity):
1377
+ hidden_state_3 = self.conv3(summed_hidden_states)
1378
+ else:
1379
+ hidden_state_3 = self.conv3(summed_hidden_states, training=training)
1380
+ return hidden_state_3
1381
+
1382
+ def compute_output_shape(self, input_shape):
1383
+ shape_after_conv1 = self.conv1.compute_output_shape(input_shape)
1384
+ return self.conv3.compute_output_shape(shape_after_conv1)
1385
+
1386
+ def get_config(self):
1387
+ config = super().get_config()
1388
+ config.update(
1389
+ {
1390
+ "activation_function": self.activation_function,
1391
+ "batch_norm_eps": self.batch_norm_eps,
1392
+ "filters": self.filters,
1393
+ "num_blocks": self.num_blocks,
1394
+ "expansion": self.expansion,
1395
+ "kernel_initializer": keras.initializers.serialize(
1396
+ self.kernel_initializer
1397
+ ),
1398
+ "bias_initializer": keras.initializers.serialize(
1399
+ self.bias_initializer
1400
+ ),
1401
+ "channel_axis": self.channel_axis,
1402
+ }
1403
+ )
1404
+ return config
1405
+
1406
+
1407
+ class DFineFeatureAggregationBlock(keras.layers.Layer):
1408
+ """Complex block combining convolutional and CSP layers.
1409
+
1410
+ This layer implements a complex feature extraction block combining multiple
1411
+ convolutional and `DFineCSPRepLayer` layers. It is the main building block
1412
+ for the Feature Pyramid Network (FPN) and Path Aggregation Network (PAN)
1413
+ pathways within the `DFineHybridEncoder`.
1414
+
1415
+ Args:
1416
+ encoder_hidden_dim: int, The hidden dimension of the encoder.
1417
+ hidden_expansion: float, The expansion factor for hidden channels.
1418
+ batch_norm_eps: float, The epsilon value for batch normalization.
1419
+ activation_function: str, The activation function to use.
1420
+ num_blocks: int, The number of blocks in the CSP layers.
1421
+ kernel_initializer: str or Initializer, optional, Initializer for
1422
+ the kernel weights. Defaults to `"glorot_uniform"`.
1423
+ bias_initializer: str or Initializer, optional, Initializer for
1424
+ the bias weights. Defaults to `"zeros"`.
1425
+ channel_axis: int, optional, The channel axis. Defaults to `None`.
1426
+ **kwargs: Additional keyword arguments passed to the parent class.
1427
+ """
1428
+
1429
+ def __init__(
1430
+ self,
1431
+ encoder_hidden_dim,
1432
+ hidden_expansion,
1433
+ batch_norm_eps,
1434
+ activation_function,
1435
+ num_blocks,
1436
+ kernel_initializer="glorot_uniform",
1437
+ bias_initializer="zeros",
1438
+ channel_axis=None,
1439
+ dtype=None,
1440
+ **kwargs,
1441
+ ):
1442
+ super().__init__(dtype=dtype, **kwargs)
1443
+ self.encoder_hidden_dim = encoder_hidden_dim
1444
+ self.hidden_expansion = hidden_expansion
1445
+ self.batch_norm_eps = batch_norm_eps
1446
+ self.activation_function = activation_function
1447
+ self.num_blocks = num_blocks
1448
+ self.kernel_initializer = keras.initializers.get(kernel_initializer)
1449
+ self.bias_initializer = keras.initializers.get(bias_initializer)
1450
+ self.channel_axis = channel_axis
1451
+
1452
+ conv3_dim = self.encoder_hidden_dim * 2
1453
+ self.conv4_dim = int(
1454
+ self.hidden_expansion * self.encoder_hidden_dim / 2
1455
+ )
1456
+ self.conv_dim = conv3_dim // 2
1457
+ self.conv1 = DFineConvNormLayer(
1458
+ filters=conv3_dim,
1459
+ kernel_size=1,
1460
+ batch_norm_eps=self.batch_norm_eps,
1461
+ stride=1,
1462
+ groups=1,
1463
+ padding=0,
1464
+ activation_function=self.activation_function,
1465
+ dtype=self.dtype_policy,
1466
+ kernel_initializer=self.kernel_initializer,
1467
+ bias_initializer=self.bias_initializer,
1468
+ channel_axis=self.channel_axis,
1469
+ name="conv1",
1470
+ )
1471
+ self.csp_rep1 = DFineCSPRepLayer(
1472
+ activation_function=self.activation_function,
1473
+ batch_norm_eps=self.batch_norm_eps,
1474
+ filters=self.conv4_dim,
1475
+ num_blocks=self.num_blocks,
1476
+ dtype=self.dtype_policy,
1477
+ kernel_initializer=self.kernel_initializer,
1478
+ bias_initializer=self.bias_initializer,
1479
+ channel_axis=self.channel_axis,
1480
+ name="csp_rep1",
1481
+ )
1482
+ self.conv2 = DFineConvNormLayer(
1483
+ filters=self.conv4_dim,
1484
+ kernel_size=3,
1485
+ batch_norm_eps=self.batch_norm_eps,
1486
+ stride=1,
1487
+ groups=1,
1488
+ padding=1,
1489
+ activation_function=self.activation_function,
1490
+ dtype=self.dtype_policy,
1491
+ kernel_initializer=self.kernel_initializer,
1492
+ bias_initializer=self.bias_initializer,
1493
+ channel_axis=self.channel_axis,
1494
+ name="conv2",
1495
+ )
1496
+ self.csp_rep2 = DFineCSPRepLayer(
1497
+ activation_function=self.activation_function,
1498
+ batch_norm_eps=self.batch_norm_eps,
1499
+ filters=self.conv4_dim,
1500
+ num_blocks=self.num_blocks,
1501
+ dtype=self.dtype_policy,
1502
+ kernel_initializer=self.kernel_initializer,
1503
+ bias_initializer=self.bias_initializer,
1504
+ channel_axis=self.channel_axis,
1505
+ name="csp_rep2",
1506
+ )
1507
+ self.conv3 = DFineConvNormLayer(
1508
+ filters=self.conv4_dim,
1509
+ kernel_size=3,
1510
+ batch_norm_eps=self.batch_norm_eps,
1511
+ stride=1,
1512
+ groups=1,
1513
+ padding=1,
1514
+ activation_function=self.activation_function,
1515
+ dtype=self.dtype_policy,
1516
+ kernel_initializer=self.kernel_initializer,
1517
+ bias_initializer=self.bias_initializer,
1518
+ channel_axis=self.channel_axis,
1519
+ name="conv3",
1520
+ )
1521
+ self.conv4 = DFineConvNormLayer(
1522
+ filters=self.encoder_hidden_dim,
1523
+ kernel_size=1,
1524
+ batch_norm_eps=self.batch_norm_eps,
1525
+ stride=1,
1526
+ groups=1,
1527
+ padding=0,
1528
+ activation_function=self.activation_function,
1529
+ dtype=self.dtype_policy,
1530
+ kernel_initializer=self.kernel_initializer,
1531
+ bias_initializer=self.bias_initializer,
1532
+ channel_axis=self.channel_axis,
1533
+ name="conv4",
1534
+ )
1535
+
1536
+ def build(self, input_shape):
1537
+ self.conv1.build(input_shape)
1538
+ shape_after_conv1 = self.conv1.compute_output_shape(input_shape)
1539
+ csp_rep_input_shape_list = list(shape_after_conv1)
1540
+ csp_rep_input_shape_list[self.channel_axis] = self.conv_dim
1541
+ csp_rep_input_shape = tuple(csp_rep_input_shape_list)
1542
+ self.csp_rep1.build(csp_rep_input_shape)
1543
+ shape_after_csp_rep1 = self.csp_rep1.compute_output_shape(
1544
+ csp_rep_input_shape
1545
+ )
1546
+ self.conv2.build(shape_after_csp_rep1)
1547
+ shape_after_conv2 = self.conv2.compute_output_shape(
1548
+ shape_after_csp_rep1
1549
+ )
1550
+ self.csp_rep2.build(shape_after_conv2)
1551
+ shape_after_csp_rep2 = self.csp_rep2.compute_output_shape(
1552
+ shape_after_conv2
1553
+ )
1554
+ self.conv3.build(shape_after_csp_rep2)
1555
+ shape_for_concat_list = list(shape_after_conv1)
1556
+ shape_for_concat_list[self.channel_axis] = (
1557
+ self.conv_dim * 2 + self.conv4_dim * 2
1558
+ )
1559
+ shape_for_concat = tuple(shape_for_concat_list)
1560
+ self.conv4.build(shape_for_concat)
1561
+ super().build(input_shape)
1562
+
1563
+ def call(self, input_features, training=None):
1564
+ conv1_out = self.conv1(input_features, training=training)
1565
+ split_features_tensor = keras.ops.split(
1566
+ conv1_out, [self.conv_dim, self.conv_dim], axis=self.channel_axis
1567
+ )
1568
+ split_features = list(split_features_tensor)
1569
+ branch1 = self.csp_rep1(split_features[-1], training=training)
1570
+ branch1 = self.conv2(branch1, training=training)
1571
+ branch2 = self.csp_rep2(branch1, training=training)
1572
+ branch2 = self.conv3(branch2, training=training)
1573
+ split_features.extend([branch1, branch2])
1574
+ merged_features = keras.ops.concatenate(
1575
+ split_features, axis=self.channel_axis
1576
+ )
1577
+ merged_features = self.conv4(merged_features, training=training)
1578
+ return merged_features
1579
+
1580
+ def compute_output_shape(self, input_shape):
1581
+ shape_after_conv1 = self.conv1.compute_output_shape(input_shape)
1582
+ shape_for_concat_list = list(shape_after_conv1)
1583
+ shape_for_concat_list[self.channel_axis] = (
1584
+ self.conv_dim * 2 + self.conv4_dim * 2
1585
+ )
1586
+ shape_for_concat = tuple(shape_for_concat_list)
1587
+ return self.conv4.compute_output_shape(shape_for_concat)
1588
+
1589
+ def get_config(self):
1590
+ config = super().get_config()
1591
+ config.update(
1592
+ {
1593
+ "encoder_hidden_dim": self.encoder_hidden_dim,
1594
+ "hidden_expansion": self.hidden_expansion,
1595
+ "batch_norm_eps": self.batch_norm_eps,
1596
+ "activation_function": self.activation_function,
1597
+ "num_blocks": self.num_blocks,
1598
+ "kernel_initializer": keras.initializers.serialize(
1599
+ self.kernel_initializer
1600
+ ),
1601
+ "bias_initializer": keras.initializers.serialize(
1602
+ self.bias_initializer
1603
+ ),
1604
+ "channel_axis": self.channel_axis,
1605
+ }
1606
+ )
1607
+ return config
1608
+
1609
+
1610
+ class DFineSCDown(keras.layers.Layer):
1611
+ """Downsampling layer using convolutions.
1612
+
1613
+ This layer is used in the `DFineHybridEncoder` to perform downsampling.
1614
+ Specifically, it is part of the Path Aggregation Network (PAN) bottom-up
1615
+ pathway, reducing the spatial resolution of feature maps.
1616
+
1617
+ Args:
1618
+ encoder_hidden_dim: int, The hidden dimension of the encoder.
1619
+ batch_norm_eps: float, The epsilon value for batch normalization.
1620
+ kernel_size: int, The kernel size for the second convolution.
1621
+ stride: int, The stride for the second convolution.
1622
+ kernel_initializer: str or Initializer, optional, Initializer for
1623
+ the kernel weights. Defaults to `"glorot_uniform"`.
1624
+ bias_initializer: str or Initializer, optional, Initializer for
1625
+ the bias weights. Defaults to `"zeros"`.
1626
+ channel_axis: int, optional, The channel axis. Defaults to `None`.
1627
+ **kwargs: Additional keyword arguments passed to the parent class.
1628
+ """
1629
+
1630
+ def __init__(
1631
+ self,
1632
+ encoder_hidden_dim,
1633
+ batch_norm_eps,
1634
+ kernel_size,
1635
+ stride,
1636
+ kernel_initializer="glorot_uniform",
1637
+ bias_initializer="zeros",
1638
+ channel_axis=None,
1639
+ dtype=None,
1640
+ **kwargs,
1641
+ ):
1642
+ super().__init__(dtype=dtype, **kwargs)
1643
+ self.encoder_hidden_dim = encoder_hidden_dim
1644
+ self.batch_norm_eps = batch_norm_eps
1645
+ self.conv2_kernel_size = kernel_size
1646
+ self.conv2_stride = stride
1647
+ self.kernel_initializer = keras.initializers.get(kernel_initializer)
1648
+ self.bias_initializer = keras.initializers.get(bias_initializer)
1649
+ self.channel_axis = channel_axis
1650
+ self.conv1 = DFineConvNormLayer(
1651
+ filters=self.encoder_hidden_dim,
1652
+ kernel_size=1,
1653
+ batch_norm_eps=self.batch_norm_eps,
1654
+ stride=1,
1655
+ groups=1,
1656
+ padding=0,
1657
+ activation_function=None,
1658
+ dtype=self.dtype_policy,
1659
+ kernel_initializer=self.kernel_initializer,
1660
+ bias_initializer=self.bias_initializer,
1661
+ channel_axis=self.channel_axis,
1662
+ name="conv1",
1663
+ )
1664
+ self.conv2 = DFineConvNormLayer(
1665
+ filters=self.encoder_hidden_dim,
1666
+ kernel_size=self.conv2_kernel_size,
1667
+ batch_norm_eps=self.batch_norm_eps,
1668
+ stride=self.conv2_stride,
1669
+ groups=self.encoder_hidden_dim,
1670
+ padding=(self.conv2_kernel_size - 1) // 2,
1671
+ activation_function=None,
1672
+ dtype=self.dtype_policy,
1673
+ kernel_initializer=self.kernel_initializer,
1674
+ bias_initializer=self.bias_initializer,
1675
+ channel_axis=self.channel_axis,
1676
+ name="conv2",
1677
+ )
1678
+
1679
+ def build(self, input_shape):
1680
+ self.conv1.build(input_shape)
1681
+ shape_after_conv1 = self.conv1.compute_output_shape(input_shape)
1682
+ self.conv2.build(shape_after_conv1)
1683
+ super().build(input_shape)
1684
+
1685
+ def call(self, input_features, training=None):
1686
+ x = self.conv1(input_features, training=training)
1687
+ x = self.conv2(x, training=training)
1688
+ return x
1689
+
1690
+ def compute_output_shape(self, input_shape):
1691
+ shape_after_conv1 = self.conv1.compute_output_shape(input_shape)
1692
+ return self.conv2.compute_output_shape(shape_after_conv1)
1693
+
1694
+ def get_config(self):
1695
+ config = super().get_config()
1696
+ config.update(
1697
+ {
1698
+ "encoder_hidden_dim": self.encoder_hidden_dim,
1699
+ "batch_norm_eps": self.batch_norm_eps,
1700
+ "kernel_size": self.conv2_kernel_size,
1701
+ "stride": self.conv2_stride,
1702
+ "kernel_initializer": keras.initializers.serialize(
1703
+ self.kernel_initializer
1704
+ ),
1705
+ "bias_initializer": keras.initializers.serialize(
1706
+ self.bias_initializer
1707
+ ),
1708
+ "channel_axis": self.channel_axis,
1709
+ }
1710
+ )
1711
+ return config
1712
+
1713
+
1714
+ class DFineMLPPredictionHead(keras.layers.Layer):
1715
+ """MLP head for making predictions from feature vectors.
1716
+
1717
+ This layer is a generic MLP used for various prediction tasks in D-FINE.
1718
+ It is used for the encoder's bounding box head (`enc_bbox_head` in
1719
+ `DFineBackbone`), the decoder's bounding box embedding (`bbox_embed` in
1720
+ `DFineDecoder`), and the query position head (`query_pos_head` in
1721
+ `DFineDecoder`).
1722
+
1723
+ Args:
1724
+ input_dim: int, The input dimension.
1725
+ hidden_dim: int, The hidden dimension for intermediate layers.
1726
+ output_dim: int, The output dimension.
1727
+ num_layers: int, The number of layers in the MLP.
1728
+ kernel_initializer: str or Initializer, optional, Initializer for
1729
+ the kernel weights. Defaults to `"glorot_uniform"`.
1730
+ bias_initializer: str or Initializer, optional, Initializer for
1731
+ the bias weights. Defaults to `"zeros"`.
1732
+ last_layer_initializer: str or Initializer, optional, Special
1733
+ initializer for the final layer's weights and biases. If `None`,
1734
+ uses `kernel_initializer` and `bias_initializer`. Defaults to
1735
+ `None`.
1736
+ **kwargs: Additional keyword arguments passed to the parent class.
1737
+ """
1738
+
1739
+ def __init__(
1740
+ self,
1741
+ input_dim,
1742
+ hidden_dim,
1743
+ output_dim,
1744
+ num_layers,
1745
+ kernel_initializer="glorot_uniform",
1746
+ bias_initializer="zeros",
1747
+ last_layer_initializer=None,
1748
+ dtype=None,
1749
+ **kwargs,
1750
+ ):
1751
+ super().__init__(dtype=dtype, **kwargs)
1752
+ self.input_dim = input_dim
1753
+ self.hidden_dim = hidden_dim
1754
+ self.output_dim = output_dim
1755
+ self.num_layers = num_layers
1756
+ self.kernel_initializer = keras.initializers.get(kernel_initializer)
1757
+ self.bias_initializer = keras.initializers.get(bias_initializer)
1758
+ self.last_layer_initializer = keras.initializers.get(
1759
+ last_layer_initializer
1760
+ )
1761
+
1762
+ h = [self.hidden_dim] * (self.num_layers - 1)
1763
+ input_dims = [self.input_dim] + h
1764
+ output_dims = h + [self.output_dim]
1765
+
1766
+ self.dense_layers = []
1767
+ for i, (_, out_dim) in enumerate(zip(input_dims, output_dims)):
1768
+ is_last_layer = i == self.num_layers - 1
1769
+ current_kernel_init = self.kernel_initializer
1770
+ current_bias_init = self.bias_initializer
1771
+ if is_last_layer and self.last_layer_initializer is not None:
1772
+ current_kernel_init = self.last_layer_initializer
1773
+ current_bias_init = self.last_layer_initializer
1774
+ self.dense_layers.append(
1775
+ keras.layers.Dense(
1776
+ units=out_dim,
1777
+ name=f"linear_{i}",
1778
+ dtype=self.dtype_policy,
1779
+ kernel_initializer=current_kernel_init,
1780
+ bias_initializer=current_bias_init,
1781
+ )
1782
+ )
1783
+
1784
+ def build(self, input_shape):
1785
+ if self.dense_layers:
1786
+ current_build_shape = input_shape
1787
+ for i, dense_layer in enumerate(self.dense_layers):
1788
+ dense_layer.build(current_build_shape)
1789
+ current_build_shape = dense_layer.compute_output_shape(
1790
+ current_build_shape
1791
+ )
1792
+ super().build(input_shape)
1793
+
1794
+ def call(self, x, training=None):
1795
+ current_x = x
1796
+ for i, layer in enumerate(self.dense_layers):
1797
+ current_x = layer(current_x)
1798
+ if i < self.num_layers - 1:
1799
+ current_x = keras.ops.relu(current_x)
1800
+ return current_x
1801
+
1802
+ def compute_output_spec(self, x_spec):
1803
+ output_shape = list(x_spec.shape)
1804
+ output_shape[-1] = self.output_dim
1805
+ return keras.KerasTensor(
1806
+ shape=tuple(output_shape), dtype=self.compute_dtype
1807
+ )
1808
+
1809
+ def get_config(self):
1810
+ config = super().get_config()
1811
+ config.update(
1812
+ {
1813
+ "input_dim": self.input_dim,
1814
+ "hidden_dim": self.hidden_dim,
1815
+ "output_dim": self.output_dim,
1816
+ "num_layers": self.num_layers,
1817
+ "kernel_initializer": keras.initializers.serialize(
1818
+ self.kernel_initializer
1819
+ ),
1820
+ "bias_initializer": keras.initializers.serialize(
1821
+ self.bias_initializer
1822
+ ),
1823
+ "last_layer_initializer": keras.initializers.serialize(
1824
+ self.last_layer_initializer
1825
+ ),
1826
+ }
1827
+ )
1828
+ return config