keras-hub-nightly 0.23.0.dev202510090417__py3-none-any.whl → 0.23.0.dev202510110411__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of keras-hub-nightly might be problematic. Click here for more details.
- keras_hub/layers/__init__.py +3 -0
- keras_hub/models/__init__.py +9 -0
- keras_hub/src/models/mobilenetv5/__init__.py +0 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_attention.py +699 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_backbone.py +396 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_blocks.py +890 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_builder.py +436 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier.py +157 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier_preprocessor.py +16 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_image_converter.py +10 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_layers.py +462 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_utils.py +146 -0
- keras_hub/src/models/qwen3_moe/__init__.py +5 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +30 -0
- keras_hub/src/utils/preset_utils.py +9 -2
- keras_hub/src/utils/timm/convert_mobilenetv5.py +321 -0
- keras_hub/src/utils/timm/preset_loader.py +8 -4
- keras_hub/src/version.py +1 -1
- {keras_hub_nightly-0.23.0.dev202510090417.dist-info → keras_hub_nightly-0.23.0.dev202510110411.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.23.0.dev202510090417.dist-info → keras_hub_nightly-0.23.0.dev202510110411.dist-info}/RECORD +22 -9
- {keras_hub_nightly-0.23.0.dev202510090417.dist-info → keras_hub_nightly-0.23.0.dev202510110411.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.23.0.dev202510090417.dist-info → keras_hub_nightly-0.23.0.dev202510110411.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,699 @@
|
|
|
1
|
+
import keras
|
|
2
|
+
|
|
3
|
+
from keras_hub.src.models.mobilenetv5.mobilenetv5_layers import DropPath
|
|
4
|
+
from keras_hub.src.models.mobilenetv5.mobilenetv5_layers import LayerScale2d
|
|
5
|
+
from keras_hub.src.models.mobilenetv5.mobilenetv5_layers import RmsNorm2d
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class MultiQueryAttention2d(keras.layers.Layer):
|
|
9
|
+
"""Implements 2D Multi-Query Attention.
|
|
10
|
+
|
|
11
|
+
This layer performs attention on 2D spatial inputs. It uses a multi-query
|
|
12
|
+
attention mechanism where multiple query heads attend to a single key and
|
|
13
|
+
value.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
filters: int. The output channel dimension.
|
|
17
|
+
num_heads: int. The number of attention heads.
|
|
18
|
+
key_dim: int. The dimension of the key. If `None`, it is calculated as
|
|
19
|
+
`dim // num_heads`.
|
|
20
|
+
value_dim: int. The dimension of the value. If `None`, it is calculated
|
|
21
|
+
as `dim // num_heads`.
|
|
22
|
+
query_strides: int or tuple. The stride for downsampling the query.
|
|
23
|
+
kv_stride: int. The stride for downsampling the key and value.
|
|
24
|
+
dw_kernel_size: int. The kernel size for the depthwise convolution used
|
|
25
|
+
for downsampling.
|
|
26
|
+
dilation: int. The dilation rate for the depthwise convolution.
|
|
27
|
+
padding: str. The padding type for convolutions.
|
|
28
|
+
attn_drop: float. The dropout rate for the attention weights.
|
|
29
|
+
proj_drop: float. The dropout rate for the output projection.
|
|
30
|
+
norm_layer: keras.layers.Layer. The normalization layer to use.
|
|
31
|
+
use_bias: bool. If `True`, bias terms are used in convolutions.
|
|
32
|
+
channel_axis: int. The axis representing the channels in the input
|
|
33
|
+
tensor.
|
|
34
|
+
data_format: str. The format of the input data, either
|
|
35
|
+
`"channels_last"` or `"channels_first"`.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
filters,
|
|
41
|
+
num_heads=8,
|
|
42
|
+
key_dim=None,
|
|
43
|
+
value_dim=None,
|
|
44
|
+
query_strides=1,
|
|
45
|
+
kv_stride=1,
|
|
46
|
+
dw_kernel_size=3,
|
|
47
|
+
dilation=1,
|
|
48
|
+
padding="same",
|
|
49
|
+
attn_drop=0.0,
|
|
50
|
+
proj_drop=0.0,
|
|
51
|
+
norm_layer=keras.layers.BatchNormalization,
|
|
52
|
+
use_bias=False,
|
|
53
|
+
channel_axis=None,
|
|
54
|
+
data_format=None,
|
|
55
|
+
dtype=None,
|
|
56
|
+
**kwargs,
|
|
57
|
+
):
|
|
58
|
+
super().__init__(dtype=dtype, **kwargs)
|
|
59
|
+
self.filters = filters
|
|
60
|
+
self.num_heads = num_heads
|
|
61
|
+
self.key_dim_arg = key_dim
|
|
62
|
+
self.value_dim_arg = value_dim
|
|
63
|
+
self.query_strides_arg = query_strides
|
|
64
|
+
self.kv_stride = kv_stride
|
|
65
|
+
self.dw_kernel_size = dw_kernel_size
|
|
66
|
+
self.dilation = dilation
|
|
67
|
+
self.padding_arg = padding
|
|
68
|
+
self.attn_drop_rate = attn_drop
|
|
69
|
+
self.proj_drop_rate = proj_drop
|
|
70
|
+
self.norm_layer = norm_layer
|
|
71
|
+
self.use_bias = use_bias
|
|
72
|
+
self.channel_axis = channel_axis
|
|
73
|
+
self.data_format = data_format
|
|
74
|
+
self.query_strides = (
|
|
75
|
+
query_strides
|
|
76
|
+
if isinstance(query_strides, (list, tuple))
|
|
77
|
+
else (query_strides, query_strides)
|
|
78
|
+
)
|
|
79
|
+
self.has_query_strides = any([s > 1 for s in self.query_strides])
|
|
80
|
+
self.padding = padding
|
|
81
|
+
self.conv_kernel_initializer = keras.initializers.VarianceScaling(
|
|
82
|
+
scale=2.0, mode="fan_out", distribution="untruncated_normal"
|
|
83
|
+
)
|
|
84
|
+
self.bias_initializer = "zeros"
|
|
85
|
+
self.attn_drop_layer = keras.layers.Dropout(
|
|
86
|
+
attn_drop, dtype=self.dtype_policy
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
def build(self, input_shape):
|
|
90
|
+
super().build(input_shape)
|
|
91
|
+
dim = input_shape[self.channel_axis]
|
|
92
|
+
self.key_dim = self.key_dim_arg or dim // self.num_heads
|
|
93
|
+
self.value_dim = self.value_dim_arg or dim // self.num_heads
|
|
94
|
+
self.scale = self.key_dim**-0.5
|
|
95
|
+
query_layers = []
|
|
96
|
+
if self.has_query_strides:
|
|
97
|
+
pool_padding = "valid" if self.padding == "valid" else "same"
|
|
98
|
+
query_layers.append(
|
|
99
|
+
keras.layers.AveragePooling2D(
|
|
100
|
+
pool_size=self.query_strides,
|
|
101
|
+
strides=self.query_strides,
|
|
102
|
+
padding=pool_padding,
|
|
103
|
+
data_format=self.data_format,
|
|
104
|
+
name="query_down_pool",
|
|
105
|
+
dtype=self.dtype_policy,
|
|
106
|
+
)
|
|
107
|
+
)
|
|
108
|
+
if self.norm_layer is RmsNorm2d:
|
|
109
|
+
norm = self.norm_layer(
|
|
110
|
+
dim=dim,
|
|
111
|
+
channel_axis=self.channel_axis,
|
|
112
|
+
data_format=self.data_format,
|
|
113
|
+
name="query_norm",
|
|
114
|
+
dtype=self.dtype_policy,
|
|
115
|
+
)
|
|
116
|
+
else:
|
|
117
|
+
norm = self.norm_layer(
|
|
118
|
+
axis=self.channel_axis,
|
|
119
|
+
name="query_norm",
|
|
120
|
+
gamma_initializer="ones",
|
|
121
|
+
beta_initializer="zeros",
|
|
122
|
+
dtype=self.dtype_policy,
|
|
123
|
+
)
|
|
124
|
+
query_layers.append(norm)
|
|
125
|
+
query_layers.append(
|
|
126
|
+
keras.layers.Conv2D(
|
|
127
|
+
filters=self.num_heads * self.key_dim,
|
|
128
|
+
kernel_size=1,
|
|
129
|
+
use_bias=self.use_bias,
|
|
130
|
+
data_format=self.data_format,
|
|
131
|
+
name="query_proj",
|
|
132
|
+
kernel_initializer=self.conv_kernel_initializer,
|
|
133
|
+
bias_initializer=self.bias_initializer,
|
|
134
|
+
dtype=self.dtype_policy,
|
|
135
|
+
)
|
|
136
|
+
)
|
|
137
|
+
self.query_layers = query_layers
|
|
138
|
+
key_layers = []
|
|
139
|
+
if self.kv_stride > 1:
|
|
140
|
+
key_layers.append(
|
|
141
|
+
keras.layers.DepthwiseConv2D(
|
|
142
|
+
kernel_size=self.dw_kernel_size,
|
|
143
|
+
strides=self.kv_stride,
|
|
144
|
+
dilation_rate=self.dilation,
|
|
145
|
+
padding=self.padding,
|
|
146
|
+
data_format=self.data_format,
|
|
147
|
+
name="key_down_conv",
|
|
148
|
+
depthwise_initializer=self.conv_kernel_initializer,
|
|
149
|
+
bias_initializer=self.bias_initializer,
|
|
150
|
+
use_bias=False,
|
|
151
|
+
dtype=self.dtype_policy,
|
|
152
|
+
)
|
|
153
|
+
)
|
|
154
|
+
if self.norm_layer is RmsNorm2d:
|
|
155
|
+
norm = self.norm_layer(
|
|
156
|
+
dim=dim,
|
|
157
|
+
channel_axis=self.channel_axis,
|
|
158
|
+
data_format=self.data_format,
|
|
159
|
+
name="key_norm",
|
|
160
|
+
dtype=self.dtype_policy,
|
|
161
|
+
)
|
|
162
|
+
else:
|
|
163
|
+
norm = self.norm_layer(
|
|
164
|
+
axis=self.channel_axis,
|
|
165
|
+
gamma_initializer="ones",
|
|
166
|
+
beta_initializer="zeros",
|
|
167
|
+
name="key_norm",
|
|
168
|
+
dtype=self.dtype_policy,
|
|
169
|
+
)
|
|
170
|
+
key_layers.append(norm)
|
|
171
|
+
key_layers.append(
|
|
172
|
+
keras.layers.Conv2D(
|
|
173
|
+
filters=self.key_dim,
|
|
174
|
+
kernel_size=1,
|
|
175
|
+
padding="valid",
|
|
176
|
+
use_bias=self.use_bias,
|
|
177
|
+
data_format=self.data_format,
|
|
178
|
+
name="key_proj",
|
|
179
|
+
kernel_initializer=self.conv_kernel_initializer,
|
|
180
|
+
bias_initializer=self.bias_initializer,
|
|
181
|
+
dtype=self.dtype_policy,
|
|
182
|
+
)
|
|
183
|
+
)
|
|
184
|
+
self.key_layers = key_layers
|
|
185
|
+
value_layers = []
|
|
186
|
+
if self.kv_stride > 1:
|
|
187
|
+
value_layers.append(
|
|
188
|
+
keras.layers.DepthwiseConv2D(
|
|
189
|
+
kernel_size=self.dw_kernel_size,
|
|
190
|
+
strides=self.kv_stride,
|
|
191
|
+
dilation_rate=self.dilation,
|
|
192
|
+
padding=self.padding,
|
|
193
|
+
data_format=self.data_format,
|
|
194
|
+
name="value_down_conv",
|
|
195
|
+
depthwise_initializer=self.conv_kernel_initializer,
|
|
196
|
+
bias_initializer=self.bias_initializer,
|
|
197
|
+
use_bias=False,
|
|
198
|
+
dtype=self.dtype_policy,
|
|
199
|
+
)
|
|
200
|
+
)
|
|
201
|
+
if self.norm_layer is RmsNorm2d:
|
|
202
|
+
norm = self.norm_layer(
|
|
203
|
+
dim=dim,
|
|
204
|
+
channel_axis=self.channel_axis,
|
|
205
|
+
data_format=self.data_format,
|
|
206
|
+
name="value_norm",
|
|
207
|
+
dtype=self.dtype_policy,
|
|
208
|
+
)
|
|
209
|
+
else:
|
|
210
|
+
norm = self.norm_layer(
|
|
211
|
+
axis=self.channel_axis,
|
|
212
|
+
gamma_initializer="ones",
|
|
213
|
+
beta_initializer="zeros",
|
|
214
|
+
name="value_norm",
|
|
215
|
+
dtype=self.dtype_policy,
|
|
216
|
+
)
|
|
217
|
+
value_layers.append(norm)
|
|
218
|
+
value_layers.append(
|
|
219
|
+
keras.layers.Conv2D(
|
|
220
|
+
filters=self.value_dim,
|
|
221
|
+
kernel_size=1,
|
|
222
|
+
padding="valid",
|
|
223
|
+
use_bias=self.use_bias,
|
|
224
|
+
data_format=self.data_format,
|
|
225
|
+
name="value_proj",
|
|
226
|
+
kernel_initializer=self.conv_kernel_initializer,
|
|
227
|
+
bias_initializer=self.bias_initializer,
|
|
228
|
+
dtype=self.dtype_policy,
|
|
229
|
+
)
|
|
230
|
+
)
|
|
231
|
+
self.value_layers = value_layers
|
|
232
|
+
output_layers = []
|
|
233
|
+
if self.has_query_strides:
|
|
234
|
+
output_layers.append(
|
|
235
|
+
keras.layers.UpSampling2D(
|
|
236
|
+
size=self.query_strides,
|
|
237
|
+
interpolation="bilinear",
|
|
238
|
+
data_format=self.data_format,
|
|
239
|
+
name="output_upsample",
|
|
240
|
+
dtype=self.dtype_policy,
|
|
241
|
+
)
|
|
242
|
+
)
|
|
243
|
+
output_layers.append(
|
|
244
|
+
keras.layers.Conv2D(
|
|
245
|
+
filters=self.filters,
|
|
246
|
+
kernel_size=1,
|
|
247
|
+
use_bias=self.use_bias,
|
|
248
|
+
data_format=self.data_format,
|
|
249
|
+
name="output_proj",
|
|
250
|
+
kernel_initializer=self.conv_kernel_initializer,
|
|
251
|
+
bias_initializer=self.bias_initializer,
|
|
252
|
+
dtype=self.dtype_policy,
|
|
253
|
+
)
|
|
254
|
+
)
|
|
255
|
+
output_layers.append(
|
|
256
|
+
keras.layers.Dropout(self.proj_drop_rate, dtype=self.dtype_policy)
|
|
257
|
+
)
|
|
258
|
+
self.output_proj_layers = output_layers
|
|
259
|
+
|
|
260
|
+
def call(self, x, training=False):
|
|
261
|
+
B = keras.ops.shape(x)[0]
|
|
262
|
+
q = x
|
|
263
|
+
for layer in self.query_layers:
|
|
264
|
+
try:
|
|
265
|
+
q = layer(q, training=training)
|
|
266
|
+
except TypeError:
|
|
267
|
+
q = layer(q)
|
|
268
|
+
k = x
|
|
269
|
+
for layer in self.key_layers:
|
|
270
|
+
try:
|
|
271
|
+
k = layer(k, training=training)
|
|
272
|
+
except TypeError:
|
|
273
|
+
k = layer(k)
|
|
274
|
+
v = x
|
|
275
|
+
for layer in self.value_layers:
|
|
276
|
+
try:
|
|
277
|
+
v = layer(v, training=training)
|
|
278
|
+
except TypeError:
|
|
279
|
+
v = layer(v)
|
|
280
|
+
if self.data_format == "channels_last":
|
|
281
|
+
q = keras.ops.transpose(q, (0, 3, 1, 2))
|
|
282
|
+
k = keras.ops.transpose(k, (0, 3, 1, 2))
|
|
283
|
+
v = keras.ops.transpose(v, (0, 3, 1, 2))
|
|
284
|
+
s_q = keras.ops.shape(q)
|
|
285
|
+
h_q, w_q = s_q[2], s_q[3]
|
|
286
|
+
q = keras.ops.reshape(q, (B, self.num_heads, self.key_dim, -1))
|
|
287
|
+
q = keras.ops.transpose(q, (0, 1, 3, 2))
|
|
288
|
+
k = keras.ops.reshape(k, (B, self.key_dim, -1))
|
|
289
|
+
k = keras.ops.transpose(k, (0, 2, 1))
|
|
290
|
+
k = keras.ops.expand_dims(k, axis=1)
|
|
291
|
+
v = keras.ops.reshape(v, (B, self.value_dim, -1))
|
|
292
|
+
v = keras.ops.transpose(v, (0, 2, 1))
|
|
293
|
+
v = keras.ops.expand_dims(v, axis=1)
|
|
294
|
+
q = q * self.scale
|
|
295
|
+
attn = keras.ops.matmul(q, keras.ops.transpose(k, (0, 1, 3, 2)))
|
|
296
|
+
attn = keras.ops.softmax(attn, axis=-1)
|
|
297
|
+
attn = self.attn_drop_layer(attn, training=training)
|
|
298
|
+
o = keras.ops.matmul(attn, v)
|
|
299
|
+
o = keras.ops.transpose(o, (0, 2, 1, 3))
|
|
300
|
+
feat_dim = self.num_heads * self.value_dim
|
|
301
|
+
o = keras.ops.reshape(o, (B, h_q, w_q, feat_dim))
|
|
302
|
+
if self.data_format == "channels_first":
|
|
303
|
+
o = keras.ops.transpose(o, (0, 3, 1, 2))
|
|
304
|
+
x_out = o
|
|
305
|
+
for layer in self.output_proj_layers:
|
|
306
|
+
try:
|
|
307
|
+
x_out = layer(x_out, training=training)
|
|
308
|
+
except TypeError:
|
|
309
|
+
x_out = layer(x_out)
|
|
310
|
+
return x_out
|
|
311
|
+
|
|
312
|
+
def get_config(self):
|
|
313
|
+
config = super().get_config()
|
|
314
|
+
config.update(
|
|
315
|
+
{
|
|
316
|
+
"filters": self.filters,
|
|
317
|
+
"num_heads": self.num_heads,
|
|
318
|
+
"key_dim": self.key_dim_arg,
|
|
319
|
+
"value_dim": self.value_dim_arg,
|
|
320
|
+
"query_strides": self.query_strides_arg,
|
|
321
|
+
"kv_stride": self.kv_stride,
|
|
322
|
+
"dw_kernel_size": self.dw_kernel_size,
|
|
323
|
+
"dilation": self.dilation,
|
|
324
|
+
"padding": self.padding_arg,
|
|
325
|
+
"attn_drop": self.attn_drop_rate,
|
|
326
|
+
"proj_drop": self.proj_drop_rate,
|
|
327
|
+
"norm_layer": keras.saving.serialize_keras_object(
|
|
328
|
+
self.norm_layer
|
|
329
|
+
),
|
|
330
|
+
"use_bias": self.use_bias,
|
|
331
|
+
"channel_axis": self.channel_axis,
|
|
332
|
+
"data_format": self.data_format,
|
|
333
|
+
}
|
|
334
|
+
)
|
|
335
|
+
return config
|
|
336
|
+
|
|
337
|
+
@classmethod
|
|
338
|
+
def from_config(cls, config):
|
|
339
|
+
config["norm_layer"] = keras.saving.deserialize_keras_object(
|
|
340
|
+
config["norm_layer"]
|
|
341
|
+
)
|
|
342
|
+
return cls(**config)
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
class Attention2d(keras.layers.Layer):
|
|
346
|
+
"""Implements 2D Multi-Head Attention.
|
|
347
|
+
|
|
348
|
+
This layer performs multi-head self-attention on 2D spatial inputs.
|
|
349
|
+
|
|
350
|
+
Args:
|
|
351
|
+
filters: int. The output channel dimension.
|
|
352
|
+
num_heads: int. The number of attention heads.
|
|
353
|
+
bias: bool. If `True`, bias terms are used in the qkv and projection
|
|
354
|
+
convolutions.
|
|
355
|
+
attn_drop: float. The dropout rate for the attention weights.
|
|
356
|
+
proj_drop: float. The dropout rate for the output projection.
|
|
357
|
+
channel_axis: int. The axis representing the channels in the input
|
|
358
|
+
tensor.
|
|
359
|
+
data_format: str. The format of the input data, either
|
|
360
|
+
`"channels_last"` or `"channels_first"`.
|
|
361
|
+
"""
|
|
362
|
+
|
|
363
|
+
def __init__(
|
|
364
|
+
self,
|
|
365
|
+
filters,
|
|
366
|
+
num_heads=32,
|
|
367
|
+
bias=True,
|
|
368
|
+
attn_drop=0.0,
|
|
369
|
+
proj_drop=0.0,
|
|
370
|
+
channel_axis=None,
|
|
371
|
+
data_format=None,
|
|
372
|
+
dtype=None,
|
|
373
|
+
**kwargs,
|
|
374
|
+
):
|
|
375
|
+
super().__init__(dtype=dtype, **kwargs)
|
|
376
|
+
self.filters = filters
|
|
377
|
+
self.num_heads = num_heads
|
|
378
|
+
self.bias = bias
|
|
379
|
+
self.attn_drop_rate = attn_drop
|
|
380
|
+
self.proj_drop_rate = proj_drop
|
|
381
|
+
self.channel_axis = channel_axis
|
|
382
|
+
self.data_format = data_format
|
|
383
|
+
self.conv_kernel_initializer = keras.initializers.VarianceScaling(
|
|
384
|
+
scale=2.0, mode="fan_out", distribution="untruncated_normal"
|
|
385
|
+
)
|
|
386
|
+
self.bias_initializer = "zeros"
|
|
387
|
+
self.attn_drop_layer = keras.layers.Dropout(
|
|
388
|
+
attn_drop, dtype=self.dtype_policy
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
def build(self, input_shape):
|
|
392
|
+
super().build(input_shape)
|
|
393
|
+
dim = input_shape[self.channel_axis]
|
|
394
|
+
self.head_dim = dim // self.num_heads
|
|
395
|
+
self.qkv = keras.layers.Conv2D(
|
|
396
|
+
dim * 3,
|
|
397
|
+
kernel_size=1,
|
|
398
|
+
use_bias=self.bias,
|
|
399
|
+
data_format=self.data_format,
|
|
400
|
+
name="qkv",
|
|
401
|
+
dtype=self.dtype_policy,
|
|
402
|
+
kernel_initializer=self.conv_kernel_initializer,
|
|
403
|
+
bias_initializer=self.bias_initializer,
|
|
404
|
+
)
|
|
405
|
+
self.proj = keras.layers.Conv2D(
|
|
406
|
+
self.filters,
|
|
407
|
+
kernel_size=1,
|
|
408
|
+
use_bias=self.bias,
|
|
409
|
+
data_format=self.data_format,
|
|
410
|
+
name="proj",
|
|
411
|
+
dtype=self.dtype_policy,
|
|
412
|
+
kernel_initializer=self.conv_kernel_initializer,
|
|
413
|
+
bias_initializer=self.bias_initializer,
|
|
414
|
+
)
|
|
415
|
+
self.proj_drop_layer = keras.layers.Dropout(
|
|
416
|
+
self.proj_drop_rate, dtype=self.dtype_policy
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
def call(self, x, attn_mask=None, training=False):
|
|
420
|
+
if self.data_format == "channels_first":
|
|
421
|
+
B, C, H, W = keras.ops.shape(x)
|
|
422
|
+
else:
|
|
423
|
+
B, H, W, C = keras.ops.shape(x)
|
|
424
|
+
qkv = self.qkv(x)
|
|
425
|
+
if self.data_format == "channels_last":
|
|
426
|
+
qkv = keras.ops.transpose(qkv, (0, 3, 1, 2))
|
|
427
|
+
q, k, v = keras.ops.unstack(
|
|
428
|
+
keras.ops.reshape(
|
|
429
|
+
qkv,
|
|
430
|
+
(B, 3, self.num_heads, self.head_dim, H * W),
|
|
431
|
+
),
|
|
432
|
+
axis=1,
|
|
433
|
+
)
|
|
434
|
+
q = keras.ops.transpose(q, (0, 1, 3, 2))
|
|
435
|
+
k = keras.ops.transpose(k, (0, 1, 2, 3))
|
|
436
|
+
v = keras.ops.transpose(v, (0, 1, 3, 2))
|
|
437
|
+
attn = keras.ops.matmul(q, k) * (self.head_dim**-0.5)
|
|
438
|
+
if attn_mask is not None:
|
|
439
|
+
attn = attn + attn_mask
|
|
440
|
+
attn = keras.ops.softmax(attn, axis=-1)
|
|
441
|
+
attn = self.attn_drop_layer(attn, training=training)
|
|
442
|
+
x = keras.ops.matmul(attn, v)
|
|
443
|
+
x = keras.ops.transpose(x, (0, 1, 3, 2))
|
|
444
|
+
if self.data_format == "channels_first":
|
|
445
|
+
x = keras.ops.reshape(x, (B, -1, H, W))
|
|
446
|
+
else:
|
|
447
|
+
x = keras.ops.reshape(x, (B, H, W, -1))
|
|
448
|
+
x = self.proj(x)
|
|
449
|
+
x = self.proj_drop_layer(x, training=training)
|
|
450
|
+
return x
|
|
451
|
+
|
|
452
|
+
def get_config(self):
|
|
453
|
+
config = super().get_config()
|
|
454
|
+
config.update(
|
|
455
|
+
{
|
|
456
|
+
"filters": self.filters,
|
|
457
|
+
"num_heads": self.num_heads,
|
|
458
|
+
"bias": self.bias,
|
|
459
|
+
"attn_drop": self.attn_drop_rate,
|
|
460
|
+
"proj_drop": self.proj_drop_rate,
|
|
461
|
+
"channel_axis": self.channel_axis,
|
|
462
|
+
"data_format": self.data_format,
|
|
463
|
+
}
|
|
464
|
+
)
|
|
465
|
+
return config
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
class MobileAttention(keras.layers.Layer):
|
|
469
|
+
"""MobileNetV5 attention block.
|
|
470
|
+
|
|
471
|
+
This block combines attention with depthwise convolutions for efficiency.
|
|
472
|
+
It can use either standard Multi-Head Attention or Multi-Query Attention.
|
|
473
|
+
|
|
474
|
+
Args:
|
|
475
|
+
filters: int. The number of output channels.
|
|
476
|
+
stride: int. The stride for the block.
|
|
477
|
+
dw_kernel_size: int. The kernel size for the depthwise convolution in
|
|
478
|
+
Multi-Query Attention.
|
|
479
|
+
dilation: int. The dilation rate for convolutions.
|
|
480
|
+
pad_type: str. The padding type for convolutions.
|
|
481
|
+
num_heads: int. The number of attention heads.
|
|
482
|
+
key_dim: int. The dimension of the key.
|
|
483
|
+
value_dim: int. The dimension of the value.
|
|
484
|
+
use_multi_query: bool. If `True`, use `MultiQueryAttention2d`,
|
|
485
|
+
otherwise use `Attention2d`.
|
|
486
|
+
query_strides: tuple. The strides for the query downsampling.
|
|
487
|
+
kv_stride: int. The stride for key/value downsampling.
|
|
488
|
+
cpe_dw_kernel_size: int. The kernel size for the conditional position
|
|
489
|
+
encoding depthwise convolution.
|
|
490
|
+
noskip: bool. If `True`, the skip connection is disabled.
|
|
491
|
+
norm_layer: str. The normalization layer to use (`"batch_norm"` or
|
|
492
|
+
`"rms_norm"`).
|
|
493
|
+
drop_path_rate: float. The stochastic depth rate.
|
|
494
|
+
attn_drop: float. The dropout rate for the attention weights.
|
|
495
|
+
proj_drop: float. The dropout rate for the output projection.
|
|
496
|
+
layer_scale_init_value: float. The initial value for layer scale. If
|
|
497
|
+
`None`, layer scale is not used.
|
|
498
|
+
use_bias: bool. If `True`, bias terms are used in convolutions.
|
|
499
|
+
use_cpe: bool. If `True`, a conditional position encoding is added.
|
|
500
|
+
channel_axis: int. The axis representing the channels in the input
|
|
501
|
+
tensor.
|
|
502
|
+
data_format: str. The format of the input data, either
|
|
503
|
+
`"channels_last"` or `"channels_first"`.
|
|
504
|
+
"""
|
|
505
|
+
|
|
506
|
+
def __init__(
|
|
507
|
+
self,
|
|
508
|
+
filters,
|
|
509
|
+
stride=1,
|
|
510
|
+
dw_kernel_size=3,
|
|
511
|
+
dilation=1,
|
|
512
|
+
pad_type="same",
|
|
513
|
+
num_heads=8,
|
|
514
|
+
key_dim=64,
|
|
515
|
+
value_dim=64,
|
|
516
|
+
use_multi_query=False,
|
|
517
|
+
query_strides=(1, 1),
|
|
518
|
+
kv_stride=1,
|
|
519
|
+
cpe_dw_kernel_size=3,
|
|
520
|
+
noskip=False,
|
|
521
|
+
norm_layer="batch_norm",
|
|
522
|
+
drop_path_rate=0.0,
|
|
523
|
+
attn_drop=0.0,
|
|
524
|
+
proj_drop=0.0,
|
|
525
|
+
layer_scale_init_value=1e-5,
|
|
526
|
+
use_bias=False,
|
|
527
|
+
use_cpe=False,
|
|
528
|
+
channel_axis=None,
|
|
529
|
+
data_format=None,
|
|
530
|
+
dtype=None,
|
|
531
|
+
**kwargs,
|
|
532
|
+
):
|
|
533
|
+
super().__init__(dtype=dtype, **kwargs)
|
|
534
|
+
self.filters = filters
|
|
535
|
+
self.stride = stride
|
|
536
|
+
self.dw_kernel_size = dw_kernel_size
|
|
537
|
+
self.dilation = dilation
|
|
538
|
+
self.pad_type = pad_type
|
|
539
|
+
self.num_heads = num_heads
|
|
540
|
+
self.key_dim = key_dim
|
|
541
|
+
self.value_dim = value_dim
|
|
542
|
+
self.use_multi_query = use_multi_query
|
|
543
|
+
self.query_strides = query_strides
|
|
544
|
+
self.kv_stride = kv_stride
|
|
545
|
+
self.cpe_dw_kernel_size = cpe_dw_kernel_size
|
|
546
|
+
self.noskip = noskip
|
|
547
|
+
self.norm_layer_name = norm_layer
|
|
548
|
+
self.drop_path_rate = drop_path_rate
|
|
549
|
+
self.attn_drop_rate = attn_drop
|
|
550
|
+
self.proj_drop_rate = proj_drop
|
|
551
|
+
self.layer_scale_init_value = layer_scale_init_value
|
|
552
|
+
self.use_bias = use_bias
|
|
553
|
+
self.use_cpe = use_cpe
|
|
554
|
+
self.channel_axis = channel_axis
|
|
555
|
+
self.data_format = data_format
|
|
556
|
+
self.conv_kernel_initializer = keras.initializers.VarianceScaling(
|
|
557
|
+
scale=2.0, mode="fan_out", distribution="untruncated_normal"
|
|
558
|
+
)
|
|
559
|
+
self.bias_initializer = "zeros"
|
|
560
|
+
|
|
561
|
+
def build(self, input_shape):
|
|
562
|
+
super().build(input_shape)
|
|
563
|
+
in_chs = input_shape[self.channel_axis]
|
|
564
|
+
self.has_skip = (
|
|
565
|
+
self.stride == 1 and in_chs == self.filters
|
|
566
|
+
) and not self.noskip
|
|
567
|
+
if self.use_cpe:
|
|
568
|
+
self.conv_cpe_dw = keras.layers.DepthwiseConv2D(
|
|
569
|
+
kernel_size=self.cpe_dw_kernel_size,
|
|
570
|
+
strides=1,
|
|
571
|
+
padding="same",
|
|
572
|
+
dilation_rate=self.dilation,
|
|
573
|
+
use_bias=True,
|
|
574
|
+
data_format=self.data_format,
|
|
575
|
+
name="conv_cpe_dw",
|
|
576
|
+
depthwise_initializer=self.conv_kernel_initializer,
|
|
577
|
+
bias_initializer=self.bias_initializer,
|
|
578
|
+
dtype=self.dtype_policy,
|
|
579
|
+
)
|
|
580
|
+
else:
|
|
581
|
+
self.conv_cpe_dw = None
|
|
582
|
+
if self.norm_layer_name == "batch_norm":
|
|
583
|
+
self.norm = keras.layers.BatchNormalization(
|
|
584
|
+
axis=self.channel_axis,
|
|
585
|
+
name="norm",
|
|
586
|
+
gamma_initializer="ones",
|
|
587
|
+
beta_initializer="zeros",
|
|
588
|
+
dtype=self.dtype_policy,
|
|
589
|
+
)
|
|
590
|
+
elif self.norm_layer_name == "rms_norm":
|
|
591
|
+
self.norm = RmsNorm2d(
|
|
592
|
+
in_chs,
|
|
593
|
+
data_format=self.data_format,
|
|
594
|
+
gamma_initializer="ones",
|
|
595
|
+
channel_axis=self.channel_axis,
|
|
596
|
+
name="norm",
|
|
597
|
+
dtype=self.dtype_policy,
|
|
598
|
+
)
|
|
599
|
+
else:
|
|
600
|
+
raise ValueError(f"Unsupported norm_layer: {self.norm_layer_name}")
|
|
601
|
+
num_heads = self.num_heads
|
|
602
|
+
if num_heads is None:
|
|
603
|
+
assert in_chs % self.key_dim == 0
|
|
604
|
+
num_heads = in_chs // self.key_dim
|
|
605
|
+
attn_norm_layer = (
|
|
606
|
+
RmsNorm2d
|
|
607
|
+
if self.norm_layer_name == "rms_norm"
|
|
608
|
+
else keras.layers.BatchNormalization
|
|
609
|
+
)
|
|
610
|
+
if self.use_multi_query:
|
|
611
|
+
self.attn = MultiQueryAttention2d(
|
|
612
|
+
filters=self.filters,
|
|
613
|
+
num_heads=num_heads,
|
|
614
|
+
key_dim=self.key_dim,
|
|
615
|
+
value_dim=self.value_dim,
|
|
616
|
+
query_strides=self.query_strides,
|
|
617
|
+
kv_stride=self.kv_stride,
|
|
618
|
+
dw_kernel_size=self.dw_kernel_size,
|
|
619
|
+
dilation=self.dilation,
|
|
620
|
+
padding=self.pad_type,
|
|
621
|
+
attn_drop=self.attn_drop_rate,
|
|
622
|
+
proj_drop=self.proj_drop_rate,
|
|
623
|
+
norm_layer=attn_norm_layer,
|
|
624
|
+
use_bias=self.use_bias,
|
|
625
|
+
channel_axis=self.channel_axis,
|
|
626
|
+
data_format=self.data_format,
|
|
627
|
+
name="attn",
|
|
628
|
+
dtype=self.dtype_policy,
|
|
629
|
+
)
|
|
630
|
+
else:
|
|
631
|
+
self.attn = Attention2d(
|
|
632
|
+
filters=self.filters,
|
|
633
|
+
num_heads=num_heads,
|
|
634
|
+
attn_drop=self.attn_drop_rate,
|
|
635
|
+
proj_drop=self.proj_drop_rate,
|
|
636
|
+
bias=self.use_bias,
|
|
637
|
+
channel_axis=self.channel_axis,
|
|
638
|
+
data_format=self.data_format,
|
|
639
|
+
name="attn",
|
|
640
|
+
dtype=self.dtype_policy,
|
|
641
|
+
)
|
|
642
|
+
if self.layer_scale_init_value is not None:
|
|
643
|
+
self.layer_scale = LayerScale2d(
|
|
644
|
+
self.filters,
|
|
645
|
+
self.layer_scale_init_value,
|
|
646
|
+
name="layer_scale",
|
|
647
|
+
channel_axis=self.channel_axis,
|
|
648
|
+
data_format=self.data_format,
|
|
649
|
+
dtype=self.dtype_policy,
|
|
650
|
+
)
|
|
651
|
+
else:
|
|
652
|
+
self.layer_scale = lambda x: x
|
|
653
|
+
self.drop_path = (
|
|
654
|
+
DropPath(self.drop_path_rate, dtype=self.dtype_policy)
|
|
655
|
+
if self.drop_path_rate > 0.0
|
|
656
|
+
else lambda x, training: x
|
|
657
|
+
)
|
|
658
|
+
|
|
659
|
+
def call(self, x, training=False):
|
|
660
|
+
if self.conv_cpe_dw is not None:
|
|
661
|
+
x = x + self.conv_cpe_dw(x)
|
|
662
|
+
shortcut = x
|
|
663
|
+
x_normed = self.norm(x, training=training)
|
|
664
|
+
x_attn = self.attn(x_normed, training=training)
|
|
665
|
+
x_scaled = self.layer_scale(x_attn)
|
|
666
|
+
if self.has_skip:
|
|
667
|
+
return self.drop_path(x_scaled, training=training) + shortcut
|
|
668
|
+
else:
|
|
669
|
+
return x_scaled
|
|
670
|
+
|
|
671
|
+
def get_config(self):
|
|
672
|
+
config = super().get_config()
|
|
673
|
+
config.update(
|
|
674
|
+
{
|
|
675
|
+
"filters": self.filters,
|
|
676
|
+
"stride": self.stride,
|
|
677
|
+
"dw_kernel_size": self.dw_kernel_size,
|
|
678
|
+
"dilation": self.dilation,
|
|
679
|
+
"pad_type": self.pad_type,
|
|
680
|
+
"num_heads": self.num_heads,
|
|
681
|
+
"key_dim": self.key_dim,
|
|
682
|
+
"value_dim": self.value_dim,
|
|
683
|
+
"use_multi_query": self.use_multi_query,
|
|
684
|
+
"query_strides": self.query_strides,
|
|
685
|
+
"kv_stride": self.kv_stride,
|
|
686
|
+
"cpe_dw_kernel_size": self.cpe_dw_kernel_size,
|
|
687
|
+
"noskip": self.noskip,
|
|
688
|
+
"norm_layer": self.norm_layer_name,
|
|
689
|
+
"drop_path_rate": self.drop_path_rate,
|
|
690
|
+
"attn_drop": self.attn_drop_rate,
|
|
691
|
+
"proj_drop": self.proj_drop_rate,
|
|
692
|
+
"layer_scale_init_value": self.layer_scale_init_value,
|
|
693
|
+
"use_bias": self.use_bias,
|
|
694
|
+
"use_cpe": self.use_cpe,
|
|
695
|
+
"channel_axis": self.channel_axis,
|
|
696
|
+
"data_format": self.data_format,
|
|
697
|
+
}
|
|
698
|
+
)
|
|
699
|
+
return config
|