keras-hub 0.25.1__py3-none-any.whl → 0.26.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. keras_hub/layers/__init__.py +21 -0
  2. keras_hub/models/__init__.py +27 -0
  3. keras_hub/src/layers/modeling/non_max_supression.py +5 -2
  4. keras_hub/src/layers/modeling/reversible_embedding.py +2 -275
  5. keras_hub/src/layers/modeling/token_and_position_embedding.py +6 -6
  6. keras_hub/src/layers/modeling/transformer_layer_utils.py +9 -9
  7. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +3 -1
  8. keras_hub/src/layers/preprocessing/multi_segment_packer.py +3 -1
  9. keras_hub/src/models/albert/albert_backbone.py +1 -3
  10. keras_hub/src/models/backbone.py +3 -0
  11. keras_hub/src/models/bart/bart_backbone.py +1 -3
  12. keras_hub/src/models/bert/bert_backbone.py +2 -4
  13. keras_hub/src/models/bloom/bloom_backbone.py +1 -3
  14. keras_hub/src/models/causal_lm.py +2 -2
  15. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -3
  16. keras_hub/src/models/edrec/edrec_backbone.py +147 -0
  17. keras_hub/src/models/edrec/edrec_layers.py +434 -0
  18. keras_hub/src/models/edrec/edrec_seq2seq_lm.py +273 -0
  19. keras_hub/src/models/electra/electra_backbone.py +1 -3
  20. keras_hub/src/models/f_net/f_net_backbone.py +1 -3
  21. keras_hub/src/models/falcon/falcon_backbone.py +1 -3
  22. keras_hub/src/models/flux/flux_layers.py +3 -3
  23. keras_hub/src/models/flux/flux_maths.py +29 -15
  24. keras_hub/src/models/gemma/gemma_backbone.py +1 -3
  25. keras_hub/src/models/gemma/gemma_causal_lm.py +1 -1
  26. keras_hub/src/models/gemma3/gemma3_attention.py +1 -1
  27. keras_hub/src/models/gemma3/gemma3_backbone.py +70 -8
  28. keras_hub/src/models/gemma3/gemma3_causal_lm.py +16 -1
  29. keras_hub/src/models/gemma3/gemma3_decoder_block.py +1 -1
  30. keras_hub/src/models/gemma3/{gemma3_interleave_embeddings.py → gemma3_layers.py} +101 -0
  31. keras_hub/src/models/gemma3/gemma3_presets.py +67 -7
  32. keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
  33. keras_hub/src/models/gpt2/gpt2_backbone.py +1 -3
  34. keras_hub/src/models/gpt2/gpt2_causal_lm.py +1 -1
  35. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +1 -3
  36. keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +1 -3
  37. keras_hub/src/models/llama/llama_backbone.py +1 -3
  38. keras_hub/src/models/masked_lm.py +1 -1
  39. keras_hub/src/models/mistral/mistral_backbone.py +1 -3
  40. keras_hub/src/models/mixtral/mixtral_backbone.py +1 -3
  41. keras_hub/src/models/moonshine/moonshine_backbone.py +1 -3
  42. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +1 -3
  43. keras_hub/src/models/parseq/parseq_tokenizer.py +3 -1
  44. keras_hub/src/models/phi3/phi3_backbone.py +1 -3
  45. keras_hub/src/models/qwen/qwen_backbone.py +1 -3
  46. keras_hub/src/models/qwen/qwen_presets.py +209 -0
  47. keras_hub/src/models/qwen3/qwen3_backbone.py +1 -3
  48. keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +1 -3
  49. keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +15 -0
  50. keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +1 -3
  51. keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -3
  52. keras_hub/src/models/rqvae/__init__.py +5 -0
  53. keras_hub/src/models/rqvae/rqvae_backbone.py +167 -0
  54. keras_hub/src/models/rqvae/rqvae_layers.py +335 -0
  55. keras_hub/src/models/rwkv7/__init__.py +5 -0
  56. keras_hub/src/models/rwkv7/rwkv7_backbone.py +180 -0
  57. keras_hub/src/models/rwkv7/rwkv7_causal_lm.py +259 -0
  58. keras_hub/src/models/rwkv7/rwkv7_causal_lm_preprocessor.py +214 -0
  59. keras_hub/src/models/rwkv7/rwkv7_layer.py +724 -0
  60. keras_hub/src/models/rwkv7/rwkv7_presets.py +26 -0
  61. keras_hub/src/models/rwkv7/rwkv7_tokenizer.py +495 -0
  62. keras_hub/src/models/sam/sam_backbone.py +5 -1
  63. keras_hub/src/models/sam/sam_prompt_encoder.py +1 -1
  64. keras_hub/src/models/sam3/__init__.py +7 -0
  65. keras_hub/src/models/sam3/roi_align.py +222 -0
  66. keras_hub/src/models/sam3/sam3_detr_decoder.py +641 -0
  67. keras_hub/src/models/sam3/sam3_detr_encoder.py +293 -0
  68. keras_hub/src/models/sam3/sam3_dot_product_scoring.py +120 -0
  69. keras_hub/src/models/sam3/sam3_geometry_encoder.py +517 -0
  70. keras_hub/src/models/sam3/sam3_image_converter.py +10 -0
  71. keras_hub/src/models/sam3/sam3_layers.py +814 -0
  72. keras_hub/src/models/sam3/sam3_mask_decoder.py +374 -0
  73. keras_hub/src/models/sam3/sam3_pc_backbone.py +306 -0
  74. keras_hub/src/models/sam3/sam3_pc_image_segmenter.py +282 -0
  75. keras_hub/src/models/sam3/sam3_pc_image_segmenter_preprocessor.py +336 -0
  76. keras_hub/src/models/sam3/sam3_presets.py +16 -0
  77. keras_hub/src/models/sam3/sam3_text_encoder.py +212 -0
  78. keras_hub/src/models/sam3/sam3_tokenizer.py +65 -0
  79. keras_hub/src/models/sam3/sam3_utils.py +134 -0
  80. keras_hub/src/models/sam3/sam3_vision_encoder.py +738 -0
  81. keras_hub/src/models/segformer/segformer_backbone.py +6 -6
  82. keras_hub/src/models/siglip/siglip_layers.py +1 -3
  83. keras_hub/src/models/smollm3/smollm3_backbone.py +1 -3
  84. keras_hub/src/models/stable_diffusion_3/t5_encoder.py +1 -3
  85. keras_hub/src/models/t5/t5_backbone.py +1 -3
  86. keras_hub/src/models/t5gemma/t5gemma_backbone.py +1 -3
  87. keras_hub/src/models/task.py +1 -1
  88. keras_hub/src/tests/test_case.py +394 -3
  89. keras_hub/src/tokenizers/byte_pair_tokenizer.py +33 -2
  90. keras_hub/src/tokenizers/byte_tokenizer.py +3 -1
  91. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +15 -1
  92. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +3 -1
  93. keras_hub/src/tokenizers/word_piece_tokenizer.py +15 -1
  94. keras_hub/src/utils/preset_utils.py +1 -1
  95. keras_hub/src/utils/tensor_utils.py +12 -0
  96. keras_hub/src/utils/transformers/convert_gemma3.py +68 -22
  97. keras_hub/src/utils/transformers/convert_qwen3_moe.py +4 -1
  98. keras_hub/src/utils/transformers/convert_sam3.py +472 -0
  99. keras_hub/src/utils/transformers/export/gemma3.py +196 -0
  100. keras_hub/src/utils/transformers/export/hf_exporter.py +86 -25
  101. keras_hub/src/utils/transformers/export/qwen.py +136 -0
  102. keras_hub/src/utils/transformers/preset_loader.py +15 -1
  103. keras_hub/src/version.py +1 -1
  104. keras_hub/tokenizers/__init__.py +6 -0
  105. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/METADATA +6 -13
  106. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/RECORD +108 -76
  107. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/WHEEL +1 -1
  108. keras_hub/src/models/gemma3/rms_normalization.py +0 -26
  109. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,374 @@
1
+ from keras import layers
2
+ from keras import ops
3
+
4
+ from keras_hub.src.api_export import keras_hub_export
5
+ from keras_hub.src.models.sam3.sam3_layers import SAM3Attention
6
+ from keras_hub.src.models.sam3.sam3_utils import create_bidirectional_mask
7
+ from keras_hub.src.utils.keras_utils import standardize_data_format
8
+
9
+
10
+ class SAM3MaskEmbedder(layers.Layer):
11
+ def __init__(self, hidden_dim, **kwargs):
12
+ super().__init__(**kwargs)
13
+ self.hidden_dim = int(hidden_dim)
14
+
15
+ self.layers = [
16
+ layers.Dense(
17
+ self.hidden_dim, dtype=self.dtype_policy, name="layer_0"
18
+ ),
19
+ layers.Dense(
20
+ self.hidden_dim, dtype=self.dtype_policy, name="layer_1"
21
+ ),
22
+ layers.Dense(
23
+ self.hidden_dim, dtype=self.dtype_policy, name="layer_2"
24
+ ),
25
+ ]
26
+ self.activation = layers.ReLU(
27
+ dtype=self.dtype_policy, name="activation"
28
+ )
29
+
30
+ def build(self, queries_shape):
31
+ hidden_state_shape = queries_shape
32
+ self.activation.build(hidden_state_shape)
33
+ for layer in self.layers:
34
+ layer.build(hidden_state_shape)
35
+ hidden_state_shape = layer.compute_output_shape(hidden_state_shape)
36
+
37
+ def call(self, queries, training=None):
38
+ hidden_states = queries
39
+ for i, layer in enumerate(self.layers):
40
+ hidden_states = layer(hidden_states, training=training)
41
+ if i < len(self.layers) - 1:
42
+ hidden_states = self.activation(
43
+ hidden_states, training=training
44
+ )
45
+ return hidden_states
46
+
47
+ def get_config(self):
48
+ config = super().get_config()
49
+ config.update({"hidden_dim": self.hidden_dim})
50
+ return config
51
+
52
+ def compute_output_shape(self, queries_shape):
53
+ hidden_state_shape = list(queries_shape)
54
+ hidden_state_shape[-1] = self.hidden_dim
55
+ return hidden_state_shape
56
+
57
+
58
+ class SAM3PixelDecoder(layers.Layer):
59
+ def __init__(
60
+ self, num_upsampling_stages, hidden_dim, data_format=None, **kwargs
61
+ ):
62
+ super().__init__(**kwargs)
63
+ self.num_upsampling_stages = int(num_upsampling_stages)
64
+ self.hidden_dim = int(hidden_dim)
65
+ self.data_format = standardize_data_format(data_format)
66
+
67
+ # Create conv layers and norms for FPN.
68
+ self.pad_layers = [
69
+ layers.ZeroPadding2D(
70
+ padding=1,
71
+ data_format=self.data_format,
72
+ dtype=self.dtype_policy,
73
+ name=f"pad_layer_{i}",
74
+ )
75
+ for i in range(self.num_upsampling_stages)
76
+ ]
77
+ self.conv_layers = [
78
+ layers.Conv2D(
79
+ self.hidden_dim,
80
+ 3,
81
+ 1,
82
+ data_format=self.data_format,
83
+ dtype=self.dtype_policy,
84
+ name=f"conv_layer_{i}",
85
+ )
86
+ for i in range(self.num_upsampling_stages)
87
+ ]
88
+ self.norms = [
89
+ layers.GroupNormalization(
90
+ 8, epsilon=1e-5, dtype=self.dtype_policy, name=f"norm_{i}"
91
+ )
92
+ for i in range(self.num_upsampling_stages)
93
+ ]
94
+
95
+ def build(self, backbone_features_shapes):
96
+ self.sizes = []
97
+ for i, feature_shape in enumerate(
98
+ reversed(backbone_features_shapes[:-1])
99
+ ):
100
+ if self.data_format == "channels_last":
101
+ self.sizes.append(
102
+ (int(feature_shape[1]), int(feature_shape[2]))
103
+ )
104
+ else:
105
+ self.sizes.append(
106
+ (int(feature_shape[2]), int(feature_shape[3]))
107
+ )
108
+ pad_layer = self.pad_layers[i]
109
+ conv_layer = self.conv_layers[i]
110
+ norm_layer = self.norms[i]
111
+ pad_layer.build(feature_shape)
112
+ feature_shape = pad_layer.compute_output_shape(feature_shape)
113
+ conv_layer.build(feature_shape)
114
+ feature_shape = conv_layer.compute_output_shape(feature_shape)
115
+ norm_layer.build(feature_shape)
116
+
117
+ def call(self, backbone_features, training=None):
118
+ prev_fpn = backbone_features[-1]
119
+ for i, feature in enumerate(reversed(backbone_features[:-1])):
120
+ prev_fpn = ops.image.resize(
121
+ prev_fpn,
122
+ size=self.sizes[i],
123
+ interpolation="nearest",
124
+ data_format=self.data_format,
125
+ )
126
+ prev_fpn = ops.add(prev_fpn, feature)
127
+ prev_fpn = self.pad_layers[i](prev_fpn, training=training)
128
+ prev_fpn = self.conv_layers[i](prev_fpn, training=training)
129
+ prev_fpn = self.norms[i](prev_fpn, training=training)
130
+ prev_fpn = ops.relu(prev_fpn)
131
+ return prev_fpn
132
+
133
+ def get_config(self):
134
+ config = super().get_config()
135
+ config.update(
136
+ {
137
+ "num_upsampling_stages": self.num_upsampling_stages,
138
+ "hidden_dim": self.hidden_dim,
139
+ }
140
+ )
141
+ return config
142
+
143
+ def compute_output_shape(self, backbone_features_shapes):
144
+ return backbone_features_shapes[0]
145
+
146
+
147
+ @keras_hub_export("keras_hub.layers.SAM3MaskDecoder")
148
+ class SAM3MaskDecoder(layers.Layer):
149
+ """A mask decoder for the Segment Anything Model 3 (SAM3).
150
+
151
+ This layer generates segmentation masks given the object queries from the
152
+ DETR decoder and fused features. It uses a pixel decoder to upsample
153
+ backbone features and predicts instance masks and semantic segmentation.
154
+
155
+ Args:
156
+ num_upsampling_stages: int. The number of upsampling stages in the
157
+ pixel decoder.
158
+ hidden_dim: int. The hidden dimension of the decoder.
159
+ num_heads: int. The number of attention heads.
160
+ dropout_rate: float. The dropout rate for attention. Defaults to `0.0`.
161
+ layer_norm_epsilon: float. The epsilon value for layer normalization.
162
+ Defaults to `1e-6`.
163
+ data_format: str. The data format, either `"channels_last"` or
164
+ `"channels_first"`.
165
+ """
166
+
167
+ def __init__(
168
+ self,
169
+ num_upsampling_stages,
170
+ hidden_dim,
171
+ num_heads,
172
+ dropout_rate=0.0,
173
+ layer_norm_epsilon=1e-6,
174
+ data_format=None,
175
+ **kwargs,
176
+ ):
177
+ super().__init__(**kwargs)
178
+ self.num_upsampling_stages = int(num_upsampling_stages)
179
+ self.hidden_dim = int(hidden_dim)
180
+ self.num_heads = int(num_heads)
181
+ self.dropout_rate = float(dropout_rate)
182
+ self.layer_norm_epsilon = float(layer_norm_epsilon)
183
+ self.data_format = standardize_data_format(data_format)
184
+
185
+ self.pixel_decoder = SAM3PixelDecoder(
186
+ num_upsampling_stages=self.num_upsampling_stages,
187
+ hidden_dim=self.hidden_dim,
188
+ data_format=self.data_format,
189
+ dtype=self.dtype_policy,
190
+ name="pixel_decoder",
191
+ )
192
+ self.mask_embedder = SAM3MaskEmbedder(
193
+ hidden_dim=self.hidden_dim,
194
+ dtype=self.dtype_policy,
195
+ name="mask_embedder",
196
+ )
197
+ self.instance_projection = layers.Conv2D(
198
+ self.hidden_dim,
199
+ 1,
200
+ data_format=self.data_format,
201
+ dtype=self.dtype_policy,
202
+ name="instance_projection",
203
+ )
204
+ self.semantic_projection = layers.Conv2D(
205
+ 1,
206
+ 1,
207
+ data_format=self.data_format,
208
+ dtype=self.dtype_policy,
209
+ name="semantic_projection",
210
+ )
211
+ self.prompt_cross_attn = SAM3Attention(
212
+ hidden_dim=self.hidden_dim,
213
+ num_heads=self.num_heads,
214
+ dtype=self.dtype_policy,
215
+ name="prompt_cross_attn",
216
+ )
217
+ self.prompt_cross_attn_norm = layers.LayerNormalization(
218
+ epsilon=self.layer_norm_epsilon,
219
+ dtype=self.dtype_policy,
220
+ name="prompt_cross_attn_norm",
221
+ )
222
+ self.prompt_cross_attn_dropout = layers.Dropout(
223
+ self.dropout_rate,
224
+ dtype=self.dtype_policy,
225
+ name="prompt_cross_attn_dropout",
226
+ )
227
+
228
+ def build(
229
+ self,
230
+ decoder_queries_shape,
231
+ backbone_features_shape,
232
+ encoder_hidden_states_shape,
233
+ prompt_features_shape,
234
+ prompt_masks_shape,
235
+ ):
236
+ if self.data_format == "channels_last":
237
+ self.height = int(backbone_features_shape[-1][1])
238
+ self.width = int(backbone_features_shape[-1][2])
239
+ else:
240
+ self.height = int(backbone_features_shape[-1][2])
241
+ self.width = int(backbone_features_shape[-1][3])
242
+ self.prompt_cross_attn_norm.build(encoder_hidden_states_shape)
243
+ self.prompt_cross_attn.build(
244
+ encoder_hidden_states_shape,
245
+ prompt_features_shape,
246
+ prompt_features_shape,
247
+ )
248
+ self.prompt_cross_attn_dropout.build(encoder_hidden_states_shape)
249
+ # _embed_pixels.
250
+ encoder_visual_embeds_shape = [
251
+ encoder_hidden_states_shape[0],
252
+ self.height * self.width,
253
+ encoder_hidden_states_shape[-1],
254
+ ]
255
+ backbone_features_shape = list(backbone_features_shape)
256
+ backbone_features_shape[-1] = encoder_visual_embeds_shape
257
+ self.pixel_decoder.build(backbone_features_shape)
258
+ pixel_embeds_shape = self.pixel_decoder.compute_output_shape(
259
+ backbone_features_shape
260
+ )
261
+ self.instance_projection.build(pixel_embeds_shape)
262
+ self.mask_embedder.build(decoder_queries_shape)
263
+ self.semantic_projection.build(pixel_embeds_shape)
264
+
265
+ def _embed_pixels(self, backbone_features, encoder_hidden_states):
266
+ spatial_dim = self.height * self.width
267
+ encoder_visual_embed = encoder_hidden_states[:, :spatial_dim, :]
268
+ encoder_visual_embed = ops.reshape(
269
+ encoder_visual_embed, (-1, self.height, self.width, self.hidden_dim)
270
+ )
271
+ if self.data_format == "channels_first":
272
+ encoder_visual_embed = ops.transpose(
273
+ encoder_visual_embed, (0, 3, 1, 2)
274
+ )
275
+ backbone_features = list(backbone_features)
276
+ backbone_features[-1] = encoder_visual_embed
277
+ return self.pixel_decoder(backbone_features)
278
+
279
+ def call(
280
+ self,
281
+ decoder_queries,
282
+ backbone_features,
283
+ encoder_hidden_states,
284
+ prompt_features,
285
+ prompt_masks,
286
+ training=None,
287
+ ):
288
+ # Cross-attention: encoder features attend to prompt features.
289
+ residual = encoder_hidden_states
290
+ normed_hidden_states = self.prompt_cross_attn_norm(
291
+ encoder_hidden_states, training=training
292
+ )
293
+ cross_attn_mask = create_bidirectional_mask(
294
+ normed_hidden_states, prompt_masks
295
+ )
296
+ attn_output = self.prompt_cross_attn(
297
+ query=normed_hidden_states,
298
+ key=prompt_features,
299
+ value=prompt_features,
300
+ attention_mask=cross_attn_mask,
301
+ training=training,
302
+ )
303
+ encoder_hidden_states = ops.add(
304
+ residual,
305
+ self.prompt_cross_attn_dropout(attn_output, training=training),
306
+ )
307
+
308
+ # Process backbone features through FPN to get pixel embeddings.
309
+ pixel_embed = self._embed_pixels(
310
+ backbone_features, encoder_hidden_states
311
+ )
312
+
313
+ # Predict instance masks via dot product between query embeddings and
314
+ # pixel embeddings.
315
+ instance_embeds = self.instance_projection(
316
+ pixel_embed, training=training
317
+ )
318
+ mask_embeddings = self.mask_embedder(decoder_queries, training=training)
319
+ if self.data_format == "channels_last":
320
+ pred_masks = ops.einsum(
321
+ "bqc,bhwc->bhwq", mask_embeddings, instance_embeds
322
+ )
323
+ else:
324
+ pred_masks = ops.einsum(
325
+ "bqc,bchw->bqhw", mask_embeddings, instance_embeds
326
+ )
327
+
328
+ # Generate semantic segmentation.
329
+ semantic_segs = self.semantic_projection(pixel_embed, training=training)
330
+ return pred_masks, semantic_segs
331
+
332
+ def get_config(self):
333
+ config = super().get_config()
334
+ config.update(
335
+ {
336
+ "num_upsampling_stages": self.num_upsampling_stages,
337
+ "hidden_dim": self.hidden_dim,
338
+ "num_heads": self.num_heads,
339
+ "dropout_rate": self.dropout_rate,
340
+ "layer_norm_epsilon": self.layer_norm_epsilon,
341
+ }
342
+ )
343
+ return config
344
+
345
+ def compute_output_shape(
346
+ self,
347
+ decoder_queries_shape,
348
+ backbone_features_shape,
349
+ encoder_hidden_states_shape,
350
+ prompt_features_shape,
351
+ prompt_masks_shape,
352
+ ):
353
+ batch_size = encoder_hidden_states_shape[0]
354
+ if self.data_format == "channels_last":
355
+ output_height = int(backbone_features_shape[0][1])
356
+ output_width = int(backbone_features_shape[0][2])
357
+ pred_masks_shape = [
358
+ batch_size,
359
+ output_height,
360
+ output_width,
361
+ self.hidden_dim,
362
+ ]
363
+ semantic_segs_shape = [batch_size, output_height, output_width, 1]
364
+ else:
365
+ output_height = int(backbone_features_shape[0][2])
366
+ output_width = int(backbone_features_shape[0][3])
367
+ pred_masks_shape = [
368
+ batch_size,
369
+ self.hidden_dim,
370
+ output_height,
371
+ output_width,
372
+ ]
373
+ semantic_segs_shape = [batch_size, 1, output_height, output_width]
374
+ return pred_masks_shape, semantic_segs_shape
@@ -0,0 +1,306 @@
1
+ import keras
2
+ from keras import layers
3
+ from keras import ops
4
+
5
+ from keras_hub.src.api_export import keras_hub_export
6
+ from keras_hub.src.models.backbone import Backbone
7
+ from keras_hub.src.models.sam3.sam3_dot_product_scoring import (
8
+ SAM3DotProductScoring,
9
+ )
10
+ from keras_hub.src.models.sam3.sam3_layers import SAM3BoxDecoder
11
+
12
+
13
+ @keras_hub_export("keras_hub.models.SAM3PromptableConceptBackbone")
14
+ class SAM3PromptableConceptBackbone(Backbone):
15
+ """A backbone for the Segment Anything Model 3 (SAM3).
16
+
17
+ SAM3 is a multi-modal model that supports text and geometry prompts (boxes)
18
+ to perform object segmentation. It consists of a vision encoder, a text
19
+ encoder, a geometry encoder for processing box prompts, and a DETR-based
20
+ encoder-decoder architecture to fuse multi-modal features and predict
21
+ segmentation masks.
22
+
23
+ Args:
24
+ vision_encoder: `keras_hub.layers.SAM3VisionEncoder`. A feature
25
+ extractor for the input images.
26
+ text_encoder: `keras_hub.layers.SAM3TextEncoder`. A Keras layer to
27
+ compute embeddings for text prompts.
28
+ geometry_encoder: `keras_hub.layers.SAM3GeometryEncoder`. A Keras layer
29
+ to compute embeddings for geometry (box) prompts.
30
+ detr_encoder: `keras_hub.layers.SAM3DetrEncoder`. A transformer-based
31
+ encoder that fuses vision and prompt features.
32
+ detr_decoder: `keras_hub.layers.SAM3DetrDecoder`. A transformer-based
33
+ decoder that predicts object queries.
34
+ mask_decoder: `keras_hub.layers.SAM3MaskDecoder`. A Keras layer to
35
+ generate segmentation masks given the object queries and fused
36
+ features.
37
+ dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
38
+ for the models computations and weights. Note that some
39
+ computations, such as softmax and layer normalization will always
40
+ be done in float32 precision regardless of dtype. Defaults to
41
+ `bfloat16`.
42
+
43
+ Example:
44
+ ```python
45
+ import numpy as np
46
+ import keras_hub
47
+
48
+ vision_encoder = keras_hub.layers.SAM3VisionEncoder(
49
+ image_shape=(224, 224, 3),
50
+ patch_size=14,
51
+ num_layers=2,
52
+ hidden_dim=32,
53
+ intermediate_dim=128,
54
+ num_heads=2,
55
+ fpn_hidden_dim=32,
56
+ fpn_scale_factors=[4.0, 2.0, 1.0, 0.5],
57
+ pretrain_image_shape=(112, 112, 3),
58
+ window_size=2,
59
+ global_attn_indexes=[1, 2],
60
+ )
61
+ text_encoder = keras_hub.layers.SAM3TextEncoder(
62
+ vocabulary_size=1024,
63
+ embedding_dim=32,
64
+ hidden_dim=32,
65
+ num_layers=2,
66
+ num_heads=2,
67
+ intermediate_dim=128,
68
+ )
69
+ geometry_encoder = keras_hub.layers.SAM3GeometryEncoder(
70
+ num_layers=3,
71
+ hidden_dim=32,
72
+ intermediate_dim=128,
73
+ num_heads=2,
74
+ roi_size=7,
75
+ )
76
+ detr_encoder = keras_hub.layers.SAM3DetrEncoder(
77
+ num_layers=3,
78
+ hidden_dim=32,
79
+ intermediate_dim=128,
80
+ num_heads=2,
81
+ )
82
+ detr_decoder = keras_hub.layers.SAM3DetrDecoder(
83
+ image_shape=(224, 224, 3),
84
+ patch_size=14,
85
+ num_layers=2,
86
+ hidden_dim=32,
87
+ intermediate_dim=128,
88
+ num_heads=2,
89
+ num_queries=100,
90
+ )
91
+ mask_decoder = keras_hub.layers.SAM3MaskDecoder(
92
+ num_upsampling_stages=3,
93
+ hidden_dim=32,
94
+ num_heads=2,
95
+ )
96
+ backbone = keras_hub.models.SAM3PromptableConceptBackbone(
97
+ vision_encoder=vision_encoder,
98
+ text_encoder=text_encoder,
99
+ geometry_encoder=geometry_encoder,
100
+ detr_encoder=detr_encoder,
101
+ detr_decoder=detr_decoder,
102
+ mask_decoder=mask_decoder,
103
+ )
104
+ input_data = {
105
+ "pixel_values": np.ones((2, 224, 224, 3), dtype="float32"),
106
+ "token_ids": np.ones((2, 32), dtype="int32"),
107
+ "padding_mask": np.ones((2, 32), dtype="bool"),
108
+ "boxes": np.zeros((2, 1, 5), dtype="float32"),
109
+ "box_labels": np.zeros((2, 1), dtype="int32"),
110
+ }
111
+ outputs = backbone(input_data)
112
+ ```
113
+ """
114
+
115
+ def __init__(
116
+ self,
117
+ vision_encoder,
118
+ text_encoder,
119
+ geometry_encoder,
120
+ detr_encoder,
121
+ detr_decoder,
122
+ mask_decoder,
123
+ dtype=None,
124
+ **kwargs,
125
+ ):
126
+ # === Layers ===
127
+ self.vision_encoder = vision_encoder
128
+ self.text_encoder = text_encoder
129
+ self.geometry_encoder = geometry_encoder
130
+ self.detr_encoder = detr_encoder
131
+ self.detr_decoder = detr_decoder
132
+ self.mask_decoder = mask_decoder
133
+
134
+ self.text_projection = layers.Dense(
135
+ self.detr_encoder.hidden_dim, dtype=dtype, name="text_projection"
136
+ )
137
+ self.dot_product_scoring = SAM3DotProductScoring(
138
+ hidden_dim=self.detr_decoder.hidden_dim,
139
+ intermediate_dim=self.detr_decoder.intermediate_dim,
140
+ dropout_rate=self.detr_decoder.dropout_rate,
141
+ layer_norm_epsilon=1e-6,
142
+ dtype=dtype,
143
+ name="dot_product_scoring",
144
+ )
145
+ self.box_decoder = SAM3BoxDecoder(dtype=dtype, name="box_decoder")
146
+
147
+ # === Functional Model ===
148
+ pixel_value_input = layers.Input(
149
+ shape=self.vision_encoder.image_shape, name="pixel_values"
150
+ )
151
+ token_id_input = keras.Input(
152
+ shape=(None,), dtype="int32", name="token_ids"
153
+ )
154
+ padding_mask_input = keras.Input(
155
+ shape=(None,), dtype="int32", name="padding_mask"
156
+ )
157
+ box_input = keras.Input(shape=(None, 5), dtype="float32", name="boxes")
158
+ box_label_input = keras.Input(
159
+ shape=(None,), dtype="int32", name="box_labels"
160
+ )
161
+
162
+ padding_mask = ops.cast(padding_mask_input, dtype="bool")
163
+ box_masks = ops.cast(
164
+ ops.where(ops.not_equal(box_label_input, -10), 1, 0), dtype="bool"
165
+ )
166
+
167
+ fpn_hidden_states, fpn_position_encodings = self.vision_encoder(
168
+ pixel_value_input
169
+ )
170
+ fpn_hidden_states = fpn_hidden_states[:-1]
171
+ fpn_position_encodings = fpn_position_encodings[:-1]
172
+ text_features = self.text_encoder(token_id_input, padding_mask)
173
+ text_features = self.text_projection(text_features)
174
+ geometry_prompt_features, geometry_prompt_mask = self.geometry_encoder(
175
+ box_input,
176
+ box_label_input,
177
+ box_masks,
178
+ fpn_hidden_states=fpn_hidden_states[-1],
179
+ fpn_position_encodings=fpn_position_encodings[-1],
180
+ )
181
+ combined_prompt_features = ops.concatenate(
182
+ [text_features, geometry_prompt_features], axis=1
183
+ )
184
+ combined_prompt_masks = ops.concatenate(
185
+ [padding_mask, geometry_prompt_mask], axis=1
186
+ )
187
+ encoder_outputs = self.detr_encoder(
188
+ vision_features=fpn_hidden_states[-1],
189
+ text_features=combined_prompt_features,
190
+ vision_pos_embeds=fpn_position_encodings[-1],
191
+ text_masks=combined_prompt_masks,
192
+ )
193
+ decoder_outputs = self.detr_decoder(
194
+ vision_features=encoder_outputs[0],
195
+ text_features=combined_prompt_features,
196
+ vision_pos_encodings=encoder_outputs[1],
197
+ text_masks=combined_prompt_masks,
198
+ )
199
+ decoder_hidden_states = decoder_outputs[0]
200
+ decoder_presence_logits = decoder_outputs[2]
201
+ all_box_offsets = self.detr_decoder.box_head(decoder_hidden_states)
202
+ all_pred_logits = self.dot_product_scoring(
203
+ decoder_hidden_states=decoder_hidden_states,
204
+ text_features=combined_prompt_features,
205
+ text_masks=combined_prompt_masks,
206
+ )
207
+ pred_boxes, pred_logits, presence_logits = self.box_decoder(
208
+ box_offsets=all_box_offsets,
209
+ reference_boxes=decoder_outputs[1],
210
+ pred_logits=all_pred_logits,
211
+ presence_logits=decoder_presence_logits,
212
+ )
213
+ pred_masks, semantic_segs = self.mask_decoder(
214
+ decoder_queries=decoder_hidden_states[:, -1],
215
+ backbone_features=fpn_hidden_states,
216
+ encoder_hidden_states=encoder_outputs[0],
217
+ prompt_features=combined_prompt_features,
218
+ prompt_masks=combined_prompt_masks,
219
+ )
220
+
221
+ super().__init__(
222
+ inputs={
223
+ "pixel_values": pixel_value_input,
224
+ "token_ids": token_id_input,
225
+ "padding_mask": padding_mask_input,
226
+ "boxes": box_input,
227
+ "box_labels": box_label_input,
228
+ },
229
+ outputs={
230
+ "pred_masks": pred_masks,
231
+ "pred_boxes": pred_boxes,
232
+ "pred_logits": pred_logits,
233
+ "presence_logits": presence_logits,
234
+ "semantic_segs": semantic_segs,
235
+ },
236
+ dtype=dtype,
237
+ **kwargs,
238
+ )
239
+
240
+ def get_config(self):
241
+ config = super().get_config()
242
+ config.update(
243
+ {
244
+ "vision_encoder": keras.layers.serialize(self.vision_encoder),
245
+ "text_encoder": keras.layers.serialize(self.text_encoder),
246
+ "geometry_encoder": keras.layers.serialize(
247
+ self.geometry_encoder
248
+ ),
249
+ "detr_encoder": keras.layers.serialize(self.detr_encoder),
250
+ "detr_decoder": keras.layers.serialize(self.detr_decoder),
251
+ "mask_decoder": keras.layers.serialize(self.mask_decoder),
252
+ }
253
+ )
254
+ return config
255
+
256
+ @classmethod
257
+ def from_config(cls, config):
258
+ config = config.copy()
259
+
260
+ # Propagate `dtype` to submodels if needed.
261
+ if "dtype" in config and config["dtype"] is not None:
262
+ dtype_config = config["dtype"]
263
+ if "dtype" not in config["vision_encoder"]["config"]:
264
+ config["vision_encoder"]["config"]["dtype"] = dtype_config
265
+ if "dtype" not in config["text_encoder"]["config"]:
266
+ config["text_encoder"]["config"]["dtype"] = dtype_config
267
+ if "dtype" not in config["geometry_encoder"]["config"]:
268
+ config["geometry_encoder"]["config"]["dtype"] = dtype_config
269
+ if "dtype" not in config["detr_encoder"]["config"]:
270
+ config["detr_encoder"]["config"]["dtype"] = dtype_config
271
+ if "dtype" not in config["detr_decoder"]["config"]:
272
+ config["detr_decoder"]["config"]["dtype"] = dtype_config
273
+ if "dtype" not in config["mask_decoder"]["config"]:
274
+ config["mask_decoder"]["config"]["dtype"] = dtype_config
275
+
276
+ # Propagate `image_shape` to submodels if needed.
277
+ if "image_shape" in config and config["image_shape"] is not None:
278
+ image_shape = config.pop("image_shape")
279
+ if "image_shape" in config["vision_encoder"]["config"]:
280
+ config["vision_encoder"]["config"]["image_shape"] = image_shape
281
+ if "image_shape" in config["detr_decoder"]["config"]:
282
+ config["detr_decoder"]["config"]["image_shape"] = image_shape
283
+
284
+ config.update(
285
+ {
286
+ "vision_encoder": keras.layers.deserialize(
287
+ config["vision_encoder"]
288
+ ),
289
+ "text_encoder": keras.layers.deserialize(
290
+ config["text_encoder"]
291
+ ),
292
+ "geometry_encoder": keras.layers.deserialize(
293
+ config["geometry_encoder"]
294
+ ),
295
+ "detr_encoder": keras.layers.deserialize(
296
+ config["detr_encoder"]
297
+ ),
298
+ "detr_decoder": keras.layers.deserialize(
299
+ config["detr_decoder"]
300
+ ),
301
+ "mask_decoder": keras.layers.deserialize(
302
+ config["mask_decoder"]
303
+ ),
304
+ }
305
+ )
306
+ return super().from_config(config)