keras-hub-nightly 0.23.0.dev202510100415__py3-none-any.whl → 0.23.0.dev202510110411__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of keras-hub-nightly might be problematic. Click here for more details.
- keras_hub/layers/__init__.py +3 -0
- keras_hub/models/__init__.py +9 -0
- keras_hub/src/models/mobilenetv5/__init__.py +0 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_attention.py +699 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_backbone.py +396 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_blocks.py +890 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_builder.py +436 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier.py +157 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier_preprocessor.py +16 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_image_converter.py +10 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_layers.py +462 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_utils.py +146 -0
- keras_hub/src/utils/preset_utils.py +9 -2
- keras_hub/src/utils/timm/convert_mobilenetv5.py +321 -0
- keras_hub/src/utils/timm/preset_loader.py +8 -4
- keras_hub/src/version.py +1 -1
- {keras_hub_nightly-0.23.0.dev202510100415.dist-info → keras_hub_nightly-0.23.0.dev202510110411.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.23.0.dev202510100415.dist-info → keras_hub_nightly-0.23.0.dev202510110411.dist-info}/RECORD +20 -9
- {keras_hub_nightly-0.23.0.dev202510100415.dist-info → keras_hub_nightly-0.23.0.dev202510110411.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.23.0.dev202510100415.dist-info → keras_hub_nightly-0.23.0.dev202510110411.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,396 @@
|
|
|
1
|
+
import keras
|
|
2
|
+
from keras.src import saving
|
|
3
|
+
|
|
4
|
+
from keras_hub.src.api_export import keras_hub_export
|
|
5
|
+
from keras_hub.src.models.backbone import Backbone
|
|
6
|
+
from keras_hub.src.models.mobilenet.mobilenet_backbone import SqueezeAndExcite2D
|
|
7
|
+
from keras_hub.src.models.mobilenetv5.mobilenetv5_blocks import (
|
|
8
|
+
MobileNetV5MultiScaleFusionAdapter,
|
|
9
|
+
)
|
|
10
|
+
from keras_hub.src.models.mobilenetv5.mobilenetv5_builder import (
|
|
11
|
+
MobileNetV5Builder,
|
|
12
|
+
)
|
|
13
|
+
from keras_hub.src.models.mobilenetv5.mobilenetv5_layers import ConvNormAct
|
|
14
|
+
from keras_hub.src.models.mobilenetv5.mobilenetv5_utils import (
|
|
15
|
+
feature_take_indices,
|
|
16
|
+
)
|
|
17
|
+
from keras_hub.src.models.mobilenetv5.mobilenetv5_utils import round_channels
|
|
18
|
+
from keras_hub.src.utils.keras_utils import standardize_data_format
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@keras_hub_export("keras_hub.models.MobileNetV5Backbone")
|
|
22
|
+
class MobileNetV5Backbone(Backbone):
|
|
23
|
+
"""MobileNetV5 backbone network.
|
|
24
|
+
|
|
25
|
+
This class represents the backbone of the MobileNetV5 architecture, which
|
|
26
|
+
can be used as a feature extractor for various downstream tasks.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
stackwise_block_types: list of list of strings. The block type for each
|
|
30
|
+
block in each stack.
|
|
31
|
+
stackwise_num_blocks: list of ints. The number of blocks for each
|
|
32
|
+
stack.
|
|
33
|
+
stackwise_num_filters: list of list of ints. The number of filters for
|
|
34
|
+
each block in each stack.
|
|
35
|
+
stackwise_strides: list of list of ints. The stride for each block in
|
|
36
|
+
each stack.
|
|
37
|
+
stackwise_act_layers: list of list of strings. The activation function
|
|
38
|
+
for each block in each stack.
|
|
39
|
+
stackwise_exp_ratios: list of list of floats. The expansion ratio for
|
|
40
|
+
each block in each stack.
|
|
41
|
+
stackwise_se_ratios: list of list of floats. The SE ratio for each
|
|
42
|
+
block in each stack.
|
|
43
|
+
stackwise_dw_kernel_sizes: list of list of ints. The depthwise kernel
|
|
44
|
+
size for each block in each stack.
|
|
45
|
+
stackwise_dw_start_kernel_sizes: list of list of ints. The start
|
|
46
|
+
depthwise kernel size for each `uir` block in each stack.
|
|
47
|
+
stackwise_dw_end_kernel_sizes: list of list of ints. The end depthwise
|
|
48
|
+
kernel size for each `uir` block in each stack.
|
|
49
|
+
stackwise_exp_kernel_sizes: list of list of ints. The expansion kernel
|
|
50
|
+
size for each `er` block in each stack.
|
|
51
|
+
stackwise_pw_kernel_sizes: list of list of ints. The pointwise kernel
|
|
52
|
+
size for each `er` block in each stack.
|
|
53
|
+
stackwise_num_heads: list of list of ints. The number of heads for each
|
|
54
|
+
`mqa` or `mha` block in each stack.
|
|
55
|
+
stackwise_key_dims: list of list of ints. The key dimension for each
|
|
56
|
+
`mqa` or `mha` block in each stack.
|
|
57
|
+
stackwise_value_dims: list of list of ints. The value dimension for each
|
|
58
|
+
`mqa` or `mha` block in each stack.
|
|
59
|
+
stackwise_kv_strides: list of list of ints. The key-value stride for
|
|
60
|
+
each `mqa` or `mha` block in each stack.
|
|
61
|
+
stackwise_use_cpe: list of list of bools. Whether to use conditional
|
|
62
|
+
position encoding for each `mqa` or `mha` block in each stack.
|
|
63
|
+
filters: int. The number of input channels.
|
|
64
|
+
stem_size: int. The number of channels in the stem convolution.
|
|
65
|
+
stem_bias: bool. If `True`, a bias term is used in the stem
|
|
66
|
+
convolution.
|
|
67
|
+
fix_stem: bool. If `True`, the stem size is not rounded.
|
|
68
|
+
num_features: int. The number of output features, used when `use_msfa`
|
|
69
|
+
is `True`.
|
|
70
|
+
pad_type: str. The padding type for convolutions.
|
|
71
|
+
use_msfa: bool. If `True`, the Multi-Scale Fusion Adapter is used.
|
|
72
|
+
msfa_indices: tuple. The indices of the feature maps to be used by the
|
|
73
|
+
MSFA.
|
|
74
|
+
msfa_output_resolution: int. The output resolution of the MSFA.
|
|
75
|
+
act_layer: str. The activation function to use.
|
|
76
|
+
norm_layer: str. The normalization layer to use.
|
|
77
|
+
se_layer: keras.layers.Layer. The Squeeze-and-Excitation layer to use.
|
|
78
|
+
se_from_exp: bool. If `True`, SE channel reduction is based on the
|
|
79
|
+
expanded channels.
|
|
80
|
+
round_chs_fn: callable. A function to round the number of channels.
|
|
81
|
+
drop_path_rate: float. The stochastic depth rate.
|
|
82
|
+
layer_scale_init_value: float. The initial value for layer scale.
|
|
83
|
+
image_shape: tuple. The shape of the input image. Defaults to
|
|
84
|
+
`(None, None, 3)`.
|
|
85
|
+
data_format: str, The data format of the image channels. Can be either
|
|
86
|
+
`"channels_first"` or `"channels_last"`. If `None` is specified,
|
|
87
|
+
it will use the `image_data_format` value found in your Keras
|
|
88
|
+
config file at `~/.keras/keras.json`. Defaults to `None`.
|
|
89
|
+
dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
|
|
90
|
+
to use for the model's computations and weights. Defaults to `None`.
|
|
91
|
+
|
|
92
|
+
Example:
|
|
93
|
+
```python
|
|
94
|
+
import keras
|
|
95
|
+
from keras_hub.models import MobileNetV5Backbone
|
|
96
|
+
|
|
97
|
+
# Randomly initialized backbone with a custom config.
|
|
98
|
+
model_args = {
|
|
99
|
+
"stackwise_block_types": [["er"], ["uir", "uir"]],
|
|
100
|
+
"stackwise_num_blocks": [1, 2],
|
|
101
|
+
"stackwise_num_filters": [[24], [48, 48]],
|
|
102
|
+
"stackwise_strides": [[2], [2, 1]],
|
|
103
|
+
"stackwise_act_layers": [["relu"], ["relu", "relu"]],
|
|
104
|
+
"stackwise_exp_ratios": [[4.0], [6.0, 6.0]],
|
|
105
|
+
"stackwise_se_ratios": [[0.0], [0.0, 0.0]],
|
|
106
|
+
"stackwise_dw_kernel_sizes": [[0], [5, 5]],
|
|
107
|
+
"stackwise_dw_start_kernel_sizes": [[0], [0, 0]],
|
|
108
|
+
"stackwise_dw_end_kernel_sizes": [[0], [0, 0]],
|
|
109
|
+
"stackwise_exp_kernel_sizes": [[3], [0, 0]],
|
|
110
|
+
"stackwise_pw_kernel_sizes": [[1], [0, 0]],
|
|
111
|
+
"stackwise_num_heads": [[0], [0, 0]],
|
|
112
|
+
"stackwise_key_dims": [[0], [0, 0]],
|
|
113
|
+
"stackwise_value_dims": [[0], [0, 0]],
|
|
114
|
+
"stackwise_kv_strides": [[0], [0, 0]],
|
|
115
|
+
"stackwise_use_cpe": [[False], [False, False]],
|
|
116
|
+
"use_msfa": False,
|
|
117
|
+
}
|
|
118
|
+
model = MobileNetV5Backbone(**model_args)
|
|
119
|
+
input_data = keras.ops.ones((1, 224, 224, 3))
|
|
120
|
+
output = model(input_data)
|
|
121
|
+
|
|
122
|
+
# Load the backbone from a preset and run a prediction.
|
|
123
|
+
backbone = MobileNetV5Backbone.from_preset("mobilenetv5_300m_gemma3n")
|
|
124
|
+
|
|
125
|
+
# Expected output shape = (1, 16, 16, 2048).
|
|
126
|
+
outputs = backbone.predict(keras.ops.ones((1, 224, 224, 3)))
|
|
127
|
+
```
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
def __init__(
|
|
131
|
+
self,
|
|
132
|
+
stackwise_block_types,
|
|
133
|
+
stackwise_num_blocks,
|
|
134
|
+
stackwise_num_filters,
|
|
135
|
+
stackwise_strides,
|
|
136
|
+
stackwise_act_layers,
|
|
137
|
+
stackwise_exp_ratios,
|
|
138
|
+
stackwise_se_ratios,
|
|
139
|
+
stackwise_dw_kernel_sizes,
|
|
140
|
+
stackwise_dw_start_kernel_sizes,
|
|
141
|
+
stackwise_dw_end_kernel_sizes,
|
|
142
|
+
stackwise_exp_kernel_sizes,
|
|
143
|
+
stackwise_pw_kernel_sizes,
|
|
144
|
+
stackwise_num_heads,
|
|
145
|
+
stackwise_key_dims,
|
|
146
|
+
stackwise_value_dims,
|
|
147
|
+
stackwise_kv_strides,
|
|
148
|
+
stackwise_use_cpe,
|
|
149
|
+
filters=3,
|
|
150
|
+
stem_size=16,
|
|
151
|
+
stem_bias=True,
|
|
152
|
+
fix_stem=False,
|
|
153
|
+
num_features=2048,
|
|
154
|
+
pad_type="same",
|
|
155
|
+
use_msfa=True,
|
|
156
|
+
msfa_indices=(-2, -1),
|
|
157
|
+
msfa_output_resolution=16,
|
|
158
|
+
act_layer="gelu",
|
|
159
|
+
norm_layer="rms_norm",
|
|
160
|
+
se_layer=SqueezeAndExcite2D,
|
|
161
|
+
se_from_exp=True,
|
|
162
|
+
round_chs_fn=round_channels,
|
|
163
|
+
drop_path_rate=0.0,
|
|
164
|
+
layer_scale_init_value=None,
|
|
165
|
+
image_shape=(None, None, 3),
|
|
166
|
+
data_format=None,
|
|
167
|
+
dtype=None,
|
|
168
|
+
**kwargs,
|
|
169
|
+
):
|
|
170
|
+
data_format = standardize_data_format(data_format)
|
|
171
|
+
channel_axis = -1 if data_format == "channels_last" else 1
|
|
172
|
+
block_args = []
|
|
173
|
+
for i in range(len(stackwise_num_blocks)):
|
|
174
|
+
stack_args = []
|
|
175
|
+
for j in range(stackwise_num_blocks[i]):
|
|
176
|
+
block_type = stackwise_block_types[i][j]
|
|
177
|
+
args = {
|
|
178
|
+
"block_type": block_type,
|
|
179
|
+
"out_chs": stackwise_num_filters[i][j],
|
|
180
|
+
"stride": stackwise_strides[i][j],
|
|
181
|
+
"act_layer": stackwise_act_layers[i][j],
|
|
182
|
+
}
|
|
183
|
+
if block_type == "ir":
|
|
184
|
+
args.update(
|
|
185
|
+
{
|
|
186
|
+
"dw_kernel_size": stackwise_dw_kernel_sizes[i][j],
|
|
187
|
+
"exp_ratio": stackwise_exp_ratios[i][j],
|
|
188
|
+
"se_ratio": stackwise_se_ratios[i][j],
|
|
189
|
+
}
|
|
190
|
+
)
|
|
191
|
+
elif block_type == "uir":
|
|
192
|
+
args.update(
|
|
193
|
+
{
|
|
194
|
+
"dw_kernel_size_mid": stackwise_dw_kernel_sizes[i][
|
|
195
|
+
j
|
|
196
|
+
],
|
|
197
|
+
"dw_kernel_size_start": stackwise_dw_start_kernel_sizes[ # noqa: E501
|
|
198
|
+
i
|
|
199
|
+
][j],
|
|
200
|
+
"dw_kernel_size_end": stackwise_dw_end_kernel_sizes[
|
|
201
|
+
i
|
|
202
|
+
][j],
|
|
203
|
+
"exp_ratio": stackwise_exp_ratios[i][j],
|
|
204
|
+
"se_ratio": stackwise_se_ratios[i][j],
|
|
205
|
+
}
|
|
206
|
+
)
|
|
207
|
+
elif block_type == "er":
|
|
208
|
+
args.update(
|
|
209
|
+
{
|
|
210
|
+
"exp_kernel_size": stackwise_exp_kernel_sizes[i][j],
|
|
211
|
+
"pw_kernel_size": stackwise_pw_kernel_sizes[i][j],
|
|
212
|
+
"exp_ratio": stackwise_exp_ratios[i][j],
|
|
213
|
+
"se_ratio": stackwise_se_ratios[i][j],
|
|
214
|
+
}
|
|
215
|
+
)
|
|
216
|
+
elif block_type in ("mqa", "mha"):
|
|
217
|
+
args.update(
|
|
218
|
+
{
|
|
219
|
+
"num_heads": stackwise_num_heads[i][j],
|
|
220
|
+
"key_dim": stackwise_key_dims[i][j],
|
|
221
|
+
"value_dim": stackwise_value_dims[i][j],
|
|
222
|
+
"kv_stride": stackwise_kv_strides[i][j],
|
|
223
|
+
"use_cpe": stackwise_use_cpe[i][j],
|
|
224
|
+
}
|
|
225
|
+
)
|
|
226
|
+
stack_args.append(args)
|
|
227
|
+
block_args.append(stack_args)
|
|
228
|
+
|
|
229
|
+
# === Layers ===
|
|
230
|
+
if not fix_stem:
|
|
231
|
+
stem_size = round_chs_fn(stem_size)
|
|
232
|
+
conv_stem = ConvNormAct(
|
|
233
|
+
stem_size,
|
|
234
|
+
kernel_size=3,
|
|
235
|
+
stride=2,
|
|
236
|
+
pad_type=pad_type,
|
|
237
|
+
bias=stem_bias,
|
|
238
|
+
norm_layer=norm_layer,
|
|
239
|
+
act_layer=act_layer,
|
|
240
|
+
name="conv_stem",
|
|
241
|
+
data_format=data_format,
|
|
242
|
+
channel_axis=channel_axis,
|
|
243
|
+
dtype=dtype,
|
|
244
|
+
)
|
|
245
|
+
builder = MobileNetV5Builder(
|
|
246
|
+
output_stride=32,
|
|
247
|
+
pad_type=pad_type,
|
|
248
|
+
round_chs_fn=round_chs_fn,
|
|
249
|
+
se_from_exp=se_from_exp,
|
|
250
|
+
act_layer=act_layer,
|
|
251
|
+
norm_layer=norm_layer,
|
|
252
|
+
se_layer=se_layer,
|
|
253
|
+
drop_path_rate=drop_path_rate,
|
|
254
|
+
layer_scale_init_value=layer_scale_init_value,
|
|
255
|
+
data_format=data_format,
|
|
256
|
+
channel_axis=channel_axis,
|
|
257
|
+
dtype=dtype,
|
|
258
|
+
)
|
|
259
|
+
blocks = builder(stem_size, block_args)
|
|
260
|
+
feature_info = builder.features
|
|
261
|
+
msfa = None
|
|
262
|
+
if use_msfa:
|
|
263
|
+
msfa_indices_calc, _ = feature_take_indices(
|
|
264
|
+
len(feature_info), msfa_indices
|
|
265
|
+
)
|
|
266
|
+
msfa_in_chs = [
|
|
267
|
+
feature_info[mi]["num_chs"] for mi in msfa_indices_calc
|
|
268
|
+
]
|
|
269
|
+
msfa = MobileNetV5MultiScaleFusionAdapter(
|
|
270
|
+
in_chs=msfa_in_chs,
|
|
271
|
+
filters=num_features,
|
|
272
|
+
output_resolution=msfa_output_resolution,
|
|
273
|
+
norm_layer=norm_layer,
|
|
274
|
+
act_layer=act_layer,
|
|
275
|
+
name="msfa",
|
|
276
|
+
channel_axis=channel_axis,
|
|
277
|
+
data_format=data_format,
|
|
278
|
+
dtype=dtype,
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
# === Functional Model ===
|
|
282
|
+
image_input = keras.layers.Input(shape=image_shape)
|
|
283
|
+
x = conv_stem(image_input)
|
|
284
|
+
if use_msfa:
|
|
285
|
+
intermediates = []
|
|
286
|
+
feat_idx = 0
|
|
287
|
+
if feat_idx in msfa_indices_calc:
|
|
288
|
+
intermediates.append(x)
|
|
289
|
+
|
|
290
|
+
for stage in blocks:
|
|
291
|
+
for block in stage:
|
|
292
|
+
x = block(x)
|
|
293
|
+
feat_idx += 1
|
|
294
|
+
if feat_idx in msfa_indices_calc:
|
|
295
|
+
intermediates.append(x)
|
|
296
|
+
x = msfa(intermediates)
|
|
297
|
+
else:
|
|
298
|
+
for stage in blocks:
|
|
299
|
+
for block in stage:
|
|
300
|
+
x = block(x)
|
|
301
|
+
|
|
302
|
+
super().__init__(inputs=image_input, outputs=x, dtype=dtype, **kwargs)
|
|
303
|
+
|
|
304
|
+
# === Config ===
|
|
305
|
+
self.stackwise_block_types = stackwise_block_types
|
|
306
|
+
self.stackwise_num_blocks = stackwise_num_blocks
|
|
307
|
+
self.stackwise_num_filters = stackwise_num_filters
|
|
308
|
+
self.stackwise_strides = stackwise_strides
|
|
309
|
+
self.stackwise_act_layers = stackwise_act_layers
|
|
310
|
+
self.stackwise_exp_ratios = stackwise_exp_ratios
|
|
311
|
+
self.stackwise_se_ratios = stackwise_se_ratios
|
|
312
|
+
self.stackwise_dw_kernel_sizes = stackwise_dw_kernel_sizes
|
|
313
|
+
self.stackwise_dw_start_kernel_sizes = stackwise_dw_start_kernel_sizes
|
|
314
|
+
self.stackwise_dw_end_kernel_sizes = stackwise_dw_end_kernel_sizes
|
|
315
|
+
self.stackwise_exp_kernel_sizes = stackwise_exp_kernel_sizes
|
|
316
|
+
self.stackwise_pw_kernel_sizes = stackwise_pw_kernel_sizes
|
|
317
|
+
self.stackwise_num_heads = stackwise_num_heads
|
|
318
|
+
self.stackwise_key_dims = stackwise_key_dims
|
|
319
|
+
self.stackwise_value_dims = stackwise_value_dims
|
|
320
|
+
self.stackwise_kv_strides = stackwise_kv_strides
|
|
321
|
+
self.stackwise_use_cpe = stackwise_use_cpe
|
|
322
|
+
self.filters = filters
|
|
323
|
+
self.stem_size = stem_size
|
|
324
|
+
self.stem_bias = stem_bias
|
|
325
|
+
self.fix_stem = fix_stem
|
|
326
|
+
self.num_features = num_features
|
|
327
|
+
self.pad_type = pad_type
|
|
328
|
+
self.use_msfa = use_msfa
|
|
329
|
+
self.msfa_indices = msfa_indices
|
|
330
|
+
self.msfa_output_resolution = msfa_output_resolution
|
|
331
|
+
self.act_layer = act_layer
|
|
332
|
+
self.norm_layer = norm_layer
|
|
333
|
+
self.se_layer = se_layer
|
|
334
|
+
self.se_from_exp = se_from_exp
|
|
335
|
+
self.round_chs_fn = round_chs_fn
|
|
336
|
+
self.drop_path_rate = drop_path_rate
|
|
337
|
+
self.layer_scale_init_value = layer_scale_init_value
|
|
338
|
+
self.image_shape = image_shape
|
|
339
|
+
self.data_format = data_format
|
|
340
|
+
self.channel_axis = channel_axis
|
|
341
|
+
|
|
342
|
+
def get_config(self):
|
|
343
|
+
config = {
|
|
344
|
+
"stackwise_block_types": self.stackwise_block_types,
|
|
345
|
+
"stackwise_num_blocks": self.stackwise_num_blocks,
|
|
346
|
+
"stackwise_num_filters": self.stackwise_num_filters,
|
|
347
|
+
"stackwise_strides": self.stackwise_strides,
|
|
348
|
+
"stackwise_act_layers": self.stackwise_act_layers,
|
|
349
|
+
"stackwise_exp_ratios": self.stackwise_exp_ratios,
|
|
350
|
+
"stackwise_se_ratios": self.stackwise_se_ratios,
|
|
351
|
+
"stackwise_dw_kernel_sizes": self.stackwise_dw_kernel_sizes,
|
|
352
|
+
"stackwise_dw_start_kernel_sizes": self.stackwise_dw_start_kernel_sizes, # noqa: E501
|
|
353
|
+
"stackwise_dw_end_kernel_sizes": self.stackwise_dw_end_kernel_sizes,
|
|
354
|
+
"stackwise_exp_kernel_sizes": self.stackwise_exp_kernel_sizes,
|
|
355
|
+
"stackwise_pw_kernel_sizes": self.stackwise_pw_kernel_sizes,
|
|
356
|
+
"stackwise_num_heads": self.stackwise_num_heads,
|
|
357
|
+
"stackwise_key_dims": self.stackwise_key_dims,
|
|
358
|
+
"stackwise_value_dims": self.stackwise_value_dims,
|
|
359
|
+
"stackwise_kv_strides": self.stackwise_kv_strides,
|
|
360
|
+
"stackwise_use_cpe": self.stackwise_use_cpe,
|
|
361
|
+
"filters": self.filters,
|
|
362
|
+
"stem_size": self.stem_size,
|
|
363
|
+
"stem_bias": self.stem_bias,
|
|
364
|
+
"fix_stem": self.fix_stem,
|
|
365
|
+
"num_features": self.num_features,
|
|
366
|
+
"pad_type": self.pad_type,
|
|
367
|
+
"use_msfa": self.use_msfa,
|
|
368
|
+
"msfa_indices": self.msfa_indices,
|
|
369
|
+
"msfa_output_resolution": self.msfa_output_resolution,
|
|
370
|
+
"act_layer": self.act_layer,
|
|
371
|
+
"norm_layer": self.norm_layer,
|
|
372
|
+
"se_from_exp": self.se_from_exp,
|
|
373
|
+
"drop_path_rate": self.drop_path_rate,
|
|
374
|
+
"layer_scale_init_value": self.layer_scale_init_value,
|
|
375
|
+
"image_shape": self.image_shape,
|
|
376
|
+
"data_format": self.data_format,
|
|
377
|
+
}
|
|
378
|
+
if self.round_chs_fn is not round_channels:
|
|
379
|
+
config["round_chs_fn"] = saving.serialize_keras_object(
|
|
380
|
+
self.round_chs_fn
|
|
381
|
+
)
|
|
382
|
+
if self.se_layer is not SqueezeAndExcite2D:
|
|
383
|
+
config["se_layer"] = saving.serialize_keras_object(self.se_layer)
|
|
384
|
+
return config
|
|
385
|
+
|
|
386
|
+
@classmethod
|
|
387
|
+
def from_config(cls, config):
|
|
388
|
+
if "round_chs_fn" in config:
|
|
389
|
+
config["round_chs_fn"] = saving.deserialize_keras_object(
|
|
390
|
+
config["round_chs_fn"]
|
|
391
|
+
)
|
|
392
|
+
if "se_layer" in config:
|
|
393
|
+
config["se_layer"] = saving.deserialize_keras_object(
|
|
394
|
+
config["se_layer"]
|
|
395
|
+
)
|
|
396
|
+
return cls(**config)
|