keras-hub-nightly 0.23.0.dev202510090417__py3-none-any.whl → 0.23.0.dev202510110411__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of keras-hub-nightly might be problematic. Click here for more details.
- keras_hub/layers/__init__.py +3 -0
- keras_hub/models/__init__.py +9 -0
- keras_hub/src/models/mobilenetv5/__init__.py +0 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_attention.py +699 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_backbone.py +396 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_blocks.py +890 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_builder.py +436 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier.py +157 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier_preprocessor.py +16 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_image_converter.py +10 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_layers.py +462 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_utils.py +146 -0
- keras_hub/src/models/qwen3_moe/__init__.py +5 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +30 -0
- keras_hub/src/utils/preset_utils.py +9 -2
- keras_hub/src/utils/timm/convert_mobilenetv5.py +321 -0
- keras_hub/src/utils/timm/preset_loader.py +8 -4
- keras_hub/src/version.py +1 -1
- {keras_hub_nightly-0.23.0.dev202510090417.dist-info → keras_hub_nightly-0.23.0.dev202510110411.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.23.0.dev202510090417.dist-info → keras_hub_nightly-0.23.0.dev202510110411.dist-info}/RECORD +22 -9
- {keras_hub_nightly-0.23.0.dev202510090417.dist-info → keras_hub_nightly-0.23.0.dev202510110411.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.23.0.dev202510090417.dist-info → keras_hub_nightly-0.23.0.dev202510110411.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,462 @@
|
|
|
1
|
+
import keras
|
|
2
|
+
|
|
3
|
+
from keras_hub.src.models.mobilenet.util import adjust_channels
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class DropPath(keras.layers.Layer):
|
|
7
|
+
"""Implements the DropPath layer.
|
|
8
|
+
|
|
9
|
+
DropPath is a form of stochastic depth, where connections are randomly
|
|
10
|
+
dropped during training.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
drop_prob: float. The probability of dropping a path.
|
|
14
|
+
scale_by_keep: bool. If `True`, scale the output by `1/keep_prob`.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, drop_prob=0.0, scale_by_keep=True, dtype=None, **kwargs):
|
|
18
|
+
super().__init__(dtype=dtype, **kwargs)
|
|
19
|
+
self.drop_prob = drop_prob
|
|
20
|
+
self.scale_by_keep = scale_by_keep
|
|
21
|
+
|
|
22
|
+
def call(self, x, training=False):
|
|
23
|
+
if self.drop_prob == 0.0 or not training:
|
|
24
|
+
return x
|
|
25
|
+
keep_prob = 1.0 - self.drop_prob
|
|
26
|
+
shape = (keras.ops.shape(x)[0],) + (1,) * (len(x.shape) - 1)
|
|
27
|
+
random_tensor = keep_prob + keras.random.uniform(
|
|
28
|
+
shape, 0, 1, dtype=x.dtype
|
|
29
|
+
)
|
|
30
|
+
random_tensor = keras.ops.floor(random_tensor)
|
|
31
|
+
if keep_prob > 0.0 and self.scale_by_keep:
|
|
32
|
+
random_tensor = random_tensor / keep_prob
|
|
33
|
+
return x * random_tensor
|
|
34
|
+
|
|
35
|
+
def get_config(self):
|
|
36
|
+
config = super().get_config()
|
|
37
|
+
config.update(
|
|
38
|
+
{"drop_prob": self.drop_prob, "scale_by_keep": self.scale_by_keep}
|
|
39
|
+
)
|
|
40
|
+
return config
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class LayerScale2d(keras.layers.Layer):
|
|
44
|
+
"""A layer that applies a learnable scaling factor to the input tensor.
|
|
45
|
+
|
|
46
|
+
This layer scales the input tensor by a learnable `gamma` parameter. The
|
|
47
|
+
scaling is applied channel-wise.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
dim: int. The number of channels in the input tensor.
|
|
51
|
+
init_values: float. The initial value for the `gamma` parameter.
|
|
52
|
+
data_format: str. The format of the input data, either
|
|
53
|
+
`"channels_last"` or `"channels_first"`.
|
|
54
|
+
channel_axis: int. The axis representing the channels in the input
|
|
55
|
+
tensor.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
def __init__(
|
|
59
|
+
self,
|
|
60
|
+
dim,
|
|
61
|
+
init_values=1e-5,
|
|
62
|
+
data_format=None,
|
|
63
|
+
channel_axis=None,
|
|
64
|
+
dtype=None,
|
|
65
|
+
**kwargs,
|
|
66
|
+
):
|
|
67
|
+
super().__init__(dtype=dtype, **kwargs)
|
|
68
|
+
self.dim = dim
|
|
69
|
+
self.init_values = init_values
|
|
70
|
+
self.data_format = data_format
|
|
71
|
+
self.channel_axis = channel_axis
|
|
72
|
+
|
|
73
|
+
def build(self, input_shape):
|
|
74
|
+
self.gamma = self.add_weight(
|
|
75
|
+
shape=(self.dim,),
|
|
76
|
+
initializer=keras.initializers.Constant(self.init_values),
|
|
77
|
+
trainable=True,
|
|
78
|
+
name="gamma",
|
|
79
|
+
)
|
|
80
|
+
super().build(input_shape)
|
|
81
|
+
|
|
82
|
+
def call(self, x):
|
|
83
|
+
if self.data_format == "channels_first":
|
|
84
|
+
gamma = keras.ops.reshape(self.gamma, (1, self.dim, 1, 1))
|
|
85
|
+
else:
|
|
86
|
+
gamma = keras.ops.reshape(self.gamma, (1, 1, 1, self.dim))
|
|
87
|
+
return x * gamma
|
|
88
|
+
|
|
89
|
+
def get_config(self):
|
|
90
|
+
config = super().get_config()
|
|
91
|
+
config.update(
|
|
92
|
+
{
|
|
93
|
+
"dim": self.dim,
|
|
94
|
+
"init_values": self.init_values,
|
|
95
|
+
"data_format": self.data_format,
|
|
96
|
+
"channel_axis": self.channel_axis,
|
|
97
|
+
}
|
|
98
|
+
)
|
|
99
|
+
return config
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class RmsNorm2d(keras.layers.Layer):
|
|
103
|
+
"""A layer that applies Root Mean Square Normalization to a 2D input.
|
|
104
|
+
|
|
105
|
+
This layer normalizes the input tensor along the channel dimension using
|
|
106
|
+
the root mean square of the values, and then scales it by a learnable
|
|
107
|
+
`gamma` parameter.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
dim: int. The number of channels in the input tensor.
|
|
111
|
+
eps: float. A small epsilon value to avoid division by zero.
|
|
112
|
+
data_format: str. The format of the input data, either
|
|
113
|
+
`"channels_last"` or `"channels_first"`.
|
|
114
|
+
channel_axis: int. The axis representing the channels in the input
|
|
115
|
+
tensor.
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
def __init__(
|
|
119
|
+
self,
|
|
120
|
+
dim,
|
|
121
|
+
eps=1e-6,
|
|
122
|
+
data_format=None,
|
|
123
|
+
channel_axis=None,
|
|
124
|
+
gamma_initializer="ones",
|
|
125
|
+
dtype=None,
|
|
126
|
+
**kwargs,
|
|
127
|
+
):
|
|
128
|
+
super().__init__(dtype=dtype, **kwargs)
|
|
129
|
+
self.dim = dim
|
|
130
|
+
self.eps = eps
|
|
131
|
+
self.data_format = data_format
|
|
132
|
+
self.channel_axis = channel_axis
|
|
133
|
+
self.gamma_initializer = gamma_initializer
|
|
134
|
+
|
|
135
|
+
def build(self, input_shape):
|
|
136
|
+
self.gamma = self.add_weight(
|
|
137
|
+
shape=(self.dim,),
|
|
138
|
+
initializer=self.gamma_initializer,
|
|
139
|
+
trainable=True,
|
|
140
|
+
name="gamma",
|
|
141
|
+
)
|
|
142
|
+
super().build(input_shape)
|
|
143
|
+
|
|
144
|
+
def call(self, x):
|
|
145
|
+
input_dtype = x.dtype
|
|
146
|
+
if self.data_format == "channels_first":
|
|
147
|
+
x_permuted = keras.ops.transpose(x, (0, 2, 3, 1))
|
|
148
|
+
else:
|
|
149
|
+
x_permuted = x
|
|
150
|
+
x_float = keras.ops.cast(x_permuted, "float32")
|
|
151
|
+
norm_factor = keras.ops.rsqrt(
|
|
152
|
+
keras.ops.mean(keras.ops.square(x_float), axis=-1, keepdims=True)
|
|
153
|
+
+ self.eps
|
|
154
|
+
)
|
|
155
|
+
norm_x_float = x_float * norm_factor
|
|
156
|
+
norm_x = keras.ops.cast(norm_x_float, input_dtype)
|
|
157
|
+
scaled_x = norm_x * self.gamma
|
|
158
|
+
if self.data_format == "channels_first":
|
|
159
|
+
output = keras.ops.transpose(scaled_x, (0, 3, 1, 2))
|
|
160
|
+
else:
|
|
161
|
+
output = scaled_x
|
|
162
|
+
return output
|
|
163
|
+
|
|
164
|
+
def get_config(self):
|
|
165
|
+
config = super().get_config()
|
|
166
|
+
config.update(
|
|
167
|
+
{
|
|
168
|
+
"dim": self.dim,
|
|
169
|
+
"eps": self.eps,
|
|
170
|
+
"data_format": self.data_format,
|
|
171
|
+
"channel_axis": self.channel_axis,
|
|
172
|
+
"gamma_initializer": self.gamma_initializer,
|
|
173
|
+
}
|
|
174
|
+
)
|
|
175
|
+
return config
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
class ConvNormAct(keras.layers.Layer):
|
|
179
|
+
"""A layer that combines convolution, normalization, and activation.
|
|
180
|
+
|
|
181
|
+
This layer provides a convenient way to create a sequence of a 2D
|
|
182
|
+
convolution, a normalization layer, and an activation function.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
out_chs: int. The number of output channels.
|
|
186
|
+
kernel_size: int or tuple. The size of the convolution kernel.
|
|
187
|
+
stride: int or tuple. The stride of the convolution.
|
|
188
|
+
dilation: int or tuple. The dilation rate of the convolution.
|
|
189
|
+
groups: int. The number of groups for a grouped convolution.
|
|
190
|
+
bias: bool. If `True`, a bias term is used in the convolution.
|
|
191
|
+
pad_type: str. The type of padding to use. `"same"` or `""` for same
|
|
192
|
+
padding, otherwise valid padding.
|
|
193
|
+
apply_act: bool. If `True`, an activation function is applied.
|
|
194
|
+
act_layer: str. The name of the activation function to use.
|
|
195
|
+
norm_layer: str. The name of the normalization layer to use.
|
|
196
|
+
Supported values are `"batch_norm"` and `"rms_norm"`.
|
|
197
|
+
data_format: str. The format of the input data, either
|
|
198
|
+
`"channels_last"` or `"channels_first"`.
|
|
199
|
+
channel_axis: int. The axis representing the channels in the input
|
|
200
|
+
tensor.
|
|
201
|
+
"""
|
|
202
|
+
|
|
203
|
+
def __init__(
|
|
204
|
+
self,
|
|
205
|
+
out_chs,
|
|
206
|
+
kernel_size,
|
|
207
|
+
stride=1,
|
|
208
|
+
dilation=1,
|
|
209
|
+
groups=1,
|
|
210
|
+
bias=False,
|
|
211
|
+
pad_type="same",
|
|
212
|
+
apply_act=True,
|
|
213
|
+
act_layer="relu",
|
|
214
|
+
norm_layer="batch_norm",
|
|
215
|
+
data_format=None,
|
|
216
|
+
channel_axis=None,
|
|
217
|
+
dtype=None,
|
|
218
|
+
**kwargs,
|
|
219
|
+
):
|
|
220
|
+
super().__init__(dtype=dtype, **kwargs)
|
|
221
|
+
self.out_chs = out_chs
|
|
222
|
+
self.kernel_size = kernel_size
|
|
223
|
+
self.stride = stride
|
|
224
|
+
self.dilation = dilation
|
|
225
|
+
self.groups = groups
|
|
226
|
+
self.bias = bias
|
|
227
|
+
self.pad_type = pad_type
|
|
228
|
+
self.apply_act = apply_act
|
|
229
|
+
self.act_layer = act_layer
|
|
230
|
+
self.norm_layer = norm_layer
|
|
231
|
+
self.data_format = data_format
|
|
232
|
+
self.channel_axis = channel_axis
|
|
233
|
+
self.kernel_initializer = keras.initializers.VarianceScaling(
|
|
234
|
+
scale=2.0, mode="fan_out", distribution="untruncated_normal"
|
|
235
|
+
)
|
|
236
|
+
self.bias_initializer = "zeros"
|
|
237
|
+
padding_mode = "valid"
|
|
238
|
+
if pad_type.lower() == "" or pad_type.lower() == "same":
|
|
239
|
+
padding_mode = "same"
|
|
240
|
+
|
|
241
|
+
self.conv = keras.layers.Conv2D(
|
|
242
|
+
out_chs,
|
|
243
|
+
kernel_size,
|
|
244
|
+
strides=stride,
|
|
245
|
+
padding=padding_mode,
|
|
246
|
+
dilation_rate=dilation,
|
|
247
|
+
groups=groups,
|
|
248
|
+
use_bias=bias,
|
|
249
|
+
data_format=self.data_format,
|
|
250
|
+
kernel_initializer=self.kernel_initializer,
|
|
251
|
+
bias_initializer=self.bias_initializer,
|
|
252
|
+
dtype=self.dtype_policy,
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
if norm_layer == "batch_norm":
|
|
256
|
+
self.norm = keras.layers.BatchNormalization(
|
|
257
|
+
axis=self.channel_axis,
|
|
258
|
+
epsilon=1e-5,
|
|
259
|
+
gamma_initializer="ones",
|
|
260
|
+
beta_initializer="zeros",
|
|
261
|
+
dtype=self.dtype_policy,
|
|
262
|
+
)
|
|
263
|
+
elif norm_layer == "rms_norm":
|
|
264
|
+
self.norm = RmsNorm2d(
|
|
265
|
+
out_chs,
|
|
266
|
+
data_format=self.data_format,
|
|
267
|
+
channel_axis=self.channel_axis,
|
|
268
|
+
gamma_initializer="ones",
|
|
269
|
+
dtype=self.dtype_policy,
|
|
270
|
+
)
|
|
271
|
+
else:
|
|
272
|
+
ln_axis = [1, 2, 3]
|
|
273
|
+
if self.data_format == "channels_first":
|
|
274
|
+
ln_axis = [2, 3, 1]
|
|
275
|
+
self.norm = keras.layers.LayerNormalization(
|
|
276
|
+
axis=ln_axis,
|
|
277
|
+
dtype=self.dtype_policy,
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
if self.apply_act:
|
|
281
|
+
if act_layer == "gelu":
|
|
282
|
+
self.act = keras.layers.Activation(
|
|
283
|
+
lambda x: keras.activations.gelu(x, approximate=False),
|
|
284
|
+
dtype=self.dtype_policy,
|
|
285
|
+
)
|
|
286
|
+
else:
|
|
287
|
+
self.act = keras.layers.Activation(
|
|
288
|
+
act_layer,
|
|
289
|
+
dtype=self.dtype_policy,
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
def build(self, input_shape):
|
|
293
|
+
self.conv.build(input_shape)
|
|
294
|
+
conv_output_shape = self.conv.compute_output_shape(input_shape)
|
|
295
|
+
self.norm.build(conv_output_shape)
|
|
296
|
+
if self.apply_act:
|
|
297
|
+
self.act.build(conv_output_shape)
|
|
298
|
+
self.built = True
|
|
299
|
+
|
|
300
|
+
def call(self, x, training=False):
|
|
301
|
+
x = self.conv(x)
|
|
302
|
+
x = self.norm(x, training=training)
|
|
303
|
+
if self.apply_act:
|
|
304
|
+
x = self.act(x)
|
|
305
|
+
return x
|
|
306
|
+
|
|
307
|
+
def compute_output_shape(self, input_shape):
|
|
308
|
+
return self.conv.compute_output_shape(input_shape)
|
|
309
|
+
|
|
310
|
+
def get_config(self):
|
|
311
|
+
config = super().get_config()
|
|
312
|
+
config.update(
|
|
313
|
+
{
|
|
314
|
+
"out_chs": self.out_chs,
|
|
315
|
+
"kernel_size": self.kernel_size,
|
|
316
|
+
"stride": self.stride,
|
|
317
|
+
"dilation": self.dilation,
|
|
318
|
+
"groups": self.groups,
|
|
319
|
+
"bias": self.bias,
|
|
320
|
+
"pad_type": self.pad_type,
|
|
321
|
+
"apply_act": self.apply_act,
|
|
322
|
+
"act_layer": self.act_layer,
|
|
323
|
+
"norm_layer": self.norm_layer,
|
|
324
|
+
"data_format": self.data_format,
|
|
325
|
+
"channel_axis": self.channel_axis,
|
|
326
|
+
}
|
|
327
|
+
)
|
|
328
|
+
return config
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
class SEModule(keras.layers.Layer):
|
|
332
|
+
"""Implements the Squeeze-and-Excitation (SE) module.
|
|
333
|
+
|
|
334
|
+
The SE module adaptively recalibrates channel-wise feature responses by
|
|
335
|
+
explicitly modeling interdependencies between channels.
|
|
336
|
+
|
|
337
|
+
Args:
|
|
338
|
+
channels: int. The number of input channels.
|
|
339
|
+
rd_ratio: float. The reduction ratio for the bottleneck channels.
|
|
340
|
+
rd_channels: int. The number of bottleneck channels. If specified,
|
|
341
|
+
`rd_ratio` is ignored.
|
|
342
|
+
rd_divisor: int. The divisor for rounding the number of bottleneck
|
|
343
|
+
channels.
|
|
344
|
+
add_maxpool: bool. If `True`, max pooling is used in addition to
|
|
345
|
+
average pooling for the squeeze operation.
|
|
346
|
+
bias: bool. If `True`, bias terms are used in the fully connected
|
|
347
|
+
layers.
|
|
348
|
+
act_layer: str. The activation function for the bottleneck layer.
|
|
349
|
+
norm_layer: str. The normalization layer to use.
|
|
350
|
+
data_format: str. The format of the input data, either
|
|
351
|
+
`"channels_last"` or `"channels_first"`.
|
|
352
|
+
channel_axis: int. The axis representing the channels in the input
|
|
353
|
+
tensor.
|
|
354
|
+
gate_layer: str. The gating activation function.
|
|
355
|
+
"""
|
|
356
|
+
|
|
357
|
+
def __init__(
|
|
358
|
+
self,
|
|
359
|
+
channels,
|
|
360
|
+
rd_ratio=1.0 / 16,
|
|
361
|
+
rd_channels=None,
|
|
362
|
+
rd_divisor=8,
|
|
363
|
+
add_maxpool=False,
|
|
364
|
+
bias=True,
|
|
365
|
+
act_layer="relu",
|
|
366
|
+
norm_layer=None,
|
|
367
|
+
data_format=None,
|
|
368
|
+
channel_axis=None,
|
|
369
|
+
gate_layer="sigmoid",
|
|
370
|
+
dtype=None,
|
|
371
|
+
**kwargs,
|
|
372
|
+
):
|
|
373
|
+
super().__init__(dtype=dtype, **kwargs)
|
|
374
|
+
self.channels = channels
|
|
375
|
+
self.add_maxpool = add_maxpool
|
|
376
|
+
if not rd_channels:
|
|
377
|
+
rd_channels = adjust_channels(
|
|
378
|
+
channels * rd_ratio, rd_divisor, round_limit=0.0
|
|
379
|
+
)
|
|
380
|
+
self.rd_ratio = rd_ratio
|
|
381
|
+
self.rd_channels = rd_channels
|
|
382
|
+
self.rd_divisor = rd_divisor
|
|
383
|
+
self.bias = bias
|
|
384
|
+
self.act_layer_arg = act_layer
|
|
385
|
+
self.kernel_initializer = keras.initializers.VarianceScaling(
|
|
386
|
+
scale=2.0, mode="fan_out", distribution="untruncated_normal"
|
|
387
|
+
)
|
|
388
|
+
self.bias_initializer = "zeros"
|
|
389
|
+
self.norm_layer_arg = norm_layer
|
|
390
|
+
self.gate_layer_arg = gate_layer
|
|
391
|
+
self.data_format = data_format
|
|
392
|
+
self.channel_axis = channel_axis
|
|
393
|
+
self.mean_axis = [2, 3] if data_format == "channels_first" else [1, 2]
|
|
394
|
+
self.fc1 = keras.layers.Conv2D(
|
|
395
|
+
rd_channels,
|
|
396
|
+
kernel_size=1,
|
|
397
|
+
use_bias=bias,
|
|
398
|
+
name="fc1",
|
|
399
|
+
data_format=self.data_format,
|
|
400
|
+
kernel_initializer=self.kernel_initializer,
|
|
401
|
+
bias_initializer=self.bias_initializer,
|
|
402
|
+
dtype=self.dtype_policy,
|
|
403
|
+
)
|
|
404
|
+
self.bn = (
|
|
405
|
+
keras.layers.BatchNormalization(
|
|
406
|
+
axis=channel_axis, dtype=self.dtype_policy
|
|
407
|
+
)
|
|
408
|
+
if norm_layer
|
|
409
|
+
else (lambda x, training: x)
|
|
410
|
+
)
|
|
411
|
+
self.act = keras.layers.Activation(act_layer, dtype=self.dtype_policy)
|
|
412
|
+
self.fc2 = keras.layers.Conv2D(
|
|
413
|
+
channels,
|
|
414
|
+
kernel_size=1,
|
|
415
|
+
use_bias=bias,
|
|
416
|
+
name="fc2",
|
|
417
|
+
data_format=self.data_format,
|
|
418
|
+
kernel_initializer=self.kernel_initializer,
|
|
419
|
+
bias_initializer=self.bias_initializer,
|
|
420
|
+
dtype=self.dtype_policy,
|
|
421
|
+
)
|
|
422
|
+
self.gate = keras.layers.Activation(gate_layer, dtype=self.dtype_policy)
|
|
423
|
+
|
|
424
|
+
def build(self, input_shape):
|
|
425
|
+
self.fc1.build(input_shape)
|
|
426
|
+
fc1_output_shape = self.fc1.compute_output_shape(input_shape)
|
|
427
|
+
if hasattr(self.bn, "build"):
|
|
428
|
+
self.bn.build(fc1_output_shape)
|
|
429
|
+
self.act.build(fc1_output_shape)
|
|
430
|
+
self.fc2.build(fc1_output_shape)
|
|
431
|
+
self.built = True
|
|
432
|
+
|
|
433
|
+
def call(self, x, training=False):
|
|
434
|
+
x_se = keras.ops.mean(x, axis=self.mean_axis, keepdims=True)
|
|
435
|
+
if self.add_maxpool:
|
|
436
|
+
x_se = 0.5 * x_se + 0.5 * keras.ops.max(
|
|
437
|
+
x, axis=self.mean_axis, keepdims=True
|
|
438
|
+
)
|
|
439
|
+
x_se = self.fc1(x_se)
|
|
440
|
+
x_se = self.bn(x_se, training=training)
|
|
441
|
+
x_se = self.act(x_se)
|
|
442
|
+
x_se = self.fc2(x_se)
|
|
443
|
+
return x * self.gate(x_se)
|
|
444
|
+
|
|
445
|
+
def get_config(self):
|
|
446
|
+
config = super().get_config()
|
|
447
|
+
config.update(
|
|
448
|
+
{
|
|
449
|
+
"channels": self.channels,
|
|
450
|
+
"rd_ratio": self.rd_ratio,
|
|
451
|
+
"rd_channels": self.rd_channels,
|
|
452
|
+
"rd_divisor": self.rd_divisor,
|
|
453
|
+
"add_maxpool": self.add_maxpool,
|
|
454
|
+
"bias": self.bias,
|
|
455
|
+
"act_layer": self.act_layer_arg,
|
|
456
|
+
"norm_layer": self.norm_layer_arg,
|
|
457
|
+
"gate_layer": self.gate_layer_arg,
|
|
458
|
+
"data_format": self.data_format,
|
|
459
|
+
"channel_axis": self.channel_axis,
|
|
460
|
+
}
|
|
461
|
+
)
|
|
462
|
+
return config
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
import keras
|
|
2
|
+
|
|
3
|
+
from keras_hub.src.models.mobilenet.util import adjust_channels
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def num_groups(group_size, channels):
|
|
7
|
+
if not group_size:
|
|
8
|
+
return 1
|
|
9
|
+
else:
|
|
10
|
+
if channels % group_size != 0:
|
|
11
|
+
raise ValueError(
|
|
12
|
+
f"Number of channels ({channels}) must be divisible by "
|
|
13
|
+
"group size ({group_size})."
|
|
14
|
+
)
|
|
15
|
+
return channels // group_size
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def parse_ksize(ss):
|
|
19
|
+
if ss.isdigit():
|
|
20
|
+
return int(ss)
|
|
21
|
+
else:
|
|
22
|
+
return [int(k) for k in ss.split(".")]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def round_channels(
|
|
26
|
+
channels, multiplier=1.0, divisor=8, channel_min=None, round_limit=0.9
|
|
27
|
+
):
|
|
28
|
+
if not multiplier:
|
|
29
|
+
return channels
|
|
30
|
+
return adjust_channels(channels * multiplier, divisor, channel_min)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def feature_take_indices(num_stages, indices):
|
|
34
|
+
if not isinstance(indices, (tuple, list)):
|
|
35
|
+
indices = (indices,)
|
|
36
|
+
if any(i < 0 for i in indices):
|
|
37
|
+
indices = [i if i >= 0 else num_stages + i for i in indices]
|
|
38
|
+
return indices, max(indices)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class SelectAdaptivePool2d(keras.layers.Layer):
|
|
42
|
+
"""A layer that selects and applies a 2D adaptive pooling strategy.
|
|
43
|
+
|
|
44
|
+
This layer supports various pooling types like average, max, or a
|
|
45
|
+
combination of both. It can also flatten the output.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
pool_type: str. The type of pooling to apply. One of `"avg"`, `"max"`,
|
|
49
|
+
`"avgmax"`, `"catavgmax"`, or `""` (identity).
|
|
50
|
+
flatten: bool. If `True`, the output is flattened after pooling.
|
|
51
|
+
data_format: str. The format of the input data, either
|
|
52
|
+
`"channels_last"` or `"channels_first"`.
|
|
53
|
+
channel_axis: int. The axis representing the channels in the input
|
|
54
|
+
tensor.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
pool_type="avg",
|
|
60
|
+
flatten=False,
|
|
61
|
+
data_format=None,
|
|
62
|
+
channel_axis=None,
|
|
63
|
+
dtype=None,
|
|
64
|
+
**kwargs,
|
|
65
|
+
):
|
|
66
|
+
super().__init__(dtype=dtype, **kwargs)
|
|
67
|
+
self.pool_type = pool_type.lower()
|
|
68
|
+
self.flatten = flatten
|
|
69
|
+
self.data_format = data_format
|
|
70
|
+
self.channels_axis = channel_axis
|
|
71
|
+
self.pool = None
|
|
72
|
+
self.pool_avg = None
|
|
73
|
+
self.pool_max = None
|
|
74
|
+
self.pool_cat = None
|
|
75
|
+
self.flatten_layer = None
|
|
76
|
+
if self.pool_type not in ("avg", "max", "avgmax", "catavgmax", ""):
|
|
77
|
+
raise ValueError(f"Invalid pool type: {self.pool_type}")
|
|
78
|
+
|
|
79
|
+
def build(self, input_shape):
|
|
80
|
+
if self.pool_type == "avg":
|
|
81
|
+
self.pool = keras.layers.GlobalAveragePooling2D(
|
|
82
|
+
data_format=self.data_format,
|
|
83
|
+
keepdims=not self.flatten,
|
|
84
|
+
dtype=self.dtype_policy,
|
|
85
|
+
)
|
|
86
|
+
elif self.pool_type == "max":
|
|
87
|
+
self.pool = keras.layers.GlobalMaxPooling2D(
|
|
88
|
+
data_format=self.data_format,
|
|
89
|
+
keepdims=not self.flatten,
|
|
90
|
+
dtype=self.dtype_policy,
|
|
91
|
+
)
|
|
92
|
+
elif self.pool_type in ("avgmax", "catavgmax"):
|
|
93
|
+
self.pool_avg = keras.layers.GlobalAveragePooling2D(
|
|
94
|
+
data_format=self.data_format,
|
|
95
|
+
keepdims=not self.flatten,
|
|
96
|
+
dtype=self.dtype_policy,
|
|
97
|
+
)
|
|
98
|
+
self.pool_max = keras.layers.GlobalMaxPooling2D(
|
|
99
|
+
data_format=self.data_format,
|
|
100
|
+
keepdims=not self.flatten,
|
|
101
|
+
dtype=self.dtype_policy,
|
|
102
|
+
)
|
|
103
|
+
if self.pool_type == "catavgmax":
|
|
104
|
+
axis = 1 if self.data_format == "channels_first" else -1
|
|
105
|
+
self.pool_cat = keras.layers.Concatenate(
|
|
106
|
+
axis=axis, dtype=self.dtype_policy
|
|
107
|
+
)
|
|
108
|
+
elif not self.pool_type:
|
|
109
|
+
self.pool = keras.layers.Identity(dtype=self.dtype_policy)
|
|
110
|
+
if self.flatten:
|
|
111
|
+
self.flatten_layer = keras.layers.Flatten(
|
|
112
|
+
dtype=self.dtype_policy
|
|
113
|
+
)
|
|
114
|
+
super().build(input_shape)
|
|
115
|
+
|
|
116
|
+
def call(self, x):
|
|
117
|
+
if self.pool_type in ("avg", "max"):
|
|
118
|
+
return self.pool(x)
|
|
119
|
+
elif self.pool_type == "avgmax":
|
|
120
|
+
x_avg = self.pool_avg(x)
|
|
121
|
+
x_max = self.pool_max(x)
|
|
122
|
+
return 0.5 * (x_avg + x_max)
|
|
123
|
+
elif self.pool_type == "catavgmax":
|
|
124
|
+
x_avg = self.pool_avg(x)
|
|
125
|
+
x_max = self.pool_max(x)
|
|
126
|
+
return self.pool_cat([x_avg, x_max])
|
|
127
|
+
elif not self.pool_type:
|
|
128
|
+
x = self.pool(x)
|
|
129
|
+
if self.flatten_layer:
|
|
130
|
+
x = self.flatten_layer(x)
|
|
131
|
+
return x
|
|
132
|
+
return x
|
|
133
|
+
|
|
134
|
+
def feat_mult(self):
|
|
135
|
+
return 2 if self.pool_type == "catavgmax" else 1
|
|
136
|
+
|
|
137
|
+
def get_config(self):
|
|
138
|
+
config = super().get_config()
|
|
139
|
+
config.update(
|
|
140
|
+
{
|
|
141
|
+
"pool_type": self.pool_type,
|
|
142
|
+
"flatten": self.flatten,
|
|
143
|
+
"data_format": self.data_format,
|
|
144
|
+
}
|
|
145
|
+
)
|
|
146
|
+
return config
|
|
@@ -0,0 +1,5 @@
|
|
|
1
|
+
from keras_hub.src.models.qwen3_moe.qwen3_moe_backbone import Qwen3MoeBackbone
|
|
2
|
+
from keras_hub.src.models.qwen3_moe.qwen3_moe_presets import backbone_presets
|
|
3
|
+
from keras_hub.src.utils.preset_utils import register_presets
|
|
4
|
+
|
|
5
|
+
register_presets(backbone_presets, Qwen3MoeBackbone)
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""Qwen3 MoE model preset configurations."""
|
|
2
|
+
|
|
3
|
+
backbone_presets = {
|
|
4
|
+
"qwen3_moe_30b_a3b_en": {
|
|
5
|
+
"metadata": {
|
|
6
|
+
"description": (
|
|
7
|
+
" Mixture-of-Experts (MoE) model has 30.5 billion total"
|
|
8
|
+
" parameters with 3.3 billion activated, built on 48 layers"
|
|
9
|
+
" and utilizes 32 query and 4 key/value attention heads"
|
|
10
|
+
" with 128 experts (8 active)."
|
|
11
|
+
),
|
|
12
|
+
"params": 30532122624,
|
|
13
|
+
"path": "qwen3_moe",
|
|
14
|
+
},
|
|
15
|
+
"kaggle_handle": "kaggle://keras/qwen-3-moe/keras/qwen3_moe_30b_a3b_en/2",
|
|
16
|
+
},
|
|
17
|
+
"qwen3_moe_235b_a22b_en": {
|
|
18
|
+
"metadata": {
|
|
19
|
+
"description": (
|
|
20
|
+
" Mixture-of-Experts (MoE) model has 235 billion"
|
|
21
|
+
" total parameters with 22 billion activated, built on 94"
|
|
22
|
+
" layers and utilizes 64 query and 4 key/value attention heads"
|
|
23
|
+
" with 128 experts (8 active)."
|
|
24
|
+
),
|
|
25
|
+
"params": 235093634560,
|
|
26
|
+
"path": "qwen3_moe",
|
|
27
|
+
},
|
|
28
|
+
"kaggle_handle": "kaggle://keras/qwen-3-moe/keras/qwen3_moe_235b_a22b_en/1",
|
|
29
|
+
},
|
|
30
|
+
}
|
|
@@ -502,10 +502,17 @@ def jax_memory_cleanup(layer):
|
|
|
502
502
|
# For jax, delete all previous allocated memory to avoid temporarily
|
|
503
503
|
# duplicating variable allocations. torch and tensorflow have stateful
|
|
504
504
|
# variable types and do not need this fix.
|
|
505
|
+
# Skip deletion for sharded arrays to avoid breaking references in
|
|
506
|
+
# distributed setups.
|
|
505
507
|
if keras.config.backend() == "jax":
|
|
506
508
|
for weight in layer.weights:
|
|
507
|
-
if
|
|
508
|
-
|
|
509
|
+
if weight._value is not None:
|
|
510
|
+
# Do not delete sharded arrays, as they may be referenced in
|
|
511
|
+
# JAX's distributed computation graph and deletion can cause
|
|
512
|
+
# errors.
|
|
513
|
+
sharding = getattr(weight._value, "sharding", None)
|
|
514
|
+
if sharding is None:
|
|
515
|
+
weight._value.delete()
|
|
509
516
|
|
|
510
517
|
|
|
511
518
|
def set_dtype_in_config(config, dtype=None):
|