keras-hub-nightly 0.23.0.dev202510100415__py3-none-any.whl → 0.23.0.dev202510110411__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of keras-hub-nightly might be problematic. Click here for more details.

@@ -0,0 +1,890 @@
1
+ import keras
2
+
3
+ from keras_hub.src.models.mobilenet.util import adjust_channels
4
+ from keras_hub.src.models.mobilenetv5.mobilenetv5_layers import ConvNormAct
5
+ from keras_hub.src.models.mobilenetv5.mobilenetv5_layers import DropPath
6
+ from keras_hub.src.models.mobilenetv5.mobilenetv5_layers import LayerScale2d
7
+ from keras_hub.src.models.mobilenetv5.mobilenetv5_layers import RmsNorm2d
8
+ from keras_hub.src.models.mobilenetv5.mobilenetv5_utils import num_groups
9
+
10
+
11
+ class UniversalInvertedResidual(keras.layers.Layer):
12
+ """Universal Inverted Residual block.
13
+
14
+ This block is a flexible and universal version of the inverted residual
15
+ block, which can be configured to behave like different variants of mobile
16
+ convolutional blocks.
17
+
18
+ Args:
19
+ filters: int. The number of output channels.
20
+ dw_kernel_size_start: int. The kernel size for the initial depthwise
21
+ convolution. If 0, this layer is skipped.
22
+ dw_kernel_size_mid: int. The kernel size for the middle depthwise
23
+ convolution. If 0, this layer is skipped.
24
+ dw_kernel_size_end: int. The kernel size for the final depthwise
25
+ convolution. If 0, this layer is skipped.
26
+ stride: int. The stride for the block.
27
+ dilation: int. The dilation rate for convolutions.
28
+ pad_type: str. The padding type for convolutions.
29
+ noskip: bool. If `True`, the skip connection is disabled.
30
+ exp_ratio: float. The expansion ratio for the middle channels.
31
+ act_layer: str. The activation function to use.
32
+ norm_layer: str. The normalization layer to use.
33
+ se_layer: keras.layers.Layer. The Squeeze-and-Excitation layer to use.
34
+ drop_path_rate: float. The stochastic depth rate.
35
+ layer_scale_init_value: float. The initial value for layer scale. If
36
+ `None`, layer scale is not used.
37
+ data_format: str. The format of the input data, either
38
+ `"channels_last"` or `"channels_first"`.
39
+ channel_axis: int. The axis representing the channels in the input
40
+ tensor.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ filters,
46
+ dw_kernel_size_start=0,
47
+ dw_kernel_size_mid=3,
48
+ dw_kernel_size_end=0,
49
+ stride=1,
50
+ dilation=1,
51
+ pad_type="same",
52
+ noskip=False,
53
+ exp_ratio=1.0,
54
+ act_layer="relu",
55
+ norm_layer="batch_norm",
56
+ se_layer=None,
57
+ drop_path_rate=0.0,
58
+ layer_scale_init_value=1e-5,
59
+ data_format=None,
60
+ channel_axis=None,
61
+ dtype=None,
62
+ **kwargs,
63
+ ):
64
+ super().__init__(dtype=dtype, **kwargs)
65
+ self.filters = filters
66
+ self.dw_kernel_size_start = dw_kernel_size_start
67
+ self.dw_kernel_size_mid = dw_kernel_size_mid
68
+ self.dw_kernel_size_end = dw_kernel_size_end
69
+ self.stride = stride
70
+ self.dilation = dilation
71
+ self.pad_type = pad_type
72
+ self.noskip = noskip
73
+ self.exp_ratio = exp_ratio
74
+ self.act_layer = act_layer
75
+ self.norm_layer = norm_layer
76
+ self.se_layer = se_layer
77
+ self.drop_path_rate = drop_path_rate
78
+ self.layer_scale_init_value = layer_scale_init_value
79
+ self.data_format = data_format
80
+ self.channel_axis = channel_axis
81
+
82
+ def build(self, input_shape):
83
+ super().build(input_shape)
84
+ in_chs = input_shape[self.channel_axis]
85
+ self.has_skip = (
86
+ in_chs == self.filters and self.stride == 1
87
+ ) and not self.noskip
88
+ use_bias = self.norm_layer == "rms_norm"
89
+
90
+ if self.dw_kernel_size_start:
91
+ self.dw_start = ConvNormAct(
92
+ in_chs,
93
+ self.dw_kernel_size_start,
94
+ stride=self.stride if not self.dw_kernel_size_mid else 1,
95
+ dilation=self.dilation,
96
+ groups=in_chs,
97
+ pad_type=self.pad_type,
98
+ apply_act=False,
99
+ act_layer=self.act_layer,
100
+ norm_layer=self.norm_layer,
101
+ bias=use_bias,
102
+ data_format=self.data_format,
103
+ channel_axis=self.channel_axis,
104
+ dtype=self.dtype_policy,
105
+ )
106
+ else:
107
+ self.dw_start = lambda x, training=False: x
108
+
109
+ mid_chs = adjust_channels(in_chs * self.exp_ratio)
110
+ self.pw_exp = ConvNormAct(
111
+ mid_chs,
112
+ 1,
113
+ pad_type=self.pad_type,
114
+ act_layer=self.act_layer,
115
+ norm_layer=self.norm_layer,
116
+ bias=use_bias,
117
+ data_format=self.data_format,
118
+ channel_axis=self.channel_axis,
119
+ dtype=self.dtype_policy,
120
+ )
121
+
122
+ if self.dw_kernel_size_mid:
123
+ self.dw_mid = ConvNormAct(
124
+ mid_chs,
125
+ self.dw_kernel_size_mid,
126
+ stride=self.stride,
127
+ dilation=self.dilation,
128
+ groups=mid_chs,
129
+ pad_type=self.pad_type,
130
+ act_layer=self.act_layer,
131
+ norm_layer=self.norm_layer,
132
+ bias=use_bias,
133
+ data_format=self.data_format,
134
+ channel_axis=self.channel_axis,
135
+ dtype=self.dtype_policy,
136
+ )
137
+ else:
138
+ self.dw_mid = lambda x, training=False: x
139
+ self.se = (
140
+ self.se_layer(
141
+ filters=mid_chs,
142
+ bottleneck_filters=adjust_channels(mid_chs * 0.25),
143
+ squeeze_activation=self.act_layer,
144
+ excite_activation="sigmoid",
145
+ data_format=self.data_format,
146
+ channel_axis=self.channel_axis,
147
+ dtype=self.dtype_policy,
148
+ )
149
+ if self.se_layer
150
+ else (lambda x, training=False: x)
151
+ )
152
+ self.pw_proj = ConvNormAct(
153
+ self.filters,
154
+ 1,
155
+ pad_type=self.pad_type,
156
+ apply_act=False,
157
+ act_layer=self.act_layer,
158
+ norm_layer=self.norm_layer,
159
+ bias=use_bias,
160
+ data_format=self.data_format,
161
+ channel_axis=self.channel_axis,
162
+ dtype=self.dtype_policy,
163
+ )
164
+
165
+ if self.dw_kernel_size_end:
166
+ self.dw_end = ConvNormAct(
167
+ self.filters,
168
+ self.dw_kernel_size_end,
169
+ stride=self.stride
170
+ if not self.dw_kernel_size_start and not self.dw_kernel_size_mid
171
+ else 1,
172
+ dilation=self.dilation,
173
+ groups=self.filters,
174
+ pad_type=self.pad_type,
175
+ apply_act=False,
176
+ act_layer=self.act_layer,
177
+ norm_layer=self.norm_layer,
178
+ bias=use_bias,
179
+ data_format=self.data_format,
180
+ channel_axis=self.channel_axis,
181
+ dtype=self.dtype_policy,
182
+ )
183
+ else:
184
+ self.dw_end = lambda x, training=False: x
185
+
186
+ self.layer_scale = (
187
+ LayerScale2d(
188
+ self.filters,
189
+ self.layer_scale_init_value,
190
+ data_format=self.data_format,
191
+ channel_axis=self.channel_axis,
192
+ dtype=self.dtype_policy,
193
+ )
194
+ if self.layer_scale_init_value is not None
195
+ else lambda x: x
196
+ )
197
+ self.drop_path = (
198
+ DropPath(self.drop_path_rate, dtype=self.dtype_policy)
199
+ if self.drop_path_rate > 0.0
200
+ else (lambda x, training=False: x)
201
+ )
202
+ current_shape = input_shape
203
+ if hasattr(self.dw_start, "build"):
204
+ self.dw_start.build(current_shape)
205
+ current_shape = self.dw_start.compute_output_shape(current_shape)
206
+ self.pw_exp.build(current_shape)
207
+ current_shape = self.pw_exp.compute_output_shape(current_shape)
208
+ if hasattr(self.dw_mid, "build"):
209
+ self.dw_mid.build(current_shape)
210
+ current_shape = self.dw_mid.compute_output_shape(current_shape)
211
+ if hasattr(self.se, "build"):
212
+ self.se.build(current_shape)
213
+ self.pw_proj.build(current_shape)
214
+ current_shape = self.pw_proj.compute_output_shape(current_shape)
215
+ if hasattr(self.dw_end, "build"):
216
+ self.dw_end.build(current_shape)
217
+ current_shape = self.dw_end.compute_output_shape(current_shape)
218
+ if hasattr(self.layer_scale, "build"):
219
+ self.layer_scale.build(current_shape)
220
+
221
+ def call(self, x, training=False):
222
+ shortcut = x
223
+ x = self.dw_start(x, training=training)
224
+ x = self.pw_exp(x, training=training)
225
+ x = self.dw_mid(x, training=training)
226
+ x = self.se(x, training=training)
227
+ x = self.pw_proj(x, training=training)
228
+ x = self.dw_end(x, training=training)
229
+ x = self.layer_scale(x)
230
+ if self.has_skip:
231
+ x = self.drop_path(x, training=training) + shortcut
232
+ return x
233
+
234
+ def compute_output_shape(self, input_shape):
235
+ current_shape = input_shape
236
+ if hasattr(self.dw_start, "compute_output_shape"):
237
+ current_shape = self.dw_start.compute_output_shape(current_shape)
238
+ current_shape = self.pw_exp.compute_output_shape(current_shape)
239
+ if hasattr(self.dw_mid, "compute_output_shape"):
240
+ current_shape = self.dw_mid.compute_output_shape(current_shape)
241
+ current_shape = self.pw_proj.compute_output_shape(current_shape)
242
+ if hasattr(self.dw_end, "compute_output_shape"):
243
+ current_shape = self.dw_end.compute_output_shape(current_shape)
244
+ return current_shape
245
+
246
+ def get_config(self):
247
+ config = super().get_config()
248
+ config.update(
249
+ {
250
+ "filters": self.filters,
251
+ "dw_kernel_size_start": self.dw_kernel_size_start,
252
+ "dw_kernel_size_mid": self.dw_kernel_size_mid,
253
+ "dw_kernel_size_end": self.dw_kernel_size_end,
254
+ "stride": self.stride,
255
+ "dilation": self.dilation,
256
+ "pad_type": self.pad_type,
257
+ "noskip": self.noskip,
258
+ "exp_ratio": self.exp_ratio,
259
+ "act_layer": self.act_layer,
260
+ "norm_layer": self.norm_layer,
261
+ "se_layer": keras.saving.serialize_keras_object(self.se_layer),
262
+ "drop_path_rate": self.drop_path_rate,
263
+ "layer_scale_init_value": self.layer_scale_init_value,
264
+ "data_format": self.data_format,
265
+ "channel_axis": self.channel_axis,
266
+ }
267
+ )
268
+ return config
269
+
270
+ @classmethod
271
+ def from_config(cls, config):
272
+ config["se_layer"] = keras.saving.deserialize_keras_object(
273
+ config.pop("se_layer")
274
+ )
275
+ return cls(**config)
276
+
277
+
278
+ class EdgeResidual(keras.layers.Layer):
279
+ """Edge Residual block.
280
+
281
+ This block is designed for efficiency on edge devices. It is a variant of
282
+ the inverted residual block that uses a single expansion convolution.
283
+
284
+ Args:
285
+ filters: int. The number of output channels.
286
+ exp_kernel_size: int. The kernel size for the expansion convolution.
287
+ stride: int. The stride for the block.
288
+ dilation: int. The dilation rate for convolutions.
289
+ group_size: int. The group size for grouped convolutions.
290
+ pad_type: str. The padding type for convolutions.
291
+ expansion_in_chs: int. If greater than 0, forces the number of input
292
+ channels for the expansion.
293
+ noskip: bool. If `True`, the skip connection is disabled.
294
+ exp_ratio: float. The expansion ratio for the middle channels.
295
+ pw_kernel_size: int. The kernel size for the pointwise convolution.
296
+ act_layer: str. The activation function to use.
297
+ norm_layer: str. The normalization layer to use.
298
+ se_layer: keras.layers.Layer. The Squeeze-and-Excitation layer to use.
299
+ drop_path_rate: float. The stochastic depth rate.
300
+ data_format: str. The format of the input data, either
301
+ `"channels_last"` or `"channels_first"`.
302
+ channel_axis: int. The axis representing the channels in the input
303
+ tensor.
304
+ """
305
+
306
+ def __init__(
307
+ self,
308
+ filters,
309
+ exp_kernel_size=3,
310
+ stride=1,
311
+ dilation=1,
312
+ group_size=0,
313
+ pad_type="same",
314
+ expansion_in_chs=0,
315
+ noskip=False,
316
+ exp_ratio=1.0,
317
+ pw_kernel_size=1,
318
+ act_layer="relu",
319
+ norm_layer="batch_norm",
320
+ se_layer=None,
321
+ drop_path_rate=0.0,
322
+ data_format=None,
323
+ channel_axis=None,
324
+ dtype=None,
325
+ **kwargs,
326
+ ):
327
+ super().__init__(dtype=dtype, **kwargs)
328
+ self.filters = filters
329
+ self.exp_kernel_size = exp_kernel_size
330
+ self.stride = stride
331
+ self.dilation = dilation
332
+ self.group_size = group_size
333
+ self.pad_type = pad_type
334
+ self.expansion_in_chs = expansion_in_chs
335
+ self.noskip = noskip
336
+ self.exp_ratio = exp_ratio
337
+ self.pw_kernel_size = pw_kernel_size
338
+ self.act_layer = act_layer
339
+ self.norm_layer = norm_layer
340
+ self.se_layer = se_layer
341
+ self.drop_path_rate = drop_path_rate
342
+ self.data_format = data_format
343
+ self.channel_axis = channel_axis
344
+
345
+ def build(self, input_shape):
346
+ super().build(input_shape)
347
+ in_chs = input_shape[self.channel_axis]
348
+ self.has_skip = (
349
+ in_chs == self.filters and self.stride == 1
350
+ ) and not self.noskip
351
+ if self.expansion_in_chs > 0:
352
+ mid_chs = adjust_channels(self.expansion_in_chs * self.exp_ratio)
353
+ else:
354
+ mid_chs = adjust_channels(in_chs * self.exp_ratio)
355
+ groups = num_groups(self.group_size, mid_chs)
356
+ use_bias = self.norm_layer == "rms_norm"
357
+ self.conv_exp = ConvNormAct(
358
+ mid_chs,
359
+ self.exp_kernel_size,
360
+ stride=self.stride,
361
+ dilation=self.dilation,
362
+ groups=groups,
363
+ pad_type=self.pad_type,
364
+ norm_layer=self.norm_layer,
365
+ act_layer=self.act_layer,
366
+ bias=use_bias,
367
+ data_format=self.data_format,
368
+ channel_axis=self.channel_axis,
369
+ dtype=self.dtype_policy,
370
+ )
371
+ self.se = (
372
+ self.se_layer(
373
+ filters=mid_chs,
374
+ bottleneck_filters=adjust_channels(mid_chs * 0.25),
375
+ squeeze_activation=self.act_layer,
376
+ excite_activation="sigmoid",
377
+ data_format=self.data_format,
378
+ channel_axis=self.channel_axis,
379
+ dtype=self.dtype_policy,
380
+ )
381
+ if self.se_layer
382
+ else (lambda x, training=False: x)
383
+ )
384
+ self.conv_pwl = ConvNormAct(
385
+ self.filters,
386
+ self.pw_kernel_size,
387
+ pad_type=self.pad_type,
388
+ apply_act=False,
389
+ norm_layer=self.norm_layer,
390
+ act_layer=self.act_layer,
391
+ bias=use_bias,
392
+ data_format=self.data_format,
393
+ channel_axis=self.channel_axis,
394
+ dtype=self.dtype_policy,
395
+ )
396
+ self.drop_path = (
397
+ DropPath(self.drop_path_rate, dtype=self.dtype_policy)
398
+ if self.drop_path_rate > 0.0
399
+ else (lambda x, training=False: x)
400
+ )
401
+ self.conv_exp.build(input_shape)
402
+ conv_exp_output_shape = self.conv_exp.compute_output_shape(input_shape)
403
+ if hasattr(self.se, "build"):
404
+ self.se.build(conv_exp_output_shape)
405
+ self.conv_pwl.build(conv_exp_output_shape)
406
+
407
+ def call(self, x, training=False):
408
+ shortcut = x
409
+ x = self.conv_exp(x, training=training)
410
+ x = self.se(x, training=training)
411
+ x = self.conv_pwl(x, training=training)
412
+ if self.has_skip:
413
+ x = self.drop_path(x, training=training) + shortcut
414
+ return x
415
+
416
+ def get_config(self):
417
+ config = super().get_config()
418
+ config.update(
419
+ {
420
+ "filters": self.filters,
421
+ "exp_kernel_size": self.exp_kernel_size,
422
+ "stride": self.stride,
423
+ "dilation": self.dilation,
424
+ "group_size": self.group_size,
425
+ "pad_type": self.pad_type,
426
+ "expansion_in_chs": self.expansion_in_chs,
427
+ "noskip": self.noskip,
428
+ "exp_ratio": self.exp_ratio,
429
+ "pw_kernel_size": self.pw_kernel_size,
430
+ "act_layer": self.act_layer,
431
+ "norm_layer": self.norm_layer,
432
+ "se_layer": keras.saving.serialize_keras_object(self.se_layer),
433
+ "drop_path_rate": self.drop_path_rate,
434
+ "data_format": self.data_format,
435
+ "channel_axis": self.channel_axis,
436
+ }
437
+ )
438
+ return config
439
+
440
+ @classmethod
441
+ def from_config(cls, config):
442
+ config["se_layer"] = keras.saving.deserialize_keras_object(
443
+ config.pop("se_layer")
444
+ )
445
+ return cls(**config)
446
+
447
+
448
+ class CondConvResidual(keras.layers.Layer):
449
+ """Conditionally Parameterized Convolutional Residual block.
450
+
451
+ This block uses a routing function to dynamically select and combine
452
+ different convolutional experts based on the input.
453
+
454
+ Args:
455
+ filters: int. The number of output channels.
456
+ dw_kernel_size: int. The kernel size for the depthwise convolution.
457
+ stride: int. The stride for the block.
458
+ dilation: int. The dilation rate for convolutions.
459
+ pad_type: str. The padding type for convolutions.
460
+ noskip: bool. If `True`, the skip connection is disabled.
461
+ exp_ratio: float. The expansion ratio for the middle channels.
462
+ exp_kernel_size: int. The kernel size for the expansion convolution.
463
+ pw_kernel_size: int. The kernel size for the pointwise convolution.
464
+ act_layer: str. The activation function to use.
465
+ se_layer: keras.layers.Layer. The Squeeze-and-Excitation layer to use.
466
+ num_experts: int. The number of experts to use.
467
+ drop_path_rate: float. The stochastic depth rate.
468
+ data_format: str. The format of the input data, either
469
+ `"channels_last"` or `"channels_first"`.
470
+ channel_axis: int. The axis representing the channels in the input
471
+ tensor.
472
+ """
473
+
474
+ def __init__(
475
+ self,
476
+ filters,
477
+ dw_kernel_size=3,
478
+ stride=1,
479
+ dilation=1,
480
+ pad_type="same",
481
+ noskip=False,
482
+ exp_ratio=1.0,
483
+ exp_kernel_size=1,
484
+ pw_kernel_size=1,
485
+ act_layer="relu",
486
+ se_layer=None,
487
+ num_experts=0,
488
+ drop_path_rate=0.0,
489
+ data_format=None,
490
+ channel_axis=None,
491
+ dtype=None,
492
+ **kwargs,
493
+ ):
494
+ super().__init__(dtype=dtype, **kwargs)
495
+ self.filters = filters
496
+ self.dw_kernel_size = dw_kernel_size
497
+ self.stride = stride
498
+ self.dilation = dilation
499
+ self.pad_type = pad_type
500
+ self.noskip = noskip
501
+ self.exp_ratio = exp_ratio
502
+ self.exp_kernel_size = exp_kernel_size
503
+ self.pw_kernel_size = pw_kernel_size
504
+ self.act_layer = act_layer
505
+ self.se_layer = se_layer
506
+ self.num_experts = num_experts
507
+ self.drop_path_rate = drop_path_rate
508
+ self.data_format = data_format
509
+ self.channel_axis = channel_axis
510
+ self.conv_kernel_initializer = keras.initializers.VarianceScaling(
511
+ scale=2.0, mode="fan_out", distribution="untruncated_normal"
512
+ )
513
+ self.dense_kernel_initializer = keras.initializers.VarianceScaling(
514
+ scale=1.0, mode="fan_in", distribution="uniform"
515
+ )
516
+ self.bias_initializer = "zeros"
517
+
518
+ def build(self, input_shape):
519
+ super().build(input_shape)
520
+ in_chs = input_shape[self.channel_axis]
521
+ self.has_skip = (
522
+ in_chs == self.filters and self.stride == 1
523
+ ) and not self.noskip
524
+ mid_chs = adjust_channels(in_chs * self.exp_ratio)
525
+ self.routing_fn = keras.layers.Dense(
526
+ self.num_experts,
527
+ dtype=self.dtype_policy,
528
+ kernel_initializer=self.dense_kernel_initializer,
529
+ bias_initializer=self.bias_initializer,
530
+ )
531
+ self.pool = keras.layers.GlobalAveragePooling2D(
532
+ data_format=self.data_format, dtype=self.dtype_policy
533
+ )
534
+ self.conv_pw_experts = [
535
+ keras.layers.Conv2D(
536
+ filters=mid_chs,
537
+ kernel_size=self.exp_kernel_size,
538
+ padding=self.pad_type,
539
+ use_bias=True,
540
+ data_format=self.data_format,
541
+ name=f"conv_pw_expert_{i}",
542
+ kernel_initializer=self.conv_kernel_initializer,
543
+ bias_initializer=self.bias_initializer,
544
+ dtype=self.dtype_policy,
545
+ )
546
+ for i in range(self.num_experts)
547
+ ]
548
+ self.conv_dw_experts = [
549
+ keras.layers.DepthwiseConv2D(
550
+ kernel_size=self.dw_kernel_size,
551
+ strides=self.stride,
552
+ padding=self.pad_type,
553
+ dilation_rate=self.dilation,
554
+ use_bias=True,
555
+ data_format=self.data_format,
556
+ name=f"conv_dw_expert_{i}",
557
+ depthwise_initializer=self.conv_kernel_initializer,
558
+ bias_initializer=self.bias_initializer,
559
+ dtype=self.dtype_policy,
560
+ )
561
+ for i in range(self.num_experts)
562
+ ]
563
+ self.conv_pwl_experts = [
564
+ keras.layers.Conv2D(
565
+ filters=self.filters,
566
+ kernel_size=self.pw_kernel_size,
567
+ padding=self.pad_type,
568
+ use_bias=True,
569
+ data_format=self.data_format,
570
+ name=f"conv_pwl_expert_{i}",
571
+ kernel_initializer=self.conv_kernel_initializer,
572
+ bias_initializer=self.bias_initializer,
573
+ dtype=self.dtype_policy,
574
+ )
575
+ for i in range(self.num_experts)
576
+ ]
577
+ self.bn1 = keras.layers.BatchNormalization(
578
+ axis=self.channel_axis,
579
+ dtype=self.dtype_policy,
580
+ gamma_initializer="ones",
581
+ beta_initializer="zeros",
582
+ )
583
+ self.act1 = keras.layers.Activation(
584
+ self.act_layer, dtype=self.dtype_policy
585
+ )
586
+ self.bn2 = keras.layers.BatchNormalization(
587
+ axis=self.channel_axis,
588
+ dtype=self.dtype_policy,
589
+ gamma_initializer="ones",
590
+ beta_initializer="zeros",
591
+ )
592
+ self.act2 = keras.layers.Activation(
593
+ self.act_layer, dtype=self.dtype_policy
594
+ )
595
+ self.bn3 = keras.layers.BatchNormalization(
596
+ axis=self.channel_axis,
597
+ dtype=self.dtype_policy,
598
+ gamma_initializer="ones",
599
+ beta_initializer="zeros",
600
+ )
601
+ self.se = (
602
+ self.se_layer(
603
+ filters=mid_chs,
604
+ bottleneck_filters=adjust_channels(mid_chs * 0.25),
605
+ squeeze_activation=self.act_layer,
606
+ excite_activation="sigmoid",
607
+ data_format=self.data_format,
608
+ channel_axis=self.channel_axis,
609
+ dtype=self.dtype_policy,
610
+ )
611
+ if self.se_layer
612
+ else (lambda x, training=False: x)
613
+ )
614
+ self.drop_path = (
615
+ DropPath(self.drop_path_rate, dtype=self.dtype_policy)
616
+ if self.drop_path_rate > 0.0
617
+ else (lambda x, training=False: x)
618
+ )
619
+ pooled_shape = self.pool.compute_output_shape(input_shape)
620
+ self.routing_fn.build(pooled_shape)
621
+ for expert in self.conv_pw_experts:
622
+ expert.build(input_shape)
623
+ pw_out_shape = self.conv_pw_experts[0].compute_output_shape(input_shape)
624
+ self.bn1.build(pw_out_shape)
625
+ for expert in self.conv_dw_experts:
626
+ expert.build(pw_out_shape)
627
+ dw_out_shape = self.conv_dw_experts[0].compute_output_shape(
628
+ pw_out_shape
629
+ )
630
+ self.bn2.build(dw_out_shape)
631
+ if hasattr(self.se, "build"):
632
+ self.se.build(dw_out_shape)
633
+ for expert in self.conv_pwl_experts:
634
+ expert.build(dw_out_shape)
635
+ pwl_out_shape = self.conv_pwl_experts[0].compute_output_shape(
636
+ dw_out_shape
637
+ )
638
+ self.bn3.build(pwl_out_shape)
639
+
640
+ def _apply_cond_conv(self, x, experts, routing_weights):
641
+ outputs = []
642
+ for i, expert in enumerate(experts):
643
+ expert_out = expert(x)
644
+ weight = keras.ops.reshape(routing_weights[:, i], (-1, 1, 1, 1))
645
+ outputs.append(expert_out * weight)
646
+ return keras.ops.sum(outputs, axis=0)
647
+
648
+ def call(self, x, training=False):
649
+ shortcut = x
650
+ pooled_inputs = self.pool(x)
651
+ routing_weights = keras.activations.sigmoid(
652
+ self.routing_fn(pooled_inputs)
653
+ )
654
+ x = self._apply_cond_conv(x, self.conv_pw_experts, routing_weights)
655
+ x = self.bn1(x, training=training)
656
+ x = self.act1(x)
657
+ x = self._apply_cond_conv(x, self.conv_dw_experts, routing_weights)
658
+ x = self.bn2(x, training=training)
659
+ x = self.act2(x)
660
+ x = self.se(x, training=training)
661
+ x = self._apply_cond_conv(x, self.conv_pwl_experts, routing_weights)
662
+ x = self.bn3(x, training=training)
663
+ if self.has_skip:
664
+ x = self.drop_path(x, training=training) + shortcut
665
+ return x
666
+
667
+ def get_config(self):
668
+ config = super().get_config()
669
+ config.update(
670
+ {
671
+ "filters": self.filters,
672
+ "dw_kernel_size": self.dw_kernel_size,
673
+ "stride": self.stride,
674
+ "dilation": self.dilation,
675
+ "pad_type": self.pad_type,
676
+ "noskip": self.noskip,
677
+ "exp_ratio": self.exp_ratio,
678
+ "exp_kernel_size": self.exp_kernel_size,
679
+ "pw_kernel_size": self.pw_kernel_size,
680
+ "act_layer": self.act_layer,
681
+ "se_layer": keras.saving.serialize_keras_object(self.se_layer),
682
+ "num_experts": self.num_experts,
683
+ "drop_path_rate": self.drop_path_rate,
684
+ "data_format": self.data_format,
685
+ "channel_axis": self.channel_axis,
686
+ }
687
+ )
688
+ return config
689
+
690
+ @classmethod
691
+ def from_config(cls, config):
692
+ config["se_layer"] = keras.saving.deserialize_keras_object(
693
+ config.pop("se_layer")
694
+ )
695
+ return cls(**config)
696
+
697
+
698
+ class MobileNetV5MultiScaleFusionAdapter(keras.layers.Layer):
699
+ """Multi-Scale Fusion Adapter for MobileNetV5.
700
+
701
+ This layer fuses feature maps from different scales of the backbone,
702
+ concatenates them, processes them through a FFN (Feed-Forward Network),
703
+ and then resizes the output to a target resolution.
704
+
705
+ Args:
706
+ in_chs: list of int. A list of channel counts for each input feature
707
+ map.
708
+ filters: int. The number of output channels.
709
+ output_resolution: int or tuple. The target output resolution.
710
+ expansion_ratio: float. The expansion ratio for the FFN.
711
+ interpolation_mode: str. The interpolation mode for upsampling feature
712
+ maps.
713
+ layer_scale_init_value: float. The initial value for layer scale. If
714
+ `None`, layer scale is not used.
715
+ noskip: bool. If `True`, the skip connection in the FFN is disabled.
716
+ act_layer: str. The activation function to use.
717
+ norm_layer: str. The normalization layer to use.
718
+ data_format: str. The format of the input data, either
719
+ `"channels_last"` or `"channels_first"`.
720
+ channel_axis: int. The axis representing the channels in the input
721
+ tensor.
722
+ """
723
+
724
+ def __init__(
725
+ self,
726
+ in_chs,
727
+ filters,
728
+ output_resolution,
729
+ expansion_ratio=2.0,
730
+ interpolation_mode="nearest",
731
+ layer_scale_init_value=None,
732
+ noskip=True,
733
+ act_layer="gelu",
734
+ norm_layer="rms_norm",
735
+ data_format=None,
736
+ channel_axis=None,
737
+ dtype=None,
738
+ **kwargs,
739
+ ):
740
+ super().__init__(dtype=dtype, **kwargs)
741
+ self.in_chs = in_chs
742
+ self.filters = filters
743
+ self.output_resolution_arg = output_resolution
744
+ self.expansion_ratio = expansion_ratio
745
+ self.interpolation_mode = interpolation_mode
746
+ self.layer_scale_init_value = layer_scale_init_value
747
+ self.noskip = noskip
748
+ self.act_layer = act_layer
749
+ self.norm_layer_name = norm_layer
750
+ self.data_format = data_format
751
+ self.channel_axis = channel_axis
752
+ self.in_channels = sum(in_chs)
753
+ if isinstance(output_resolution, int):
754
+ self.output_resolution = (output_resolution, output_resolution)
755
+ else:
756
+ self.output_resolution = output_resolution
757
+ self.ffn = UniversalInvertedResidual(
758
+ filters=self.filters,
759
+ dw_kernel_size_mid=0,
760
+ exp_ratio=expansion_ratio,
761
+ act_layer=act_layer,
762
+ norm_layer=norm_layer,
763
+ noskip=noskip,
764
+ layer_scale_init_value=layer_scale_init_value,
765
+ data_format=self.data_format,
766
+ channel_axis=self.channel_axis,
767
+ dtype=self.dtype_policy,
768
+ )
769
+ if norm_layer == "rms_norm":
770
+ self.norm = RmsNorm2d(
771
+ self.filters,
772
+ data_format=self.data_format,
773
+ gamma_initializer="ones",
774
+ channel_axis=self.channel_axis,
775
+ dtype=self.dtype_policy,
776
+ )
777
+ else:
778
+ self.norm = keras.layers.BatchNormalization(
779
+ axis=self.channel_axis,
780
+ gamma_initializer="ones",
781
+ beta_initializer="zeros",
782
+ dtype=self.dtype_policy,
783
+ )
784
+
785
+ def build(self, input_shape):
786
+ super().build(input_shape)
787
+ ffn_input_shape = list(input_shape[0])
788
+ if self.data_format == "channels_first":
789
+ ffn_input_shape[1] = self.in_channels
790
+ else:
791
+ ffn_input_shape[-1] = self.in_channels
792
+ self.ffn.build(tuple(ffn_input_shape))
793
+ norm_input_shape = self.ffn.compute_output_shape(tuple(ffn_input_shape))
794
+ self.norm.build(norm_input_shape)
795
+
796
+ def call(self, inputs, training=False):
797
+ shape_hr = keras.ops.shape(inputs[0])
798
+ if self.data_format == "channels_first":
799
+ high_resolution = (shape_hr[2], shape_hr[3])
800
+ else:
801
+ high_resolution = (shape_hr[1], shape_hr[2])
802
+ resized_inputs = []
803
+ for img in inputs:
804
+ if self.data_format == "channels_first":
805
+ img_transposed = keras.ops.transpose(img, (0, 2, 3, 1))
806
+ else:
807
+ img_transposed = img
808
+ img_resized = keras.ops.image.resize(
809
+ img_transposed,
810
+ size=high_resolution,
811
+ interpolation=self.interpolation_mode,
812
+ )
813
+ if self.data_format == "channels_first":
814
+ resized_inputs.append(
815
+ keras.ops.transpose(img_resized, (0, 3, 1, 2))
816
+ )
817
+ else:
818
+ resized_inputs.append(img_resized)
819
+ channel_cat_imgs = keras.ops.concatenate(
820
+ resized_inputs, axis=self.channel_axis
821
+ )
822
+ img = self.ffn(channel_cat_imgs, training=training)
823
+ if (
824
+ high_resolution[0] != self.output_resolution[0]
825
+ or high_resolution[1] != self.output_resolution[1]
826
+ ):
827
+ h_in, w_in = high_resolution
828
+ h_out, w_out = self.output_resolution
829
+ if h_in % h_out == 0 and w_in % w_out == 0:
830
+ h_stride = h_in // h_out
831
+ w_stride = w_in // w_out
832
+ img = keras.ops.nn.average_pool(
833
+ img,
834
+ pool_size=(h_stride, w_stride),
835
+ strides=(h_stride, w_stride),
836
+ padding="valid",
837
+ data_format=self.data_format,
838
+ )
839
+ else:
840
+ if self.data_format == "channels_first":
841
+ img_transposed = keras.ops.transpose(img, (0, 2, 3, 1))
842
+ else:
843
+ img_transposed = img
844
+ img_resized = keras.ops.image.resize(
845
+ img_transposed,
846
+ size=self.output_resolution,
847
+ interpolation="bilinear",
848
+ )
849
+ if self.data_format == "channels_first":
850
+ img = keras.ops.transpose(img_resized, (0, 3, 1, 2))
851
+ else:
852
+ img = img_resized
853
+ img = self.norm(img, training=training)
854
+ return img
855
+
856
+ def compute_output_shape(self, input_shape):
857
+ batch_size = input_shape[0][0]
858
+ if self.data_format == "channels_first":
859
+ return (
860
+ batch_size,
861
+ self.filters,
862
+ self.output_resolution[0],
863
+ self.output_resolution[1],
864
+ )
865
+ else:
866
+ return (
867
+ batch_size,
868
+ self.output_resolution[0],
869
+ self.output_resolution[1],
870
+ self.filters,
871
+ )
872
+
873
+ def get_config(self):
874
+ config = super().get_config()
875
+ config.update(
876
+ {
877
+ "in_chs": self.in_chs,
878
+ "filters": self.filters,
879
+ "output_resolution": self.output_resolution_arg,
880
+ "expansion_ratio": self.expansion_ratio,
881
+ "interpolation_mode": self.interpolation_mode,
882
+ "layer_scale_init_value": self.layer_scale_init_value,
883
+ "noskip": self.noskip,
884
+ "act_layer": self.act_layer,
885
+ "norm_layer": self.norm_layer_name,
886
+ "data_format": self.data_format,
887
+ "channel_axis": self.channel_axis,
888
+ }
889
+ )
890
+ return config