keras-hub-nightly 0.23.0.dev202510100415__py3-none-any.whl → 0.23.0.dev202510110411__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of keras-hub-nightly might be problematic. Click here for more details.

@@ -0,0 +1,699 @@
1
+ import keras
2
+
3
+ from keras_hub.src.models.mobilenetv5.mobilenetv5_layers import DropPath
4
+ from keras_hub.src.models.mobilenetv5.mobilenetv5_layers import LayerScale2d
5
+ from keras_hub.src.models.mobilenetv5.mobilenetv5_layers import RmsNorm2d
6
+
7
+
8
+ class MultiQueryAttention2d(keras.layers.Layer):
9
+ """Implements 2D Multi-Query Attention.
10
+
11
+ This layer performs attention on 2D spatial inputs. It uses a multi-query
12
+ attention mechanism where multiple query heads attend to a single key and
13
+ value.
14
+
15
+ Args:
16
+ filters: int. The output channel dimension.
17
+ num_heads: int. The number of attention heads.
18
+ key_dim: int. The dimension of the key. If `None`, it is calculated as
19
+ `dim // num_heads`.
20
+ value_dim: int. The dimension of the value. If `None`, it is calculated
21
+ as `dim // num_heads`.
22
+ query_strides: int or tuple. The stride for downsampling the query.
23
+ kv_stride: int. The stride for downsampling the key and value.
24
+ dw_kernel_size: int. The kernel size for the depthwise convolution used
25
+ for downsampling.
26
+ dilation: int. The dilation rate for the depthwise convolution.
27
+ padding: str. The padding type for convolutions.
28
+ attn_drop: float. The dropout rate for the attention weights.
29
+ proj_drop: float. The dropout rate for the output projection.
30
+ norm_layer: keras.layers.Layer. The normalization layer to use.
31
+ use_bias: bool. If `True`, bias terms are used in convolutions.
32
+ channel_axis: int. The axis representing the channels in the input
33
+ tensor.
34
+ data_format: str. The format of the input data, either
35
+ `"channels_last"` or `"channels_first"`.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ filters,
41
+ num_heads=8,
42
+ key_dim=None,
43
+ value_dim=None,
44
+ query_strides=1,
45
+ kv_stride=1,
46
+ dw_kernel_size=3,
47
+ dilation=1,
48
+ padding="same",
49
+ attn_drop=0.0,
50
+ proj_drop=0.0,
51
+ norm_layer=keras.layers.BatchNormalization,
52
+ use_bias=False,
53
+ channel_axis=None,
54
+ data_format=None,
55
+ dtype=None,
56
+ **kwargs,
57
+ ):
58
+ super().__init__(dtype=dtype, **kwargs)
59
+ self.filters = filters
60
+ self.num_heads = num_heads
61
+ self.key_dim_arg = key_dim
62
+ self.value_dim_arg = value_dim
63
+ self.query_strides_arg = query_strides
64
+ self.kv_stride = kv_stride
65
+ self.dw_kernel_size = dw_kernel_size
66
+ self.dilation = dilation
67
+ self.padding_arg = padding
68
+ self.attn_drop_rate = attn_drop
69
+ self.proj_drop_rate = proj_drop
70
+ self.norm_layer = norm_layer
71
+ self.use_bias = use_bias
72
+ self.channel_axis = channel_axis
73
+ self.data_format = data_format
74
+ self.query_strides = (
75
+ query_strides
76
+ if isinstance(query_strides, (list, tuple))
77
+ else (query_strides, query_strides)
78
+ )
79
+ self.has_query_strides = any([s > 1 for s in self.query_strides])
80
+ self.padding = padding
81
+ self.conv_kernel_initializer = keras.initializers.VarianceScaling(
82
+ scale=2.0, mode="fan_out", distribution="untruncated_normal"
83
+ )
84
+ self.bias_initializer = "zeros"
85
+ self.attn_drop_layer = keras.layers.Dropout(
86
+ attn_drop, dtype=self.dtype_policy
87
+ )
88
+
89
+ def build(self, input_shape):
90
+ super().build(input_shape)
91
+ dim = input_shape[self.channel_axis]
92
+ self.key_dim = self.key_dim_arg or dim // self.num_heads
93
+ self.value_dim = self.value_dim_arg or dim // self.num_heads
94
+ self.scale = self.key_dim**-0.5
95
+ query_layers = []
96
+ if self.has_query_strides:
97
+ pool_padding = "valid" if self.padding == "valid" else "same"
98
+ query_layers.append(
99
+ keras.layers.AveragePooling2D(
100
+ pool_size=self.query_strides,
101
+ strides=self.query_strides,
102
+ padding=pool_padding,
103
+ data_format=self.data_format,
104
+ name="query_down_pool",
105
+ dtype=self.dtype_policy,
106
+ )
107
+ )
108
+ if self.norm_layer is RmsNorm2d:
109
+ norm = self.norm_layer(
110
+ dim=dim,
111
+ channel_axis=self.channel_axis,
112
+ data_format=self.data_format,
113
+ name="query_norm",
114
+ dtype=self.dtype_policy,
115
+ )
116
+ else:
117
+ norm = self.norm_layer(
118
+ axis=self.channel_axis,
119
+ name="query_norm",
120
+ gamma_initializer="ones",
121
+ beta_initializer="zeros",
122
+ dtype=self.dtype_policy,
123
+ )
124
+ query_layers.append(norm)
125
+ query_layers.append(
126
+ keras.layers.Conv2D(
127
+ filters=self.num_heads * self.key_dim,
128
+ kernel_size=1,
129
+ use_bias=self.use_bias,
130
+ data_format=self.data_format,
131
+ name="query_proj",
132
+ kernel_initializer=self.conv_kernel_initializer,
133
+ bias_initializer=self.bias_initializer,
134
+ dtype=self.dtype_policy,
135
+ )
136
+ )
137
+ self.query_layers = query_layers
138
+ key_layers = []
139
+ if self.kv_stride > 1:
140
+ key_layers.append(
141
+ keras.layers.DepthwiseConv2D(
142
+ kernel_size=self.dw_kernel_size,
143
+ strides=self.kv_stride,
144
+ dilation_rate=self.dilation,
145
+ padding=self.padding,
146
+ data_format=self.data_format,
147
+ name="key_down_conv",
148
+ depthwise_initializer=self.conv_kernel_initializer,
149
+ bias_initializer=self.bias_initializer,
150
+ use_bias=False,
151
+ dtype=self.dtype_policy,
152
+ )
153
+ )
154
+ if self.norm_layer is RmsNorm2d:
155
+ norm = self.norm_layer(
156
+ dim=dim,
157
+ channel_axis=self.channel_axis,
158
+ data_format=self.data_format,
159
+ name="key_norm",
160
+ dtype=self.dtype_policy,
161
+ )
162
+ else:
163
+ norm = self.norm_layer(
164
+ axis=self.channel_axis,
165
+ gamma_initializer="ones",
166
+ beta_initializer="zeros",
167
+ name="key_norm",
168
+ dtype=self.dtype_policy,
169
+ )
170
+ key_layers.append(norm)
171
+ key_layers.append(
172
+ keras.layers.Conv2D(
173
+ filters=self.key_dim,
174
+ kernel_size=1,
175
+ padding="valid",
176
+ use_bias=self.use_bias,
177
+ data_format=self.data_format,
178
+ name="key_proj",
179
+ kernel_initializer=self.conv_kernel_initializer,
180
+ bias_initializer=self.bias_initializer,
181
+ dtype=self.dtype_policy,
182
+ )
183
+ )
184
+ self.key_layers = key_layers
185
+ value_layers = []
186
+ if self.kv_stride > 1:
187
+ value_layers.append(
188
+ keras.layers.DepthwiseConv2D(
189
+ kernel_size=self.dw_kernel_size,
190
+ strides=self.kv_stride,
191
+ dilation_rate=self.dilation,
192
+ padding=self.padding,
193
+ data_format=self.data_format,
194
+ name="value_down_conv",
195
+ depthwise_initializer=self.conv_kernel_initializer,
196
+ bias_initializer=self.bias_initializer,
197
+ use_bias=False,
198
+ dtype=self.dtype_policy,
199
+ )
200
+ )
201
+ if self.norm_layer is RmsNorm2d:
202
+ norm = self.norm_layer(
203
+ dim=dim,
204
+ channel_axis=self.channel_axis,
205
+ data_format=self.data_format,
206
+ name="value_norm",
207
+ dtype=self.dtype_policy,
208
+ )
209
+ else:
210
+ norm = self.norm_layer(
211
+ axis=self.channel_axis,
212
+ gamma_initializer="ones",
213
+ beta_initializer="zeros",
214
+ name="value_norm",
215
+ dtype=self.dtype_policy,
216
+ )
217
+ value_layers.append(norm)
218
+ value_layers.append(
219
+ keras.layers.Conv2D(
220
+ filters=self.value_dim,
221
+ kernel_size=1,
222
+ padding="valid",
223
+ use_bias=self.use_bias,
224
+ data_format=self.data_format,
225
+ name="value_proj",
226
+ kernel_initializer=self.conv_kernel_initializer,
227
+ bias_initializer=self.bias_initializer,
228
+ dtype=self.dtype_policy,
229
+ )
230
+ )
231
+ self.value_layers = value_layers
232
+ output_layers = []
233
+ if self.has_query_strides:
234
+ output_layers.append(
235
+ keras.layers.UpSampling2D(
236
+ size=self.query_strides,
237
+ interpolation="bilinear",
238
+ data_format=self.data_format,
239
+ name="output_upsample",
240
+ dtype=self.dtype_policy,
241
+ )
242
+ )
243
+ output_layers.append(
244
+ keras.layers.Conv2D(
245
+ filters=self.filters,
246
+ kernel_size=1,
247
+ use_bias=self.use_bias,
248
+ data_format=self.data_format,
249
+ name="output_proj",
250
+ kernel_initializer=self.conv_kernel_initializer,
251
+ bias_initializer=self.bias_initializer,
252
+ dtype=self.dtype_policy,
253
+ )
254
+ )
255
+ output_layers.append(
256
+ keras.layers.Dropout(self.proj_drop_rate, dtype=self.dtype_policy)
257
+ )
258
+ self.output_proj_layers = output_layers
259
+
260
+ def call(self, x, training=False):
261
+ B = keras.ops.shape(x)[0]
262
+ q = x
263
+ for layer in self.query_layers:
264
+ try:
265
+ q = layer(q, training=training)
266
+ except TypeError:
267
+ q = layer(q)
268
+ k = x
269
+ for layer in self.key_layers:
270
+ try:
271
+ k = layer(k, training=training)
272
+ except TypeError:
273
+ k = layer(k)
274
+ v = x
275
+ for layer in self.value_layers:
276
+ try:
277
+ v = layer(v, training=training)
278
+ except TypeError:
279
+ v = layer(v)
280
+ if self.data_format == "channels_last":
281
+ q = keras.ops.transpose(q, (0, 3, 1, 2))
282
+ k = keras.ops.transpose(k, (0, 3, 1, 2))
283
+ v = keras.ops.transpose(v, (0, 3, 1, 2))
284
+ s_q = keras.ops.shape(q)
285
+ h_q, w_q = s_q[2], s_q[3]
286
+ q = keras.ops.reshape(q, (B, self.num_heads, self.key_dim, -1))
287
+ q = keras.ops.transpose(q, (0, 1, 3, 2))
288
+ k = keras.ops.reshape(k, (B, self.key_dim, -1))
289
+ k = keras.ops.transpose(k, (0, 2, 1))
290
+ k = keras.ops.expand_dims(k, axis=1)
291
+ v = keras.ops.reshape(v, (B, self.value_dim, -1))
292
+ v = keras.ops.transpose(v, (0, 2, 1))
293
+ v = keras.ops.expand_dims(v, axis=1)
294
+ q = q * self.scale
295
+ attn = keras.ops.matmul(q, keras.ops.transpose(k, (0, 1, 3, 2)))
296
+ attn = keras.ops.softmax(attn, axis=-1)
297
+ attn = self.attn_drop_layer(attn, training=training)
298
+ o = keras.ops.matmul(attn, v)
299
+ o = keras.ops.transpose(o, (0, 2, 1, 3))
300
+ feat_dim = self.num_heads * self.value_dim
301
+ o = keras.ops.reshape(o, (B, h_q, w_q, feat_dim))
302
+ if self.data_format == "channels_first":
303
+ o = keras.ops.transpose(o, (0, 3, 1, 2))
304
+ x_out = o
305
+ for layer in self.output_proj_layers:
306
+ try:
307
+ x_out = layer(x_out, training=training)
308
+ except TypeError:
309
+ x_out = layer(x_out)
310
+ return x_out
311
+
312
+ def get_config(self):
313
+ config = super().get_config()
314
+ config.update(
315
+ {
316
+ "filters": self.filters,
317
+ "num_heads": self.num_heads,
318
+ "key_dim": self.key_dim_arg,
319
+ "value_dim": self.value_dim_arg,
320
+ "query_strides": self.query_strides_arg,
321
+ "kv_stride": self.kv_stride,
322
+ "dw_kernel_size": self.dw_kernel_size,
323
+ "dilation": self.dilation,
324
+ "padding": self.padding_arg,
325
+ "attn_drop": self.attn_drop_rate,
326
+ "proj_drop": self.proj_drop_rate,
327
+ "norm_layer": keras.saving.serialize_keras_object(
328
+ self.norm_layer
329
+ ),
330
+ "use_bias": self.use_bias,
331
+ "channel_axis": self.channel_axis,
332
+ "data_format": self.data_format,
333
+ }
334
+ )
335
+ return config
336
+
337
+ @classmethod
338
+ def from_config(cls, config):
339
+ config["norm_layer"] = keras.saving.deserialize_keras_object(
340
+ config["norm_layer"]
341
+ )
342
+ return cls(**config)
343
+
344
+
345
+ class Attention2d(keras.layers.Layer):
346
+ """Implements 2D Multi-Head Attention.
347
+
348
+ This layer performs multi-head self-attention on 2D spatial inputs.
349
+
350
+ Args:
351
+ filters: int. The output channel dimension.
352
+ num_heads: int. The number of attention heads.
353
+ bias: bool. If `True`, bias terms are used in the qkv and projection
354
+ convolutions.
355
+ attn_drop: float. The dropout rate for the attention weights.
356
+ proj_drop: float. The dropout rate for the output projection.
357
+ channel_axis: int. The axis representing the channels in the input
358
+ tensor.
359
+ data_format: str. The format of the input data, either
360
+ `"channels_last"` or `"channels_first"`.
361
+ """
362
+
363
+ def __init__(
364
+ self,
365
+ filters,
366
+ num_heads=32,
367
+ bias=True,
368
+ attn_drop=0.0,
369
+ proj_drop=0.0,
370
+ channel_axis=None,
371
+ data_format=None,
372
+ dtype=None,
373
+ **kwargs,
374
+ ):
375
+ super().__init__(dtype=dtype, **kwargs)
376
+ self.filters = filters
377
+ self.num_heads = num_heads
378
+ self.bias = bias
379
+ self.attn_drop_rate = attn_drop
380
+ self.proj_drop_rate = proj_drop
381
+ self.channel_axis = channel_axis
382
+ self.data_format = data_format
383
+ self.conv_kernel_initializer = keras.initializers.VarianceScaling(
384
+ scale=2.0, mode="fan_out", distribution="untruncated_normal"
385
+ )
386
+ self.bias_initializer = "zeros"
387
+ self.attn_drop_layer = keras.layers.Dropout(
388
+ attn_drop, dtype=self.dtype_policy
389
+ )
390
+
391
+ def build(self, input_shape):
392
+ super().build(input_shape)
393
+ dim = input_shape[self.channel_axis]
394
+ self.head_dim = dim // self.num_heads
395
+ self.qkv = keras.layers.Conv2D(
396
+ dim * 3,
397
+ kernel_size=1,
398
+ use_bias=self.bias,
399
+ data_format=self.data_format,
400
+ name="qkv",
401
+ dtype=self.dtype_policy,
402
+ kernel_initializer=self.conv_kernel_initializer,
403
+ bias_initializer=self.bias_initializer,
404
+ )
405
+ self.proj = keras.layers.Conv2D(
406
+ self.filters,
407
+ kernel_size=1,
408
+ use_bias=self.bias,
409
+ data_format=self.data_format,
410
+ name="proj",
411
+ dtype=self.dtype_policy,
412
+ kernel_initializer=self.conv_kernel_initializer,
413
+ bias_initializer=self.bias_initializer,
414
+ )
415
+ self.proj_drop_layer = keras.layers.Dropout(
416
+ self.proj_drop_rate, dtype=self.dtype_policy
417
+ )
418
+
419
+ def call(self, x, attn_mask=None, training=False):
420
+ if self.data_format == "channels_first":
421
+ B, C, H, W = keras.ops.shape(x)
422
+ else:
423
+ B, H, W, C = keras.ops.shape(x)
424
+ qkv = self.qkv(x)
425
+ if self.data_format == "channels_last":
426
+ qkv = keras.ops.transpose(qkv, (0, 3, 1, 2))
427
+ q, k, v = keras.ops.unstack(
428
+ keras.ops.reshape(
429
+ qkv,
430
+ (B, 3, self.num_heads, self.head_dim, H * W),
431
+ ),
432
+ axis=1,
433
+ )
434
+ q = keras.ops.transpose(q, (0, 1, 3, 2))
435
+ k = keras.ops.transpose(k, (0, 1, 2, 3))
436
+ v = keras.ops.transpose(v, (0, 1, 3, 2))
437
+ attn = keras.ops.matmul(q, k) * (self.head_dim**-0.5)
438
+ if attn_mask is not None:
439
+ attn = attn + attn_mask
440
+ attn = keras.ops.softmax(attn, axis=-1)
441
+ attn = self.attn_drop_layer(attn, training=training)
442
+ x = keras.ops.matmul(attn, v)
443
+ x = keras.ops.transpose(x, (0, 1, 3, 2))
444
+ if self.data_format == "channels_first":
445
+ x = keras.ops.reshape(x, (B, -1, H, W))
446
+ else:
447
+ x = keras.ops.reshape(x, (B, H, W, -1))
448
+ x = self.proj(x)
449
+ x = self.proj_drop_layer(x, training=training)
450
+ return x
451
+
452
+ def get_config(self):
453
+ config = super().get_config()
454
+ config.update(
455
+ {
456
+ "filters": self.filters,
457
+ "num_heads": self.num_heads,
458
+ "bias": self.bias,
459
+ "attn_drop": self.attn_drop_rate,
460
+ "proj_drop": self.proj_drop_rate,
461
+ "channel_axis": self.channel_axis,
462
+ "data_format": self.data_format,
463
+ }
464
+ )
465
+ return config
466
+
467
+
468
+ class MobileAttention(keras.layers.Layer):
469
+ """MobileNetV5 attention block.
470
+
471
+ This block combines attention with depthwise convolutions for efficiency.
472
+ It can use either standard Multi-Head Attention or Multi-Query Attention.
473
+
474
+ Args:
475
+ filters: int. The number of output channels.
476
+ stride: int. The stride for the block.
477
+ dw_kernel_size: int. The kernel size for the depthwise convolution in
478
+ Multi-Query Attention.
479
+ dilation: int. The dilation rate for convolutions.
480
+ pad_type: str. The padding type for convolutions.
481
+ num_heads: int. The number of attention heads.
482
+ key_dim: int. The dimension of the key.
483
+ value_dim: int. The dimension of the value.
484
+ use_multi_query: bool. If `True`, use `MultiQueryAttention2d`,
485
+ otherwise use `Attention2d`.
486
+ query_strides: tuple. The strides for the query downsampling.
487
+ kv_stride: int. The stride for key/value downsampling.
488
+ cpe_dw_kernel_size: int. The kernel size for the conditional position
489
+ encoding depthwise convolution.
490
+ noskip: bool. If `True`, the skip connection is disabled.
491
+ norm_layer: str. The normalization layer to use (`"batch_norm"` or
492
+ `"rms_norm"`).
493
+ drop_path_rate: float. The stochastic depth rate.
494
+ attn_drop: float. The dropout rate for the attention weights.
495
+ proj_drop: float. The dropout rate for the output projection.
496
+ layer_scale_init_value: float. The initial value for layer scale. If
497
+ `None`, layer scale is not used.
498
+ use_bias: bool. If `True`, bias terms are used in convolutions.
499
+ use_cpe: bool. If `True`, a conditional position encoding is added.
500
+ channel_axis: int. The axis representing the channels in the input
501
+ tensor.
502
+ data_format: str. The format of the input data, either
503
+ `"channels_last"` or `"channels_first"`.
504
+ """
505
+
506
+ def __init__(
507
+ self,
508
+ filters,
509
+ stride=1,
510
+ dw_kernel_size=3,
511
+ dilation=1,
512
+ pad_type="same",
513
+ num_heads=8,
514
+ key_dim=64,
515
+ value_dim=64,
516
+ use_multi_query=False,
517
+ query_strides=(1, 1),
518
+ kv_stride=1,
519
+ cpe_dw_kernel_size=3,
520
+ noskip=False,
521
+ norm_layer="batch_norm",
522
+ drop_path_rate=0.0,
523
+ attn_drop=0.0,
524
+ proj_drop=0.0,
525
+ layer_scale_init_value=1e-5,
526
+ use_bias=False,
527
+ use_cpe=False,
528
+ channel_axis=None,
529
+ data_format=None,
530
+ dtype=None,
531
+ **kwargs,
532
+ ):
533
+ super().__init__(dtype=dtype, **kwargs)
534
+ self.filters = filters
535
+ self.stride = stride
536
+ self.dw_kernel_size = dw_kernel_size
537
+ self.dilation = dilation
538
+ self.pad_type = pad_type
539
+ self.num_heads = num_heads
540
+ self.key_dim = key_dim
541
+ self.value_dim = value_dim
542
+ self.use_multi_query = use_multi_query
543
+ self.query_strides = query_strides
544
+ self.kv_stride = kv_stride
545
+ self.cpe_dw_kernel_size = cpe_dw_kernel_size
546
+ self.noskip = noskip
547
+ self.norm_layer_name = norm_layer
548
+ self.drop_path_rate = drop_path_rate
549
+ self.attn_drop_rate = attn_drop
550
+ self.proj_drop_rate = proj_drop
551
+ self.layer_scale_init_value = layer_scale_init_value
552
+ self.use_bias = use_bias
553
+ self.use_cpe = use_cpe
554
+ self.channel_axis = channel_axis
555
+ self.data_format = data_format
556
+ self.conv_kernel_initializer = keras.initializers.VarianceScaling(
557
+ scale=2.0, mode="fan_out", distribution="untruncated_normal"
558
+ )
559
+ self.bias_initializer = "zeros"
560
+
561
+ def build(self, input_shape):
562
+ super().build(input_shape)
563
+ in_chs = input_shape[self.channel_axis]
564
+ self.has_skip = (
565
+ self.stride == 1 and in_chs == self.filters
566
+ ) and not self.noskip
567
+ if self.use_cpe:
568
+ self.conv_cpe_dw = keras.layers.DepthwiseConv2D(
569
+ kernel_size=self.cpe_dw_kernel_size,
570
+ strides=1,
571
+ padding="same",
572
+ dilation_rate=self.dilation,
573
+ use_bias=True,
574
+ data_format=self.data_format,
575
+ name="conv_cpe_dw",
576
+ depthwise_initializer=self.conv_kernel_initializer,
577
+ bias_initializer=self.bias_initializer,
578
+ dtype=self.dtype_policy,
579
+ )
580
+ else:
581
+ self.conv_cpe_dw = None
582
+ if self.norm_layer_name == "batch_norm":
583
+ self.norm = keras.layers.BatchNormalization(
584
+ axis=self.channel_axis,
585
+ name="norm",
586
+ gamma_initializer="ones",
587
+ beta_initializer="zeros",
588
+ dtype=self.dtype_policy,
589
+ )
590
+ elif self.norm_layer_name == "rms_norm":
591
+ self.norm = RmsNorm2d(
592
+ in_chs,
593
+ data_format=self.data_format,
594
+ gamma_initializer="ones",
595
+ channel_axis=self.channel_axis,
596
+ name="norm",
597
+ dtype=self.dtype_policy,
598
+ )
599
+ else:
600
+ raise ValueError(f"Unsupported norm_layer: {self.norm_layer_name}")
601
+ num_heads = self.num_heads
602
+ if num_heads is None:
603
+ assert in_chs % self.key_dim == 0
604
+ num_heads = in_chs // self.key_dim
605
+ attn_norm_layer = (
606
+ RmsNorm2d
607
+ if self.norm_layer_name == "rms_norm"
608
+ else keras.layers.BatchNormalization
609
+ )
610
+ if self.use_multi_query:
611
+ self.attn = MultiQueryAttention2d(
612
+ filters=self.filters,
613
+ num_heads=num_heads,
614
+ key_dim=self.key_dim,
615
+ value_dim=self.value_dim,
616
+ query_strides=self.query_strides,
617
+ kv_stride=self.kv_stride,
618
+ dw_kernel_size=self.dw_kernel_size,
619
+ dilation=self.dilation,
620
+ padding=self.pad_type,
621
+ attn_drop=self.attn_drop_rate,
622
+ proj_drop=self.proj_drop_rate,
623
+ norm_layer=attn_norm_layer,
624
+ use_bias=self.use_bias,
625
+ channel_axis=self.channel_axis,
626
+ data_format=self.data_format,
627
+ name="attn",
628
+ dtype=self.dtype_policy,
629
+ )
630
+ else:
631
+ self.attn = Attention2d(
632
+ filters=self.filters,
633
+ num_heads=num_heads,
634
+ attn_drop=self.attn_drop_rate,
635
+ proj_drop=self.proj_drop_rate,
636
+ bias=self.use_bias,
637
+ channel_axis=self.channel_axis,
638
+ data_format=self.data_format,
639
+ name="attn",
640
+ dtype=self.dtype_policy,
641
+ )
642
+ if self.layer_scale_init_value is not None:
643
+ self.layer_scale = LayerScale2d(
644
+ self.filters,
645
+ self.layer_scale_init_value,
646
+ name="layer_scale",
647
+ channel_axis=self.channel_axis,
648
+ data_format=self.data_format,
649
+ dtype=self.dtype_policy,
650
+ )
651
+ else:
652
+ self.layer_scale = lambda x: x
653
+ self.drop_path = (
654
+ DropPath(self.drop_path_rate, dtype=self.dtype_policy)
655
+ if self.drop_path_rate > 0.0
656
+ else lambda x, training: x
657
+ )
658
+
659
+ def call(self, x, training=False):
660
+ if self.conv_cpe_dw is not None:
661
+ x = x + self.conv_cpe_dw(x)
662
+ shortcut = x
663
+ x_normed = self.norm(x, training=training)
664
+ x_attn = self.attn(x_normed, training=training)
665
+ x_scaled = self.layer_scale(x_attn)
666
+ if self.has_skip:
667
+ return self.drop_path(x_scaled, training=training) + shortcut
668
+ else:
669
+ return x_scaled
670
+
671
+ def get_config(self):
672
+ config = super().get_config()
673
+ config.update(
674
+ {
675
+ "filters": self.filters,
676
+ "stride": self.stride,
677
+ "dw_kernel_size": self.dw_kernel_size,
678
+ "dilation": self.dilation,
679
+ "pad_type": self.pad_type,
680
+ "num_heads": self.num_heads,
681
+ "key_dim": self.key_dim,
682
+ "value_dim": self.value_dim,
683
+ "use_multi_query": self.use_multi_query,
684
+ "query_strides": self.query_strides,
685
+ "kv_stride": self.kv_stride,
686
+ "cpe_dw_kernel_size": self.cpe_dw_kernel_size,
687
+ "noskip": self.noskip,
688
+ "norm_layer": self.norm_layer_name,
689
+ "drop_path_rate": self.drop_path_rate,
690
+ "attn_drop": self.attn_drop_rate,
691
+ "proj_drop": self.proj_drop_rate,
692
+ "layer_scale_init_value": self.layer_scale_init_value,
693
+ "use_bias": self.use_bias,
694
+ "use_cpe": self.use_cpe,
695
+ "channel_axis": self.channel_axis,
696
+ "data_format": self.data_format,
697
+ }
698
+ )
699
+ return config