keras-hub 0.25.1__py3-none-any.whl → 0.26.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. keras_hub/layers/__init__.py +21 -0
  2. keras_hub/models/__init__.py +27 -0
  3. keras_hub/src/layers/modeling/non_max_supression.py +5 -2
  4. keras_hub/src/layers/modeling/reversible_embedding.py +2 -275
  5. keras_hub/src/layers/modeling/token_and_position_embedding.py +6 -6
  6. keras_hub/src/layers/modeling/transformer_layer_utils.py +9 -9
  7. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +3 -1
  8. keras_hub/src/layers/preprocessing/multi_segment_packer.py +3 -1
  9. keras_hub/src/models/albert/albert_backbone.py +1 -3
  10. keras_hub/src/models/backbone.py +3 -0
  11. keras_hub/src/models/bart/bart_backbone.py +1 -3
  12. keras_hub/src/models/bert/bert_backbone.py +2 -4
  13. keras_hub/src/models/bloom/bloom_backbone.py +1 -3
  14. keras_hub/src/models/causal_lm.py +2 -2
  15. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -3
  16. keras_hub/src/models/edrec/edrec_backbone.py +147 -0
  17. keras_hub/src/models/edrec/edrec_layers.py +434 -0
  18. keras_hub/src/models/edrec/edrec_seq2seq_lm.py +273 -0
  19. keras_hub/src/models/electra/electra_backbone.py +1 -3
  20. keras_hub/src/models/f_net/f_net_backbone.py +1 -3
  21. keras_hub/src/models/falcon/falcon_backbone.py +1 -3
  22. keras_hub/src/models/flux/flux_layers.py +3 -3
  23. keras_hub/src/models/flux/flux_maths.py +29 -15
  24. keras_hub/src/models/gemma/gemma_backbone.py +1 -3
  25. keras_hub/src/models/gemma/gemma_causal_lm.py +1 -1
  26. keras_hub/src/models/gemma3/gemma3_attention.py +1 -1
  27. keras_hub/src/models/gemma3/gemma3_backbone.py +70 -8
  28. keras_hub/src/models/gemma3/gemma3_causal_lm.py +16 -1
  29. keras_hub/src/models/gemma3/gemma3_decoder_block.py +1 -1
  30. keras_hub/src/models/gemma3/{gemma3_interleave_embeddings.py → gemma3_layers.py} +101 -0
  31. keras_hub/src/models/gemma3/gemma3_presets.py +67 -7
  32. keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
  33. keras_hub/src/models/gpt2/gpt2_backbone.py +1 -3
  34. keras_hub/src/models/gpt2/gpt2_causal_lm.py +1 -1
  35. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +1 -3
  36. keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +1 -3
  37. keras_hub/src/models/llama/llama_backbone.py +1 -3
  38. keras_hub/src/models/masked_lm.py +1 -1
  39. keras_hub/src/models/mistral/mistral_backbone.py +1 -3
  40. keras_hub/src/models/mixtral/mixtral_backbone.py +1 -3
  41. keras_hub/src/models/moonshine/moonshine_backbone.py +1 -3
  42. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +1 -3
  43. keras_hub/src/models/parseq/parseq_tokenizer.py +3 -1
  44. keras_hub/src/models/phi3/phi3_backbone.py +1 -3
  45. keras_hub/src/models/qwen/qwen_backbone.py +1 -3
  46. keras_hub/src/models/qwen/qwen_presets.py +209 -0
  47. keras_hub/src/models/qwen3/qwen3_backbone.py +1 -3
  48. keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +1 -3
  49. keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +15 -0
  50. keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +1 -3
  51. keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -3
  52. keras_hub/src/models/rqvae/__init__.py +5 -0
  53. keras_hub/src/models/rqvae/rqvae_backbone.py +167 -0
  54. keras_hub/src/models/rqvae/rqvae_layers.py +335 -0
  55. keras_hub/src/models/rwkv7/__init__.py +5 -0
  56. keras_hub/src/models/rwkv7/rwkv7_backbone.py +180 -0
  57. keras_hub/src/models/rwkv7/rwkv7_causal_lm.py +259 -0
  58. keras_hub/src/models/rwkv7/rwkv7_causal_lm_preprocessor.py +214 -0
  59. keras_hub/src/models/rwkv7/rwkv7_layer.py +724 -0
  60. keras_hub/src/models/rwkv7/rwkv7_presets.py +26 -0
  61. keras_hub/src/models/rwkv7/rwkv7_tokenizer.py +495 -0
  62. keras_hub/src/models/sam/sam_backbone.py +5 -1
  63. keras_hub/src/models/sam/sam_prompt_encoder.py +1 -1
  64. keras_hub/src/models/sam3/__init__.py +7 -0
  65. keras_hub/src/models/sam3/roi_align.py +222 -0
  66. keras_hub/src/models/sam3/sam3_detr_decoder.py +641 -0
  67. keras_hub/src/models/sam3/sam3_detr_encoder.py +293 -0
  68. keras_hub/src/models/sam3/sam3_dot_product_scoring.py +120 -0
  69. keras_hub/src/models/sam3/sam3_geometry_encoder.py +517 -0
  70. keras_hub/src/models/sam3/sam3_image_converter.py +10 -0
  71. keras_hub/src/models/sam3/sam3_layers.py +814 -0
  72. keras_hub/src/models/sam3/sam3_mask_decoder.py +374 -0
  73. keras_hub/src/models/sam3/sam3_pc_backbone.py +306 -0
  74. keras_hub/src/models/sam3/sam3_pc_image_segmenter.py +282 -0
  75. keras_hub/src/models/sam3/sam3_pc_image_segmenter_preprocessor.py +336 -0
  76. keras_hub/src/models/sam3/sam3_presets.py +16 -0
  77. keras_hub/src/models/sam3/sam3_text_encoder.py +212 -0
  78. keras_hub/src/models/sam3/sam3_tokenizer.py +65 -0
  79. keras_hub/src/models/sam3/sam3_utils.py +134 -0
  80. keras_hub/src/models/sam3/sam3_vision_encoder.py +738 -0
  81. keras_hub/src/models/segformer/segformer_backbone.py +6 -6
  82. keras_hub/src/models/siglip/siglip_layers.py +1 -3
  83. keras_hub/src/models/smollm3/smollm3_backbone.py +1 -3
  84. keras_hub/src/models/stable_diffusion_3/t5_encoder.py +1 -3
  85. keras_hub/src/models/t5/t5_backbone.py +1 -3
  86. keras_hub/src/models/t5gemma/t5gemma_backbone.py +1 -3
  87. keras_hub/src/models/task.py +1 -1
  88. keras_hub/src/tests/test_case.py +394 -3
  89. keras_hub/src/tokenizers/byte_pair_tokenizer.py +33 -2
  90. keras_hub/src/tokenizers/byte_tokenizer.py +3 -1
  91. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +15 -1
  92. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +3 -1
  93. keras_hub/src/tokenizers/word_piece_tokenizer.py +15 -1
  94. keras_hub/src/utils/preset_utils.py +1 -1
  95. keras_hub/src/utils/tensor_utils.py +12 -0
  96. keras_hub/src/utils/transformers/convert_gemma3.py +68 -22
  97. keras_hub/src/utils/transformers/convert_qwen3_moe.py +4 -1
  98. keras_hub/src/utils/transformers/convert_sam3.py +472 -0
  99. keras_hub/src/utils/transformers/export/gemma3.py +196 -0
  100. keras_hub/src/utils/transformers/export/hf_exporter.py +86 -25
  101. keras_hub/src/utils/transformers/export/qwen.py +136 -0
  102. keras_hub/src/utils/transformers/preset_loader.py +15 -1
  103. keras_hub/src/version.py +1 -1
  104. keras_hub/tokenizers/__init__.py +6 -0
  105. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/METADATA +6 -13
  106. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/RECORD +108 -76
  107. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/WHEEL +1 -1
  108. keras_hub/src/models/gemma3/rms_normalization.py +0 -26
  109. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,738 @@
1
+ import numpy as np
2
+ from keras import layers
3
+ from keras import ops
4
+
5
+ from keras_hub.src.api_export import keras_hub_export
6
+ from keras_hub.src.models.sam3.sam3_layers import SAM3MLP
7
+ from keras_hub.src.models.sam3.sam3_layers import SAM3Embedding
8
+ from keras_hub.src.models.sam3.sam3_layers import SAM3RoPEAttention
9
+ from keras_hub.src.models.sam3.sam3_layers import SAM3SinePositionEmbedding
10
+ from keras_hub.src.models.sam3.sam3_utils import window_partition
11
+ from keras_hub.src.models.sam3.sam3_utils import window_unpartition
12
+
13
+
14
+ class SAM3ViTRotaryEmbedding(layers.Layer):
15
+ def __init__(self, rope_theta, head_dim, end_x, end_y, scale=1.0, **kwargs):
16
+ super().__init__(**kwargs)
17
+ self.rope_theta = float(rope_theta)
18
+ self.head_dim = int(head_dim)
19
+ self.end_x = int(end_x)
20
+ self.end_y = int(end_y)
21
+ self.scale = float(scale)
22
+
23
+ # Ensure even dimension for proper axial splitting.
24
+ if self.head_dim % 4 != 0:
25
+ raise ValueError("Dimension must be divisible by 4 for axial RoPE")
26
+
27
+ def build(self, input_shape):
28
+ freqs = 1.0 / (
29
+ self.rope_theta
30
+ ** (
31
+ np.arange(0, self.head_dim, 4)[: (self.head_dim // 4)]
32
+ / self.head_dim
33
+ )
34
+ )
35
+ flattened_indices = np.arange(self.end_x * self.end_y, dtype=np.int64)
36
+ x_positions = (flattened_indices % self.end_x) * self.scale
37
+ y_positions = (
38
+ np.floor_divide(flattened_indices, self.end_x) * self.scale
39
+ )
40
+ freqs_x = np.outer(x_positions, freqs).astype(np.float32)
41
+ freqs_y = np.outer(y_positions, freqs).astype(np.float32)
42
+ inv_freq = np.concatenate([freqs_x, freqs_y], axis=-1)
43
+ inv_freq = np.repeat(inv_freq, repeats=2, axis=-1)
44
+ rope_embeddings_cos = np.cos(inv_freq)
45
+ rope_embeddings_sin = np.sin(inv_freq)
46
+ self.rope_embeddings_cos = self.add_weight(
47
+ name="rope_embeddings_cos",
48
+ shape=rope_embeddings_cos.shape,
49
+ dtype=self.variable_dtype,
50
+ trainable=False,
51
+ initializer=rope_embeddings_cos,
52
+ )
53
+ self.rope_embeddings_sin = self.add_weight(
54
+ name="rope_embeddings_sin",
55
+ shape=rope_embeddings_sin.shape,
56
+ dtype=self.variable_dtype,
57
+ trainable=False,
58
+ initializer=rope_embeddings_sin,
59
+ )
60
+
61
+ def call(self, inputs):
62
+ return self.rope_embeddings_cos, self.rope_embeddings_sin
63
+
64
+ def get_config(self):
65
+ config = super().get_config()
66
+ config.update(
67
+ {
68
+ "rope_theta": self.rope_theta,
69
+ "head_dim": self.head_dim,
70
+ "end_x": self.end_x,
71
+ "end_y": self.end_y,
72
+ "scale": self.scale,
73
+ }
74
+ )
75
+ return config
76
+
77
+ def compute_output_shape(self, input_shape):
78
+ embedding_shape = (self.end_x * self.end_y, self.head_dim)
79
+ return (embedding_shape, embedding_shape)
80
+
81
+ def load_own_variables(self, store):
82
+ try:
83
+ return super().load_own_variables(store)
84
+ except ValueError:
85
+ # `SAM3ViTRotaryEmbedding` has precomputed weights only. The issue
86
+ # of the loading logic could be ignored.
87
+ pass
88
+
89
+
90
+ class SAM3ViTLayer(layers.Layer):
91
+ def __init__(
92
+ self,
93
+ image_shape,
94
+ patch_size,
95
+ hidden_dim,
96
+ intermediate_dim,
97
+ num_heads,
98
+ hidden_activation="gelu",
99
+ rope_theta=10000.0,
100
+ window_size=0,
101
+ rotary_scale=1.0,
102
+ attention_dropout_rate=0.0,
103
+ hidden_dropout_rate=0.0,
104
+ layer_norm_epsilon=1e-6,
105
+ **kwargs,
106
+ ):
107
+ super().__init__(**kwargs)
108
+ self.image_shape = (
109
+ int(image_shape[0]),
110
+ int(image_shape[1]),
111
+ int(image_shape[2]),
112
+ )
113
+ self.patch_size = int(patch_size)
114
+ self.hidden_dim = int(hidden_dim)
115
+ self.intermediate_dim = int(intermediate_dim)
116
+ self.num_heads = int(num_heads)
117
+ self.hidden_activation = hidden_activation
118
+ self.rope_theta = float(rope_theta)
119
+ self.window_size = int(window_size)
120
+ self.rotary_scale = float(rotary_scale)
121
+ self.hidden_dropout_rate = float(hidden_dropout_rate)
122
+ self.attention_dropout_rate = float(attention_dropout_rate)
123
+ self.layer_norm_epsilon = float(layer_norm_epsilon)
124
+ self.head_dim = self.hidden_dim // self.num_heads
125
+ input_size = (
126
+ self.image_shape[0] // self.patch_size,
127
+ self.image_shape[1] // self.patch_size,
128
+ )
129
+ if self.window_size > 0 and (
130
+ input_size[0] % self.window_size != 0
131
+ or input_size[1] % self.window_size != 0
132
+ ):
133
+ raise ValueError(
134
+ "Image size must be divisible by `patch_size` and "
135
+ "`window_size` for windowed attention. "
136
+ f"Received image size: {image_shape}, "
137
+ f"patch_size: {patch_size}, window_size: {window_size}"
138
+ )
139
+ rotary_input_size = (
140
+ input_size if window_size == 0 else (window_size, window_size)
141
+ )
142
+
143
+ self.layer_norm1 = layers.LayerNormalization(
144
+ epsilon=layer_norm_epsilon,
145
+ dtype=self.dtype_policy,
146
+ name="layer_norm1",
147
+ )
148
+ self.rotary_emb = SAM3ViTRotaryEmbedding(
149
+ rope_theta=rope_theta,
150
+ head_dim=self.head_dim,
151
+ end_x=rotary_input_size[0],
152
+ end_y=rotary_input_size[1],
153
+ scale=self.rotary_scale,
154
+ dtype=self.dtype_policy,
155
+ name="rotary_emb",
156
+ )
157
+ self.attention = SAM3RoPEAttention(
158
+ hidden_dim=self.hidden_dim,
159
+ num_heads=self.num_heads,
160
+ attention_dropout_rate=self.attention_dropout_rate,
161
+ dtype=self.dtype_policy,
162
+ name="attention",
163
+ )
164
+ self.layer_norm2 = layers.LayerNormalization(
165
+ epsilon=layer_norm_epsilon,
166
+ dtype=self.dtype_policy,
167
+ name="layer_norm2",
168
+ )
169
+ self.mlp = SAM3MLP(
170
+ hidden_dim=self.hidden_dim,
171
+ intermediate_dim=self.intermediate_dim,
172
+ activation=self.hidden_activation,
173
+ dropout_rate=self.hidden_dropout_rate,
174
+ dtype=self.dtype_policy,
175
+ name="mlp",
176
+ )
177
+ self.dropout = layers.Dropout(
178
+ self.hidden_dropout_rate, dtype=self.dtype_policy, name="dropout"
179
+ )
180
+
181
+ def build(self, input_shape):
182
+ self.input_hidden_dim = int(input_shape[-1])
183
+ self.layer_norm1.build(input_shape)
184
+ input_shape = self.layer_norm1.compute_output_shape(input_shape)
185
+ self.rotary_emb.build(input_shape)
186
+ input_shape_before_attention = input_shape
187
+ if self.window_size > 0:
188
+ input_shape = list(input_shape)
189
+ input_shape = (
190
+ None,
191
+ self.window_size,
192
+ self.window_size,
193
+ input_shape[-1],
194
+ )
195
+ self.attention.build(input_shape)
196
+ input_shape = self.attention.compute_output_shape(input_shape)
197
+ if self.window_size > 0:
198
+ input_shape = input_shape_before_attention
199
+ self.layer_norm2.build(input_shape)
200
+ self.mlp.build(input_shape)
201
+ self.dropout.build(input_shape)
202
+
203
+ def call(self, hidden_states, training=None):
204
+ residual = hidden_states
205
+ hidden_states = self.layer_norm1(hidden_states, training=training)
206
+ if self.window_size > 0:
207
+ height, width = (
208
+ self.image_shape[0] // self.patch_size,
209
+ self.image_shape[1] // self.patch_size,
210
+ )
211
+ # Partition into non-overlapping windows for efficient attention.
212
+ hidden_states = window_partition(
213
+ hidden_states,
214
+ height,
215
+ width,
216
+ self.window_size,
217
+ self.input_hidden_dim,
218
+ )
219
+
220
+ position_embeddings = self.rotary_emb(hidden_states, training=training)
221
+ hidden_states = self.attention(
222
+ hidden_states, position_embeddings, training=training
223
+ )
224
+ if self.window_size > 0:
225
+ # Reverse window partition to restore original spatial layout.
226
+ hidden_states = window_unpartition(
227
+ hidden_states, height, width, self.window_size, self.hidden_dim
228
+ )
229
+ hidden_states = ops.add(residual, hidden_states)
230
+ residual = hidden_states
231
+ hidden_states = self.layer_norm2(hidden_states, training=training)
232
+ hidden_states = self.mlp(hidden_states, training=training)
233
+ hidden_states = ops.add(
234
+ residual, self.dropout(hidden_states, training=training)
235
+ )
236
+ return hidden_states
237
+
238
+ def get_config(self):
239
+ config = super().get_config()
240
+ config.update(
241
+ {
242
+ "image_shape": self.image_shape,
243
+ "patch_size": self.patch_size,
244
+ "hidden_dim": self.hidden_dim,
245
+ "intermediate_dim": self.intermediate_dim,
246
+ "num_heads": self.num_heads,
247
+ "hidden_activation": self.hidden_activation,
248
+ "rope_theta": self.rope_theta,
249
+ "window_size": self.window_size,
250
+ "rotary_scale": self.rotary_scale,
251
+ "attention_dropout_rate": self.attention_dropout_rate,
252
+ "hidden_dropout_rate": self.hidden_dropout_rate,
253
+ "layer_norm_epsilon": self.layer_norm_epsilon,
254
+ }
255
+ )
256
+ return config
257
+
258
+ def compute_output_shape(self, input_shape):
259
+ return input_shape
260
+
261
+
262
+ class SAM3ViTEncoder(layers.Layer):
263
+ def __init__(
264
+ self,
265
+ image_shape,
266
+ patch_size,
267
+ num_layers,
268
+ hidden_dim,
269
+ intermediate_dim,
270
+ num_heads,
271
+ pretrain_image_shape=(336, 336, 3),
272
+ hidden_activation="gelu",
273
+ rope_theta=100000.0,
274
+ window_size=0,
275
+ global_attn_indexes=None,
276
+ attention_dropout_rate=0.0,
277
+ hidden_dropout_rate=0.0,
278
+ layer_norm_epsilon=1e-6,
279
+ **kwargs,
280
+ ):
281
+ super().__init__(**kwargs)
282
+ self.image_shape = (
283
+ int(image_shape[0]),
284
+ int(image_shape[1]),
285
+ int(image_shape[2]),
286
+ )
287
+ self.patch_size = int(patch_size)
288
+ self.num_layers = int(num_layers)
289
+ self.hidden_dim = int(hidden_dim)
290
+ self.intermediate_dim = int(intermediate_dim)
291
+ self.num_heads = int(num_heads)
292
+ self.hidden_activation = hidden_activation
293
+ self.rope_theta = float(rope_theta)
294
+ self.window_size = int(window_size)
295
+ if global_attn_indexes is not None:
296
+ self.global_attn_indexes = [int(i) for i in global_attn_indexes]
297
+ else:
298
+ self.global_attn_indexes = None
299
+ self.pretrain_image_shape = (
300
+ int(pretrain_image_shape[0]),
301
+ int(pretrain_image_shape[1]),
302
+ int(pretrain_image_shape[2]),
303
+ )
304
+ self.hidden_dropout_rate = float(hidden_dropout_rate)
305
+ self.attention_dropout_rate = float(attention_dropout_rate)
306
+ self.layer_norm_epsilon = float(layer_norm_epsilon)
307
+ height = self.image_shape[0] // self.patch_size
308
+
309
+ self.embeddings = SAM3Embedding(
310
+ hidden_dim=self.hidden_dim,
311
+ patch_size=self.patch_size,
312
+ image_shape=self.image_shape,
313
+ dropout_rate=self.hidden_dropout_rate,
314
+ pretrain_image_shape=self.pretrain_image_shape,
315
+ dtype=self.dtype_policy,
316
+ name="embeddings",
317
+ )
318
+ self.layer_norm = layers.LayerNormalization(
319
+ epsilon=self.layer_norm_epsilon,
320
+ dtype=self.dtype_policy,
321
+ name="layer_norm",
322
+ )
323
+ self.layers = [
324
+ SAM3ViTLayer(
325
+ image_shape=self.image_shape,
326
+ patch_size=self.patch_size,
327
+ hidden_dim=self.hidden_dim,
328
+ intermediate_dim=self.intermediate_dim,
329
+ num_heads=self.num_heads,
330
+ hidden_activation=self.hidden_activation,
331
+ rope_theta=self.rope_theta,
332
+ window_size=(
333
+ self.window_size if i not in self.global_attn_indexes else 0
334
+ ),
335
+ rotary_scale=(
336
+ 1.0
337
+ if i not in self.global_attn_indexes
338
+ else float(self.window_size) / height
339
+ ),
340
+ attention_dropout_rate=self.attention_dropout_rate,
341
+ hidden_dropout_rate=self.hidden_dropout_rate,
342
+ layer_norm_epsilon=self.layer_norm_epsilon,
343
+ dtype=self.dtype_policy,
344
+ name=f"layer_{i}",
345
+ )
346
+ for i in range(self.num_layers)
347
+ ]
348
+
349
+ def build(self, input_shape):
350
+ self.embeddings.build(input_shape)
351
+ input_shape = self.embeddings.compute_output_shape(input_shape)
352
+ input_shape = list(input_shape)
353
+ height = self.image_shape[0] // self.patch_size
354
+ width = self.image_shape[1] // self.patch_size
355
+ input_shape = [input_shape[0], height, width, self.hidden_dim]
356
+ self.layer_norm.build(input_shape)
357
+ for layer in self.layers:
358
+ layer.build(input_shape)
359
+
360
+ def call(self, pixel_values, training=None):
361
+ hidden_states = self.embeddings(pixel_values, training=training)
362
+ height = self.image_shape[0] // self.patch_size
363
+ width = self.image_shape[1] // self.patch_size
364
+ # Reshape to spatial format for windowed attention:
365
+ # [batch_size, height, width, hidden_size]
366
+ hidden_states = ops.reshape(
367
+ hidden_states, (-1, height, width, self.hidden_dim)
368
+ )
369
+ hidden_states = self.layer_norm(hidden_states, training=training)
370
+ for i, layer in enumerate(self.layers):
371
+ hidden_states = layer(hidden_states, training=training)
372
+
373
+ # Reshape back to sequence format:
374
+ # [batch_size, height*width, hidden_size]
375
+ return ops.reshape(hidden_states, (-1, height * width, self.hidden_dim))
376
+
377
+ def get_config(self):
378
+ config = super().get_config()
379
+ config.update(
380
+ {
381
+ "image_shape": self.image_shape,
382
+ "patch_size": self.patch_size,
383
+ "num_layers": self.num_layers,
384
+ "hidden_dim": self.hidden_dim,
385
+ "intermediate_dim": self.intermediate_dim,
386
+ "num_heads": self.num_heads,
387
+ "pretrain_image_shape": self.pretrain_image_shape,
388
+ "hidden_activation": self.hidden_activation,
389
+ "rope_theta": self.rope_theta,
390
+ "window_size": self.window_size,
391
+ "global_attn_indexes": self.global_attn_indexes,
392
+ "attention_dropout_rate": self.attention_dropout_rate,
393
+ "hidden_dropout_rate": self.hidden_dropout_rate,
394
+ "layer_norm_epsilon": self.layer_norm_epsilon,
395
+ }
396
+ )
397
+ return config
398
+
399
+ def compute_output_shape(self, input_shape):
400
+ input_shape = self.embeddings.compute_output_shape(input_shape)
401
+ return input_shape
402
+
403
+
404
+ class SAM3FPNLayer(layers.Layer):
405
+ def __init__(self, input_dim, fpn_dim, scale_factor, **kwargs):
406
+ super().__init__(**kwargs)
407
+ self.input_dim = int(input_dim)
408
+ self.fpn_dim = int(fpn_dim)
409
+ self.scale_factor = float(scale_factor)
410
+
411
+ # Build the upsampling/downsampling layers based on scale factor.
412
+ if self.scale_factor == 4.0:
413
+ self.scale_layers = [
414
+ layers.Conv2DTranspose(
415
+ self.input_dim // 2,
416
+ kernel_size=2,
417
+ strides=2,
418
+ dtype=self.dtype_policy,
419
+ name="scale_layers_0",
420
+ ),
421
+ layers.Activation(
422
+ "gelu", dtype=self.dtype_policy, name="scale_layers_1"
423
+ ),
424
+ layers.Conv2DTranspose(
425
+ self.input_dim // 4,
426
+ kernel_size=2,
427
+ strides=2,
428
+ dtype=self.dtype_policy,
429
+ name="scale_layers_2",
430
+ ),
431
+ ]
432
+ elif self.scale_factor == 2.0:
433
+ self.scale_layers = [
434
+ layers.Conv2DTranspose(
435
+ self.input_dim // 2,
436
+ kernel_size=2,
437
+ strides=2,
438
+ dtype=self.dtype_policy,
439
+ name="scale_layers_0",
440
+ )
441
+ ]
442
+ elif self.scale_factor == 1.0:
443
+ self.scale_layers = []
444
+ elif self.scale_factor == 0.5:
445
+ self.scale_layers = [
446
+ layers.MaxPooling2D(
447
+ pool_size=2,
448
+ strides=2,
449
+ dtype=self.dtype_policy,
450
+ name="scale_layers_0",
451
+ )
452
+ ]
453
+ else:
454
+ raise ValueError(
455
+ f"Unsupported scale factor: {self.scale_factor}. "
456
+ "Supported scale factors are 4.0, 2.0, 1.0, and 0.5."
457
+ )
458
+ self.proj1 = layers.Conv2D(
459
+ self.fpn_dim, kernel_size=1, dtype=self.dtype_policy, name="proj1"
460
+ )
461
+ self.pad = layers.ZeroPadding2D(
462
+ padding=1, dtype=self.dtype_policy, name="pad"
463
+ )
464
+ self.proj2 = layers.Conv2D(
465
+ self.fpn_dim, kernel_size=3, dtype=self.dtype_policy, name="proj2"
466
+ )
467
+
468
+ def build(self, input_shape):
469
+ for layer in self.scale_layers:
470
+ layer.build(input_shape)
471
+ input_shape = layer.compute_output_shape(input_shape)
472
+ self.proj1.build(input_shape)
473
+ input_shape = self.proj1.compute_output_shape(input_shape)
474
+ self.pad.build(input_shape)
475
+ input_shape = self.pad.compute_output_shape(input_shape)
476
+ self.proj2.build(input_shape)
477
+
478
+ def call(self, inputs, training=None):
479
+ hidden_states = inputs
480
+ for layer in self.scale_layers:
481
+ hidden_states = layer(hidden_states, training=training)
482
+ hidden_states = self.proj1(hidden_states, training=training)
483
+ hidden_states = self.pad(hidden_states, training=training)
484
+ return self.proj2(hidden_states, training=training)
485
+
486
+ def get_config(self):
487
+ config = super().get_config()
488
+ config.update(
489
+ {
490
+ "input_dim": self.input_dim,
491
+ "fpn_dim": self.fpn_dim,
492
+ "scale_factor": self.scale_factor,
493
+ }
494
+ )
495
+ return config
496
+
497
+ def compute_output_shape(self, input_shape):
498
+ output_shape = input_shape
499
+ for layer in self.scale_layers:
500
+ output_shape = layer.compute_output_shape(output_shape)
501
+ output_shape = self.proj1.compute_output_shape(output_shape)
502
+ output_shape = self.pad.compute_output_shape(output_shape)
503
+ return self.proj2.compute_output_shape(output_shape)
504
+
505
+
506
+ class SAM3VisionNeck(layers.Layer):
507
+ def __init__(self, hidden_dim, fpn_hidden_dim, scale_factors, **kwargs):
508
+ super().__init__(**kwargs)
509
+ self.hidden_dim = int(hidden_dim)
510
+ self.fpn_hidden_dim = int(fpn_hidden_dim)
511
+ self.scale_factors = scale_factors
512
+
513
+ self.position_encoding = SAM3SinePositionEmbedding(
514
+ num_pos_feats=self.fpn_hidden_dim // 2,
515
+ normalize=True,
516
+ dtype=self.dtype_policy,
517
+ name="position_encoding",
518
+ )
519
+ self.fpn_layers = [
520
+ SAM3FPNLayer(
521
+ input_dim=self.hidden_dim,
522
+ fpn_dim=self.fpn_hidden_dim,
523
+ scale_factor=scale,
524
+ dtype=self.dtype_policy,
525
+ name=f"fpn_layer_{i}",
526
+ )
527
+ for i, scale in enumerate(self.scale_factors)
528
+ ]
529
+
530
+ def build(self, input_shape):
531
+ self.position_encoding.build()
532
+ self.fpn_image_shapes = []
533
+ for layer in self.fpn_layers:
534
+ layer.build(input_shape)
535
+ fpn_shape = layer.compute_output_shape(input_shape)
536
+ self.fpn_image_shapes.append([int(fpn_shape[1]), int(fpn_shape[2])])
537
+
538
+ def call(self, hidden_states, training=None):
539
+ fpn_hidden_states = []
540
+ fpn_position_encodings = []
541
+ for i, layer in enumerate(self.fpn_layers):
542
+ fpn_output = layer(hidden_states, training=training)
543
+ fpn_hidden_states.append(fpn_output)
544
+ height, width = self.fpn_image_shapes[i]
545
+ pos_enc = self.position_encoding(
546
+ fpn_output, height=height, width=width, training=training
547
+ )
548
+ fpn_position_encodings.append(pos_enc)
549
+ return fpn_hidden_states, fpn_position_encodings
550
+
551
+ def get_config(self):
552
+ config = super().get_config()
553
+ config.update(
554
+ {
555
+ "hidden_dim": self.hidden_dim,
556
+ "fpn_hidden_dim": self.fpn_hidden_dim,
557
+ "scale_factors": self.scale_factors,
558
+ }
559
+ )
560
+ return config
561
+
562
+ def compute_output_shape(self, input_shape):
563
+ fpn_hidden_state_shapes = []
564
+ for layer in self.fpn_layers:
565
+ fpn_hidden_state_shapes.append(
566
+ layer.compute_output_shape(input_shape)
567
+ )
568
+ # fpn_hidden_states and fpn_position_encodings have the same shapes.
569
+ return fpn_hidden_state_shapes, fpn_hidden_state_shapes
570
+
571
+
572
+ @keras_hub_export("keras_hub.layers.SAM3VisionEncoder")
573
+ class SAM3VisionEncoder(layers.Layer):
574
+ """A vision encoder for the Segment Anything Model 3 (SAM3).
575
+
576
+ This layer implements a Vision Transformer (ViT) backbone followed by a
577
+ Feature Pyramid Network (FPN) neck. It processes input images and produces
578
+ multi-scale feature maps and their corresponding position encodings.
579
+
580
+ Args:
581
+ image_shape: tuple. The shape of the input image
582
+ (height, width, channels).
583
+ patch_size: int. The size of the patches to be extracted from the image.
584
+ num_layers: int. The number of transformer layers in the ViT backbone.
585
+ hidden_dim: int. The hidden dimension of the transformer layers.
586
+ intermediate_dim: int. The dimension of the intermediate layer in the
587
+ transformer's MLP.
588
+ num_heads: int. The number of attention heads.
589
+ fpn_hidden_dim: int. The hidden dimension of the FPN.
590
+ fpn_scale_factors: list of floats. The scale factors for each level of
591
+ the feature pyramid.
592
+ pretrain_image_shape: tuple. The shape of the image used during
593
+ pretraining, for position embedding interpolation. Defaults to
594
+ `(336, 336, 3)`.
595
+ hidden_activation: str. The activation function for the transformer
596
+ layers. Defaults to `"gelu"`.
597
+ rope_theta: float. The theta value for rotary position embeddings.
598
+ Defaults to `10000.0`.
599
+ window_size: int. The size of the window for windowed attention.
600
+ Defaults to `0`.
601
+ global_attn_indexes: list of ints. The indices of the layers that use
602
+ global attention instead of windowed attention.
603
+ attention_dropout_rate: float. The dropout rate for attention. Defaults
604
+ to `0`.
605
+ hidden_dropout_rate: float. The dropout rate for the MLP. Defaults to
606
+ `0.0`.
607
+ layer_norm_epsilon: float. The epsilon value for layer normalization.
608
+ Defaults to `1e-6`.
609
+ """
610
+
611
+ def __init__(
612
+ self,
613
+ image_shape,
614
+ patch_size,
615
+ num_layers,
616
+ hidden_dim,
617
+ intermediate_dim,
618
+ num_heads,
619
+ fpn_hidden_dim,
620
+ fpn_scale_factors,
621
+ pretrain_image_shape=(336, 336, 3),
622
+ hidden_activation="gelu",
623
+ rope_theta=10000.0,
624
+ window_size=0,
625
+ global_attn_indexes=None,
626
+ attention_dropout_rate=0.0,
627
+ hidden_dropout_rate=0.0,
628
+ layer_norm_epsilon=1e-6,
629
+ **kwargs,
630
+ ):
631
+ super().__init__(**kwargs)
632
+ self.image_shape = (
633
+ int(image_shape[0]),
634
+ int(image_shape[1]),
635
+ int(image_shape[2]),
636
+ )
637
+ self.patch_size = int(patch_size)
638
+ self.num_layers = int(num_layers)
639
+ self.hidden_dim = int(hidden_dim)
640
+ self.intermediate_dim = int(intermediate_dim)
641
+ self.num_heads = int(num_heads)
642
+ self.fpn_hidden_dim = int(fpn_hidden_dim)
643
+ self.fpn_scale_factors = fpn_scale_factors
644
+ self.hidden_activation = hidden_activation
645
+ self.rope_theta = float(rope_theta)
646
+ self.window_size = int(window_size)
647
+ if global_attn_indexes is not None:
648
+ self.global_attn_indexes = [int(i) for i in global_attn_indexes]
649
+ else:
650
+ self.global_attn_indexes = None
651
+ self.pretrain_image_shape = (
652
+ int(pretrain_image_shape[0]),
653
+ int(pretrain_image_shape[1]),
654
+ int(pretrain_image_shape[2]),
655
+ )
656
+ self.hidden_dropout_rate = float(hidden_dropout_rate)
657
+ self.attention_dropout_rate = float(attention_dropout_rate)
658
+ self.layer_norm_epsilon = float(layer_norm_epsilon)
659
+
660
+ self.backbone = SAM3ViTEncoder(
661
+ image_shape=self.image_shape,
662
+ patch_size=self.patch_size,
663
+ num_layers=self.num_layers,
664
+ hidden_dim=self.hidden_dim,
665
+ intermediate_dim=self.intermediate_dim,
666
+ num_heads=self.num_heads,
667
+ pretrain_image_shape=self.pretrain_image_shape,
668
+ hidden_activation=self.hidden_activation,
669
+ rope_theta=self.rope_theta,
670
+ window_size=self.window_size,
671
+ global_attn_indexes=self.global_attn_indexes,
672
+ attention_dropout_rate=self.attention_dropout_rate,
673
+ hidden_dropout_rate=self.hidden_dropout_rate,
674
+ layer_norm_epsilon=self.layer_norm_epsilon,
675
+ dtype=self.dtype_policy,
676
+ name="backbone",
677
+ )
678
+ self.vision_neck = SAM3VisionNeck(
679
+ hidden_dim=self.hidden_dim,
680
+ fpn_hidden_dim=self.fpn_hidden_dim,
681
+ scale_factors=self.fpn_scale_factors,
682
+ dtype=self.dtype_policy,
683
+ name="vision_neck",
684
+ )
685
+
686
+ def build(self, input_shape):
687
+ self.backbone.build(input_shape)
688
+ input_shape = self.backbone.compute_output_shape(input_shape)
689
+ height = self.image_shape[0] // self.patch_size
690
+ width = self.image_shape[1] // self.patch_size
691
+ input_shape = (input_shape[0], height, width, input_shape[-1])
692
+ self.vision_neck.build(input_shape)
693
+
694
+ def call(self, pixel_values, training=None):
695
+ hidden_states = self.backbone(pixel_values, training=training)
696
+ height = self.image_shape[0] // self.patch_size
697
+ width = self.image_shape[1] // self.patch_size
698
+ spatial_hidden_states = ops.reshape(
699
+ hidden_states, (-1, height, width, self.hidden_dim)
700
+ )
701
+ fpn_hidden_states, fpn_position_encodings = self.vision_neck(
702
+ spatial_hidden_states, training=training
703
+ )
704
+ return fpn_hidden_states, fpn_position_encodings
705
+
706
+ def get_config(self):
707
+ config = super().get_config()
708
+ config.update(
709
+ {
710
+ "image_shape": self.image_shape,
711
+ "patch_size": self.patch_size,
712
+ "num_layers": self.num_layers,
713
+ "hidden_dim": self.hidden_dim,
714
+ "intermediate_dim": self.intermediate_dim,
715
+ "num_heads": self.num_heads,
716
+ "fpn_hidden_dim": self.fpn_hidden_dim,
717
+ "fpn_scale_factors": self.fpn_scale_factors,
718
+ "pretrain_image_shape": self.pretrain_image_shape,
719
+ "hidden_activation": self.hidden_activation,
720
+ "rope_theta": self.rope_theta,
721
+ "window_size": self.window_size,
722
+ "global_attn_indexes": self.global_attn_indexes,
723
+ "attention_dropout_rate": self.attention_dropout_rate,
724
+ "hidden_dropout_rate": self.hidden_dropout_rate,
725
+ "layer_norm_epsilon": self.layer_norm_epsilon,
726
+ }
727
+ )
728
+ return config
729
+
730
+ def compute_output_shape(self, input_shape):
731
+ input_shape = self.backbone.compute_output_shape(input_shape)
732
+ height = self.image_shape[0] // self.patch_size
733
+ width = self.image_shape[1] // self.patch_size
734
+ input_shape = (input_shape[0], height, width, input_shape[-1])
735
+ fpn_hidden_state_shapes, fpn_position_encoding_shapes = (
736
+ self.vision_neck.compute_output_shape(input_shape)
737
+ )
738
+ return fpn_hidden_state_shapes, fpn_position_encoding_shapes