keras-hub-nightly 0.19.0.dev202412180349__py3-none-any.whl → 0.19.0.dev202412200346__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -35,6 +35,9 @@ from keras_hub.src.layers.preprocessing.multi_segment_packer import (
35
35
  from keras_hub.src.layers.preprocessing.random_deletion import RandomDeletion
36
36
  from keras_hub.src.layers.preprocessing.random_swap import RandomSwap
37
37
  from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker
38
+ from keras_hub.src.models.basnet.basnet_image_converter import (
39
+ BASNetImageConverter,
40
+ )
38
41
  from keras_hub.src.models.clip.clip_image_converter import CLIPImageConverter
39
42
  from keras_hub.src.models.deeplab_v3.deeplab_v3_image_converter import (
40
43
  DeepLabV3ImageConverter,
@@ -29,6 +29,9 @@ from keras_hub.src.models.bart.bart_seq_2_seq_lm_preprocessor import (
29
29
  BartSeq2SeqLMPreprocessor,
30
30
  )
31
31
  from keras_hub.src.models.bart.bart_tokenizer import BartTokenizer
32
+ from keras_hub.src.models.basnet.basnet import BASNetImageSegmenter
33
+ from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone
34
+ from keras_hub.src.models.basnet.basnet_preprocessor import BASNetPreprocessor
32
35
  from keras_hub.src.models.bert.bert_backbone import BertBackbone
33
36
  from keras_hub.src.models.bert.bert_masked_lm import BertMaskedLM
34
37
  from keras_hub.src.models.bert.bert_masked_lm_preprocessor import (
@@ -0,0 +1,5 @@
1
+ from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone
2
+ from keras_hub.src.models.basnet.basnet_presets import basnet_presets
3
+ from keras_hub.src.utils.preset_utils import register_presets
4
+
5
+ register_presets(basnet_presets, BASNetBackbone)
@@ -0,0 +1,122 @@
1
+ import keras
2
+
3
+ from keras_hub.src.api_export import keras_hub_export
4
+ from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone
5
+ from keras_hub.src.models.basnet.basnet_preprocessor import BASNetPreprocessor
6
+ from keras_hub.src.models.image_segmenter import ImageSegmenter
7
+
8
+
9
+ @keras_hub_export("keras_hub.models.BASNetImageSegmenter")
10
+ class BASNetImageSegmenter(ImageSegmenter):
11
+ """BASNet image segmentation task.
12
+
13
+ Args:
14
+ backbone: A `keras_hub.models.BASNetBackbone` instance.
15
+ preprocessor: `None`, a `keras_hub.models.Preprocessor` instance,
16
+ a `keras.Layer` instance, or a callable. If `None` no preprocessing
17
+ will be applied to the inputs.
18
+
19
+ Example:
20
+ ```python
21
+ import keras_hub
22
+
23
+ images = np.ones(shape=(1, 288, 288, 3))
24
+ labels = np.zeros(shape=(1, 288, 288, 1))
25
+
26
+ image_encoder = keras_hub.models.ResNetBackbone.from_preset(
27
+ "resnet_18_imagenet",
28
+ load_weights=False
29
+ )
30
+ backbone = keras_hub.models.BASNetBackbone(
31
+ image_encoder,
32
+ num_classes=1,
33
+ image_shape=[288, 288, 3]
34
+ )
35
+ model = keras_hub.models.BASNetImageSegmenter(backbone)
36
+
37
+ # Evaluate the model
38
+ pred_labels = model(images)
39
+
40
+ # Train the model
41
+ model.compile(
42
+ optimizer="adam",
43
+ loss=keras.losses.BinaryCrossentropy(from_logits=False),
44
+ metrics=["accuracy"],
45
+ )
46
+ model.fit(images, labels, epochs=3)
47
+ ```
48
+ """
49
+
50
+ backbone_cls = BASNetBackbone
51
+ preprocessor_cls = BASNetPreprocessor
52
+
53
+ def __init__(
54
+ self,
55
+ backbone,
56
+ preprocessor=None,
57
+ **kwargs,
58
+ ):
59
+ # === Functional Model ===
60
+ x = backbone.input
61
+ outputs = backbone(x)
62
+ # only return the refinement module's output as final prediction
63
+ outputs = outputs["refine_out"]
64
+ super().__init__(inputs=x, outputs=outputs, **kwargs)
65
+
66
+ # === Config ===
67
+ self.backbone = backbone
68
+ self.preprocessor = preprocessor
69
+
70
+ def compute_loss(self, x, y, y_pred, *args, **kwargs):
71
+ # train BASNet's prediction and refinement module outputs against the
72
+ # same ground truth data
73
+ outputs = self.backbone(x)
74
+ losses = []
75
+ for output in outputs.values():
76
+ losses.append(super().compute_loss(x, y, output, *args, **kwargs))
77
+ return keras.ops.sum(losses, axis=0)
78
+
79
+ def compile(
80
+ self,
81
+ optimizer="auto",
82
+ loss="auto",
83
+ metrics="auto",
84
+ **kwargs,
85
+ ):
86
+ """Configures the `BASNet` task for training.
87
+
88
+ `BASNet` extends the default compilation signature
89
+ of `keras.Model.compile` with defaults for `optimizer` and `loss`. To
90
+ override these defaults, pass any value to these arguments during
91
+ compilation.
92
+
93
+ Args:
94
+ optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer`
95
+ instance. Defaults to `"auto"`, which uses the default
96
+ optimizer for `BASNet`. See `keras.Model.compile` and
97
+ `keras.optimizers` for more info on possible `optimizer`
98
+ values.
99
+ loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance.
100
+ Defaults to `"auto"`, in which case the default loss
101
+ computation of `BASNet` will be applied.
102
+ See `keras.Model.compile` and `keras.losses` for more info on
103
+ possible `loss` values.
104
+ metrics: `"auto"`, or a list of metrics to be evaluated by
105
+ the model during training and testing. Defaults to `"auto"`,
106
+ where a `keras.metrics.Accuracy` will be applied to track the
107
+ accuracy of the model during training.
108
+ See `keras.Model.compile` and `keras.metrics` for
109
+ more info on possible `metrics` values.
110
+ **kwargs: See `keras.Model.compile` for a full list of arguments
111
+ supported by the compile method.
112
+ """
113
+ if loss == "auto":
114
+ loss = keras.losses.BinaryCrossentropy()
115
+ if metrics == "auto":
116
+ metrics = [keras.metrics.Accuracy()]
117
+ super().compile(
118
+ optimizer=optimizer,
119
+ loss=loss,
120
+ metrics=metrics,
121
+ **kwargs,
122
+ )
@@ -0,0 +1,366 @@
1
+ import keras
2
+
3
+ from keras_hub.src.api_export import keras_hub_export
4
+ from keras_hub.src.models.backbone import Backbone
5
+ from keras_hub.src.models.resnet.resnet_backbone import (
6
+ apply_basic_block as resnet_basic_block,
7
+ )
8
+
9
+
10
+ @keras_hub_export("keras_hub.models.BASNetBackbone")
11
+ class BASNetBackbone(Backbone):
12
+ """BASNet architecture for semantic segmentation.
13
+
14
+ A Keras model implementing the BASNet architecture described in [BASNet:
15
+ Boundary-Aware Segmentation Network for Mobile and Web Applications](
16
+ https://arxiv.org/abs/2101.04704). BASNet uses a predict-refine
17
+ architecture for highly accurate image segmentation.
18
+
19
+ Args:
20
+ image_encoder: A `keras_hub.models.ResNetBackbone` instance. The
21
+ backbone network for the model that is used as a feature extractor
22
+ for BASNet prediction encoder. Currently supported backbones are
23
+ ResNet18 and ResNet34.
24
+ (Note: Do not specify `image_shape` within the backbone.
25
+ Please provide these while initializing the 'BASNetBackbone' model)
26
+ num_classes: int, the number of classes for the segmentation model.
27
+ image_shape: optional shape tuple, defaults to (None, None, 3).
28
+ projection_filters: int, number of filters in the convolution layer
29
+ projecting low-level features from the `backbone`.
30
+ prediction_heads: (Optional) List of `keras.layers.Layer` defining
31
+ the prediction module head for the model. If not provided, a
32
+ default head is created with a Conv2D layer followed by resizing.
33
+ refinement_head: (Optional) a `keras.layers.Layer` defining the
34
+ refinement module head for the model. If not provided, a default
35
+ head is created with a Conv2D layer.
36
+ dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
37
+ to use for the model's computations and weights.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ image_encoder,
43
+ num_classes,
44
+ image_shape=(None, None, 3),
45
+ projection_filters=64,
46
+ prediction_heads=None,
47
+ refinement_head=None,
48
+ dtype=None,
49
+ **kwargs,
50
+ ):
51
+ if not isinstance(image_encoder, keras.layers.Layer) or not isinstance(
52
+ image_encoder, keras.Model
53
+ ):
54
+ raise ValueError(
55
+ "Argument `image_encoder` must be a `keras.layers.Layer`"
56
+ f" instance or `keras.Model`. Received instead"
57
+ f" image_encoder={image_encoder} (of type"
58
+ f" {type(image_encoder)})."
59
+ )
60
+
61
+ if tuple(image_encoder.image_shape) != (None, None, 3):
62
+ raise ValueError(
63
+ "Do not specify `image_shape` within the"
64
+ " `BASNetBackbone`'s image_encoder. \nPlease provide"
65
+ " `image_shape` while initializing the 'BASNetBackbone' model."
66
+ )
67
+
68
+ # === Functional Model ===
69
+ inputs = keras.layers.Input(shape=image_shape)
70
+ x = inputs
71
+
72
+ if prediction_heads is None:
73
+ prediction_heads = []
74
+ for size in (1, 2, 4, 8, 16, 32, 32):
75
+ head_layers = [
76
+ keras.layers.Conv2D(
77
+ num_classes,
78
+ kernel_size=(3, 3),
79
+ padding="same",
80
+ dtype=dtype,
81
+ )
82
+ ]
83
+ if size != 1:
84
+ head_layers.append(
85
+ keras.layers.UpSampling2D(
86
+ size=size, interpolation="bilinear", dtype=dtype
87
+ )
88
+ )
89
+ prediction_heads.append(keras.Sequential(head_layers))
90
+
91
+ if refinement_head is None:
92
+ refinement_head = keras.Sequential(
93
+ [
94
+ keras.layers.Conv2D(
95
+ num_classes,
96
+ kernel_size=(3, 3),
97
+ padding="same",
98
+ dtype=dtype,
99
+ ),
100
+ ]
101
+ )
102
+
103
+ # Prediction model.
104
+ predict_model = basnet_predict(
105
+ x, image_encoder, projection_filters, prediction_heads, dtype=dtype
106
+ )
107
+
108
+ # Refinement model.
109
+ refine_model = basnet_rrm(
110
+ predict_model, projection_filters, refinement_head, dtype=dtype
111
+ )
112
+
113
+ outputs = refine_model.outputs # Combine outputs.
114
+ outputs.extend(predict_model.outputs)
115
+
116
+ output_names = ["refine_out"] + [
117
+ f"predict_out_{i}" for i in range(1, len(outputs))
118
+ ]
119
+
120
+ outputs = {
121
+ output_name: keras.layers.Activation(
122
+ "sigmoid", name=output_name, dtype=dtype
123
+ )(output)
124
+ for output, output_name in zip(outputs, output_names)
125
+ }
126
+
127
+ super().__init__(inputs=inputs, outputs=outputs, dtype=dtype, **kwargs)
128
+
129
+ # === Config ===
130
+ self.image_encoder = image_encoder
131
+ self.num_classes = num_classes
132
+ self.image_shape = image_shape
133
+ self.projection_filters = projection_filters
134
+ self.prediction_heads = prediction_heads
135
+ self.refinement_head = refinement_head
136
+
137
+ def get_config(self):
138
+ config = super().get_config()
139
+ config.update(
140
+ {
141
+ "image_encoder": keras.saving.serialize_keras_object(
142
+ self.image_encoder
143
+ ),
144
+ "num_classes": self.num_classes,
145
+ "image_shape": self.image_shape,
146
+ "projection_filters": self.projection_filters,
147
+ "prediction_heads": [
148
+ keras.saving.serialize_keras_object(prediction_head)
149
+ for prediction_head in self.prediction_heads
150
+ ],
151
+ "refinement_head": keras.saving.serialize_keras_object(
152
+ self.refinement_head
153
+ ),
154
+ }
155
+ )
156
+ return config
157
+
158
+ @classmethod
159
+ def from_config(cls, config):
160
+ if "image_encoder" in config:
161
+ config["image_encoder"] = keras.layers.deserialize(
162
+ config["image_encoder"]
163
+ )
164
+ if "prediction_heads" in config and isinstance(
165
+ config["prediction_heads"], list
166
+ ):
167
+ for i in range(len(config["prediction_heads"])):
168
+ if isinstance(config["prediction_heads"][i], dict):
169
+ config["prediction_heads"][i] = keras.layers.deserialize(
170
+ config["prediction_heads"][i]
171
+ )
172
+
173
+ if "refinement_head" in config and isinstance(
174
+ config["refinement_head"], dict
175
+ ):
176
+ config["refinement_head"] = keras.layers.deserialize(
177
+ config["refinement_head"]
178
+ )
179
+ return super().from_config(config)
180
+
181
+
182
+ def convolution_block(x_input, filters, dilation=1, dtype=None):
183
+ """Apply convolution + batch normalization + ReLU activation.
184
+
185
+ Args:
186
+ x_input: Input keras tensor.
187
+ filters: int, number of output filters in the convolution.
188
+ dilation: int, dilation rate for the convolution operation.
189
+ Defaults to 1.
190
+ dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
191
+ to use for the model's computations and weights.
192
+
193
+ Returns:
194
+ A tensor with convolution, batch normalization, and ReLU
195
+ activation applied.
196
+ """
197
+ x = keras.layers.Conv2D(
198
+ filters, (3, 3), padding="same", dilation_rate=dilation, dtype=dtype
199
+ )(x_input)
200
+ x = keras.layers.BatchNormalization(dtype=dtype)(x)
201
+ return keras.layers.Activation("relu", dtype=dtype)(x)
202
+
203
+
204
+ def get_resnet_block(_resnet, block_num):
205
+ """Extract and return a specific ResNet block.
206
+
207
+ Args:
208
+ _resnet: `keras.Model`. ResNet model instance.
209
+ block_num: int, block number to extract.
210
+
211
+ Returns:
212
+ A Keras Model representing the specified ResNet block.
213
+ """
214
+
215
+ extractor_levels = ["P2", "P3", "P4", "P5"]
216
+ num_blocks = _resnet.stackwise_num_blocks
217
+ if block_num == 0:
218
+ x = _resnet.get_layer("pool1_pool").output
219
+ else:
220
+ x = _resnet.pyramid_outputs[extractor_levels[block_num - 1]]
221
+ y = _resnet.get_layer(
222
+ f"stack{block_num}_block{num_blocks[block_num]-1}_add"
223
+ ).output
224
+ return keras.models.Model(
225
+ inputs=x,
226
+ outputs=y,
227
+ name=f"resnet_block{block_num + 1}",
228
+ )
229
+
230
+
231
+ def basnet_predict(x_input, backbone, filters, segmentation_heads, dtype=None):
232
+ """BASNet Prediction Module.
233
+
234
+ This module outputs a coarse label map by integrating heavy
235
+ encoder, bridge, and decoder blocks.
236
+
237
+ Args:
238
+ x_input: Input keras tensor.
239
+ backbone: `keras.Model`. The backbone network used as a feature
240
+ extractor for BASNet prediction encoder.
241
+ filters: int, the number of filters.
242
+ segmentation_heads: List of `keras.layers.Layer`, A list of Keras
243
+ layers serving as the segmentation head for prediction module.
244
+ dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
245
+ to use for the model's computations and weights.
246
+
247
+
248
+ Returns:
249
+ A Keras Model that integrates the encoder, bridge, and decoder
250
+ blocks for coarse label map prediction.
251
+ """
252
+ num_stages = 6
253
+
254
+ x = x_input
255
+
256
+ # -------------Encoder--------------
257
+ x = keras.layers.Conv2D(
258
+ filters, kernel_size=(3, 3), padding="same", dtype=dtype
259
+ )(x)
260
+
261
+ encoder_blocks = []
262
+ for i in range(num_stages):
263
+ if i < 4: # First four stages are adopted from ResNet backbone.
264
+ x = get_resnet_block(backbone, i)(x)
265
+ encoder_blocks.append(x)
266
+ else: # Last 2 stages consist of three basic resnet blocks.
267
+ x = keras.layers.MaxPool2D(
268
+ pool_size=(2, 2), strides=(2, 2), dtype=dtype
269
+ )(x)
270
+ for j in range(3):
271
+ x = resnet_basic_block(
272
+ x,
273
+ filters=x.shape[3],
274
+ conv_shortcut=False,
275
+ name=f"v1_basic_block_{i + 1}_{j + 1}",
276
+ dtype=dtype,
277
+ )
278
+ encoder_blocks.append(x)
279
+
280
+ # -------------Bridge-------------
281
+ x = convolution_block(x, filters=filters * 8, dilation=2, dtype=dtype)
282
+ x = convolution_block(x, filters=filters * 8, dilation=2, dtype=dtype)
283
+ x = convolution_block(x, filters=filters * 8, dilation=2, dtype=dtype)
284
+ encoder_blocks.append(x)
285
+
286
+ # -------------Decoder-------------
287
+ decoder_blocks = []
288
+ for i in reversed(range(num_stages)):
289
+ if i != (num_stages - 1): # Except first, scale other decoder stages.
290
+ x = keras.layers.UpSampling2D(
291
+ size=2, interpolation="bilinear", dtype=dtype
292
+ )(x)
293
+
294
+ x = keras.layers.concatenate([encoder_blocks[i], x], axis=-1)
295
+ x = convolution_block(x, filters=filters * 8, dtype=dtype)
296
+ x = convolution_block(x, filters=filters * 8, dtype=dtype)
297
+ x = convolution_block(x, filters=filters * 8, dtype=dtype)
298
+ decoder_blocks.append(x)
299
+
300
+ decoder_blocks.reverse() # Change order from last to first decoder stage.
301
+ decoder_blocks.append(encoder_blocks[-1]) # Copy bridge to decoder.
302
+
303
+ # -------------Side Outputs--------------
304
+ decoder_blocks = [
305
+ segmentation_head(decoder_block) # Prediction segmentation head.
306
+ for segmentation_head, decoder_block in zip(
307
+ segmentation_heads, decoder_blocks
308
+ )
309
+ ]
310
+
311
+ return keras.models.Model(inputs=[x_input], outputs=decoder_blocks)
312
+
313
+
314
+ def basnet_rrm(base_model, filters, segmentation_head, dtype=None):
315
+ """BASNet Residual Refinement Module (RRM).
316
+
317
+ This module outputs a fine label map by integrating light encoder,
318
+ bridge, and decoder blocks.
319
+
320
+ Args:
321
+ base_model: Keras model used as the base or coarse label map.
322
+ filters: int, the number of filters.
323
+ segmentation_head: a `keras.layers.Layer`, A Keras layer serving
324
+ as the segmentation head for refinement module.
325
+ dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
326
+ to use for the model's computations and weights.
327
+
328
+ Returns:
329
+ A Keras Model that constructs the Residual Refinement Module (RRM).
330
+ """
331
+ num_stages = 4
332
+
333
+ x_input = base_model.output[0]
334
+
335
+ # -------------Encoder--------------
336
+ x = keras.layers.Conv2D(
337
+ filters, kernel_size=(3, 3), padding="same", dtype=dtype
338
+ )(x_input)
339
+
340
+ encoder_blocks = []
341
+ for _ in range(num_stages):
342
+ x = convolution_block(x, filters=filters)
343
+ encoder_blocks.append(x)
344
+ x = keras.layers.MaxPool2D(
345
+ pool_size=(2, 2), strides=(2, 2), dtype=dtype
346
+ )(x)
347
+
348
+ # -------------Bridge--------------
349
+ x = convolution_block(x, filters=filters, dtype=dtype)
350
+
351
+ # -------------Decoder--------------
352
+ for i in reversed(range(num_stages)):
353
+ x = keras.layers.UpSampling2D(
354
+ size=2, interpolation="bilinear", dtype=dtype
355
+ )(x)
356
+ x = keras.layers.concatenate([encoder_blocks[i], x], axis=-1)
357
+ x = convolution_block(x, filters=filters)
358
+
359
+ x = segmentation_head(x) # Refinement segmentation head.
360
+
361
+ # ------------- refined = coarse + residual
362
+ x = keras.layers.Add(dtype=dtype)(
363
+ [x_input, x]
364
+ ) # Add prediction + refinement output
365
+
366
+ return keras.models.Model(inputs=base_model.input, outputs=[x])
@@ -0,0 +1,8 @@
1
+ from keras_hub.src.api_export import keras_hub_export
2
+ from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
3
+ from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone
4
+
5
+
6
+ @keras_hub_export("keras_hub.layers.BASNetImageConverter")
7
+ class BASNetImageConverter(ImageConverter):
8
+ backbone_cls = BASNetBackbone
@@ -0,0 +1,14 @@
1
+ from keras_hub.src.api_export import keras_hub_export
2
+ from keras_hub.src.models.basnet.basnet_backbone import BASNetBackbone
3
+ from keras_hub.src.models.basnet.basnet_image_converter import (
4
+ BASNetImageConverter,
5
+ )
6
+ from keras_hub.src.models.image_segmenter_preprocessor import (
7
+ ImageSegmenterPreprocessor,
8
+ )
9
+
10
+
11
+ @keras_hub_export("keras_hub.models.BASNetPreprocessor")
12
+ class BASNetPreprocessor(ImageSegmenterPreprocessor):
13
+ backbone_cls = BASNetBackbone
14
+ image_converter_cls = BASNetImageConverter
@@ -0,0 +1,3 @@
1
+ """BASNet model preset configurations."""
2
+
3
+ basnet_presets = {}
@@ -46,4 +46,81 @@ backbone_presets = {
46
46
  },
47
47
  "kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch16_384_imagenet/1",
48
48
  },
49
+ "vit_base_patch32_384_imagenet": {
50
+ "metadata": {
51
+ "description": (
52
+ "ViT-B32 model pre-trained on the ImageNet 1k dataset with "
53
+ "image resolution of 384x384 "
54
+ ),
55
+ "params": 87528192,
56
+ "path": "vit",
57
+ },
58
+ "kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch32_384_imagenet/1",
59
+ },
60
+ "vit_large_patch32_384_imagenet": {
61
+ "metadata": {
62
+ "description": (
63
+ "ViT-L32 model pre-trained on the ImageNet 1k dataset with "
64
+ "image resolution of 384x384 "
65
+ ),
66
+ "params": 305607680,
67
+ "path": "vit",
68
+ },
69
+ "kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch32_384_imagenet/1",
70
+ },
71
+ "vit_base_patch16_224_imagenet21k": {
72
+ "metadata": {
73
+ "description": (
74
+ "ViT-B16 backbone pre-trained on the ImageNet 21k dataset with "
75
+ "image resolution of 224x224 "
76
+ ),
77
+ "params": 85798656,
78
+ "path": "vit",
79
+ },
80
+ "kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch16_224_imagenet21k/1",
81
+ },
82
+ "vit_base_patch32_224_imagenet21k": {
83
+ "metadata": {
84
+ "description": (
85
+ "ViT-B32 backbone pre-trained on the ImageNet 21k dataset with "
86
+ "image resolution of 224x224 "
87
+ ),
88
+ "params": 87455232,
89
+ "path": "vit",
90
+ },
91
+ "kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch32_224_imagenet21k/1",
92
+ },
93
+ "vit_huge_patch14_224_imagenet21k": {
94
+ "metadata": {
95
+ "description": (
96
+ "ViT-H14 backbone pre-trained on the ImageNet 21k dataset with "
97
+ "image resolution of 224x224 "
98
+ ),
99
+ "params": 630764800,
100
+ "path": "vit",
101
+ },
102
+ "kaggle_handle": "kaggle://keras/vit/keras/vit_huge_patch14_224_imagenet21k/1",
103
+ },
104
+ "vit_large_patch16_224_imagenet21k": {
105
+ "metadata": {
106
+ "description": (
107
+ "ViT-L16 backbone pre-trained on the ImageNet 21k dataset with "
108
+ "image resolution of 224x224 "
109
+ ),
110
+ "params": 303301632,
111
+ "path": "vit",
112
+ },
113
+ "kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch16_224_imagenet21k/1",
114
+ },
115
+ "vit_large_patch32_224_imagenet21k": {
116
+ "metadata": {
117
+ "description": (
118
+ "ViT-L32 backbone pre-trained on the ImageNet 21k dataset with "
119
+ "image resolution of 224x224 "
120
+ ),
121
+ "params": 305510400,
122
+ "path": "vit",
123
+ },
124
+ "kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch32_224_imagenet21k/1",
125
+ },
49
126
  }
@@ -454,16 +454,6 @@ def load_json(preset, config_file=CONFIG_FILE):
454
454
  return config
455
455
 
456
456
 
457
- def load_serialized_object(config, **kwargs):
458
- # `dtype` in config might be a serialized `DTypePolicy` or `DTypePolicyMap`.
459
- # Ensure that `dtype` is properly configured.
460
- dtype = kwargs.pop("dtype", None)
461
- config = set_dtype_in_config(config, dtype)
462
-
463
- config["config"] = {**config["config"], **kwargs}
464
- return keras.saving.deserialize_keras_object(config)
465
-
466
-
467
457
  def check_config_class(config):
468
458
  """Validate a preset is being loaded on the correct class."""
469
459
  registered_name = config["registered_name"]
@@ -631,7 +621,7 @@ class KerasPresetLoader(PresetLoader):
631
621
  return check_config_class(self.config)
632
622
 
633
623
  def load_backbone(self, cls, load_weights, **kwargs):
634
- backbone = load_serialized_object(self.config, **kwargs)
624
+ backbone = self._load_serialized_object(self.config, **kwargs)
635
625
  if load_weights:
636
626
  jax_memory_cleanup(backbone)
637
627
  backbone.load_weights(get_file(self.preset, MODEL_WEIGHTS_FILE))
@@ -639,18 +629,18 @@ class KerasPresetLoader(PresetLoader):
639
629
 
640
630
  def load_tokenizer(self, cls, config_file=TOKENIZER_CONFIG_FILE, **kwargs):
641
631
  tokenizer_config = load_json(self.preset, config_file)
642
- tokenizer = load_serialized_object(tokenizer_config, **kwargs)
632
+ tokenizer = self._load_serialized_object(tokenizer_config, **kwargs)
643
633
  if hasattr(tokenizer, "load_preset_assets"):
644
634
  tokenizer.load_preset_assets(self.preset)
645
635
  return tokenizer
646
636
 
647
637
  def load_audio_converter(self, cls, **kwargs):
648
638
  converter_config = load_json(self.preset, AUDIO_CONVERTER_CONFIG_FILE)
649
- return load_serialized_object(converter_config, **kwargs)
639
+ return self._load_serialized_object(converter_config, **kwargs)
650
640
 
651
641
  def load_image_converter(self, cls, **kwargs):
652
642
  converter_config = load_json(self.preset, IMAGE_CONVERTER_CONFIG_FILE)
653
- return load_serialized_object(converter_config, **kwargs)
643
+ return self._load_serialized_object(converter_config, **kwargs)
654
644
 
655
645
  def load_task(self, cls, load_weights, load_task_weights, **kwargs):
656
646
  # If there is no `task.json` or it's for the wrong class delegate to the
@@ -671,7 +661,7 @@ class KerasPresetLoader(PresetLoader):
671
661
  backbone_config = task_config["config"]["backbone"]["config"]
672
662
  backbone_config = {**backbone_config, **backbone_kwargs}
673
663
  task_config["config"]["backbone"]["config"] = backbone_config
674
- task = load_serialized_object(task_config, **kwargs)
664
+ task = self._load_serialized_object(task_config, **kwargs)
675
665
  if task.preprocessor and hasattr(
676
666
  task.preprocessor, "load_preset_assets"
677
667
  ):
@@ -699,11 +689,20 @@ class KerasPresetLoader(PresetLoader):
699
689
  if not issubclass(check_config_class(preprocessor_json), cls):
700
690
  return super().load_preprocessor(cls, **kwargs)
701
691
  # We found a `preprocessing.json` with a complete config for our class.
702
- preprocessor = load_serialized_object(preprocessor_json, **kwargs)
692
+ preprocessor = self._load_serialized_object(preprocessor_json, **kwargs)
703
693
  if hasattr(preprocessor, "load_preset_assets"):
704
694
  preprocessor.load_preset_assets(self.preset)
705
695
  return preprocessor
706
696
 
697
+ def _load_serialized_object(self, config, **kwargs):
698
+ # `dtype` in config might be a serialized `DTypePolicy` or
699
+ # `DTypePolicyMap`. Ensure that `dtype` is properly configured.
700
+ dtype = kwargs.pop("dtype", None)
701
+ config = set_dtype_in_config(config, dtype)
702
+
703
+ config["config"] = {**config["config"], **kwargs}
704
+ return keras.saving.deserialize_keras_object(config)
705
+
707
706
 
708
707
  class KerasPresetSaver:
709
708
  def __init__(self, preset_dir):
@@ -787,6 +786,8 @@ class KerasPresetSaver:
787
786
  tasks = list_subclasses(Task)
788
787
  tasks = filter(lambda x: x.backbone_cls is type(layer), tasks)
789
788
  tasks = [task.__base__.__name__ for task in tasks]
789
+ # Keep task list alphabetical.
790
+ tasks = sorted(tasks)
790
791
 
791
792
  keras_version = keras.version() if hasattr(keras, "version") else None
792
793
  metadata = {
@@ -1,7 +1,7 @@
1
1
  from keras_hub.src.api_export import keras_hub_export
2
2
 
3
3
  # Unique source of truth for the version number.
4
- __version__ = "0.19.0.dev202412180349"
4
+ __version__ = "0.19.0.dev202412200346"
5
5
 
6
6
 
7
7
  @keras_hub_export("keras_hub.version")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: keras-hub-nightly
3
- Version: 0.19.0.dev202412180349
3
+ Version: 0.19.0.dev202412200346
4
4
  Summary: Industry-strength Natural Language Processing extensions for Keras.
5
5
  Home-page: https://github.com/keras-team/keras-hub
6
6
  Author: Keras team
@@ -1,15 +1,15 @@
1
1
  keras_hub/__init__.py,sha256=QGdXyHgYt6cMUAP1ebxwc6oR86dE0dkMxNy2eOCQtFo,855
2
2
  keras_hub/api/__init__.py,sha256=spMxsgqzjpeuC8rY4WP-2kAZ2qwwKRSbFwddXgUjqQE,524
3
3
  keras_hub/api/bounding_box/__init__.py,sha256=T8R_X7BPm0et1xaZq8565uJmid7dylsSFSj4V-rGuFQ,1097
4
- keras_hub/api/layers/__init__.py,sha256=7-EfzihzyE-TinUW078OyLDQnjsADB2binfim8rj6TA,3102
4
+ keras_hub/api/layers/__init__.py,sha256=YO_YLbcxMEboFEgmFkzRf_JfQciQukX2AseOGpWEbDo,3195
5
5
  keras_hub/api/metrics/__init__.py,sha256=So8Ec-lOcTzn_UUMmAdzDm8RKkPu2dbRUm2px8gpUEI,381
6
- keras_hub/api/models/__init__.py,sha256=0VcBlT-Ee_-7KnMW6YmjSYd0fjligqNmIXbL7Vl0GTo,16577
6
+ keras_hub/api/models/__init__.py,sha256=suTcar7FqO5w9nNtalqmfYn7Fs6XmNEGpbojK-gaMEY,16795
7
7
  keras_hub/api/samplers/__init__.py,sha256=n-_SEXxr2LNUzK2FqVFN7alsrkx1P_HOVTeLZKeGCdE,730
8
8
  keras_hub/api/tokenizers/__init__.py,sha256=mtJgQy1spfQnPAkeLoeinsT_W9iCWHlJXwzcol5W1aU,2524
9
9
  keras_hub/api/utils/__init__.py,sha256=Gp1E6gG-RtKQS3PBEQEOz9PQvXkXaJ0ySGMqZ7myN7A,215
10
10
  keras_hub/src/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  keras_hub/src/api_export.py,sha256=9pQZK27JObxWZ96QPLBp1OBsjWigh1iuV6RglPGMRk0,1499
12
- keras_hub/src/version_utils.py,sha256=2C2DeD2GCwzG2BMJUN9YMKu_w131-hIZyhHDcb9XDBA,222
12
+ keras_hub/src/version_utils.py,sha256=Tk7iPoppgJnEhbF7HzURDrFGXD3R74NzoSPr_MJo9yY,222
13
13
  keras_hub/src/bounding_box/__init__.py,sha256=7i6KnGupN4AVivR_dFjQyuuTbI0GkHy8d-aMXeqZdU8,95
14
14
  keras_hub/src/bounding_box/converters.py,sha256=UUp1hwegpDZyIo8sh9TLNy1v6JjwmvwzL6wmHFMAtbk,21916
15
15
  keras_hub/src/bounding_box/formats.py,sha256=YmskOz2BOSat7NaE__J9VfpSNGPJJR0znSzA4lp8MMI,3868
@@ -85,6 +85,12 @@ keras_hub/src/models/bart/bart_presets.py,sha256=ppk9r_4Sm21XO6F9k3L946rkJBwWSLN
85
85
  keras_hub/src/models/bart/bart_seq_2_seq_lm.py,sha256=0r9snJsqqmH8F1_CDQZyFgqLNMYJM8AYFkmqfxUNB1U,19262
86
86
  keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py,sha256=3_e-ULIcm_3DKgt7X7cvyLZEDIEkpu9HdANgH6MjZgg,4373
87
87
  keras_hub/src/models/bart/bart_tokenizer.py,sha256=Q7IXmIwXzhPSN427oQRyF9ufoExQGS184Yo_4boaOZo,2811
88
+ keras_hub/src/models/basnet/__init__.py,sha256=4N6XvIUYYJl5xtoaL3_9fawUX_qP3WmTYNEEU7tn8Gw,253
89
+ keras_hub/src/models/basnet/basnet.py,sha256=JA58Q9lmygdSOm5MUaPAlaL6B8XnmqCcRaGrk9c8P3Q,4287
90
+ keras_hub/src/models/basnet/basnet_backbone.py,sha256=t_52WW6jetONS7AnPf9YsiMLDqOjVwjNuayQEv6ZAk4,13503
91
+ keras_hub/src/models/basnet/basnet_image_converter.py,sha256=DwzAwtZeggYw_qyRQ-Abnnm885Wobv3wClxRzOTscI0,342
92
+ keras_hub/src/models/basnet/basnet_preprocessor.py,sha256=uM504utaXODSqR5zpKnopRuaV_l84zCg06RkNoNSKIs,510
93
+ keras_hub/src/models/basnet/basnet_presets.py,sha256=z6tR2q_EvYnUmGfsWIWYfmR_8gvWYPH3QmtpAu_T8f8,63
88
94
  keras_hub/src/models/bert/__init__.py,sha256=K_UmCqDgOFFvXgzjXRn5oG0WWi53rAsQMOmUrsiBe1k,245
89
95
  keras_hub/src/models/bert/bert_backbone.py,sha256=o8GXUpoKPXLpfFzx5u9wI_3rZJeabPfYJEYSI09Clos,8069
90
96
  keras_hub/src/models/bert/bert_masked_lm.py,sha256=8gb1g8h5VFVLmKNEPfLe26z7SOlFnzf9R9okK3rp8AU,4045
@@ -339,7 +345,7 @@ keras_hub/src/models/vit/vit_image_classifier.py,sha256=lMVxiD1_6drx7XQ7P7YzlqnF
339
345
  keras_hub/src/models/vit/vit_image_classifier_preprocessor.py,sha256=wu6YcBlXMWB9sKCPvmNdGBZKTLQt_HyHWS6P9nyDwsk,504
340
346
  keras_hub/src/models/vit/vit_image_converter.py,sha256=5xVF04BzMcdTDc6aErAYj3_BuGmVd3zoJMcH1ho4T0g,2561
341
347
  keras_hub/src/models/vit/vit_layers.py,sha256=s4j3n3qnJnv6W9AdUkNsO3Vsi_BhxEGECYkaLVCU6XY,13238
342
- keras_hub/src/models/vit/vit_presets.py,sha256=TK7fBafZsM6S8_KCMSdayof8qrB40H8GLYMVYA83HTA,1682
348
+ keras_hub/src/models/vit/vit_presets.py,sha256=1QSyagzonaK4zpJdnjW2UL70T85xGxktsmLdSxcZTjk,4479
343
349
  keras_hub/src/models/vit_det/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
344
350
  keras_hub/src/models/vit_det/vit_det_backbone.py,sha256=DOZ5J7c1t5PAZ6y0pMmBoQTMOUup7UoUrYVfCs69ltY,7697
345
351
  keras_hub/src/models/vit_det/vit_layers.py,sha256=mnwu56chMc6zxmfp_hsLdR7TXYy1_YsWy1KwGX9M5Ic,19840
@@ -387,7 +393,7 @@ keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py,sha256=Zz1SGgArykxBVWnS
387
393
  keras_hub/src/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
388
394
  keras_hub/src/utils/keras_utils.py,sha256=lrZuC8HL2lmQfbHaS_t1JUyJann_ji2iTYE0Fzos8PU,1969
389
395
  keras_hub/src/utils/pipeline_model.py,sha256=jgzB6NQPSl0KOu08N-TazfOnXnUJbZjH2EXXhx25Ftg,9084
390
- keras_hub/src/utils/preset_utils.py,sha256=P3vqzVb3M-gJmPJDGOe1k0KLmHrhgi4ULeg-L_n5jhM,30976
396
+ keras_hub/src/utils/preset_utils.py,sha256=MFQqOIIWvfYToiUHfpPX0lERmgCkz09bM9L67E44H3s,31115
391
397
  keras_hub/src/utils/python_utils.py,sha256=N8nWeO3san4YnGkffRXG3Ix7VEIMTKSN21FX5TuL7G8,202
392
398
  keras_hub/src/utils/tensor_utils.py,sha256=YVJesN91bk-OzJXY1mOKBppuY8noBU7zhPQNXPxZVGc,14646
393
399
  keras_hub/src/utils/imagenet/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -411,7 +417,7 @@ keras_hub/src/utils/transformers/convert_pali_gemma.py,sha256=B1leeDw96Yvu81hYum
411
417
  keras_hub/src/utils/transformers/convert_vit.py,sha256=9SUZ9utNJhW_5cj3acMn9cRy47u2eIcDsrhmzj77o9k,5187
412
418
  keras_hub/src/utils/transformers/preset_loader.py,sha256=DgGJXbTSB9Na8FIR-YWWVqQPOFxHwWrGm41EwcS_EFs,3797
413
419
  keras_hub/src/utils/transformers/safetensor_utils.py,sha256=CYUHyA4y-B61r7NDnCsFb4t_UmSwZ1k9L-8gzEd6KRg,3339
414
- keras_hub_nightly-0.19.0.dev202412180349.dist-info/METADATA,sha256=MC4QMK_Qao8jT7qOGSWYpMfxofLZcUxdIn-c-d-czd0,7263
415
- keras_hub_nightly-0.19.0.dev202412180349.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
416
- keras_hub_nightly-0.19.0.dev202412180349.dist-info/top_level.txt,sha256=N4J6piIWBKa38A4uV-CnIopnOEf8mHAbkNXafXm_CuA,10
417
- keras_hub_nightly-0.19.0.dev202412180349.dist-info/RECORD,,
420
+ keras_hub_nightly-0.19.0.dev202412200346.dist-info/METADATA,sha256=gnzV7FHJ7Cx0ZFRZ-1WimsdR_QeM-z_8z8ycCRAC808,7263
421
+ keras_hub_nightly-0.19.0.dev202412200346.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
422
+ keras_hub_nightly-0.19.0.dev202412200346.dist-info/top_level.txt,sha256=N4J6piIWBKa38A4uV-CnIopnOEf8mHAbkNXafXm_CuA,10
423
+ keras_hub_nightly-0.19.0.dev202412200346.dist-info/RECORD,,