keras-hub 0.25.0.dev0__py3-none-any.whl → 0.26.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. keras_hub/layers/__init__.py +21 -0
  2. keras_hub/models/__init__.py +27 -0
  3. keras_hub/src/layers/modeling/non_max_supression.py +5 -2
  4. keras_hub/src/layers/modeling/reversible_embedding.py +2 -275
  5. keras_hub/src/layers/modeling/token_and_position_embedding.py +6 -6
  6. keras_hub/src/layers/modeling/transformer_layer_utils.py +9 -9
  7. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +3 -1
  8. keras_hub/src/layers/preprocessing/multi_segment_packer.py +3 -1
  9. keras_hub/src/models/albert/albert_backbone.py +1 -3
  10. keras_hub/src/models/backbone.py +3 -0
  11. keras_hub/src/models/bart/bart_backbone.py +1 -3
  12. keras_hub/src/models/bert/bert_backbone.py +2 -4
  13. keras_hub/src/models/bloom/bloom_backbone.py +1 -3
  14. keras_hub/src/models/causal_lm.py +2 -2
  15. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -3
  16. keras_hub/src/models/edrec/edrec_backbone.py +147 -0
  17. keras_hub/src/models/edrec/edrec_layers.py +434 -0
  18. keras_hub/src/models/edrec/edrec_seq2seq_lm.py +273 -0
  19. keras_hub/src/models/electra/electra_backbone.py +1 -3
  20. keras_hub/src/models/f_net/f_net_backbone.py +1 -3
  21. keras_hub/src/models/falcon/falcon_backbone.py +1 -3
  22. keras_hub/src/models/flux/flux_layers.py +3 -3
  23. keras_hub/src/models/flux/flux_maths.py +29 -15
  24. keras_hub/src/models/gemma/gemma_backbone.py +1 -3
  25. keras_hub/src/models/gemma/gemma_causal_lm.py +1 -1
  26. keras_hub/src/models/gemma3/gemma3_attention.py +1 -1
  27. keras_hub/src/models/gemma3/gemma3_backbone.py +70 -8
  28. keras_hub/src/models/gemma3/gemma3_causal_lm.py +16 -1
  29. keras_hub/src/models/gemma3/gemma3_decoder_block.py +23 -3
  30. keras_hub/src/models/gemma3/{gemma3_interleave_embeddings.py → gemma3_layers.py} +101 -0
  31. keras_hub/src/models/gemma3/gemma3_presets.py +79 -7
  32. keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
  33. keras_hub/src/models/gpt2/gpt2_backbone.py +1 -3
  34. keras_hub/src/models/gpt2/gpt2_causal_lm.py +1 -1
  35. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +1 -3
  36. keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +1 -3
  37. keras_hub/src/models/llama/llama_backbone.py +1 -3
  38. keras_hub/src/models/masked_lm.py +1 -1
  39. keras_hub/src/models/mistral/mistral_backbone.py +1 -3
  40. keras_hub/src/models/mixtral/mixtral_backbone.py +1 -3
  41. keras_hub/src/models/moonshine/moonshine_backbone.py +1 -3
  42. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +1 -3
  43. keras_hub/src/models/parseq/parseq_tokenizer.py +3 -1
  44. keras_hub/src/models/phi3/phi3_backbone.py +1 -3
  45. keras_hub/src/models/qwen/qwen_backbone.py +1 -3
  46. keras_hub/src/models/qwen/qwen_presets.py +209 -0
  47. keras_hub/src/models/qwen3/qwen3_backbone.py +1 -3
  48. keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +1 -3
  49. keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +15 -0
  50. keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +1 -3
  51. keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -3
  52. keras_hub/src/models/rqvae/__init__.py +5 -0
  53. keras_hub/src/models/rqvae/rqvae_backbone.py +167 -0
  54. keras_hub/src/models/rqvae/rqvae_layers.py +335 -0
  55. keras_hub/src/models/rwkv7/__init__.py +5 -0
  56. keras_hub/src/models/rwkv7/rwkv7_backbone.py +180 -0
  57. keras_hub/src/models/rwkv7/rwkv7_causal_lm.py +259 -0
  58. keras_hub/src/models/rwkv7/rwkv7_causal_lm_preprocessor.py +214 -0
  59. keras_hub/src/models/rwkv7/rwkv7_layer.py +724 -0
  60. keras_hub/src/models/rwkv7/rwkv7_presets.py +26 -0
  61. keras_hub/src/models/rwkv7/rwkv7_tokenizer.py +495 -0
  62. keras_hub/src/models/sam/sam_backbone.py +5 -1
  63. keras_hub/src/models/sam/sam_prompt_encoder.py +1 -1
  64. keras_hub/src/models/sam3/__init__.py +7 -0
  65. keras_hub/src/models/sam3/roi_align.py +222 -0
  66. keras_hub/src/models/sam3/sam3_detr_decoder.py +641 -0
  67. keras_hub/src/models/sam3/sam3_detr_encoder.py +293 -0
  68. keras_hub/src/models/sam3/sam3_dot_product_scoring.py +120 -0
  69. keras_hub/src/models/sam3/sam3_geometry_encoder.py +517 -0
  70. keras_hub/src/models/sam3/sam3_image_converter.py +10 -0
  71. keras_hub/src/models/sam3/sam3_layers.py +814 -0
  72. keras_hub/src/models/sam3/sam3_mask_decoder.py +374 -0
  73. keras_hub/src/models/sam3/sam3_pc_backbone.py +306 -0
  74. keras_hub/src/models/sam3/sam3_pc_image_segmenter.py +282 -0
  75. keras_hub/src/models/sam3/sam3_pc_image_segmenter_preprocessor.py +336 -0
  76. keras_hub/src/models/sam3/sam3_presets.py +16 -0
  77. keras_hub/src/models/sam3/sam3_text_encoder.py +212 -0
  78. keras_hub/src/models/sam3/sam3_tokenizer.py +65 -0
  79. keras_hub/src/models/sam3/sam3_utils.py +134 -0
  80. keras_hub/src/models/sam3/sam3_vision_encoder.py +738 -0
  81. keras_hub/src/models/segformer/segformer_backbone.py +6 -6
  82. keras_hub/src/models/siglip/siglip_layers.py +1 -3
  83. keras_hub/src/models/smollm3/smollm3_backbone.py +1 -3
  84. keras_hub/src/models/stable_diffusion_3/t5_encoder.py +1 -3
  85. keras_hub/src/models/t5/t5_backbone.py +1 -3
  86. keras_hub/src/models/t5gemma/t5gemma_backbone.py +1 -3
  87. keras_hub/src/models/task.py +1 -1
  88. keras_hub/src/tests/test_case.py +394 -3
  89. keras_hub/src/tokenizers/byte_pair_tokenizer.py +33 -2
  90. keras_hub/src/tokenizers/byte_tokenizer.py +3 -1
  91. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +15 -1
  92. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +3 -1
  93. keras_hub/src/tokenizers/word_piece_tokenizer.py +15 -1
  94. keras_hub/src/utils/preset_utils.py +1 -1
  95. keras_hub/src/utils/tensor_utils.py +12 -0
  96. keras_hub/src/utils/transformers/convert_gemma3.py +68 -22
  97. keras_hub/src/utils/transformers/convert_qwen3_moe.py +4 -1
  98. keras_hub/src/utils/transformers/convert_sam3.py +472 -0
  99. keras_hub/src/utils/transformers/export/gemma3.py +196 -0
  100. keras_hub/src/utils/transformers/export/hf_exporter.py +86 -25
  101. keras_hub/src/utils/transformers/export/qwen.py +136 -0
  102. keras_hub/src/utils/transformers/preset_loader.py +15 -1
  103. keras_hub/src/version.py +1 -1
  104. keras_hub/tokenizers/__init__.py +6 -0
  105. {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/METADATA +6 -13
  106. {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/RECORD +108 -76
  107. {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/WHEEL +1 -1
  108. keras_hub/src/models/gemma3/rms_normalization.py +0 -26
  109. {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,641 @@
1
+ import math
2
+
3
+ from keras import layers
4
+ from keras import ops
5
+
6
+ from keras_hub.src.api_export import keras_hub_export
7
+ from keras_hub.src.models.sam3.sam3_layers import SAM3MLP
8
+ from keras_hub.src.models.sam3.sam3_layers import SAM3Attention
9
+ from keras_hub.src.models.sam3.sam3_layers import SAM3DecoderMLP
10
+ from keras_hub.src.models.sam3.sam3_layers import SAM3SinePositionEmbedding
11
+ from keras_hub.src.models.sam3.sam3_utils import box_cxcywh_to_xyxy
12
+ from keras_hub.src.models.sam3.sam3_utils import create_bidirectional_mask
13
+ from keras_hub.src.models.sam3.sam3_utils import inverse_sigmoid
14
+
15
+
16
+ class SAM3DetrDecoderLayer(layers.Layer):
17
+ def __init__(
18
+ self,
19
+ hidden_dim,
20
+ intermediate_dim,
21
+ num_heads,
22
+ hidden_activation="relu",
23
+ dropout_rate=0.0,
24
+ layer_norm_epsilon=1e-6,
25
+ **kwargs,
26
+ ):
27
+ super().__init__(**kwargs)
28
+ self.hidden_dim = int(hidden_dim)
29
+ self.intermediate_dim = int(intermediate_dim)
30
+ self.num_heads = int(num_heads)
31
+ self.dropout_rate = float(dropout_rate)
32
+ self.hidden_activation = hidden_activation
33
+ self.layer_norm_epsilon = float(layer_norm_epsilon)
34
+
35
+ self.self_attn = SAM3Attention(
36
+ hidden_dim=self.hidden_dim,
37
+ num_heads=self.num_heads,
38
+ dtype=self.dtype_policy,
39
+ name="self_attn",
40
+ )
41
+ self.self_attn_dropout = layers.Dropout(
42
+ rate=self.dropout_rate,
43
+ dtype=self.dtype_policy,
44
+ name="self_attn_dropout",
45
+ )
46
+ self.self_attn_layer_norm = layers.LayerNormalization(
47
+ epsilon=self.layer_norm_epsilon,
48
+ dtype=self.dtype_policy,
49
+ name="self_attn_layer_norm",
50
+ )
51
+ self.text_cross_attn = SAM3Attention(
52
+ hidden_dim=self.hidden_dim,
53
+ num_heads=self.num_heads,
54
+ dtype=self.dtype_policy,
55
+ name="text_cross_attn",
56
+ )
57
+ self.text_cross_attn_dropout = layers.Dropout(
58
+ rate=self.dropout_rate,
59
+ dtype=self.dtype_policy,
60
+ name="text_cross_attn_dropout",
61
+ )
62
+ self.text_cross_attn_layer_norm = layers.LayerNormalization(
63
+ epsilon=self.layer_norm_epsilon,
64
+ dtype=self.dtype_policy,
65
+ name="text_cross_attn_layer_norm",
66
+ )
67
+ self.vision_cross_attn = SAM3Attention(
68
+ hidden_dim=self.hidden_dim,
69
+ num_heads=self.num_heads,
70
+ dtype=self.dtype_policy,
71
+ name="vision_cross_attn",
72
+ )
73
+ self.vision_cross_attn_dropout = layers.Dropout(
74
+ rate=self.dropout_rate,
75
+ dtype=self.dtype_policy,
76
+ name="vision_cross_attn_dropout",
77
+ )
78
+ self.vision_cross_attn_layer_norm = layers.LayerNormalization(
79
+ epsilon=self.layer_norm_epsilon,
80
+ dtype=self.dtype_policy,
81
+ name="vision_cross_attn_layer_norm",
82
+ )
83
+ self.mlp = SAM3MLP(
84
+ hidden_dim=self.hidden_dim,
85
+ intermediate_dim=self.intermediate_dim,
86
+ activation=self.hidden_activation,
87
+ dropout_rate=self.dropout_rate,
88
+ dtype=self.dtype_policy,
89
+ name="mlp",
90
+ )
91
+ self.mlp_dropout = layers.Dropout(
92
+ rate=self.dropout_rate,
93
+ dtype=self.dtype_policy,
94
+ name="mlp_dropout",
95
+ )
96
+ self.mlp_layer_norm = layers.LayerNormalization(
97
+ epsilon=self.layer_norm_epsilon,
98
+ dtype=self.dtype_policy,
99
+ name="mlp_layer_norm",
100
+ )
101
+
102
+ def build(
103
+ self,
104
+ hidden_states_shape,
105
+ query_pos_shape,
106
+ text_features_shape,
107
+ vision_features_shape,
108
+ vision_pos_encodings_shape,
109
+ text_cross_attn_masks_shape,
110
+ vision_cross_attn_masks_shape,
111
+ ):
112
+ self.self_attn.build(
113
+ hidden_states_shape, hidden_states_shape, hidden_states_shape
114
+ )
115
+ self.self_attn_dropout.build(hidden_states_shape)
116
+ self.self_attn_layer_norm.build(hidden_states_shape)
117
+ self.text_cross_attn.build(
118
+ hidden_states_shape, text_features_shape, text_features_shape
119
+ )
120
+ self.text_cross_attn_dropout.build(hidden_states_shape)
121
+ self.text_cross_attn_layer_norm.build(hidden_states_shape)
122
+ self.vision_cross_attn.build(
123
+ hidden_states_shape, hidden_states_shape, vision_features_shape
124
+ )
125
+ self.vision_cross_attn_dropout.build(hidden_states_shape)
126
+ self.vision_cross_attn_layer_norm.build(hidden_states_shape)
127
+ self.mlp.build(hidden_states_shape)
128
+ self.mlp_dropout.build(hidden_states_shape)
129
+ self.mlp_layer_norm.build(hidden_states_shape)
130
+
131
+ def call(
132
+ self,
133
+ hidden_states,
134
+ query_pos,
135
+ text_features,
136
+ vision_features,
137
+ vision_pos_encodings,
138
+ text_cross_attn_masks,
139
+ vision_cross_attn_masks,
140
+ training=None,
141
+ ):
142
+ # Prepend zeros to query_pos for presence token.
143
+ query_pos = ops.pad(query_pos, [[0, 0], [1, 0], [0, 0]])
144
+
145
+ # Self-attention with query position encoding.
146
+ residual = hidden_states
147
+ query_with_pos = ops.add(hidden_states, query_pos)
148
+ attn_output = self.self_attn(
149
+ query=query_with_pos,
150
+ key=query_with_pos,
151
+ value=hidden_states,
152
+ attention_mask=None,
153
+ training=training,
154
+ )
155
+ hidden_states = ops.add(
156
+ residual, self.self_attn_dropout(attn_output, training=training)
157
+ )
158
+ hidden_states = self.self_attn_layer_norm(hidden_states)
159
+
160
+ # Text cross-attention: queries attend to text features.
161
+ residual = hidden_states
162
+ query_with_pos = ops.add(hidden_states, query_pos)
163
+ attn_output = self.text_cross_attn(
164
+ query=query_with_pos,
165
+ key=text_features,
166
+ value=text_features,
167
+ attention_mask=text_cross_attn_masks,
168
+ training=training,
169
+ )
170
+ hidden_states = ops.add(
171
+ residual,
172
+ self.text_cross_attn_dropout(attn_output, training=training),
173
+ )
174
+ hidden_states = self.text_cross_attn_layer_norm(hidden_states)
175
+
176
+ # Vision cross-attention: queries attend to vision features (with RPB)
177
+ residual = hidden_states
178
+ query_with_pos = ops.add(hidden_states, query_pos)
179
+ key_with_pos = ops.add(vision_features, vision_pos_encodings)
180
+ attn_output = self.vision_cross_attn(
181
+ query=query_with_pos,
182
+ key=key_with_pos,
183
+ value=vision_features,
184
+ attention_bias=vision_cross_attn_masks,
185
+ training=training,
186
+ )
187
+ hidden_states = ops.add(
188
+ residual,
189
+ self.vision_cross_attn_dropout(attn_output, training=training),
190
+ )
191
+ hidden_states = self.vision_cross_attn_layer_norm(hidden_states)
192
+
193
+ # MLP.
194
+ residual = hidden_states
195
+ hidden_states = self.mlp(hidden_states, training=training)
196
+ hidden_states = ops.add(
197
+ residual, self.mlp_dropout(hidden_states, training=training)
198
+ )
199
+ hidden_states = self.mlp_layer_norm(hidden_states)
200
+ return hidden_states
201
+
202
+ def get_config(self):
203
+ config = super().get_config()
204
+ config.update(
205
+ {
206
+ "hidden_dim": self.hidden_dim,
207
+ "intermediate_dim": self.intermediate_dim,
208
+ "num_heads": self.num_heads,
209
+ "hidden_activation": self.hidden_activation,
210
+ "dropout_rate": self.dropout_rate,
211
+ "layer_norm_epsilon": self.layer_norm_epsilon,
212
+ }
213
+ )
214
+ return config
215
+
216
+ def compute_output_shape(
217
+ self,
218
+ hidden_states_shape,
219
+ query_pos_shape,
220
+ text_features_shape,
221
+ vision_features_shape,
222
+ vision_pos_encodings_shape,
223
+ text_cross_attn_masks_shape,
224
+ vision_cross_attn_masks_shape,
225
+ ):
226
+ return hidden_states_shape
227
+
228
+
229
+ @keras_hub_export("keras_hub.layers.SAM3DetrDecoder")
230
+ class SAM3DetrDecoder(layers.Layer):
231
+ """A DETR decoder for the Segment Anything Model 3 (SAM3).
232
+
233
+ This layer implements a transformer-based decoder that predicts object
234
+ queries. It processes object queries and fused features through multiple
235
+ layers of self-attention and cross-attention.
236
+
237
+ Args:
238
+ image_shape: tuple. The shape of the input image
239
+ (height, width, channels).
240
+ patch_size: int. The size of the patches to be extracted from the image.
241
+ num_layers: int. The number of transformer layers.
242
+ hidden_dim: int. The hidden dimension of the transformer layers.
243
+ intermediate_dim: int. The dimension of the intermediate layer in the
244
+ transformer's MLP.
245
+ num_heads: int. The number of attention heads.
246
+ num_queries: int. The number of object queries.
247
+ hidden_activation: str. The activation function for the transformer
248
+ layers. Defaults to `"relu"`.
249
+ dropout_rate: float. The dropout rate for the MLP and attention.
250
+ Defaults to `0.0`.
251
+ layer_norm_epsilon: float. The epsilon value for layer normalization.
252
+ Defaults to `1e-6`.
253
+ """
254
+
255
+ def __init__(
256
+ self,
257
+ image_shape,
258
+ patch_size,
259
+ num_layers,
260
+ hidden_dim,
261
+ intermediate_dim,
262
+ num_heads,
263
+ num_queries,
264
+ hidden_activation="relu",
265
+ dropout_rate=0.0,
266
+ layer_norm_epsilon=1e-6,
267
+ **kwargs,
268
+ ):
269
+ super().__init__(**kwargs)
270
+ self.image_shape = (
271
+ int(image_shape[0]),
272
+ int(image_shape[1]),
273
+ int(image_shape[2]),
274
+ )
275
+ self.patch_size = int(patch_size)
276
+ self.num_layers = int(num_layers)
277
+ self.hidden_dim = int(hidden_dim)
278
+ self.intermediate_dim = int(intermediate_dim)
279
+ self.num_heads = int(num_heads)
280
+ self.num_queries = int(num_queries)
281
+ self.hidden_activation = hidden_activation
282
+ self.dropout_rate = float(dropout_rate)
283
+ self.layer_norm_epsilon = float(layer_norm_epsilon)
284
+ self.height = self.image_shape[0] // self.patch_size
285
+ self.width = self.image_shape[1] // self.patch_size
286
+
287
+ self.layers = [
288
+ SAM3DetrDecoderLayer(
289
+ hidden_dim=self.hidden_dim,
290
+ intermediate_dim=self.intermediate_dim,
291
+ num_heads=self.num_heads,
292
+ hidden_activation=self.hidden_activation,
293
+ dropout_rate=self.dropout_rate,
294
+ layer_norm_epsilon=self.layer_norm_epsilon,
295
+ dtype=self.dtype_policy,
296
+ name=f"layer_{i}",
297
+ )
298
+ for i in range(self.num_layers)
299
+ ]
300
+ self.output_layer_norm = layers.LayerNormalization(
301
+ epsilon=self.layer_norm_epsilon,
302
+ dtype=self.dtype_policy,
303
+ name="output_layer_norm",
304
+ )
305
+ self.box_head = SAM3DecoderMLP(
306
+ num_layers=3,
307
+ hidden_dim=self.hidden_dim,
308
+ output_dim=4,
309
+ dtype=self.dtype_policy,
310
+ name="box_head",
311
+ )
312
+ self.query_embed = layers.Embedding(
313
+ self.num_queries,
314
+ self.hidden_dim,
315
+ dtype=self.dtype_policy,
316
+ name="query_embed",
317
+ )
318
+ self.reference_points = layers.Embedding(
319
+ self.num_queries,
320
+ 4,
321
+ dtype=self.dtype_policy,
322
+ name="reference_points",
323
+ )
324
+ self.presence_token = layers.Embedding(
325
+ 1,
326
+ self.hidden_dim,
327
+ dtype=self.dtype_policy,
328
+ name="presence_token",
329
+ )
330
+ self.presence_head = SAM3DecoderMLP(
331
+ num_layers=3,
332
+ hidden_dim=self.hidden_dim,
333
+ output_dim=1,
334
+ dtype=self.dtype_policy,
335
+ name="presence_head",
336
+ )
337
+ self.presence_layer_norm = layers.LayerNormalization(
338
+ epsilon=self.layer_norm_epsilon,
339
+ dtype=self.dtype_policy,
340
+ name="presence_layer_norm",
341
+ )
342
+ self.clamp_presence_logit_max_val = 10.0
343
+ self.ref_point_head = SAM3DecoderMLP(
344
+ num_layers=2,
345
+ hidden_dim=self.hidden_dim,
346
+ output_dim=self.hidden_dim,
347
+ dtype=self.dtype_policy,
348
+ name="ref_point_head",
349
+ )
350
+ self.box_rpb_embed_x = SAM3DecoderMLP(
351
+ num_layers=2,
352
+ hidden_dim=self.hidden_dim,
353
+ output_dim=self.num_heads,
354
+ dtype=self.dtype_policy,
355
+ name="box_rpb_embed_x",
356
+ )
357
+ self.box_rpb_embed_y = SAM3DecoderMLP(
358
+ num_layers=2,
359
+ hidden_dim=self.hidden_dim,
360
+ output_dim=self.num_heads,
361
+ dtype=self.dtype_policy,
362
+ name="box_rpb_embed_y",
363
+ )
364
+ self.position_encoding = SAM3SinePositionEmbedding(
365
+ num_pos_feats=self.hidden_dim // 2,
366
+ normalize=False,
367
+ dtype=self.dtype_policy,
368
+ name="position_encoding",
369
+ )
370
+
371
+ def build(
372
+ self,
373
+ vision_features_shape,
374
+ text_features_shape,
375
+ vision_pos_encodings_shape,
376
+ text_masks_shape,
377
+ ):
378
+ self.query_embed.build()
379
+ self.reference_points.build()
380
+ self.presence_token.build()
381
+ self.position_encoding.build()
382
+ batch_size = vision_features_shape[0]
383
+ vision_len = vision_features_shape[1]
384
+ hidden_states_shape = [
385
+ batch_size,
386
+ 1 + self.num_queries,
387
+ self.hidden_dim,
388
+ ]
389
+ text_cross_attn_masks_shape = [
390
+ batch_size,
391
+ 1,
392
+ 1 + self.num_queries,
393
+ text_masks_shape[-1],
394
+ ]
395
+ query_pos_shape = [batch_size, self.num_queries, self.hidden_dim]
396
+ vision_cross_attn_masks_shape = [
397
+ batch_size,
398
+ self.num_heads,
399
+ 1 + self.num_queries,
400
+ vision_len,
401
+ ]
402
+ query_hidden_state_shape = [
403
+ batch_size,
404
+ self.num_queries,
405
+ self.hidden_dim,
406
+ ]
407
+ presence_hidden_shape = [batch_size, 1, self.hidden_dim]
408
+ query_sine_embed_shape = [
409
+ batch_size,
410
+ self.num_queries,
411
+ self.hidden_dim // 2 * 4,
412
+ ]
413
+ deltas_x_log_shape = [batch_size, self.num_queries, self.width, 2]
414
+ deltas_y_log_shape = [batch_size, self.num_queries, self.height, 2]
415
+
416
+ self.output_layer_norm.build(query_hidden_state_shape)
417
+ self.box_head.build(query_hidden_state_shape)
418
+ self.presence_layer_norm.build(presence_hidden_shape)
419
+ self.presence_head.build(presence_hidden_shape)
420
+ self.ref_point_head.build(query_sine_embed_shape)
421
+ self.box_rpb_embed_x.build(deltas_x_log_shape)
422
+ self.box_rpb_embed_y.build(deltas_y_log_shape)
423
+ for layer in self.layers:
424
+ layer.build(
425
+ hidden_states_shape,
426
+ query_pos_shape,
427
+ text_features_shape,
428
+ vision_features_shape,
429
+ vision_pos_encodings_shape,
430
+ text_cross_attn_masks_shape,
431
+ vision_cross_attn_masks_shape,
432
+ )
433
+
434
+ def _get_coords(self, height, width, dtype):
435
+ coords_h = ops.divide(ops.arange(height, dtype=dtype), height)
436
+ coords_w = ops.divide(ops.arange(width, dtype=dtype), width)
437
+ return coords_h, coords_w
438
+
439
+ def _get_rpb_matrix(self, reference_boxes):
440
+ boxes_xyxy = box_cxcywh_to_xyxy(reference_boxes)
441
+
442
+ # Generate coordinate grids.
443
+ coords_h, coords_w = self._get_coords(
444
+ self.height, self.width, reference_boxes.dtype
445
+ )
446
+
447
+ # Compute deltas between coordinates and box boundaries.
448
+ deltas_y = ops.subtract(
449
+ ops.reshape(coords_h, (1, -1, 1)),
450
+ ops.reshape(boxes_xyxy, (-1, 1, 4))[:, :, 1:4:2],
451
+ )
452
+ deltas_y = ops.reshape(deltas_y, (-1, self.num_queries, self.height, 2))
453
+ deltas_x = ops.subtract(
454
+ ops.reshape(coords_w, (1, -1, 1)),
455
+ ops.reshape(boxes_xyxy, (-1, 1, 4))[:, :, 0:3:2],
456
+ )
457
+ deltas_x = ops.reshape(deltas_x, (-1, self.num_queries, self.width, 2))
458
+
459
+ # Apply log-scale encoding.
460
+ deltas_x_log = ops.multiply(deltas_x, 8.0)
461
+ deltas_x_log = ops.divide(
462
+ ops.multiply(
463
+ ops.sign(deltas_x_log),
464
+ ops.log2(ops.add(ops.abs(deltas_x_log), 1.0)),
465
+ ),
466
+ math.log2(8),
467
+ )
468
+ deltas_y_log = ops.multiply(deltas_y, 8.0)
469
+ deltas_y_log = ops.divide(
470
+ ops.multiply(
471
+ ops.sign(deltas_y_log),
472
+ ops.log2(ops.add(ops.abs(deltas_y_log), 1.0)),
473
+ ),
474
+ math.log2(8),
475
+ )
476
+
477
+ # Embed deltas.
478
+ deltas_x = self.box_rpb_embed_x(deltas_x_log)
479
+ deltas_y = self.box_rpb_embed_y(deltas_y_log)
480
+
481
+ # Combine into 2D bias matrix.
482
+ rpb_matrix = ops.add(
483
+ ops.expand_dims(deltas_y, axis=3),
484
+ ops.expand_dims(deltas_x, axis=2),
485
+ )
486
+ rpb_matrix = ops.reshape(
487
+ rpb_matrix,
488
+ (-1, self.num_queries, self.height * self.width, self.num_heads),
489
+ )
490
+ rpb_matrix = ops.transpose(rpb_matrix, (0, 3, 1, 2))
491
+ return rpb_matrix
492
+
493
+ def call(
494
+ self,
495
+ vision_features,
496
+ text_features,
497
+ vision_pos_encodings,
498
+ text_masks,
499
+ training=None,
500
+ ):
501
+ batch_size = ops.shape(vision_features)[0]
502
+ query_embeds = ops.tile(
503
+ ops.expand_dims(self.query_embed.embeddings, axis=0),
504
+ [batch_size, 1, 1],
505
+ )
506
+ query_embeds = ops.cast(query_embeds, vision_features.dtype)
507
+ reference_boxes = ops.tile(
508
+ ops.expand_dims(self.reference_points.embeddings, axis=0),
509
+ [batch_size, 1, 1],
510
+ )
511
+ reference_boxes = ops.cast(reference_boxes, vision_features.dtype)
512
+ reference_boxes = ops.sigmoid(reference_boxes)
513
+ presence_token = ops.tile(
514
+ ops.expand_dims(self.presence_token.embeddings, axis=0),
515
+ [batch_size, 1, 1],
516
+ )
517
+ presence_token = ops.cast(presence_token, vision_features.dtype)
518
+
519
+ # Concatenate presence token with query embeddings
520
+ hidden_states = ops.concatenate([presence_token, query_embeds], axis=1)
521
+ text_cross_attn_masks = create_bidirectional_mask(
522
+ hidden_states, text_masks
523
+ )
524
+
525
+ intermediate_outputs = []
526
+ intermediate_boxes = [reference_boxes]
527
+ intermediate_presence_logits = []
528
+ for layer in self.layers:
529
+ # Generate sine embeddings for conditional queries.
530
+ reference_points_input = ops.expand_dims(reference_boxes, axis=2)
531
+ query_sine_embed = self.position_encoding.encode_boxes(
532
+ reference_points_input[:, :, 0, :]
533
+ )
534
+ query_pos = self.ref_point_head(query_sine_embed)
535
+
536
+ # Compute box relative position bias (RPB) attention mask.
537
+ rpb_matrix = self._get_rpb_matrix(reference_boxes)
538
+ vision_cross_attn_masks = ops.pad(
539
+ rpb_matrix, [[0, 0], [0, 0], [1, 0], [0, 0]]
540
+ )
541
+
542
+ hidden_states = layer(
543
+ hidden_states,
544
+ query_pos,
545
+ text_features,
546
+ vision_features,
547
+ vision_pos_encodings,
548
+ text_cross_attn_masks,
549
+ vision_cross_attn_masks,
550
+ training=training,
551
+ )
552
+
553
+ # Extract query hidden states (without presence token) for box
554
+ # refinement.
555
+ query_hidden_states = hidden_states[:, 1:]
556
+
557
+ # Box refinement: predict delta and update reference boxes.
558
+ reference_boxes_before_sigmoid = inverse_sigmoid(reference_boxes)
559
+ output_hidden_states = self.output_layer_norm(
560
+ query_hidden_states, training=training
561
+ )
562
+ delta_boxes = self.box_head(output_hidden_states, training=training)
563
+ new_reference_boxes = ops.sigmoid(
564
+ ops.add(delta_boxes, reference_boxes_before_sigmoid)
565
+ )
566
+ # For next layer.
567
+ reference_boxes = ops.stop_gradient(new_reference_boxes)
568
+
569
+ intermediate_outputs.append(output_hidden_states)
570
+ intermediate_boxes.append(reference_boxes)
571
+
572
+ # Process presence token.
573
+ presence_hidden = hidden_states[:, :1]
574
+ presence_logits = self.presence_head(
575
+ self.presence_layer_norm(presence_hidden, training=training),
576
+ training=training,
577
+ )
578
+ presence_logits = ops.squeeze(presence_logits, axis=-1)
579
+ presence_logits = ops.clip(
580
+ presence_logits,
581
+ -self.clamp_presence_logit_max_val,
582
+ self.clamp_presence_logit_max_val,
583
+ )
584
+ intermediate_presence_logits.append(presence_logits)
585
+
586
+ # Stack outputs from all layers.
587
+ intermediate_outputs = ops.stack(intermediate_outputs, axis=1)
588
+ intermediate_boxes = ops.stack(intermediate_boxes[:-1], axis=1)
589
+ intermediate_presence_logits = ops.stack(
590
+ intermediate_presence_logits, axis=1
591
+ )
592
+ return (
593
+ intermediate_outputs,
594
+ intermediate_boxes,
595
+ intermediate_presence_logits,
596
+ )
597
+
598
+ def get_config(self):
599
+ config = super().get_config()
600
+ config.update(
601
+ {
602
+ "image_shape": self.image_shape,
603
+ "patch_size": self.patch_size,
604
+ "num_layers": self.num_layers,
605
+ "hidden_dim": self.hidden_dim,
606
+ "intermediate_dim": self.intermediate_dim,
607
+ "num_heads": self.num_heads,
608
+ "num_queries": self.num_queries,
609
+ "hidden_activation": self.hidden_activation,
610
+ "dropout_rate": self.dropout_rate,
611
+ "layer_norm_epsilon": self.layer_norm_epsilon,
612
+ }
613
+ )
614
+ return config
615
+
616
+ def compute_output_shape(
617
+ self,
618
+ vision_features_shape,
619
+ text_features_shape,
620
+ vision_pos_encodings_shape,
621
+ text_masks_shape,
622
+ ):
623
+ batch_size = vision_features_shape[0]
624
+ intermediate_output_shape = [
625
+ batch_size,
626
+ self.num_layers,
627
+ self.num_queries,
628
+ self.hidden_dim,
629
+ ]
630
+ intermediate_boxes_shape = [
631
+ batch_size,
632
+ self.num_layers,
633
+ self.num_queries,
634
+ 4,
635
+ ]
636
+ intermediate_presence_logits_shape = [batch_size, self.num_layers, 1]
637
+ return (
638
+ intermediate_output_shape,
639
+ intermediate_boxes_shape,
640
+ intermediate_presence_logits_shape,
641
+ )