keras-hub-nightly 0.16.1.dev202409230338__py3-none-any.whl → 0.16.1.dev202409250340__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -56,6 +56,8 @@ from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import (
56
56
  from keras_hub.src.models.resnet.resnet_image_converter import (
57
57
  ResNetImageConverter,
58
58
  )
59
+ from keras_hub.src.models.sam.sam_mask_decoder import SAMMaskDecoder
60
+ from keras_hub.src.models.sam.sam_prompt_encoder import SAMPromptEncoder
59
61
  from keras_hub.src.models.whisper.whisper_audio_converter import (
60
62
  WhisperAudioConverter,
61
63
  )
@@ -175,6 +175,7 @@ from keras_hub.src.models.image_classifier import ImageClassifier
175
175
  from keras_hub.src.models.image_classifier_preprocessor import (
176
176
  ImageClassifierPreprocessor,
177
177
  )
178
+ from keras_hub.src.models.image_segmenter import ImageSegmenter
178
179
  from keras_hub.src.models.llama3.llama3_backbone import Llama3Backbone
179
180
  from keras_hub.src.models.llama3.llama3_causal_lm import Llama3CausalLM
180
181
  from keras_hub.src.models.llama3.llama3_causal_lm_preprocessor import (
@@ -255,6 +256,8 @@ from keras_hub.src.models.roberta.roberta_text_classifier_preprocessor import (
255
256
  RobertaTextClassifierPreprocessor as RobertaPreprocessor,
256
257
  )
257
258
  from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer
259
+ from keras_hub.src.models.sam.sam_backbone import SAMBackbone
260
+ from keras_hub.src.models.sam.sam_image_segmenter import SAMImageSegmenter
258
261
  from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM
259
262
  from keras_hub.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor
260
263
  from keras_hub.src.models.t5.t5_backbone import T5Backbone
@@ -0,0 +1,86 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import keras
15
+
16
+ from keras_hub.src.api_export import keras_hub_export
17
+ from keras_hub.src.models.task import Task
18
+
19
+
20
+ @keras_hub_export("keras_hub.models.ImageSegmenter")
21
+ class ImageSegmenter(Task):
22
+ """Base class for all image segmentation tasks.
23
+
24
+ `ImageSegmenter` tasks wrap a `keras_hub.models.Task` and
25
+ a `keras_hub.models.Preprocessor` to create a model that can be used for
26
+ image segmentation.
27
+
28
+ All `ImageSegmenter` tasks include a `from_preset()` constructor which can
29
+ be used to load a pre-trained config and weights.
30
+ """
31
+
32
+ def __init__(self, *args, **kwargs):
33
+ super().__init__(*args, **kwargs)
34
+ # Default compilation.
35
+ self.compile()
36
+
37
+ def compile(
38
+ self,
39
+ optimizer="auto",
40
+ loss="auto",
41
+ *,
42
+ metrics="auto",
43
+ **kwargs,
44
+ ):
45
+ """Configures the `ImageSegmenter` task for training.
46
+
47
+ The `ImageSegmenter` task extends the default compilation signature of
48
+ `keras.Model.compile` with defaults for `optimizer`, `loss`, and
49
+ `metrics`. To override these defaults, pass any value
50
+ to these arguments during compilation.
51
+
52
+ Args:
53
+ optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer`
54
+ instance. Defaults to `"auto"`, which uses the default optimizer
55
+ for the given model and task. See `keras.Model.compile` and
56
+ `keras.optimizers` for more info on possible `optimizer` values.
57
+ loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance.
58
+ Defaults to `"auto"`, where a
59
+ `keras.losses.SparseCategoricalCrossentropy` loss will be
60
+ applied for the classification task. See
61
+ `keras.Model.compile` and `keras.losses` for more info on
62
+ possible `loss` values.
63
+ metrics: `"auto"`, or a list of metrics to be evaluated by
64
+ the model during training and testing. Defaults to `"auto"`,
65
+ where a `keras.metrics.SparseCategoricalAccuracy` will be
66
+ applied to track the accuracy of the model during training.
67
+ See `keras.Model.compile` and `keras.metrics` for
68
+ more info on possible `metrics` values.
69
+ **kwargs: See `keras.Model.compile` for a full list of arguments
70
+ supported by the compile method.
71
+ """
72
+ if optimizer == "auto":
73
+ optimizer = keras.optimizers.Adam(5e-5)
74
+ if loss == "auto":
75
+ activation = getattr(self, "activation", None)
76
+ activation = keras.activations.get(activation)
77
+ from_logits = activation != keras.activations.softmax
78
+ loss = keras.losses.CategoricalCrossentropy(from_logits=from_logits)
79
+ if metrics == "auto":
80
+ metrics = [keras.metrics.CategoricalAccuracy()]
81
+ super().compile(
82
+ optimizer=optimizer,
83
+ loss=loss,
84
+ metrics=metrics,
85
+ **kwargs,
86
+ )
@@ -0,0 +1,13 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,153 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import keras
16
+
17
+ from keras_hub.src.api_export import keras_hub_export
18
+ from keras_hub.src.models.backbone import Backbone
19
+
20
+
21
+ @keras_hub_export("keras_hub.models.SAMBackbone")
22
+ class SAMBackbone(Backbone):
23
+ """A backbone for the Segment Anything Model (SAM).
24
+
25
+ Args:
26
+ image_encoder: `keras_hub.models.ViTDetBackbone`. A feature extractor for
27
+ the input images.
28
+ prompt_encoder: `keras_hub.layers.SAMPromptEncoder`. A Keras layer to
29
+ compute embeddings for points, box, and mask prompt.
30
+ mask_decoder: `keras_hub.layers.SAMMaskDecoder`. A Keras layer to
31
+ generate segmentation masks given the embeddings generated by the
32
+ backbone and the prompt encoder.
33
+ dtype: The dtype of the layer weights.
34
+
35
+ Example:
36
+ ```python
37
+ image_size=128
38
+ batch_size=2
39
+ input_data = {
40
+ "images": np.ones(
41
+ (batch_size, image_size, image_size, 3),
42
+ dtype="float32",
43
+ ),
44
+ "points": np.ones((batch_size, 1, 2), dtype="float32"),
45
+ "labels": np.ones((batch_size, 1), dtype="float32"),
46
+ "boxes": np.ones((batch_size, 1, 2, 2), dtype="float32"),
47
+ "masks": np.zeros(
48
+ (batch_size, 0, image_size, image_size, 1)
49
+ ),
50
+ }
51
+ image_encoder = keras_hub.models.ViTDetBackbone(
52
+ hidden_size=16,
53
+ num_layers=16,
54
+ intermediate_dim=16 * 4,
55
+ num_heads=16,
56
+ global_attention_layer_indices=[2, 5, 8, 11],
57
+ patch_size=16,
58
+ num_output_channels=8,
59
+ window_size=2,
60
+ image_shape=(image_size, image_size, 3),
61
+ )
62
+ prompt_encoder = keras_hub.layers.SAMPromptEncoder(
63
+ hidden_size=8,
64
+ image_embedding_size=(8, 8),
65
+ input_image_size=(
66
+ image_size,
67
+ image_size,
68
+ ),
69
+ mask_in_channels=16,
70
+ )
71
+ mask_decoder = keras_hub.layers.SAMMaskDecoder(
72
+ num_layers=2,
73
+ hidden_size=8,
74
+ intermediate_dim=32,
75
+ num_heads=8,
76
+ embedding_dim=8,
77
+ num_multimask_outputs=3,
78
+ iou_head_depth=3,
79
+ iou_head_hidden_dim=8,
80
+ )
81
+ backbone = keras_hub.models.SAMBackbone(
82
+ image_encoder=image_encoder,
83
+ prompt_encoder=prompt_encoder,
84
+ mask_decoder=mask_decoder,
85
+ image_shape=(image_size, image_size, 3),
86
+ )
87
+ backbone(input_data)
88
+ ```
89
+ """
90
+
91
+ def __init__(
92
+ self,
93
+ image_encoder,
94
+ prompt_encoder,
95
+ mask_decoder,
96
+ dtype=None,
97
+ **kwargs,
98
+ ):
99
+ # === Layers ===
100
+ self.image_encoder = image_encoder
101
+ self.prompt_encoder = prompt_encoder
102
+ self.mask_decoder = mask_decoder
103
+ # === Functional model
104
+ image_input = self.image_encoder.input
105
+
106
+ inputs = {
107
+ "images": image_input,
108
+ "points": keras.Input(shape=[None, 2], name="points"),
109
+ "labels": keras.Input(shape=[None], name="labels"),
110
+ "boxes": keras.Input(shape=[None, 2, 2], name="boxes"),
111
+ "masks": keras.Input(shape=[None, None, None, 1], name="masks"),
112
+ }
113
+ image_embeddings = self.image_encoder.output
114
+ prompt_embeddings = self.prompt_encoder(**inputs)
115
+ outputs = {
116
+ "image_embeddings": image_embeddings,
117
+ }
118
+ outputs.update(prompt_embeddings)
119
+ super().__init__(
120
+ inputs=inputs,
121
+ outputs=outputs,
122
+ dtype=dtype,
123
+ **kwargs,
124
+ )
125
+
126
+ def get_config(self):
127
+ config = super().get_config()
128
+ config.update(
129
+ {
130
+ "image_encoder": keras.layers.serialize(self.image_encoder),
131
+ "prompt_encoder": keras.layers.serialize(self.prompt_encoder),
132
+ "mask_decoder": keras.layers.serialize(self.mask_decoder),
133
+ }
134
+ )
135
+ return config
136
+
137
+ @classmethod
138
+ def from_config(cls, config):
139
+ config.update(
140
+ {
141
+ "image_encoder": keras.layers.deserialize(
142
+ config["image_encoder"]
143
+ ),
144
+ "prompt_encoder": keras.layers.deserialize(
145
+ config["prompt_encoder"]
146
+ ),
147
+ "mask_decoder": keras.layers.deserialize(
148
+ config["mask_decoder"]
149
+ ),
150
+ }
151
+ )
152
+
153
+ return super().from_config(config)
@@ -0,0 +1,237 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import numpy as np
16
+ from keras import ops
17
+
18
+ from keras_hub.src.api_export import keras_hub_export
19
+ from keras_hub.src.models.image_segmenter import ImageSegmenter
20
+ from keras_hub.src.models.sam.sam_backbone import SAMBackbone
21
+
22
+
23
+ @keras_hub_export("keras_hub.models.SAMImageSegmenter")
24
+ class SAMImageSegmenter(ImageSegmenter):
25
+ """The Segment Anything (SAM) image segmenter Model.
26
+
27
+ SAM works by prompting the input images. There are three ways to prompt:
28
+ (1) Labelled Points: Foreground points (points with label 1) are encoded
29
+ such that the output masks generated by the mask decoder contain them
30
+ and background points (points with label 0) are encoded such that the
31
+ generated masks don't contain them.
32
+ (2) Box: A box tells the model which part/crop of the image to segment.
33
+ (3) Mask: An input mask can be used to refine the output of the mask
34
+ decoder.
35
+ These prompts can be mixed and matched but at least one of the prompts
36
+ must be present. To turn off a particular prompt, simply exclude it from
37
+ the inputs to the model.
38
+ (1) For points prompts, the expected shape is `(batch, num_points, 2)`.
39
+ The labels must have a corresponding shape of `(batch, num_points)`.
40
+ (2) For box prompt, the expected shape is `(batch, 1, 2, 2)`.
41
+ (3) Similarly, mask prompts have shape `(batch, 1, H, W, 1)`.
42
+
43
+
44
+ Args:
45
+ backbone: A `keras_hub.models.VGGBackbone` instance.
46
+
47
+ Example:
48
+ Load pretrained model using `from_preset`.
49
+
50
+ ```python
51
+ image_size=128
52
+ batch_size=2
53
+ input_data = {
54
+ "images": np.ones(
55
+ (batch_size, image_size, image_size, 3),
56
+ dtype="float32",
57
+ ),
58
+ "points": np.ones((batch_size, 1, 2), dtype="float32"),
59
+ "labels": np.ones((batch_size, 1), dtype="float32"),
60
+ "boxes": np.ones((batch_size, 1, 2, 2), dtype="float32"),
61
+ "masks": np.zeros(
62
+ (batch_size, 0, image_size, image_size, 1)
63
+ ),
64
+ }
65
+ # todo: update preset name
66
+ sam = keras_hub.models.SAMImageSegmenter.from_preset(`sam_base`)
67
+ sam(input_data)
68
+ ```
69
+
70
+ Load segment anything image segmenter with custom backbone
71
+
72
+ ```python
73
+ image_size = 128
74
+ batch_size = 2
75
+ images = np.ones(
76
+ (batch_size, image_size, image_size, 3),
77
+ dtype="float32",
78
+ )
79
+ image_encoder = ViTDetBackbone(
80
+ hidden_size=16,
81
+ num_layers=16,
82
+ intermediate_dim=16 * 4,
83
+ num_heads=16,
84
+ global_attention_layer_indices=[2, 5, 8, 11],
85
+ patch_size=16,
86
+ num_output_channels=8,
87
+ window_size=2,
88
+ image_shape=(image_size, image_size, 3),
89
+ )
90
+ prompt_encoder = SAMPromptEncoder(
91
+ hidden_size=8,
92
+ image_embedding_size=(8, 8),
93
+ input_image_size=(
94
+ image_size,
95
+ image_size,
96
+ ),
97
+ mask_in_channels=16,
98
+ )
99
+ mask_decoder = SAMMaskDecoder(
100
+ num_layers=2,
101
+ hidden_size=8,
102
+ intermediate_dim=32,
103
+ num_heads=8,
104
+ embedding_dim=8,
105
+ num_multimask_outputs=3,
106
+ iou_head_depth=3,
107
+ iou_head_hidden_dim=8,
108
+ )
109
+ backbone = SAMBackbone(
110
+ image_encoder=image_encoder,
111
+ prompt_encoder=prompt_encoder,
112
+ mask_decoder=mask_decoder,
113
+ image_shape=(image_size, image_size, 3),
114
+ )
115
+ sam = SAMImageSegmenter(
116
+ backbone=backbone
117
+ )
118
+ ```
119
+
120
+ For example, to pass in all the prompts, do:
121
+
122
+ ```python
123
+
124
+ points = np.array([[[512., 512.], [100., 100.]]])
125
+ # For labels: 1 means foreground point, 0 means background
126
+ labels = np.array([[1., 0.]])
127
+ box = np.array([[[[384., 384.], [640., 640.]]]])
128
+ input_mask = np.ones((1, 1, 256, 256, 1))
129
+ Prepare an input dictionary:
130
+ inputs = {
131
+ "images": image,
132
+ "points": points,
133
+ "labels": labels,
134
+ "boxes": box,
135
+ "masks": input_mask
136
+ }
137
+ outputs = sam.predict(inputs)
138
+ masks, iou_pred = outputs["masks"], outputs["iou_pred"]
139
+ ```
140
+
141
+ The first mask in the output `masks` (i.e. `masks[:, 0, ...]`) is the best
142
+ mask predicted by the model based on the prompts. Other `masks`
143
+ (i.e. `masks[:, 1:, ...]`) are alternate predictions that can be used if
144
+ they are desired over the first one.
145
+ Now, in case of only points and box prompts, simply exclude the masks:
146
+
147
+ ```python
148
+ inputs = {
149
+ "images": image,
150
+ "points": points,
151
+ "labels": labels,
152
+ "boxes": box,
153
+ }
154
+
155
+ outputs = sam.predict(inputs)
156
+ masks, iou_pred = outputs["masks"], outputs["iou_pred"]
157
+ ```
158
+
159
+ Another example is that only points prompts are present.
160
+ Note that if point prompts are present but no box prompt is present, the
161
+ points must be padded using a zero point and -1 label:
162
+
163
+ ```python
164
+ padded_points = np.concatenate(
165
+ [points, np.zeros((1, 1, 2))], axis=1
166
+ )
167
+
168
+ padded_labels = np.concatenate(
169
+ [labels, -np.ones((1, 1))], axis=1
170
+ )
171
+ inputs = {
172
+ "images": image,
173
+ "points": padded_points,
174
+ "labels": padded_labels,
175
+ }
176
+ outputs = sam.predict(inputs)
177
+ masks, iou_pred = outputs["masks"], outputs["iou_pred"]
178
+ ```
179
+ """
180
+
181
+ backbone_cls = SAMBackbone
182
+ preprocessor_cls = None
183
+
184
+ def __init__(self, backbone, preprocessor=None, **kwargs):
185
+ # The implementation has been adapted form [Segment Anything
186
+ # paper](https://arxiv.org/abs/2304.02643) and [Segment Anything
187
+ # GitHub](https://github.com/facebookresearch/segment-anything) and
188
+ # [Detectron2](https://github.com/facebookresearch/detectron2).
189
+ # === Layers ===
190
+ self.backbone = backbone
191
+ # === Functional Model ===
192
+ inputs = self.backbone.input
193
+ x = self.backbone(inputs)
194
+ outputs = self.backbone.mask_decoder(**x)
195
+ super().__init__(inputs=inputs, outputs=outputs, **kwargs)
196
+
197
+ def predict_step(self, *args, **kwargs):
198
+ if len(args) == 2:
199
+ args = (args[0], self._add_placeholder_prompts(args[-1]))
200
+ else:
201
+ args = (self._add_placeholder_prompts(args[0]),)
202
+
203
+ return super().predict_step(*args, **kwargs)
204
+
205
+ def fit(self, *args, **kwargs):
206
+ raise NotImplementedError(
207
+ "Segment Anything Model only supports inference for now. Training"
208
+ " the model isn't supported yet."
209
+ )
210
+
211
+ def _add_placeholder_prompts(self, inputs):
212
+ """Adds placeholder prompt inputs for a call to SAM.
213
+
214
+ Because SAM is a functional subclass model, all inputs must be specified in
215
+ calls to the model. However, prompt inputs are all optional, so we have to
216
+ add placeholders when they're not specified by the user.
217
+ """
218
+ inputs = inputs.copy()
219
+
220
+ # Get the batch shape based on the image input
221
+ batch_size = ops.shape(inputs["images"])[0]
222
+
223
+ # The type of the placeholders must match the existing inputs with respect
224
+ # to whether or not they are tensors (as opposed to Numpy arrays).
225
+ zeros = ops.zeros if ops.is_tensor(inputs["images"]) else np.zeros
226
+
227
+ # Fill in missing inputs.
228
+ if "points" not in inputs:
229
+ inputs["points"] = zeros((batch_size, 0, 2))
230
+ if "labels" not in inputs:
231
+ inputs["labels"] = zeros((batch_size, 0))
232
+ if "boxes" not in inputs:
233
+ inputs["boxes"] = zeros((batch_size, 0, 2, 2))
234
+ if "masks" not in inputs:
235
+ inputs["masks"] = zeros((batch_size, 0, 256, 256, 1))
236
+
237
+ return inputs