keras-hub-nightly 0.16.1.dev202409230338__py3-none-any.whl → 0.16.1.dev202409250340__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +2 -0
- keras_hub/api/models/__init__.py +3 -0
- keras_hub/src/models/image_segmenter.py +86 -0
- keras_hub/src/models/sam/__init__.py +13 -0
- keras_hub/src/models/sam/sam_backbone.py +153 -0
- keras_hub/src/models/sam/sam_image_segmenter.py +237 -0
- keras_hub/src/models/sam/sam_layers.py +402 -0
- keras_hub/src/models/sam/sam_mask_decoder.py +270 -0
- keras_hub/src/models/sam/sam_prompt_encoder.py +336 -0
- keras_hub/src/models/sam/sam_transformer.py +159 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +17 -12
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.16.1.dev202409230338.dist-info → keras_hub_nightly-0.16.1.dev202409250340.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.16.1.dev202409230338.dist-info → keras_hub_nightly-0.16.1.dev202409250340.dist-info}/RECORD +16 -8
- {keras_hub_nightly-0.16.1.dev202409230338.dist-info → keras_hub_nightly-0.16.1.dev202409250340.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.16.1.dev202409230338.dist-info → keras_hub_nightly-0.16.1.dev202409250340.dist-info}/top_level.txt +0 -0
keras_hub/api/layers/__init__.py
CHANGED
@@ -56,6 +56,8 @@ from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import (
|
|
56
56
|
from keras_hub.src.models.resnet.resnet_image_converter import (
|
57
57
|
ResNetImageConverter,
|
58
58
|
)
|
59
|
+
from keras_hub.src.models.sam.sam_mask_decoder import SAMMaskDecoder
|
60
|
+
from keras_hub.src.models.sam.sam_prompt_encoder import SAMPromptEncoder
|
59
61
|
from keras_hub.src.models.whisper.whisper_audio_converter import (
|
60
62
|
WhisperAudioConverter,
|
61
63
|
)
|
keras_hub/api/models/__init__.py
CHANGED
@@ -175,6 +175,7 @@ from keras_hub.src.models.image_classifier import ImageClassifier
|
|
175
175
|
from keras_hub.src.models.image_classifier_preprocessor import (
|
176
176
|
ImageClassifierPreprocessor,
|
177
177
|
)
|
178
|
+
from keras_hub.src.models.image_segmenter import ImageSegmenter
|
178
179
|
from keras_hub.src.models.llama3.llama3_backbone import Llama3Backbone
|
179
180
|
from keras_hub.src.models.llama3.llama3_causal_lm import Llama3CausalLM
|
180
181
|
from keras_hub.src.models.llama3.llama3_causal_lm_preprocessor import (
|
@@ -255,6 +256,8 @@ from keras_hub.src.models.roberta.roberta_text_classifier_preprocessor import (
|
|
255
256
|
RobertaTextClassifierPreprocessor as RobertaPreprocessor,
|
256
257
|
)
|
257
258
|
from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer
|
259
|
+
from keras_hub.src.models.sam.sam_backbone import SAMBackbone
|
260
|
+
from keras_hub.src.models.sam.sam_image_segmenter import SAMImageSegmenter
|
258
261
|
from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM
|
259
262
|
from keras_hub.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor
|
260
263
|
from keras_hub.src.models.t5.t5_backbone import T5Backbone
|
@@ -0,0 +1,86 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import keras
|
15
|
+
|
16
|
+
from keras_hub.src.api_export import keras_hub_export
|
17
|
+
from keras_hub.src.models.task import Task
|
18
|
+
|
19
|
+
|
20
|
+
@keras_hub_export("keras_hub.models.ImageSegmenter")
|
21
|
+
class ImageSegmenter(Task):
|
22
|
+
"""Base class for all image segmentation tasks.
|
23
|
+
|
24
|
+
`ImageSegmenter` tasks wrap a `keras_hub.models.Task` and
|
25
|
+
a `keras_hub.models.Preprocessor` to create a model that can be used for
|
26
|
+
image segmentation.
|
27
|
+
|
28
|
+
All `ImageSegmenter` tasks include a `from_preset()` constructor which can
|
29
|
+
be used to load a pre-trained config and weights.
|
30
|
+
"""
|
31
|
+
|
32
|
+
def __init__(self, *args, **kwargs):
|
33
|
+
super().__init__(*args, **kwargs)
|
34
|
+
# Default compilation.
|
35
|
+
self.compile()
|
36
|
+
|
37
|
+
def compile(
|
38
|
+
self,
|
39
|
+
optimizer="auto",
|
40
|
+
loss="auto",
|
41
|
+
*,
|
42
|
+
metrics="auto",
|
43
|
+
**kwargs,
|
44
|
+
):
|
45
|
+
"""Configures the `ImageSegmenter` task for training.
|
46
|
+
|
47
|
+
The `ImageSegmenter` task extends the default compilation signature of
|
48
|
+
`keras.Model.compile` with defaults for `optimizer`, `loss`, and
|
49
|
+
`metrics`. To override these defaults, pass any value
|
50
|
+
to these arguments during compilation.
|
51
|
+
|
52
|
+
Args:
|
53
|
+
optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer`
|
54
|
+
instance. Defaults to `"auto"`, which uses the default optimizer
|
55
|
+
for the given model and task. See `keras.Model.compile` and
|
56
|
+
`keras.optimizers` for more info on possible `optimizer` values.
|
57
|
+
loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance.
|
58
|
+
Defaults to `"auto"`, where a
|
59
|
+
`keras.losses.SparseCategoricalCrossentropy` loss will be
|
60
|
+
applied for the classification task. See
|
61
|
+
`keras.Model.compile` and `keras.losses` for more info on
|
62
|
+
possible `loss` values.
|
63
|
+
metrics: `"auto"`, or a list of metrics to be evaluated by
|
64
|
+
the model during training and testing. Defaults to `"auto"`,
|
65
|
+
where a `keras.metrics.SparseCategoricalAccuracy` will be
|
66
|
+
applied to track the accuracy of the model during training.
|
67
|
+
See `keras.Model.compile` and `keras.metrics` for
|
68
|
+
more info on possible `metrics` values.
|
69
|
+
**kwargs: See `keras.Model.compile` for a full list of arguments
|
70
|
+
supported by the compile method.
|
71
|
+
"""
|
72
|
+
if optimizer == "auto":
|
73
|
+
optimizer = keras.optimizers.Adam(5e-5)
|
74
|
+
if loss == "auto":
|
75
|
+
activation = getattr(self, "activation", None)
|
76
|
+
activation = keras.activations.get(activation)
|
77
|
+
from_logits = activation != keras.activations.softmax
|
78
|
+
loss = keras.losses.CategoricalCrossentropy(from_logits=from_logits)
|
79
|
+
if metrics == "auto":
|
80
|
+
metrics = [keras.metrics.CategoricalAccuracy()]
|
81
|
+
super().compile(
|
82
|
+
optimizer=optimizer,
|
83
|
+
loss=loss,
|
84
|
+
metrics=metrics,
|
85
|
+
**kwargs,
|
86
|
+
)
|
@@ -0,0 +1,13 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
@@ -0,0 +1,153 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import keras
|
16
|
+
|
17
|
+
from keras_hub.src.api_export import keras_hub_export
|
18
|
+
from keras_hub.src.models.backbone import Backbone
|
19
|
+
|
20
|
+
|
21
|
+
@keras_hub_export("keras_hub.models.SAMBackbone")
|
22
|
+
class SAMBackbone(Backbone):
|
23
|
+
"""A backbone for the Segment Anything Model (SAM).
|
24
|
+
|
25
|
+
Args:
|
26
|
+
image_encoder: `keras_hub.models.ViTDetBackbone`. A feature extractor for
|
27
|
+
the input images.
|
28
|
+
prompt_encoder: `keras_hub.layers.SAMPromptEncoder`. A Keras layer to
|
29
|
+
compute embeddings for points, box, and mask prompt.
|
30
|
+
mask_decoder: `keras_hub.layers.SAMMaskDecoder`. A Keras layer to
|
31
|
+
generate segmentation masks given the embeddings generated by the
|
32
|
+
backbone and the prompt encoder.
|
33
|
+
dtype: The dtype of the layer weights.
|
34
|
+
|
35
|
+
Example:
|
36
|
+
```python
|
37
|
+
image_size=128
|
38
|
+
batch_size=2
|
39
|
+
input_data = {
|
40
|
+
"images": np.ones(
|
41
|
+
(batch_size, image_size, image_size, 3),
|
42
|
+
dtype="float32",
|
43
|
+
),
|
44
|
+
"points": np.ones((batch_size, 1, 2), dtype="float32"),
|
45
|
+
"labels": np.ones((batch_size, 1), dtype="float32"),
|
46
|
+
"boxes": np.ones((batch_size, 1, 2, 2), dtype="float32"),
|
47
|
+
"masks": np.zeros(
|
48
|
+
(batch_size, 0, image_size, image_size, 1)
|
49
|
+
),
|
50
|
+
}
|
51
|
+
image_encoder = keras_hub.models.ViTDetBackbone(
|
52
|
+
hidden_size=16,
|
53
|
+
num_layers=16,
|
54
|
+
intermediate_dim=16 * 4,
|
55
|
+
num_heads=16,
|
56
|
+
global_attention_layer_indices=[2, 5, 8, 11],
|
57
|
+
patch_size=16,
|
58
|
+
num_output_channels=8,
|
59
|
+
window_size=2,
|
60
|
+
image_shape=(image_size, image_size, 3),
|
61
|
+
)
|
62
|
+
prompt_encoder = keras_hub.layers.SAMPromptEncoder(
|
63
|
+
hidden_size=8,
|
64
|
+
image_embedding_size=(8, 8),
|
65
|
+
input_image_size=(
|
66
|
+
image_size,
|
67
|
+
image_size,
|
68
|
+
),
|
69
|
+
mask_in_channels=16,
|
70
|
+
)
|
71
|
+
mask_decoder = keras_hub.layers.SAMMaskDecoder(
|
72
|
+
num_layers=2,
|
73
|
+
hidden_size=8,
|
74
|
+
intermediate_dim=32,
|
75
|
+
num_heads=8,
|
76
|
+
embedding_dim=8,
|
77
|
+
num_multimask_outputs=3,
|
78
|
+
iou_head_depth=3,
|
79
|
+
iou_head_hidden_dim=8,
|
80
|
+
)
|
81
|
+
backbone = keras_hub.models.SAMBackbone(
|
82
|
+
image_encoder=image_encoder,
|
83
|
+
prompt_encoder=prompt_encoder,
|
84
|
+
mask_decoder=mask_decoder,
|
85
|
+
image_shape=(image_size, image_size, 3),
|
86
|
+
)
|
87
|
+
backbone(input_data)
|
88
|
+
```
|
89
|
+
"""
|
90
|
+
|
91
|
+
def __init__(
|
92
|
+
self,
|
93
|
+
image_encoder,
|
94
|
+
prompt_encoder,
|
95
|
+
mask_decoder,
|
96
|
+
dtype=None,
|
97
|
+
**kwargs,
|
98
|
+
):
|
99
|
+
# === Layers ===
|
100
|
+
self.image_encoder = image_encoder
|
101
|
+
self.prompt_encoder = prompt_encoder
|
102
|
+
self.mask_decoder = mask_decoder
|
103
|
+
# === Functional model
|
104
|
+
image_input = self.image_encoder.input
|
105
|
+
|
106
|
+
inputs = {
|
107
|
+
"images": image_input,
|
108
|
+
"points": keras.Input(shape=[None, 2], name="points"),
|
109
|
+
"labels": keras.Input(shape=[None], name="labels"),
|
110
|
+
"boxes": keras.Input(shape=[None, 2, 2], name="boxes"),
|
111
|
+
"masks": keras.Input(shape=[None, None, None, 1], name="masks"),
|
112
|
+
}
|
113
|
+
image_embeddings = self.image_encoder.output
|
114
|
+
prompt_embeddings = self.prompt_encoder(**inputs)
|
115
|
+
outputs = {
|
116
|
+
"image_embeddings": image_embeddings,
|
117
|
+
}
|
118
|
+
outputs.update(prompt_embeddings)
|
119
|
+
super().__init__(
|
120
|
+
inputs=inputs,
|
121
|
+
outputs=outputs,
|
122
|
+
dtype=dtype,
|
123
|
+
**kwargs,
|
124
|
+
)
|
125
|
+
|
126
|
+
def get_config(self):
|
127
|
+
config = super().get_config()
|
128
|
+
config.update(
|
129
|
+
{
|
130
|
+
"image_encoder": keras.layers.serialize(self.image_encoder),
|
131
|
+
"prompt_encoder": keras.layers.serialize(self.prompt_encoder),
|
132
|
+
"mask_decoder": keras.layers.serialize(self.mask_decoder),
|
133
|
+
}
|
134
|
+
)
|
135
|
+
return config
|
136
|
+
|
137
|
+
@classmethod
|
138
|
+
def from_config(cls, config):
|
139
|
+
config.update(
|
140
|
+
{
|
141
|
+
"image_encoder": keras.layers.deserialize(
|
142
|
+
config["image_encoder"]
|
143
|
+
),
|
144
|
+
"prompt_encoder": keras.layers.deserialize(
|
145
|
+
config["prompt_encoder"]
|
146
|
+
),
|
147
|
+
"mask_decoder": keras.layers.deserialize(
|
148
|
+
config["mask_decoder"]
|
149
|
+
),
|
150
|
+
}
|
151
|
+
)
|
152
|
+
|
153
|
+
return super().from_config(config)
|
@@ -0,0 +1,237 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import numpy as np
|
16
|
+
from keras import ops
|
17
|
+
|
18
|
+
from keras_hub.src.api_export import keras_hub_export
|
19
|
+
from keras_hub.src.models.image_segmenter import ImageSegmenter
|
20
|
+
from keras_hub.src.models.sam.sam_backbone import SAMBackbone
|
21
|
+
|
22
|
+
|
23
|
+
@keras_hub_export("keras_hub.models.SAMImageSegmenter")
|
24
|
+
class SAMImageSegmenter(ImageSegmenter):
|
25
|
+
"""The Segment Anything (SAM) image segmenter Model.
|
26
|
+
|
27
|
+
SAM works by prompting the input images. There are three ways to prompt:
|
28
|
+
(1) Labelled Points: Foreground points (points with label 1) are encoded
|
29
|
+
such that the output masks generated by the mask decoder contain them
|
30
|
+
and background points (points with label 0) are encoded such that the
|
31
|
+
generated masks don't contain them.
|
32
|
+
(2) Box: A box tells the model which part/crop of the image to segment.
|
33
|
+
(3) Mask: An input mask can be used to refine the output of the mask
|
34
|
+
decoder.
|
35
|
+
These prompts can be mixed and matched but at least one of the prompts
|
36
|
+
must be present. To turn off a particular prompt, simply exclude it from
|
37
|
+
the inputs to the model.
|
38
|
+
(1) For points prompts, the expected shape is `(batch, num_points, 2)`.
|
39
|
+
The labels must have a corresponding shape of `(batch, num_points)`.
|
40
|
+
(2) For box prompt, the expected shape is `(batch, 1, 2, 2)`.
|
41
|
+
(3) Similarly, mask prompts have shape `(batch, 1, H, W, 1)`.
|
42
|
+
|
43
|
+
|
44
|
+
Args:
|
45
|
+
backbone: A `keras_hub.models.VGGBackbone` instance.
|
46
|
+
|
47
|
+
Example:
|
48
|
+
Load pretrained model using `from_preset`.
|
49
|
+
|
50
|
+
```python
|
51
|
+
image_size=128
|
52
|
+
batch_size=2
|
53
|
+
input_data = {
|
54
|
+
"images": np.ones(
|
55
|
+
(batch_size, image_size, image_size, 3),
|
56
|
+
dtype="float32",
|
57
|
+
),
|
58
|
+
"points": np.ones((batch_size, 1, 2), dtype="float32"),
|
59
|
+
"labels": np.ones((batch_size, 1), dtype="float32"),
|
60
|
+
"boxes": np.ones((batch_size, 1, 2, 2), dtype="float32"),
|
61
|
+
"masks": np.zeros(
|
62
|
+
(batch_size, 0, image_size, image_size, 1)
|
63
|
+
),
|
64
|
+
}
|
65
|
+
# todo: update preset name
|
66
|
+
sam = keras_hub.models.SAMImageSegmenter.from_preset(`sam_base`)
|
67
|
+
sam(input_data)
|
68
|
+
```
|
69
|
+
|
70
|
+
Load segment anything image segmenter with custom backbone
|
71
|
+
|
72
|
+
```python
|
73
|
+
image_size = 128
|
74
|
+
batch_size = 2
|
75
|
+
images = np.ones(
|
76
|
+
(batch_size, image_size, image_size, 3),
|
77
|
+
dtype="float32",
|
78
|
+
)
|
79
|
+
image_encoder = ViTDetBackbone(
|
80
|
+
hidden_size=16,
|
81
|
+
num_layers=16,
|
82
|
+
intermediate_dim=16 * 4,
|
83
|
+
num_heads=16,
|
84
|
+
global_attention_layer_indices=[2, 5, 8, 11],
|
85
|
+
patch_size=16,
|
86
|
+
num_output_channels=8,
|
87
|
+
window_size=2,
|
88
|
+
image_shape=(image_size, image_size, 3),
|
89
|
+
)
|
90
|
+
prompt_encoder = SAMPromptEncoder(
|
91
|
+
hidden_size=8,
|
92
|
+
image_embedding_size=(8, 8),
|
93
|
+
input_image_size=(
|
94
|
+
image_size,
|
95
|
+
image_size,
|
96
|
+
),
|
97
|
+
mask_in_channels=16,
|
98
|
+
)
|
99
|
+
mask_decoder = SAMMaskDecoder(
|
100
|
+
num_layers=2,
|
101
|
+
hidden_size=8,
|
102
|
+
intermediate_dim=32,
|
103
|
+
num_heads=8,
|
104
|
+
embedding_dim=8,
|
105
|
+
num_multimask_outputs=3,
|
106
|
+
iou_head_depth=3,
|
107
|
+
iou_head_hidden_dim=8,
|
108
|
+
)
|
109
|
+
backbone = SAMBackbone(
|
110
|
+
image_encoder=image_encoder,
|
111
|
+
prompt_encoder=prompt_encoder,
|
112
|
+
mask_decoder=mask_decoder,
|
113
|
+
image_shape=(image_size, image_size, 3),
|
114
|
+
)
|
115
|
+
sam = SAMImageSegmenter(
|
116
|
+
backbone=backbone
|
117
|
+
)
|
118
|
+
```
|
119
|
+
|
120
|
+
For example, to pass in all the prompts, do:
|
121
|
+
|
122
|
+
```python
|
123
|
+
|
124
|
+
points = np.array([[[512., 512.], [100., 100.]]])
|
125
|
+
# For labels: 1 means foreground point, 0 means background
|
126
|
+
labels = np.array([[1., 0.]])
|
127
|
+
box = np.array([[[[384., 384.], [640., 640.]]]])
|
128
|
+
input_mask = np.ones((1, 1, 256, 256, 1))
|
129
|
+
Prepare an input dictionary:
|
130
|
+
inputs = {
|
131
|
+
"images": image,
|
132
|
+
"points": points,
|
133
|
+
"labels": labels,
|
134
|
+
"boxes": box,
|
135
|
+
"masks": input_mask
|
136
|
+
}
|
137
|
+
outputs = sam.predict(inputs)
|
138
|
+
masks, iou_pred = outputs["masks"], outputs["iou_pred"]
|
139
|
+
```
|
140
|
+
|
141
|
+
The first mask in the output `masks` (i.e. `masks[:, 0, ...]`) is the best
|
142
|
+
mask predicted by the model based on the prompts. Other `masks`
|
143
|
+
(i.e. `masks[:, 1:, ...]`) are alternate predictions that can be used if
|
144
|
+
they are desired over the first one.
|
145
|
+
Now, in case of only points and box prompts, simply exclude the masks:
|
146
|
+
|
147
|
+
```python
|
148
|
+
inputs = {
|
149
|
+
"images": image,
|
150
|
+
"points": points,
|
151
|
+
"labels": labels,
|
152
|
+
"boxes": box,
|
153
|
+
}
|
154
|
+
|
155
|
+
outputs = sam.predict(inputs)
|
156
|
+
masks, iou_pred = outputs["masks"], outputs["iou_pred"]
|
157
|
+
```
|
158
|
+
|
159
|
+
Another example is that only points prompts are present.
|
160
|
+
Note that if point prompts are present but no box prompt is present, the
|
161
|
+
points must be padded using a zero point and -1 label:
|
162
|
+
|
163
|
+
```python
|
164
|
+
padded_points = np.concatenate(
|
165
|
+
[points, np.zeros((1, 1, 2))], axis=1
|
166
|
+
)
|
167
|
+
|
168
|
+
padded_labels = np.concatenate(
|
169
|
+
[labels, -np.ones((1, 1))], axis=1
|
170
|
+
)
|
171
|
+
inputs = {
|
172
|
+
"images": image,
|
173
|
+
"points": padded_points,
|
174
|
+
"labels": padded_labels,
|
175
|
+
}
|
176
|
+
outputs = sam.predict(inputs)
|
177
|
+
masks, iou_pred = outputs["masks"], outputs["iou_pred"]
|
178
|
+
```
|
179
|
+
"""
|
180
|
+
|
181
|
+
backbone_cls = SAMBackbone
|
182
|
+
preprocessor_cls = None
|
183
|
+
|
184
|
+
def __init__(self, backbone, preprocessor=None, **kwargs):
|
185
|
+
# The implementation has been adapted form [Segment Anything
|
186
|
+
# paper](https://arxiv.org/abs/2304.02643) and [Segment Anything
|
187
|
+
# GitHub](https://github.com/facebookresearch/segment-anything) and
|
188
|
+
# [Detectron2](https://github.com/facebookresearch/detectron2).
|
189
|
+
# === Layers ===
|
190
|
+
self.backbone = backbone
|
191
|
+
# === Functional Model ===
|
192
|
+
inputs = self.backbone.input
|
193
|
+
x = self.backbone(inputs)
|
194
|
+
outputs = self.backbone.mask_decoder(**x)
|
195
|
+
super().__init__(inputs=inputs, outputs=outputs, **kwargs)
|
196
|
+
|
197
|
+
def predict_step(self, *args, **kwargs):
|
198
|
+
if len(args) == 2:
|
199
|
+
args = (args[0], self._add_placeholder_prompts(args[-1]))
|
200
|
+
else:
|
201
|
+
args = (self._add_placeholder_prompts(args[0]),)
|
202
|
+
|
203
|
+
return super().predict_step(*args, **kwargs)
|
204
|
+
|
205
|
+
def fit(self, *args, **kwargs):
|
206
|
+
raise NotImplementedError(
|
207
|
+
"Segment Anything Model only supports inference for now. Training"
|
208
|
+
" the model isn't supported yet."
|
209
|
+
)
|
210
|
+
|
211
|
+
def _add_placeholder_prompts(self, inputs):
|
212
|
+
"""Adds placeholder prompt inputs for a call to SAM.
|
213
|
+
|
214
|
+
Because SAM is a functional subclass model, all inputs must be specified in
|
215
|
+
calls to the model. However, prompt inputs are all optional, so we have to
|
216
|
+
add placeholders when they're not specified by the user.
|
217
|
+
"""
|
218
|
+
inputs = inputs.copy()
|
219
|
+
|
220
|
+
# Get the batch shape based on the image input
|
221
|
+
batch_size = ops.shape(inputs["images"])[0]
|
222
|
+
|
223
|
+
# The type of the placeholders must match the existing inputs with respect
|
224
|
+
# to whether or not they are tensors (as opposed to Numpy arrays).
|
225
|
+
zeros = ops.zeros if ops.is_tensor(inputs["images"]) else np.zeros
|
226
|
+
|
227
|
+
# Fill in missing inputs.
|
228
|
+
if "points" not in inputs:
|
229
|
+
inputs["points"] = zeros((batch_size, 0, 2))
|
230
|
+
if "labels" not in inputs:
|
231
|
+
inputs["labels"] = zeros((batch_size, 0))
|
232
|
+
if "boxes" not in inputs:
|
233
|
+
inputs["boxes"] = zeros((batch_size, 0, 2, 2))
|
234
|
+
if "masks" not in inputs:
|
235
|
+
inputs["masks"] = zeros((batch_size, 0, 256, 256, 1))
|
236
|
+
|
237
|
+
return inputs
|