keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.16.1.dev202410040340__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. keras_hub/api/layers/__init__.py +3 -3
  2. keras_hub/api/models/__init__.py +10 -1
  3. keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
  4. keras_hub/src/layers/preprocessing/image_converter.py +164 -34
  5. keras_hub/src/models/backbone.py +3 -9
  6. keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
  7. keras_hub/src/models/deeplab_v3/__init__.py +7 -0
  8. keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +196 -0
  9. keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
  10. keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
  11. keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
  12. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +4 -0
  13. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +109 -0
  14. keras_hub/src/models/densenet/densenet_image_classifier.py +0 -128
  15. keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
  16. keras_hub/src/models/feature_pyramid_backbone.py +1 -1
  17. keras_hub/src/models/image_classifier.py +147 -2
  18. keras_hub/src/models/image_classifier_preprocessor.py +3 -3
  19. keras_hub/src/models/image_segmenter.py +0 -5
  20. keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
  21. keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -109
  22. keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
  23. keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
  24. keras_hub/src/models/preprocessor.py +3 -5
  25. keras_hub/src/models/resnet/resnet_backbone.py +1 -11
  26. keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
  27. keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
  28. keras_hub/src/models/sam/__init__.py +5 -0
  29. keras_hub/src/models/sam/sam_image_converter.py +2 -4
  30. keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
  31. keras_hub/src/models/sam/sam_presets.py +3 -3
  32. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
  33. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +57 -93
  34. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +3 -3
  35. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +5 -3
  36. keras_hub/src/models/task.py +39 -36
  37. keras_hub/src/models/vae/__init__.py +1 -0
  38. keras_hub/src/models/vae/vae_backbone.py +172 -0
  39. keras_hub/src/models/vae/vae_layers.py +740 -0
  40. keras_hub/src/models/vgg/vgg_backbone.py +1 -20
  41. keras_hub/src/models/vgg/vgg_image_classifier.py +108 -29
  42. keras_hub/src/tokenizers/tokenizer.py +3 -6
  43. keras_hub/src/utils/preset_utils.py +103 -61
  44. keras_hub/src/utils/timm/preset_loader.py +8 -9
  45. keras_hub/src/version_utils.py +1 -1
  46. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.16.1.dev202410040340.dist-info}/METADATA +1 -1
  47. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.16.1.dev202410040340.dist-info}/RECORD +49 -41
  48. keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
  49. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
  50. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.16.1.dev202410040340.dist-info}/WHEEL +0 -0
  51. {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.16.1.dev202410040340.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,16 @@
1
+ from keras_hub.src.api_export import keras_hub_export
2
+ from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import (
3
+ DeepLabV3Backbone,
4
+ )
5
+ from keras_hub.src.models.deeplab_v3.deeplab_v3_image_converter import (
6
+ DeepLabV3ImageConverter,
7
+ )
8
+ from keras_hub.src.models.image_segmenter_preprocessor import (
9
+ ImageSegmenterPreprocessor,
10
+ )
11
+
12
+
13
+ @keras_hub_export("keras_hub.models.DeepLabV3ImageSegmenterPreprocessor")
14
+ class DeepLabV3ImageSegmenterPreprocessor(ImageSegmenterPreprocessor):
15
+ backbone_cls = DeepLabV3Backbone
16
+ image_converter_cls = DeepLabV3ImageConverter
@@ -0,0 +1,215 @@
1
+ import keras
2
+ from keras import ops
3
+
4
+
5
+ class SpatialPyramidPooling(keras.layers.Layer):
6
+ """Implements the Atrous Spatial Pyramid Pooling.
7
+
8
+ Reference for Atrous Spatial Pyramid Pooling [Rethinking Atrous Convolution
9
+ for Semantic Image Segmentation](https://arxiv.org/pdf/1706.05587.pdf) and
10
+ [Encoder-Decoder with Atrous Separable Convolution for Semantic Image
11
+ Segmentation](https://arxiv.org/pdf/1802.02611.pdf)
12
+
13
+ Args:
14
+ dilation_rates: list of ints. The dilation rate for parallel dilated conv.
15
+ Usually a sample choice of rates are `[6, 12, 18]`.
16
+ num_channels: int. The number of output channels, defaults to `256`.
17
+ activation: str. Activation to be used, defaults to `relu`.
18
+ dropout: float. The dropout rate of the final projection output after the
19
+ activations and batch norm, defaults to `0.0`, which means no dropout is
20
+ applied to the output.
21
+
22
+ Example:
23
+ ```python
24
+ inp = keras.layers.Input((384, 384, 3))
25
+ backbone = keras.applications.EfficientNetB0(
26
+ input_tensor=inp,
27
+ include_top=False)
28
+ output = backbone(inp)
29
+ output = SpatialPyramidPooling(
30
+ dilation_rates=[6, 12, 18])(output)
31
+ ```
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ dilation_rates,
37
+ num_channels=256,
38
+ activation="relu",
39
+ dropout=0.0,
40
+ **kwargs,
41
+ ):
42
+ super().__init__(**kwargs)
43
+ self.dilation_rates = dilation_rates
44
+ self.num_channels = num_channels
45
+ self.activation = activation
46
+ self.dropout = dropout
47
+ self.data_format = keras.config.image_data_format()
48
+ self.channel_axis = -1 if self.data_format == "channels_last" else 1
49
+
50
+ def build(self, input_shape):
51
+ channels = input_shape[self.channel_axis]
52
+
53
+ # This is the parallel networks that process the input features with
54
+ # different dilation rates. The output from each channel will be merged
55
+ # together and feed to the output.
56
+ self.aspp_parallel_channels = []
57
+
58
+ # Channel1 with Conv2D and 1x1 kernel size.
59
+ conv_sequential = keras.Sequential(
60
+ [
61
+ keras.layers.Conv2D(
62
+ filters=self.num_channels,
63
+ kernel_size=(1, 1),
64
+ use_bias=False,
65
+ data_format=self.data_format,
66
+ name="aspp_conv_1",
67
+ ),
68
+ keras.layers.BatchNormalization(
69
+ axis=self.channel_axis, name="aspp_bn_1"
70
+ ),
71
+ keras.layers.Activation(
72
+ self.activation, name="aspp_activation_1"
73
+ ),
74
+ ]
75
+ )
76
+ conv_sequential.build(input_shape)
77
+ self.aspp_parallel_channels.append(conv_sequential)
78
+
79
+ # Channel 2 and afterwards are based on self.dilation_rates, and each of
80
+ # them will have conv2D with 3x3 kernel size.
81
+ for i, dilation_rate in enumerate(self.dilation_rates):
82
+ conv_sequential = keras.Sequential(
83
+ [
84
+ keras.layers.Conv2D(
85
+ filters=self.num_channels,
86
+ kernel_size=(3, 3),
87
+ padding="same",
88
+ dilation_rate=dilation_rate,
89
+ use_bias=False,
90
+ data_format=self.data_format,
91
+ name=f"aspp_conv_{i+2}",
92
+ ),
93
+ keras.layers.BatchNormalization(
94
+ axis=self.channel_axis, name=f"aspp_bn_{i+2}"
95
+ ),
96
+ keras.layers.Activation(
97
+ self.activation, name=f"aspp_activation_{i+2}"
98
+ ),
99
+ ]
100
+ )
101
+ conv_sequential.build(input_shape)
102
+ self.aspp_parallel_channels.append(conv_sequential)
103
+
104
+ # Last channel is the global average pooling with conv2D 1x1 kernel.
105
+ if self.channel_axis == -1:
106
+ reshape = keras.layers.Reshape((1, 1, channels), name="reshape")
107
+ else:
108
+ reshape = keras.layers.Reshape((channels, 1, 1), name="reshape")
109
+ pool_sequential = keras.Sequential(
110
+ [
111
+ keras.layers.GlobalAveragePooling2D(
112
+ data_format=self.data_format, name="average_pooling"
113
+ ),
114
+ reshape,
115
+ keras.layers.Conv2D(
116
+ filters=self.num_channels,
117
+ kernel_size=(1, 1),
118
+ use_bias=False,
119
+ data_format=self.data_format,
120
+ name="conv_pooling",
121
+ ),
122
+ keras.layers.BatchNormalization(
123
+ axis=self.channel_axis, name="bn_pooling"
124
+ ),
125
+ keras.layers.Activation(
126
+ self.activation, name="activation_pooling"
127
+ ),
128
+ ]
129
+ )
130
+ pool_sequential.build(input_shape)
131
+ self.aspp_parallel_channels.append(pool_sequential)
132
+
133
+ # Final projection layers
134
+ projection = keras.Sequential(
135
+ [
136
+ keras.layers.Conv2D(
137
+ filters=self.num_channels,
138
+ kernel_size=(1, 1),
139
+ use_bias=False,
140
+ data_format=self.data_format,
141
+ name="conv_projection",
142
+ ),
143
+ keras.layers.BatchNormalization(
144
+ axis=self.channel_axis, name="bn_projection"
145
+ ),
146
+ keras.layers.Activation(
147
+ self.activation, name="activation_projection"
148
+ ),
149
+ keras.layers.Dropout(rate=self.dropout, name="dropout"),
150
+ ],
151
+ )
152
+ projection_input_channels = (
153
+ 2 + len(self.dilation_rates)
154
+ ) * self.num_channels
155
+ if self.data_format == "channels_first":
156
+ projection.build(
157
+ (input_shape[0],)
158
+ + (projection_input_channels,)
159
+ + (input_shape[2:])
160
+ )
161
+ else:
162
+ projection.build((input_shape[:-1]) + (projection_input_channels,))
163
+ self.projection = projection
164
+ self.built = True
165
+
166
+ def call(self, inputs):
167
+ """Calls the Atrous Spatial Pyramid Pooling layer on an input.
168
+
169
+ Args:
170
+ inputs: A tensor of shape [batch, height, width, channels]
171
+
172
+ Returns:
173
+ A tensor of shape [batch, height, width, num_channels]
174
+ """
175
+ result = []
176
+
177
+ for channel in self.aspp_parallel_channels:
178
+ temp = ops.cast(channel(inputs), inputs.dtype)
179
+ result.append(temp)
180
+
181
+ image_shape = ops.shape(inputs)
182
+ if self.channel_axis == -1:
183
+ height, width = image_shape[1], image_shape[2]
184
+ else:
185
+ height, width = image_shape[2], image_shape[3]
186
+ result[-1] = keras.layers.Resizing(
187
+ height,
188
+ width,
189
+ interpolation="bilinear",
190
+ data_format=self.data_format,
191
+ name="resizing",
192
+ )(result[-1])
193
+
194
+ result = ops.concatenate(result, axis=self.channel_axis)
195
+ return self.projection(result)
196
+
197
+ def compute_output_shape(self, inputs_shape):
198
+ if self.data_format == "channels_first":
199
+ return tuple(
200
+ (inputs_shape[0],) + (self.num_channels,) + (inputs_shape[2:])
201
+ )
202
+ else:
203
+ return tuple((inputs_shape[:-1]) + (self.num_channels,))
204
+
205
+ def get_config(self):
206
+ config = super().get_config()
207
+ config.update(
208
+ {
209
+ "dilation_rates": self.dilation_rates,
210
+ "num_channels": self.num_channels,
211
+ "activation": self.activation,
212
+ "dropout": self.dropout,
213
+ }
214
+ )
215
+ return config
@@ -0,0 +1,4 @@
1
+ """DeepLabV3 preset configurations."""
2
+
3
+ # TODO https://github.com/keras-team/keras-hub/issues/1896,
4
+ backbone_presets = {}
@@ -0,0 +1,109 @@
1
+ import keras
2
+
3
+ from keras_hub.src.api_export import keras_hub_export
4
+ from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import (
5
+ DeepLabV3Backbone,
6
+ )
7
+ from keras_hub.src.models.deeplab_v3.deeplab_v3_image_segmeter_preprocessor import (
8
+ DeepLabV3ImageSegmenterPreprocessor,
9
+ )
10
+ from keras_hub.src.models.image_segmenter import ImageSegmenter
11
+
12
+
13
+ @keras_hub_export("keras_hub.models.DeepLabV3ImageSegmenter")
14
+ class DeepLabV3ImageSegmenter(ImageSegmenter):
15
+ """DeepLabV3 and DeeplabV3 and DeeplabV3Plus segmentation task.
16
+
17
+ Args:
18
+ backbone: A `keras_hub.models.DeepLabV3` instance.
19
+ num_classes: int. The number of classes for the detection model. Note
20
+ that the `num_classes` contains the background class, and the
21
+ classes from the data should be represented by integers with range
22
+ `[0, num_classes]`.
23
+ activation: str or callable. The activation function to use on
24
+ the `Dense` layer. Set `activation=None` to return the output
25
+ logits. Defaults to `None`.
26
+ preprocessor: A `keras_hub.models.DeepLabV3ImageSegmenterPreprocessor`
27
+ or `None`. If `None`, this model will not apply preprocessing, and
28
+ inputs should be preprocessed before calling the model.
29
+
30
+ Example:
31
+ Load a DeepLabV3 preset with all the 21 class, pretrained segmentation head.
32
+ ```python
33
+ images = np.ones(shape=(1, 96, 96, 3))
34
+ labels = np.zeros(shape=(1, 96, 96, 1))
35
+ segmenter = keras_hub.models.DeepLabV3ImageSegmenter.from_preset(
36
+ "deeplabv3_resnet50_pascalvoc",
37
+ )
38
+ segmenter.predict(images)
39
+ ```
40
+
41
+ Specify `num_classes` to load randomly initialized segmentation head.
42
+ ```python
43
+ segmenter = keras_hub.models.DeepLabV3ImageSegmenter.from_preset(
44
+ "deeplabv3_resnet50_pascalvoc",
45
+ num_classes=2,
46
+ )
47
+ segmenter.fit(images, labels, epochs=3)
48
+ segmenter.predict(images) # Trained 2 class segmentation.
49
+ ```
50
+ Load DeepLabv3+ presets a extension of DeepLabv3 by adding a simple yet
51
+ effective decoder module to refine the segmentation results especially
52
+ along object boundaries.
53
+ ```python
54
+ segmenter = keras_hub.models.DeepLabV3ImageSegmenter.from_preset(
55
+ "deeplabv3_plus_resnet50_pascalvoc",
56
+ )
57
+ segmenter.predict(images)
58
+ ```
59
+ """
60
+
61
+ backbone_cls = DeepLabV3Backbone
62
+ preprocessor_cls = DeepLabV3ImageSegmenterPreprocessor
63
+
64
+ def __init__(
65
+ self,
66
+ backbone,
67
+ num_classes,
68
+ activation=None,
69
+ preprocessor=None,
70
+ **kwargs,
71
+ ):
72
+ data_format = keras.config.image_data_format()
73
+ # === Layers ===
74
+ self.output_conv = keras.layers.Conv2D(
75
+ name="segmentation_output",
76
+ filters=num_classes,
77
+ kernel_size=1,
78
+ use_bias=False,
79
+ padding="same",
80
+ activation=activation,
81
+ data_format=data_format,
82
+ )
83
+
84
+ # === Functional Model ===
85
+ inputs = backbone.input
86
+ x = backbone(inputs)
87
+ outputs = self.output_conv(x)
88
+ super().__init__(
89
+ inputs=inputs,
90
+ outputs=outputs,
91
+ **kwargs,
92
+ )
93
+
94
+ # === Config ===
95
+ self.backbone = backbone
96
+ self.num_classes = num_classes
97
+ self.activation = activation
98
+ self.preprocessor = preprocessor
99
+
100
+ def get_config(self):
101
+ # Backbone serialized in `super`
102
+ config = super().get_config()
103
+ config.update(
104
+ {
105
+ "num_classes": self.num_classes,
106
+ "activation": self.activation,
107
+ }
108
+ )
109
+ return config
@@ -1,5 +1,3 @@
1
- import keras
2
-
3
1
  from keras_hub.src.api_export import keras_hub_export
4
2
  from keras_hub.src.models.densenet.densenet_backbone import DenseNetBackbone
5
3
  from keras_hub.src.models.densenet.densenet_image_classifier_preprocessor import (
@@ -10,131 +8,5 @@ from keras_hub.src.models.image_classifier import ImageClassifier
10
8
 
11
9
  @keras_hub_export("keras_hub.models.DenseNetImageClassifier")
12
10
  class DenseNetImageClassifier(ImageClassifier):
13
- """DenseNet image classifier task model.
14
-
15
- To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)`
16
- where `x` is a tensor and `y` is a integer from `[0, num_classes)`.
17
- All `ImageClassifier` tasks include a `from_preset()` constructor which can
18
- be used to load a pre-trained config and weights.
19
-
20
- Args:
21
- backbone: A `keras_hub.models.DenseNetBackbone` instance.
22
- num_classes: int. The number of classes to predict.
23
- activation: `None`, str or callable. The activation function to use on
24
- the `Dense` layer. Set `activation=None` to return the output
25
- logits. Defaults to `None`.
26
- pooling: A pooling layer to use before the final classification layer,
27
- must be one of "avg" or "max". Use "avg" for
28
- `GlobalAveragePooling2D` and "max" for "GlobalMaxPooling2D.
29
- preprocessor: A `keras_hub.models.DenseNetImageClassifierPreprocessor`
30
- or `None`. If `None`, this model will not apply preprocessing, and
31
- inputs should be preprocessed before calling the model.
32
-
33
- Examples:
34
-
35
- Call `predict()` to run inference.
36
- ```python
37
- # Load preset and train
38
- images = np.ones((2, 224, 224, 3), dtype="float32")
39
- classifier = keras_hub.models.DenseNetImageClassifier.from_preset(
40
- "densenet121_imagenet")
41
- classifier.predict(images)
42
- ```
43
-
44
- Call `fit()` on a single batch.
45
- ```python
46
- # Load preset and train
47
- images = np.ones((2, 224, 224, 3), dtype="float32")
48
- labels = [0, 3]
49
- classifier = keras_hub.models.DenseNetImageClassifier.from_preset(
50
- "densenet121_imagenet")
51
- classifier.fit(x=images, y=labels, batch_size=2)
52
- ```
53
-
54
- Call `fit()` with custom loss, optimizer and backbone.
55
- ```python
56
- classifier = keras_hub.models.DenseNetImageClassifier.from_preset(
57
- "densenet121_imagenet")
58
- classifier.compile(
59
- loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
60
- optimizer=keras.optimizers.Adam(5e-5),
61
- )
62
- classifier.backbone.trainable = False
63
- classifier.fit(x=images, y=labels, batch_size=2)
64
- ```
65
-
66
- Custom backbone.
67
- ```python
68
- images = np.ones((2, 224, 224, 3), dtype="float32")
69
- labels = [0, 3]
70
- backbone = keras_hub.models.DenseNetBackbone(
71
- stackwise_num_filters=[128, 256, 512, 1024],
72
- stackwise_depth=[3, 9, 9, 3],
73
- block_type="basic_block",
74
- image_shape = (224, 224, 3),
75
- )
76
- classifier = keras_hub.models.DenseNetImageClassifier(
77
- backbone=backbone,
78
- num_classes=4,
79
- )
80
- classifier.fit(x=images, y=labels, batch_size=2)
81
- ```
82
- """
83
-
84
11
  backbone_cls = DenseNetBackbone
85
12
  preprocessor_cls = DenseNetImageClassifierPreprocessor
86
-
87
- def __init__(
88
- self,
89
- backbone,
90
- num_classes,
91
- activation=None,
92
- pooling="avg",
93
- preprocessor=None,
94
- **kwargs,
95
- ):
96
- # === Layers ===
97
- self.backbone = backbone
98
- self.preprocessor = preprocessor
99
- if pooling == "avg":
100
- self.pooler = keras.layers.GlobalAveragePooling2D()
101
- elif pooling == "max":
102
- self.pooler = keras.layers.GlobalMaxPooling2D()
103
- else:
104
- raise ValueError(
105
- "Unknown `pooling` type. Polling should be either `'avg'` or "
106
- f"`'max'`. Received: pooling={pooling}."
107
- )
108
- self.output_dense = keras.layers.Dense(
109
- num_classes,
110
- activation=activation,
111
- name="predictions",
112
- )
113
-
114
- # === Functional Model ===
115
- inputs = self.backbone.input
116
- x = self.backbone(inputs)
117
- x = self.pooler(x)
118
- outputs = self.output_dense(x)
119
- super().__init__(
120
- inputs=inputs,
121
- outputs=outputs,
122
- **kwargs,
123
- )
124
-
125
- # === Config ===
126
- self.num_classes = num_classes
127
- self.activation = activation
128
- self.pooling = pooling
129
-
130
- def get_config(self):
131
- # Backbone serialized in `super`
132
- config = super().get_config()
133
- config.update(
134
- {
135
- "num_classes": self.num_classes,
136
- "activation": self.activation,
137
- "pooling": self.pooling,
138
- }
139
- )
140
- return config
@@ -1,10 +1,8 @@
1
1
  from keras_hub.src.api_export import keras_hub_export
2
- from keras_hub.src.layers.preprocessing.resizing_image_converter import (
3
- ResizingImageConverter,
4
- )
2
+ from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
5
3
  from keras_hub.src.models.densenet.densenet_backbone import DenseNetBackbone
6
4
 
7
5
 
8
6
  @keras_hub_export("keras_hub.layers.DenseNetImageConverter")
9
- class DenseNetImageConverter(ResizingImageConverter):
7
+ class DenseNetImageConverter(ImageConverter):
10
8
  backbone_cls = DenseNetBackbone
@@ -15,7 +15,7 @@ class FeaturePyramidBackbone(Backbone):
15
15
  Example:
16
16
 
17
17
  ```python
18
- input_data = np.random.uniform(0, 255, size=(2, 224, 224, 3))
18
+ input_data = np.random.uniform(0, 256, size=(2, 224, 224, 3))
19
19
 
20
20
  # Convert to feature pyramid output format using ResNet.
21
21
  backbone = ResNetBackbone.from_preset("resnet50")
@@ -15,11 +15,156 @@ class ImageClassifier(Task):
15
15
 
16
16
  To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)`
17
17
  labels where `x` is a string and `y` is a integer from `[0, num_classes)`.
18
+ All `ImageClassifier` tasks include a `from_preset()` constructor which can
19
+ be used to load a pre-trained config and weights.
18
20
 
19
- All `ImageClassifier` tasks include a `from_preset()` constructor which can be
20
- used to load a pre-trained config and weights.
21
+ Args:
22
+ backbone: A `keras_hub.models.Backbone` instance or a `keras.Model`.
23
+ num_classes: int. The number of classes to predict.
24
+ preprocessor: `None`, a `keras_hub.models.Preprocessor` instance,
25
+ a `keras.Layer` instance, or a callable. If `None` no preprocessing
26
+ will be applied to the inputs.
27
+ pooling: `"avg"` or `"max"`. The type of pooling to apply on backbone
28
+ output. Defaults to average pooling.
29
+ activation: `None`, str, or callable. The activation function to use on
30
+ the `Dense` layer. Set `activation=None` to return the output
31
+ logits. Defaults to `"softmax"`.
32
+ head_dtype: `None`, str, or `keras.mixed_precision.DTypePolicy`. The
33
+ dtype to use for the classification head's computations and weights.
34
+
35
+ Examples:
36
+
37
+ Call `predict()` to run inference.
38
+ ```python
39
+ # Load preset and train
40
+ images = np.random.randint(0, 256, size=(2, 224, 224, 3))
41
+ classifier = keras_hub.models.ImageClassifier.from_preset(
42
+ "resnet_50_imagenet"
43
+ )
44
+ classifier.predict(images)
45
+ ```
46
+
47
+ Call `fit()` on a single batch.
48
+ ```python
49
+ # Load preset and train
50
+ images = np.random.randint(0, 256, size=(2, 224, 224, 3))
51
+ labels = [0, 3]
52
+ classifier = keras_hub.models.ImageClassifier.from_preset(
53
+ "resnet_50_imagenet"
54
+ )
55
+ classifier.fit(x=images, y=labels, batch_size=2)
56
+ ```
57
+
58
+ Call `fit()` with custom loss, optimizer and backbone.
59
+ ```python
60
+ classifier = keras_hub.models.ImageClassifier.from_preset(
61
+ "resnet_50_imagenet"
62
+ )
63
+ classifier.compile(
64
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
65
+ optimizer=keras.optimizers.Adam(5e-5),
66
+ )
67
+ classifier.backbone.trainable = False
68
+ classifier.fit(x=images, y=labels, batch_size=2)
69
+ ```
70
+
71
+ Custom backbone.
72
+ ```python
73
+ images = np.random.randint(0, 256, size=(2, 224, 224, 3))
74
+ labels = [0, 3]
75
+ backbone = keras_hub.models.ResNetBackbone(
76
+ stackwise_num_filters=[64, 64, 64],
77
+ stackwise_num_blocks=[2, 2, 2],
78
+ stackwise_num_strides=[1, 2, 2],
79
+ block_type="basic_block",
80
+ use_pre_activation=True,
81
+ pooling="avg",
82
+ )
83
+ classifier = keras_hub.models.ImageClassifier(
84
+ backbone=backbone,
85
+ num_classes=4,
86
+ )
87
+ classifier.fit(x=images, y=labels, batch_size=2)
88
+ ```
21
89
  """
22
90
 
91
+ def __init__(
92
+ self,
93
+ backbone,
94
+ num_classes,
95
+ preprocessor=None,
96
+ pooling="avg",
97
+ activation=None,
98
+ dropout=0.0,
99
+ head_dtype=None,
100
+ **kwargs,
101
+ ):
102
+ head_dtype = head_dtype or backbone.dtype_policy
103
+ data_format = getattr(backbone, "data_format", None)
104
+
105
+ # === Layers ===
106
+ self.backbone = backbone
107
+ self.preprocessor = preprocessor
108
+ if pooling == "avg":
109
+ self.pooler = keras.layers.GlobalAveragePooling2D(
110
+ data_format,
111
+ dtype=head_dtype,
112
+ name="pooler",
113
+ )
114
+ elif pooling == "max":
115
+ self.pooler = keras.layers.GlobalMaxPooling2D(
116
+ data_format,
117
+ dtype=head_dtype,
118
+ name="pooler",
119
+ )
120
+ else:
121
+ raise ValueError(
122
+ "Unknown `pooling` type. Polling should be either `'avg'` or "
123
+ f"`'max'`. Received: pooling={pooling}."
124
+ )
125
+ self.output_dropout = keras.layers.Dropout(
126
+ dropout,
127
+ dtype=head_dtype,
128
+ name="output_dropout",
129
+ )
130
+ self.output_dense = keras.layers.Dense(
131
+ num_classes,
132
+ activation=activation,
133
+ dtype=head_dtype,
134
+ name="predictions",
135
+ )
136
+
137
+ # === Functional Model ===
138
+ inputs = self.backbone.input
139
+ x = self.backbone(inputs)
140
+ x = self.pooler(x)
141
+ x = self.output_dropout(x)
142
+ outputs = self.output_dense(x)
143
+ super().__init__(
144
+ inputs=inputs,
145
+ outputs=outputs,
146
+ **kwargs,
147
+ )
148
+
149
+ # === Config ===
150
+ self.num_classes = num_classes
151
+ self.activation = activation
152
+ self.pooling = pooling
153
+ self.dropout = dropout
154
+
155
+ def get_config(self):
156
+ # Backbone serialized in `super`
157
+ config = super().get_config()
158
+ config.update(
159
+ {
160
+ "num_classes": self.num_classes,
161
+ "pooling": self.pooling,
162
+ "activation": self.activation,
163
+ "dropout": self.dropout,
164
+ }
165
+ )
166
+ return config
167
+
23
168
  def compile(
24
169
  self,
25
170
  optimizer="auto",
@@ -38,15 +38,15 @@ class ImageClassifierPreprocessor(Preprocessor):
38
38
  )
39
39
 
40
40
  # Resize a single image for resnet 50.
41
- x = np.ones((512, 512, 3))
41
+ x = np.random.randint(0, 256, (512, 512, 3))
42
42
  x = preprocessor(x)
43
43
 
44
44
  # Resize a labeled image.
45
- x, y = np.ones((512, 512, 3)), 1
45
+ x, y = np.random.randint(0, 256, (512, 512, 3)), 1
46
46
  x, y = preprocessor(x, y)
47
47
 
48
48
  # Resize a batch of labeled images.
49
- x, y = [np.ones((512, 512, 3)), np.zeros((512, 512, 3))], [1, 0]
49
+ x, y = [np.random.randint(0, 256, (512, 512, 3)), np.zeros((512, 512, 3))], [1, 0]
50
50
  x, y = preprocessor(x, y)
51
51
 
52
52
  # Use a `tf.data.Dataset`.