keras-hub-nightly 0.23.0.dev202510100415__py3-none-any.whl → 0.23.0.dev202510110411__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of keras-hub-nightly might be problematic. Click here for more details.

@@ -0,0 +1,462 @@
1
+ import keras
2
+
3
+ from keras_hub.src.models.mobilenet.util import adjust_channels
4
+
5
+
6
+ class DropPath(keras.layers.Layer):
7
+ """Implements the DropPath layer.
8
+
9
+ DropPath is a form of stochastic depth, where connections are randomly
10
+ dropped during training.
11
+
12
+ Args:
13
+ drop_prob: float. The probability of dropping a path.
14
+ scale_by_keep: bool. If `True`, scale the output by `1/keep_prob`.
15
+ """
16
+
17
+ def __init__(self, drop_prob=0.0, scale_by_keep=True, dtype=None, **kwargs):
18
+ super().__init__(dtype=dtype, **kwargs)
19
+ self.drop_prob = drop_prob
20
+ self.scale_by_keep = scale_by_keep
21
+
22
+ def call(self, x, training=False):
23
+ if self.drop_prob == 0.0 or not training:
24
+ return x
25
+ keep_prob = 1.0 - self.drop_prob
26
+ shape = (keras.ops.shape(x)[0],) + (1,) * (len(x.shape) - 1)
27
+ random_tensor = keep_prob + keras.random.uniform(
28
+ shape, 0, 1, dtype=x.dtype
29
+ )
30
+ random_tensor = keras.ops.floor(random_tensor)
31
+ if keep_prob > 0.0 and self.scale_by_keep:
32
+ random_tensor = random_tensor / keep_prob
33
+ return x * random_tensor
34
+
35
+ def get_config(self):
36
+ config = super().get_config()
37
+ config.update(
38
+ {"drop_prob": self.drop_prob, "scale_by_keep": self.scale_by_keep}
39
+ )
40
+ return config
41
+
42
+
43
+ class LayerScale2d(keras.layers.Layer):
44
+ """A layer that applies a learnable scaling factor to the input tensor.
45
+
46
+ This layer scales the input tensor by a learnable `gamma` parameter. The
47
+ scaling is applied channel-wise.
48
+
49
+ Args:
50
+ dim: int. The number of channels in the input tensor.
51
+ init_values: float. The initial value for the `gamma` parameter.
52
+ data_format: str. The format of the input data, either
53
+ `"channels_last"` or `"channels_first"`.
54
+ channel_axis: int. The axis representing the channels in the input
55
+ tensor.
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ dim,
61
+ init_values=1e-5,
62
+ data_format=None,
63
+ channel_axis=None,
64
+ dtype=None,
65
+ **kwargs,
66
+ ):
67
+ super().__init__(dtype=dtype, **kwargs)
68
+ self.dim = dim
69
+ self.init_values = init_values
70
+ self.data_format = data_format
71
+ self.channel_axis = channel_axis
72
+
73
+ def build(self, input_shape):
74
+ self.gamma = self.add_weight(
75
+ shape=(self.dim,),
76
+ initializer=keras.initializers.Constant(self.init_values),
77
+ trainable=True,
78
+ name="gamma",
79
+ )
80
+ super().build(input_shape)
81
+
82
+ def call(self, x):
83
+ if self.data_format == "channels_first":
84
+ gamma = keras.ops.reshape(self.gamma, (1, self.dim, 1, 1))
85
+ else:
86
+ gamma = keras.ops.reshape(self.gamma, (1, 1, 1, self.dim))
87
+ return x * gamma
88
+
89
+ def get_config(self):
90
+ config = super().get_config()
91
+ config.update(
92
+ {
93
+ "dim": self.dim,
94
+ "init_values": self.init_values,
95
+ "data_format": self.data_format,
96
+ "channel_axis": self.channel_axis,
97
+ }
98
+ )
99
+ return config
100
+
101
+
102
+ class RmsNorm2d(keras.layers.Layer):
103
+ """A layer that applies Root Mean Square Normalization to a 2D input.
104
+
105
+ This layer normalizes the input tensor along the channel dimension using
106
+ the root mean square of the values, and then scales it by a learnable
107
+ `gamma` parameter.
108
+
109
+ Args:
110
+ dim: int. The number of channels in the input tensor.
111
+ eps: float. A small epsilon value to avoid division by zero.
112
+ data_format: str. The format of the input data, either
113
+ `"channels_last"` or `"channels_first"`.
114
+ channel_axis: int. The axis representing the channels in the input
115
+ tensor.
116
+ """
117
+
118
+ def __init__(
119
+ self,
120
+ dim,
121
+ eps=1e-6,
122
+ data_format=None,
123
+ channel_axis=None,
124
+ gamma_initializer="ones",
125
+ dtype=None,
126
+ **kwargs,
127
+ ):
128
+ super().__init__(dtype=dtype, **kwargs)
129
+ self.dim = dim
130
+ self.eps = eps
131
+ self.data_format = data_format
132
+ self.channel_axis = channel_axis
133
+ self.gamma_initializer = gamma_initializer
134
+
135
+ def build(self, input_shape):
136
+ self.gamma = self.add_weight(
137
+ shape=(self.dim,),
138
+ initializer=self.gamma_initializer,
139
+ trainable=True,
140
+ name="gamma",
141
+ )
142
+ super().build(input_shape)
143
+
144
+ def call(self, x):
145
+ input_dtype = x.dtype
146
+ if self.data_format == "channels_first":
147
+ x_permuted = keras.ops.transpose(x, (0, 2, 3, 1))
148
+ else:
149
+ x_permuted = x
150
+ x_float = keras.ops.cast(x_permuted, "float32")
151
+ norm_factor = keras.ops.rsqrt(
152
+ keras.ops.mean(keras.ops.square(x_float), axis=-1, keepdims=True)
153
+ + self.eps
154
+ )
155
+ norm_x_float = x_float * norm_factor
156
+ norm_x = keras.ops.cast(norm_x_float, input_dtype)
157
+ scaled_x = norm_x * self.gamma
158
+ if self.data_format == "channels_first":
159
+ output = keras.ops.transpose(scaled_x, (0, 3, 1, 2))
160
+ else:
161
+ output = scaled_x
162
+ return output
163
+
164
+ def get_config(self):
165
+ config = super().get_config()
166
+ config.update(
167
+ {
168
+ "dim": self.dim,
169
+ "eps": self.eps,
170
+ "data_format": self.data_format,
171
+ "channel_axis": self.channel_axis,
172
+ "gamma_initializer": self.gamma_initializer,
173
+ }
174
+ )
175
+ return config
176
+
177
+
178
+ class ConvNormAct(keras.layers.Layer):
179
+ """A layer that combines convolution, normalization, and activation.
180
+
181
+ This layer provides a convenient way to create a sequence of a 2D
182
+ convolution, a normalization layer, and an activation function.
183
+
184
+ Args:
185
+ out_chs: int. The number of output channels.
186
+ kernel_size: int or tuple. The size of the convolution kernel.
187
+ stride: int or tuple. The stride of the convolution.
188
+ dilation: int or tuple. The dilation rate of the convolution.
189
+ groups: int. The number of groups for a grouped convolution.
190
+ bias: bool. If `True`, a bias term is used in the convolution.
191
+ pad_type: str. The type of padding to use. `"same"` or `""` for same
192
+ padding, otherwise valid padding.
193
+ apply_act: bool. If `True`, an activation function is applied.
194
+ act_layer: str. The name of the activation function to use.
195
+ norm_layer: str. The name of the normalization layer to use.
196
+ Supported values are `"batch_norm"` and `"rms_norm"`.
197
+ data_format: str. The format of the input data, either
198
+ `"channels_last"` or `"channels_first"`.
199
+ channel_axis: int. The axis representing the channels in the input
200
+ tensor.
201
+ """
202
+
203
+ def __init__(
204
+ self,
205
+ out_chs,
206
+ kernel_size,
207
+ stride=1,
208
+ dilation=1,
209
+ groups=1,
210
+ bias=False,
211
+ pad_type="same",
212
+ apply_act=True,
213
+ act_layer="relu",
214
+ norm_layer="batch_norm",
215
+ data_format=None,
216
+ channel_axis=None,
217
+ dtype=None,
218
+ **kwargs,
219
+ ):
220
+ super().__init__(dtype=dtype, **kwargs)
221
+ self.out_chs = out_chs
222
+ self.kernel_size = kernel_size
223
+ self.stride = stride
224
+ self.dilation = dilation
225
+ self.groups = groups
226
+ self.bias = bias
227
+ self.pad_type = pad_type
228
+ self.apply_act = apply_act
229
+ self.act_layer = act_layer
230
+ self.norm_layer = norm_layer
231
+ self.data_format = data_format
232
+ self.channel_axis = channel_axis
233
+ self.kernel_initializer = keras.initializers.VarianceScaling(
234
+ scale=2.0, mode="fan_out", distribution="untruncated_normal"
235
+ )
236
+ self.bias_initializer = "zeros"
237
+ padding_mode = "valid"
238
+ if pad_type.lower() == "" or pad_type.lower() == "same":
239
+ padding_mode = "same"
240
+
241
+ self.conv = keras.layers.Conv2D(
242
+ out_chs,
243
+ kernel_size,
244
+ strides=stride,
245
+ padding=padding_mode,
246
+ dilation_rate=dilation,
247
+ groups=groups,
248
+ use_bias=bias,
249
+ data_format=self.data_format,
250
+ kernel_initializer=self.kernel_initializer,
251
+ bias_initializer=self.bias_initializer,
252
+ dtype=self.dtype_policy,
253
+ )
254
+
255
+ if norm_layer == "batch_norm":
256
+ self.norm = keras.layers.BatchNormalization(
257
+ axis=self.channel_axis,
258
+ epsilon=1e-5,
259
+ gamma_initializer="ones",
260
+ beta_initializer="zeros",
261
+ dtype=self.dtype_policy,
262
+ )
263
+ elif norm_layer == "rms_norm":
264
+ self.norm = RmsNorm2d(
265
+ out_chs,
266
+ data_format=self.data_format,
267
+ channel_axis=self.channel_axis,
268
+ gamma_initializer="ones",
269
+ dtype=self.dtype_policy,
270
+ )
271
+ else:
272
+ ln_axis = [1, 2, 3]
273
+ if self.data_format == "channels_first":
274
+ ln_axis = [2, 3, 1]
275
+ self.norm = keras.layers.LayerNormalization(
276
+ axis=ln_axis,
277
+ dtype=self.dtype_policy,
278
+ )
279
+
280
+ if self.apply_act:
281
+ if act_layer == "gelu":
282
+ self.act = keras.layers.Activation(
283
+ lambda x: keras.activations.gelu(x, approximate=False),
284
+ dtype=self.dtype_policy,
285
+ )
286
+ else:
287
+ self.act = keras.layers.Activation(
288
+ act_layer,
289
+ dtype=self.dtype_policy,
290
+ )
291
+
292
+ def build(self, input_shape):
293
+ self.conv.build(input_shape)
294
+ conv_output_shape = self.conv.compute_output_shape(input_shape)
295
+ self.norm.build(conv_output_shape)
296
+ if self.apply_act:
297
+ self.act.build(conv_output_shape)
298
+ self.built = True
299
+
300
+ def call(self, x, training=False):
301
+ x = self.conv(x)
302
+ x = self.norm(x, training=training)
303
+ if self.apply_act:
304
+ x = self.act(x)
305
+ return x
306
+
307
+ def compute_output_shape(self, input_shape):
308
+ return self.conv.compute_output_shape(input_shape)
309
+
310
+ def get_config(self):
311
+ config = super().get_config()
312
+ config.update(
313
+ {
314
+ "out_chs": self.out_chs,
315
+ "kernel_size": self.kernel_size,
316
+ "stride": self.stride,
317
+ "dilation": self.dilation,
318
+ "groups": self.groups,
319
+ "bias": self.bias,
320
+ "pad_type": self.pad_type,
321
+ "apply_act": self.apply_act,
322
+ "act_layer": self.act_layer,
323
+ "norm_layer": self.norm_layer,
324
+ "data_format": self.data_format,
325
+ "channel_axis": self.channel_axis,
326
+ }
327
+ )
328
+ return config
329
+
330
+
331
+ class SEModule(keras.layers.Layer):
332
+ """Implements the Squeeze-and-Excitation (SE) module.
333
+
334
+ The SE module adaptively recalibrates channel-wise feature responses by
335
+ explicitly modeling interdependencies between channels.
336
+
337
+ Args:
338
+ channels: int. The number of input channels.
339
+ rd_ratio: float. The reduction ratio for the bottleneck channels.
340
+ rd_channels: int. The number of bottleneck channels. If specified,
341
+ `rd_ratio` is ignored.
342
+ rd_divisor: int. The divisor for rounding the number of bottleneck
343
+ channels.
344
+ add_maxpool: bool. If `True`, max pooling is used in addition to
345
+ average pooling for the squeeze operation.
346
+ bias: bool. If `True`, bias terms are used in the fully connected
347
+ layers.
348
+ act_layer: str. The activation function for the bottleneck layer.
349
+ norm_layer: str. The normalization layer to use.
350
+ data_format: str. The format of the input data, either
351
+ `"channels_last"` or `"channels_first"`.
352
+ channel_axis: int. The axis representing the channels in the input
353
+ tensor.
354
+ gate_layer: str. The gating activation function.
355
+ """
356
+
357
+ def __init__(
358
+ self,
359
+ channels,
360
+ rd_ratio=1.0 / 16,
361
+ rd_channels=None,
362
+ rd_divisor=8,
363
+ add_maxpool=False,
364
+ bias=True,
365
+ act_layer="relu",
366
+ norm_layer=None,
367
+ data_format=None,
368
+ channel_axis=None,
369
+ gate_layer="sigmoid",
370
+ dtype=None,
371
+ **kwargs,
372
+ ):
373
+ super().__init__(dtype=dtype, **kwargs)
374
+ self.channels = channels
375
+ self.add_maxpool = add_maxpool
376
+ if not rd_channels:
377
+ rd_channels = adjust_channels(
378
+ channels * rd_ratio, rd_divisor, round_limit=0.0
379
+ )
380
+ self.rd_ratio = rd_ratio
381
+ self.rd_channels = rd_channels
382
+ self.rd_divisor = rd_divisor
383
+ self.bias = bias
384
+ self.act_layer_arg = act_layer
385
+ self.kernel_initializer = keras.initializers.VarianceScaling(
386
+ scale=2.0, mode="fan_out", distribution="untruncated_normal"
387
+ )
388
+ self.bias_initializer = "zeros"
389
+ self.norm_layer_arg = norm_layer
390
+ self.gate_layer_arg = gate_layer
391
+ self.data_format = data_format
392
+ self.channel_axis = channel_axis
393
+ self.mean_axis = [2, 3] if data_format == "channels_first" else [1, 2]
394
+ self.fc1 = keras.layers.Conv2D(
395
+ rd_channels,
396
+ kernel_size=1,
397
+ use_bias=bias,
398
+ name="fc1",
399
+ data_format=self.data_format,
400
+ kernel_initializer=self.kernel_initializer,
401
+ bias_initializer=self.bias_initializer,
402
+ dtype=self.dtype_policy,
403
+ )
404
+ self.bn = (
405
+ keras.layers.BatchNormalization(
406
+ axis=channel_axis, dtype=self.dtype_policy
407
+ )
408
+ if norm_layer
409
+ else (lambda x, training: x)
410
+ )
411
+ self.act = keras.layers.Activation(act_layer, dtype=self.dtype_policy)
412
+ self.fc2 = keras.layers.Conv2D(
413
+ channels,
414
+ kernel_size=1,
415
+ use_bias=bias,
416
+ name="fc2",
417
+ data_format=self.data_format,
418
+ kernel_initializer=self.kernel_initializer,
419
+ bias_initializer=self.bias_initializer,
420
+ dtype=self.dtype_policy,
421
+ )
422
+ self.gate = keras.layers.Activation(gate_layer, dtype=self.dtype_policy)
423
+
424
+ def build(self, input_shape):
425
+ self.fc1.build(input_shape)
426
+ fc1_output_shape = self.fc1.compute_output_shape(input_shape)
427
+ if hasattr(self.bn, "build"):
428
+ self.bn.build(fc1_output_shape)
429
+ self.act.build(fc1_output_shape)
430
+ self.fc2.build(fc1_output_shape)
431
+ self.built = True
432
+
433
+ def call(self, x, training=False):
434
+ x_se = keras.ops.mean(x, axis=self.mean_axis, keepdims=True)
435
+ if self.add_maxpool:
436
+ x_se = 0.5 * x_se + 0.5 * keras.ops.max(
437
+ x, axis=self.mean_axis, keepdims=True
438
+ )
439
+ x_se = self.fc1(x_se)
440
+ x_se = self.bn(x_se, training=training)
441
+ x_se = self.act(x_se)
442
+ x_se = self.fc2(x_se)
443
+ return x * self.gate(x_se)
444
+
445
+ def get_config(self):
446
+ config = super().get_config()
447
+ config.update(
448
+ {
449
+ "channels": self.channels,
450
+ "rd_ratio": self.rd_ratio,
451
+ "rd_channels": self.rd_channels,
452
+ "rd_divisor": self.rd_divisor,
453
+ "add_maxpool": self.add_maxpool,
454
+ "bias": self.bias,
455
+ "act_layer": self.act_layer_arg,
456
+ "norm_layer": self.norm_layer_arg,
457
+ "gate_layer": self.gate_layer_arg,
458
+ "data_format": self.data_format,
459
+ "channel_axis": self.channel_axis,
460
+ }
461
+ )
462
+ return config
@@ -0,0 +1,146 @@
1
+ import keras
2
+
3
+ from keras_hub.src.models.mobilenet.util import adjust_channels
4
+
5
+
6
+ def num_groups(group_size, channels):
7
+ if not group_size:
8
+ return 1
9
+ else:
10
+ if channels % group_size != 0:
11
+ raise ValueError(
12
+ f"Number of channels ({channels}) must be divisible by "
13
+ "group size ({group_size})."
14
+ )
15
+ return channels // group_size
16
+
17
+
18
+ def parse_ksize(ss):
19
+ if ss.isdigit():
20
+ return int(ss)
21
+ else:
22
+ return [int(k) for k in ss.split(".")]
23
+
24
+
25
+ def round_channels(
26
+ channels, multiplier=1.0, divisor=8, channel_min=None, round_limit=0.9
27
+ ):
28
+ if not multiplier:
29
+ return channels
30
+ return adjust_channels(channels * multiplier, divisor, channel_min)
31
+
32
+
33
+ def feature_take_indices(num_stages, indices):
34
+ if not isinstance(indices, (tuple, list)):
35
+ indices = (indices,)
36
+ if any(i < 0 for i in indices):
37
+ indices = [i if i >= 0 else num_stages + i for i in indices]
38
+ return indices, max(indices)
39
+
40
+
41
+ class SelectAdaptivePool2d(keras.layers.Layer):
42
+ """A layer that selects and applies a 2D adaptive pooling strategy.
43
+
44
+ This layer supports various pooling types like average, max, or a
45
+ combination of both. It can also flatten the output.
46
+
47
+ Args:
48
+ pool_type: str. The type of pooling to apply. One of `"avg"`, `"max"`,
49
+ `"avgmax"`, `"catavgmax"`, or `""` (identity).
50
+ flatten: bool. If `True`, the output is flattened after pooling.
51
+ data_format: str. The format of the input data, either
52
+ `"channels_last"` or `"channels_first"`.
53
+ channel_axis: int. The axis representing the channels in the input
54
+ tensor.
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ pool_type="avg",
60
+ flatten=False,
61
+ data_format=None,
62
+ channel_axis=None,
63
+ dtype=None,
64
+ **kwargs,
65
+ ):
66
+ super().__init__(dtype=dtype, **kwargs)
67
+ self.pool_type = pool_type.lower()
68
+ self.flatten = flatten
69
+ self.data_format = data_format
70
+ self.channels_axis = channel_axis
71
+ self.pool = None
72
+ self.pool_avg = None
73
+ self.pool_max = None
74
+ self.pool_cat = None
75
+ self.flatten_layer = None
76
+ if self.pool_type not in ("avg", "max", "avgmax", "catavgmax", ""):
77
+ raise ValueError(f"Invalid pool type: {self.pool_type}")
78
+
79
+ def build(self, input_shape):
80
+ if self.pool_type == "avg":
81
+ self.pool = keras.layers.GlobalAveragePooling2D(
82
+ data_format=self.data_format,
83
+ keepdims=not self.flatten,
84
+ dtype=self.dtype_policy,
85
+ )
86
+ elif self.pool_type == "max":
87
+ self.pool = keras.layers.GlobalMaxPooling2D(
88
+ data_format=self.data_format,
89
+ keepdims=not self.flatten,
90
+ dtype=self.dtype_policy,
91
+ )
92
+ elif self.pool_type in ("avgmax", "catavgmax"):
93
+ self.pool_avg = keras.layers.GlobalAveragePooling2D(
94
+ data_format=self.data_format,
95
+ keepdims=not self.flatten,
96
+ dtype=self.dtype_policy,
97
+ )
98
+ self.pool_max = keras.layers.GlobalMaxPooling2D(
99
+ data_format=self.data_format,
100
+ keepdims=not self.flatten,
101
+ dtype=self.dtype_policy,
102
+ )
103
+ if self.pool_type == "catavgmax":
104
+ axis = 1 if self.data_format == "channels_first" else -1
105
+ self.pool_cat = keras.layers.Concatenate(
106
+ axis=axis, dtype=self.dtype_policy
107
+ )
108
+ elif not self.pool_type:
109
+ self.pool = keras.layers.Identity(dtype=self.dtype_policy)
110
+ if self.flatten:
111
+ self.flatten_layer = keras.layers.Flatten(
112
+ dtype=self.dtype_policy
113
+ )
114
+ super().build(input_shape)
115
+
116
+ def call(self, x):
117
+ if self.pool_type in ("avg", "max"):
118
+ return self.pool(x)
119
+ elif self.pool_type == "avgmax":
120
+ x_avg = self.pool_avg(x)
121
+ x_max = self.pool_max(x)
122
+ return 0.5 * (x_avg + x_max)
123
+ elif self.pool_type == "catavgmax":
124
+ x_avg = self.pool_avg(x)
125
+ x_max = self.pool_max(x)
126
+ return self.pool_cat([x_avg, x_max])
127
+ elif not self.pool_type:
128
+ x = self.pool(x)
129
+ if self.flatten_layer:
130
+ x = self.flatten_layer(x)
131
+ return x
132
+ return x
133
+
134
+ def feat_mult(self):
135
+ return 2 if self.pool_type == "catavgmax" else 1
136
+
137
+ def get_config(self):
138
+ config = super().get_config()
139
+ config.update(
140
+ {
141
+ "pool_type": self.pool_type,
142
+ "flatten": self.flatten,
143
+ "data_format": self.data_format,
144
+ }
145
+ )
146
+ return config
@@ -502,10 +502,17 @@ def jax_memory_cleanup(layer):
502
502
  # For jax, delete all previous allocated memory to avoid temporarily
503
503
  # duplicating variable allocations. torch and tensorflow have stateful
504
504
  # variable types and do not need this fix.
505
+ # Skip deletion for sharded arrays to avoid breaking references in
506
+ # distributed setups.
505
507
  if keras.config.backend() == "jax":
506
508
  for weight in layer.weights:
507
- if getattr(weight, "_value", None) is not None:
508
- weight._value.delete()
509
+ if weight._value is not None:
510
+ # Do not delete sharded arrays, as they may be referenced in
511
+ # JAX's distributed computation graph and deletion can cause
512
+ # errors.
513
+ sharding = getattr(weight._value, "sharding", None)
514
+ if sharding is None:
515
+ weight._value.delete()
509
516
 
510
517
 
511
518
  def set_dtype_in_config(config, dtype=None):