keras-hub-nightly 0.16.1.dev202409240339__py3-none-any.whl → 0.16.1.dev202409260340__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +5 -0
- keras_hub/api/models/__init__.py +19 -0
- keras_hub/api/tokenizers/__init__.py +1 -0
- keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_encoder_block.py +8 -2
- keras_hub/src/models/clip/clip_preprocessor.py +147 -0
- keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_text_encoder.py +60 -57
- keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_tokenizer.py +69 -30
- keras_hub/src/models/densenet/__init__.py +6 -0
- keras_hub/src/models/densenet/densenet_backbone.py +11 -8
- keras_hub/src/models/densenet/densenet_image_classifier.py +27 -4
- keras_hub/src/models/densenet/densenet_image_classifier_preprocessor.py +27 -0
- keras_hub/src/models/densenet/densenet_image_converter.py +23 -0
- keras_hub/src/models/densenet/densenet_presets.py +56 -0
- keras_hub/src/models/image_segmenter.py +86 -0
- keras_hub/src/models/sam/__init__.py +13 -0
- keras_hub/src/models/sam/sam_backbone.py +153 -0
- keras_hub/src/models/sam/sam_image_segmenter.py +237 -0
- keras_hub/src/models/sam/sam_layers.py +402 -0
- keras_hub/src/models/sam/sam_mask_decoder.py +270 -0
- keras_hub/src/models/sam/sam_prompt_encoder.py +336 -0
- keras_hub/src/models/sam/sam_transformer.py +159 -0
- keras_hub/src/models/stable_diffusion_3/__init__.py +13 -0
- keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +93 -0
- keras_hub/src/models/{stable_diffusion_v3 → stable_diffusion_3}/mmdit.py +351 -26
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +630 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +151 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +77 -0
- keras_hub/src/models/{stable_diffusion_v3/t5_xxl_text_encoder.py → stable_diffusion_3/t5_encoder.py} +7 -7
- keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +333 -0
- keras_hub/src/models/{stable_diffusion_v3/t5_xxl_preprocessor.py → t5/t5_preprocessor.py} +12 -3
- keras_hub/src/models/text_to_image.py +295 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +17 -12
- keras_hub/src/utils/timm/convert_densenet.py +107 -0
- keras_hub/src/utils/timm/preset_loader.py +3 -0
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.16.1.dev202409240339.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.16.1.dev202409240339.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/RECORD +40 -24
- keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +0 -93
- keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +0 -317
- keras_hub/src/models/stable_diffusion_v3/vae_attention.py +0 -126
- keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +0 -186
- /keras_hub/src/models/{stable_diffusion_v3 → clip}/__init__.py +0 -0
- {keras_hub_nightly-0.16.1.dev202409240339.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.16.1.dev202409240339.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,270 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import keras
|
16
|
+
from keras import ops
|
17
|
+
|
18
|
+
from keras_hub.src.api_export import keras_hub_export
|
19
|
+
from keras_hub.src.models.sam.sam_layers import MLP
|
20
|
+
from keras_hub.src.models.sam.sam_transformer import TwoWayTransformer
|
21
|
+
|
22
|
+
|
23
|
+
@keras_hub_export("keras_hub.layers.SAMMaskDecoder")
|
24
|
+
class SAMMaskDecoder(keras.layers.Layer):
|
25
|
+
"""Mask decoder for the Segment Anything Model (SAM).
|
26
|
+
|
27
|
+
This lightweight module efficiently maps the image embedding and a set of
|
28
|
+
prompt embeddings to an output mask. Before applying the transformer
|
29
|
+
decoder, the layer first inserts into the set of prompt embeddings a
|
30
|
+
learned output token embedding that will be used at the decoder's output.
|
31
|
+
For simplicity, these embeddings (not including the image embedding) are
|
32
|
+
collectively called "tokens".
|
33
|
+
|
34
|
+
The image embeddings, positional image embeddings, and tokens are passed
|
35
|
+
through a transformer decoder. After running the decoder, the layer
|
36
|
+
upsamples the updated image embedding by 4x with two transposed
|
37
|
+
convolutional layers (now it's downscaled 4x relative to the input
|
38
|
+
image). Then, the tokens attend once more to the image embedding and
|
39
|
+
the updated output token embedding are passed to a small 3-layer MLP that
|
40
|
+
outputs a vector matching the channel dimension of the upscaled image
|
41
|
+
embedding.
|
42
|
+
|
43
|
+
Finally, a mask is predicted with a spatially point-wise
|
44
|
+
product between the upscaled image embedding and the MLP's output.
|
45
|
+
|
46
|
+
Args:
|
47
|
+
hidden_size: int. The hidden size of the TwoWayTransformer.
|
48
|
+
num_layers: int. The number of layers in the TwoWayTransformer.
|
49
|
+
intermediate_dim: int. The intermediate dimension of the
|
50
|
+
TwoWayTransformer.
|
51
|
+
num_heads: int. The number of heads in the TwoWayTransformer.
|
52
|
+
embedding_dim: int, optional. The number of input features to the
|
53
|
+
transformer decoder. Defaults to `256`.
|
54
|
+
num_multimask_outputs: int, optional. Number of multimask outputs.
|
55
|
+
The model would generate these many extra masks. The total masks
|
56
|
+
generated by the model are `1 + num_multimask_outputs`. Defaults
|
57
|
+
to `3`.
|
58
|
+
iou_head_depth: int, optional. The depth of the dense net used to
|
59
|
+
predict the IoU confidence score. Defaults to `3`.
|
60
|
+
iou_head_hidden_dim: int, optional. The number of units in the hidden
|
61
|
+
layers used in the dense net to predict the IoU confidence score.
|
62
|
+
Defaults to `256`.
|
63
|
+
activation: str, optional. Activation to use in the mask upscaler
|
64
|
+
network. Defaults to `"gelu"`.
|
65
|
+
"""
|
66
|
+
|
67
|
+
def __init__(
|
68
|
+
self,
|
69
|
+
*,
|
70
|
+
hidden_size,
|
71
|
+
num_layers,
|
72
|
+
intermediate_dim,
|
73
|
+
num_heads,
|
74
|
+
embedding_dim=256,
|
75
|
+
num_multimask_outputs=3,
|
76
|
+
iou_head_depth=3,
|
77
|
+
iou_head_hidden_dim=256,
|
78
|
+
activation="gelu",
|
79
|
+
**kwargs,
|
80
|
+
):
|
81
|
+
super().__init__(**kwargs)
|
82
|
+
self.hidden_size = hidden_size
|
83
|
+
self.num_layers = num_layers
|
84
|
+
self.intermediate_dim = intermediate_dim
|
85
|
+
self.num_heads = num_heads
|
86
|
+
self.embedding_dim = embedding_dim
|
87
|
+
transformer = TwoWayTransformer(
|
88
|
+
num_layers=num_layers,
|
89
|
+
hidden_size=hidden_size,
|
90
|
+
intermediate_dim=intermediate_dim,
|
91
|
+
num_heads=num_heads,
|
92
|
+
dtype=self.dtype_policy,
|
93
|
+
)
|
94
|
+
self.transformer = transformer
|
95
|
+
self.num_multimask_outputs = num_multimask_outputs
|
96
|
+
self.iou_head_depth = iou_head_depth
|
97
|
+
self.iou_head_hidden_dim = iou_head_hidden_dim
|
98
|
+
self.activation = activation
|
99
|
+
|
100
|
+
self.iou_token = keras.layers.Embedding(
|
101
|
+
1, embedding_dim, dtype=self.dtype_policy
|
102
|
+
)
|
103
|
+
self.num_mask_tokens = num_multimask_outputs + 1
|
104
|
+
self.mask_tokens = keras.layers.Embedding(
|
105
|
+
self.num_mask_tokens, embedding_dim, dtype=self.dtype_policy
|
106
|
+
)
|
107
|
+
|
108
|
+
self.output_upscaling = keras.models.Sequential(
|
109
|
+
[
|
110
|
+
keras.layers.Conv2DTranspose(
|
111
|
+
embedding_dim // 4,
|
112
|
+
kernel_size=2,
|
113
|
+
strides=2,
|
114
|
+
dtype=self.dtype_policy,
|
115
|
+
),
|
116
|
+
keras.layers.LayerNormalization(
|
117
|
+
epsilon=1e-6, dtype=self.dtype_policy
|
118
|
+
),
|
119
|
+
keras.layers.Activation(activation, dtype=self.dtype_policy),
|
120
|
+
keras.layers.Conv2DTranspose(
|
121
|
+
embedding_dim // 8,
|
122
|
+
kernel_size=2,
|
123
|
+
strides=2,
|
124
|
+
dtype=self.dtype_policy,
|
125
|
+
),
|
126
|
+
keras.layers.Activation(activation, dtype=self.dtype_policy),
|
127
|
+
]
|
128
|
+
)
|
129
|
+
|
130
|
+
self.output_hypernetworks_mlps = [
|
131
|
+
MLP(embedding_dim, embedding_dim // 8, 3, dtype=self.dtype_policy)
|
132
|
+
for _ in range(self.num_mask_tokens)
|
133
|
+
]
|
134
|
+
|
135
|
+
self.iou_prediction_head = MLP(
|
136
|
+
iou_head_hidden_dim,
|
137
|
+
self.num_mask_tokens,
|
138
|
+
iou_head_depth,
|
139
|
+
dtype=self.dtype_policy,
|
140
|
+
)
|
141
|
+
|
142
|
+
def build(self, input_shape=None, **kwargs):
|
143
|
+
self.transformer.build()
|
144
|
+
self.iou_token.build([None])
|
145
|
+
self.mask_tokens.build([None])
|
146
|
+
self.output_upscaling.build([None, None, None, self.embedding_dim])
|
147
|
+
for mlp in self.output_hypernetworks_mlps:
|
148
|
+
mlp.build([None, self.embedding_dim])
|
149
|
+
self.iou_prediction_head.build([None, self.embedding_dim])
|
150
|
+
self.built = True
|
151
|
+
|
152
|
+
def call(
|
153
|
+
self,
|
154
|
+
image_embeddings,
|
155
|
+
prompt_dense_positional_embeddings,
|
156
|
+
prompt_sparse_embeddings,
|
157
|
+
prompt_dense_embeddings,
|
158
|
+
):
|
159
|
+
masks, iou_pred = self._predict_masks(
|
160
|
+
image_embeddings=image_embeddings,
|
161
|
+
image_positional_embeddings=prompt_dense_positional_embeddings,
|
162
|
+
prompt_sparse_embeddings=prompt_sparse_embeddings,
|
163
|
+
prompt_dense_embeddings=prompt_dense_embeddings,
|
164
|
+
)
|
165
|
+
|
166
|
+
return {"masks": masks, "iou_pred": iou_pred}
|
167
|
+
|
168
|
+
def _predict_masks(
|
169
|
+
self,
|
170
|
+
image_embeddings,
|
171
|
+
image_positional_embeddings,
|
172
|
+
prompt_sparse_embeddings,
|
173
|
+
prompt_dense_embeddings,
|
174
|
+
):
|
175
|
+
indices_iou = ops.arange(1, dtype="int32")
|
176
|
+
indices_mask = ops.arange(self.num_mask_tokens, dtype="int32")
|
177
|
+
|
178
|
+
output_tokens = ops.concatenate(
|
179
|
+
[self.iou_token(indices_iou), self.mask_tokens(indices_mask)],
|
180
|
+
axis=0,
|
181
|
+
)
|
182
|
+
output_tokens = ops.broadcast_to(
|
183
|
+
output_tokens[None, ...],
|
184
|
+
shape=(
|
185
|
+
ops.shape(prompt_sparse_embeddings)[0],
|
186
|
+
ops.shape(output_tokens)[0],
|
187
|
+
ops.shape(output_tokens)[1],
|
188
|
+
),
|
189
|
+
)
|
190
|
+
tokens = ops.concatenate(
|
191
|
+
[output_tokens, prompt_sparse_embeddings], axis=1
|
192
|
+
)
|
193
|
+
|
194
|
+
source = ops.broadcast_to(
|
195
|
+
image_embeddings,
|
196
|
+
shape=(
|
197
|
+
ops.shape(tokens)[0],
|
198
|
+
ops.shape(image_embeddings)[1],
|
199
|
+
ops.shape(image_embeddings)[2],
|
200
|
+
ops.shape(image_embeddings)[3],
|
201
|
+
),
|
202
|
+
)
|
203
|
+
source = source + prompt_dense_embeddings
|
204
|
+
positional_source = ops.broadcast_to(
|
205
|
+
image_positional_embeddings,
|
206
|
+
shape=(
|
207
|
+
ops.shape(tokens)[0],
|
208
|
+
ops.shape(image_embeddings)[1],
|
209
|
+
ops.shape(image_embeddings)[2],
|
210
|
+
ops.shape(image_embeddings)[3],
|
211
|
+
),
|
212
|
+
)
|
213
|
+
shape = ops.shape(source)
|
214
|
+
batch_dim, height, width, channels = (
|
215
|
+
shape[0],
|
216
|
+
shape[1],
|
217
|
+
shape[2],
|
218
|
+
shape[3],
|
219
|
+
)
|
220
|
+
|
221
|
+
hidden_state, source = self.transformer(
|
222
|
+
source, positional_source, tokens
|
223
|
+
)
|
224
|
+
iou_token_out = hidden_state[:, 0, :]
|
225
|
+
mask_tokens_out = hidden_state[:, 1 : (1 + self.num_mask_tokens), :]
|
226
|
+
|
227
|
+
source = ops.reshape(source, (batch_dim, height, width, channels))
|
228
|
+
upscaled_embeddings = self.output_upscaling(source)
|
229
|
+
hyper_in_list = []
|
230
|
+
for i in range(self.num_mask_tokens):
|
231
|
+
hyper_in_list.append(
|
232
|
+
self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
|
233
|
+
)
|
234
|
+
hyper_in = ops.stack(hyper_in_list, axis=1)
|
235
|
+
shape = ops.shape(upscaled_embeddings)
|
236
|
+
batch_dim, height, width, channels = (
|
237
|
+
shape[0],
|
238
|
+
shape[1],
|
239
|
+
shape[2],
|
240
|
+
shape[3],
|
241
|
+
)
|
242
|
+
upscaled_embeddings = ops.reshape(
|
243
|
+
ops.transpose(upscaled_embeddings, axes=(0, 3, 1, 2)),
|
244
|
+
(batch_dim, channels, height * width),
|
245
|
+
)
|
246
|
+
masks = ops.reshape(
|
247
|
+
hyper_in @ upscaled_embeddings,
|
248
|
+
(batch_dim, self.num_mask_tokens, height, width),
|
249
|
+
)
|
250
|
+
|
251
|
+
iou_pred = self.iou_prediction_head(iou_token_out)
|
252
|
+
|
253
|
+
return masks, iou_pred
|
254
|
+
|
255
|
+
def get_config(self):
|
256
|
+
config = super().get_config()
|
257
|
+
config.update(
|
258
|
+
{
|
259
|
+
"hidden_size": self.hidden_size,
|
260
|
+
"num_layers": self.num_layers,
|
261
|
+
"intermediate_dim": self.intermediate_dim,
|
262
|
+
"num_heads": self.num_heads,
|
263
|
+
"embedding_dim": self.embedding_dim,
|
264
|
+
"num_multimask_outputs": self.num_multimask_outputs,
|
265
|
+
"iou_head_depth": self.iou_head_depth,
|
266
|
+
"iou_head_hidden_dim": self.iou_head_hidden_dim,
|
267
|
+
"activation": self.activation,
|
268
|
+
}
|
269
|
+
)
|
270
|
+
return config
|
@@ -0,0 +1,336 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import keras
|
16
|
+
from keras import ops
|
17
|
+
|
18
|
+
from keras_hub.src.api_export import keras_hub_export
|
19
|
+
from keras_hub.src.models.sam.sam_layers import (
|
20
|
+
RandomFrequencyPositionalEmbeddings,
|
21
|
+
)
|
22
|
+
|
23
|
+
|
24
|
+
@keras_hub_export("keras_hub.layers.SAMPromptEncoder")
|
25
|
+
class SAMPromptEncoder(keras.layers.Layer):
|
26
|
+
"""Prompt Encoder for the Segment Anything Model (SAM).
|
27
|
+
|
28
|
+
The prompt encoder generates encodings for three types of prompts:
|
29
|
+
- Point prompts: Points on the image along with a label indicating whether
|
30
|
+
the point is in the foreground (part of the mask) or in the background
|
31
|
+
(not a part of the mask).
|
32
|
+
- Box prompts: A batch of bounding boxes with format [(x1, y1), (x2, y2)]
|
33
|
+
used to determine the location of the masks in the image.
|
34
|
+
- Masks: An input mask can be passed to refine the positional embeddings
|
35
|
+
for the output mask.
|
36
|
+
|
37
|
+
First, the point prompts and box prompts are concatenated and positional
|
38
|
+
encodings are generated using random spatial frequencies. A point is
|
39
|
+
represented as the sum of a positional encoding of the point's location
|
40
|
+
and one of two learned embeddings that indicate if the point is either in
|
41
|
+
the foreground or background. A box is represented by an embedding pair:
|
42
|
+
(1) the positional encoding of its top-left corner summed with a learned
|
43
|
+
embedding representing "top-left corner" and
|
44
|
+
(2) the same structure but using a learned embedding indicating
|
45
|
+
"bottom-right corner".
|
46
|
+
The box and point encodings are referred to as "prompt_sparse encodings"
|
47
|
+
If a mask prompt is passed, a convolutional neural net is used to
|
48
|
+
downscale it to generate "dense encodings". If no mask prompt is passed,
|
49
|
+
an embedding layer is used instead to generate a "no mask" embedding.
|
50
|
+
|
51
|
+
|
52
|
+
Args:
|
53
|
+
hidden_size: int, optional. The number of features in the output
|
54
|
+
embeddings. Defaults to `256`.
|
55
|
+
image_embedding_size: int, optional. The number of features in the
|
56
|
+
image embeddings generated by an image encoder. Defaults to
|
57
|
+
`(64, 64)`.
|
58
|
+
input_image_size: tuple[int], optional. A tuple of the height and
|
59
|
+
width of the image being prompted. Defaults to `(1024, 1024)`.
|
60
|
+
mask_in_channels: int, optional. The number of channels of the mask
|
61
|
+
prompt. Defaults to `16`.
|
62
|
+
activation: str, optional. The activation to use in the mask
|
63
|
+
downscaler neural net. Defaults to `"gelu"`.
|
64
|
+
"""
|
65
|
+
|
66
|
+
def __init__(
|
67
|
+
self,
|
68
|
+
*,
|
69
|
+
hidden_size=256,
|
70
|
+
image_embedding_size=(64, 64),
|
71
|
+
input_image_size=(1024, 1024),
|
72
|
+
mask_in_channels=16,
|
73
|
+
activation="gelu",
|
74
|
+
**kwargs
|
75
|
+
):
|
76
|
+
super().__init__(**kwargs)
|
77
|
+
self.hidden_size = hidden_size
|
78
|
+
self.image_embedding_size = image_embedding_size
|
79
|
+
self.input_image_size = input_image_size
|
80
|
+
self.mask_in_channels = mask_in_channels
|
81
|
+
self.activation = activation
|
82
|
+
|
83
|
+
self.positional_embedding_layer = RandomFrequencyPositionalEmbeddings(
|
84
|
+
num_positional_features=self.hidden_size // 2, scale=1
|
85
|
+
)
|
86
|
+
|
87
|
+
self.foreground_point_embed = keras.layers.Embedding(
|
88
|
+
1, hidden_size, name="foreground_point_embed"
|
89
|
+
)
|
90
|
+
self.background_point_embed = keras.layers.Embedding(
|
91
|
+
1, hidden_size, name="background_point_embed"
|
92
|
+
)
|
93
|
+
self.top_left_corner_embed = keras.layers.Embedding(
|
94
|
+
1, hidden_size, name="top_left_corner_embed"
|
95
|
+
)
|
96
|
+
self.bottom_right_corner_embed = keras.layers.Embedding(
|
97
|
+
1, hidden_size, name="bottom_right_corner_embed"
|
98
|
+
)
|
99
|
+
self.not_a_point_embed = keras.layers.Embedding(
|
100
|
+
1, hidden_size, name="not_a_point_embed"
|
101
|
+
)
|
102
|
+
|
103
|
+
self.mask_downscaler = keras.models.Sequential(
|
104
|
+
[
|
105
|
+
keras.layers.Conv2D(
|
106
|
+
mask_in_channels // 4, kernel_size=2, strides=2
|
107
|
+
),
|
108
|
+
keras.layers.LayerNormalization(epsilon=1e-6),
|
109
|
+
keras.layers.Activation(activation),
|
110
|
+
keras.layers.Conv2D(mask_in_channels, kernel_size=2, strides=2),
|
111
|
+
keras.layers.LayerNormalization(epsilon=1e-6),
|
112
|
+
keras.layers.Activation(activation),
|
113
|
+
keras.layers.Conv2D(hidden_size, kernel_size=1),
|
114
|
+
],
|
115
|
+
name="mask_downscaler",
|
116
|
+
)
|
117
|
+
self.no_mask_embed = keras.layers.Embedding(
|
118
|
+
1, hidden_size, name="no_mask_embed"
|
119
|
+
)
|
120
|
+
|
121
|
+
def build(
|
122
|
+
self,
|
123
|
+
points_shape=None,
|
124
|
+
labels_shape=None,
|
125
|
+
boxes_shape=None,
|
126
|
+
masks_shape=None,
|
127
|
+
):
|
128
|
+
self.positional_embedding_layer.build()
|
129
|
+
for layer in [
|
130
|
+
self.foreground_point_embed,
|
131
|
+
self.background_point_embed,
|
132
|
+
self.top_left_corner_embed,
|
133
|
+
self.bottom_right_corner_embed,
|
134
|
+
self.not_a_point_embed,
|
135
|
+
self.no_mask_embed,
|
136
|
+
]:
|
137
|
+
layer.build([None])
|
138
|
+
self.mask_downscaler.build(
|
139
|
+
[
|
140
|
+
None,
|
141
|
+
4 * self.image_embedding_size[0],
|
142
|
+
4 * self.image_embedding_size[1],
|
143
|
+
1,
|
144
|
+
]
|
145
|
+
)
|
146
|
+
self.built = True
|
147
|
+
|
148
|
+
def compute_output_shape(
|
149
|
+
self,
|
150
|
+
points_shape=None,
|
151
|
+
labels_shape=None,
|
152
|
+
boxes_shape=None,
|
153
|
+
masks_shape=None,
|
154
|
+
):
|
155
|
+
batch_size = None
|
156
|
+
for shape in (points_shape, labels_shape, boxes_shape, masks_shape):
|
157
|
+
if shape is not None:
|
158
|
+
batch_size = shape[0]
|
159
|
+
break
|
160
|
+
return {
|
161
|
+
"prompt_sparse_embeddings": (
|
162
|
+
batch_size,
|
163
|
+
None,
|
164
|
+
self.hidden_size,
|
165
|
+
),
|
166
|
+
"prompt_dense_embeddings": (
|
167
|
+
batch_size,
|
168
|
+
self.image_embedding_size[0],
|
169
|
+
self.image_embedding_size[1],
|
170
|
+
self.hidden_size,
|
171
|
+
),
|
172
|
+
"prompt_dense_positional_embeddings": (
|
173
|
+
batch_size,
|
174
|
+
self.image_embedding_size[0],
|
175
|
+
self.image_embedding_size[1],
|
176
|
+
self.hidden_size,
|
177
|
+
),
|
178
|
+
}
|
179
|
+
|
180
|
+
def _embed_points(self, points, labels):
|
181
|
+
points = points + 0.5
|
182
|
+
indices = ops.arange(1, dtype="int32")
|
183
|
+
|
184
|
+
point_embeddings = self.positional_embedding_layer.encode_coordinates(
|
185
|
+
points, self.input_image_size
|
186
|
+
)
|
187
|
+
labels = ops.broadcast_to(
|
188
|
+
labels[..., None], ops.shape(point_embeddings)
|
189
|
+
)
|
190
|
+
point_embeddings = ops.where(
|
191
|
+
labels == 0,
|
192
|
+
point_embeddings + self.background_point_embed(indices),
|
193
|
+
point_embeddings + self.foreground_point_embed(indices),
|
194
|
+
)
|
195
|
+
point_embeddings = ops.where(
|
196
|
+
labels == -1,
|
197
|
+
self.not_a_point_embed(indices),
|
198
|
+
point_embeddings,
|
199
|
+
)
|
200
|
+
return point_embeddings
|
201
|
+
|
202
|
+
def _embed_box(self, box):
|
203
|
+
shape = ops.shape(box)
|
204
|
+
batch_size, N = shape[0], shape[1]
|
205
|
+
box = box + 0.5
|
206
|
+
indices = ops.arange(1, dtype="int32")
|
207
|
+
corner_embedding = self.positional_embedding_layer.encode_coordinates(
|
208
|
+
box, self.input_image_size
|
209
|
+
)
|
210
|
+
top_left_embedding = corner_embedding[
|
211
|
+
:, :, 0, :
|
212
|
+
] + self.top_left_corner_embed(indices)
|
213
|
+
bottom_right_embedding = corner_embedding[
|
214
|
+
:, :, 1, :
|
215
|
+
] + self.bottom_right_corner_embed(indices)
|
216
|
+
corner_embedding = ops.stack(
|
217
|
+
[top_left_embedding, bottom_right_embedding], axis=2
|
218
|
+
)
|
219
|
+
return ops.reshape(
|
220
|
+
corner_embedding, (batch_size, N * 2, self.hidden_size)
|
221
|
+
)
|
222
|
+
|
223
|
+
def _embed_mask(self, mask):
|
224
|
+
mask_embedding = self.mask_downscaler(mask)
|
225
|
+
return mask_embedding
|
226
|
+
|
227
|
+
def call(
|
228
|
+
self, images=None, points=None, labels=None, boxes=None, masks=None
|
229
|
+
):
|
230
|
+
# Get the batch shape based on any arbitrary input, because batch
|
231
|
+
# shapes must all match.
|
232
|
+
valid_inputs = [
|
233
|
+
x for x in (points, labels, boxes, masks) if x is not None
|
234
|
+
]
|
235
|
+
|
236
|
+
batch_size = ops.shape(valid_inputs[0])[0]
|
237
|
+
if points is None:
|
238
|
+
points = ops.zeros((batch_size, 0, 2))
|
239
|
+
if labels is None:
|
240
|
+
labels = ops.zeros((batch_size, 0))
|
241
|
+
if boxes is None:
|
242
|
+
boxes = ops.zeros((batch_size, 0, 2, 2))
|
243
|
+
if masks is None:
|
244
|
+
masks = ops.zeros((batch_size, 0, 256, 256, 1))
|
245
|
+
|
246
|
+
# Compute point embeddings
|
247
|
+
point_embeddings = self._embed_points(points, labels)
|
248
|
+
|
249
|
+
# Compute box embeddings
|
250
|
+
box_embeddings = self._embed_box(boxes)
|
251
|
+
|
252
|
+
# Concatenate both into a sparse embeddings tensor
|
253
|
+
sparse_embeddings = ops.concatenate(
|
254
|
+
[point_embeddings, box_embeddings], axis=1
|
255
|
+
)
|
256
|
+
|
257
|
+
# Compute the mask embeddings
|
258
|
+
def _no_mask_embed():
|
259
|
+
reshaped_embed = ops.reshape(
|
260
|
+
self.no_mask_embed(ops.arange(1, dtype="int32")),
|
261
|
+
(1, 1, 1, self.hidden_size),
|
262
|
+
)
|
263
|
+
broadcasted_embed = ops.broadcast_to(
|
264
|
+
reshaped_embed,
|
265
|
+
shape=(
|
266
|
+
batch_size,
|
267
|
+
self.image_embedding_size[0],
|
268
|
+
self.image_embedding_size[1],
|
269
|
+
self.hidden_size,
|
270
|
+
),
|
271
|
+
)
|
272
|
+
return broadcasted_embed
|
273
|
+
|
274
|
+
def _maybe_input_mask_embed():
|
275
|
+
# Keras passes the masks as concrete tensors for both the
|
276
|
+
# true and false functions to build the output shape. So, we
|
277
|
+
# need to handle the case when 0 size masks is passed and
|
278
|
+
# dispatch the call to `_no_mask_embed`. Note that we can't call
|
279
|
+
# the lambda directly since the inputs are bound to different
|
280
|
+
# values when called with concrete values.
|
281
|
+
if masks.shape[1] == 0:
|
282
|
+
return ops.broadcast_to(
|
283
|
+
ops.reshape(
|
284
|
+
self.no_mask_embed(ops.arange(1, dtype="int32")),
|
285
|
+
(1, 1, 1, self.hidden_size),
|
286
|
+
),
|
287
|
+
shape=(
|
288
|
+
batch_size,
|
289
|
+
self.image_embedding_size[0],
|
290
|
+
self.image_embedding_size[1],
|
291
|
+
self.hidden_size,
|
292
|
+
),
|
293
|
+
)
|
294
|
+
shape = ops.shape(masks)
|
295
|
+
BM, N, height, width, channels = (
|
296
|
+
shape[0],
|
297
|
+
shape[1],
|
298
|
+
shape[2],
|
299
|
+
shape[3],
|
300
|
+
shape[4],
|
301
|
+
)
|
302
|
+
return self._embed_mask(
|
303
|
+
ops.reshape(masks, (BM * N, height, width, channels))
|
304
|
+
)
|
305
|
+
|
306
|
+
dense_embeddings = ops.cond(
|
307
|
+
ops.equal(ops.size(masks), 0),
|
308
|
+
_no_mask_embed,
|
309
|
+
_maybe_input_mask_embed,
|
310
|
+
)
|
311
|
+
|
312
|
+
# Compute the dense positional embeddings
|
313
|
+
prompt_dense_positional_embeddings = (
|
314
|
+
self.positional_embedding_layer.encode_image(
|
315
|
+
self.image_embedding_size
|
316
|
+
)[None, ...]
|
317
|
+
)
|
318
|
+
|
319
|
+
return {
|
320
|
+
"prompt_sparse_embeddings": sparse_embeddings,
|
321
|
+
"prompt_dense_embeddings": dense_embeddings,
|
322
|
+
"prompt_dense_positional_embeddings": prompt_dense_positional_embeddings,
|
323
|
+
}
|
324
|
+
|
325
|
+
def get_config(self):
|
326
|
+
config = super().get_config()
|
327
|
+
config.update(
|
328
|
+
{
|
329
|
+
"hidden_size": self.hidden_size,
|
330
|
+
"image_embedding_size": self.image_embedding_size,
|
331
|
+
"input_image_size": self.input_image_size,
|
332
|
+
"mask_in_channels": self.mask_in_channels,
|
333
|
+
"activation": self.activation,
|
334
|
+
}
|
335
|
+
)
|
336
|
+
return config
|