keras-hub-nightly 0.22.0.dev202507110420__py3-none-any.whl → 0.22.0.dev202507130422__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,918 @@
1
+ import keras
2
+
3
+
4
+ @keras.saving.register_keras_serializable(package="keras_hub")
5
+ class HGNetV2LearnableAffineBlock(keras.layers.Layer):
6
+ """
7
+ HGNetV2 learnable affine block.
8
+
9
+ Applies a learnable scale and bias to the input tensor, implementing a
10
+ simple affine transformation with trainable parameters.
11
+
12
+ Args:
13
+ scale_value: float, optional. Initial value for the scale parameter.
14
+ Defaults to 1.0.
15
+ bias_value: float, optional. Initial value for the bias parameter.
16
+ Defaults to 0.0.
17
+ **kwargs: Additional keyword arguments passed to the parent class.
18
+ """
19
+
20
+ def __init__(self, scale_value=1.0, bias_value=0.0, **kwargs):
21
+ super().__init__(**kwargs)
22
+ self.scale_value = scale_value
23
+ self.bias_value = bias_value
24
+
25
+ def build(self, input_shape):
26
+ self.scale = self.add_weight(
27
+ name="scale",
28
+ shape=(),
29
+ initializer=keras.initializers.Constant(self.scale_value),
30
+ trainable=True,
31
+ dtype=self.dtype,
32
+ )
33
+ self.bias = self.add_weight(
34
+ name="bias",
35
+ shape=(),
36
+ initializer=keras.initializers.Constant(self.bias_value),
37
+ trainable=True,
38
+ dtype=self.dtype,
39
+ )
40
+ super().build(input_shape)
41
+
42
+ def call(self, hidden_state):
43
+ return self.scale * hidden_state + self.bias
44
+
45
+ def get_config(self):
46
+ config = super().get_config()
47
+ config.update(
48
+ {"scale_value": self.scale_value, "bias_value": self.bias_value}
49
+ )
50
+ return config
51
+
52
+
53
+ @keras.saving.register_keras_serializable(package="keras_hub")
54
+ class HGNetV2ConvLayer(keras.layers.Layer):
55
+ """
56
+ HGNetV2 convolutional layer.
57
+
58
+ Performs a 2D convolution followed by batch normalization and an activation
59
+ function. Includes zero-padding to maintain spatial dimensions and
60
+ optionally applies a learnable affine block.
61
+
62
+ Args:
63
+ in_channels: int. Number of input channels.
64
+ out_channels: int. Number of output channels.
65
+ kernel_size: int. Size of the convolutional kernel.
66
+ stride: int. Stride of the convolution.
67
+ groups: int. Number of groups for group convolution.
68
+ activation: string, optional. Activation function to use ('relu',
69
+ 'gelu', 'tanh', or None). Defaults to 'relu'.
70
+ use_learnable_affine_block: bool, optional. Whether to include a
71
+ learnable affine block after activation. Defaults to False.
72
+ data_format: string, optional. Data format of the input ('channels_last'
73
+ or 'channels_first'). Defaults to None.
74
+ channel_axis: int, optional. Axis of the channel dimension. Defaults to
75
+ None.
76
+ **kwargs: Additional keyword arguments passed to the parent class.
77
+ """
78
+
79
+ def __init__(
80
+ self,
81
+ in_channels,
82
+ out_channels,
83
+ kernel_size,
84
+ stride,
85
+ groups,
86
+ activation="relu",
87
+ use_learnable_affine_block=False,
88
+ data_format=None,
89
+ channel_axis=None,
90
+ **kwargs,
91
+ ):
92
+ super().__init__(**kwargs)
93
+ self.in_channels = in_channels
94
+ self.out_channels = out_channels
95
+ self.kernel_size = kernel_size
96
+ self.stride = stride
97
+ self.groups = groups
98
+ self.activation_name = activation
99
+ self.use_learnable_affine_block = use_learnable_affine_block
100
+ self.data_format = data_format
101
+ self.channel_axis = channel_axis
102
+ pad = (self.kernel_size - 1) // 2
103
+ self.padding = keras.layers.ZeroPadding2D(
104
+ padding=((pad, pad), (pad, pad)),
105
+ data_format=self.data_format,
106
+ name=f"{self.name}_pad" if self.name else None,
107
+ )
108
+ self.convolution = keras.layers.Conv2D(
109
+ filters=self.out_channels,
110
+ kernel_size=self.kernel_size,
111
+ strides=self.stride,
112
+ groups=self.groups,
113
+ padding="valid",
114
+ use_bias=False,
115
+ data_format=self.data_format,
116
+ name=f"{self.name}_conv" if self.name else None,
117
+ dtype=self.dtype_policy,
118
+ )
119
+ self.normalization = keras.layers.BatchNormalization(
120
+ axis=self.channel_axis,
121
+ epsilon=1e-5,
122
+ momentum=0.9,
123
+ name=f"{self.name}_bn" if self.name else None,
124
+ dtype=self.dtype_policy,
125
+ )
126
+
127
+ if self.activation_name == "relu":
128
+ self.activation_layer = keras.layers.ReLU(
129
+ name=f"{self.name}_relu" if self.name else None,
130
+ dtype=self.dtype_policy,
131
+ )
132
+ elif self.activation_name == "gelu":
133
+ self.activation_layer = keras.layers.Activation(
134
+ "gelu",
135
+ name=f"{self.name}_gelu" if self.name else None,
136
+ dtype=self.dtype_policy,
137
+ )
138
+ elif self.activation_name == "tanh":
139
+ self.activation_layer = keras.layers.Activation(
140
+ "tanh",
141
+ name=f"{self.name}_tanh" if self.name else None,
142
+ dtype=self.dtype_policy,
143
+ )
144
+ elif self.activation_name is None:
145
+ self.activation_layer = keras.layers.Identity(
146
+ name=f"{self.name}_identity_activation" if self.name else None,
147
+ dtype=self.dtype_policy,
148
+ )
149
+ else:
150
+ raise ValueError(f"Unsupported activation: {self.activation_name}")
151
+
152
+ if self.use_learnable_affine_block:
153
+ self.lab = HGNetV2LearnableAffineBlock(
154
+ name=f"{self.name}_lab" if self.name else None,
155
+ dtype=self.dtype_policy,
156
+ )
157
+ else:
158
+ self.lab = keras.layers.Identity(
159
+ name=f"{self.name}_identity_lab" if self.name else None
160
+ )
161
+
162
+ def build(self, input_shape):
163
+ super().build(input_shape)
164
+ self.padding.build(input_shape)
165
+ padded_shape = self.padding.compute_output_shape(input_shape)
166
+ self.convolution.build(padded_shape)
167
+ conv_output_shape = self.convolution.compute_output_shape(padded_shape)
168
+ self.normalization.build(conv_output_shape)
169
+ self.lab.build(conv_output_shape)
170
+
171
+ def call(self, inputs, training=None):
172
+ hidden_state = self.padding(inputs)
173
+ hidden_state = self.convolution(hidden_state)
174
+ hidden_state = self.normalization(hidden_state, training=training)
175
+ hidden_state = self.activation_layer(hidden_state)
176
+ hidden_state = self.lab(hidden_state)
177
+ return hidden_state
178
+
179
+ def compute_output_shape(self, input_shape):
180
+ padded_shape = self.padding.compute_output_shape(input_shape)
181
+ shape = self.convolution.compute_output_shape(padded_shape)
182
+ return shape
183
+
184
+ def get_config(self):
185
+ config = super().get_config()
186
+ config.update(
187
+ {
188
+ "in_channels": self.in_channels,
189
+ "out_channels": self.out_channels,
190
+ "kernel_size": self.kernel_size,
191
+ "stride": self.stride,
192
+ "groups": self.groups,
193
+ "activation": self.activation_name,
194
+ "use_learnable_affine_block": self.use_learnable_affine_block,
195
+ "data_format": self.data_format,
196
+ "channel_axis": self.channel_axis,
197
+ }
198
+ )
199
+ return config
200
+
201
+
202
+ @keras.saving.register_keras_serializable(package="keras_hub")
203
+ class HGNetV2ConvLayerLight(keras.layers.Layer):
204
+ """
205
+ HGNetV2 lightweight convolutional layer.
206
+
207
+ Composes two convolutional layers: a 1x1 convolution followed by a depthwise
208
+ convolution with the specified kernel size. Optionally includes a learnable
209
+ affine block in the second convolution.
210
+
211
+ Args:
212
+ in_channels: int. Number of input channels.
213
+ out_channels: int. Number of output channels.
214
+ kernel_size: int. Size of the convolutional kernel for the depthwise
215
+ convolution.
216
+ use_learnable_affine_block: bool, optional. Whether to include a
217
+ learnable affine block in the second convolution. Defaults to False.
218
+ data_format: string, optional. Data format of the input. Defaults to
219
+ None.
220
+ channel_axis: int, optional. Axis of the channel dimension. Defaults to
221
+ None.
222
+ **kwargs: Additional keyword arguments passed to the parent class.
223
+ """
224
+
225
+ def __init__(
226
+ self,
227
+ in_channels,
228
+ out_channels,
229
+ kernel_size,
230
+ use_learnable_affine_block=False,
231
+ data_format=None,
232
+ channel_axis=None,
233
+ **kwargs,
234
+ ):
235
+ super().__init__(**kwargs)
236
+ self.in_channels = in_channels
237
+ self.out_channels = out_channels
238
+ self.kernel_size = kernel_size
239
+ self.use_learnable_affine_block = use_learnable_affine_block
240
+ self.data_format = data_format
241
+ self.channel_axis = channel_axis
242
+
243
+ self.conv1_layer = HGNetV2ConvLayer(
244
+ in_channels=self.in_channels,
245
+ out_channels=self.out_channels,
246
+ kernel_size=1,
247
+ stride=1,
248
+ groups=1,
249
+ activation=None,
250
+ use_learnable_affine_block=False,
251
+ data_format=self.data_format,
252
+ channel_axis=self.channel_axis,
253
+ name=f"{self.name}_conv1" if self.name else "conv1",
254
+ dtype=self.dtype_policy,
255
+ )
256
+ self.conv2_layer = HGNetV2ConvLayer(
257
+ in_channels=self.out_channels,
258
+ out_channels=self.out_channels,
259
+ kernel_size=self.kernel_size,
260
+ stride=1,
261
+ groups=self.out_channels,
262
+ activation="relu",
263
+ use_learnable_affine_block=self.use_learnable_affine_block,
264
+ data_format=self.data_format,
265
+ channel_axis=self.channel_axis,
266
+ name=f"{self.name}_conv2" if self.name else "conv2",
267
+ dtype=self.dtype_policy,
268
+ )
269
+
270
+ def build(self, input_shape):
271
+ super().build(input_shape)
272
+ self.conv1_layer.build(input_shape)
273
+ conv1_output_shape = self.conv1_layer.compute_output_shape(input_shape)
274
+ self.conv2_layer.build(conv1_output_shape)
275
+
276
+ def call(self, hidden_state, training=None):
277
+ hidden_state = self.conv1_layer(hidden_state, training=training)
278
+ hidden_state = self.conv2_layer(hidden_state, training=training)
279
+ return hidden_state
280
+
281
+ def get_config(self):
282
+ config = super().get_config()
283
+ config.update(
284
+ {
285
+ "in_channels": self.in_channels,
286
+ "out_channels": self.out_channels,
287
+ "kernel_size": self.kernel_size,
288
+ "use_learnable_affine_block": self.use_learnable_affine_block,
289
+ "data_format": self.data_format,
290
+ "channel_axis": self.channel_axis,
291
+ }
292
+ )
293
+ return config
294
+
295
+ def compute_output_shape(self, input_shape):
296
+ shape = self.conv1_layer.compute_output_shape(input_shape)
297
+ shape = self.conv2_layer.compute_output_shape(shape)
298
+ return shape
299
+
300
+
301
+ @keras.saving.register_keras_serializable(package="keras_hub")
302
+ class HGNetV2Embeddings(keras.layers.Layer):
303
+ """
304
+ HGNetV2 embedding layer.
305
+
306
+ Processes input images through a series of convolutional and pooling
307
+ operations to produce feature maps. Includes multiple convolutional layers
308
+ with specific configurations, padding, and concatenation.
309
+
310
+ Args:
311
+ stem_channels: list of int. Channels for the stem layers.
312
+ hidden_act: string. Activation function to use in the convolutional
313
+ layers.
314
+ use_learnable_affine_block: bool. Whether to use learnable affine blocks
315
+ in the convolutional layers.
316
+ data_format: string, optional. Data format of the input. Defaults to
317
+ None.
318
+ channel_axis: int, optional. Axis of the channel dimension. Defaults to
319
+ None.
320
+ **kwargs: Additional keyword arguments passed to the parent class.
321
+ """
322
+
323
+ def __init__(
324
+ self,
325
+ stem_channels,
326
+ hidden_act,
327
+ use_learnable_affine_block,
328
+ data_format=None,
329
+ channel_axis=None,
330
+ **kwargs,
331
+ ):
332
+ super().__init__(**kwargs)
333
+ self.stem_channels = stem_channels
334
+ self.hidden_act = hidden_act
335
+ self.use_learnable_affine_block = use_learnable_affine_block
336
+ self.data_format = data_format
337
+ self.channel_axis = channel_axis
338
+ self.stem1_layer = HGNetV2ConvLayer(
339
+ in_channels=self.stem_channels[0],
340
+ out_channels=self.stem_channels[1],
341
+ kernel_size=3,
342
+ stride=2,
343
+ groups=1,
344
+ activation=self.hidden_act,
345
+ use_learnable_affine_block=self.use_learnable_affine_block,
346
+ data_format=self.data_format,
347
+ channel_axis=self.channel_axis,
348
+ name=f"{self.name}_stem1" if self.name else "stem1",
349
+ dtype=self.dtype_policy,
350
+ )
351
+ self.padding1 = keras.layers.ZeroPadding2D(
352
+ padding=((0, 1), (0, 1)),
353
+ data_format=self.data_format,
354
+ name=f"{self.name}_padding1" if self.name else "padding1",
355
+ )
356
+ self.stem2a_layer = HGNetV2ConvLayer(
357
+ in_channels=self.stem_channels[1],
358
+ out_channels=self.stem_channels[1] // 2,
359
+ kernel_size=2,
360
+ stride=1,
361
+ groups=1,
362
+ activation=self.hidden_act,
363
+ use_learnable_affine_block=self.use_learnable_affine_block,
364
+ data_format=self.data_format,
365
+ channel_axis=self.channel_axis,
366
+ name=f"{self.name}_stem2a" if self.name else "stem2a",
367
+ dtype=self.dtype_policy,
368
+ )
369
+ self.padding2 = keras.layers.ZeroPadding2D(
370
+ padding=((0, 1), (0, 1)),
371
+ data_format=self.data_format,
372
+ name=f"{self.name}_padding2" if self.name else "padding2",
373
+ )
374
+ self.stem2b_layer = HGNetV2ConvLayer(
375
+ in_channels=self.stem_channels[1] // 2,
376
+ out_channels=self.stem_channels[1],
377
+ kernel_size=2,
378
+ stride=1,
379
+ groups=1,
380
+ activation=self.hidden_act,
381
+ use_learnable_affine_block=self.use_learnable_affine_block,
382
+ data_format=self.data_format,
383
+ channel_axis=self.channel_axis,
384
+ name=f"{self.name}_stem2b" if self.name else "stem2b",
385
+ dtype=self.dtype_policy,
386
+ )
387
+ self.pool_layer = keras.layers.MaxPool2D(
388
+ pool_size=2,
389
+ strides=1,
390
+ padding="valid",
391
+ data_format=self.data_format,
392
+ name=f"{self.name}_pool" if self.name else "pool",
393
+ )
394
+ self.concatenate_layer = keras.layers.Concatenate(
395
+ axis=self.channel_axis,
396
+ name=f"{self.name}_concat" if self.name else "concat",
397
+ )
398
+ self.stem3_layer = HGNetV2ConvLayer(
399
+ in_channels=self.stem_channels[1] * 2,
400
+ out_channels=self.stem_channels[1],
401
+ kernel_size=3,
402
+ stride=2,
403
+ groups=1,
404
+ activation=self.hidden_act,
405
+ use_learnable_affine_block=self.use_learnable_affine_block,
406
+ data_format=self.data_format,
407
+ channel_axis=self.channel_axis,
408
+ name=f"{self.name}_stem3" if self.name else "stem3",
409
+ dtype=self.dtype_policy,
410
+ )
411
+ self.stem4_layer = HGNetV2ConvLayer(
412
+ in_channels=self.stem_channels[1],
413
+ out_channels=self.stem_channels[2],
414
+ kernel_size=1,
415
+ stride=1,
416
+ groups=1,
417
+ activation=self.hidden_act,
418
+ use_learnable_affine_block=self.use_learnable_affine_block,
419
+ data_format=self.data_format,
420
+ channel_axis=self.channel_axis,
421
+ name=f"{self.name}_stem4" if self.name else "stem4",
422
+ dtype=self.dtype_policy,
423
+ )
424
+
425
+ def build(self, input_shape):
426
+ super().build(input_shape)
427
+ current_shape = input_shape
428
+ self.stem1_layer.build(current_shape)
429
+ current_shape = self.stem1_layer.compute_output_shape(current_shape)
430
+ padded_shape1 = self.padding1.compute_output_shape(current_shape)
431
+ self.stem2a_layer.build(padded_shape1)
432
+ shape_after_stem2a = self.stem2a_layer.compute_output_shape(
433
+ padded_shape1
434
+ )
435
+ padded_shape2 = self.padding2.compute_output_shape(shape_after_stem2a)
436
+ self.stem2b_layer.build(padded_shape2)
437
+ shape_after_stem2b = self.stem2b_layer.compute_output_shape(
438
+ padded_shape2
439
+ )
440
+ shape_after_pool = self.pool_layer.compute_output_shape(padded_shape1)
441
+ concat_input_shapes = [shape_after_pool, shape_after_stem2b]
442
+ shape_after_concat = self.concatenate_layer.compute_output_shape(
443
+ concat_input_shapes
444
+ )
445
+ self.stem3_layer.build(shape_after_concat)
446
+ shape_after_stem3 = self.stem3_layer.compute_output_shape(
447
+ shape_after_concat
448
+ )
449
+ self.stem4_layer.build(shape_after_stem3)
450
+
451
+ def compute_output_shape(self, input_shape):
452
+ current_shape = self.stem1_layer.compute_output_shape(input_shape)
453
+ padded_shape1 = self.padding1.compute_output_shape(current_shape)
454
+ shape_after_stem2a = self.stem2a_layer.compute_output_shape(
455
+ padded_shape1
456
+ )
457
+ padded_shape2 = self.padding2.compute_output_shape(shape_after_stem2a)
458
+ shape_after_stem2b = self.stem2b_layer.compute_output_shape(
459
+ padded_shape2
460
+ )
461
+ shape_after_pool = self.pool_layer.compute_output_shape(padded_shape1)
462
+ concat_input_shapes = [shape_after_pool, shape_after_stem2b]
463
+ shape_after_concat = self.concatenate_layer.compute_output_shape(
464
+ concat_input_shapes
465
+ )
466
+ shape_after_stem3 = self.stem3_layer.compute_output_shape(
467
+ shape_after_concat
468
+ )
469
+ final_shape = self.stem4_layer.compute_output_shape(shape_after_stem3)
470
+ return final_shape
471
+
472
+ def call(self, pixel_values, training=None):
473
+ embedding = self.stem1_layer(pixel_values, training=training)
474
+ embedding_padded_for_2a_and_pool = self.padding1(embedding)
475
+ emb_stem_2a = self.stem2a_layer(
476
+ embedding_padded_for_2a_and_pool, training=training
477
+ )
478
+ emb_stem_2a_padded = self.padding2(emb_stem_2a)
479
+ emb_stem_2a_processed = self.stem2b_layer(
480
+ emb_stem_2a_padded, training=training
481
+ )
482
+ pooled_emb = self.pool_layer(embedding_padded_for_2a_and_pool)
483
+ embedding_concatenated = self.concatenate_layer(
484
+ [pooled_emb, emb_stem_2a_processed]
485
+ )
486
+ embedding_after_stem3 = self.stem3_layer(
487
+ embedding_concatenated, training=training
488
+ )
489
+ final_embedding = self.stem4_layer(
490
+ embedding_after_stem3, training=training
491
+ )
492
+ return final_embedding
493
+
494
+ def get_config(self):
495
+ config = super().get_config()
496
+ config.update(
497
+ {
498
+ "stem_channels": self.stem_channels,
499
+ "hidden_act": self.hidden_act,
500
+ "use_learnable_affine_block": self.use_learnable_affine_block,
501
+ "data_format": self.data_format,
502
+ "channel_axis": self.channel_axis,
503
+ }
504
+ )
505
+ return config
506
+
507
+
508
+ @keras.saving.register_keras_serializable(package="keras_hub")
509
+ class HGNetV2BasicLayer(keras.layers.Layer):
510
+ """
511
+ HGNetV2 basic layer.
512
+
513
+ Consists of multiple convolutional blocks followed by aggregation through
514
+ concatenation and convolutional layers. Supports residual connections and
515
+ drop path for regularization.
516
+
517
+ Args:
518
+ in_channels: int. Number of input channels.
519
+ middle_channels: int. Number of channels in the intermediate
520
+ convolutional blocks.
521
+ out_channels: int. Number of output channels.
522
+ layer_num: int. Number of convolutional blocks in the layer.
523
+ kernel_size: int, optional. Kernel size for the convolutional blocks.
524
+ Defaults to 3.
525
+ residual: bool, optional. Whether to include a residual connection.
526
+ Defaults to False.
527
+ light_block: bool, optional. Whether to use lightweight convolutional
528
+ blocks. Defaults to False.
529
+ drop_path: float, optional. Drop path rate for regularization. Defaults
530
+ to 0.0.
531
+ use_learnable_affine_block: bool, optional. Whether to use learnable
532
+ affine blocks in the convolutional blocks. Defaults to False.
533
+ data_format: string, optional. Data format of the input. Defaults to
534
+ None.
535
+ channel_axis: int, optional. Axis of the channel dimension. Defaults to
536
+ None.
537
+ **kwargs: Additional keyword arguments passed to the parent class.
538
+ """
539
+
540
+ def __init__(
541
+ self,
542
+ in_channels,
543
+ middle_channels,
544
+ out_channels,
545
+ layer_num,
546
+ kernel_size=3,
547
+ residual=False,
548
+ light_block=False,
549
+ drop_path=0.0,
550
+ use_learnable_affine_block=False,
551
+ data_format=None,
552
+ channel_axis=None,
553
+ **kwargs,
554
+ ):
555
+ super().__init__(**kwargs)
556
+ self.in_channels_arg = in_channels
557
+ self.middle_channels = middle_channels
558
+ self.out_channels = out_channels
559
+ self.layer_num = layer_num
560
+ self.kernel_size = kernel_size
561
+ self.residual = residual
562
+ self.light_block = light_block
563
+ self.drop_path_rate = drop_path
564
+ self.use_learnable_affine_block = use_learnable_affine_block
565
+ self.data_format = data_format
566
+ self.channel_axis = channel_axis
567
+
568
+ self.layer_list = []
569
+ for i in range(self.layer_num):
570
+ block_input_channels = (
571
+ self.in_channels_arg if i == 0 else self.middle_channels
572
+ )
573
+ if self.light_block:
574
+ block = HGNetV2ConvLayerLight(
575
+ in_channels=block_input_channels,
576
+ out_channels=self.middle_channels,
577
+ kernel_size=self.kernel_size,
578
+ use_learnable_affine_block=self.use_learnable_affine_block,
579
+ data_format=self.data_format,
580
+ channel_axis=self.channel_axis,
581
+ name=f"{self.name}_light_block_{i}"
582
+ if self.name
583
+ else f"light_block_{i}",
584
+ dtype=self.dtype_policy,
585
+ )
586
+ else:
587
+ block = HGNetV2ConvLayer(
588
+ in_channels=block_input_channels,
589
+ out_channels=self.middle_channels,
590
+ kernel_size=self.kernel_size,
591
+ stride=1,
592
+ groups=1,
593
+ activation="relu",
594
+ use_learnable_affine_block=self.use_learnable_affine_block,
595
+ data_format=self.data_format,
596
+ channel_axis=self.channel_axis,
597
+ name=f"{self.name}_conv_block_{i}"
598
+ if self.name
599
+ else f"conv_block_{i}",
600
+ dtype=self.dtype_policy,
601
+ )
602
+ self.layer_list.append(block)
603
+ self.total_channels_for_aggregation = (
604
+ self.in_channels_arg + self.layer_num * self.middle_channels
605
+ )
606
+ self.aggregation_squeeze_conv = HGNetV2ConvLayer(
607
+ in_channels=self.total_channels_for_aggregation,
608
+ out_channels=self.out_channels // 2,
609
+ kernel_size=1,
610
+ stride=1,
611
+ groups=1,
612
+ activation="relu",
613
+ use_learnable_affine_block=self.use_learnable_affine_block,
614
+ data_format=self.data_format,
615
+ channel_axis=self.channel_axis,
616
+ name=f"{self.name}_agg_squeeze" if self.name else "agg_squeeze",
617
+ dtype=self.dtype_policy,
618
+ )
619
+ self.aggregation_excitation_conv = HGNetV2ConvLayer(
620
+ in_channels=self.out_channels // 2,
621
+ out_channels=self.out_channels,
622
+ kernel_size=1,
623
+ stride=1,
624
+ groups=1,
625
+ activation="relu",
626
+ use_learnable_affine_block=self.use_learnable_affine_block,
627
+ data_format=self.data_format,
628
+ channel_axis=self.channel_axis,
629
+ name=f"{self.name}_agg_excite" if self.name else "agg_excite",
630
+ dtype=self.dtype_policy,
631
+ )
632
+
633
+ if self.drop_path_rate > 0.0:
634
+ self.drop_path_layer = keras.layers.Dropout(
635
+ self.drop_path_rate,
636
+ noise_shape=(None, 1, 1, 1),
637
+ name=f"{self.name}_drop_path" if self.name else "drop_path",
638
+ )
639
+ else:
640
+ self.drop_path_layer = keras.layers.Identity(
641
+ name=f"{self.name}_identity_drop_path"
642
+ if self.name
643
+ else "identity_drop_path"
644
+ )
645
+
646
+ self.concatenate_layer = keras.layers.Concatenate(
647
+ axis=self.channel_axis,
648
+ name=f"{self.name}_concat" if self.name else "concat",
649
+ )
650
+ if self.residual:
651
+ self.add_layer = keras.layers.Add(
652
+ name=f"{self.name}_add_residual"
653
+ if self.name
654
+ else "add_residual"
655
+ )
656
+
657
+ def build(self, input_shape):
658
+ super().build(input_shape)
659
+ current_block_input_shape = input_shape
660
+ output_shapes_for_concat = [input_shape]
661
+ for i, layer_block in enumerate(self.layer_list):
662
+ layer_block.build(current_block_input_shape)
663
+ current_block_output_shape = layer_block.compute_output_shape(
664
+ current_block_input_shape
665
+ )
666
+ output_shapes_for_concat.append(current_block_output_shape)
667
+ current_block_input_shape = current_block_output_shape
668
+ concatenated_shape = self.concatenate_layer.compute_output_shape(
669
+ output_shapes_for_concat
670
+ )
671
+ self.aggregation_squeeze_conv.build(concatenated_shape)
672
+ agg_squeeze_output_shape = (
673
+ self.aggregation_squeeze_conv.compute_output_shape(
674
+ concatenated_shape
675
+ )
676
+ )
677
+ self.aggregation_excitation_conv.build(agg_squeeze_output_shape)
678
+
679
+ def compute_output_shape(self, input_shape):
680
+ output_tensors_shapes = [input_shape]
681
+ current_block_input_shape = input_shape
682
+ for layer_block in self.layer_list:
683
+ current_block_output_shape = layer_block.compute_output_shape(
684
+ current_block_input_shape
685
+ )
686
+ output_tensors_shapes.append(current_block_output_shape)
687
+ current_block_input_shape = current_block_output_shape
688
+ concatenated_features_shape = (
689
+ self.concatenate_layer.compute_output_shape(output_tensors_shapes)
690
+ )
691
+ aggregated_features_shape = (
692
+ self.aggregation_squeeze_conv.compute_output_shape(
693
+ concatenated_features_shape
694
+ )
695
+ )
696
+ final_output_shape = (
697
+ self.aggregation_excitation_conv.compute_output_shape(
698
+ aggregated_features_shape
699
+ )
700
+ )
701
+
702
+ return final_output_shape
703
+
704
+ def call(self, hidden_state, training=None):
705
+ identity = hidden_state
706
+ output_tensors = [hidden_state]
707
+
708
+ current_feature_map = hidden_state
709
+ for layer_block in self.layer_list:
710
+ current_feature_map = layer_block(
711
+ current_feature_map, training=training
712
+ )
713
+ output_tensors.append(current_feature_map)
714
+ concatenated_features = self.concatenate_layer(output_tensors)
715
+ aggregated_features = self.aggregation_squeeze_conv(
716
+ concatenated_features, training=training
717
+ )
718
+ aggregated_features = self.aggregation_excitation_conv(
719
+ aggregated_features, training=training
720
+ )
721
+ if self.residual:
722
+ dropped_features = self.drop_path_layer(
723
+ aggregated_features, training=training
724
+ )
725
+ final_output = self.add_layer([dropped_features, identity])
726
+ else:
727
+ final_output = aggregated_features
728
+ return final_output
729
+
730
+ def get_config(self):
731
+ config = super().get_config()
732
+ config.update(
733
+ {
734
+ "in_channels": self.in_channels_arg,
735
+ "middle_channels": self.middle_channels,
736
+ "out_channels": self.out_channels,
737
+ "layer_num": self.layer_num,
738
+ "kernel_size": self.kernel_size,
739
+ "residual": self.residual,
740
+ "light_block": self.light_block,
741
+ "drop_path": self.drop_path_rate,
742
+ "use_learnable_affine_block": self.use_learnable_affine_block,
743
+ "data_format": self.data_format,
744
+ "channel_axis": self.channel_axis,
745
+ }
746
+ )
747
+ return config
748
+
749
+
750
+ @keras.saving.register_keras_serializable(package="keras_hub")
751
+ class HGNetV2Stage(keras.layers.Layer):
752
+ """
753
+ HGNetV2 stage layer.
754
+
755
+ Represents a stage in the HGNetV2 model, which may include downsampling
756
+ followed by a series of basic layers. Each stage can have different
757
+ configurations for the number of blocks, channels, etc.
758
+
759
+ Args:
760
+ stage_in_channels: list of int. Input channels for each stage.
761
+ stage_mid_channels: list of int. Middle channels for each stage.
762
+ stage_out_channels: list of int. Output channels for each stage.
763
+ stage_num_blocks: list of int. Number of basic layers in each stage.
764
+ stage_num_of_layers: list of int. Number of convolutional blocks in
765
+ each basic layer.
766
+ apply_downsample: list of bools. Whether to downsample at the beginning
767
+ of each stage.
768
+ use_lightweight_conv_block: list of bools. Whether to use HGNetV2
769
+ lightweight convolutional block in the stage.
770
+ stage_kernel_size: list of int. Kernel sizes for each stage.
771
+ use_learnable_affine_block: bool. Whether to use learnable affine
772
+ blocks.
773
+ stage_index: int. The index of the current stage.
774
+ drop_path: float, optional. Drop path rate. Defaults to 0.0.
775
+ data_format: string, optional. Data format of the input. Defaults to
776
+ None.
777
+ channel_axis: int, optional. Axis of the channel dimension. Defaults to
778
+ None.
779
+ **kwargs: Additional keyword arguments passed to the parent class.
780
+ """
781
+
782
+ def __init__(
783
+ self,
784
+ stage_in_channels,
785
+ stage_mid_channels,
786
+ stage_out_channels,
787
+ stage_num_blocks,
788
+ stage_num_of_layers,
789
+ apply_downsample,
790
+ use_lightweight_conv_block,
791
+ stage_kernel_size,
792
+ use_learnable_affine_block,
793
+ stage_index: int,
794
+ drop_path: float = 0.0,
795
+ data_format=None,
796
+ channel_axis=None,
797
+ **kwargs,
798
+ ):
799
+ super().__init__(**kwargs)
800
+ self.stage_in_channels = stage_in_channels
801
+ self.stage_mid_channels = stage_mid_channels
802
+ self.stage_out_channels = stage_out_channels
803
+ self.stage_num_blocks = stage_num_blocks
804
+ self.stage_num_of_layers = stage_num_of_layers
805
+ self.apply_downsample = apply_downsample
806
+ self.use_lightweight_conv_block = use_lightweight_conv_block
807
+ self.stage_kernel_size = stage_kernel_size
808
+ self.use_learnable_affine_block = use_learnable_affine_block
809
+ self.stage_index = stage_index
810
+ self.drop_path = drop_path
811
+ self.data_format = data_format
812
+ self.channel_axis = channel_axis
813
+ self.current_stage_in_channels = stage_in_channels[stage_index]
814
+ self.current_stage_mid_channels = stage_mid_channels[stage_index]
815
+ self.current_stage_out_channels = stage_out_channels[stage_index]
816
+ self.current_stage_num_blocks = stage_num_blocks[stage_index]
817
+ self.current_stage_num_layers_per_block = stage_num_of_layers[
818
+ stage_index
819
+ ]
820
+ self.current_stage_is_downsample_active = apply_downsample[stage_index]
821
+ self.current_stage_is_light_block = use_lightweight_conv_block[
822
+ stage_index
823
+ ]
824
+ self.current_stage_kernel_size = stage_kernel_size[stage_index]
825
+ self.current_stage_use_lab = use_learnable_affine_block
826
+ self.current_stage_drop_path = drop_path
827
+ if self.current_stage_is_downsample_active:
828
+ self.downsample_layer = HGNetV2ConvLayer(
829
+ in_channels=self.current_stage_in_channels,
830
+ out_channels=self.current_stage_in_channels,
831
+ kernel_size=3,
832
+ stride=2,
833
+ groups=self.current_stage_in_channels,
834
+ activation=None,
835
+ use_learnable_affine_block=False,
836
+ data_format=self.data_format,
837
+ channel_axis=self.channel_axis,
838
+ name=f"{self.name}_downsample" if self.name else "downsample",
839
+ dtype=self.dtype_policy,
840
+ )
841
+ else:
842
+ self.downsample_layer = keras.layers.Identity(
843
+ name=f"{self.name}_identity_downsample"
844
+ if self.name
845
+ else "identity_downsample"
846
+ )
847
+
848
+ self.blocks_list = []
849
+ for i in range(self.current_stage_num_blocks):
850
+ basic_layer_input_channels = (
851
+ self.current_stage_in_channels
852
+ if i == 0
853
+ else self.current_stage_out_channels
854
+ )
855
+
856
+ block = HGNetV2BasicLayer(
857
+ in_channels=basic_layer_input_channels,
858
+ middle_channels=self.current_stage_mid_channels,
859
+ out_channels=self.current_stage_out_channels,
860
+ layer_num=self.current_stage_num_layers_per_block,
861
+ residual=(False if i == 0 else True),
862
+ kernel_size=self.current_stage_kernel_size,
863
+ light_block=self.current_stage_is_light_block,
864
+ drop_path=self.current_stage_drop_path,
865
+ use_learnable_affine_block=self.current_stage_use_lab,
866
+ data_format=self.data_format,
867
+ channel_axis=self.channel_axis,
868
+ name=f"{self.name}_block_{i}" if self.name else f"block_{i}",
869
+ dtype=self.dtype_policy,
870
+ )
871
+ self.blocks_list.append(block)
872
+
873
+ def build(self, input_shape):
874
+ super().build(input_shape)
875
+ current_input_shape = input_shape
876
+ self.downsample_layer.build(current_input_shape)
877
+ current_input_shape = self.downsample_layer.compute_output_shape(
878
+ current_input_shape
879
+ )
880
+
881
+ for block_item in self.blocks_list:
882
+ block_item.build(current_input_shape)
883
+ current_input_shape = block_item.compute_output_shape(
884
+ current_input_shape
885
+ )
886
+
887
+ def compute_output_shape(self, input_shape):
888
+ current_shape = self.downsample_layer.compute_output_shape(input_shape)
889
+ for block_item in self.blocks_list:
890
+ current_shape = block_item.compute_output_shape(current_shape)
891
+ return current_shape
892
+
893
+ def call(self, hidden_state, training=None):
894
+ hidden_state = self.downsample_layer(hidden_state, training=training)
895
+ for block_item in self.blocks_list:
896
+ hidden_state = block_item(hidden_state, training=training)
897
+ return hidden_state
898
+
899
+ def get_config(self):
900
+ config = super().get_config()
901
+ config.update(
902
+ {
903
+ "stage_in_channels": self.stage_in_channels,
904
+ "stage_mid_channels": self.stage_mid_channels,
905
+ "stage_out_channels": self.stage_out_channels,
906
+ "stage_num_blocks": self.stage_num_blocks,
907
+ "stage_num_of_layers": self.stage_num_of_layers,
908
+ "apply_downsample": self.apply_downsample,
909
+ "use_lightweight_conv_block": self.use_lightweight_conv_block,
910
+ "stage_kernel_size": self.stage_kernel_size,
911
+ "use_learnable_affine_block": self.use_learnable_affine_block,
912
+ "stage_index": self.stage_index,
913
+ "drop_path": self.drop_path,
914
+ "data_format": self.data_format,
915
+ "channel_axis": self.channel_axis,
916
+ }
917
+ )
918
+ return config