keras-hub 0.25.0.dev0__py3-none-any.whl → 0.26.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/layers/__init__.py +21 -0
- keras_hub/models/__init__.py +27 -0
- keras_hub/src/layers/modeling/non_max_supression.py +5 -2
- keras_hub/src/layers/modeling/reversible_embedding.py +2 -275
- keras_hub/src/layers/modeling/token_and_position_embedding.py +6 -6
- keras_hub/src/layers/modeling/transformer_layer_utils.py +9 -9
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +3 -1
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +3 -1
- keras_hub/src/models/albert/albert_backbone.py +1 -3
- keras_hub/src/models/backbone.py +3 -0
- keras_hub/src/models/bart/bart_backbone.py +1 -3
- keras_hub/src/models/bert/bert_backbone.py +2 -4
- keras_hub/src/models/bloom/bloom_backbone.py +1 -3
- keras_hub/src/models/causal_lm.py +2 -2
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -3
- keras_hub/src/models/edrec/edrec_backbone.py +147 -0
- keras_hub/src/models/edrec/edrec_layers.py +434 -0
- keras_hub/src/models/edrec/edrec_seq2seq_lm.py +273 -0
- keras_hub/src/models/electra/electra_backbone.py +1 -3
- keras_hub/src/models/f_net/f_net_backbone.py +1 -3
- keras_hub/src/models/falcon/falcon_backbone.py +1 -3
- keras_hub/src/models/flux/flux_layers.py +3 -3
- keras_hub/src/models/flux/flux_maths.py +29 -15
- keras_hub/src/models/gemma/gemma_backbone.py +1 -3
- keras_hub/src/models/gemma/gemma_causal_lm.py +1 -1
- keras_hub/src/models/gemma3/gemma3_attention.py +1 -1
- keras_hub/src/models/gemma3/gemma3_backbone.py +70 -8
- keras_hub/src/models/gemma3/gemma3_causal_lm.py +16 -1
- keras_hub/src/models/gemma3/gemma3_decoder_block.py +23 -3
- keras_hub/src/models/gemma3/{gemma3_interleave_embeddings.py → gemma3_layers.py} +101 -0
- keras_hub/src/models/gemma3/gemma3_presets.py +79 -7
- keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
- keras_hub/src/models/gpt2/gpt2_backbone.py +1 -3
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +1 -1
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +1 -3
- keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +1 -3
- keras_hub/src/models/llama/llama_backbone.py +1 -3
- keras_hub/src/models/masked_lm.py +1 -1
- keras_hub/src/models/mistral/mistral_backbone.py +1 -3
- keras_hub/src/models/mixtral/mixtral_backbone.py +1 -3
- keras_hub/src/models/moonshine/moonshine_backbone.py +1 -3
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +1 -3
- keras_hub/src/models/parseq/parseq_tokenizer.py +3 -1
- keras_hub/src/models/phi3/phi3_backbone.py +1 -3
- keras_hub/src/models/qwen/qwen_backbone.py +1 -3
- keras_hub/src/models/qwen/qwen_presets.py +209 -0
- keras_hub/src/models/qwen3/qwen3_backbone.py +1 -3
- keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +1 -3
- keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +15 -0
- keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +1 -3
- keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -3
- keras_hub/src/models/rqvae/__init__.py +5 -0
- keras_hub/src/models/rqvae/rqvae_backbone.py +167 -0
- keras_hub/src/models/rqvae/rqvae_layers.py +335 -0
- keras_hub/src/models/rwkv7/__init__.py +5 -0
- keras_hub/src/models/rwkv7/rwkv7_backbone.py +180 -0
- keras_hub/src/models/rwkv7/rwkv7_causal_lm.py +259 -0
- keras_hub/src/models/rwkv7/rwkv7_causal_lm_preprocessor.py +214 -0
- keras_hub/src/models/rwkv7/rwkv7_layer.py +724 -0
- keras_hub/src/models/rwkv7/rwkv7_presets.py +26 -0
- keras_hub/src/models/rwkv7/rwkv7_tokenizer.py +495 -0
- keras_hub/src/models/sam/sam_backbone.py +5 -1
- keras_hub/src/models/sam/sam_prompt_encoder.py +1 -1
- keras_hub/src/models/sam3/__init__.py +7 -0
- keras_hub/src/models/sam3/roi_align.py +222 -0
- keras_hub/src/models/sam3/sam3_detr_decoder.py +641 -0
- keras_hub/src/models/sam3/sam3_detr_encoder.py +293 -0
- keras_hub/src/models/sam3/sam3_dot_product_scoring.py +120 -0
- keras_hub/src/models/sam3/sam3_geometry_encoder.py +517 -0
- keras_hub/src/models/sam3/sam3_image_converter.py +10 -0
- keras_hub/src/models/sam3/sam3_layers.py +814 -0
- keras_hub/src/models/sam3/sam3_mask_decoder.py +374 -0
- keras_hub/src/models/sam3/sam3_pc_backbone.py +306 -0
- keras_hub/src/models/sam3/sam3_pc_image_segmenter.py +282 -0
- keras_hub/src/models/sam3/sam3_pc_image_segmenter_preprocessor.py +336 -0
- keras_hub/src/models/sam3/sam3_presets.py +16 -0
- keras_hub/src/models/sam3/sam3_text_encoder.py +212 -0
- keras_hub/src/models/sam3/sam3_tokenizer.py +65 -0
- keras_hub/src/models/sam3/sam3_utils.py +134 -0
- keras_hub/src/models/sam3/sam3_vision_encoder.py +738 -0
- keras_hub/src/models/segformer/segformer_backbone.py +6 -6
- keras_hub/src/models/siglip/siglip_layers.py +1 -3
- keras_hub/src/models/smollm3/smollm3_backbone.py +1 -3
- keras_hub/src/models/stable_diffusion_3/t5_encoder.py +1 -3
- keras_hub/src/models/t5/t5_backbone.py +1 -3
- keras_hub/src/models/t5gemma/t5gemma_backbone.py +1 -3
- keras_hub/src/models/task.py +1 -1
- keras_hub/src/tests/test_case.py +394 -3
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +33 -2
- keras_hub/src/tokenizers/byte_tokenizer.py +3 -1
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +15 -1
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +3 -1
- keras_hub/src/tokenizers/word_piece_tokenizer.py +15 -1
- keras_hub/src/utils/preset_utils.py +1 -1
- keras_hub/src/utils/tensor_utils.py +12 -0
- keras_hub/src/utils/transformers/convert_gemma3.py +68 -22
- keras_hub/src/utils/transformers/convert_qwen3_moe.py +4 -1
- keras_hub/src/utils/transformers/convert_sam3.py +472 -0
- keras_hub/src/utils/transformers/export/gemma3.py +196 -0
- keras_hub/src/utils/transformers/export/hf_exporter.py +86 -25
- keras_hub/src/utils/transformers/export/qwen.py +136 -0
- keras_hub/src/utils/transformers/preset_loader.py +15 -1
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +6 -0
- {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/METADATA +6 -13
- {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/RECORD +108 -76
- {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/WHEEL +1 -1
- keras_hub/src/models/gemma3/rms_normalization.py +0 -26
- {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,814 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
from keras import backend
|
|
4
|
+
from keras import config
|
|
5
|
+
from keras import initializers
|
|
6
|
+
from keras import layers
|
|
7
|
+
from keras import ops
|
|
8
|
+
|
|
9
|
+
from keras_hub.src.models.sam3.sam3_utils import box_cxcywh_to_xyxy
|
|
10
|
+
from keras_hub.src.models.sam3.sam3_utils import inverse_sigmoid
|
|
11
|
+
from keras_hub.src.utils.keras_utils import standardize_data_format
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SAM3MLP(layers.Layer):
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
hidden_dim,
|
|
18
|
+
intermediate_dim,
|
|
19
|
+
activation="gelu",
|
|
20
|
+
dropout_rate=0.0,
|
|
21
|
+
**kwargs,
|
|
22
|
+
):
|
|
23
|
+
super().__init__(**kwargs)
|
|
24
|
+
self.hidden_dim = int(hidden_dim)
|
|
25
|
+
self.intermediate_dim = int(intermediate_dim)
|
|
26
|
+
self.activation = activation
|
|
27
|
+
self.dropout_rate = float(dropout_rate)
|
|
28
|
+
|
|
29
|
+
self.fc1 = layers.Dense(
|
|
30
|
+
intermediate_dim, dtype=self.dtype_policy, name="fc1"
|
|
31
|
+
)
|
|
32
|
+
self.act = layers.Activation(activation, dtype=self.dtype_policy)
|
|
33
|
+
self.fc2 = layers.Dense(hidden_dim, dtype=self.dtype_policy, name="fc2")
|
|
34
|
+
self.dropout = layers.Dropout(
|
|
35
|
+
dropout_rate, dtype=self.dtype_policy, name="dropout"
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
def build(self, input_shape):
|
|
39
|
+
self.fc1.build(input_shape)
|
|
40
|
+
input_shape = self.fc1.compute_output_shape(input_shape)
|
|
41
|
+
self.dropout.build(input_shape)
|
|
42
|
+
self.act.build(input_shape)
|
|
43
|
+
self.fc2.build(input_shape)
|
|
44
|
+
input_shape = self.fc2.compute_output_shape(input_shape)
|
|
45
|
+
|
|
46
|
+
def call(self, inputs, training=None):
|
|
47
|
+
x = self.fc1(inputs, training=training)
|
|
48
|
+
x = self.dropout(x, training=training)
|
|
49
|
+
x = self.act(x)
|
|
50
|
+
return self.fc2(x, training=training)
|
|
51
|
+
|
|
52
|
+
def get_config(self):
|
|
53
|
+
config = super().get_config()
|
|
54
|
+
config.update(
|
|
55
|
+
{
|
|
56
|
+
"hidden_dim": self.hidden_dim,
|
|
57
|
+
"intermediate_dim": self.intermediate_dim,
|
|
58
|
+
"activation": self.activation,
|
|
59
|
+
"dropout_rate": self.dropout_rate,
|
|
60
|
+
}
|
|
61
|
+
)
|
|
62
|
+
return config
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class SAM3Attention(layers.Layer):
|
|
66
|
+
def __init__(self, hidden_dim, num_heads, **kwargs):
|
|
67
|
+
super().__init__(**kwargs)
|
|
68
|
+
self.hidden_dim = int(hidden_dim)
|
|
69
|
+
self.num_heads = int(num_heads)
|
|
70
|
+
self.head_dim = self.hidden_dim // self.num_heads
|
|
71
|
+
self.scale = self.head_dim**-0.5
|
|
72
|
+
|
|
73
|
+
self.q_proj = layers.Dense(
|
|
74
|
+
self.hidden_dim, dtype=self.dtype_policy, name="q_proj"
|
|
75
|
+
)
|
|
76
|
+
self.k_proj = layers.Dense(
|
|
77
|
+
self.hidden_dim, dtype=self.dtype_policy, name="k_proj"
|
|
78
|
+
)
|
|
79
|
+
self.v_proj = layers.Dense(
|
|
80
|
+
self.hidden_dim, dtype=self.dtype_policy, name="v_proj"
|
|
81
|
+
)
|
|
82
|
+
self.o_proj = layers.Dense(
|
|
83
|
+
self.hidden_dim, dtype=self.dtype_policy, name="o_proj"
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
def build(self, query_shape, key_shape, value_shape):
|
|
87
|
+
self.q_proj.build(query_shape)
|
|
88
|
+
self.k_proj.build(key_shape)
|
|
89
|
+
self.v_proj.build(value_shape)
|
|
90
|
+
self.o_proj.build(value_shape)
|
|
91
|
+
|
|
92
|
+
def call(
|
|
93
|
+
self,
|
|
94
|
+
query,
|
|
95
|
+
key,
|
|
96
|
+
value,
|
|
97
|
+
attention_mask=None,
|
|
98
|
+
attention_bias=None,
|
|
99
|
+
training=None,
|
|
100
|
+
):
|
|
101
|
+
batch_size = ops.shape(query)[0]
|
|
102
|
+
|
|
103
|
+
query = self.q_proj(query, training=training)
|
|
104
|
+
query = ops.reshape(
|
|
105
|
+
query, (batch_size, -1, self.num_heads, self.head_dim)
|
|
106
|
+
)
|
|
107
|
+
key = self.k_proj(key, training=training)
|
|
108
|
+
key = ops.reshape(key, (batch_size, -1, self.num_heads, self.head_dim))
|
|
109
|
+
value = self.v_proj(value, training=training)
|
|
110
|
+
value = ops.reshape(
|
|
111
|
+
value, (batch_size, -1, self.num_heads, self.head_dim)
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
if (
|
|
115
|
+
backend.backend() == "torch"
|
|
116
|
+
and attention_mask is None
|
|
117
|
+
and attention_bias is not None
|
|
118
|
+
):
|
|
119
|
+
# TODO: Torch backend doesn't support attention_bias in
|
|
120
|
+
# ops.dot_product_attention yet.
|
|
121
|
+
# Fixed by https://github.com/keras-team/keras/pull/22045
|
|
122
|
+
import torch
|
|
123
|
+
|
|
124
|
+
query = torch.transpose(query, 1, 2).contiguous()
|
|
125
|
+
key = torch.transpose(key, 1, 2).contiguous()
|
|
126
|
+
value = torch.transpose(value, 1, 2).contiguous()
|
|
127
|
+
attention_bias = attention_bias.contiguous()
|
|
128
|
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
|
129
|
+
query,
|
|
130
|
+
key,
|
|
131
|
+
value,
|
|
132
|
+
attn_mask=attention_bias,
|
|
133
|
+
is_causal=False,
|
|
134
|
+
scale=self.scale,
|
|
135
|
+
)
|
|
136
|
+
attn_output = torch.transpose(attn_output, 2, 1)
|
|
137
|
+
else:
|
|
138
|
+
if attention_mask is not None:
|
|
139
|
+
attention_mask = ops.cast(attention_mask, dtype="bool")
|
|
140
|
+
attn_output = ops.dot_product_attention(
|
|
141
|
+
query,
|
|
142
|
+
key,
|
|
143
|
+
value,
|
|
144
|
+
bias=attention_bias,
|
|
145
|
+
mask=attention_mask,
|
|
146
|
+
scale=self.scale,
|
|
147
|
+
is_causal=False,
|
|
148
|
+
)
|
|
149
|
+
attn_output = ops.reshape(
|
|
150
|
+
attn_output, (batch_size, -1, self.num_heads * self.head_dim)
|
|
151
|
+
)
|
|
152
|
+
return self.o_proj(attn_output, training=training)
|
|
153
|
+
|
|
154
|
+
def get_config(self):
|
|
155
|
+
config = super().get_config()
|
|
156
|
+
config.update(
|
|
157
|
+
{
|
|
158
|
+
"hidden_dim": self.hidden_dim,
|
|
159
|
+
"num_heads": self.num_heads,
|
|
160
|
+
}
|
|
161
|
+
)
|
|
162
|
+
return config
|
|
163
|
+
|
|
164
|
+
def compute_output_shape(self, input_shape):
|
|
165
|
+
return input_shape
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
class SAM3RoPEAttention(layers.Layer):
|
|
169
|
+
def __init__(
|
|
170
|
+
self,
|
|
171
|
+
hidden_dim,
|
|
172
|
+
num_heads,
|
|
173
|
+
attention_dropout_rate=0.0,
|
|
174
|
+
**kwargs,
|
|
175
|
+
):
|
|
176
|
+
super().__init__(**kwargs)
|
|
177
|
+
self.hidden_dim = int(hidden_dim)
|
|
178
|
+
self.num_heads = int(num_heads)
|
|
179
|
+
self.attention_dropout_rate = float(attention_dropout_rate)
|
|
180
|
+
self.head_dim = self.hidden_dim // self.num_heads
|
|
181
|
+
self.scale = self.head_dim**-0.5
|
|
182
|
+
|
|
183
|
+
self.q_proj = layers.Dense(
|
|
184
|
+
self.hidden_dim, dtype=self.dtype_policy, name="q_proj"
|
|
185
|
+
)
|
|
186
|
+
self.k_proj = layers.Dense(
|
|
187
|
+
self.hidden_dim, dtype=self.dtype_policy, name="k_proj"
|
|
188
|
+
)
|
|
189
|
+
self.v_proj = layers.Dense(
|
|
190
|
+
self.hidden_dim, dtype=self.dtype_policy, name="v_proj"
|
|
191
|
+
)
|
|
192
|
+
self.o_proj = layers.Dense(
|
|
193
|
+
self.hidden_dim, dtype=self.dtype_policy, name="o_proj"
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
def build(self, input_shape):
|
|
197
|
+
self.height = input_shape[1]
|
|
198
|
+
self.width = input_shape[2]
|
|
199
|
+
self.q_proj.build(input_shape)
|
|
200
|
+
self.k_proj.build(input_shape)
|
|
201
|
+
self.v_proj.build(input_shape)
|
|
202
|
+
self.o_proj.build(input_shape)
|
|
203
|
+
|
|
204
|
+
def apply_rotary_pos_emb_2d(self, query, key, cos, sin):
|
|
205
|
+
def rotate_pairwise(x):
|
|
206
|
+
x = ops.reshape(
|
|
207
|
+
x,
|
|
208
|
+
(
|
|
209
|
+
-1,
|
|
210
|
+
self.num_heads,
|
|
211
|
+
self.height * self.width,
|
|
212
|
+
self.head_dim // 2,
|
|
213
|
+
2,
|
|
214
|
+
),
|
|
215
|
+
)
|
|
216
|
+
x1 = x[..., 0]
|
|
217
|
+
x2 = x[..., 1]
|
|
218
|
+
x = ops.stack((-x2, x1), axis=-1)
|
|
219
|
+
return ops.reshape(
|
|
220
|
+
x, (-1, self.num_heads, self.height * self.width, self.head_dim)
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
query = ops.transpose(query, axes=(0, 2, 1, 3))
|
|
224
|
+
key = ops.transpose(key, axes=(0, 2, 1, 3))
|
|
225
|
+
|
|
226
|
+
original_dtype = backend.standardize_dtype(query.dtype)
|
|
227
|
+
query_embed = ops.cast(query, dtype="float32")
|
|
228
|
+
query_embed = ops.add(
|
|
229
|
+
ops.multiply(query_embed, cos),
|
|
230
|
+
ops.multiply(rotate_pairwise(query_embed), sin),
|
|
231
|
+
)
|
|
232
|
+
key_embed = ops.cast(key, dtype="float32")
|
|
233
|
+
key_embed = ops.add(
|
|
234
|
+
ops.multiply(key_embed, cos),
|
|
235
|
+
ops.multiply(rotate_pairwise(key_embed), sin),
|
|
236
|
+
)
|
|
237
|
+
query_embed = ops.cast(query_embed, dtype=original_dtype)
|
|
238
|
+
key_embed = ops.cast(key_embed, dtype=original_dtype)
|
|
239
|
+
|
|
240
|
+
query_embed = ops.transpose(query_embed, axes=(0, 2, 1, 3))
|
|
241
|
+
key_embed = ops.transpose(key_embed, axes=(0, 2, 1, 3))
|
|
242
|
+
return query_embed, key_embed
|
|
243
|
+
|
|
244
|
+
def call(self, hidden_states, position_embeddings, training=None):
|
|
245
|
+
new_shape = (
|
|
246
|
+
-1,
|
|
247
|
+
self.height * self.width,
|
|
248
|
+
self.num_heads,
|
|
249
|
+
self.head_dim,
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
query = self.q_proj(hidden_states, training=training)
|
|
253
|
+
query = ops.reshape(query, new_shape)
|
|
254
|
+
key = self.k_proj(hidden_states, training=training)
|
|
255
|
+
key = ops.reshape(key, new_shape)
|
|
256
|
+
value = self.v_proj(hidden_states, training=training)
|
|
257
|
+
value = ops.reshape(value, new_shape)
|
|
258
|
+
cos, sin = position_embeddings
|
|
259
|
+
query, key = self.apply_rotary_pos_emb_2d(query, key, cos=cos, sin=sin)
|
|
260
|
+
|
|
261
|
+
attention_output = ops.dot_product_attention(
|
|
262
|
+
query, key, value, scale=self.scale, is_causal=False
|
|
263
|
+
)
|
|
264
|
+
attention_output = ops.reshape(
|
|
265
|
+
attention_output, (-1, self.height, self.width, self.hidden_dim)
|
|
266
|
+
)
|
|
267
|
+
attention_output = self.o_proj(attention_output, training=training)
|
|
268
|
+
return attention_output
|
|
269
|
+
|
|
270
|
+
def get_config(self):
|
|
271
|
+
config = super().get_config()
|
|
272
|
+
config.update(
|
|
273
|
+
{
|
|
274
|
+
"hidden_dim": self.hidden_dim,
|
|
275
|
+
"num_heads": self.num_heads,
|
|
276
|
+
"attention_dropout_rate": self.attention_dropout_rate,
|
|
277
|
+
}
|
|
278
|
+
)
|
|
279
|
+
return config
|
|
280
|
+
|
|
281
|
+
def compute_output_shape(self, input_shape):
|
|
282
|
+
return input_shape
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
class SAM3PatchEmbedding(layers.Layer):
|
|
286
|
+
def __init__(self, hidden_dim, patch_size, data_format=None, **kwargs):
|
|
287
|
+
super().__init__(**kwargs)
|
|
288
|
+
self.hidden_dim = int(hidden_dim)
|
|
289
|
+
self.patch_size = int(patch_size)
|
|
290
|
+
self.data_format = standardize_data_format(data_format)
|
|
291
|
+
|
|
292
|
+
self.projection = layers.Conv2D(
|
|
293
|
+
self.hidden_dim,
|
|
294
|
+
kernel_size=self.patch_size,
|
|
295
|
+
strides=self.patch_size,
|
|
296
|
+
use_bias=False,
|
|
297
|
+
dtype=self.dtype_policy,
|
|
298
|
+
name="projection",
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
def build(self, input_shape):
|
|
302
|
+
self.projection.build(input_shape)
|
|
303
|
+
output_shape = self.projection.compute_output_shape(input_shape)
|
|
304
|
+
if self.data_format == "channels_last":
|
|
305
|
+
self.seq_len = int(output_shape[1]) * int(output_shape[2])
|
|
306
|
+
else:
|
|
307
|
+
self.seq_len = int(output_shape[2]) * int(output_shape[3])
|
|
308
|
+
|
|
309
|
+
def call(self, inputs, training=None):
|
|
310
|
+
embeddings = self.projection(inputs, training=training)
|
|
311
|
+
if self.data_format == "channels_last":
|
|
312
|
+
embeddings = ops.reshape(
|
|
313
|
+
embeddings, (-1, self.seq_len, self.hidden_dim)
|
|
314
|
+
)
|
|
315
|
+
else:
|
|
316
|
+
embeddings = ops.reshape(
|
|
317
|
+
embeddings, (-1, self.hidden_dim, self.seq_len)
|
|
318
|
+
)
|
|
319
|
+
embeddings = ops.transpose(embeddings, (0, 2, 1))
|
|
320
|
+
return embeddings
|
|
321
|
+
|
|
322
|
+
def get_config(self):
|
|
323
|
+
config = super().get_config()
|
|
324
|
+
config.update(
|
|
325
|
+
{
|
|
326
|
+
"hidden_dim": self.hidden_dim,
|
|
327
|
+
"patch_size": self.patch_size,
|
|
328
|
+
}
|
|
329
|
+
)
|
|
330
|
+
return config
|
|
331
|
+
|
|
332
|
+
def compute_output_shape(self, input_shape):
|
|
333
|
+
output_shape = [input_shape[0], None, self.hidden_dim]
|
|
334
|
+
if self.data_format == "channels_last":
|
|
335
|
+
if input_shape[1] is not None and input_shape[2] is not None:
|
|
336
|
+
patch_num = input_shape[1] // self.patch_size
|
|
337
|
+
output_shape[1] = patch_num**2
|
|
338
|
+
else:
|
|
339
|
+
if input_shape[2] is not None and input_shape[3] is not None:
|
|
340
|
+
patch_num = input_shape[2] // self.patch_size
|
|
341
|
+
output_shape[1] = patch_num**2
|
|
342
|
+
return output_shape
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
class SAM3Embedding(layers.Layer):
|
|
346
|
+
def __init__(
|
|
347
|
+
self,
|
|
348
|
+
hidden_dim,
|
|
349
|
+
patch_size,
|
|
350
|
+
image_shape,
|
|
351
|
+
dropout_rate=0.0,
|
|
352
|
+
pretrain_image_shape=(336, 336, 3),
|
|
353
|
+
data_format=None,
|
|
354
|
+
**kwargs,
|
|
355
|
+
):
|
|
356
|
+
super().__init__(**kwargs)
|
|
357
|
+
self.hidden_dim = int(hidden_dim)
|
|
358
|
+
self.patch_size = int(patch_size)
|
|
359
|
+
self.image_shape = (
|
|
360
|
+
int(image_shape[0]),
|
|
361
|
+
int(image_shape[1]),
|
|
362
|
+
int(image_shape[2]),
|
|
363
|
+
)
|
|
364
|
+
self.dropout_rate = float(dropout_rate)
|
|
365
|
+
self.pretrain_image_shape = (
|
|
366
|
+
int(pretrain_image_shape[0]),
|
|
367
|
+
int(pretrain_image_shape[1]),
|
|
368
|
+
int(pretrain_image_shape[2]),
|
|
369
|
+
)
|
|
370
|
+
self.data_format = standardize_data_format(data_format)
|
|
371
|
+
self.num_patches = (self.pretrain_image_shape[0] // self.patch_size) * (
|
|
372
|
+
self.pretrain_image_shape[1] // self.patch_size
|
|
373
|
+
)
|
|
374
|
+
self.tiled_num_patches = (self.image_shape[0] // self.patch_size) * (
|
|
375
|
+
self.image_shape[1] // self.patch_size
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
self.patch_embeddings = SAM3PatchEmbedding(
|
|
379
|
+
hidden_dim=self.hidden_dim,
|
|
380
|
+
patch_size=self.patch_size,
|
|
381
|
+
data_format=self.data_format,
|
|
382
|
+
dtype=self.dtype_policy,
|
|
383
|
+
name="patch_embeddings",
|
|
384
|
+
)
|
|
385
|
+
self.dropout = layers.Dropout(
|
|
386
|
+
self.dropout_rate, dtype=self.dtype_policy, name="dropout"
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
def build(self, input_shape):
|
|
390
|
+
self.patch_embeddings.build(input_shape)
|
|
391
|
+
embedding_shape = self.patch_embeddings.compute_output_shape(
|
|
392
|
+
input_shape
|
|
393
|
+
)
|
|
394
|
+
self.dropout.build(embedding_shape)
|
|
395
|
+
|
|
396
|
+
# Note that there are two position embeddings:
|
|
397
|
+
# `self.tiled_position_embeddings` is used for the image inputs during
|
|
398
|
+
# both training and inference.
|
|
399
|
+
# `self.position_embeddings` is used to load pretrained weights and
|
|
400
|
+
# remains unchanged during training and inference. It will be updated
|
|
401
|
+
# during saving once `self.tiled_position_embeddings` is modified.
|
|
402
|
+
self.position_embeddings = self.add_weight(
|
|
403
|
+
shape=(1, self.num_patches, self.hidden_dim),
|
|
404
|
+
initializer=initializers.TruncatedNormal(stddev=0.02),
|
|
405
|
+
trainable=False,
|
|
406
|
+
name="position_embeddings",
|
|
407
|
+
)
|
|
408
|
+
self.tiled_position_embeddings = self.add_weight(
|
|
409
|
+
shape=(1, self.tiled_num_patches, self.hidden_dim),
|
|
410
|
+
initializer="zeros", # Will be initialized by tiling.
|
|
411
|
+
trainable=True,
|
|
412
|
+
name="tiled_position_embeddings",
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
# Initialize the interpolated position embeddings.
|
|
416
|
+
self.tiled_position_embeddings.assign(
|
|
417
|
+
self._tile_position_embeddings(
|
|
418
|
+
self.position_embeddings,
|
|
419
|
+
patch_size=self.patch_size,
|
|
420
|
+
source_shape=self.pretrain_image_shape,
|
|
421
|
+
target_shape=self.image_shape,
|
|
422
|
+
)
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
def call(self, inputs, training=None):
|
|
426
|
+
x = inputs
|
|
427
|
+
patch_embeddings = self.patch_embeddings(x, training=training)
|
|
428
|
+
if self.data_format == "channels_last":
|
|
429
|
+
patch_embeddings = ops.reshape(
|
|
430
|
+
patch_embeddings,
|
|
431
|
+
(-1, self.patch_embeddings.seq_len, self.hidden_dim),
|
|
432
|
+
)
|
|
433
|
+
else:
|
|
434
|
+
patch_embeddings = ops.reshape(
|
|
435
|
+
patch_embeddings,
|
|
436
|
+
(-1, self.hidden_dim, self.patch_embeddings.seq_len),
|
|
437
|
+
)
|
|
438
|
+
patch_embeddings = ops.transpose(patch_embeddings, (0, 2, 1))
|
|
439
|
+
embeddings = ops.add(patch_embeddings, self.tiled_position_embeddings)
|
|
440
|
+
embeddings = self.dropout(embeddings, training=training)
|
|
441
|
+
return embeddings
|
|
442
|
+
|
|
443
|
+
def get_config(self):
|
|
444
|
+
config = super().get_config()
|
|
445
|
+
config.update(
|
|
446
|
+
{
|
|
447
|
+
"hidden_dim": self.hidden_dim,
|
|
448
|
+
"patch_size": self.patch_size,
|
|
449
|
+
"image_shape": self.image_shape,
|
|
450
|
+
"dropout_rate": self.dropout_rate,
|
|
451
|
+
"pretrain_image_shape": self.pretrain_image_shape,
|
|
452
|
+
}
|
|
453
|
+
)
|
|
454
|
+
return config
|
|
455
|
+
|
|
456
|
+
def compute_output_shape(self, input_shape):
|
|
457
|
+
if input_shape is None:
|
|
458
|
+
input_shape = [None, None, None, None]
|
|
459
|
+
output_shape = [input_shape[0], None, self.hidden_dim]
|
|
460
|
+
if self.data_format == "channels_last":
|
|
461
|
+
if input_shape[1] is not None and input_shape[2] is not None:
|
|
462
|
+
patch_num = input_shape[1] // self.patch_size
|
|
463
|
+
output_shape[1] = patch_num**2
|
|
464
|
+
else:
|
|
465
|
+
if input_shape[2] is not None and input_shape[3] is not None:
|
|
466
|
+
patch_num = input_shape[2] // self.patch_size
|
|
467
|
+
output_shape[1] = patch_num**2
|
|
468
|
+
return output_shape
|
|
469
|
+
|
|
470
|
+
@staticmethod
|
|
471
|
+
def _tile_position_embeddings(
|
|
472
|
+
position_embeddings, patch_size, source_shape, target_shape
|
|
473
|
+
):
|
|
474
|
+
"""Tile position embeddings to match the target image shape.
|
|
475
|
+
|
|
476
|
+
Reference:
|
|
477
|
+
- https://github.com/huggingface/transformers/blob/main/src/transformers/models/sam3/modeling_sam3.py
|
|
478
|
+
"""
|
|
479
|
+
position_embeddings = ops.convert_to_tensor(position_embeddings)
|
|
480
|
+
patch_size = int(patch_size)
|
|
481
|
+
source_shape = (int(source_shape[0]), int(source_shape[1]))
|
|
482
|
+
target_shape = (int(target_shape[0]), int(target_shape[1]))
|
|
483
|
+
hidden_dim = int(position_embeddings.shape[-1])
|
|
484
|
+
|
|
485
|
+
if (
|
|
486
|
+
source_shape[0] == target_shape[0]
|
|
487
|
+
and source_shape[1] == target_shape[1]
|
|
488
|
+
):
|
|
489
|
+
# No need to tile if the image size is the same as the
|
|
490
|
+
# position embedding image size.
|
|
491
|
+
return ops.copy(position_embeddings)
|
|
492
|
+
|
|
493
|
+
# Tile position embeddings to match target image size.
|
|
494
|
+
source_embedding_shape = (
|
|
495
|
+
source_shape[0] // patch_size,
|
|
496
|
+
source_shape[1] // patch_size,
|
|
497
|
+
)
|
|
498
|
+
target_embedding_shape = (
|
|
499
|
+
target_shape[0] // patch_size,
|
|
500
|
+
target_shape[1] // patch_size,
|
|
501
|
+
)
|
|
502
|
+
position_embeddings = ops.reshape(
|
|
503
|
+
position_embeddings,
|
|
504
|
+
(
|
|
505
|
+
1,
|
|
506
|
+
source_embedding_shape[0],
|
|
507
|
+
source_embedding_shape[1],
|
|
508
|
+
hidden_dim,
|
|
509
|
+
),
|
|
510
|
+
)
|
|
511
|
+
repeat_h = target_embedding_shape[0] // source_embedding_shape[0] + 1
|
|
512
|
+
repeat_w = target_embedding_shape[1] // source_embedding_shape[1] + 1
|
|
513
|
+
position_embeddings = ops.tile(
|
|
514
|
+
position_embeddings, (1, repeat_h, repeat_w, 1)
|
|
515
|
+
)
|
|
516
|
+
position_embeddings = position_embeddings[
|
|
517
|
+
:, : target_embedding_shape[0], : target_embedding_shape[1], :
|
|
518
|
+
]
|
|
519
|
+
return ops.reshape(position_embeddings, (1, -1, hidden_dim))
|
|
520
|
+
|
|
521
|
+
def _is_tiled_position_embeddings_updated(self):
|
|
522
|
+
"""Check if the tiled position embeddings are updated."""
|
|
523
|
+
original_tiled_position_embeddings = self._tile_position_embeddings(
|
|
524
|
+
self.position_embeddings,
|
|
525
|
+
patch_size=self.patch_size,
|
|
526
|
+
source_shape=self.pretrain_image_shape,
|
|
527
|
+
target_shape=self.image_shape,
|
|
528
|
+
)
|
|
529
|
+
diff = ops.sum(
|
|
530
|
+
ops.subtract(
|
|
531
|
+
original_tiled_position_embeddings,
|
|
532
|
+
self.tiled_position_embeddings,
|
|
533
|
+
)
|
|
534
|
+
)
|
|
535
|
+
return ops.cond(
|
|
536
|
+
ops.greater(diff, config.epsilon()), lambda: True, lambda: False
|
|
537
|
+
)
|
|
538
|
+
|
|
539
|
+
def save_own_variables(self, store):
|
|
540
|
+
if self._is_tiled_position_embeddings_updated():
|
|
541
|
+
self.position_embeddings.assign(
|
|
542
|
+
self._tile_position_embeddings(
|
|
543
|
+
self.tiled_position_embeddings,
|
|
544
|
+
patch_size=self.patch_size,
|
|
545
|
+
source_shape=self.image_shape,
|
|
546
|
+
target_shape=self.pretrain_image_shape,
|
|
547
|
+
)
|
|
548
|
+
)
|
|
549
|
+
super().save_own_variables(store)
|
|
550
|
+
|
|
551
|
+
def load_own_variables(self, store):
|
|
552
|
+
all_vars = self._trainable_variables + self._non_trainable_variables
|
|
553
|
+
for i, v in enumerate(all_vars):
|
|
554
|
+
if v is self.tiled_position_embeddings:
|
|
555
|
+
continue
|
|
556
|
+
v.assign(store[f"{i}"])
|
|
557
|
+
self.tiled_position_embeddings.assign(
|
|
558
|
+
self._tile_position_embeddings(
|
|
559
|
+
self.position_embeddings,
|
|
560
|
+
patch_size=self.patch_size,
|
|
561
|
+
source_shape=self.pretrain_image_shape,
|
|
562
|
+
target_shape=self.image_shape,
|
|
563
|
+
)
|
|
564
|
+
)
|
|
565
|
+
|
|
566
|
+
|
|
567
|
+
class SAM3SinePositionEmbedding(layers.Layer):
|
|
568
|
+
def __init__(
|
|
569
|
+
self,
|
|
570
|
+
num_pos_feats=64,
|
|
571
|
+
temperature=10000,
|
|
572
|
+
normalize=False,
|
|
573
|
+
scale=None,
|
|
574
|
+
**kwargs,
|
|
575
|
+
):
|
|
576
|
+
super().__init__(**kwargs)
|
|
577
|
+
self.num_pos_feats = int(num_pos_feats)
|
|
578
|
+
self.temperature = float(temperature)
|
|
579
|
+
self.normalize = bool(normalize)
|
|
580
|
+
if scale is not None and normalize is False:
|
|
581
|
+
raise ValueError("normalize should be True if scale is passed")
|
|
582
|
+
self.scale = 2 * math.pi if scale is None else scale
|
|
583
|
+
|
|
584
|
+
def build(self, input_shape=None):
|
|
585
|
+
if self.built:
|
|
586
|
+
return
|
|
587
|
+
|
|
588
|
+
def encode_1d_positions(self, x, y):
|
|
589
|
+
x_embed = ops.multiply(x, self.scale)
|
|
590
|
+
y_embed = ops.multiply(y, self.scale)
|
|
591
|
+
dim_t = ops.cast(ops.arange(self.num_pos_feats), dtype=x.dtype)
|
|
592
|
+
dim_t = ops.power(
|
|
593
|
+
self.temperature,
|
|
594
|
+
ops.divide(
|
|
595
|
+
ops.multiply(2, ops.floor_divide(dim_t, 2)), self.num_pos_feats
|
|
596
|
+
),
|
|
597
|
+
)
|
|
598
|
+
pos_x = ops.divide(ops.expand_dims(x_embed, -1), dim_t)
|
|
599
|
+
pos_y = ops.divide(ops.expand_dims(y_embed, -1), dim_t)
|
|
600
|
+
pos_x = ops.stack(
|
|
601
|
+
(ops.sin(pos_x[:, 0::2]), ops.cos(pos_x[:, 1::2])), axis=2
|
|
602
|
+
)
|
|
603
|
+
pos_x = ops.reshape(pos_x, (-1, self.num_pos_feats))
|
|
604
|
+
pos_y = ops.stack(
|
|
605
|
+
(ops.sin(pos_y[:, 0::2]), ops.cos(pos_y[:, 1::2])), axis=2
|
|
606
|
+
)
|
|
607
|
+
pos_y = ops.reshape(pos_y, (-1, self.num_pos_feats))
|
|
608
|
+
return pos_x, pos_y
|
|
609
|
+
|
|
610
|
+
def encode_boxes(self, boxes):
|
|
611
|
+
dim_t = ops.cast(ops.arange(self.num_pos_feats), dtype=boxes.dtype)
|
|
612
|
+
dim_t = ops.power(
|
|
613
|
+
self.temperature,
|
|
614
|
+
ops.divide(
|
|
615
|
+
ops.multiply(2, ops.floor_divide(dim_t, 2)), self.num_pos_feats
|
|
616
|
+
),
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
x_embed = ops.multiply(boxes[..., 0], self.scale)
|
|
620
|
+
y_embed = ops.multiply(boxes[..., 1], self.scale)
|
|
621
|
+
w_embed = ops.multiply(boxes[..., 2], self.scale)
|
|
622
|
+
h_embed = ops.multiply(boxes[..., 3], self.scale)
|
|
623
|
+
pos_x = ops.divide(ops.expand_dims(x_embed, -1), dim_t)
|
|
624
|
+
pos_y = ops.divide(ops.expand_dims(y_embed, -1), dim_t)
|
|
625
|
+
pos_w = ops.divide(ops.expand_dims(w_embed, -1), dim_t)
|
|
626
|
+
pos_h = ops.divide(ops.expand_dims(h_embed, -1), dim_t)
|
|
627
|
+
pos_x_shape = ops.shape(pos_x)
|
|
628
|
+
newshape = (pos_x_shape[0], pos_x_shape[1], self.num_pos_feats)
|
|
629
|
+
pos_x = ops.stack(
|
|
630
|
+
(ops.sin(pos_x[..., 0::2]), ops.cos(pos_x[..., 1::2])), axis=3
|
|
631
|
+
)
|
|
632
|
+
pos_x = ops.reshape(pos_x, newshape)
|
|
633
|
+
pos_y = ops.stack(
|
|
634
|
+
(ops.sin(pos_y[..., 0::2]), ops.cos(pos_y[..., 1::2])), axis=3
|
|
635
|
+
)
|
|
636
|
+
pos_y = ops.reshape(pos_y, newshape)
|
|
637
|
+
pos_w = ops.stack(
|
|
638
|
+
(ops.sin(pos_w[..., 0::2]), ops.cos(pos_w[..., 1::2])), axis=3
|
|
639
|
+
)
|
|
640
|
+
pos_w = ops.reshape(pos_w, newshape)
|
|
641
|
+
pos_h = ops.stack(
|
|
642
|
+
(ops.sin(pos_h[..., 0::2]), ops.cos(pos_h[..., 1::2])), axis=3
|
|
643
|
+
)
|
|
644
|
+
pos_h = ops.reshape(pos_h, newshape)
|
|
645
|
+
return ops.concatenate([pos_y, pos_x, pos_w, pos_h], axis=2)
|
|
646
|
+
|
|
647
|
+
def call(self, inputs, height, width, training=None):
|
|
648
|
+
not_mask = ops.ones((1, height, width), dtype=self.compute_dtype)
|
|
649
|
+
y_embed = ops.cumsum(not_mask, axis=1)
|
|
650
|
+
x_embed = ops.cumsum(not_mask, axis=2)
|
|
651
|
+
if self.normalize:
|
|
652
|
+
eps = 1e-6
|
|
653
|
+
y_embed = ops.multiply(
|
|
654
|
+
ops.divide(y_embed, ops.add(y_embed[:, -1:, :], eps)),
|
|
655
|
+
self.scale,
|
|
656
|
+
)
|
|
657
|
+
x_embed = ops.multiply(
|
|
658
|
+
ops.divide(x_embed, ops.add(x_embed[:, :, -1:], eps)),
|
|
659
|
+
self.scale,
|
|
660
|
+
)
|
|
661
|
+
dim_t = ops.cast(
|
|
662
|
+
ops.arange(self.num_pos_feats), dtype=self.compute_dtype
|
|
663
|
+
)
|
|
664
|
+
dim_t = ops.power(
|
|
665
|
+
self.temperature,
|
|
666
|
+
ops.divide(
|
|
667
|
+
ops.multiply(2, ops.floor_divide(dim_t, 2)), self.num_pos_feats
|
|
668
|
+
),
|
|
669
|
+
)
|
|
670
|
+
|
|
671
|
+
pos_x = ops.divide(ops.expand_dims(x_embed, -1), dim_t)
|
|
672
|
+
pos_y = ops.divide(ops.expand_dims(y_embed, -1), dim_t)
|
|
673
|
+
newshape = (1, height, width, self.num_pos_feats)
|
|
674
|
+
pos_x = ops.stack(
|
|
675
|
+
(ops.sin(pos_x[..., 0::2]), ops.cos(pos_x[..., 1::2])), axis=4
|
|
676
|
+
)
|
|
677
|
+
pos_x = ops.reshape(pos_x, newshape)
|
|
678
|
+
pos_y = ops.stack(
|
|
679
|
+
(ops.sin(pos_y[..., 0::2]), ops.cos(pos_y[..., 1::2])), axis=4
|
|
680
|
+
)
|
|
681
|
+
pos_y = ops.reshape(pos_y, newshape)
|
|
682
|
+
pos = ops.concatenate([pos_y, pos_x], axis=3)
|
|
683
|
+
pos = ops.tile(pos, (ops.shape(inputs)[0], 1, 1, 1))
|
|
684
|
+
return pos
|
|
685
|
+
|
|
686
|
+
def get_config(self):
|
|
687
|
+
config = super().get_config()
|
|
688
|
+
config.update(
|
|
689
|
+
{
|
|
690
|
+
"num_pos_feats": self.num_pos_feats,
|
|
691
|
+
"temperature": self.temperature,
|
|
692
|
+
"normalize": self.normalize,
|
|
693
|
+
"scale": self.scale,
|
|
694
|
+
}
|
|
695
|
+
)
|
|
696
|
+
return config
|
|
697
|
+
|
|
698
|
+
def compute_output_shape(self, input_shape):
|
|
699
|
+
output_shape = list(input_shape)
|
|
700
|
+
output_shape[1] = self.num_pos_feats * 2
|
|
701
|
+
return output_shape
|
|
702
|
+
|
|
703
|
+
|
|
704
|
+
class SAM3DecoderMLP(layers.Layer):
|
|
705
|
+
def __init__(self, num_layers, hidden_dim, output_dim, **kwargs):
|
|
706
|
+
super().__init__(**kwargs)
|
|
707
|
+
self.num_layers = int(num_layers)
|
|
708
|
+
self.hidden_dim = int(hidden_dim)
|
|
709
|
+
self.output_dim = int(output_dim)
|
|
710
|
+
|
|
711
|
+
if self.num_layers == 2:
|
|
712
|
+
self.layer1 = layers.Dense(
|
|
713
|
+
hidden_dim, dtype=self.dtype_policy, name="layer1"
|
|
714
|
+
)
|
|
715
|
+
self.layer2 = layers.Dense(
|
|
716
|
+
output_dim, dtype=self.dtype_policy, name="layer2"
|
|
717
|
+
)
|
|
718
|
+
elif num_layers == 3:
|
|
719
|
+
self.layer1 = layers.Dense(
|
|
720
|
+
hidden_dim, dtype=self.dtype_policy, name="layer1"
|
|
721
|
+
)
|
|
722
|
+
self.layer2 = layers.Dense(
|
|
723
|
+
hidden_dim, dtype=self.dtype_policy, name="layer2"
|
|
724
|
+
)
|
|
725
|
+
self.layer3 = layers.Dense(
|
|
726
|
+
output_dim, dtype=self.dtype_policy, name="layer3"
|
|
727
|
+
)
|
|
728
|
+
else:
|
|
729
|
+
raise ValueError("num_layers should be 2 or 3.")
|
|
730
|
+
|
|
731
|
+
def build(self, input_shape):
|
|
732
|
+
self.layer1.build(input_shape)
|
|
733
|
+
input_shape = self.layer1.compute_output_shape(input_shape)
|
|
734
|
+
self.layer2.build(input_shape)
|
|
735
|
+
if self.num_layers == 3:
|
|
736
|
+
input_shape = self.layer2.compute_output_shape(input_shape)
|
|
737
|
+
self.layer3.build(input_shape)
|
|
738
|
+
|
|
739
|
+
def call(self, inputs, training=None):
|
|
740
|
+
x = ops.relu(self.layer1(inputs, training=training))
|
|
741
|
+
if self.num_layers == 2:
|
|
742
|
+
return self.layer2(x, training=training)
|
|
743
|
+
else:
|
|
744
|
+
x = ops.relu(self.layer2(x, training=training))
|
|
745
|
+
return self.layer3(x, training=training)
|
|
746
|
+
|
|
747
|
+
def get_config(self):
|
|
748
|
+
config = super().get_config()
|
|
749
|
+
config.update(
|
|
750
|
+
{
|
|
751
|
+
"num_layers": self.num_layers,
|
|
752
|
+
"hidden_dim": self.hidden_dim,
|
|
753
|
+
"output_dim": self.output_dim,
|
|
754
|
+
}
|
|
755
|
+
)
|
|
756
|
+
return config
|
|
757
|
+
|
|
758
|
+
def compute_output_shape(self, input_shape):
|
|
759
|
+
output_shape = list(input_shape)
|
|
760
|
+
output_shape[-1] = self.output_dim
|
|
761
|
+
return output_shape
|
|
762
|
+
|
|
763
|
+
|
|
764
|
+
class SAM3BoxDecoder(layers.Layer):
|
|
765
|
+
def build(
|
|
766
|
+
self,
|
|
767
|
+
box_offsets_shape,
|
|
768
|
+
reference_boxes_shape,
|
|
769
|
+
pred_logits_shape,
|
|
770
|
+
presence_logits_shape,
|
|
771
|
+
):
|
|
772
|
+
pass
|
|
773
|
+
|
|
774
|
+
def call(
|
|
775
|
+
self,
|
|
776
|
+
box_offsets,
|
|
777
|
+
reference_boxes,
|
|
778
|
+
pred_logits,
|
|
779
|
+
presence_logits,
|
|
780
|
+
training=None,
|
|
781
|
+
):
|
|
782
|
+
reference_boxes_inv_sig = inverse_sigmoid(reference_boxes)
|
|
783
|
+
pred_boxes_cxcywh = ops.nn.sigmoid(
|
|
784
|
+
ops.add(reference_boxes_inv_sig, box_offsets)
|
|
785
|
+
)
|
|
786
|
+
pred_boxes = box_cxcywh_to_xyxy(pred_boxes_cxcywh)
|
|
787
|
+
return (
|
|
788
|
+
pred_boxes[:, -1],
|
|
789
|
+
pred_logits[:, -1, :, 0],
|
|
790
|
+
presence_logits[:, -1],
|
|
791
|
+
)
|
|
792
|
+
|
|
793
|
+
def compute_output_shape(
|
|
794
|
+
self,
|
|
795
|
+
box_offsets_shape,
|
|
796
|
+
reference_boxes_shape,
|
|
797
|
+
pred_logits_shape,
|
|
798
|
+
presence_logits_shape,
|
|
799
|
+
):
|
|
800
|
+
pred_boxes_shape = [
|
|
801
|
+
box_offsets_shape[0],
|
|
802
|
+
box_offsets_shape[-2],
|
|
803
|
+
box_offsets_shape[-1],
|
|
804
|
+
]
|
|
805
|
+
pred_logits_shape = [
|
|
806
|
+
pred_logits_shape[0],
|
|
807
|
+
pred_logits_shape[-2],
|
|
808
|
+
]
|
|
809
|
+
presence_logits_shape = [
|
|
810
|
+
presence_logits_shape[0],
|
|
811
|
+
presence_logits_shape[-2],
|
|
812
|
+
presence_logits_shape[-1],
|
|
813
|
+
]
|
|
814
|
+
return pred_boxes_shape, pred_logits_shape, presence_logits_shape
|