keras-hub-nightly 0.23.0.dev202510100415__py3-none-any.whl → 0.23.0.dev202510110411__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of keras-hub-nightly might be problematic. Click here for more details.
- keras_hub/layers/__init__.py +3 -0
- keras_hub/models/__init__.py +9 -0
- keras_hub/src/models/mobilenetv5/__init__.py +0 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_attention.py +699 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_backbone.py +396 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_blocks.py +890 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_builder.py +436 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier.py +157 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier_preprocessor.py +16 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_image_converter.py +10 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_layers.py +462 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_utils.py +146 -0
- keras_hub/src/utils/preset_utils.py +9 -2
- keras_hub/src/utils/timm/convert_mobilenetv5.py +321 -0
- keras_hub/src/utils/timm/preset_loader.py +8 -4
- keras_hub/src/version.py +1 -1
- {keras_hub_nightly-0.23.0.dev202510100415.dist-info → keras_hub_nightly-0.23.0.dev202510110411.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.23.0.dev202510100415.dist-info → keras_hub_nightly-0.23.0.dev202510110411.dist-info}/RECORD +20 -9
- {keras_hub_nightly-0.23.0.dev202510100415.dist-info → keras_hub_nightly-0.23.0.dev202510110411.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.23.0.dev202510100415.dist-info → keras_hub_nightly-0.23.0.dev202510110411.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,890 @@
|
|
|
1
|
+
import keras
|
|
2
|
+
|
|
3
|
+
from keras_hub.src.models.mobilenet.util import adjust_channels
|
|
4
|
+
from keras_hub.src.models.mobilenetv5.mobilenetv5_layers import ConvNormAct
|
|
5
|
+
from keras_hub.src.models.mobilenetv5.mobilenetv5_layers import DropPath
|
|
6
|
+
from keras_hub.src.models.mobilenetv5.mobilenetv5_layers import LayerScale2d
|
|
7
|
+
from keras_hub.src.models.mobilenetv5.mobilenetv5_layers import RmsNorm2d
|
|
8
|
+
from keras_hub.src.models.mobilenetv5.mobilenetv5_utils import num_groups
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class UniversalInvertedResidual(keras.layers.Layer):
|
|
12
|
+
"""Universal Inverted Residual block.
|
|
13
|
+
|
|
14
|
+
This block is a flexible and universal version of the inverted residual
|
|
15
|
+
block, which can be configured to behave like different variants of mobile
|
|
16
|
+
convolutional blocks.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
filters: int. The number of output channels.
|
|
20
|
+
dw_kernel_size_start: int. The kernel size for the initial depthwise
|
|
21
|
+
convolution. If 0, this layer is skipped.
|
|
22
|
+
dw_kernel_size_mid: int. The kernel size for the middle depthwise
|
|
23
|
+
convolution. If 0, this layer is skipped.
|
|
24
|
+
dw_kernel_size_end: int. The kernel size for the final depthwise
|
|
25
|
+
convolution. If 0, this layer is skipped.
|
|
26
|
+
stride: int. The stride for the block.
|
|
27
|
+
dilation: int. The dilation rate for convolutions.
|
|
28
|
+
pad_type: str. The padding type for convolutions.
|
|
29
|
+
noskip: bool. If `True`, the skip connection is disabled.
|
|
30
|
+
exp_ratio: float. The expansion ratio for the middle channels.
|
|
31
|
+
act_layer: str. The activation function to use.
|
|
32
|
+
norm_layer: str. The normalization layer to use.
|
|
33
|
+
se_layer: keras.layers.Layer. The Squeeze-and-Excitation layer to use.
|
|
34
|
+
drop_path_rate: float. The stochastic depth rate.
|
|
35
|
+
layer_scale_init_value: float. The initial value for layer scale. If
|
|
36
|
+
`None`, layer scale is not used.
|
|
37
|
+
data_format: str. The format of the input data, either
|
|
38
|
+
`"channels_last"` or `"channels_first"`.
|
|
39
|
+
channel_axis: int. The axis representing the channels in the input
|
|
40
|
+
tensor.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
filters,
|
|
46
|
+
dw_kernel_size_start=0,
|
|
47
|
+
dw_kernel_size_mid=3,
|
|
48
|
+
dw_kernel_size_end=0,
|
|
49
|
+
stride=1,
|
|
50
|
+
dilation=1,
|
|
51
|
+
pad_type="same",
|
|
52
|
+
noskip=False,
|
|
53
|
+
exp_ratio=1.0,
|
|
54
|
+
act_layer="relu",
|
|
55
|
+
norm_layer="batch_norm",
|
|
56
|
+
se_layer=None,
|
|
57
|
+
drop_path_rate=0.0,
|
|
58
|
+
layer_scale_init_value=1e-5,
|
|
59
|
+
data_format=None,
|
|
60
|
+
channel_axis=None,
|
|
61
|
+
dtype=None,
|
|
62
|
+
**kwargs,
|
|
63
|
+
):
|
|
64
|
+
super().__init__(dtype=dtype, **kwargs)
|
|
65
|
+
self.filters = filters
|
|
66
|
+
self.dw_kernel_size_start = dw_kernel_size_start
|
|
67
|
+
self.dw_kernel_size_mid = dw_kernel_size_mid
|
|
68
|
+
self.dw_kernel_size_end = dw_kernel_size_end
|
|
69
|
+
self.stride = stride
|
|
70
|
+
self.dilation = dilation
|
|
71
|
+
self.pad_type = pad_type
|
|
72
|
+
self.noskip = noskip
|
|
73
|
+
self.exp_ratio = exp_ratio
|
|
74
|
+
self.act_layer = act_layer
|
|
75
|
+
self.norm_layer = norm_layer
|
|
76
|
+
self.se_layer = se_layer
|
|
77
|
+
self.drop_path_rate = drop_path_rate
|
|
78
|
+
self.layer_scale_init_value = layer_scale_init_value
|
|
79
|
+
self.data_format = data_format
|
|
80
|
+
self.channel_axis = channel_axis
|
|
81
|
+
|
|
82
|
+
def build(self, input_shape):
|
|
83
|
+
super().build(input_shape)
|
|
84
|
+
in_chs = input_shape[self.channel_axis]
|
|
85
|
+
self.has_skip = (
|
|
86
|
+
in_chs == self.filters and self.stride == 1
|
|
87
|
+
) and not self.noskip
|
|
88
|
+
use_bias = self.norm_layer == "rms_norm"
|
|
89
|
+
|
|
90
|
+
if self.dw_kernel_size_start:
|
|
91
|
+
self.dw_start = ConvNormAct(
|
|
92
|
+
in_chs,
|
|
93
|
+
self.dw_kernel_size_start,
|
|
94
|
+
stride=self.stride if not self.dw_kernel_size_mid else 1,
|
|
95
|
+
dilation=self.dilation,
|
|
96
|
+
groups=in_chs,
|
|
97
|
+
pad_type=self.pad_type,
|
|
98
|
+
apply_act=False,
|
|
99
|
+
act_layer=self.act_layer,
|
|
100
|
+
norm_layer=self.norm_layer,
|
|
101
|
+
bias=use_bias,
|
|
102
|
+
data_format=self.data_format,
|
|
103
|
+
channel_axis=self.channel_axis,
|
|
104
|
+
dtype=self.dtype_policy,
|
|
105
|
+
)
|
|
106
|
+
else:
|
|
107
|
+
self.dw_start = lambda x, training=False: x
|
|
108
|
+
|
|
109
|
+
mid_chs = adjust_channels(in_chs * self.exp_ratio)
|
|
110
|
+
self.pw_exp = ConvNormAct(
|
|
111
|
+
mid_chs,
|
|
112
|
+
1,
|
|
113
|
+
pad_type=self.pad_type,
|
|
114
|
+
act_layer=self.act_layer,
|
|
115
|
+
norm_layer=self.norm_layer,
|
|
116
|
+
bias=use_bias,
|
|
117
|
+
data_format=self.data_format,
|
|
118
|
+
channel_axis=self.channel_axis,
|
|
119
|
+
dtype=self.dtype_policy,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
if self.dw_kernel_size_mid:
|
|
123
|
+
self.dw_mid = ConvNormAct(
|
|
124
|
+
mid_chs,
|
|
125
|
+
self.dw_kernel_size_mid,
|
|
126
|
+
stride=self.stride,
|
|
127
|
+
dilation=self.dilation,
|
|
128
|
+
groups=mid_chs,
|
|
129
|
+
pad_type=self.pad_type,
|
|
130
|
+
act_layer=self.act_layer,
|
|
131
|
+
norm_layer=self.norm_layer,
|
|
132
|
+
bias=use_bias,
|
|
133
|
+
data_format=self.data_format,
|
|
134
|
+
channel_axis=self.channel_axis,
|
|
135
|
+
dtype=self.dtype_policy,
|
|
136
|
+
)
|
|
137
|
+
else:
|
|
138
|
+
self.dw_mid = lambda x, training=False: x
|
|
139
|
+
self.se = (
|
|
140
|
+
self.se_layer(
|
|
141
|
+
filters=mid_chs,
|
|
142
|
+
bottleneck_filters=adjust_channels(mid_chs * 0.25),
|
|
143
|
+
squeeze_activation=self.act_layer,
|
|
144
|
+
excite_activation="sigmoid",
|
|
145
|
+
data_format=self.data_format,
|
|
146
|
+
channel_axis=self.channel_axis,
|
|
147
|
+
dtype=self.dtype_policy,
|
|
148
|
+
)
|
|
149
|
+
if self.se_layer
|
|
150
|
+
else (lambda x, training=False: x)
|
|
151
|
+
)
|
|
152
|
+
self.pw_proj = ConvNormAct(
|
|
153
|
+
self.filters,
|
|
154
|
+
1,
|
|
155
|
+
pad_type=self.pad_type,
|
|
156
|
+
apply_act=False,
|
|
157
|
+
act_layer=self.act_layer,
|
|
158
|
+
norm_layer=self.norm_layer,
|
|
159
|
+
bias=use_bias,
|
|
160
|
+
data_format=self.data_format,
|
|
161
|
+
channel_axis=self.channel_axis,
|
|
162
|
+
dtype=self.dtype_policy,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
if self.dw_kernel_size_end:
|
|
166
|
+
self.dw_end = ConvNormAct(
|
|
167
|
+
self.filters,
|
|
168
|
+
self.dw_kernel_size_end,
|
|
169
|
+
stride=self.stride
|
|
170
|
+
if not self.dw_kernel_size_start and not self.dw_kernel_size_mid
|
|
171
|
+
else 1,
|
|
172
|
+
dilation=self.dilation,
|
|
173
|
+
groups=self.filters,
|
|
174
|
+
pad_type=self.pad_type,
|
|
175
|
+
apply_act=False,
|
|
176
|
+
act_layer=self.act_layer,
|
|
177
|
+
norm_layer=self.norm_layer,
|
|
178
|
+
bias=use_bias,
|
|
179
|
+
data_format=self.data_format,
|
|
180
|
+
channel_axis=self.channel_axis,
|
|
181
|
+
dtype=self.dtype_policy,
|
|
182
|
+
)
|
|
183
|
+
else:
|
|
184
|
+
self.dw_end = lambda x, training=False: x
|
|
185
|
+
|
|
186
|
+
self.layer_scale = (
|
|
187
|
+
LayerScale2d(
|
|
188
|
+
self.filters,
|
|
189
|
+
self.layer_scale_init_value,
|
|
190
|
+
data_format=self.data_format,
|
|
191
|
+
channel_axis=self.channel_axis,
|
|
192
|
+
dtype=self.dtype_policy,
|
|
193
|
+
)
|
|
194
|
+
if self.layer_scale_init_value is not None
|
|
195
|
+
else lambda x: x
|
|
196
|
+
)
|
|
197
|
+
self.drop_path = (
|
|
198
|
+
DropPath(self.drop_path_rate, dtype=self.dtype_policy)
|
|
199
|
+
if self.drop_path_rate > 0.0
|
|
200
|
+
else (lambda x, training=False: x)
|
|
201
|
+
)
|
|
202
|
+
current_shape = input_shape
|
|
203
|
+
if hasattr(self.dw_start, "build"):
|
|
204
|
+
self.dw_start.build(current_shape)
|
|
205
|
+
current_shape = self.dw_start.compute_output_shape(current_shape)
|
|
206
|
+
self.pw_exp.build(current_shape)
|
|
207
|
+
current_shape = self.pw_exp.compute_output_shape(current_shape)
|
|
208
|
+
if hasattr(self.dw_mid, "build"):
|
|
209
|
+
self.dw_mid.build(current_shape)
|
|
210
|
+
current_shape = self.dw_mid.compute_output_shape(current_shape)
|
|
211
|
+
if hasattr(self.se, "build"):
|
|
212
|
+
self.se.build(current_shape)
|
|
213
|
+
self.pw_proj.build(current_shape)
|
|
214
|
+
current_shape = self.pw_proj.compute_output_shape(current_shape)
|
|
215
|
+
if hasattr(self.dw_end, "build"):
|
|
216
|
+
self.dw_end.build(current_shape)
|
|
217
|
+
current_shape = self.dw_end.compute_output_shape(current_shape)
|
|
218
|
+
if hasattr(self.layer_scale, "build"):
|
|
219
|
+
self.layer_scale.build(current_shape)
|
|
220
|
+
|
|
221
|
+
def call(self, x, training=False):
|
|
222
|
+
shortcut = x
|
|
223
|
+
x = self.dw_start(x, training=training)
|
|
224
|
+
x = self.pw_exp(x, training=training)
|
|
225
|
+
x = self.dw_mid(x, training=training)
|
|
226
|
+
x = self.se(x, training=training)
|
|
227
|
+
x = self.pw_proj(x, training=training)
|
|
228
|
+
x = self.dw_end(x, training=training)
|
|
229
|
+
x = self.layer_scale(x)
|
|
230
|
+
if self.has_skip:
|
|
231
|
+
x = self.drop_path(x, training=training) + shortcut
|
|
232
|
+
return x
|
|
233
|
+
|
|
234
|
+
def compute_output_shape(self, input_shape):
|
|
235
|
+
current_shape = input_shape
|
|
236
|
+
if hasattr(self.dw_start, "compute_output_shape"):
|
|
237
|
+
current_shape = self.dw_start.compute_output_shape(current_shape)
|
|
238
|
+
current_shape = self.pw_exp.compute_output_shape(current_shape)
|
|
239
|
+
if hasattr(self.dw_mid, "compute_output_shape"):
|
|
240
|
+
current_shape = self.dw_mid.compute_output_shape(current_shape)
|
|
241
|
+
current_shape = self.pw_proj.compute_output_shape(current_shape)
|
|
242
|
+
if hasattr(self.dw_end, "compute_output_shape"):
|
|
243
|
+
current_shape = self.dw_end.compute_output_shape(current_shape)
|
|
244
|
+
return current_shape
|
|
245
|
+
|
|
246
|
+
def get_config(self):
|
|
247
|
+
config = super().get_config()
|
|
248
|
+
config.update(
|
|
249
|
+
{
|
|
250
|
+
"filters": self.filters,
|
|
251
|
+
"dw_kernel_size_start": self.dw_kernel_size_start,
|
|
252
|
+
"dw_kernel_size_mid": self.dw_kernel_size_mid,
|
|
253
|
+
"dw_kernel_size_end": self.dw_kernel_size_end,
|
|
254
|
+
"stride": self.stride,
|
|
255
|
+
"dilation": self.dilation,
|
|
256
|
+
"pad_type": self.pad_type,
|
|
257
|
+
"noskip": self.noskip,
|
|
258
|
+
"exp_ratio": self.exp_ratio,
|
|
259
|
+
"act_layer": self.act_layer,
|
|
260
|
+
"norm_layer": self.norm_layer,
|
|
261
|
+
"se_layer": keras.saving.serialize_keras_object(self.se_layer),
|
|
262
|
+
"drop_path_rate": self.drop_path_rate,
|
|
263
|
+
"layer_scale_init_value": self.layer_scale_init_value,
|
|
264
|
+
"data_format": self.data_format,
|
|
265
|
+
"channel_axis": self.channel_axis,
|
|
266
|
+
}
|
|
267
|
+
)
|
|
268
|
+
return config
|
|
269
|
+
|
|
270
|
+
@classmethod
|
|
271
|
+
def from_config(cls, config):
|
|
272
|
+
config["se_layer"] = keras.saving.deserialize_keras_object(
|
|
273
|
+
config.pop("se_layer")
|
|
274
|
+
)
|
|
275
|
+
return cls(**config)
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
class EdgeResidual(keras.layers.Layer):
|
|
279
|
+
"""Edge Residual block.
|
|
280
|
+
|
|
281
|
+
This block is designed for efficiency on edge devices. It is a variant of
|
|
282
|
+
the inverted residual block that uses a single expansion convolution.
|
|
283
|
+
|
|
284
|
+
Args:
|
|
285
|
+
filters: int. The number of output channels.
|
|
286
|
+
exp_kernel_size: int. The kernel size for the expansion convolution.
|
|
287
|
+
stride: int. The stride for the block.
|
|
288
|
+
dilation: int. The dilation rate for convolutions.
|
|
289
|
+
group_size: int. The group size for grouped convolutions.
|
|
290
|
+
pad_type: str. The padding type for convolutions.
|
|
291
|
+
expansion_in_chs: int. If greater than 0, forces the number of input
|
|
292
|
+
channels for the expansion.
|
|
293
|
+
noskip: bool. If `True`, the skip connection is disabled.
|
|
294
|
+
exp_ratio: float. The expansion ratio for the middle channels.
|
|
295
|
+
pw_kernel_size: int. The kernel size for the pointwise convolution.
|
|
296
|
+
act_layer: str. The activation function to use.
|
|
297
|
+
norm_layer: str. The normalization layer to use.
|
|
298
|
+
se_layer: keras.layers.Layer. The Squeeze-and-Excitation layer to use.
|
|
299
|
+
drop_path_rate: float. The stochastic depth rate.
|
|
300
|
+
data_format: str. The format of the input data, either
|
|
301
|
+
`"channels_last"` or `"channels_first"`.
|
|
302
|
+
channel_axis: int. The axis representing the channels in the input
|
|
303
|
+
tensor.
|
|
304
|
+
"""
|
|
305
|
+
|
|
306
|
+
def __init__(
|
|
307
|
+
self,
|
|
308
|
+
filters,
|
|
309
|
+
exp_kernel_size=3,
|
|
310
|
+
stride=1,
|
|
311
|
+
dilation=1,
|
|
312
|
+
group_size=0,
|
|
313
|
+
pad_type="same",
|
|
314
|
+
expansion_in_chs=0,
|
|
315
|
+
noskip=False,
|
|
316
|
+
exp_ratio=1.0,
|
|
317
|
+
pw_kernel_size=1,
|
|
318
|
+
act_layer="relu",
|
|
319
|
+
norm_layer="batch_norm",
|
|
320
|
+
se_layer=None,
|
|
321
|
+
drop_path_rate=0.0,
|
|
322
|
+
data_format=None,
|
|
323
|
+
channel_axis=None,
|
|
324
|
+
dtype=None,
|
|
325
|
+
**kwargs,
|
|
326
|
+
):
|
|
327
|
+
super().__init__(dtype=dtype, **kwargs)
|
|
328
|
+
self.filters = filters
|
|
329
|
+
self.exp_kernel_size = exp_kernel_size
|
|
330
|
+
self.stride = stride
|
|
331
|
+
self.dilation = dilation
|
|
332
|
+
self.group_size = group_size
|
|
333
|
+
self.pad_type = pad_type
|
|
334
|
+
self.expansion_in_chs = expansion_in_chs
|
|
335
|
+
self.noskip = noskip
|
|
336
|
+
self.exp_ratio = exp_ratio
|
|
337
|
+
self.pw_kernel_size = pw_kernel_size
|
|
338
|
+
self.act_layer = act_layer
|
|
339
|
+
self.norm_layer = norm_layer
|
|
340
|
+
self.se_layer = se_layer
|
|
341
|
+
self.drop_path_rate = drop_path_rate
|
|
342
|
+
self.data_format = data_format
|
|
343
|
+
self.channel_axis = channel_axis
|
|
344
|
+
|
|
345
|
+
def build(self, input_shape):
|
|
346
|
+
super().build(input_shape)
|
|
347
|
+
in_chs = input_shape[self.channel_axis]
|
|
348
|
+
self.has_skip = (
|
|
349
|
+
in_chs == self.filters and self.stride == 1
|
|
350
|
+
) and not self.noskip
|
|
351
|
+
if self.expansion_in_chs > 0:
|
|
352
|
+
mid_chs = adjust_channels(self.expansion_in_chs * self.exp_ratio)
|
|
353
|
+
else:
|
|
354
|
+
mid_chs = adjust_channels(in_chs * self.exp_ratio)
|
|
355
|
+
groups = num_groups(self.group_size, mid_chs)
|
|
356
|
+
use_bias = self.norm_layer == "rms_norm"
|
|
357
|
+
self.conv_exp = ConvNormAct(
|
|
358
|
+
mid_chs,
|
|
359
|
+
self.exp_kernel_size,
|
|
360
|
+
stride=self.stride,
|
|
361
|
+
dilation=self.dilation,
|
|
362
|
+
groups=groups,
|
|
363
|
+
pad_type=self.pad_type,
|
|
364
|
+
norm_layer=self.norm_layer,
|
|
365
|
+
act_layer=self.act_layer,
|
|
366
|
+
bias=use_bias,
|
|
367
|
+
data_format=self.data_format,
|
|
368
|
+
channel_axis=self.channel_axis,
|
|
369
|
+
dtype=self.dtype_policy,
|
|
370
|
+
)
|
|
371
|
+
self.se = (
|
|
372
|
+
self.se_layer(
|
|
373
|
+
filters=mid_chs,
|
|
374
|
+
bottleneck_filters=adjust_channels(mid_chs * 0.25),
|
|
375
|
+
squeeze_activation=self.act_layer,
|
|
376
|
+
excite_activation="sigmoid",
|
|
377
|
+
data_format=self.data_format,
|
|
378
|
+
channel_axis=self.channel_axis,
|
|
379
|
+
dtype=self.dtype_policy,
|
|
380
|
+
)
|
|
381
|
+
if self.se_layer
|
|
382
|
+
else (lambda x, training=False: x)
|
|
383
|
+
)
|
|
384
|
+
self.conv_pwl = ConvNormAct(
|
|
385
|
+
self.filters,
|
|
386
|
+
self.pw_kernel_size,
|
|
387
|
+
pad_type=self.pad_type,
|
|
388
|
+
apply_act=False,
|
|
389
|
+
norm_layer=self.norm_layer,
|
|
390
|
+
act_layer=self.act_layer,
|
|
391
|
+
bias=use_bias,
|
|
392
|
+
data_format=self.data_format,
|
|
393
|
+
channel_axis=self.channel_axis,
|
|
394
|
+
dtype=self.dtype_policy,
|
|
395
|
+
)
|
|
396
|
+
self.drop_path = (
|
|
397
|
+
DropPath(self.drop_path_rate, dtype=self.dtype_policy)
|
|
398
|
+
if self.drop_path_rate > 0.0
|
|
399
|
+
else (lambda x, training=False: x)
|
|
400
|
+
)
|
|
401
|
+
self.conv_exp.build(input_shape)
|
|
402
|
+
conv_exp_output_shape = self.conv_exp.compute_output_shape(input_shape)
|
|
403
|
+
if hasattr(self.se, "build"):
|
|
404
|
+
self.se.build(conv_exp_output_shape)
|
|
405
|
+
self.conv_pwl.build(conv_exp_output_shape)
|
|
406
|
+
|
|
407
|
+
def call(self, x, training=False):
|
|
408
|
+
shortcut = x
|
|
409
|
+
x = self.conv_exp(x, training=training)
|
|
410
|
+
x = self.se(x, training=training)
|
|
411
|
+
x = self.conv_pwl(x, training=training)
|
|
412
|
+
if self.has_skip:
|
|
413
|
+
x = self.drop_path(x, training=training) + shortcut
|
|
414
|
+
return x
|
|
415
|
+
|
|
416
|
+
def get_config(self):
|
|
417
|
+
config = super().get_config()
|
|
418
|
+
config.update(
|
|
419
|
+
{
|
|
420
|
+
"filters": self.filters,
|
|
421
|
+
"exp_kernel_size": self.exp_kernel_size,
|
|
422
|
+
"stride": self.stride,
|
|
423
|
+
"dilation": self.dilation,
|
|
424
|
+
"group_size": self.group_size,
|
|
425
|
+
"pad_type": self.pad_type,
|
|
426
|
+
"expansion_in_chs": self.expansion_in_chs,
|
|
427
|
+
"noskip": self.noskip,
|
|
428
|
+
"exp_ratio": self.exp_ratio,
|
|
429
|
+
"pw_kernel_size": self.pw_kernel_size,
|
|
430
|
+
"act_layer": self.act_layer,
|
|
431
|
+
"norm_layer": self.norm_layer,
|
|
432
|
+
"se_layer": keras.saving.serialize_keras_object(self.se_layer),
|
|
433
|
+
"drop_path_rate": self.drop_path_rate,
|
|
434
|
+
"data_format": self.data_format,
|
|
435
|
+
"channel_axis": self.channel_axis,
|
|
436
|
+
}
|
|
437
|
+
)
|
|
438
|
+
return config
|
|
439
|
+
|
|
440
|
+
@classmethod
|
|
441
|
+
def from_config(cls, config):
|
|
442
|
+
config["se_layer"] = keras.saving.deserialize_keras_object(
|
|
443
|
+
config.pop("se_layer")
|
|
444
|
+
)
|
|
445
|
+
return cls(**config)
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
class CondConvResidual(keras.layers.Layer):
|
|
449
|
+
"""Conditionally Parameterized Convolutional Residual block.
|
|
450
|
+
|
|
451
|
+
This block uses a routing function to dynamically select and combine
|
|
452
|
+
different convolutional experts based on the input.
|
|
453
|
+
|
|
454
|
+
Args:
|
|
455
|
+
filters: int. The number of output channels.
|
|
456
|
+
dw_kernel_size: int. The kernel size for the depthwise convolution.
|
|
457
|
+
stride: int. The stride for the block.
|
|
458
|
+
dilation: int. The dilation rate for convolutions.
|
|
459
|
+
pad_type: str. The padding type for convolutions.
|
|
460
|
+
noskip: bool. If `True`, the skip connection is disabled.
|
|
461
|
+
exp_ratio: float. The expansion ratio for the middle channels.
|
|
462
|
+
exp_kernel_size: int. The kernel size for the expansion convolution.
|
|
463
|
+
pw_kernel_size: int. The kernel size for the pointwise convolution.
|
|
464
|
+
act_layer: str. The activation function to use.
|
|
465
|
+
se_layer: keras.layers.Layer. The Squeeze-and-Excitation layer to use.
|
|
466
|
+
num_experts: int. The number of experts to use.
|
|
467
|
+
drop_path_rate: float. The stochastic depth rate.
|
|
468
|
+
data_format: str. The format of the input data, either
|
|
469
|
+
`"channels_last"` or `"channels_first"`.
|
|
470
|
+
channel_axis: int. The axis representing the channels in the input
|
|
471
|
+
tensor.
|
|
472
|
+
"""
|
|
473
|
+
|
|
474
|
+
def __init__(
|
|
475
|
+
self,
|
|
476
|
+
filters,
|
|
477
|
+
dw_kernel_size=3,
|
|
478
|
+
stride=1,
|
|
479
|
+
dilation=1,
|
|
480
|
+
pad_type="same",
|
|
481
|
+
noskip=False,
|
|
482
|
+
exp_ratio=1.0,
|
|
483
|
+
exp_kernel_size=1,
|
|
484
|
+
pw_kernel_size=1,
|
|
485
|
+
act_layer="relu",
|
|
486
|
+
se_layer=None,
|
|
487
|
+
num_experts=0,
|
|
488
|
+
drop_path_rate=0.0,
|
|
489
|
+
data_format=None,
|
|
490
|
+
channel_axis=None,
|
|
491
|
+
dtype=None,
|
|
492
|
+
**kwargs,
|
|
493
|
+
):
|
|
494
|
+
super().__init__(dtype=dtype, **kwargs)
|
|
495
|
+
self.filters = filters
|
|
496
|
+
self.dw_kernel_size = dw_kernel_size
|
|
497
|
+
self.stride = stride
|
|
498
|
+
self.dilation = dilation
|
|
499
|
+
self.pad_type = pad_type
|
|
500
|
+
self.noskip = noskip
|
|
501
|
+
self.exp_ratio = exp_ratio
|
|
502
|
+
self.exp_kernel_size = exp_kernel_size
|
|
503
|
+
self.pw_kernel_size = pw_kernel_size
|
|
504
|
+
self.act_layer = act_layer
|
|
505
|
+
self.se_layer = se_layer
|
|
506
|
+
self.num_experts = num_experts
|
|
507
|
+
self.drop_path_rate = drop_path_rate
|
|
508
|
+
self.data_format = data_format
|
|
509
|
+
self.channel_axis = channel_axis
|
|
510
|
+
self.conv_kernel_initializer = keras.initializers.VarianceScaling(
|
|
511
|
+
scale=2.0, mode="fan_out", distribution="untruncated_normal"
|
|
512
|
+
)
|
|
513
|
+
self.dense_kernel_initializer = keras.initializers.VarianceScaling(
|
|
514
|
+
scale=1.0, mode="fan_in", distribution="uniform"
|
|
515
|
+
)
|
|
516
|
+
self.bias_initializer = "zeros"
|
|
517
|
+
|
|
518
|
+
def build(self, input_shape):
|
|
519
|
+
super().build(input_shape)
|
|
520
|
+
in_chs = input_shape[self.channel_axis]
|
|
521
|
+
self.has_skip = (
|
|
522
|
+
in_chs == self.filters and self.stride == 1
|
|
523
|
+
) and not self.noskip
|
|
524
|
+
mid_chs = adjust_channels(in_chs * self.exp_ratio)
|
|
525
|
+
self.routing_fn = keras.layers.Dense(
|
|
526
|
+
self.num_experts,
|
|
527
|
+
dtype=self.dtype_policy,
|
|
528
|
+
kernel_initializer=self.dense_kernel_initializer,
|
|
529
|
+
bias_initializer=self.bias_initializer,
|
|
530
|
+
)
|
|
531
|
+
self.pool = keras.layers.GlobalAveragePooling2D(
|
|
532
|
+
data_format=self.data_format, dtype=self.dtype_policy
|
|
533
|
+
)
|
|
534
|
+
self.conv_pw_experts = [
|
|
535
|
+
keras.layers.Conv2D(
|
|
536
|
+
filters=mid_chs,
|
|
537
|
+
kernel_size=self.exp_kernel_size,
|
|
538
|
+
padding=self.pad_type,
|
|
539
|
+
use_bias=True,
|
|
540
|
+
data_format=self.data_format,
|
|
541
|
+
name=f"conv_pw_expert_{i}",
|
|
542
|
+
kernel_initializer=self.conv_kernel_initializer,
|
|
543
|
+
bias_initializer=self.bias_initializer,
|
|
544
|
+
dtype=self.dtype_policy,
|
|
545
|
+
)
|
|
546
|
+
for i in range(self.num_experts)
|
|
547
|
+
]
|
|
548
|
+
self.conv_dw_experts = [
|
|
549
|
+
keras.layers.DepthwiseConv2D(
|
|
550
|
+
kernel_size=self.dw_kernel_size,
|
|
551
|
+
strides=self.stride,
|
|
552
|
+
padding=self.pad_type,
|
|
553
|
+
dilation_rate=self.dilation,
|
|
554
|
+
use_bias=True,
|
|
555
|
+
data_format=self.data_format,
|
|
556
|
+
name=f"conv_dw_expert_{i}",
|
|
557
|
+
depthwise_initializer=self.conv_kernel_initializer,
|
|
558
|
+
bias_initializer=self.bias_initializer,
|
|
559
|
+
dtype=self.dtype_policy,
|
|
560
|
+
)
|
|
561
|
+
for i in range(self.num_experts)
|
|
562
|
+
]
|
|
563
|
+
self.conv_pwl_experts = [
|
|
564
|
+
keras.layers.Conv2D(
|
|
565
|
+
filters=self.filters,
|
|
566
|
+
kernel_size=self.pw_kernel_size,
|
|
567
|
+
padding=self.pad_type,
|
|
568
|
+
use_bias=True,
|
|
569
|
+
data_format=self.data_format,
|
|
570
|
+
name=f"conv_pwl_expert_{i}",
|
|
571
|
+
kernel_initializer=self.conv_kernel_initializer,
|
|
572
|
+
bias_initializer=self.bias_initializer,
|
|
573
|
+
dtype=self.dtype_policy,
|
|
574
|
+
)
|
|
575
|
+
for i in range(self.num_experts)
|
|
576
|
+
]
|
|
577
|
+
self.bn1 = keras.layers.BatchNormalization(
|
|
578
|
+
axis=self.channel_axis,
|
|
579
|
+
dtype=self.dtype_policy,
|
|
580
|
+
gamma_initializer="ones",
|
|
581
|
+
beta_initializer="zeros",
|
|
582
|
+
)
|
|
583
|
+
self.act1 = keras.layers.Activation(
|
|
584
|
+
self.act_layer, dtype=self.dtype_policy
|
|
585
|
+
)
|
|
586
|
+
self.bn2 = keras.layers.BatchNormalization(
|
|
587
|
+
axis=self.channel_axis,
|
|
588
|
+
dtype=self.dtype_policy,
|
|
589
|
+
gamma_initializer="ones",
|
|
590
|
+
beta_initializer="zeros",
|
|
591
|
+
)
|
|
592
|
+
self.act2 = keras.layers.Activation(
|
|
593
|
+
self.act_layer, dtype=self.dtype_policy
|
|
594
|
+
)
|
|
595
|
+
self.bn3 = keras.layers.BatchNormalization(
|
|
596
|
+
axis=self.channel_axis,
|
|
597
|
+
dtype=self.dtype_policy,
|
|
598
|
+
gamma_initializer="ones",
|
|
599
|
+
beta_initializer="zeros",
|
|
600
|
+
)
|
|
601
|
+
self.se = (
|
|
602
|
+
self.se_layer(
|
|
603
|
+
filters=mid_chs,
|
|
604
|
+
bottleneck_filters=adjust_channels(mid_chs * 0.25),
|
|
605
|
+
squeeze_activation=self.act_layer,
|
|
606
|
+
excite_activation="sigmoid",
|
|
607
|
+
data_format=self.data_format,
|
|
608
|
+
channel_axis=self.channel_axis,
|
|
609
|
+
dtype=self.dtype_policy,
|
|
610
|
+
)
|
|
611
|
+
if self.se_layer
|
|
612
|
+
else (lambda x, training=False: x)
|
|
613
|
+
)
|
|
614
|
+
self.drop_path = (
|
|
615
|
+
DropPath(self.drop_path_rate, dtype=self.dtype_policy)
|
|
616
|
+
if self.drop_path_rate > 0.0
|
|
617
|
+
else (lambda x, training=False: x)
|
|
618
|
+
)
|
|
619
|
+
pooled_shape = self.pool.compute_output_shape(input_shape)
|
|
620
|
+
self.routing_fn.build(pooled_shape)
|
|
621
|
+
for expert in self.conv_pw_experts:
|
|
622
|
+
expert.build(input_shape)
|
|
623
|
+
pw_out_shape = self.conv_pw_experts[0].compute_output_shape(input_shape)
|
|
624
|
+
self.bn1.build(pw_out_shape)
|
|
625
|
+
for expert in self.conv_dw_experts:
|
|
626
|
+
expert.build(pw_out_shape)
|
|
627
|
+
dw_out_shape = self.conv_dw_experts[0].compute_output_shape(
|
|
628
|
+
pw_out_shape
|
|
629
|
+
)
|
|
630
|
+
self.bn2.build(dw_out_shape)
|
|
631
|
+
if hasattr(self.se, "build"):
|
|
632
|
+
self.se.build(dw_out_shape)
|
|
633
|
+
for expert in self.conv_pwl_experts:
|
|
634
|
+
expert.build(dw_out_shape)
|
|
635
|
+
pwl_out_shape = self.conv_pwl_experts[0].compute_output_shape(
|
|
636
|
+
dw_out_shape
|
|
637
|
+
)
|
|
638
|
+
self.bn3.build(pwl_out_shape)
|
|
639
|
+
|
|
640
|
+
def _apply_cond_conv(self, x, experts, routing_weights):
|
|
641
|
+
outputs = []
|
|
642
|
+
for i, expert in enumerate(experts):
|
|
643
|
+
expert_out = expert(x)
|
|
644
|
+
weight = keras.ops.reshape(routing_weights[:, i], (-1, 1, 1, 1))
|
|
645
|
+
outputs.append(expert_out * weight)
|
|
646
|
+
return keras.ops.sum(outputs, axis=0)
|
|
647
|
+
|
|
648
|
+
def call(self, x, training=False):
|
|
649
|
+
shortcut = x
|
|
650
|
+
pooled_inputs = self.pool(x)
|
|
651
|
+
routing_weights = keras.activations.sigmoid(
|
|
652
|
+
self.routing_fn(pooled_inputs)
|
|
653
|
+
)
|
|
654
|
+
x = self._apply_cond_conv(x, self.conv_pw_experts, routing_weights)
|
|
655
|
+
x = self.bn1(x, training=training)
|
|
656
|
+
x = self.act1(x)
|
|
657
|
+
x = self._apply_cond_conv(x, self.conv_dw_experts, routing_weights)
|
|
658
|
+
x = self.bn2(x, training=training)
|
|
659
|
+
x = self.act2(x)
|
|
660
|
+
x = self.se(x, training=training)
|
|
661
|
+
x = self._apply_cond_conv(x, self.conv_pwl_experts, routing_weights)
|
|
662
|
+
x = self.bn3(x, training=training)
|
|
663
|
+
if self.has_skip:
|
|
664
|
+
x = self.drop_path(x, training=training) + shortcut
|
|
665
|
+
return x
|
|
666
|
+
|
|
667
|
+
def get_config(self):
|
|
668
|
+
config = super().get_config()
|
|
669
|
+
config.update(
|
|
670
|
+
{
|
|
671
|
+
"filters": self.filters,
|
|
672
|
+
"dw_kernel_size": self.dw_kernel_size,
|
|
673
|
+
"stride": self.stride,
|
|
674
|
+
"dilation": self.dilation,
|
|
675
|
+
"pad_type": self.pad_type,
|
|
676
|
+
"noskip": self.noskip,
|
|
677
|
+
"exp_ratio": self.exp_ratio,
|
|
678
|
+
"exp_kernel_size": self.exp_kernel_size,
|
|
679
|
+
"pw_kernel_size": self.pw_kernel_size,
|
|
680
|
+
"act_layer": self.act_layer,
|
|
681
|
+
"se_layer": keras.saving.serialize_keras_object(self.se_layer),
|
|
682
|
+
"num_experts": self.num_experts,
|
|
683
|
+
"drop_path_rate": self.drop_path_rate,
|
|
684
|
+
"data_format": self.data_format,
|
|
685
|
+
"channel_axis": self.channel_axis,
|
|
686
|
+
}
|
|
687
|
+
)
|
|
688
|
+
return config
|
|
689
|
+
|
|
690
|
+
@classmethod
|
|
691
|
+
def from_config(cls, config):
|
|
692
|
+
config["se_layer"] = keras.saving.deserialize_keras_object(
|
|
693
|
+
config.pop("se_layer")
|
|
694
|
+
)
|
|
695
|
+
return cls(**config)
|
|
696
|
+
|
|
697
|
+
|
|
698
|
+
class MobileNetV5MultiScaleFusionAdapter(keras.layers.Layer):
|
|
699
|
+
"""Multi-Scale Fusion Adapter for MobileNetV5.
|
|
700
|
+
|
|
701
|
+
This layer fuses feature maps from different scales of the backbone,
|
|
702
|
+
concatenates them, processes them through a FFN (Feed-Forward Network),
|
|
703
|
+
and then resizes the output to a target resolution.
|
|
704
|
+
|
|
705
|
+
Args:
|
|
706
|
+
in_chs: list of int. A list of channel counts for each input feature
|
|
707
|
+
map.
|
|
708
|
+
filters: int. The number of output channels.
|
|
709
|
+
output_resolution: int or tuple. The target output resolution.
|
|
710
|
+
expansion_ratio: float. The expansion ratio for the FFN.
|
|
711
|
+
interpolation_mode: str. The interpolation mode for upsampling feature
|
|
712
|
+
maps.
|
|
713
|
+
layer_scale_init_value: float. The initial value for layer scale. If
|
|
714
|
+
`None`, layer scale is not used.
|
|
715
|
+
noskip: bool. If `True`, the skip connection in the FFN is disabled.
|
|
716
|
+
act_layer: str. The activation function to use.
|
|
717
|
+
norm_layer: str. The normalization layer to use.
|
|
718
|
+
data_format: str. The format of the input data, either
|
|
719
|
+
`"channels_last"` or `"channels_first"`.
|
|
720
|
+
channel_axis: int. The axis representing the channels in the input
|
|
721
|
+
tensor.
|
|
722
|
+
"""
|
|
723
|
+
|
|
724
|
+
def __init__(
|
|
725
|
+
self,
|
|
726
|
+
in_chs,
|
|
727
|
+
filters,
|
|
728
|
+
output_resolution,
|
|
729
|
+
expansion_ratio=2.0,
|
|
730
|
+
interpolation_mode="nearest",
|
|
731
|
+
layer_scale_init_value=None,
|
|
732
|
+
noskip=True,
|
|
733
|
+
act_layer="gelu",
|
|
734
|
+
norm_layer="rms_norm",
|
|
735
|
+
data_format=None,
|
|
736
|
+
channel_axis=None,
|
|
737
|
+
dtype=None,
|
|
738
|
+
**kwargs,
|
|
739
|
+
):
|
|
740
|
+
super().__init__(dtype=dtype, **kwargs)
|
|
741
|
+
self.in_chs = in_chs
|
|
742
|
+
self.filters = filters
|
|
743
|
+
self.output_resolution_arg = output_resolution
|
|
744
|
+
self.expansion_ratio = expansion_ratio
|
|
745
|
+
self.interpolation_mode = interpolation_mode
|
|
746
|
+
self.layer_scale_init_value = layer_scale_init_value
|
|
747
|
+
self.noskip = noskip
|
|
748
|
+
self.act_layer = act_layer
|
|
749
|
+
self.norm_layer_name = norm_layer
|
|
750
|
+
self.data_format = data_format
|
|
751
|
+
self.channel_axis = channel_axis
|
|
752
|
+
self.in_channels = sum(in_chs)
|
|
753
|
+
if isinstance(output_resolution, int):
|
|
754
|
+
self.output_resolution = (output_resolution, output_resolution)
|
|
755
|
+
else:
|
|
756
|
+
self.output_resolution = output_resolution
|
|
757
|
+
self.ffn = UniversalInvertedResidual(
|
|
758
|
+
filters=self.filters,
|
|
759
|
+
dw_kernel_size_mid=0,
|
|
760
|
+
exp_ratio=expansion_ratio,
|
|
761
|
+
act_layer=act_layer,
|
|
762
|
+
norm_layer=norm_layer,
|
|
763
|
+
noskip=noskip,
|
|
764
|
+
layer_scale_init_value=layer_scale_init_value,
|
|
765
|
+
data_format=self.data_format,
|
|
766
|
+
channel_axis=self.channel_axis,
|
|
767
|
+
dtype=self.dtype_policy,
|
|
768
|
+
)
|
|
769
|
+
if norm_layer == "rms_norm":
|
|
770
|
+
self.norm = RmsNorm2d(
|
|
771
|
+
self.filters,
|
|
772
|
+
data_format=self.data_format,
|
|
773
|
+
gamma_initializer="ones",
|
|
774
|
+
channel_axis=self.channel_axis,
|
|
775
|
+
dtype=self.dtype_policy,
|
|
776
|
+
)
|
|
777
|
+
else:
|
|
778
|
+
self.norm = keras.layers.BatchNormalization(
|
|
779
|
+
axis=self.channel_axis,
|
|
780
|
+
gamma_initializer="ones",
|
|
781
|
+
beta_initializer="zeros",
|
|
782
|
+
dtype=self.dtype_policy,
|
|
783
|
+
)
|
|
784
|
+
|
|
785
|
+
def build(self, input_shape):
|
|
786
|
+
super().build(input_shape)
|
|
787
|
+
ffn_input_shape = list(input_shape[0])
|
|
788
|
+
if self.data_format == "channels_first":
|
|
789
|
+
ffn_input_shape[1] = self.in_channels
|
|
790
|
+
else:
|
|
791
|
+
ffn_input_shape[-1] = self.in_channels
|
|
792
|
+
self.ffn.build(tuple(ffn_input_shape))
|
|
793
|
+
norm_input_shape = self.ffn.compute_output_shape(tuple(ffn_input_shape))
|
|
794
|
+
self.norm.build(norm_input_shape)
|
|
795
|
+
|
|
796
|
+
def call(self, inputs, training=False):
|
|
797
|
+
shape_hr = keras.ops.shape(inputs[0])
|
|
798
|
+
if self.data_format == "channels_first":
|
|
799
|
+
high_resolution = (shape_hr[2], shape_hr[3])
|
|
800
|
+
else:
|
|
801
|
+
high_resolution = (shape_hr[1], shape_hr[2])
|
|
802
|
+
resized_inputs = []
|
|
803
|
+
for img in inputs:
|
|
804
|
+
if self.data_format == "channels_first":
|
|
805
|
+
img_transposed = keras.ops.transpose(img, (0, 2, 3, 1))
|
|
806
|
+
else:
|
|
807
|
+
img_transposed = img
|
|
808
|
+
img_resized = keras.ops.image.resize(
|
|
809
|
+
img_transposed,
|
|
810
|
+
size=high_resolution,
|
|
811
|
+
interpolation=self.interpolation_mode,
|
|
812
|
+
)
|
|
813
|
+
if self.data_format == "channels_first":
|
|
814
|
+
resized_inputs.append(
|
|
815
|
+
keras.ops.transpose(img_resized, (0, 3, 1, 2))
|
|
816
|
+
)
|
|
817
|
+
else:
|
|
818
|
+
resized_inputs.append(img_resized)
|
|
819
|
+
channel_cat_imgs = keras.ops.concatenate(
|
|
820
|
+
resized_inputs, axis=self.channel_axis
|
|
821
|
+
)
|
|
822
|
+
img = self.ffn(channel_cat_imgs, training=training)
|
|
823
|
+
if (
|
|
824
|
+
high_resolution[0] != self.output_resolution[0]
|
|
825
|
+
or high_resolution[1] != self.output_resolution[1]
|
|
826
|
+
):
|
|
827
|
+
h_in, w_in = high_resolution
|
|
828
|
+
h_out, w_out = self.output_resolution
|
|
829
|
+
if h_in % h_out == 0 and w_in % w_out == 0:
|
|
830
|
+
h_stride = h_in // h_out
|
|
831
|
+
w_stride = w_in // w_out
|
|
832
|
+
img = keras.ops.nn.average_pool(
|
|
833
|
+
img,
|
|
834
|
+
pool_size=(h_stride, w_stride),
|
|
835
|
+
strides=(h_stride, w_stride),
|
|
836
|
+
padding="valid",
|
|
837
|
+
data_format=self.data_format,
|
|
838
|
+
)
|
|
839
|
+
else:
|
|
840
|
+
if self.data_format == "channels_first":
|
|
841
|
+
img_transposed = keras.ops.transpose(img, (0, 2, 3, 1))
|
|
842
|
+
else:
|
|
843
|
+
img_transposed = img
|
|
844
|
+
img_resized = keras.ops.image.resize(
|
|
845
|
+
img_transposed,
|
|
846
|
+
size=self.output_resolution,
|
|
847
|
+
interpolation="bilinear",
|
|
848
|
+
)
|
|
849
|
+
if self.data_format == "channels_first":
|
|
850
|
+
img = keras.ops.transpose(img_resized, (0, 3, 1, 2))
|
|
851
|
+
else:
|
|
852
|
+
img = img_resized
|
|
853
|
+
img = self.norm(img, training=training)
|
|
854
|
+
return img
|
|
855
|
+
|
|
856
|
+
def compute_output_shape(self, input_shape):
|
|
857
|
+
batch_size = input_shape[0][0]
|
|
858
|
+
if self.data_format == "channels_first":
|
|
859
|
+
return (
|
|
860
|
+
batch_size,
|
|
861
|
+
self.filters,
|
|
862
|
+
self.output_resolution[0],
|
|
863
|
+
self.output_resolution[1],
|
|
864
|
+
)
|
|
865
|
+
else:
|
|
866
|
+
return (
|
|
867
|
+
batch_size,
|
|
868
|
+
self.output_resolution[0],
|
|
869
|
+
self.output_resolution[1],
|
|
870
|
+
self.filters,
|
|
871
|
+
)
|
|
872
|
+
|
|
873
|
+
def get_config(self):
|
|
874
|
+
config = super().get_config()
|
|
875
|
+
config.update(
|
|
876
|
+
{
|
|
877
|
+
"in_chs": self.in_chs,
|
|
878
|
+
"filters": self.filters,
|
|
879
|
+
"output_resolution": self.output_resolution_arg,
|
|
880
|
+
"expansion_ratio": self.expansion_ratio,
|
|
881
|
+
"interpolation_mode": self.interpolation_mode,
|
|
882
|
+
"layer_scale_init_value": self.layer_scale_init_value,
|
|
883
|
+
"noskip": self.noskip,
|
|
884
|
+
"act_layer": self.act_layer,
|
|
885
|
+
"norm_layer": self.norm_layer_name,
|
|
886
|
+
"data_format": self.data_format,
|
|
887
|
+
"channel_axis": self.channel_axis,
|
|
888
|
+
}
|
|
889
|
+
)
|
|
890
|
+
return config
|