keras-hub 0.25.1__py3-none-any.whl → 0.26.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/layers/__init__.py +21 -0
- keras_hub/models/__init__.py +27 -0
- keras_hub/src/layers/modeling/non_max_supression.py +5 -2
- keras_hub/src/layers/modeling/reversible_embedding.py +2 -275
- keras_hub/src/layers/modeling/token_and_position_embedding.py +6 -6
- keras_hub/src/layers/modeling/transformer_layer_utils.py +9 -9
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +3 -1
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +3 -1
- keras_hub/src/models/albert/albert_backbone.py +1 -3
- keras_hub/src/models/backbone.py +3 -0
- keras_hub/src/models/bart/bart_backbone.py +1 -3
- keras_hub/src/models/bert/bert_backbone.py +2 -4
- keras_hub/src/models/bloom/bloom_backbone.py +1 -3
- keras_hub/src/models/causal_lm.py +2 -2
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -3
- keras_hub/src/models/edrec/edrec_backbone.py +147 -0
- keras_hub/src/models/edrec/edrec_layers.py +434 -0
- keras_hub/src/models/edrec/edrec_seq2seq_lm.py +273 -0
- keras_hub/src/models/electra/electra_backbone.py +1 -3
- keras_hub/src/models/f_net/f_net_backbone.py +1 -3
- keras_hub/src/models/falcon/falcon_backbone.py +1 -3
- keras_hub/src/models/flux/flux_layers.py +3 -3
- keras_hub/src/models/flux/flux_maths.py +29 -15
- keras_hub/src/models/gemma/gemma_backbone.py +1 -3
- keras_hub/src/models/gemma/gemma_causal_lm.py +1 -1
- keras_hub/src/models/gemma3/gemma3_attention.py +1 -1
- keras_hub/src/models/gemma3/gemma3_backbone.py +70 -8
- keras_hub/src/models/gemma3/gemma3_causal_lm.py +16 -1
- keras_hub/src/models/gemma3/gemma3_decoder_block.py +1 -1
- keras_hub/src/models/gemma3/{gemma3_interleave_embeddings.py → gemma3_layers.py} +101 -0
- keras_hub/src/models/gemma3/gemma3_presets.py +67 -7
- keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
- keras_hub/src/models/gpt2/gpt2_backbone.py +1 -3
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +1 -1
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +1 -3
- keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +1 -3
- keras_hub/src/models/llama/llama_backbone.py +1 -3
- keras_hub/src/models/masked_lm.py +1 -1
- keras_hub/src/models/mistral/mistral_backbone.py +1 -3
- keras_hub/src/models/mixtral/mixtral_backbone.py +1 -3
- keras_hub/src/models/moonshine/moonshine_backbone.py +1 -3
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +1 -3
- keras_hub/src/models/parseq/parseq_tokenizer.py +3 -1
- keras_hub/src/models/phi3/phi3_backbone.py +1 -3
- keras_hub/src/models/qwen/qwen_backbone.py +1 -3
- keras_hub/src/models/qwen/qwen_presets.py +209 -0
- keras_hub/src/models/qwen3/qwen3_backbone.py +1 -3
- keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +1 -3
- keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +15 -0
- keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +1 -3
- keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -3
- keras_hub/src/models/rqvae/__init__.py +5 -0
- keras_hub/src/models/rqvae/rqvae_backbone.py +167 -0
- keras_hub/src/models/rqvae/rqvae_layers.py +335 -0
- keras_hub/src/models/rwkv7/__init__.py +5 -0
- keras_hub/src/models/rwkv7/rwkv7_backbone.py +180 -0
- keras_hub/src/models/rwkv7/rwkv7_causal_lm.py +259 -0
- keras_hub/src/models/rwkv7/rwkv7_causal_lm_preprocessor.py +214 -0
- keras_hub/src/models/rwkv7/rwkv7_layer.py +724 -0
- keras_hub/src/models/rwkv7/rwkv7_presets.py +26 -0
- keras_hub/src/models/rwkv7/rwkv7_tokenizer.py +495 -0
- keras_hub/src/models/sam/sam_backbone.py +5 -1
- keras_hub/src/models/sam/sam_prompt_encoder.py +1 -1
- keras_hub/src/models/sam3/__init__.py +7 -0
- keras_hub/src/models/sam3/roi_align.py +222 -0
- keras_hub/src/models/sam3/sam3_detr_decoder.py +641 -0
- keras_hub/src/models/sam3/sam3_detr_encoder.py +293 -0
- keras_hub/src/models/sam3/sam3_dot_product_scoring.py +120 -0
- keras_hub/src/models/sam3/sam3_geometry_encoder.py +517 -0
- keras_hub/src/models/sam3/sam3_image_converter.py +10 -0
- keras_hub/src/models/sam3/sam3_layers.py +814 -0
- keras_hub/src/models/sam3/sam3_mask_decoder.py +374 -0
- keras_hub/src/models/sam3/sam3_pc_backbone.py +306 -0
- keras_hub/src/models/sam3/sam3_pc_image_segmenter.py +282 -0
- keras_hub/src/models/sam3/sam3_pc_image_segmenter_preprocessor.py +336 -0
- keras_hub/src/models/sam3/sam3_presets.py +16 -0
- keras_hub/src/models/sam3/sam3_text_encoder.py +212 -0
- keras_hub/src/models/sam3/sam3_tokenizer.py +65 -0
- keras_hub/src/models/sam3/sam3_utils.py +134 -0
- keras_hub/src/models/sam3/sam3_vision_encoder.py +738 -0
- keras_hub/src/models/segformer/segformer_backbone.py +6 -6
- keras_hub/src/models/siglip/siglip_layers.py +1 -3
- keras_hub/src/models/smollm3/smollm3_backbone.py +1 -3
- keras_hub/src/models/stable_diffusion_3/t5_encoder.py +1 -3
- keras_hub/src/models/t5/t5_backbone.py +1 -3
- keras_hub/src/models/t5gemma/t5gemma_backbone.py +1 -3
- keras_hub/src/models/task.py +1 -1
- keras_hub/src/tests/test_case.py +394 -3
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +33 -2
- keras_hub/src/tokenizers/byte_tokenizer.py +3 -1
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +15 -1
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +3 -1
- keras_hub/src/tokenizers/word_piece_tokenizer.py +15 -1
- keras_hub/src/utils/preset_utils.py +1 -1
- keras_hub/src/utils/tensor_utils.py +12 -0
- keras_hub/src/utils/transformers/convert_gemma3.py +68 -22
- keras_hub/src/utils/transformers/convert_qwen3_moe.py +4 -1
- keras_hub/src/utils/transformers/convert_sam3.py +472 -0
- keras_hub/src/utils/transformers/export/gemma3.py +196 -0
- keras_hub/src/utils/transformers/export/hf_exporter.py +86 -25
- keras_hub/src/utils/transformers/export/qwen.py +136 -0
- keras_hub/src/utils/transformers/preset_loader.py +15 -1
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +6 -0
- {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/METADATA +6 -13
- {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/RECORD +108 -76
- {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/WHEEL +1 -1
- keras_hub/src/models/gemma3/rms_normalization.py +0 -26
- {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,517 @@
|
|
|
1
|
+
import keras
|
|
2
|
+
from keras import layers
|
|
3
|
+
from keras import ops
|
|
4
|
+
|
|
5
|
+
from keras_hub.src.api_export import keras_hub_export
|
|
6
|
+
from keras_hub.src.models.sam3.roi_align import roi_align
|
|
7
|
+
from keras_hub.src.models.sam3.sam3_layers import SAM3MLP
|
|
8
|
+
from keras_hub.src.models.sam3.sam3_layers import SAM3Attention
|
|
9
|
+
from keras_hub.src.models.sam3.sam3_layers import SAM3SinePositionEmbedding
|
|
10
|
+
from keras_hub.src.models.sam3.sam3_utils import box_cxcywh_to_xyxy
|
|
11
|
+
from keras_hub.src.models.sam3.sam3_utils import concatenate_padded_sequences
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SAM3GeometryEncoderLayer(layers.Layer):
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
hidden_dim,
|
|
18
|
+
intermediate_dim,
|
|
19
|
+
num_heads,
|
|
20
|
+
hidden_activation="relu",
|
|
21
|
+
dropout_rate=0.0,
|
|
22
|
+
layer_norm_epsilon=1e-6,
|
|
23
|
+
**kwargs,
|
|
24
|
+
):
|
|
25
|
+
super().__init__(**kwargs)
|
|
26
|
+
self.hidden_dim = int(hidden_dim)
|
|
27
|
+
self.intermediate_dim = int(intermediate_dim)
|
|
28
|
+
self.num_heads = int(num_heads)
|
|
29
|
+
self.hidden_activation = hidden_activation
|
|
30
|
+
self.dropout_rate = float(dropout_rate)
|
|
31
|
+
self.layer_norm_epsilon = float(layer_norm_epsilon)
|
|
32
|
+
|
|
33
|
+
self.layer_norm1 = layers.LayerNormalization(
|
|
34
|
+
epsilon=self.layer_norm_epsilon,
|
|
35
|
+
dtype=self.dtype_policy,
|
|
36
|
+
name="layer_norm1",
|
|
37
|
+
)
|
|
38
|
+
self.self_attn = SAM3Attention(
|
|
39
|
+
hidden_dim=self.hidden_dim,
|
|
40
|
+
num_heads=self.num_heads,
|
|
41
|
+
dtype=self.dtype_policy,
|
|
42
|
+
name="self_attn",
|
|
43
|
+
)
|
|
44
|
+
self.dropout = layers.Dropout(
|
|
45
|
+
rate=self.dropout_rate, dtype=self.dtype_policy, name="dropout"
|
|
46
|
+
)
|
|
47
|
+
self.cross_attn = SAM3Attention(
|
|
48
|
+
hidden_dim=self.hidden_dim,
|
|
49
|
+
num_heads=self.num_heads,
|
|
50
|
+
dtype=self.dtype_policy,
|
|
51
|
+
name="cross_attn",
|
|
52
|
+
)
|
|
53
|
+
self.layer_norm2 = layers.LayerNormalization(
|
|
54
|
+
epsilon=self.layer_norm_epsilon,
|
|
55
|
+
dtype=self.dtype_policy,
|
|
56
|
+
name="layer_norm2",
|
|
57
|
+
)
|
|
58
|
+
self.mlp = SAM3MLP(
|
|
59
|
+
hidden_dim=self.hidden_dim,
|
|
60
|
+
intermediate_dim=self.intermediate_dim,
|
|
61
|
+
activation=self.hidden_activation,
|
|
62
|
+
dropout_rate=self.dropout_rate,
|
|
63
|
+
dtype=self.dtype_policy,
|
|
64
|
+
name="mlp",
|
|
65
|
+
)
|
|
66
|
+
self.layer_norm3 = layers.LayerNormalization(
|
|
67
|
+
epsilon=self.layer_norm_epsilon,
|
|
68
|
+
dtype=self.dtype_policy,
|
|
69
|
+
name="layer_norm3",
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
def build(
|
|
73
|
+
self,
|
|
74
|
+
prompt_feats_shape,
|
|
75
|
+
vision_feats_shape,
|
|
76
|
+
vision_pos_encodings_shape,
|
|
77
|
+
prompt_masks_shape,
|
|
78
|
+
):
|
|
79
|
+
self.layer_norm1.build(prompt_feats_shape)
|
|
80
|
+
self.self_attn.build(
|
|
81
|
+
prompt_feats_shape, prompt_feats_shape, prompt_feats_shape
|
|
82
|
+
)
|
|
83
|
+
self.dropout.build(prompt_feats_shape)
|
|
84
|
+
self.layer_norm2.build(prompt_feats_shape)
|
|
85
|
+
self.cross_attn.build(
|
|
86
|
+
prompt_feats_shape, vision_feats_shape, vision_feats_shape
|
|
87
|
+
)
|
|
88
|
+
self.layer_norm3.build(prompt_feats_shape)
|
|
89
|
+
self.mlp.build(prompt_feats_shape)
|
|
90
|
+
|
|
91
|
+
def call(
|
|
92
|
+
self,
|
|
93
|
+
prompt_feats,
|
|
94
|
+
vision_feats,
|
|
95
|
+
vision_pos_encodings,
|
|
96
|
+
prompt_masks,
|
|
97
|
+
training=None,
|
|
98
|
+
):
|
|
99
|
+
residual = prompt_feats
|
|
100
|
+
hidden_states = self.layer_norm1(prompt_feats, training=training)
|
|
101
|
+
hidden_states = self.self_attn(
|
|
102
|
+
query=hidden_states,
|
|
103
|
+
key=hidden_states,
|
|
104
|
+
value=hidden_states,
|
|
105
|
+
attention_mask=prompt_masks,
|
|
106
|
+
training=training,
|
|
107
|
+
)
|
|
108
|
+
hidden_states = ops.add(
|
|
109
|
+
self.dropout(hidden_states, training=training), residual
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
residual = hidden_states
|
|
113
|
+
hidden_states = self.layer_norm2(hidden_states, training=training)
|
|
114
|
+
key = ops.add(vision_feats, vision_pos_encodings)
|
|
115
|
+
hidden_states = self.cross_attn(
|
|
116
|
+
query=hidden_states, key=key, value=vision_feats, training=training
|
|
117
|
+
)
|
|
118
|
+
hidden_states = ops.add(
|
|
119
|
+
self.dropout(hidden_states, training=training), residual
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
residual = hidden_states
|
|
123
|
+
hidden_states = self.layer_norm3(hidden_states, training=training)
|
|
124
|
+
hidden_states = self.mlp(hidden_states, training=training)
|
|
125
|
+
hidden_states = ops.add(
|
|
126
|
+
self.dropout(hidden_states, training=training), residual
|
|
127
|
+
)
|
|
128
|
+
return hidden_states
|
|
129
|
+
|
|
130
|
+
def get_config(self):
|
|
131
|
+
config = super().get_config()
|
|
132
|
+
config.update(
|
|
133
|
+
{
|
|
134
|
+
"hidden_dim": self.hidden_dim,
|
|
135
|
+
"intermediate_dim": self.intermediate_dim,
|
|
136
|
+
"num_heads": self.num_heads,
|
|
137
|
+
"hidden_activation": self.hidden_activation,
|
|
138
|
+
"dropout_rate": self.dropout_rate,
|
|
139
|
+
"layer_norm_epsilon": self.layer_norm_epsilon,
|
|
140
|
+
}
|
|
141
|
+
)
|
|
142
|
+
return config
|
|
143
|
+
|
|
144
|
+
def compute_output_shape(
|
|
145
|
+
self,
|
|
146
|
+
prompt_feats_shape,
|
|
147
|
+
vision_feats_shape,
|
|
148
|
+
vision_pos_encodings_shape,
|
|
149
|
+
prompt_masks_shape,
|
|
150
|
+
):
|
|
151
|
+
return prompt_feats_shape
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
@keras_hub_export("keras_hub.layers.SAM3GeometryEncoder")
|
|
155
|
+
class SAM3GeometryEncoder(layers.Layer):
|
|
156
|
+
"""A geometry encoder for the Segment Anything Model 3 (SAM3).
|
|
157
|
+
|
|
158
|
+
This layer implements a transformer-based encoder for processing geometry
|
|
159
|
+
prompts (boxes). It extracts features from the input boxes, pools vision
|
|
160
|
+
features based on the boxes, and fuses them with transformer layers.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
num_layers: int. The number of transformer layers.
|
|
164
|
+
hidden_dim: int. The hidden dimension of the transformer layers.
|
|
165
|
+
intermediate_dim: int. The dimension of the intermediate layer in the
|
|
166
|
+
transformer's MLP.
|
|
167
|
+
num_heads: int. The number of attention heads.
|
|
168
|
+
roi_size: int. The size of the ROI pooling for boxes.
|
|
169
|
+
hidden_activation: str. The activation function for the transformer
|
|
170
|
+
layers. Defaults to `"relu"`.
|
|
171
|
+
dropout_rate: float. The dropout rate for the MLP and attention.
|
|
172
|
+
Defaults to `0.0`.
|
|
173
|
+
layer_norm_epsilon: float. The epsilon value for layer normalization.
|
|
174
|
+
Defaults to `1e-6`.
|
|
175
|
+
"""
|
|
176
|
+
|
|
177
|
+
def __init__(
|
|
178
|
+
self,
|
|
179
|
+
num_layers,
|
|
180
|
+
hidden_dim,
|
|
181
|
+
intermediate_dim,
|
|
182
|
+
num_heads,
|
|
183
|
+
roi_size,
|
|
184
|
+
hidden_activation="relu",
|
|
185
|
+
dropout_rate=0.0,
|
|
186
|
+
layer_norm_epsilon=1e-6,
|
|
187
|
+
**kwargs,
|
|
188
|
+
):
|
|
189
|
+
super().__init__(**kwargs)
|
|
190
|
+
self.num_layers = int(num_layers)
|
|
191
|
+
self.hidden_dim = int(hidden_dim)
|
|
192
|
+
self.intermediate_dim = int(intermediate_dim)
|
|
193
|
+
self.num_heads = int(num_heads)
|
|
194
|
+
self.roi_size = int(roi_size)
|
|
195
|
+
self.hidden_activation = hidden_activation
|
|
196
|
+
self.dropout_rate = float(dropout_rate)
|
|
197
|
+
self.layer_norm_epsilon = float(layer_norm_epsilon)
|
|
198
|
+
|
|
199
|
+
self.position_encoding = SAM3SinePositionEmbedding(
|
|
200
|
+
num_pos_feats=self.hidden_dim // 2,
|
|
201
|
+
normalize=True,
|
|
202
|
+
dtype=self.dtype_policy,
|
|
203
|
+
name="position_encoding",
|
|
204
|
+
)
|
|
205
|
+
self.label_embed = layers.Embedding(
|
|
206
|
+
2, self.hidden_dim, dtype=self.dtype_policy, name="label_embed"
|
|
207
|
+
)
|
|
208
|
+
self.cls_embed = layers.Embedding(
|
|
209
|
+
1, self.hidden_dim, dtype=self.dtype_policy, name="cls_embed"
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
# Box encoding layers.
|
|
213
|
+
self.boxes_direct_project = layers.Dense(
|
|
214
|
+
self.hidden_dim,
|
|
215
|
+
dtype=self.dtype_policy,
|
|
216
|
+
name="boxes_direct_project",
|
|
217
|
+
)
|
|
218
|
+
self.boxes_pool_project = layers.Conv2D(
|
|
219
|
+
self.hidden_dim,
|
|
220
|
+
kernel_size=self.roi_size,
|
|
221
|
+
dtype=self.dtype_policy,
|
|
222
|
+
name="boxes_pool_project",
|
|
223
|
+
)
|
|
224
|
+
self.boxes_pos_enc_project = layers.Dense(
|
|
225
|
+
self.hidden_dim,
|
|
226
|
+
dtype=self.dtype_policy,
|
|
227
|
+
name="boxes_pos_enc_project",
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
# Image feature normalization.
|
|
231
|
+
self.vision_layer_norm = layers.LayerNormalization(
|
|
232
|
+
epsilon=self.layer_norm_epsilon,
|
|
233
|
+
dtype=self.dtype_policy,
|
|
234
|
+
name="vision_layer_norm",
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
# Prompt projection and normalization.
|
|
238
|
+
self.final_proj = layers.Dense(
|
|
239
|
+
self.hidden_dim, dtype=self.dtype_policy, name="final_proj"
|
|
240
|
+
)
|
|
241
|
+
self.prompt_layer_norm = layers.LayerNormalization(
|
|
242
|
+
epsilon=self.layer_norm_epsilon,
|
|
243
|
+
dtype=self.dtype_policy,
|
|
244
|
+
name="prompt_layer_norm",
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
# Transformer layers.
|
|
248
|
+
self.layers = [
|
|
249
|
+
SAM3GeometryEncoderLayer(
|
|
250
|
+
hidden_dim=self.hidden_dim,
|
|
251
|
+
intermediate_dim=self.intermediate_dim,
|
|
252
|
+
num_heads=self.num_heads,
|
|
253
|
+
dropout_rate=self.dropout_rate,
|
|
254
|
+
hidden_activation=self.hidden_activation,
|
|
255
|
+
layer_norm_epsilon=self.layer_norm_epsilon,
|
|
256
|
+
dtype=self.dtype_policy,
|
|
257
|
+
name=f"layer_{i}",
|
|
258
|
+
)
|
|
259
|
+
for i in range(self.num_layers)
|
|
260
|
+
]
|
|
261
|
+
self.output_layer_norm = layers.LayerNormalization(
|
|
262
|
+
epsilon=self.layer_norm_epsilon,
|
|
263
|
+
dtype=self.dtype_policy,
|
|
264
|
+
name="output_layer_norm",
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
def build(
|
|
268
|
+
self,
|
|
269
|
+
box_embeddings_shape,
|
|
270
|
+
box_masks_shape,
|
|
271
|
+
box_labels_shape,
|
|
272
|
+
fpn_hidden_states_shape,
|
|
273
|
+
fpn_position_encodings_shape,
|
|
274
|
+
):
|
|
275
|
+
batch_size = fpn_hidden_states_shape[0]
|
|
276
|
+
self.height = fpn_hidden_states_shape[1]
|
|
277
|
+
self.width = fpn_hidden_states_shape[2]
|
|
278
|
+
self.input_hidden_dim = fpn_hidden_states_shape[-1]
|
|
279
|
+
|
|
280
|
+
self.position_encoding.build()
|
|
281
|
+
self.vision_layer_norm.build(fpn_hidden_states_shape)
|
|
282
|
+
|
|
283
|
+
box_proj_input_shape = list(box_embeddings_shape)
|
|
284
|
+
box_proj_input_shape[-1] = box_embeddings_shape[-1] - 1
|
|
285
|
+
self.boxes_direct_project.build(tuple(box_proj_input_shape))
|
|
286
|
+
|
|
287
|
+
sampled_feature_shape = [
|
|
288
|
+
batch_size,
|
|
289
|
+
self.roi_size,
|
|
290
|
+
self.roi_size,
|
|
291
|
+
self.input_hidden_dim,
|
|
292
|
+
]
|
|
293
|
+
self.boxes_pool_project.build(sampled_feature_shape)
|
|
294
|
+
|
|
295
|
+
pos_enc_shape = [batch_size, None, self.input_hidden_dim + 2]
|
|
296
|
+
self.boxes_pos_enc_project.build(pos_enc_shape)
|
|
297
|
+
self.label_embed.build([batch_size, 1])
|
|
298
|
+
self.cls_embed.build([batch_size, 1])
|
|
299
|
+
|
|
300
|
+
prompt_embed_shape = [batch_size, None, self.hidden_dim]
|
|
301
|
+
self.final_proj.build(prompt_embed_shape)
|
|
302
|
+
self.prompt_layer_norm.build(prompt_embed_shape)
|
|
303
|
+
|
|
304
|
+
vision_feat_flat_shape = [
|
|
305
|
+
batch_size,
|
|
306
|
+
self.height * self.width,
|
|
307
|
+
self.input_hidden_dim,
|
|
308
|
+
]
|
|
309
|
+
for layer in self.layers:
|
|
310
|
+
layer.build(
|
|
311
|
+
prompt_embed_shape,
|
|
312
|
+
vision_feat_flat_shape,
|
|
313
|
+
vision_feat_flat_shape,
|
|
314
|
+
None,
|
|
315
|
+
)
|
|
316
|
+
self.output_layer_norm.build(prompt_embed_shape)
|
|
317
|
+
|
|
318
|
+
def _encode_box_coordinates(self, center_x, center_y, width, height):
|
|
319
|
+
pos_x, pos_y = self.position_encoding.encode_1d_positions(
|
|
320
|
+
center_x, center_y
|
|
321
|
+
)
|
|
322
|
+
pos = ops.concatenate(
|
|
323
|
+
(pos_y, pos_x, height[:, None], width[:, None]), axis=1
|
|
324
|
+
)
|
|
325
|
+
return pos
|
|
326
|
+
|
|
327
|
+
def _encode_boxes(self, boxes, boxes_mask, boxes_labels, vision_features):
|
|
328
|
+
# Keras passes the masks as concrete tensors for both the
|
|
329
|
+
# true and false functions to build the output shape. So, we
|
|
330
|
+
# need to handle the case when 0 size masks is passed and
|
|
331
|
+
# dispatch the call to `_no_box_embeddings`. Note that we can't call
|
|
332
|
+
# the lambda directly since the inputs are bound to different
|
|
333
|
+
# values when called with concrete values.
|
|
334
|
+
if boxes.shape[1] == 0:
|
|
335
|
+
return self._no_box_embeddings(boxes, boxes_mask)
|
|
336
|
+
|
|
337
|
+
# The shape of boxes is different from HF's implementation.
|
|
338
|
+
# boxes: [batch_size, num_boxes, 5] where the last dimension is
|
|
339
|
+
# (batch_index, cx, cy, w, h)
|
|
340
|
+
boxes_indices = boxes[..., 0:1]
|
|
341
|
+
boxes = boxes[..., 1:]
|
|
342
|
+
batch_size = ops.shape(boxes)[0]
|
|
343
|
+
boxes_embed = self.boxes_direct_project(boxes)
|
|
344
|
+
|
|
345
|
+
# Pool features using ROI align.
|
|
346
|
+
# Convert boxes from cxcywh to xyxy format and denormalize.
|
|
347
|
+
boxes_xyxy = box_cxcywh_to_xyxy(boxes)
|
|
348
|
+
scale = ops.array(
|
|
349
|
+
[[[self.width, self.height, self.width, self.height]]],
|
|
350
|
+
dtype=boxes.dtype,
|
|
351
|
+
)
|
|
352
|
+
boxes_xyxy = ops.multiply(boxes_xyxy, scale)
|
|
353
|
+
boxes_xyxy = ops.reshape(boxes_xyxy, (-1, 4))
|
|
354
|
+
# Add batch indices to boxes for roi_align.
|
|
355
|
+
rois = ops.concatenate(
|
|
356
|
+
[ops.reshape(boxes_indices, (-1, 1)), boxes_xyxy], axis=-1
|
|
357
|
+
)
|
|
358
|
+
sampled_features = roi_align(
|
|
359
|
+
vision_features,
|
|
360
|
+
rois,
|
|
361
|
+
(self.roi_size, self.roi_size),
|
|
362
|
+
spatial_scale=1.0,
|
|
363
|
+
height=self.height,
|
|
364
|
+
width=self.width,
|
|
365
|
+
hidden_dim=self.input_hidden_dim,
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
pooled_projection = self.boxes_pool_project(sampled_features)
|
|
369
|
+
pooled_projection = ops.reshape(
|
|
370
|
+
pooled_projection, (batch_size, -1, self.hidden_dim)
|
|
371
|
+
)
|
|
372
|
+
boxes_embed = ops.add(boxes_embed, pooled_projection)
|
|
373
|
+
|
|
374
|
+
# Add position encoding.
|
|
375
|
+
center_x, center_y, box_width, box_height = ops.unstack(
|
|
376
|
+
boxes, num=4, axis=-1
|
|
377
|
+
)
|
|
378
|
+
pos_enc = self._encode_box_coordinates(
|
|
379
|
+
ops.reshape(center_x, (-1,)),
|
|
380
|
+
ops.reshape(center_y, (-1,)),
|
|
381
|
+
ops.reshape(box_width, (-1,)),
|
|
382
|
+
ops.reshape(box_height, (-1,)),
|
|
383
|
+
)
|
|
384
|
+
pos_enc = ops.reshape(
|
|
385
|
+
pos_enc,
|
|
386
|
+
(batch_size, -1, self.position_encoding.num_pos_feats * 2 + 2),
|
|
387
|
+
)
|
|
388
|
+
pos_projection = self.boxes_pos_enc_project(pos_enc)
|
|
389
|
+
boxes_embed = ops.add(boxes_embed, pos_projection)
|
|
390
|
+
|
|
391
|
+
# Add label embeddings (positive / negative).
|
|
392
|
+
label_embed = self.label_embed(ops.cast(boxes_labels, dtype="int32"))
|
|
393
|
+
return ops.add(label_embed, boxes_embed), boxes_mask
|
|
394
|
+
|
|
395
|
+
def _no_box_embeddings(self, box_embeddings, box_masks):
|
|
396
|
+
batch_size = ops.shape(box_embeddings)[0]
|
|
397
|
+
num_boxes = ops.shape(box_embeddings)[1]
|
|
398
|
+
return (
|
|
399
|
+
ops.zeros(
|
|
400
|
+
(batch_size, num_boxes, self.hidden_dim),
|
|
401
|
+
dtype=box_embeddings.dtype,
|
|
402
|
+
),
|
|
403
|
+
box_masks,
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
def call(
|
|
407
|
+
self,
|
|
408
|
+
box_embeddings,
|
|
409
|
+
box_masks,
|
|
410
|
+
box_labels,
|
|
411
|
+
fpn_hidden_states,
|
|
412
|
+
fpn_position_encodings,
|
|
413
|
+
training=None,
|
|
414
|
+
):
|
|
415
|
+
# Prepare vision features for cross-attention.
|
|
416
|
+
vision_feats_flat = ops.reshape(
|
|
417
|
+
fpn_hidden_states,
|
|
418
|
+
(-1, self.height * self.width, self.input_hidden_dim),
|
|
419
|
+
)
|
|
420
|
+
vision_pos_embeds_flat = ops.reshape(
|
|
421
|
+
fpn_position_encodings,
|
|
422
|
+
(-1, self.height * self.width, self.input_hidden_dim),
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
# Normalize image features for pooling operations.
|
|
426
|
+
normalized_image_feats = self.vision_layer_norm(fpn_hidden_states)
|
|
427
|
+
|
|
428
|
+
prompt_embeds, prompt_mask = ops.cond(
|
|
429
|
+
ops.equal(ops.shape(box_embeddings)[1], 0),
|
|
430
|
+
lambda: self._no_box_embeddings(box_embeddings, box_masks),
|
|
431
|
+
lambda: self._encode_boxes(
|
|
432
|
+
box_embeddings, box_masks, box_labels, normalized_image_feats
|
|
433
|
+
),
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
# Add CLS token (always valid).
|
|
437
|
+
cls_embed = ops.reshape(
|
|
438
|
+
self.cls_embed._embeddings, (1, 1, self.hidden_dim)
|
|
439
|
+
)
|
|
440
|
+
cls_embed = ops.tile(cls_embed, (ops.shape(prompt_embeds)[0], 1, 1))
|
|
441
|
+
cls_mask = ops.ones_like(cls_embed[:, :, 0], dtype=prompt_mask.dtype)
|
|
442
|
+
|
|
443
|
+
prompt_embeds, prompt_mask = concatenate_padded_sequences(
|
|
444
|
+
prompt_embeds,
|
|
445
|
+
prompt_mask,
|
|
446
|
+
ops.shape(prompt_embeds)[1],
|
|
447
|
+
cls_embed,
|
|
448
|
+
cls_mask,
|
|
449
|
+
1,
|
|
450
|
+
self.hidden_dim,
|
|
451
|
+
)
|
|
452
|
+
prompt_embeds = self.prompt_layer_norm(self.final_proj(prompt_embeds))
|
|
453
|
+
|
|
454
|
+
# Apply transformer layers with cross-attention to vision features.
|
|
455
|
+
for layer in self.layers:
|
|
456
|
+
prompt_embeds = layer(
|
|
457
|
+
prompt_embeds,
|
|
458
|
+
vision_feats_flat,
|
|
459
|
+
vision_pos_embeds_flat,
|
|
460
|
+
prompt_mask,
|
|
461
|
+
training=training,
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
# Final output normalization.
|
|
465
|
+
prompt_embeds = self.output_layer_norm(prompt_embeds, training=training)
|
|
466
|
+
return prompt_embeds, prompt_mask
|
|
467
|
+
|
|
468
|
+
def get_config(self):
|
|
469
|
+
config = super().get_config()
|
|
470
|
+
config.update(
|
|
471
|
+
{
|
|
472
|
+
"num_layers": self.num_layers,
|
|
473
|
+
"hidden_dim": self.hidden_dim,
|
|
474
|
+
"intermediate_dim": self.intermediate_dim,
|
|
475
|
+
"num_heads": self.num_heads,
|
|
476
|
+
"roi_size": self.roi_size,
|
|
477
|
+
"hidden_activation": self.hidden_activation,
|
|
478
|
+
"dropout_rate": self.dropout_rate,
|
|
479
|
+
"layer_norm_epsilon": self.layer_norm_epsilon,
|
|
480
|
+
}
|
|
481
|
+
)
|
|
482
|
+
return config
|
|
483
|
+
|
|
484
|
+
def compute_output_shape(
|
|
485
|
+
self,
|
|
486
|
+
box_embeddings_shape,
|
|
487
|
+
box_masks_shape,
|
|
488
|
+
box_labels_shape,
|
|
489
|
+
fpn_hidden_states_shape,
|
|
490
|
+
fpn_position_encodings_shape,
|
|
491
|
+
):
|
|
492
|
+
batch_size = fpn_hidden_states_shape[0]
|
|
493
|
+
num_boxes = box_embeddings_shape[1]
|
|
494
|
+
seq_len = None
|
|
495
|
+
if num_boxes is not None:
|
|
496
|
+
seq_len = num_boxes + 1
|
|
497
|
+
return [batch_size, seq_len, self.hidden_dim], [batch_size, seq_len]
|
|
498
|
+
|
|
499
|
+
def compute_output_spec(
|
|
500
|
+
self,
|
|
501
|
+
box_embeddings,
|
|
502
|
+
box_masks,
|
|
503
|
+
box_labels,
|
|
504
|
+
fpn_hidden_states,
|
|
505
|
+
fpn_position_encodings,
|
|
506
|
+
):
|
|
507
|
+
prompt_embeds_shape, prompt_mask_shape = self.compute_output_shape(
|
|
508
|
+
box_embeddings.shape,
|
|
509
|
+
box_masks.shape,
|
|
510
|
+
box_labels.shape,
|
|
511
|
+
fpn_hidden_states.shape,
|
|
512
|
+
fpn_position_encodings.shape,
|
|
513
|
+
)
|
|
514
|
+
return (
|
|
515
|
+
keras.KerasTensor(prompt_embeds_shape, dtype=self.compute_dtype),
|
|
516
|
+
keras.KerasTensor(prompt_mask_shape, dtype="bool"),
|
|
517
|
+
)
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
from keras_hub.src.api_export import keras_hub_export
|
|
2
|
+
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
|
|
3
|
+
from keras_hub.src.models.sam3.sam3_pc_backbone import (
|
|
4
|
+
SAM3PromptableConceptBackbone,
|
|
5
|
+
)
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@keras_hub_export("keras_hub.layers.SAM3ImageConverter")
|
|
9
|
+
class SAM3ImageConverter(ImageConverter):
|
|
10
|
+
backbone_cls = SAM3PromptableConceptBackbone
|