keras-hub 0.25.1__py3-none-any.whl → 0.26.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. keras_hub/layers/__init__.py +21 -0
  2. keras_hub/models/__init__.py +27 -0
  3. keras_hub/src/layers/modeling/non_max_supression.py +5 -2
  4. keras_hub/src/layers/modeling/reversible_embedding.py +2 -275
  5. keras_hub/src/layers/modeling/token_and_position_embedding.py +6 -6
  6. keras_hub/src/layers/modeling/transformer_layer_utils.py +9 -9
  7. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +3 -1
  8. keras_hub/src/layers/preprocessing/multi_segment_packer.py +3 -1
  9. keras_hub/src/models/albert/albert_backbone.py +1 -3
  10. keras_hub/src/models/backbone.py +3 -0
  11. keras_hub/src/models/bart/bart_backbone.py +1 -3
  12. keras_hub/src/models/bert/bert_backbone.py +2 -4
  13. keras_hub/src/models/bloom/bloom_backbone.py +1 -3
  14. keras_hub/src/models/causal_lm.py +2 -2
  15. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -3
  16. keras_hub/src/models/edrec/edrec_backbone.py +147 -0
  17. keras_hub/src/models/edrec/edrec_layers.py +434 -0
  18. keras_hub/src/models/edrec/edrec_seq2seq_lm.py +273 -0
  19. keras_hub/src/models/electra/electra_backbone.py +1 -3
  20. keras_hub/src/models/f_net/f_net_backbone.py +1 -3
  21. keras_hub/src/models/falcon/falcon_backbone.py +1 -3
  22. keras_hub/src/models/flux/flux_layers.py +3 -3
  23. keras_hub/src/models/flux/flux_maths.py +29 -15
  24. keras_hub/src/models/gemma/gemma_backbone.py +1 -3
  25. keras_hub/src/models/gemma/gemma_causal_lm.py +1 -1
  26. keras_hub/src/models/gemma3/gemma3_attention.py +1 -1
  27. keras_hub/src/models/gemma3/gemma3_backbone.py +70 -8
  28. keras_hub/src/models/gemma3/gemma3_causal_lm.py +16 -1
  29. keras_hub/src/models/gemma3/gemma3_decoder_block.py +1 -1
  30. keras_hub/src/models/gemma3/{gemma3_interleave_embeddings.py → gemma3_layers.py} +101 -0
  31. keras_hub/src/models/gemma3/gemma3_presets.py +67 -7
  32. keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
  33. keras_hub/src/models/gpt2/gpt2_backbone.py +1 -3
  34. keras_hub/src/models/gpt2/gpt2_causal_lm.py +1 -1
  35. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +1 -3
  36. keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +1 -3
  37. keras_hub/src/models/llama/llama_backbone.py +1 -3
  38. keras_hub/src/models/masked_lm.py +1 -1
  39. keras_hub/src/models/mistral/mistral_backbone.py +1 -3
  40. keras_hub/src/models/mixtral/mixtral_backbone.py +1 -3
  41. keras_hub/src/models/moonshine/moonshine_backbone.py +1 -3
  42. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +1 -3
  43. keras_hub/src/models/parseq/parseq_tokenizer.py +3 -1
  44. keras_hub/src/models/phi3/phi3_backbone.py +1 -3
  45. keras_hub/src/models/qwen/qwen_backbone.py +1 -3
  46. keras_hub/src/models/qwen/qwen_presets.py +209 -0
  47. keras_hub/src/models/qwen3/qwen3_backbone.py +1 -3
  48. keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +1 -3
  49. keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +15 -0
  50. keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +1 -3
  51. keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -3
  52. keras_hub/src/models/rqvae/__init__.py +5 -0
  53. keras_hub/src/models/rqvae/rqvae_backbone.py +167 -0
  54. keras_hub/src/models/rqvae/rqvae_layers.py +335 -0
  55. keras_hub/src/models/rwkv7/__init__.py +5 -0
  56. keras_hub/src/models/rwkv7/rwkv7_backbone.py +180 -0
  57. keras_hub/src/models/rwkv7/rwkv7_causal_lm.py +259 -0
  58. keras_hub/src/models/rwkv7/rwkv7_causal_lm_preprocessor.py +214 -0
  59. keras_hub/src/models/rwkv7/rwkv7_layer.py +724 -0
  60. keras_hub/src/models/rwkv7/rwkv7_presets.py +26 -0
  61. keras_hub/src/models/rwkv7/rwkv7_tokenizer.py +495 -0
  62. keras_hub/src/models/sam/sam_backbone.py +5 -1
  63. keras_hub/src/models/sam/sam_prompt_encoder.py +1 -1
  64. keras_hub/src/models/sam3/__init__.py +7 -0
  65. keras_hub/src/models/sam3/roi_align.py +222 -0
  66. keras_hub/src/models/sam3/sam3_detr_decoder.py +641 -0
  67. keras_hub/src/models/sam3/sam3_detr_encoder.py +293 -0
  68. keras_hub/src/models/sam3/sam3_dot_product_scoring.py +120 -0
  69. keras_hub/src/models/sam3/sam3_geometry_encoder.py +517 -0
  70. keras_hub/src/models/sam3/sam3_image_converter.py +10 -0
  71. keras_hub/src/models/sam3/sam3_layers.py +814 -0
  72. keras_hub/src/models/sam3/sam3_mask_decoder.py +374 -0
  73. keras_hub/src/models/sam3/sam3_pc_backbone.py +306 -0
  74. keras_hub/src/models/sam3/sam3_pc_image_segmenter.py +282 -0
  75. keras_hub/src/models/sam3/sam3_pc_image_segmenter_preprocessor.py +336 -0
  76. keras_hub/src/models/sam3/sam3_presets.py +16 -0
  77. keras_hub/src/models/sam3/sam3_text_encoder.py +212 -0
  78. keras_hub/src/models/sam3/sam3_tokenizer.py +65 -0
  79. keras_hub/src/models/sam3/sam3_utils.py +134 -0
  80. keras_hub/src/models/sam3/sam3_vision_encoder.py +738 -0
  81. keras_hub/src/models/segformer/segformer_backbone.py +6 -6
  82. keras_hub/src/models/siglip/siglip_layers.py +1 -3
  83. keras_hub/src/models/smollm3/smollm3_backbone.py +1 -3
  84. keras_hub/src/models/stable_diffusion_3/t5_encoder.py +1 -3
  85. keras_hub/src/models/t5/t5_backbone.py +1 -3
  86. keras_hub/src/models/t5gemma/t5gemma_backbone.py +1 -3
  87. keras_hub/src/models/task.py +1 -1
  88. keras_hub/src/tests/test_case.py +394 -3
  89. keras_hub/src/tokenizers/byte_pair_tokenizer.py +33 -2
  90. keras_hub/src/tokenizers/byte_tokenizer.py +3 -1
  91. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +15 -1
  92. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +3 -1
  93. keras_hub/src/tokenizers/word_piece_tokenizer.py +15 -1
  94. keras_hub/src/utils/preset_utils.py +1 -1
  95. keras_hub/src/utils/tensor_utils.py +12 -0
  96. keras_hub/src/utils/transformers/convert_gemma3.py +68 -22
  97. keras_hub/src/utils/transformers/convert_qwen3_moe.py +4 -1
  98. keras_hub/src/utils/transformers/convert_sam3.py +472 -0
  99. keras_hub/src/utils/transformers/export/gemma3.py +196 -0
  100. keras_hub/src/utils/transformers/export/hf_exporter.py +86 -25
  101. keras_hub/src/utils/transformers/export/qwen.py +136 -0
  102. keras_hub/src/utils/transformers/preset_loader.py +15 -1
  103. keras_hub/src/version.py +1 -1
  104. keras_hub/tokenizers/__init__.py +6 -0
  105. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/METADATA +6 -13
  106. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/RECORD +108 -76
  107. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/WHEEL +1 -1
  108. keras_hub/src/models/gemma3/rms_normalization.py +0 -26
  109. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,814 @@
1
+ import math
2
+
3
+ from keras import backend
4
+ from keras import config
5
+ from keras import initializers
6
+ from keras import layers
7
+ from keras import ops
8
+
9
+ from keras_hub.src.models.sam3.sam3_utils import box_cxcywh_to_xyxy
10
+ from keras_hub.src.models.sam3.sam3_utils import inverse_sigmoid
11
+ from keras_hub.src.utils.keras_utils import standardize_data_format
12
+
13
+
14
+ class SAM3MLP(layers.Layer):
15
+ def __init__(
16
+ self,
17
+ hidden_dim,
18
+ intermediate_dim,
19
+ activation="gelu",
20
+ dropout_rate=0.0,
21
+ **kwargs,
22
+ ):
23
+ super().__init__(**kwargs)
24
+ self.hidden_dim = int(hidden_dim)
25
+ self.intermediate_dim = int(intermediate_dim)
26
+ self.activation = activation
27
+ self.dropout_rate = float(dropout_rate)
28
+
29
+ self.fc1 = layers.Dense(
30
+ intermediate_dim, dtype=self.dtype_policy, name="fc1"
31
+ )
32
+ self.act = layers.Activation(activation, dtype=self.dtype_policy)
33
+ self.fc2 = layers.Dense(hidden_dim, dtype=self.dtype_policy, name="fc2")
34
+ self.dropout = layers.Dropout(
35
+ dropout_rate, dtype=self.dtype_policy, name="dropout"
36
+ )
37
+
38
+ def build(self, input_shape):
39
+ self.fc1.build(input_shape)
40
+ input_shape = self.fc1.compute_output_shape(input_shape)
41
+ self.dropout.build(input_shape)
42
+ self.act.build(input_shape)
43
+ self.fc2.build(input_shape)
44
+ input_shape = self.fc2.compute_output_shape(input_shape)
45
+
46
+ def call(self, inputs, training=None):
47
+ x = self.fc1(inputs, training=training)
48
+ x = self.dropout(x, training=training)
49
+ x = self.act(x)
50
+ return self.fc2(x, training=training)
51
+
52
+ def get_config(self):
53
+ config = super().get_config()
54
+ config.update(
55
+ {
56
+ "hidden_dim": self.hidden_dim,
57
+ "intermediate_dim": self.intermediate_dim,
58
+ "activation": self.activation,
59
+ "dropout_rate": self.dropout_rate,
60
+ }
61
+ )
62
+ return config
63
+
64
+
65
+ class SAM3Attention(layers.Layer):
66
+ def __init__(self, hidden_dim, num_heads, **kwargs):
67
+ super().__init__(**kwargs)
68
+ self.hidden_dim = int(hidden_dim)
69
+ self.num_heads = int(num_heads)
70
+ self.head_dim = self.hidden_dim // self.num_heads
71
+ self.scale = self.head_dim**-0.5
72
+
73
+ self.q_proj = layers.Dense(
74
+ self.hidden_dim, dtype=self.dtype_policy, name="q_proj"
75
+ )
76
+ self.k_proj = layers.Dense(
77
+ self.hidden_dim, dtype=self.dtype_policy, name="k_proj"
78
+ )
79
+ self.v_proj = layers.Dense(
80
+ self.hidden_dim, dtype=self.dtype_policy, name="v_proj"
81
+ )
82
+ self.o_proj = layers.Dense(
83
+ self.hidden_dim, dtype=self.dtype_policy, name="o_proj"
84
+ )
85
+
86
+ def build(self, query_shape, key_shape, value_shape):
87
+ self.q_proj.build(query_shape)
88
+ self.k_proj.build(key_shape)
89
+ self.v_proj.build(value_shape)
90
+ self.o_proj.build(value_shape)
91
+
92
+ def call(
93
+ self,
94
+ query,
95
+ key,
96
+ value,
97
+ attention_mask=None,
98
+ attention_bias=None,
99
+ training=None,
100
+ ):
101
+ batch_size = ops.shape(query)[0]
102
+
103
+ query = self.q_proj(query, training=training)
104
+ query = ops.reshape(
105
+ query, (batch_size, -1, self.num_heads, self.head_dim)
106
+ )
107
+ key = self.k_proj(key, training=training)
108
+ key = ops.reshape(key, (batch_size, -1, self.num_heads, self.head_dim))
109
+ value = self.v_proj(value, training=training)
110
+ value = ops.reshape(
111
+ value, (batch_size, -1, self.num_heads, self.head_dim)
112
+ )
113
+
114
+ if (
115
+ backend.backend() == "torch"
116
+ and attention_mask is None
117
+ and attention_bias is not None
118
+ ):
119
+ # TODO: Torch backend doesn't support attention_bias in
120
+ # ops.dot_product_attention yet.
121
+ # Fixed by https://github.com/keras-team/keras/pull/22045
122
+ import torch
123
+
124
+ query = torch.transpose(query, 1, 2).contiguous()
125
+ key = torch.transpose(key, 1, 2).contiguous()
126
+ value = torch.transpose(value, 1, 2).contiguous()
127
+ attention_bias = attention_bias.contiguous()
128
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
129
+ query,
130
+ key,
131
+ value,
132
+ attn_mask=attention_bias,
133
+ is_causal=False,
134
+ scale=self.scale,
135
+ )
136
+ attn_output = torch.transpose(attn_output, 2, 1)
137
+ else:
138
+ if attention_mask is not None:
139
+ attention_mask = ops.cast(attention_mask, dtype="bool")
140
+ attn_output = ops.dot_product_attention(
141
+ query,
142
+ key,
143
+ value,
144
+ bias=attention_bias,
145
+ mask=attention_mask,
146
+ scale=self.scale,
147
+ is_causal=False,
148
+ )
149
+ attn_output = ops.reshape(
150
+ attn_output, (batch_size, -1, self.num_heads * self.head_dim)
151
+ )
152
+ return self.o_proj(attn_output, training=training)
153
+
154
+ def get_config(self):
155
+ config = super().get_config()
156
+ config.update(
157
+ {
158
+ "hidden_dim": self.hidden_dim,
159
+ "num_heads": self.num_heads,
160
+ }
161
+ )
162
+ return config
163
+
164
+ def compute_output_shape(self, input_shape):
165
+ return input_shape
166
+
167
+
168
+ class SAM3RoPEAttention(layers.Layer):
169
+ def __init__(
170
+ self,
171
+ hidden_dim,
172
+ num_heads,
173
+ attention_dropout_rate=0.0,
174
+ **kwargs,
175
+ ):
176
+ super().__init__(**kwargs)
177
+ self.hidden_dim = int(hidden_dim)
178
+ self.num_heads = int(num_heads)
179
+ self.attention_dropout_rate = float(attention_dropout_rate)
180
+ self.head_dim = self.hidden_dim // self.num_heads
181
+ self.scale = self.head_dim**-0.5
182
+
183
+ self.q_proj = layers.Dense(
184
+ self.hidden_dim, dtype=self.dtype_policy, name="q_proj"
185
+ )
186
+ self.k_proj = layers.Dense(
187
+ self.hidden_dim, dtype=self.dtype_policy, name="k_proj"
188
+ )
189
+ self.v_proj = layers.Dense(
190
+ self.hidden_dim, dtype=self.dtype_policy, name="v_proj"
191
+ )
192
+ self.o_proj = layers.Dense(
193
+ self.hidden_dim, dtype=self.dtype_policy, name="o_proj"
194
+ )
195
+
196
+ def build(self, input_shape):
197
+ self.height = input_shape[1]
198
+ self.width = input_shape[2]
199
+ self.q_proj.build(input_shape)
200
+ self.k_proj.build(input_shape)
201
+ self.v_proj.build(input_shape)
202
+ self.o_proj.build(input_shape)
203
+
204
+ def apply_rotary_pos_emb_2d(self, query, key, cos, sin):
205
+ def rotate_pairwise(x):
206
+ x = ops.reshape(
207
+ x,
208
+ (
209
+ -1,
210
+ self.num_heads,
211
+ self.height * self.width,
212
+ self.head_dim // 2,
213
+ 2,
214
+ ),
215
+ )
216
+ x1 = x[..., 0]
217
+ x2 = x[..., 1]
218
+ x = ops.stack((-x2, x1), axis=-1)
219
+ return ops.reshape(
220
+ x, (-1, self.num_heads, self.height * self.width, self.head_dim)
221
+ )
222
+
223
+ query = ops.transpose(query, axes=(0, 2, 1, 3))
224
+ key = ops.transpose(key, axes=(0, 2, 1, 3))
225
+
226
+ original_dtype = backend.standardize_dtype(query.dtype)
227
+ query_embed = ops.cast(query, dtype="float32")
228
+ query_embed = ops.add(
229
+ ops.multiply(query_embed, cos),
230
+ ops.multiply(rotate_pairwise(query_embed), sin),
231
+ )
232
+ key_embed = ops.cast(key, dtype="float32")
233
+ key_embed = ops.add(
234
+ ops.multiply(key_embed, cos),
235
+ ops.multiply(rotate_pairwise(key_embed), sin),
236
+ )
237
+ query_embed = ops.cast(query_embed, dtype=original_dtype)
238
+ key_embed = ops.cast(key_embed, dtype=original_dtype)
239
+
240
+ query_embed = ops.transpose(query_embed, axes=(0, 2, 1, 3))
241
+ key_embed = ops.transpose(key_embed, axes=(0, 2, 1, 3))
242
+ return query_embed, key_embed
243
+
244
+ def call(self, hidden_states, position_embeddings, training=None):
245
+ new_shape = (
246
+ -1,
247
+ self.height * self.width,
248
+ self.num_heads,
249
+ self.head_dim,
250
+ )
251
+
252
+ query = self.q_proj(hidden_states, training=training)
253
+ query = ops.reshape(query, new_shape)
254
+ key = self.k_proj(hidden_states, training=training)
255
+ key = ops.reshape(key, new_shape)
256
+ value = self.v_proj(hidden_states, training=training)
257
+ value = ops.reshape(value, new_shape)
258
+ cos, sin = position_embeddings
259
+ query, key = self.apply_rotary_pos_emb_2d(query, key, cos=cos, sin=sin)
260
+
261
+ attention_output = ops.dot_product_attention(
262
+ query, key, value, scale=self.scale, is_causal=False
263
+ )
264
+ attention_output = ops.reshape(
265
+ attention_output, (-1, self.height, self.width, self.hidden_dim)
266
+ )
267
+ attention_output = self.o_proj(attention_output, training=training)
268
+ return attention_output
269
+
270
+ def get_config(self):
271
+ config = super().get_config()
272
+ config.update(
273
+ {
274
+ "hidden_dim": self.hidden_dim,
275
+ "num_heads": self.num_heads,
276
+ "attention_dropout_rate": self.attention_dropout_rate,
277
+ }
278
+ )
279
+ return config
280
+
281
+ def compute_output_shape(self, input_shape):
282
+ return input_shape
283
+
284
+
285
+ class SAM3PatchEmbedding(layers.Layer):
286
+ def __init__(self, hidden_dim, patch_size, data_format=None, **kwargs):
287
+ super().__init__(**kwargs)
288
+ self.hidden_dim = int(hidden_dim)
289
+ self.patch_size = int(patch_size)
290
+ self.data_format = standardize_data_format(data_format)
291
+
292
+ self.projection = layers.Conv2D(
293
+ self.hidden_dim,
294
+ kernel_size=self.patch_size,
295
+ strides=self.patch_size,
296
+ use_bias=False,
297
+ dtype=self.dtype_policy,
298
+ name="projection",
299
+ )
300
+
301
+ def build(self, input_shape):
302
+ self.projection.build(input_shape)
303
+ output_shape = self.projection.compute_output_shape(input_shape)
304
+ if self.data_format == "channels_last":
305
+ self.seq_len = int(output_shape[1]) * int(output_shape[2])
306
+ else:
307
+ self.seq_len = int(output_shape[2]) * int(output_shape[3])
308
+
309
+ def call(self, inputs, training=None):
310
+ embeddings = self.projection(inputs, training=training)
311
+ if self.data_format == "channels_last":
312
+ embeddings = ops.reshape(
313
+ embeddings, (-1, self.seq_len, self.hidden_dim)
314
+ )
315
+ else:
316
+ embeddings = ops.reshape(
317
+ embeddings, (-1, self.hidden_dim, self.seq_len)
318
+ )
319
+ embeddings = ops.transpose(embeddings, (0, 2, 1))
320
+ return embeddings
321
+
322
+ def get_config(self):
323
+ config = super().get_config()
324
+ config.update(
325
+ {
326
+ "hidden_dim": self.hidden_dim,
327
+ "patch_size": self.patch_size,
328
+ }
329
+ )
330
+ return config
331
+
332
+ def compute_output_shape(self, input_shape):
333
+ output_shape = [input_shape[0], None, self.hidden_dim]
334
+ if self.data_format == "channels_last":
335
+ if input_shape[1] is not None and input_shape[2] is not None:
336
+ patch_num = input_shape[1] // self.patch_size
337
+ output_shape[1] = patch_num**2
338
+ else:
339
+ if input_shape[2] is not None and input_shape[3] is not None:
340
+ patch_num = input_shape[2] // self.patch_size
341
+ output_shape[1] = patch_num**2
342
+ return output_shape
343
+
344
+
345
+ class SAM3Embedding(layers.Layer):
346
+ def __init__(
347
+ self,
348
+ hidden_dim,
349
+ patch_size,
350
+ image_shape,
351
+ dropout_rate=0.0,
352
+ pretrain_image_shape=(336, 336, 3),
353
+ data_format=None,
354
+ **kwargs,
355
+ ):
356
+ super().__init__(**kwargs)
357
+ self.hidden_dim = int(hidden_dim)
358
+ self.patch_size = int(patch_size)
359
+ self.image_shape = (
360
+ int(image_shape[0]),
361
+ int(image_shape[1]),
362
+ int(image_shape[2]),
363
+ )
364
+ self.dropout_rate = float(dropout_rate)
365
+ self.pretrain_image_shape = (
366
+ int(pretrain_image_shape[0]),
367
+ int(pretrain_image_shape[1]),
368
+ int(pretrain_image_shape[2]),
369
+ )
370
+ self.data_format = standardize_data_format(data_format)
371
+ self.num_patches = (self.pretrain_image_shape[0] // self.patch_size) * (
372
+ self.pretrain_image_shape[1] // self.patch_size
373
+ )
374
+ self.tiled_num_patches = (self.image_shape[0] // self.patch_size) * (
375
+ self.image_shape[1] // self.patch_size
376
+ )
377
+
378
+ self.patch_embeddings = SAM3PatchEmbedding(
379
+ hidden_dim=self.hidden_dim,
380
+ patch_size=self.patch_size,
381
+ data_format=self.data_format,
382
+ dtype=self.dtype_policy,
383
+ name="patch_embeddings",
384
+ )
385
+ self.dropout = layers.Dropout(
386
+ self.dropout_rate, dtype=self.dtype_policy, name="dropout"
387
+ )
388
+
389
+ def build(self, input_shape):
390
+ self.patch_embeddings.build(input_shape)
391
+ embedding_shape = self.patch_embeddings.compute_output_shape(
392
+ input_shape
393
+ )
394
+ self.dropout.build(embedding_shape)
395
+
396
+ # Note that there are two position embeddings:
397
+ # `self.tiled_position_embeddings` is used for the image inputs during
398
+ # both training and inference.
399
+ # `self.position_embeddings` is used to load pretrained weights and
400
+ # remains unchanged during training and inference. It will be updated
401
+ # during saving once `self.tiled_position_embeddings` is modified.
402
+ self.position_embeddings = self.add_weight(
403
+ shape=(1, self.num_patches, self.hidden_dim),
404
+ initializer=initializers.TruncatedNormal(stddev=0.02),
405
+ trainable=False,
406
+ name="position_embeddings",
407
+ )
408
+ self.tiled_position_embeddings = self.add_weight(
409
+ shape=(1, self.tiled_num_patches, self.hidden_dim),
410
+ initializer="zeros", # Will be initialized by tiling.
411
+ trainable=True,
412
+ name="tiled_position_embeddings",
413
+ )
414
+
415
+ # Initialize the interpolated position embeddings.
416
+ self.tiled_position_embeddings.assign(
417
+ self._tile_position_embeddings(
418
+ self.position_embeddings,
419
+ patch_size=self.patch_size,
420
+ source_shape=self.pretrain_image_shape,
421
+ target_shape=self.image_shape,
422
+ )
423
+ )
424
+
425
+ def call(self, inputs, training=None):
426
+ x = inputs
427
+ patch_embeddings = self.patch_embeddings(x, training=training)
428
+ if self.data_format == "channels_last":
429
+ patch_embeddings = ops.reshape(
430
+ patch_embeddings,
431
+ (-1, self.patch_embeddings.seq_len, self.hidden_dim),
432
+ )
433
+ else:
434
+ patch_embeddings = ops.reshape(
435
+ patch_embeddings,
436
+ (-1, self.hidden_dim, self.patch_embeddings.seq_len),
437
+ )
438
+ patch_embeddings = ops.transpose(patch_embeddings, (0, 2, 1))
439
+ embeddings = ops.add(patch_embeddings, self.tiled_position_embeddings)
440
+ embeddings = self.dropout(embeddings, training=training)
441
+ return embeddings
442
+
443
+ def get_config(self):
444
+ config = super().get_config()
445
+ config.update(
446
+ {
447
+ "hidden_dim": self.hidden_dim,
448
+ "patch_size": self.patch_size,
449
+ "image_shape": self.image_shape,
450
+ "dropout_rate": self.dropout_rate,
451
+ "pretrain_image_shape": self.pretrain_image_shape,
452
+ }
453
+ )
454
+ return config
455
+
456
+ def compute_output_shape(self, input_shape):
457
+ if input_shape is None:
458
+ input_shape = [None, None, None, None]
459
+ output_shape = [input_shape[0], None, self.hidden_dim]
460
+ if self.data_format == "channels_last":
461
+ if input_shape[1] is not None and input_shape[2] is not None:
462
+ patch_num = input_shape[1] // self.patch_size
463
+ output_shape[1] = patch_num**2
464
+ else:
465
+ if input_shape[2] is not None and input_shape[3] is not None:
466
+ patch_num = input_shape[2] // self.patch_size
467
+ output_shape[1] = patch_num**2
468
+ return output_shape
469
+
470
+ @staticmethod
471
+ def _tile_position_embeddings(
472
+ position_embeddings, patch_size, source_shape, target_shape
473
+ ):
474
+ """Tile position embeddings to match the target image shape.
475
+
476
+ Reference:
477
+ - https://github.com/huggingface/transformers/blob/main/src/transformers/models/sam3/modeling_sam3.py
478
+ """
479
+ position_embeddings = ops.convert_to_tensor(position_embeddings)
480
+ patch_size = int(patch_size)
481
+ source_shape = (int(source_shape[0]), int(source_shape[1]))
482
+ target_shape = (int(target_shape[0]), int(target_shape[1]))
483
+ hidden_dim = int(position_embeddings.shape[-1])
484
+
485
+ if (
486
+ source_shape[0] == target_shape[0]
487
+ and source_shape[1] == target_shape[1]
488
+ ):
489
+ # No need to tile if the image size is the same as the
490
+ # position embedding image size.
491
+ return ops.copy(position_embeddings)
492
+
493
+ # Tile position embeddings to match target image size.
494
+ source_embedding_shape = (
495
+ source_shape[0] // patch_size,
496
+ source_shape[1] // patch_size,
497
+ )
498
+ target_embedding_shape = (
499
+ target_shape[0] // patch_size,
500
+ target_shape[1] // patch_size,
501
+ )
502
+ position_embeddings = ops.reshape(
503
+ position_embeddings,
504
+ (
505
+ 1,
506
+ source_embedding_shape[0],
507
+ source_embedding_shape[1],
508
+ hidden_dim,
509
+ ),
510
+ )
511
+ repeat_h = target_embedding_shape[0] // source_embedding_shape[0] + 1
512
+ repeat_w = target_embedding_shape[1] // source_embedding_shape[1] + 1
513
+ position_embeddings = ops.tile(
514
+ position_embeddings, (1, repeat_h, repeat_w, 1)
515
+ )
516
+ position_embeddings = position_embeddings[
517
+ :, : target_embedding_shape[0], : target_embedding_shape[1], :
518
+ ]
519
+ return ops.reshape(position_embeddings, (1, -1, hidden_dim))
520
+
521
+ def _is_tiled_position_embeddings_updated(self):
522
+ """Check if the tiled position embeddings are updated."""
523
+ original_tiled_position_embeddings = self._tile_position_embeddings(
524
+ self.position_embeddings,
525
+ patch_size=self.patch_size,
526
+ source_shape=self.pretrain_image_shape,
527
+ target_shape=self.image_shape,
528
+ )
529
+ diff = ops.sum(
530
+ ops.subtract(
531
+ original_tiled_position_embeddings,
532
+ self.tiled_position_embeddings,
533
+ )
534
+ )
535
+ return ops.cond(
536
+ ops.greater(diff, config.epsilon()), lambda: True, lambda: False
537
+ )
538
+
539
+ def save_own_variables(self, store):
540
+ if self._is_tiled_position_embeddings_updated():
541
+ self.position_embeddings.assign(
542
+ self._tile_position_embeddings(
543
+ self.tiled_position_embeddings,
544
+ patch_size=self.patch_size,
545
+ source_shape=self.image_shape,
546
+ target_shape=self.pretrain_image_shape,
547
+ )
548
+ )
549
+ super().save_own_variables(store)
550
+
551
+ def load_own_variables(self, store):
552
+ all_vars = self._trainable_variables + self._non_trainable_variables
553
+ for i, v in enumerate(all_vars):
554
+ if v is self.tiled_position_embeddings:
555
+ continue
556
+ v.assign(store[f"{i}"])
557
+ self.tiled_position_embeddings.assign(
558
+ self._tile_position_embeddings(
559
+ self.position_embeddings,
560
+ patch_size=self.patch_size,
561
+ source_shape=self.pretrain_image_shape,
562
+ target_shape=self.image_shape,
563
+ )
564
+ )
565
+
566
+
567
+ class SAM3SinePositionEmbedding(layers.Layer):
568
+ def __init__(
569
+ self,
570
+ num_pos_feats=64,
571
+ temperature=10000,
572
+ normalize=False,
573
+ scale=None,
574
+ **kwargs,
575
+ ):
576
+ super().__init__(**kwargs)
577
+ self.num_pos_feats = int(num_pos_feats)
578
+ self.temperature = float(temperature)
579
+ self.normalize = bool(normalize)
580
+ if scale is not None and normalize is False:
581
+ raise ValueError("normalize should be True if scale is passed")
582
+ self.scale = 2 * math.pi if scale is None else scale
583
+
584
+ def build(self, input_shape=None):
585
+ if self.built:
586
+ return
587
+
588
+ def encode_1d_positions(self, x, y):
589
+ x_embed = ops.multiply(x, self.scale)
590
+ y_embed = ops.multiply(y, self.scale)
591
+ dim_t = ops.cast(ops.arange(self.num_pos_feats), dtype=x.dtype)
592
+ dim_t = ops.power(
593
+ self.temperature,
594
+ ops.divide(
595
+ ops.multiply(2, ops.floor_divide(dim_t, 2)), self.num_pos_feats
596
+ ),
597
+ )
598
+ pos_x = ops.divide(ops.expand_dims(x_embed, -1), dim_t)
599
+ pos_y = ops.divide(ops.expand_dims(y_embed, -1), dim_t)
600
+ pos_x = ops.stack(
601
+ (ops.sin(pos_x[:, 0::2]), ops.cos(pos_x[:, 1::2])), axis=2
602
+ )
603
+ pos_x = ops.reshape(pos_x, (-1, self.num_pos_feats))
604
+ pos_y = ops.stack(
605
+ (ops.sin(pos_y[:, 0::2]), ops.cos(pos_y[:, 1::2])), axis=2
606
+ )
607
+ pos_y = ops.reshape(pos_y, (-1, self.num_pos_feats))
608
+ return pos_x, pos_y
609
+
610
+ def encode_boxes(self, boxes):
611
+ dim_t = ops.cast(ops.arange(self.num_pos_feats), dtype=boxes.dtype)
612
+ dim_t = ops.power(
613
+ self.temperature,
614
+ ops.divide(
615
+ ops.multiply(2, ops.floor_divide(dim_t, 2)), self.num_pos_feats
616
+ ),
617
+ )
618
+
619
+ x_embed = ops.multiply(boxes[..., 0], self.scale)
620
+ y_embed = ops.multiply(boxes[..., 1], self.scale)
621
+ w_embed = ops.multiply(boxes[..., 2], self.scale)
622
+ h_embed = ops.multiply(boxes[..., 3], self.scale)
623
+ pos_x = ops.divide(ops.expand_dims(x_embed, -1), dim_t)
624
+ pos_y = ops.divide(ops.expand_dims(y_embed, -1), dim_t)
625
+ pos_w = ops.divide(ops.expand_dims(w_embed, -1), dim_t)
626
+ pos_h = ops.divide(ops.expand_dims(h_embed, -1), dim_t)
627
+ pos_x_shape = ops.shape(pos_x)
628
+ newshape = (pos_x_shape[0], pos_x_shape[1], self.num_pos_feats)
629
+ pos_x = ops.stack(
630
+ (ops.sin(pos_x[..., 0::2]), ops.cos(pos_x[..., 1::2])), axis=3
631
+ )
632
+ pos_x = ops.reshape(pos_x, newshape)
633
+ pos_y = ops.stack(
634
+ (ops.sin(pos_y[..., 0::2]), ops.cos(pos_y[..., 1::2])), axis=3
635
+ )
636
+ pos_y = ops.reshape(pos_y, newshape)
637
+ pos_w = ops.stack(
638
+ (ops.sin(pos_w[..., 0::2]), ops.cos(pos_w[..., 1::2])), axis=3
639
+ )
640
+ pos_w = ops.reshape(pos_w, newshape)
641
+ pos_h = ops.stack(
642
+ (ops.sin(pos_h[..., 0::2]), ops.cos(pos_h[..., 1::2])), axis=3
643
+ )
644
+ pos_h = ops.reshape(pos_h, newshape)
645
+ return ops.concatenate([pos_y, pos_x, pos_w, pos_h], axis=2)
646
+
647
+ def call(self, inputs, height, width, training=None):
648
+ not_mask = ops.ones((1, height, width), dtype=self.compute_dtype)
649
+ y_embed = ops.cumsum(not_mask, axis=1)
650
+ x_embed = ops.cumsum(not_mask, axis=2)
651
+ if self.normalize:
652
+ eps = 1e-6
653
+ y_embed = ops.multiply(
654
+ ops.divide(y_embed, ops.add(y_embed[:, -1:, :], eps)),
655
+ self.scale,
656
+ )
657
+ x_embed = ops.multiply(
658
+ ops.divide(x_embed, ops.add(x_embed[:, :, -1:], eps)),
659
+ self.scale,
660
+ )
661
+ dim_t = ops.cast(
662
+ ops.arange(self.num_pos_feats), dtype=self.compute_dtype
663
+ )
664
+ dim_t = ops.power(
665
+ self.temperature,
666
+ ops.divide(
667
+ ops.multiply(2, ops.floor_divide(dim_t, 2)), self.num_pos_feats
668
+ ),
669
+ )
670
+
671
+ pos_x = ops.divide(ops.expand_dims(x_embed, -1), dim_t)
672
+ pos_y = ops.divide(ops.expand_dims(y_embed, -1), dim_t)
673
+ newshape = (1, height, width, self.num_pos_feats)
674
+ pos_x = ops.stack(
675
+ (ops.sin(pos_x[..., 0::2]), ops.cos(pos_x[..., 1::2])), axis=4
676
+ )
677
+ pos_x = ops.reshape(pos_x, newshape)
678
+ pos_y = ops.stack(
679
+ (ops.sin(pos_y[..., 0::2]), ops.cos(pos_y[..., 1::2])), axis=4
680
+ )
681
+ pos_y = ops.reshape(pos_y, newshape)
682
+ pos = ops.concatenate([pos_y, pos_x], axis=3)
683
+ pos = ops.tile(pos, (ops.shape(inputs)[0], 1, 1, 1))
684
+ return pos
685
+
686
+ def get_config(self):
687
+ config = super().get_config()
688
+ config.update(
689
+ {
690
+ "num_pos_feats": self.num_pos_feats,
691
+ "temperature": self.temperature,
692
+ "normalize": self.normalize,
693
+ "scale": self.scale,
694
+ }
695
+ )
696
+ return config
697
+
698
+ def compute_output_shape(self, input_shape):
699
+ output_shape = list(input_shape)
700
+ output_shape[1] = self.num_pos_feats * 2
701
+ return output_shape
702
+
703
+
704
+ class SAM3DecoderMLP(layers.Layer):
705
+ def __init__(self, num_layers, hidden_dim, output_dim, **kwargs):
706
+ super().__init__(**kwargs)
707
+ self.num_layers = int(num_layers)
708
+ self.hidden_dim = int(hidden_dim)
709
+ self.output_dim = int(output_dim)
710
+
711
+ if self.num_layers == 2:
712
+ self.layer1 = layers.Dense(
713
+ hidden_dim, dtype=self.dtype_policy, name="layer1"
714
+ )
715
+ self.layer2 = layers.Dense(
716
+ output_dim, dtype=self.dtype_policy, name="layer2"
717
+ )
718
+ elif num_layers == 3:
719
+ self.layer1 = layers.Dense(
720
+ hidden_dim, dtype=self.dtype_policy, name="layer1"
721
+ )
722
+ self.layer2 = layers.Dense(
723
+ hidden_dim, dtype=self.dtype_policy, name="layer2"
724
+ )
725
+ self.layer3 = layers.Dense(
726
+ output_dim, dtype=self.dtype_policy, name="layer3"
727
+ )
728
+ else:
729
+ raise ValueError("num_layers should be 2 or 3.")
730
+
731
+ def build(self, input_shape):
732
+ self.layer1.build(input_shape)
733
+ input_shape = self.layer1.compute_output_shape(input_shape)
734
+ self.layer2.build(input_shape)
735
+ if self.num_layers == 3:
736
+ input_shape = self.layer2.compute_output_shape(input_shape)
737
+ self.layer3.build(input_shape)
738
+
739
+ def call(self, inputs, training=None):
740
+ x = ops.relu(self.layer1(inputs, training=training))
741
+ if self.num_layers == 2:
742
+ return self.layer2(x, training=training)
743
+ else:
744
+ x = ops.relu(self.layer2(x, training=training))
745
+ return self.layer3(x, training=training)
746
+
747
+ def get_config(self):
748
+ config = super().get_config()
749
+ config.update(
750
+ {
751
+ "num_layers": self.num_layers,
752
+ "hidden_dim": self.hidden_dim,
753
+ "output_dim": self.output_dim,
754
+ }
755
+ )
756
+ return config
757
+
758
+ def compute_output_shape(self, input_shape):
759
+ output_shape = list(input_shape)
760
+ output_shape[-1] = self.output_dim
761
+ return output_shape
762
+
763
+
764
+ class SAM3BoxDecoder(layers.Layer):
765
+ def build(
766
+ self,
767
+ box_offsets_shape,
768
+ reference_boxes_shape,
769
+ pred_logits_shape,
770
+ presence_logits_shape,
771
+ ):
772
+ pass
773
+
774
+ def call(
775
+ self,
776
+ box_offsets,
777
+ reference_boxes,
778
+ pred_logits,
779
+ presence_logits,
780
+ training=None,
781
+ ):
782
+ reference_boxes_inv_sig = inverse_sigmoid(reference_boxes)
783
+ pred_boxes_cxcywh = ops.nn.sigmoid(
784
+ ops.add(reference_boxes_inv_sig, box_offsets)
785
+ )
786
+ pred_boxes = box_cxcywh_to_xyxy(pred_boxes_cxcywh)
787
+ return (
788
+ pred_boxes[:, -1],
789
+ pred_logits[:, -1, :, 0],
790
+ presence_logits[:, -1],
791
+ )
792
+
793
+ def compute_output_shape(
794
+ self,
795
+ box_offsets_shape,
796
+ reference_boxes_shape,
797
+ pred_logits_shape,
798
+ presence_logits_shape,
799
+ ):
800
+ pred_boxes_shape = [
801
+ box_offsets_shape[0],
802
+ box_offsets_shape[-2],
803
+ box_offsets_shape[-1],
804
+ ]
805
+ pred_logits_shape = [
806
+ pred_logits_shape[0],
807
+ pred_logits_shape[-2],
808
+ ]
809
+ presence_logits_shape = [
810
+ presence_logits_shape[0],
811
+ presence_logits_shape[-2],
812
+ presence_logits_shape[-1],
813
+ ]
814
+ return pred_boxes_shape, pred_logits_shape, presence_logits_shape