keras-hub 0.25.0.dev0__py3-none-any.whl → 0.26.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/layers/__init__.py +21 -0
- keras_hub/models/__init__.py +27 -0
- keras_hub/src/layers/modeling/non_max_supression.py +5 -2
- keras_hub/src/layers/modeling/reversible_embedding.py +2 -275
- keras_hub/src/layers/modeling/token_and_position_embedding.py +6 -6
- keras_hub/src/layers/modeling/transformer_layer_utils.py +9 -9
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +3 -1
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +3 -1
- keras_hub/src/models/albert/albert_backbone.py +1 -3
- keras_hub/src/models/backbone.py +3 -0
- keras_hub/src/models/bart/bart_backbone.py +1 -3
- keras_hub/src/models/bert/bert_backbone.py +2 -4
- keras_hub/src/models/bloom/bloom_backbone.py +1 -3
- keras_hub/src/models/causal_lm.py +2 -2
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -3
- keras_hub/src/models/edrec/edrec_backbone.py +147 -0
- keras_hub/src/models/edrec/edrec_layers.py +434 -0
- keras_hub/src/models/edrec/edrec_seq2seq_lm.py +273 -0
- keras_hub/src/models/electra/electra_backbone.py +1 -3
- keras_hub/src/models/f_net/f_net_backbone.py +1 -3
- keras_hub/src/models/falcon/falcon_backbone.py +1 -3
- keras_hub/src/models/flux/flux_layers.py +3 -3
- keras_hub/src/models/flux/flux_maths.py +29 -15
- keras_hub/src/models/gemma/gemma_backbone.py +1 -3
- keras_hub/src/models/gemma/gemma_causal_lm.py +1 -1
- keras_hub/src/models/gemma3/gemma3_attention.py +1 -1
- keras_hub/src/models/gemma3/gemma3_backbone.py +70 -8
- keras_hub/src/models/gemma3/gemma3_causal_lm.py +16 -1
- keras_hub/src/models/gemma3/gemma3_decoder_block.py +23 -3
- keras_hub/src/models/gemma3/{gemma3_interleave_embeddings.py → gemma3_layers.py} +101 -0
- keras_hub/src/models/gemma3/gemma3_presets.py +79 -7
- keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
- keras_hub/src/models/gpt2/gpt2_backbone.py +1 -3
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +1 -1
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +1 -3
- keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +1 -3
- keras_hub/src/models/llama/llama_backbone.py +1 -3
- keras_hub/src/models/masked_lm.py +1 -1
- keras_hub/src/models/mistral/mistral_backbone.py +1 -3
- keras_hub/src/models/mixtral/mixtral_backbone.py +1 -3
- keras_hub/src/models/moonshine/moonshine_backbone.py +1 -3
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +1 -3
- keras_hub/src/models/parseq/parseq_tokenizer.py +3 -1
- keras_hub/src/models/phi3/phi3_backbone.py +1 -3
- keras_hub/src/models/qwen/qwen_backbone.py +1 -3
- keras_hub/src/models/qwen/qwen_presets.py +209 -0
- keras_hub/src/models/qwen3/qwen3_backbone.py +1 -3
- keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +1 -3
- keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +15 -0
- keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +1 -3
- keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -3
- keras_hub/src/models/rqvae/__init__.py +5 -0
- keras_hub/src/models/rqvae/rqvae_backbone.py +167 -0
- keras_hub/src/models/rqvae/rqvae_layers.py +335 -0
- keras_hub/src/models/rwkv7/__init__.py +5 -0
- keras_hub/src/models/rwkv7/rwkv7_backbone.py +180 -0
- keras_hub/src/models/rwkv7/rwkv7_causal_lm.py +259 -0
- keras_hub/src/models/rwkv7/rwkv7_causal_lm_preprocessor.py +214 -0
- keras_hub/src/models/rwkv7/rwkv7_layer.py +724 -0
- keras_hub/src/models/rwkv7/rwkv7_presets.py +26 -0
- keras_hub/src/models/rwkv7/rwkv7_tokenizer.py +495 -0
- keras_hub/src/models/sam/sam_backbone.py +5 -1
- keras_hub/src/models/sam/sam_prompt_encoder.py +1 -1
- keras_hub/src/models/sam3/__init__.py +7 -0
- keras_hub/src/models/sam3/roi_align.py +222 -0
- keras_hub/src/models/sam3/sam3_detr_decoder.py +641 -0
- keras_hub/src/models/sam3/sam3_detr_encoder.py +293 -0
- keras_hub/src/models/sam3/sam3_dot_product_scoring.py +120 -0
- keras_hub/src/models/sam3/sam3_geometry_encoder.py +517 -0
- keras_hub/src/models/sam3/sam3_image_converter.py +10 -0
- keras_hub/src/models/sam3/sam3_layers.py +814 -0
- keras_hub/src/models/sam3/sam3_mask_decoder.py +374 -0
- keras_hub/src/models/sam3/sam3_pc_backbone.py +306 -0
- keras_hub/src/models/sam3/sam3_pc_image_segmenter.py +282 -0
- keras_hub/src/models/sam3/sam3_pc_image_segmenter_preprocessor.py +336 -0
- keras_hub/src/models/sam3/sam3_presets.py +16 -0
- keras_hub/src/models/sam3/sam3_text_encoder.py +212 -0
- keras_hub/src/models/sam3/sam3_tokenizer.py +65 -0
- keras_hub/src/models/sam3/sam3_utils.py +134 -0
- keras_hub/src/models/sam3/sam3_vision_encoder.py +738 -0
- keras_hub/src/models/segformer/segformer_backbone.py +6 -6
- keras_hub/src/models/siglip/siglip_layers.py +1 -3
- keras_hub/src/models/smollm3/smollm3_backbone.py +1 -3
- keras_hub/src/models/stable_diffusion_3/t5_encoder.py +1 -3
- keras_hub/src/models/t5/t5_backbone.py +1 -3
- keras_hub/src/models/t5gemma/t5gemma_backbone.py +1 -3
- keras_hub/src/models/task.py +1 -1
- keras_hub/src/tests/test_case.py +394 -3
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +33 -2
- keras_hub/src/tokenizers/byte_tokenizer.py +3 -1
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +15 -1
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +3 -1
- keras_hub/src/tokenizers/word_piece_tokenizer.py +15 -1
- keras_hub/src/utils/preset_utils.py +1 -1
- keras_hub/src/utils/tensor_utils.py +12 -0
- keras_hub/src/utils/transformers/convert_gemma3.py +68 -22
- keras_hub/src/utils/transformers/convert_qwen3_moe.py +4 -1
- keras_hub/src/utils/transformers/convert_sam3.py +472 -0
- keras_hub/src/utils/transformers/export/gemma3.py +196 -0
- keras_hub/src/utils/transformers/export/hf_exporter.py +86 -25
- keras_hub/src/utils/transformers/export/qwen.py +136 -0
- keras_hub/src/utils/transformers/preset_loader.py +15 -1
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +6 -0
- {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/METADATA +6 -13
- {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/RECORD +108 -76
- {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/WHEEL +1 -1
- keras_hub/src/models/gemma3/rms_normalization.py +0 -26
- {keras_hub-0.25.0.dev0.dist-info → keras_hub-0.26.0.dev0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,738 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from keras import layers
|
|
3
|
+
from keras import ops
|
|
4
|
+
|
|
5
|
+
from keras_hub.src.api_export import keras_hub_export
|
|
6
|
+
from keras_hub.src.models.sam3.sam3_layers import SAM3MLP
|
|
7
|
+
from keras_hub.src.models.sam3.sam3_layers import SAM3Embedding
|
|
8
|
+
from keras_hub.src.models.sam3.sam3_layers import SAM3RoPEAttention
|
|
9
|
+
from keras_hub.src.models.sam3.sam3_layers import SAM3SinePositionEmbedding
|
|
10
|
+
from keras_hub.src.models.sam3.sam3_utils import window_partition
|
|
11
|
+
from keras_hub.src.models.sam3.sam3_utils import window_unpartition
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SAM3ViTRotaryEmbedding(layers.Layer):
|
|
15
|
+
def __init__(self, rope_theta, head_dim, end_x, end_y, scale=1.0, **kwargs):
|
|
16
|
+
super().__init__(**kwargs)
|
|
17
|
+
self.rope_theta = float(rope_theta)
|
|
18
|
+
self.head_dim = int(head_dim)
|
|
19
|
+
self.end_x = int(end_x)
|
|
20
|
+
self.end_y = int(end_y)
|
|
21
|
+
self.scale = float(scale)
|
|
22
|
+
|
|
23
|
+
# Ensure even dimension for proper axial splitting.
|
|
24
|
+
if self.head_dim % 4 != 0:
|
|
25
|
+
raise ValueError("Dimension must be divisible by 4 for axial RoPE")
|
|
26
|
+
|
|
27
|
+
def build(self, input_shape):
|
|
28
|
+
freqs = 1.0 / (
|
|
29
|
+
self.rope_theta
|
|
30
|
+
** (
|
|
31
|
+
np.arange(0, self.head_dim, 4)[: (self.head_dim // 4)]
|
|
32
|
+
/ self.head_dim
|
|
33
|
+
)
|
|
34
|
+
)
|
|
35
|
+
flattened_indices = np.arange(self.end_x * self.end_y, dtype=np.int64)
|
|
36
|
+
x_positions = (flattened_indices % self.end_x) * self.scale
|
|
37
|
+
y_positions = (
|
|
38
|
+
np.floor_divide(flattened_indices, self.end_x) * self.scale
|
|
39
|
+
)
|
|
40
|
+
freqs_x = np.outer(x_positions, freqs).astype(np.float32)
|
|
41
|
+
freqs_y = np.outer(y_positions, freqs).astype(np.float32)
|
|
42
|
+
inv_freq = np.concatenate([freqs_x, freqs_y], axis=-1)
|
|
43
|
+
inv_freq = np.repeat(inv_freq, repeats=2, axis=-1)
|
|
44
|
+
rope_embeddings_cos = np.cos(inv_freq)
|
|
45
|
+
rope_embeddings_sin = np.sin(inv_freq)
|
|
46
|
+
self.rope_embeddings_cos = self.add_weight(
|
|
47
|
+
name="rope_embeddings_cos",
|
|
48
|
+
shape=rope_embeddings_cos.shape,
|
|
49
|
+
dtype=self.variable_dtype,
|
|
50
|
+
trainable=False,
|
|
51
|
+
initializer=rope_embeddings_cos,
|
|
52
|
+
)
|
|
53
|
+
self.rope_embeddings_sin = self.add_weight(
|
|
54
|
+
name="rope_embeddings_sin",
|
|
55
|
+
shape=rope_embeddings_sin.shape,
|
|
56
|
+
dtype=self.variable_dtype,
|
|
57
|
+
trainable=False,
|
|
58
|
+
initializer=rope_embeddings_sin,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
def call(self, inputs):
|
|
62
|
+
return self.rope_embeddings_cos, self.rope_embeddings_sin
|
|
63
|
+
|
|
64
|
+
def get_config(self):
|
|
65
|
+
config = super().get_config()
|
|
66
|
+
config.update(
|
|
67
|
+
{
|
|
68
|
+
"rope_theta": self.rope_theta,
|
|
69
|
+
"head_dim": self.head_dim,
|
|
70
|
+
"end_x": self.end_x,
|
|
71
|
+
"end_y": self.end_y,
|
|
72
|
+
"scale": self.scale,
|
|
73
|
+
}
|
|
74
|
+
)
|
|
75
|
+
return config
|
|
76
|
+
|
|
77
|
+
def compute_output_shape(self, input_shape):
|
|
78
|
+
embedding_shape = (self.end_x * self.end_y, self.head_dim)
|
|
79
|
+
return (embedding_shape, embedding_shape)
|
|
80
|
+
|
|
81
|
+
def load_own_variables(self, store):
|
|
82
|
+
try:
|
|
83
|
+
return super().load_own_variables(store)
|
|
84
|
+
except ValueError:
|
|
85
|
+
# `SAM3ViTRotaryEmbedding` has precomputed weights only. The issue
|
|
86
|
+
# of the loading logic could be ignored.
|
|
87
|
+
pass
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class SAM3ViTLayer(layers.Layer):
|
|
91
|
+
def __init__(
|
|
92
|
+
self,
|
|
93
|
+
image_shape,
|
|
94
|
+
patch_size,
|
|
95
|
+
hidden_dim,
|
|
96
|
+
intermediate_dim,
|
|
97
|
+
num_heads,
|
|
98
|
+
hidden_activation="gelu",
|
|
99
|
+
rope_theta=10000.0,
|
|
100
|
+
window_size=0,
|
|
101
|
+
rotary_scale=1.0,
|
|
102
|
+
attention_dropout_rate=0.0,
|
|
103
|
+
hidden_dropout_rate=0.0,
|
|
104
|
+
layer_norm_epsilon=1e-6,
|
|
105
|
+
**kwargs,
|
|
106
|
+
):
|
|
107
|
+
super().__init__(**kwargs)
|
|
108
|
+
self.image_shape = (
|
|
109
|
+
int(image_shape[0]),
|
|
110
|
+
int(image_shape[1]),
|
|
111
|
+
int(image_shape[2]),
|
|
112
|
+
)
|
|
113
|
+
self.patch_size = int(patch_size)
|
|
114
|
+
self.hidden_dim = int(hidden_dim)
|
|
115
|
+
self.intermediate_dim = int(intermediate_dim)
|
|
116
|
+
self.num_heads = int(num_heads)
|
|
117
|
+
self.hidden_activation = hidden_activation
|
|
118
|
+
self.rope_theta = float(rope_theta)
|
|
119
|
+
self.window_size = int(window_size)
|
|
120
|
+
self.rotary_scale = float(rotary_scale)
|
|
121
|
+
self.hidden_dropout_rate = float(hidden_dropout_rate)
|
|
122
|
+
self.attention_dropout_rate = float(attention_dropout_rate)
|
|
123
|
+
self.layer_norm_epsilon = float(layer_norm_epsilon)
|
|
124
|
+
self.head_dim = self.hidden_dim // self.num_heads
|
|
125
|
+
input_size = (
|
|
126
|
+
self.image_shape[0] // self.patch_size,
|
|
127
|
+
self.image_shape[1] // self.patch_size,
|
|
128
|
+
)
|
|
129
|
+
if self.window_size > 0 and (
|
|
130
|
+
input_size[0] % self.window_size != 0
|
|
131
|
+
or input_size[1] % self.window_size != 0
|
|
132
|
+
):
|
|
133
|
+
raise ValueError(
|
|
134
|
+
"Image size must be divisible by `patch_size` and "
|
|
135
|
+
"`window_size` for windowed attention. "
|
|
136
|
+
f"Received image size: {image_shape}, "
|
|
137
|
+
f"patch_size: {patch_size}, window_size: {window_size}"
|
|
138
|
+
)
|
|
139
|
+
rotary_input_size = (
|
|
140
|
+
input_size if window_size == 0 else (window_size, window_size)
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
self.layer_norm1 = layers.LayerNormalization(
|
|
144
|
+
epsilon=layer_norm_epsilon,
|
|
145
|
+
dtype=self.dtype_policy,
|
|
146
|
+
name="layer_norm1",
|
|
147
|
+
)
|
|
148
|
+
self.rotary_emb = SAM3ViTRotaryEmbedding(
|
|
149
|
+
rope_theta=rope_theta,
|
|
150
|
+
head_dim=self.head_dim,
|
|
151
|
+
end_x=rotary_input_size[0],
|
|
152
|
+
end_y=rotary_input_size[1],
|
|
153
|
+
scale=self.rotary_scale,
|
|
154
|
+
dtype=self.dtype_policy,
|
|
155
|
+
name="rotary_emb",
|
|
156
|
+
)
|
|
157
|
+
self.attention = SAM3RoPEAttention(
|
|
158
|
+
hidden_dim=self.hidden_dim,
|
|
159
|
+
num_heads=self.num_heads,
|
|
160
|
+
attention_dropout_rate=self.attention_dropout_rate,
|
|
161
|
+
dtype=self.dtype_policy,
|
|
162
|
+
name="attention",
|
|
163
|
+
)
|
|
164
|
+
self.layer_norm2 = layers.LayerNormalization(
|
|
165
|
+
epsilon=layer_norm_epsilon,
|
|
166
|
+
dtype=self.dtype_policy,
|
|
167
|
+
name="layer_norm2",
|
|
168
|
+
)
|
|
169
|
+
self.mlp = SAM3MLP(
|
|
170
|
+
hidden_dim=self.hidden_dim,
|
|
171
|
+
intermediate_dim=self.intermediate_dim,
|
|
172
|
+
activation=self.hidden_activation,
|
|
173
|
+
dropout_rate=self.hidden_dropout_rate,
|
|
174
|
+
dtype=self.dtype_policy,
|
|
175
|
+
name="mlp",
|
|
176
|
+
)
|
|
177
|
+
self.dropout = layers.Dropout(
|
|
178
|
+
self.hidden_dropout_rate, dtype=self.dtype_policy, name="dropout"
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
def build(self, input_shape):
|
|
182
|
+
self.input_hidden_dim = int(input_shape[-1])
|
|
183
|
+
self.layer_norm1.build(input_shape)
|
|
184
|
+
input_shape = self.layer_norm1.compute_output_shape(input_shape)
|
|
185
|
+
self.rotary_emb.build(input_shape)
|
|
186
|
+
input_shape_before_attention = input_shape
|
|
187
|
+
if self.window_size > 0:
|
|
188
|
+
input_shape = list(input_shape)
|
|
189
|
+
input_shape = (
|
|
190
|
+
None,
|
|
191
|
+
self.window_size,
|
|
192
|
+
self.window_size,
|
|
193
|
+
input_shape[-1],
|
|
194
|
+
)
|
|
195
|
+
self.attention.build(input_shape)
|
|
196
|
+
input_shape = self.attention.compute_output_shape(input_shape)
|
|
197
|
+
if self.window_size > 0:
|
|
198
|
+
input_shape = input_shape_before_attention
|
|
199
|
+
self.layer_norm2.build(input_shape)
|
|
200
|
+
self.mlp.build(input_shape)
|
|
201
|
+
self.dropout.build(input_shape)
|
|
202
|
+
|
|
203
|
+
def call(self, hidden_states, training=None):
|
|
204
|
+
residual = hidden_states
|
|
205
|
+
hidden_states = self.layer_norm1(hidden_states, training=training)
|
|
206
|
+
if self.window_size > 0:
|
|
207
|
+
height, width = (
|
|
208
|
+
self.image_shape[0] // self.patch_size,
|
|
209
|
+
self.image_shape[1] // self.patch_size,
|
|
210
|
+
)
|
|
211
|
+
# Partition into non-overlapping windows for efficient attention.
|
|
212
|
+
hidden_states = window_partition(
|
|
213
|
+
hidden_states,
|
|
214
|
+
height,
|
|
215
|
+
width,
|
|
216
|
+
self.window_size,
|
|
217
|
+
self.input_hidden_dim,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
position_embeddings = self.rotary_emb(hidden_states, training=training)
|
|
221
|
+
hidden_states = self.attention(
|
|
222
|
+
hidden_states, position_embeddings, training=training
|
|
223
|
+
)
|
|
224
|
+
if self.window_size > 0:
|
|
225
|
+
# Reverse window partition to restore original spatial layout.
|
|
226
|
+
hidden_states = window_unpartition(
|
|
227
|
+
hidden_states, height, width, self.window_size, self.hidden_dim
|
|
228
|
+
)
|
|
229
|
+
hidden_states = ops.add(residual, hidden_states)
|
|
230
|
+
residual = hidden_states
|
|
231
|
+
hidden_states = self.layer_norm2(hidden_states, training=training)
|
|
232
|
+
hidden_states = self.mlp(hidden_states, training=training)
|
|
233
|
+
hidden_states = ops.add(
|
|
234
|
+
residual, self.dropout(hidden_states, training=training)
|
|
235
|
+
)
|
|
236
|
+
return hidden_states
|
|
237
|
+
|
|
238
|
+
def get_config(self):
|
|
239
|
+
config = super().get_config()
|
|
240
|
+
config.update(
|
|
241
|
+
{
|
|
242
|
+
"image_shape": self.image_shape,
|
|
243
|
+
"patch_size": self.patch_size,
|
|
244
|
+
"hidden_dim": self.hidden_dim,
|
|
245
|
+
"intermediate_dim": self.intermediate_dim,
|
|
246
|
+
"num_heads": self.num_heads,
|
|
247
|
+
"hidden_activation": self.hidden_activation,
|
|
248
|
+
"rope_theta": self.rope_theta,
|
|
249
|
+
"window_size": self.window_size,
|
|
250
|
+
"rotary_scale": self.rotary_scale,
|
|
251
|
+
"attention_dropout_rate": self.attention_dropout_rate,
|
|
252
|
+
"hidden_dropout_rate": self.hidden_dropout_rate,
|
|
253
|
+
"layer_norm_epsilon": self.layer_norm_epsilon,
|
|
254
|
+
}
|
|
255
|
+
)
|
|
256
|
+
return config
|
|
257
|
+
|
|
258
|
+
def compute_output_shape(self, input_shape):
|
|
259
|
+
return input_shape
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
class SAM3ViTEncoder(layers.Layer):
|
|
263
|
+
def __init__(
|
|
264
|
+
self,
|
|
265
|
+
image_shape,
|
|
266
|
+
patch_size,
|
|
267
|
+
num_layers,
|
|
268
|
+
hidden_dim,
|
|
269
|
+
intermediate_dim,
|
|
270
|
+
num_heads,
|
|
271
|
+
pretrain_image_shape=(336, 336, 3),
|
|
272
|
+
hidden_activation="gelu",
|
|
273
|
+
rope_theta=100000.0,
|
|
274
|
+
window_size=0,
|
|
275
|
+
global_attn_indexes=None,
|
|
276
|
+
attention_dropout_rate=0.0,
|
|
277
|
+
hidden_dropout_rate=0.0,
|
|
278
|
+
layer_norm_epsilon=1e-6,
|
|
279
|
+
**kwargs,
|
|
280
|
+
):
|
|
281
|
+
super().__init__(**kwargs)
|
|
282
|
+
self.image_shape = (
|
|
283
|
+
int(image_shape[0]),
|
|
284
|
+
int(image_shape[1]),
|
|
285
|
+
int(image_shape[2]),
|
|
286
|
+
)
|
|
287
|
+
self.patch_size = int(patch_size)
|
|
288
|
+
self.num_layers = int(num_layers)
|
|
289
|
+
self.hidden_dim = int(hidden_dim)
|
|
290
|
+
self.intermediate_dim = int(intermediate_dim)
|
|
291
|
+
self.num_heads = int(num_heads)
|
|
292
|
+
self.hidden_activation = hidden_activation
|
|
293
|
+
self.rope_theta = float(rope_theta)
|
|
294
|
+
self.window_size = int(window_size)
|
|
295
|
+
if global_attn_indexes is not None:
|
|
296
|
+
self.global_attn_indexes = [int(i) for i in global_attn_indexes]
|
|
297
|
+
else:
|
|
298
|
+
self.global_attn_indexes = None
|
|
299
|
+
self.pretrain_image_shape = (
|
|
300
|
+
int(pretrain_image_shape[0]),
|
|
301
|
+
int(pretrain_image_shape[1]),
|
|
302
|
+
int(pretrain_image_shape[2]),
|
|
303
|
+
)
|
|
304
|
+
self.hidden_dropout_rate = float(hidden_dropout_rate)
|
|
305
|
+
self.attention_dropout_rate = float(attention_dropout_rate)
|
|
306
|
+
self.layer_norm_epsilon = float(layer_norm_epsilon)
|
|
307
|
+
height = self.image_shape[0] // self.patch_size
|
|
308
|
+
|
|
309
|
+
self.embeddings = SAM3Embedding(
|
|
310
|
+
hidden_dim=self.hidden_dim,
|
|
311
|
+
patch_size=self.patch_size,
|
|
312
|
+
image_shape=self.image_shape,
|
|
313
|
+
dropout_rate=self.hidden_dropout_rate,
|
|
314
|
+
pretrain_image_shape=self.pretrain_image_shape,
|
|
315
|
+
dtype=self.dtype_policy,
|
|
316
|
+
name="embeddings",
|
|
317
|
+
)
|
|
318
|
+
self.layer_norm = layers.LayerNormalization(
|
|
319
|
+
epsilon=self.layer_norm_epsilon,
|
|
320
|
+
dtype=self.dtype_policy,
|
|
321
|
+
name="layer_norm",
|
|
322
|
+
)
|
|
323
|
+
self.layers = [
|
|
324
|
+
SAM3ViTLayer(
|
|
325
|
+
image_shape=self.image_shape,
|
|
326
|
+
patch_size=self.patch_size,
|
|
327
|
+
hidden_dim=self.hidden_dim,
|
|
328
|
+
intermediate_dim=self.intermediate_dim,
|
|
329
|
+
num_heads=self.num_heads,
|
|
330
|
+
hidden_activation=self.hidden_activation,
|
|
331
|
+
rope_theta=self.rope_theta,
|
|
332
|
+
window_size=(
|
|
333
|
+
self.window_size if i not in self.global_attn_indexes else 0
|
|
334
|
+
),
|
|
335
|
+
rotary_scale=(
|
|
336
|
+
1.0
|
|
337
|
+
if i not in self.global_attn_indexes
|
|
338
|
+
else float(self.window_size) / height
|
|
339
|
+
),
|
|
340
|
+
attention_dropout_rate=self.attention_dropout_rate,
|
|
341
|
+
hidden_dropout_rate=self.hidden_dropout_rate,
|
|
342
|
+
layer_norm_epsilon=self.layer_norm_epsilon,
|
|
343
|
+
dtype=self.dtype_policy,
|
|
344
|
+
name=f"layer_{i}",
|
|
345
|
+
)
|
|
346
|
+
for i in range(self.num_layers)
|
|
347
|
+
]
|
|
348
|
+
|
|
349
|
+
def build(self, input_shape):
|
|
350
|
+
self.embeddings.build(input_shape)
|
|
351
|
+
input_shape = self.embeddings.compute_output_shape(input_shape)
|
|
352
|
+
input_shape = list(input_shape)
|
|
353
|
+
height = self.image_shape[0] // self.patch_size
|
|
354
|
+
width = self.image_shape[1] // self.patch_size
|
|
355
|
+
input_shape = [input_shape[0], height, width, self.hidden_dim]
|
|
356
|
+
self.layer_norm.build(input_shape)
|
|
357
|
+
for layer in self.layers:
|
|
358
|
+
layer.build(input_shape)
|
|
359
|
+
|
|
360
|
+
def call(self, pixel_values, training=None):
|
|
361
|
+
hidden_states = self.embeddings(pixel_values, training=training)
|
|
362
|
+
height = self.image_shape[0] // self.patch_size
|
|
363
|
+
width = self.image_shape[1] // self.patch_size
|
|
364
|
+
# Reshape to spatial format for windowed attention:
|
|
365
|
+
# [batch_size, height, width, hidden_size]
|
|
366
|
+
hidden_states = ops.reshape(
|
|
367
|
+
hidden_states, (-1, height, width, self.hidden_dim)
|
|
368
|
+
)
|
|
369
|
+
hidden_states = self.layer_norm(hidden_states, training=training)
|
|
370
|
+
for i, layer in enumerate(self.layers):
|
|
371
|
+
hidden_states = layer(hidden_states, training=training)
|
|
372
|
+
|
|
373
|
+
# Reshape back to sequence format:
|
|
374
|
+
# [batch_size, height*width, hidden_size]
|
|
375
|
+
return ops.reshape(hidden_states, (-1, height * width, self.hidden_dim))
|
|
376
|
+
|
|
377
|
+
def get_config(self):
|
|
378
|
+
config = super().get_config()
|
|
379
|
+
config.update(
|
|
380
|
+
{
|
|
381
|
+
"image_shape": self.image_shape,
|
|
382
|
+
"patch_size": self.patch_size,
|
|
383
|
+
"num_layers": self.num_layers,
|
|
384
|
+
"hidden_dim": self.hidden_dim,
|
|
385
|
+
"intermediate_dim": self.intermediate_dim,
|
|
386
|
+
"num_heads": self.num_heads,
|
|
387
|
+
"pretrain_image_shape": self.pretrain_image_shape,
|
|
388
|
+
"hidden_activation": self.hidden_activation,
|
|
389
|
+
"rope_theta": self.rope_theta,
|
|
390
|
+
"window_size": self.window_size,
|
|
391
|
+
"global_attn_indexes": self.global_attn_indexes,
|
|
392
|
+
"attention_dropout_rate": self.attention_dropout_rate,
|
|
393
|
+
"hidden_dropout_rate": self.hidden_dropout_rate,
|
|
394
|
+
"layer_norm_epsilon": self.layer_norm_epsilon,
|
|
395
|
+
}
|
|
396
|
+
)
|
|
397
|
+
return config
|
|
398
|
+
|
|
399
|
+
def compute_output_shape(self, input_shape):
|
|
400
|
+
input_shape = self.embeddings.compute_output_shape(input_shape)
|
|
401
|
+
return input_shape
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
class SAM3FPNLayer(layers.Layer):
|
|
405
|
+
def __init__(self, input_dim, fpn_dim, scale_factor, **kwargs):
|
|
406
|
+
super().__init__(**kwargs)
|
|
407
|
+
self.input_dim = int(input_dim)
|
|
408
|
+
self.fpn_dim = int(fpn_dim)
|
|
409
|
+
self.scale_factor = float(scale_factor)
|
|
410
|
+
|
|
411
|
+
# Build the upsampling/downsampling layers based on scale factor.
|
|
412
|
+
if self.scale_factor == 4.0:
|
|
413
|
+
self.scale_layers = [
|
|
414
|
+
layers.Conv2DTranspose(
|
|
415
|
+
self.input_dim // 2,
|
|
416
|
+
kernel_size=2,
|
|
417
|
+
strides=2,
|
|
418
|
+
dtype=self.dtype_policy,
|
|
419
|
+
name="scale_layers_0",
|
|
420
|
+
),
|
|
421
|
+
layers.Activation(
|
|
422
|
+
"gelu", dtype=self.dtype_policy, name="scale_layers_1"
|
|
423
|
+
),
|
|
424
|
+
layers.Conv2DTranspose(
|
|
425
|
+
self.input_dim // 4,
|
|
426
|
+
kernel_size=2,
|
|
427
|
+
strides=2,
|
|
428
|
+
dtype=self.dtype_policy,
|
|
429
|
+
name="scale_layers_2",
|
|
430
|
+
),
|
|
431
|
+
]
|
|
432
|
+
elif self.scale_factor == 2.0:
|
|
433
|
+
self.scale_layers = [
|
|
434
|
+
layers.Conv2DTranspose(
|
|
435
|
+
self.input_dim // 2,
|
|
436
|
+
kernel_size=2,
|
|
437
|
+
strides=2,
|
|
438
|
+
dtype=self.dtype_policy,
|
|
439
|
+
name="scale_layers_0",
|
|
440
|
+
)
|
|
441
|
+
]
|
|
442
|
+
elif self.scale_factor == 1.0:
|
|
443
|
+
self.scale_layers = []
|
|
444
|
+
elif self.scale_factor == 0.5:
|
|
445
|
+
self.scale_layers = [
|
|
446
|
+
layers.MaxPooling2D(
|
|
447
|
+
pool_size=2,
|
|
448
|
+
strides=2,
|
|
449
|
+
dtype=self.dtype_policy,
|
|
450
|
+
name="scale_layers_0",
|
|
451
|
+
)
|
|
452
|
+
]
|
|
453
|
+
else:
|
|
454
|
+
raise ValueError(
|
|
455
|
+
f"Unsupported scale factor: {self.scale_factor}. "
|
|
456
|
+
"Supported scale factors are 4.0, 2.0, 1.0, and 0.5."
|
|
457
|
+
)
|
|
458
|
+
self.proj1 = layers.Conv2D(
|
|
459
|
+
self.fpn_dim, kernel_size=1, dtype=self.dtype_policy, name="proj1"
|
|
460
|
+
)
|
|
461
|
+
self.pad = layers.ZeroPadding2D(
|
|
462
|
+
padding=1, dtype=self.dtype_policy, name="pad"
|
|
463
|
+
)
|
|
464
|
+
self.proj2 = layers.Conv2D(
|
|
465
|
+
self.fpn_dim, kernel_size=3, dtype=self.dtype_policy, name="proj2"
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
def build(self, input_shape):
|
|
469
|
+
for layer in self.scale_layers:
|
|
470
|
+
layer.build(input_shape)
|
|
471
|
+
input_shape = layer.compute_output_shape(input_shape)
|
|
472
|
+
self.proj1.build(input_shape)
|
|
473
|
+
input_shape = self.proj1.compute_output_shape(input_shape)
|
|
474
|
+
self.pad.build(input_shape)
|
|
475
|
+
input_shape = self.pad.compute_output_shape(input_shape)
|
|
476
|
+
self.proj2.build(input_shape)
|
|
477
|
+
|
|
478
|
+
def call(self, inputs, training=None):
|
|
479
|
+
hidden_states = inputs
|
|
480
|
+
for layer in self.scale_layers:
|
|
481
|
+
hidden_states = layer(hidden_states, training=training)
|
|
482
|
+
hidden_states = self.proj1(hidden_states, training=training)
|
|
483
|
+
hidden_states = self.pad(hidden_states, training=training)
|
|
484
|
+
return self.proj2(hidden_states, training=training)
|
|
485
|
+
|
|
486
|
+
def get_config(self):
|
|
487
|
+
config = super().get_config()
|
|
488
|
+
config.update(
|
|
489
|
+
{
|
|
490
|
+
"input_dim": self.input_dim,
|
|
491
|
+
"fpn_dim": self.fpn_dim,
|
|
492
|
+
"scale_factor": self.scale_factor,
|
|
493
|
+
}
|
|
494
|
+
)
|
|
495
|
+
return config
|
|
496
|
+
|
|
497
|
+
def compute_output_shape(self, input_shape):
|
|
498
|
+
output_shape = input_shape
|
|
499
|
+
for layer in self.scale_layers:
|
|
500
|
+
output_shape = layer.compute_output_shape(output_shape)
|
|
501
|
+
output_shape = self.proj1.compute_output_shape(output_shape)
|
|
502
|
+
output_shape = self.pad.compute_output_shape(output_shape)
|
|
503
|
+
return self.proj2.compute_output_shape(output_shape)
|
|
504
|
+
|
|
505
|
+
|
|
506
|
+
class SAM3VisionNeck(layers.Layer):
|
|
507
|
+
def __init__(self, hidden_dim, fpn_hidden_dim, scale_factors, **kwargs):
|
|
508
|
+
super().__init__(**kwargs)
|
|
509
|
+
self.hidden_dim = int(hidden_dim)
|
|
510
|
+
self.fpn_hidden_dim = int(fpn_hidden_dim)
|
|
511
|
+
self.scale_factors = scale_factors
|
|
512
|
+
|
|
513
|
+
self.position_encoding = SAM3SinePositionEmbedding(
|
|
514
|
+
num_pos_feats=self.fpn_hidden_dim // 2,
|
|
515
|
+
normalize=True,
|
|
516
|
+
dtype=self.dtype_policy,
|
|
517
|
+
name="position_encoding",
|
|
518
|
+
)
|
|
519
|
+
self.fpn_layers = [
|
|
520
|
+
SAM3FPNLayer(
|
|
521
|
+
input_dim=self.hidden_dim,
|
|
522
|
+
fpn_dim=self.fpn_hidden_dim,
|
|
523
|
+
scale_factor=scale,
|
|
524
|
+
dtype=self.dtype_policy,
|
|
525
|
+
name=f"fpn_layer_{i}",
|
|
526
|
+
)
|
|
527
|
+
for i, scale in enumerate(self.scale_factors)
|
|
528
|
+
]
|
|
529
|
+
|
|
530
|
+
def build(self, input_shape):
|
|
531
|
+
self.position_encoding.build()
|
|
532
|
+
self.fpn_image_shapes = []
|
|
533
|
+
for layer in self.fpn_layers:
|
|
534
|
+
layer.build(input_shape)
|
|
535
|
+
fpn_shape = layer.compute_output_shape(input_shape)
|
|
536
|
+
self.fpn_image_shapes.append([int(fpn_shape[1]), int(fpn_shape[2])])
|
|
537
|
+
|
|
538
|
+
def call(self, hidden_states, training=None):
|
|
539
|
+
fpn_hidden_states = []
|
|
540
|
+
fpn_position_encodings = []
|
|
541
|
+
for i, layer in enumerate(self.fpn_layers):
|
|
542
|
+
fpn_output = layer(hidden_states, training=training)
|
|
543
|
+
fpn_hidden_states.append(fpn_output)
|
|
544
|
+
height, width = self.fpn_image_shapes[i]
|
|
545
|
+
pos_enc = self.position_encoding(
|
|
546
|
+
fpn_output, height=height, width=width, training=training
|
|
547
|
+
)
|
|
548
|
+
fpn_position_encodings.append(pos_enc)
|
|
549
|
+
return fpn_hidden_states, fpn_position_encodings
|
|
550
|
+
|
|
551
|
+
def get_config(self):
|
|
552
|
+
config = super().get_config()
|
|
553
|
+
config.update(
|
|
554
|
+
{
|
|
555
|
+
"hidden_dim": self.hidden_dim,
|
|
556
|
+
"fpn_hidden_dim": self.fpn_hidden_dim,
|
|
557
|
+
"scale_factors": self.scale_factors,
|
|
558
|
+
}
|
|
559
|
+
)
|
|
560
|
+
return config
|
|
561
|
+
|
|
562
|
+
def compute_output_shape(self, input_shape):
|
|
563
|
+
fpn_hidden_state_shapes = []
|
|
564
|
+
for layer in self.fpn_layers:
|
|
565
|
+
fpn_hidden_state_shapes.append(
|
|
566
|
+
layer.compute_output_shape(input_shape)
|
|
567
|
+
)
|
|
568
|
+
# fpn_hidden_states and fpn_position_encodings have the same shapes.
|
|
569
|
+
return fpn_hidden_state_shapes, fpn_hidden_state_shapes
|
|
570
|
+
|
|
571
|
+
|
|
572
|
+
@keras_hub_export("keras_hub.layers.SAM3VisionEncoder")
|
|
573
|
+
class SAM3VisionEncoder(layers.Layer):
|
|
574
|
+
"""A vision encoder for the Segment Anything Model 3 (SAM3).
|
|
575
|
+
|
|
576
|
+
This layer implements a Vision Transformer (ViT) backbone followed by a
|
|
577
|
+
Feature Pyramid Network (FPN) neck. It processes input images and produces
|
|
578
|
+
multi-scale feature maps and their corresponding position encodings.
|
|
579
|
+
|
|
580
|
+
Args:
|
|
581
|
+
image_shape: tuple. The shape of the input image
|
|
582
|
+
(height, width, channels).
|
|
583
|
+
patch_size: int. The size of the patches to be extracted from the image.
|
|
584
|
+
num_layers: int. The number of transformer layers in the ViT backbone.
|
|
585
|
+
hidden_dim: int. The hidden dimension of the transformer layers.
|
|
586
|
+
intermediate_dim: int. The dimension of the intermediate layer in the
|
|
587
|
+
transformer's MLP.
|
|
588
|
+
num_heads: int. The number of attention heads.
|
|
589
|
+
fpn_hidden_dim: int. The hidden dimension of the FPN.
|
|
590
|
+
fpn_scale_factors: list of floats. The scale factors for each level of
|
|
591
|
+
the feature pyramid.
|
|
592
|
+
pretrain_image_shape: tuple. The shape of the image used during
|
|
593
|
+
pretraining, for position embedding interpolation. Defaults to
|
|
594
|
+
`(336, 336, 3)`.
|
|
595
|
+
hidden_activation: str. The activation function for the transformer
|
|
596
|
+
layers. Defaults to `"gelu"`.
|
|
597
|
+
rope_theta: float. The theta value for rotary position embeddings.
|
|
598
|
+
Defaults to `10000.0`.
|
|
599
|
+
window_size: int. The size of the window for windowed attention.
|
|
600
|
+
Defaults to `0`.
|
|
601
|
+
global_attn_indexes: list of ints. The indices of the layers that use
|
|
602
|
+
global attention instead of windowed attention.
|
|
603
|
+
attention_dropout_rate: float. The dropout rate for attention. Defaults
|
|
604
|
+
to `0`.
|
|
605
|
+
hidden_dropout_rate: float. The dropout rate for the MLP. Defaults to
|
|
606
|
+
`0.0`.
|
|
607
|
+
layer_norm_epsilon: float. The epsilon value for layer normalization.
|
|
608
|
+
Defaults to `1e-6`.
|
|
609
|
+
"""
|
|
610
|
+
|
|
611
|
+
def __init__(
|
|
612
|
+
self,
|
|
613
|
+
image_shape,
|
|
614
|
+
patch_size,
|
|
615
|
+
num_layers,
|
|
616
|
+
hidden_dim,
|
|
617
|
+
intermediate_dim,
|
|
618
|
+
num_heads,
|
|
619
|
+
fpn_hidden_dim,
|
|
620
|
+
fpn_scale_factors,
|
|
621
|
+
pretrain_image_shape=(336, 336, 3),
|
|
622
|
+
hidden_activation="gelu",
|
|
623
|
+
rope_theta=10000.0,
|
|
624
|
+
window_size=0,
|
|
625
|
+
global_attn_indexes=None,
|
|
626
|
+
attention_dropout_rate=0.0,
|
|
627
|
+
hidden_dropout_rate=0.0,
|
|
628
|
+
layer_norm_epsilon=1e-6,
|
|
629
|
+
**kwargs,
|
|
630
|
+
):
|
|
631
|
+
super().__init__(**kwargs)
|
|
632
|
+
self.image_shape = (
|
|
633
|
+
int(image_shape[0]),
|
|
634
|
+
int(image_shape[1]),
|
|
635
|
+
int(image_shape[2]),
|
|
636
|
+
)
|
|
637
|
+
self.patch_size = int(patch_size)
|
|
638
|
+
self.num_layers = int(num_layers)
|
|
639
|
+
self.hidden_dim = int(hidden_dim)
|
|
640
|
+
self.intermediate_dim = int(intermediate_dim)
|
|
641
|
+
self.num_heads = int(num_heads)
|
|
642
|
+
self.fpn_hidden_dim = int(fpn_hidden_dim)
|
|
643
|
+
self.fpn_scale_factors = fpn_scale_factors
|
|
644
|
+
self.hidden_activation = hidden_activation
|
|
645
|
+
self.rope_theta = float(rope_theta)
|
|
646
|
+
self.window_size = int(window_size)
|
|
647
|
+
if global_attn_indexes is not None:
|
|
648
|
+
self.global_attn_indexes = [int(i) for i in global_attn_indexes]
|
|
649
|
+
else:
|
|
650
|
+
self.global_attn_indexes = None
|
|
651
|
+
self.pretrain_image_shape = (
|
|
652
|
+
int(pretrain_image_shape[0]),
|
|
653
|
+
int(pretrain_image_shape[1]),
|
|
654
|
+
int(pretrain_image_shape[2]),
|
|
655
|
+
)
|
|
656
|
+
self.hidden_dropout_rate = float(hidden_dropout_rate)
|
|
657
|
+
self.attention_dropout_rate = float(attention_dropout_rate)
|
|
658
|
+
self.layer_norm_epsilon = float(layer_norm_epsilon)
|
|
659
|
+
|
|
660
|
+
self.backbone = SAM3ViTEncoder(
|
|
661
|
+
image_shape=self.image_shape,
|
|
662
|
+
patch_size=self.patch_size,
|
|
663
|
+
num_layers=self.num_layers,
|
|
664
|
+
hidden_dim=self.hidden_dim,
|
|
665
|
+
intermediate_dim=self.intermediate_dim,
|
|
666
|
+
num_heads=self.num_heads,
|
|
667
|
+
pretrain_image_shape=self.pretrain_image_shape,
|
|
668
|
+
hidden_activation=self.hidden_activation,
|
|
669
|
+
rope_theta=self.rope_theta,
|
|
670
|
+
window_size=self.window_size,
|
|
671
|
+
global_attn_indexes=self.global_attn_indexes,
|
|
672
|
+
attention_dropout_rate=self.attention_dropout_rate,
|
|
673
|
+
hidden_dropout_rate=self.hidden_dropout_rate,
|
|
674
|
+
layer_norm_epsilon=self.layer_norm_epsilon,
|
|
675
|
+
dtype=self.dtype_policy,
|
|
676
|
+
name="backbone",
|
|
677
|
+
)
|
|
678
|
+
self.vision_neck = SAM3VisionNeck(
|
|
679
|
+
hidden_dim=self.hidden_dim,
|
|
680
|
+
fpn_hidden_dim=self.fpn_hidden_dim,
|
|
681
|
+
scale_factors=self.fpn_scale_factors,
|
|
682
|
+
dtype=self.dtype_policy,
|
|
683
|
+
name="vision_neck",
|
|
684
|
+
)
|
|
685
|
+
|
|
686
|
+
def build(self, input_shape):
|
|
687
|
+
self.backbone.build(input_shape)
|
|
688
|
+
input_shape = self.backbone.compute_output_shape(input_shape)
|
|
689
|
+
height = self.image_shape[0] // self.patch_size
|
|
690
|
+
width = self.image_shape[1] // self.patch_size
|
|
691
|
+
input_shape = (input_shape[0], height, width, input_shape[-1])
|
|
692
|
+
self.vision_neck.build(input_shape)
|
|
693
|
+
|
|
694
|
+
def call(self, pixel_values, training=None):
|
|
695
|
+
hidden_states = self.backbone(pixel_values, training=training)
|
|
696
|
+
height = self.image_shape[0] // self.patch_size
|
|
697
|
+
width = self.image_shape[1] // self.patch_size
|
|
698
|
+
spatial_hidden_states = ops.reshape(
|
|
699
|
+
hidden_states, (-1, height, width, self.hidden_dim)
|
|
700
|
+
)
|
|
701
|
+
fpn_hidden_states, fpn_position_encodings = self.vision_neck(
|
|
702
|
+
spatial_hidden_states, training=training
|
|
703
|
+
)
|
|
704
|
+
return fpn_hidden_states, fpn_position_encodings
|
|
705
|
+
|
|
706
|
+
def get_config(self):
|
|
707
|
+
config = super().get_config()
|
|
708
|
+
config.update(
|
|
709
|
+
{
|
|
710
|
+
"image_shape": self.image_shape,
|
|
711
|
+
"patch_size": self.patch_size,
|
|
712
|
+
"num_layers": self.num_layers,
|
|
713
|
+
"hidden_dim": self.hidden_dim,
|
|
714
|
+
"intermediate_dim": self.intermediate_dim,
|
|
715
|
+
"num_heads": self.num_heads,
|
|
716
|
+
"fpn_hidden_dim": self.fpn_hidden_dim,
|
|
717
|
+
"fpn_scale_factors": self.fpn_scale_factors,
|
|
718
|
+
"pretrain_image_shape": self.pretrain_image_shape,
|
|
719
|
+
"hidden_activation": self.hidden_activation,
|
|
720
|
+
"rope_theta": self.rope_theta,
|
|
721
|
+
"window_size": self.window_size,
|
|
722
|
+
"global_attn_indexes": self.global_attn_indexes,
|
|
723
|
+
"attention_dropout_rate": self.attention_dropout_rate,
|
|
724
|
+
"hidden_dropout_rate": self.hidden_dropout_rate,
|
|
725
|
+
"layer_norm_epsilon": self.layer_norm_epsilon,
|
|
726
|
+
}
|
|
727
|
+
)
|
|
728
|
+
return config
|
|
729
|
+
|
|
730
|
+
def compute_output_shape(self, input_shape):
|
|
731
|
+
input_shape = self.backbone.compute_output_shape(input_shape)
|
|
732
|
+
height = self.image_shape[0] // self.patch_size
|
|
733
|
+
width = self.image_shape[1] // self.patch_size
|
|
734
|
+
input_shape = (input_shape[0], height, width, input_shape[-1])
|
|
735
|
+
fpn_hidden_state_shapes, fpn_position_encoding_shapes = (
|
|
736
|
+
self.vision_neck.compute_output_shape(input_shape)
|
|
737
|
+
)
|
|
738
|
+
return fpn_hidden_state_shapes, fpn_position_encoding_shapes
|