keras-hub 0.25.1__py3-none-any.whl → 0.26.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. keras_hub/layers/__init__.py +21 -0
  2. keras_hub/models/__init__.py +27 -0
  3. keras_hub/src/layers/modeling/non_max_supression.py +5 -2
  4. keras_hub/src/layers/modeling/reversible_embedding.py +2 -275
  5. keras_hub/src/layers/modeling/token_and_position_embedding.py +6 -6
  6. keras_hub/src/layers/modeling/transformer_layer_utils.py +9 -9
  7. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +3 -1
  8. keras_hub/src/layers/preprocessing/multi_segment_packer.py +3 -1
  9. keras_hub/src/models/albert/albert_backbone.py +1 -3
  10. keras_hub/src/models/backbone.py +3 -0
  11. keras_hub/src/models/bart/bart_backbone.py +1 -3
  12. keras_hub/src/models/bert/bert_backbone.py +2 -4
  13. keras_hub/src/models/bloom/bloom_backbone.py +1 -3
  14. keras_hub/src/models/causal_lm.py +2 -2
  15. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -3
  16. keras_hub/src/models/edrec/edrec_backbone.py +147 -0
  17. keras_hub/src/models/edrec/edrec_layers.py +434 -0
  18. keras_hub/src/models/edrec/edrec_seq2seq_lm.py +273 -0
  19. keras_hub/src/models/electra/electra_backbone.py +1 -3
  20. keras_hub/src/models/f_net/f_net_backbone.py +1 -3
  21. keras_hub/src/models/falcon/falcon_backbone.py +1 -3
  22. keras_hub/src/models/flux/flux_layers.py +3 -3
  23. keras_hub/src/models/flux/flux_maths.py +29 -15
  24. keras_hub/src/models/gemma/gemma_backbone.py +1 -3
  25. keras_hub/src/models/gemma/gemma_causal_lm.py +1 -1
  26. keras_hub/src/models/gemma3/gemma3_attention.py +1 -1
  27. keras_hub/src/models/gemma3/gemma3_backbone.py +70 -8
  28. keras_hub/src/models/gemma3/gemma3_causal_lm.py +16 -1
  29. keras_hub/src/models/gemma3/gemma3_decoder_block.py +1 -1
  30. keras_hub/src/models/gemma3/{gemma3_interleave_embeddings.py → gemma3_layers.py} +101 -0
  31. keras_hub/src/models/gemma3/gemma3_presets.py +67 -7
  32. keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
  33. keras_hub/src/models/gpt2/gpt2_backbone.py +1 -3
  34. keras_hub/src/models/gpt2/gpt2_causal_lm.py +1 -1
  35. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +1 -3
  36. keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +1 -3
  37. keras_hub/src/models/llama/llama_backbone.py +1 -3
  38. keras_hub/src/models/masked_lm.py +1 -1
  39. keras_hub/src/models/mistral/mistral_backbone.py +1 -3
  40. keras_hub/src/models/mixtral/mixtral_backbone.py +1 -3
  41. keras_hub/src/models/moonshine/moonshine_backbone.py +1 -3
  42. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +1 -3
  43. keras_hub/src/models/parseq/parseq_tokenizer.py +3 -1
  44. keras_hub/src/models/phi3/phi3_backbone.py +1 -3
  45. keras_hub/src/models/qwen/qwen_backbone.py +1 -3
  46. keras_hub/src/models/qwen/qwen_presets.py +209 -0
  47. keras_hub/src/models/qwen3/qwen3_backbone.py +1 -3
  48. keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +1 -3
  49. keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +15 -0
  50. keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +1 -3
  51. keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -3
  52. keras_hub/src/models/rqvae/__init__.py +5 -0
  53. keras_hub/src/models/rqvae/rqvae_backbone.py +167 -0
  54. keras_hub/src/models/rqvae/rqvae_layers.py +335 -0
  55. keras_hub/src/models/rwkv7/__init__.py +5 -0
  56. keras_hub/src/models/rwkv7/rwkv7_backbone.py +180 -0
  57. keras_hub/src/models/rwkv7/rwkv7_causal_lm.py +259 -0
  58. keras_hub/src/models/rwkv7/rwkv7_causal_lm_preprocessor.py +214 -0
  59. keras_hub/src/models/rwkv7/rwkv7_layer.py +724 -0
  60. keras_hub/src/models/rwkv7/rwkv7_presets.py +26 -0
  61. keras_hub/src/models/rwkv7/rwkv7_tokenizer.py +495 -0
  62. keras_hub/src/models/sam/sam_backbone.py +5 -1
  63. keras_hub/src/models/sam/sam_prompt_encoder.py +1 -1
  64. keras_hub/src/models/sam3/__init__.py +7 -0
  65. keras_hub/src/models/sam3/roi_align.py +222 -0
  66. keras_hub/src/models/sam3/sam3_detr_decoder.py +641 -0
  67. keras_hub/src/models/sam3/sam3_detr_encoder.py +293 -0
  68. keras_hub/src/models/sam3/sam3_dot_product_scoring.py +120 -0
  69. keras_hub/src/models/sam3/sam3_geometry_encoder.py +517 -0
  70. keras_hub/src/models/sam3/sam3_image_converter.py +10 -0
  71. keras_hub/src/models/sam3/sam3_layers.py +814 -0
  72. keras_hub/src/models/sam3/sam3_mask_decoder.py +374 -0
  73. keras_hub/src/models/sam3/sam3_pc_backbone.py +306 -0
  74. keras_hub/src/models/sam3/sam3_pc_image_segmenter.py +282 -0
  75. keras_hub/src/models/sam3/sam3_pc_image_segmenter_preprocessor.py +336 -0
  76. keras_hub/src/models/sam3/sam3_presets.py +16 -0
  77. keras_hub/src/models/sam3/sam3_text_encoder.py +212 -0
  78. keras_hub/src/models/sam3/sam3_tokenizer.py +65 -0
  79. keras_hub/src/models/sam3/sam3_utils.py +134 -0
  80. keras_hub/src/models/sam3/sam3_vision_encoder.py +738 -0
  81. keras_hub/src/models/segformer/segformer_backbone.py +6 -6
  82. keras_hub/src/models/siglip/siglip_layers.py +1 -3
  83. keras_hub/src/models/smollm3/smollm3_backbone.py +1 -3
  84. keras_hub/src/models/stable_diffusion_3/t5_encoder.py +1 -3
  85. keras_hub/src/models/t5/t5_backbone.py +1 -3
  86. keras_hub/src/models/t5gemma/t5gemma_backbone.py +1 -3
  87. keras_hub/src/models/task.py +1 -1
  88. keras_hub/src/tests/test_case.py +394 -3
  89. keras_hub/src/tokenizers/byte_pair_tokenizer.py +33 -2
  90. keras_hub/src/tokenizers/byte_tokenizer.py +3 -1
  91. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +15 -1
  92. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +3 -1
  93. keras_hub/src/tokenizers/word_piece_tokenizer.py +15 -1
  94. keras_hub/src/utils/preset_utils.py +1 -1
  95. keras_hub/src/utils/tensor_utils.py +12 -0
  96. keras_hub/src/utils/transformers/convert_gemma3.py +68 -22
  97. keras_hub/src/utils/transformers/convert_qwen3_moe.py +4 -1
  98. keras_hub/src/utils/transformers/convert_sam3.py +472 -0
  99. keras_hub/src/utils/transformers/export/gemma3.py +196 -0
  100. keras_hub/src/utils/transformers/export/hf_exporter.py +86 -25
  101. keras_hub/src/utils/transformers/export/qwen.py +136 -0
  102. keras_hub/src/utils/transformers/preset_loader.py +15 -1
  103. keras_hub/src/version.py +1 -1
  104. keras_hub/tokenizers/__init__.py +6 -0
  105. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/METADATA +6 -13
  106. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/RECORD +108 -76
  107. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/WHEEL +1 -1
  108. keras_hub/src/models/gemma3/rms_normalization.py +0 -26
  109. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,517 @@
1
+ import keras
2
+ from keras import layers
3
+ from keras import ops
4
+
5
+ from keras_hub.src.api_export import keras_hub_export
6
+ from keras_hub.src.models.sam3.roi_align import roi_align
7
+ from keras_hub.src.models.sam3.sam3_layers import SAM3MLP
8
+ from keras_hub.src.models.sam3.sam3_layers import SAM3Attention
9
+ from keras_hub.src.models.sam3.sam3_layers import SAM3SinePositionEmbedding
10
+ from keras_hub.src.models.sam3.sam3_utils import box_cxcywh_to_xyxy
11
+ from keras_hub.src.models.sam3.sam3_utils import concatenate_padded_sequences
12
+
13
+
14
+ class SAM3GeometryEncoderLayer(layers.Layer):
15
+ def __init__(
16
+ self,
17
+ hidden_dim,
18
+ intermediate_dim,
19
+ num_heads,
20
+ hidden_activation="relu",
21
+ dropout_rate=0.0,
22
+ layer_norm_epsilon=1e-6,
23
+ **kwargs,
24
+ ):
25
+ super().__init__(**kwargs)
26
+ self.hidden_dim = int(hidden_dim)
27
+ self.intermediate_dim = int(intermediate_dim)
28
+ self.num_heads = int(num_heads)
29
+ self.hidden_activation = hidden_activation
30
+ self.dropout_rate = float(dropout_rate)
31
+ self.layer_norm_epsilon = float(layer_norm_epsilon)
32
+
33
+ self.layer_norm1 = layers.LayerNormalization(
34
+ epsilon=self.layer_norm_epsilon,
35
+ dtype=self.dtype_policy,
36
+ name="layer_norm1",
37
+ )
38
+ self.self_attn = SAM3Attention(
39
+ hidden_dim=self.hidden_dim,
40
+ num_heads=self.num_heads,
41
+ dtype=self.dtype_policy,
42
+ name="self_attn",
43
+ )
44
+ self.dropout = layers.Dropout(
45
+ rate=self.dropout_rate, dtype=self.dtype_policy, name="dropout"
46
+ )
47
+ self.cross_attn = SAM3Attention(
48
+ hidden_dim=self.hidden_dim,
49
+ num_heads=self.num_heads,
50
+ dtype=self.dtype_policy,
51
+ name="cross_attn",
52
+ )
53
+ self.layer_norm2 = layers.LayerNormalization(
54
+ epsilon=self.layer_norm_epsilon,
55
+ dtype=self.dtype_policy,
56
+ name="layer_norm2",
57
+ )
58
+ self.mlp = SAM3MLP(
59
+ hidden_dim=self.hidden_dim,
60
+ intermediate_dim=self.intermediate_dim,
61
+ activation=self.hidden_activation,
62
+ dropout_rate=self.dropout_rate,
63
+ dtype=self.dtype_policy,
64
+ name="mlp",
65
+ )
66
+ self.layer_norm3 = layers.LayerNormalization(
67
+ epsilon=self.layer_norm_epsilon,
68
+ dtype=self.dtype_policy,
69
+ name="layer_norm3",
70
+ )
71
+
72
+ def build(
73
+ self,
74
+ prompt_feats_shape,
75
+ vision_feats_shape,
76
+ vision_pos_encodings_shape,
77
+ prompt_masks_shape,
78
+ ):
79
+ self.layer_norm1.build(prompt_feats_shape)
80
+ self.self_attn.build(
81
+ prompt_feats_shape, prompt_feats_shape, prompt_feats_shape
82
+ )
83
+ self.dropout.build(prompt_feats_shape)
84
+ self.layer_norm2.build(prompt_feats_shape)
85
+ self.cross_attn.build(
86
+ prompt_feats_shape, vision_feats_shape, vision_feats_shape
87
+ )
88
+ self.layer_norm3.build(prompt_feats_shape)
89
+ self.mlp.build(prompt_feats_shape)
90
+
91
+ def call(
92
+ self,
93
+ prompt_feats,
94
+ vision_feats,
95
+ vision_pos_encodings,
96
+ prompt_masks,
97
+ training=None,
98
+ ):
99
+ residual = prompt_feats
100
+ hidden_states = self.layer_norm1(prompt_feats, training=training)
101
+ hidden_states = self.self_attn(
102
+ query=hidden_states,
103
+ key=hidden_states,
104
+ value=hidden_states,
105
+ attention_mask=prompt_masks,
106
+ training=training,
107
+ )
108
+ hidden_states = ops.add(
109
+ self.dropout(hidden_states, training=training), residual
110
+ )
111
+
112
+ residual = hidden_states
113
+ hidden_states = self.layer_norm2(hidden_states, training=training)
114
+ key = ops.add(vision_feats, vision_pos_encodings)
115
+ hidden_states = self.cross_attn(
116
+ query=hidden_states, key=key, value=vision_feats, training=training
117
+ )
118
+ hidden_states = ops.add(
119
+ self.dropout(hidden_states, training=training), residual
120
+ )
121
+
122
+ residual = hidden_states
123
+ hidden_states = self.layer_norm3(hidden_states, training=training)
124
+ hidden_states = self.mlp(hidden_states, training=training)
125
+ hidden_states = ops.add(
126
+ self.dropout(hidden_states, training=training), residual
127
+ )
128
+ return hidden_states
129
+
130
+ def get_config(self):
131
+ config = super().get_config()
132
+ config.update(
133
+ {
134
+ "hidden_dim": self.hidden_dim,
135
+ "intermediate_dim": self.intermediate_dim,
136
+ "num_heads": self.num_heads,
137
+ "hidden_activation": self.hidden_activation,
138
+ "dropout_rate": self.dropout_rate,
139
+ "layer_norm_epsilon": self.layer_norm_epsilon,
140
+ }
141
+ )
142
+ return config
143
+
144
+ def compute_output_shape(
145
+ self,
146
+ prompt_feats_shape,
147
+ vision_feats_shape,
148
+ vision_pos_encodings_shape,
149
+ prompt_masks_shape,
150
+ ):
151
+ return prompt_feats_shape
152
+
153
+
154
+ @keras_hub_export("keras_hub.layers.SAM3GeometryEncoder")
155
+ class SAM3GeometryEncoder(layers.Layer):
156
+ """A geometry encoder for the Segment Anything Model 3 (SAM3).
157
+
158
+ This layer implements a transformer-based encoder for processing geometry
159
+ prompts (boxes). It extracts features from the input boxes, pools vision
160
+ features based on the boxes, and fuses them with transformer layers.
161
+
162
+ Args:
163
+ num_layers: int. The number of transformer layers.
164
+ hidden_dim: int. The hidden dimension of the transformer layers.
165
+ intermediate_dim: int. The dimension of the intermediate layer in the
166
+ transformer's MLP.
167
+ num_heads: int. The number of attention heads.
168
+ roi_size: int. The size of the ROI pooling for boxes.
169
+ hidden_activation: str. The activation function for the transformer
170
+ layers. Defaults to `"relu"`.
171
+ dropout_rate: float. The dropout rate for the MLP and attention.
172
+ Defaults to `0.0`.
173
+ layer_norm_epsilon: float. The epsilon value for layer normalization.
174
+ Defaults to `1e-6`.
175
+ """
176
+
177
+ def __init__(
178
+ self,
179
+ num_layers,
180
+ hidden_dim,
181
+ intermediate_dim,
182
+ num_heads,
183
+ roi_size,
184
+ hidden_activation="relu",
185
+ dropout_rate=0.0,
186
+ layer_norm_epsilon=1e-6,
187
+ **kwargs,
188
+ ):
189
+ super().__init__(**kwargs)
190
+ self.num_layers = int(num_layers)
191
+ self.hidden_dim = int(hidden_dim)
192
+ self.intermediate_dim = int(intermediate_dim)
193
+ self.num_heads = int(num_heads)
194
+ self.roi_size = int(roi_size)
195
+ self.hidden_activation = hidden_activation
196
+ self.dropout_rate = float(dropout_rate)
197
+ self.layer_norm_epsilon = float(layer_norm_epsilon)
198
+
199
+ self.position_encoding = SAM3SinePositionEmbedding(
200
+ num_pos_feats=self.hidden_dim // 2,
201
+ normalize=True,
202
+ dtype=self.dtype_policy,
203
+ name="position_encoding",
204
+ )
205
+ self.label_embed = layers.Embedding(
206
+ 2, self.hidden_dim, dtype=self.dtype_policy, name="label_embed"
207
+ )
208
+ self.cls_embed = layers.Embedding(
209
+ 1, self.hidden_dim, dtype=self.dtype_policy, name="cls_embed"
210
+ )
211
+
212
+ # Box encoding layers.
213
+ self.boxes_direct_project = layers.Dense(
214
+ self.hidden_dim,
215
+ dtype=self.dtype_policy,
216
+ name="boxes_direct_project",
217
+ )
218
+ self.boxes_pool_project = layers.Conv2D(
219
+ self.hidden_dim,
220
+ kernel_size=self.roi_size,
221
+ dtype=self.dtype_policy,
222
+ name="boxes_pool_project",
223
+ )
224
+ self.boxes_pos_enc_project = layers.Dense(
225
+ self.hidden_dim,
226
+ dtype=self.dtype_policy,
227
+ name="boxes_pos_enc_project",
228
+ )
229
+
230
+ # Image feature normalization.
231
+ self.vision_layer_norm = layers.LayerNormalization(
232
+ epsilon=self.layer_norm_epsilon,
233
+ dtype=self.dtype_policy,
234
+ name="vision_layer_norm",
235
+ )
236
+
237
+ # Prompt projection and normalization.
238
+ self.final_proj = layers.Dense(
239
+ self.hidden_dim, dtype=self.dtype_policy, name="final_proj"
240
+ )
241
+ self.prompt_layer_norm = layers.LayerNormalization(
242
+ epsilon=self.layer_norm_epsilon,
243
+ dtype=self.dtype_policy,
244
+ name="prompt_layer_norm",
245
+ )
246
+
247
+ # Transformer layers.
248
+ self.layers = [
249
+ SAM3GeometryEncoderLayer(
250
+ hidden_dim=self.hidden_dim,
251
+ intermediate_dim=self.intermediate_dim,
252
+ num_heads=self.num_heads,
253
+ dropout_rate=self.dropout_rate,
254
+ hidden_activation=self.hidden_activation,
255
+ layer_norm_epsilon=self.layer_norm_epsilon,
256
+ dtype=self.dtype_policy,
257
+ name=f"layer_{i}",
258
+ )
259
+ for i in range(self.num_layers)
260
+ ]
261
+ self.output_layer_norm = layers.LayerNormalization(
262
+ epsilon=self.layer_norm_epsilon,
263
+ dtype=self.dtype_policy,
264
+ name="output_layer_norm",
265
+ )
266
+
267
+ def build(
268
+ self,
269
+ box_embeddings_shape,
270
+ box_masks_shape,
271
+ box_labels_shape,
272
+ fpn_hidden_states_shape,
273
+ fpn_position_encodings_shape,
274
+ ):
275
+ batch_size = fpn_hidden_states_shape[0]
276
+ self.height = fpn_hidden_states_shape[1]
277
+ self.width = fpn_hidden_states_shape[2]
278
+ self.input_hidden_dim = fpn_hidden_states_shape[-1]
279
+
280
+ self.position_encoding.build()
281
+ self.vision_layer_norm.build(fpn_hidden_states_shape)
282
+
283
+ box_proj_input_shape = list(box_embeddings_shape)
284
+ box_proj_input_shape[-1] = box_embeddings_shape[-1] - 1
285
+ self.boxes_direct_project.build(tuple(box_proj_input_shape))
286
+
287
+ sampled_feature_shape = [
288
+ batch_size,
289
+ self.roi_size,
290
+ self.roi_size,
291
+ self.input_hidden_dim,
292
+ ]
293
+ self.boxes_pool_project.build(sampled_feature_shape)
294
+
295
+ pos_enc_shape = [batch_size, None, self.input_hidden_dim + 2]
296
+ self.boxes_pos_enc_project.build(pos_enc_shape)
297
+ self.label_embed.build([batch_size, 1])
298
+ self.cls_embed.build([batch_size, 1])
299
+
300
+ prompt_embed_shape = [batch_size, None, self.hidden_dim]
301
+ self.final_proj.build(prompt_embed_shape)
302
+ self.prompt_layer_norm.build(prompt_embed_shape)
303
+
304
+ vision_feat_flat_shape = [
305
+ batch_size,
306
+ self.height * self.width,
307
+ self.input_hidden_dim,
308
+ ]
309
+ for layer in self.layers:
310
+ layer.build(
311
+ prompt_embed_shape,
312
+ vision_feat_flat_shape,
313
+ vision_feat_flat_shape,
314
+ None,
315
+ )
316
+ self.output_layer_norm.build(prompt_embed_shape)
317
+
318
+ def _encode_box_coordinates(self, center_x, center_y, width, height):
319
+ pos_x, pos_y = self.position_encoding.encode_1d_positions(
320
+ center_x, center_y
321
+ )
322
+ pos = ops.concatenate(
323
+ (pos_y, pos_x, height[:, None], width[:, None]), axis=1
324
+ )
325
+ return pos
326
+
327
+ def _encode_boxes(self, boxes, boxes_mask, boxes_labels, vision_features):
328
+ # Keras passes the masks as concrete tensors for both the
329
+ # true and false functions to build the output shape. So, we
330
+ # need to handle the case when 0 size masks is passed and
331
+ # dispatch the call to `_no_box_embeddings`. Note that we can't call
332
+ # the lambda directly since the inputs are bound to different
333
+ # values when called with concrete values.
334
+ if boxes.shape[1] == 0:
335
+ return self._no_box_embeddings(boxes, boxes_mask)
336
+
337
+ # The shape of boxes is different from HF's implementation.
338
+ # boxes: [batch_size, num_boxes, 5] where the last dimension is
339
+ # (batch_index, cx, cy, w, h)
340
+ boxes_indices = boxes[..., 0:1]
341
+ boxes = boxes[..., 1:]
342
+ batch_size = ops.shape(boxes)[0]
343
+ boxes_embed = self.boxes_direct_project(boxes)
344
+
345
+ # Pool features using ROI align.
346
+ # Convert boxes from cxcywh to xyxy format and denormalize.
347
+ boxes_xyxy = box_cxcywh_to_xyxy(boxes)
348
+ scale = ops.array(
349
+ [[[self.width, self.height, self.width, self.height]]],
350
+ dtype=boxes.dtype,
351
+ )
352
+ boxes_xyxy = ops.multiply(boxes_xyxy, scale)
353
+ boxes_xyxy = ops.reshape(boxes_xyxy, (-1, 4))
354
+ # Add batch indices to boxes for roi_align.
355
+ rois = ops.concatenate(
356
+ [ops.reshape(boxes_indices, (-1, 1)), boxes_xyxy], axis=-1
357
+ )
358
+ sampled_features = roi_align(
359
+ vision_features,
360
+ rois,
361
+ (self.roi_size, self.roi_size),
362
+ spatial_scale=1.0,
363
+ height=self.height,
364
+ width=self.width,
365
+ hidden_dim=self.input_hidden_dim,
366
+ )
367
+
368
+ pooled_projection = self.boxes_pool_project(sampled_features)
369
+ pooled_projection = ops.reshape(
370
+ pooled_projection, (batch_size, -1, self.hidden_dim)
371
+ )
372
+ boxes_embed = ops.add(boxes_embed, pooled_projection)
373
+
374
+ # Add position encoding.
375
+ center_x, center_y, box_width, box_height = ops.unstack(
376
+ boxes, num=4, axis=-1
377
+ )
378
+ pos_enc = self._encode_box_coordinates(
379
+ ops.reshape(center_x, (-1,)),
380
+ ops.reshape(center_y, (-1,)),
381
+ ops.reshape(box_width, (-1,)),
382
+ ops.reshape(box_height, (-1,)),
383
+ )
384
+ pos_enc = ops.reshape(
385
+ pos_enc,
386
+ (batch_size, -1, self.position_encoding.num_pos_feats * 2 + 2),
387
+ )
388
+ pos_projection = self.boxes_pos_enc_project(pos_enc)
389
+ boxes_embed = ops.add(boxes_embed, pos_projection)
390
+
391
+ # Add label embeddings (positive / negative).
392
+ label_embed = self.label_embed(ops.cast(boxes_labels, dtype="int32"))
393
+ return ops.add(label_embed, boxes_embed), boxes_mask
394
+
395
+ def _no_box_embeddings(self, box_embeddings, box_masks):
396
+ batch_size = ops.shape(box_embeddings)[0]
397
+ num_boxes = ops.shape(box_embeddings)[1]
398
+ return (
399
+ ops.zeros(
400
+ (batch_size, num_boxes, self.hidden_dim),
401
+ dtype=box_embeddings.dtype,
402
+ ),
403
+ box_masks,
404
+ )
405
+
406
+ def call(
407
+ self,
408
+ box_embeddings,
409
+ box_masks,
410
+ box_labels,
411
+ fpn_hidden_states,
412
+ fpn_position_encodings,
413
+ training=None,
414
+ ):
415
+ # Prepare vision features for cross-attention.
416
+ vision_feats_flat = ops.reshape(
417
+ fpn_hidden_states,
418
+ (-1, self.height * self.width, self.input_hidden_dim),
419
+ )
420
+ vision_pos_embeds_flat = ops.reshape(
421
+ fpn_position_encodings,
422
+ (-1, self.height * self.width, self.input_hidden_dim),
423
+ )
424
+
425
+ # Normalize image features for pooling operations.
426
+ normalized_image_feats = self.vision_layer_norm(fpn_hidden_states)
427
+
428
+ prompt_embeds, prompt_mask = ops.cond(
429
+ ops.equal(ops.shape(box_embeddings)[1], 0),
430
+ lambda: self._no_box_embeddings(box_embeddings, box_masks),
431
+ lambda: self._encode_boxes(
432
+ box_embeddings, box_masks, box_labels, normalized_image_feats
433
+ ),
434
+ )
435
+
436
+ # Add CLS token (always valid).
437
+ cls_embed = ops.reshape(
438
+ self.cls_embed._embeddings, (1, 1, self.hidden_dim)
439
+ )
440
+ cls_embed = ops.tile(cls_embed, (ops.shape(prompt_embeds)[0], 1, 1))
441
+ cls_mask = ops.ones_like(cls_embed[:, :, 0], dtype=prompt_mask.dtype)
442
+
443
+ prompt_embeds, prompt_mask = concatenate_padded_sequences(
444
+ prompt_embeds,
445
+ prompt_mask,
446
+ ops.shape(prompt_embeds)[1],
447
+ cls_embed,
448
+ cls_mask,
449
+ 1,
450
+ self.hidden_dim,
451
+ )
452
+ prompt_embeds = self.prompt_layer_norm(self.final_proj(prompt_embeds))
453
+
454
+ # Apply transformer layers with cross-attention to vision features.
455
+ for layer in self.layers:
456
+ prompt_embeds = layer(
457
+ prompt_embeds,
458
+ vision_feats_flat,
459
+ vision_pos_embeds_flat,
460
+ prompt_mask,
461
+ training=training,
462
+ )
463
+
464
+ # Final output normalization.
465
+ prompt_embeds = self.output_layer_norm(prompt_embeds, training=training)
466
+ return prompt_embeds, prompt_mask
467
+
468
+ def get_config(self):
469
+ config = super().get_config()
470
+ config.update(
471
+ {
472
+ "num_layers": self.num_layers,
473
+ "hidden_dim": self.hidden_dim,
474
+ "intermediate_dim": self.intermediate_dim,
475
+ "num_heads": self.num_heads,
476
+ "roi_size": self.roi_size,
477
+ "hidden_activation": self.hidden_activation,
478
+ "dropout_rate": self.dropout_rate,
479
+ "layer_norm_epsilon": self.layer_norm_epsilon,
480
+ }
481
+ )
482
+ return config
483
+
484
+ def compute_output_shape(
485
+ self,
486
+ box_embeddings_shape,
487
+ box_masks_shape,
488
+ box_labels_shape,
489
+ fpn_hidden_states_shape,
490
+ fpn_position_encodings_shape,
491
+ ):
492
+ batch_size = fpn_hidden_states_shape[0]
493
+ num_boxes = box_embeddings_shape[1]
494
+ seq_len = None
495
+ if num_boxes is not None:
496
+ seq_len = num_boxes + 1
497
+ return [batch_size, seq_len, self.hidden_dim], [batch_size, seq_len]
498
+
499
+ def compute_output_spec(
500
+ self,
501
+ box_embeddings,
502
+ box_masks,
503
+ box_labels,
504
+ fpn_hidden_states,
505
+ fpn_position_encodings,
506
+ ):
507
+ prompt_embeds_shape, prompt_mask_shape = self.compute_output_shape(
508
+ box_embeddings.shape,
509
+ box_masks.shape,
510
+ box_labels.shape,
511
+ fpn_hidden_states.shape,
512
+ fpn_position_encodings.shape,
513
+ )
514
+ return (
515
+ keras.KerasTensor(prompt_embeds_shape, dtype=self.compute_dtype),
516
+ keras.KerasTensor(prompt_mask_shape, dtype="bool"),
517
+ )
@@ -0,0 +1,10 @@
1
+ from keras_hub.src.api_export import keras_hub_export
2
+ from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
3
+ from keras_hub.src.models.sam3.sam3_pc_backbone import (
4
+ SAM3PromptableConceptBackbone,
5
+ )
6
+
7
+
8
+ @keras_hub_export("keras_hub.layers.SAM3ImageConverter")
9
+ class SAM3ImageConverter(ImageConverter):
10
+ backbone_cls = SAM3PromptableConceptBackbone