keras-hub-nightly 0.16.1.dev202409230338__py3-none-any.whl → 0.16.1.dev202409250340__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +2 -0
- keras_hub/api/models/__init__.py +3 -0
- keras_hub/src/models/image_segmenter.py +86 -0
- keras_hub/src/models/sam/__init__.py +13 -0
- keras_hub/src/models/sam/sam_backbone.py +153 -0
- keras_hub/src/models/sam/sam_image_segmenter.py +237 -0
- keras_hub/src/models/sam/sam_layers.py +402 -0
- keras_hub/src/models/sam/sam_mask_decoder.py +270 -0
- keras_hub/src/models/sam/sam_prompt_encoder.py +336 -0
- keras_hub/src/models/sam/sam_transformer.py +159 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +17 -12
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.16.1.dev202409230338.dist-info → keras_hub_nightly-0.16.1.dev202409250340.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.16.1.dev202409230338.dist-info → keras_hub_nightly-0.16.1.dev202409250340.dist-info}/RECORD +16 -8
- {keras_hub_nightly-0.16.1.dev202409230338.dist-info → keras_hub_nightly-0.16.1.dev202409250340.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.16.1.dev202409230338.dist-info → keras_hub_nightly-0.16.1.dev202409250340.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,402 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import math
|
16
|
+
|
17
|
+
import keras
|
18
|
+
from keras import ops
|
19
|
+
|
20
|
+
|
21
|
+
class MLP(keras.layers.Layer):
|
22
|
+
"""A MLP block with architecture.
|
23
|
+
|
24
|
+
`input_dim -> [hidden_dim] * (num_layers - 1) -> output_dim`.
|
25
|
+
|
26
|
+
Args:
|
27
|
+
hidden_dim: int. The number of units in the hidden layers.
|
28
|
+
output_dim: int. The number of units in the output layer.
|
29
|
+
num_layers: int. The total number of dense layers to use.
|
30
|
+
activation: str. Activation to use in the hidden layers.
|
31
|
+
Default is `"relu"`.
|
32
|
+
"""
|
33
|
+
|
34
|
+
def __init__(
|
35
|
+
self, hidden_dim, output_dim, num_layers, activation="relu", **kwargs
|
36
|
+
):
|
37
|
+
super().__init__(**kwargs)
|
38
|
+
self.hidden_dim = hidden_dim
|
39
|
+
self.output_dim = output_dim
|
40
|
+
self.num_layers = num_layers
|
41
|
+
self.activation = activation
|
42
|
+
h = [hidden_dim] * (num_layers - 1)
|
43
|
+
self.mlp_block = []
|
44
|
+
for hidden_dim in h:
|
45
|
+
self.mlp_block.append(
|
46
|
+
keras.layers.Dense(hidden_dim, dtype=self.dtype_policy)
|
47
|
+
)
|
48
|
+
self.mlp_block.append(
|
49
|
+
keras.layers.Activation(activation, dtype=self.dtype_policy)
|
50
|
+
)
|
51
|
+
self.mlp_block.append(
|
52
|
+
keras.layers.Dense(output_dim, dtype=self.dtype_policy)
|
53
|
+
)
|
54
|
+
self.mlp_block = keras.models.Sequential(self.mlp_block)
|
55
|
+
|
56
|
+
def build(self, input_shape):
|
57
|
+
self.mlp_block.build(input_shape)
|
58
|
+
self.built = True
|
59
|
+
|
60
|
+
def call(self, x):
|
61
|
+
return self.mlp_block(x)
|
62
|
+
|
63
|
+
def get_config(self):
|
64
|
+
config = super().get_config()
|
65
|
+
config.update(
|
66
|
+
{
|
67
|
+
"hidden_dim": self.hidden_dim,
|
68
|
+
"output_dim": self.output_dim,
|
69
|
+
"num_layers": self.num_layers,
|
70
|
+
"activation": self.activation,
|
71
|
+
}
|
72
|
+
)
|
73
|
+
return config
|
74
|
+
|
75
|
+
|
76
|
+
class MultiHeadAttentionWithDownsampling(keras.layers.Layer):
|
77
|
+
"""Multi-Head Attention with downsampling.
|
78
|
+
|
79
|
+
An attention layer that allows for downscaling the size of the embedding
|
80
|
+
after projection to queries, keys, and values.
|
81
|
+
This layer first downscales the features of input queries, keys, and
|
82
|
+
values using a dense layer. Multi-head attention is then performed
|
83
|
+
and the attention map is projected back (upscaled) to the number of
|
84
|
+
input features.
|
85
|
+
|
86
|
+
Args:
|
87
|
+
num_heads: int. Number of attention heads.
|
88
|
+
key_dim: int. Size of each attention head for query, key, and
|
89
|
+
value.
|
90
|
+
downsample_rate: int, optional. The factor by which to downscale the
|
91
|
+
input features i.e. the input features of size `key_dim` are
|
92
|
+
projected down to `key_dim // downsample_rate`.
|
93
|
+
"""
|
94
|
+
|
95
|
+
def __init__(self, num_heads, key_dim, downsample_rate=1, **kwargs):
|
96
|
+
super().__init__(**kwargs)
|
97
|
+
self.num_heads = num_heads
|
98
|
+
self.key_dim = key_dim
|
99
|
+
self.downsample_rate = downsample_rate
|
100
|
+
self.internal_dims = key_dim // downsample_rate
|
101
|
+
|
102
|
+
# Downsample
|
103
|
+
self.query_proj = keras.layers.Dense(
|
104
|
+
self.internal_dims * self.num_heads, dtype=self.dtype_policy
|
105
|
+
)
|
106
|
+
self.key_proj = keras.layers.Dense(
|
107
|
+
self.internal_dims * self.num_heads, dtype=self.dtype_policy
|
108
|
+
)
|
109
|
+
self.value_proj = keras.layers.Dense(
|
110
|
+
self.internal_dims * self.num_heads, dtype=self.dtype_policy
|
111
|
+
)
|
112
|
+
|
113
|
+
# Upsample
|
114
|
+
self.out_proj = keras.layers.Dense(
|
115
|
+
self.key_dim * self.num_heads, dtype=self.dtype_policy
|
116
|
+
)
|
117
|
+
|
118
|
+
def build(self, input_shape=None):
|
119
|
+
self.query_proj.build([None, None, self.num_heads * self.key_dim])
|
120
|
+
self.key_proj.build([None, None, self.num_heads * self.key_dim])
|
121
|
+
self.value_proj.build([None, None, self.num_heads * self.key_dim])
|
122
|
+
self.out_proj.build([None, None, self.internal_dims * self.num_heads])
|
123
|
+
self.built = True
|
124
|
+
|
125
|
+
def _separate_heads(self, x):
|
126
|
+
shape = ops.shape(x)
|
127
|
+
batch_size, N, channels = shape[0], shape[1], shape[2]
|
128
|
+
x = ops.reshape(
|
129
|
+
x, (batch_size, N, self.num_heads, channels // self.num_heads)
|
130
|
+
)
|
131
|
+
return ops.transpose(x, axes=(0, 2, 1, 3))
|
132
|
+
|
133
|
+
def _recombine_heads(self, x):
|
134
|
+
shape = ops.shape(x)
|
135
|
+
batch_size, num_heads, N_T, channels_per_head = (
|
136
|
+
shape[0],
|
137
|
+
shape[1],
|
138
|
+
shape[2],
|
139
|
+
shape[3],
|
140
|
+
)
|
141
|
+
x = ops.transpose(x, axes=(0, 2, 1, 3))
|
142
|
+
return ops.reshape(x, (batch_size, N_T, num_heads * channels_per_head))
|
143
|
+
|
144
|
+
def call(self, query, value, key):
|
145
|
+
query = self.query_proj(query)
|
146
|
+
key = self.key_proj(key)
|
147
|
+
value = self.value_proj(value)
|
148
|
+
|
149
|
+
# Separate into heads
|
150
|
+
query = self._separate_heads(query)
|
151
|
+
key = self._separate_heads(key)
|
152
|
+
value = self._separate_heads(value)
|
153
|
+
|
154
|
+
# Attention
|
155
|
+
channels_per_head = ops.shape(query)[-1]
|
156
|
+
out = ops.matmul(query, ops.transpose(key, (0, 1, 3, 2)))
|
157
|
+
out = out / ops.sqrt(
|
158
|
+
ops.cast(channels_per_head, dtype=self.compute_dtype)
|
159
|
+
)
|
160
|
+
out = ops.softmax(out, axis=-1)
|
161
|
+
|
162
|
+
# Get output
|
163
|
+
attention_map = out @ value
|
164
|
+
attention_map = self._recombine_heads(attention_map)
|
165
|
+
return self.out_proj(attention_map)
|
166
|
+
|
167
|
+
def get_config(self):
|
168
|
+
config = super().get_config()
|
169
|
+
config.update(
|
170
|
+
{
|
171
|
+
"num_heads": self.num_heads,
|
172
|
+
"key_dim": self.key_dim,
|
173
|
+
"downsample_rate": self.downsample_rate,
|
174
|
+
}
|
175
|
+
)
|
176
|
+
return config
|
177
|
+
|
178
|
+
|
179
|
+
class TwoWayMultiHeadAttention(keras.layers.Layer):
|
180
|
+
"""Two-way multi-head attention layer.
|
181
|
+
|
182
|
+
Args:
|
183
|
+
num_heads: int. Number of attention heads.
|
184
|
+
key_dim: int. Size of each attention head for query, key, and
|
185
|
+
value.
|
186
|
+
intermediate_dim: int. Number of hidden dims to use in the mlp block.
|
187
|
+
skip_first_layer_pos_embedding: bool. A boolean indicating whether to skip the
|
188
|
+
first layer positional embeddings.
|
189
|
+
attention_downsample_rate: int, optional. The downsample rate to use
|
190
|
+
in the attention layers. Defaults to 2.
|
191
|
+
activation: str, optional. The activation for the mlp block's output
|
192
|
+
layer. Defaults to "relu".
|
193
|
+
"""
|
194
|
+
|
195
|
+
def __init__(
|
196
|
+
self,
|
197
|
+
num_heads,
|
198
|
+
key_dim,
|
199
|
+
intermediate_dim,
|
200
|
+
skip_first_layer_pos_embedding,
|
201
|
+
attention_downsample_rate=2,
|
202
|
+
activation="relu",
|
203
|
+
**kwargs,
|
204
|
+
):
|
205
|
+
super().__init__(**kwargs)
|
206
|
+
self.num_heads = num_heads
|
207
|
+
self.key_dim = key_dim
|
208
|
+
self.intermediate_dim = intermediate_dim
|
209
|
+
self.skip_first_layer_pos_embedding = skip_first_layer_pos_embedding
|
210
|
+
self.attention_downsample_rate = attention_downsample_rate
|
211
|
+
self.activation = activation
|
212
|
+
|
213
|
+
self.self_attention = MultiHeadAttentionWithDownsampling(
|
214
|
+
num_heads=num_heads, key_dim=key_dim, dtype=self.dtype_policy
|
215
|
+
)
|
216
|
+
self.layer_norm1 = keras.layers.LayerNormalization(
|
217
|
+
epsilon=1e-5, dtype=self.dtype_policy
|
218
|
+
)
|
219
|
+
self.cross_attention_token_to_image = (
|
220
|
+
MultiHeadAttentionWithDownsampling(
|
221
|
+
num_heads=num_heads,
|
222
|
+
key_dim=key_dim,
|
223
|
+
downsample_rate=attention_downsample_rate,
|
224
|
+
dtype=self.dtype_policy,
|
225
|
+
)
|
226
|
+
)
|
227
|
+
self.layer_norm2 = keras.layers.LayerNormalization(
|
228
|
+
epsilon=1e-5, dtype=self.dtype_policy
|
229
|
+
)
|
230
|
+
|
231
|
+
self.mlp_block = MLP(
|
232
|
+
intermediate_dim,
|
233
|
+
key_dim * num_heads,
|
234
|
+
num_layers=2,
|
235
|
+
activation=activation,
|
236
|
+
dtype=self.dtype_policy,
|
237
|
+
)
|
238
|
+
|
239
|
+
self.layer_norm3 = keras.layers.LayerNormalization(
|
240
|
+
epsilon=1e-5, dtype=self.dtype_policy
|
241
|
+
)
|
242
|
+
self.cross_attention_image_to_token = (
|
243
|
+
MultiHeadAttentionWithDownsampling(
|
244
|
+
num_heads=num_heads,
|
245
|
+
key_dim=key_dim,
|
246
|
+
downsample_rate=attention_downsample_rate,
|
247
|
+
dtype=self.dtype_policy,
|
248
|
+
)
|
249
|
+
)
|
250
|
+
self.layer_norm4 = keras.layers.LayerNormalization(
|
251
|
+
epsilon=1e-5, dtype=self.dtype_policy
|
252
|
+
)
|
253
|
+
|
254
|
+
def build(self, input_shape=None):
|
255
|
+
self.self_attention.build()
|
256
|
+
self.layer_norm1.build([None, None, self.num_heads * self.key_dim])
|
257
|
+
self.cross_attention_token_to_image.build()
|
258
|
+
self.layer_norm2.build([None, None, self.num_heads * self.key_dim])
|
259
|
+
self.mlp_block.build([None, None, self.num_heads * self.key_dim])
|
260
|
+
self.layer_norm3.build([None, None, self.num_heads * self.key_dim])
|
261
|
+
self.cross_attention_image_to_token.build()
|
262
|
+
self.layer_norm4.build([None, None, self.num_heads * self.key_dim])
|
263
|
+
self.built = True
|
264
|
+
|
265
|
+
def call(self, queries, keys, query_pos_embedding, key_pos_embedding):
|
266
|
+
if self.skip_first_layer_pos_embedding:
|
267
|
+
queries = self.self_attention(
|
268
|
+
query=queries, value=queries, key=queries
|
269
|
+
)
|
270
|
+
else:
|
271
|
+
queries_with_pos_embedding = queries + query_pos_embedding
|
272
|
+
attention_map = self.self_attention(
|
273
|
+
query=queries_with_pos_embedding,
|
274
|
+
key=queries_with_pos_embedding,
|
275
|
+
value=queries,
|
276
|
+
)
|
277
|
+
queries = queries + attention_map
|
278
|
+
queries = self.layer_norm1(queries)
|
279
|
+
|
280
|
+
queries_with_pos_embedding = queries + query_pos_embedding
|
281
|
+
keys_with_pos_embedding = keys + key_pos_embedding
|
282
|
+
attention_map = self.cross_attention_token_to_image(
|
283
|
+
query=queries_with_pos_embedding,
|
284
|
+
key=keys_with_pos_embedding,
|
285
|
+
value=keys,
|
286
|
+
)
|
287
|
+
queries = queries + attention_map
|
288
|
+
queries = self.layer_norm2(queries)
|
289
|
+
|
290
|
+
mlp_out = self.mlp_block(queries)
|
291
|
+
queries = queries + mlp_out
|
292
|
+
queries = self.layer_norm3(queries)
|
293
|
+
|
294
|
+
queries_with_pos_embedding = queries + query_pos_embedding
|
295
|
+
keys_with_pos_embedding = keys + key_pos_embedding
|
296
|
+
attention_map = self.cross_attention_image_to_token(
|
297
|
+
query=keys_with_pos_embedding,
|
298
|
+
key=queries_with_pos_embedding,
|
299
|
+
value=queries,
|
300
|
+
)
|
301
|
+
keys = keys + attention_map
|
302
|
+
keys = self.layer_norm4(keys)
|
303
|
+
|
304
|
+
return queries, keys
|
305
|
+
|
306
|
+
def get_config(self):
|
307
|
+
config = super().get_config()
|
308
|
+
config.update(
|
309
|
+
{
|
310
|
+
"num_heads": self.num_heads,
|
311
|
+
"key_dim": self.key_dim,
|
312
|
+
"intermediate_dim": self.intermediate_dim,
|
313
|
+
"skip_first_layer_pos_embedding": self.skip_first_layer_pos_embedding,
|
314
|
+
"attention_downsample_rate": self.attention_downsample_rate,
|
315
|
+
"activation": self.activation,
|
316
|
+
}
|
317
|
+
)
|
318
|
+
return config
|
319
|
+
|
320
|
+
|
321
|
+
class RandomFrequencyPositionalEmbeddings(keras.layers.Layer):
|
322
|
+
"""Positional encoding using random spatial frequencies.
|
323
|
+
|
324
|
+
This layer maps coordinates/points in 2D space to positional
|
325
|
+
encodings using random spatial frequencies.
|
326
|
+
|
327
|
+
Args:
|
328
|
+
num_positional_features: int. Number of positional features
|
329
|
+
in the output.
|
330
|
+
scale: float. The standard deviation of the random frequencies.
|
331
|
+
"""
|
332
|
+
|
333
|
+
def __init__(self, num_positional_features, scale, **kwargs):
|
334
|
+
super().__init__(**kwargs)
|
335
|
+
self.num_positional_features = num_positional_features
|
336
|
+
self.scale = scale
|
337
|
+
self.positional_encoding_gaussian_matrix = self.add_weight(
|
338
|
+
name="positional_encoding_gaussian_matrix",
|
339
|
+
shape=(2, self.num_positional_features),
|
340
|
+
dtype=self.variable_dtype,
|
341
|
+
trainable=False,
|
342
|
+
initializer=keras.initializers.get("normal"),
|
343
|
+
)
|
344
|
+
|
345
|
+
def build(self, input_shape=None):
|
346
|
+
self.built = True
|
347
|
+
|
348
|
+
def _positional_encodings(self, coords):
|
349
|
+
coords = coords * 2 - 1
|
350
|
+
coords = coords @ ops.cast(
|
351
|
+
self.positional_encoding_gaussian_matrix, dtype=self.compute_dtype
|
352
|
+
)
|
353
|
+
coords = coords * (2 * math.pi)
|
354
|
+
return ops.concatenate([ops.sin(coords), ops.cos(coords)], axis=-1)
|
355
|
+
|
356
|
+
def call(self, size):
|
357
|
+
return self.encode_image(size)
|
358
|
+
|
359
|
+
def encode_image(self, size):
|
360
|
+
"""Generate a positional encoding for an image of any given size.
|
361
|
+
Args:
|
362
|
+
size: tuple[int, int]. The size of the image.
|
363
|
+
Returns:
|
364
|
+
tensor: Positional encoding of the image.
|
365
|
+
"""
|
366
|
+
height, width = size
|
367
|
+
grid = ops.ones(shape=(height, width), dtype=self.compute_dtype)
|
368
|
+
y_embed = ops.cumsum(grid, axis=0) - 0.5
|
369
|
+
x_embed = ops.cumsum(grid, axis=1) - 0.5
|
370
|
+
y_embed = y_embed / ops.cast(height, self.compute_dtype)
|
371
|
+
x_embed = x_embed / ops.cast(width, self.compute_dtype)
|
372
|
+
return self._positional_encodings(
|
373
|
+
ops.stack([x_embed, y_embed], axis=-1)
|
374
|
+
)
|
375
|
+
|
376
|
+
def encode_coordinates(self, coords_input, image_size):
|
377
|
+
"""Positionally encode points that are not normalized to `[0, 1]`.
|
378
|
+
Args:
|
379
|
+
coords_input: tensor. 2D coordinates/points to map.
|
380
|
+
image_size: tuple[int, int]. Height and width of the image
|
381
|
+
being prompted.
|
382
|
+
Returns:
|
383
|
+
tensor: Positional encodings of the normalized coordinates.
|
384
|
+
"""
|
385
|
+
coords_normalized = ops.stack(
|
386
|
+
[
|
387
|
+
coords_input[..., 0] / image_size[1],
|
388
|
+
coords_input[..., 1] / image_size[0],
|
389
|
+
],
|
390
|
+
axis=-1,
|
391
|
+
)
|
392
|
+
return self._positional_encodings(coords_normalized)
|
393
|
+
|
394
|
+
def get_config(self):
|
395
|
+
config = super().get_config()
|
396
|
+
config.update(
|
397
|
+
{
|
398
|
+
"num_positional_features": self.num_positional_features,
|
399
|
+
"scale": self.scale,
|
400
|
+
}
|
401
|
+
)
|
402
|
+
return config
|
@@ -0,0 +1,270 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import keras
|
16
|
+
from keras import ops
|
17
|
+
|
18
|
+
from keras_hub.src.api_export import keras_hub_export
|
19
|
+
from keras_hub.src.models.sam.sam_layers import MLP
|
20
|
+
from keras_hub.src.models.sam.sam_transformer import TwoWayTransformer
|
21
|
+
|
22
|
+
|
23
|
+
@keras_hub_export("keras_hub.layers.SAMMaskDecoder")
|
24
|
+
class SAMMaskDecoder(keras.layers.Layer):
|
25
|
+
"""Mask decoder for the Segment Anything Model (SAM).
|
26
|
+
|
27
|
+
This lightweight module efficiently maps the image embedding and a set of
|
28
|
+
prompt embeddings to an output mask. Before applying the transformer
|
29
|
+
decoder, the layer first inserts into the set of prompt embeddings a
|
30
|
+
learned output token embedding that will be used at the decoder's output.
|
31
|
+
For simplicity, these embeddings (not including the image embedding) are
|
32
|
+
collectively called "tokens".
|
33
|
+
|
34
|
+
The image embeddings, positional image embeddings, and tokens are passed
|
35
|
+
through a transformer decoder. After running the decoder, the layer
|
36
|
+
upsamples the updated image embedding by 4x with two transposed
|
37
|
+
convolutional layers (now it's downscaled 4x relative to the input
|
38
|
+
image). Then, the tokens attend once more to the image embedding and
|
39
|
+
the updated output token embedding are passed to a small 3-layer MLP that
|
40
|
+
outputs a vector matching the channel dimension of the upscaled image
|
41
|
+
embedding.
|
42
|
+
|
43
|
+
Finally, a mask is predicted with a spatially point-wise
|
44
|
+
product between the upscaled image embedding and the MLP's output.
|
45
|
+
|
46
|
+
Args:
|
47
|
+
hidden_size: int. The hidden size of the TwoWayTransformer.
|
48
|
+
num_layers: int. The number of layers in the TwoWayTransformer.
|
49
|
+
intermediate_dim: int. The intermediate dimension of the
|
50
|
+
TwoWayTransformer.
|
51
|
+
num_heads: int. The number of heads in the TwoWayTransformer.
|
52
|
+
embedding_dim: int, optional. The number of input features to the
|
53
|
+
transformer decoder. Defaults to `256`.
|
54
|
+
num_multimask_outputs: int, optional. Number of multimask outputs.
|
55
|
+
The model would generate these many extra masks. The total masks
|
56
|
+
generated by the model are `1 + num_multimask_outputs`. Defaults
|
57
|
+
to `3`.
|
58
|
+
iou_head_depth: int, optional. The depth of the dense net used to
|
59
|
+
predict the IoU confidence score. Defaults to `3`.
|
60
|
+
iou_head_hidden_dim: int, optional. The number of units in the hidden
|
61
|
+
layers used in the dense net to predict the IoU confidence score.
|
62
|
+
Defaults to `256`.
|
63
|
+
activation: str, optional. Activation to use in the mask upscaler
|
64
|
+
network. Defaults to `"gelu"`.
|
65
|
+
"""
|
66
|
+
|
67
|
+
def __init__(
|
68
|
+
self,
|
69
|
+
*,
|
70
|
+
hidden_size,
|
71
|
+
num_layers,
|
72
|
+
intermediate_dim,
|
73
|
+
num_heads,
|
74
|
+
embedding_dim=256,
|
75
|
+
num_multimask_outputs=3,
|
76
|
+
iou_head_depth=3,
|
77
|
+
iou_head_hidden_dim=256,
|
78
|
+
activation="gelu",
|
79
|
+
**kwargs,
|
80
|
+
):
|
81
|
+
super().__init__(**kwargs)
|
82
|
+
self.hidden_size = hidden_size
|
83
|
+
self.num_layers = num_layers
|
84
|
+
self.intermediate_dim = intermediate_dim
|
85
|
+
self.num_heads = num_heads
|
86
|
+
self.embedding_dim = embedding_dim
|
87
|
+
transformer = TwoWayTransformer(
|
88
|
+
num_layers=num_layers,
|
89
|
+
hidden_size=hidden_size,
|
90
|
+
intermediate_dim=intermediate_dim,
|
91
|
+
num_heads=num_heads,
|
92
|
+
dtype=self.dtype_policy,
|
93
|
+
)
|
94
|
+
self.transformer = transformer
|
95
|
+
self.num_multimask_outputs = num_multimask_outputs
|
96
|
+
self.iou_head_depth = iou_head_depth
|
97
|
+
self.iou_head_hidden_dim = iou_head_hidden_dim
|
98
|
+
self.activation = activation
|
99
|
+
|
100
|
+
self.iou_token = keras.layers.Embedding(
|
101
|
+
1, embedding_dim, dtype=self.dtype_policy
|
102
|
+
)
|
103
|
+
self.num_mask_tokens = num_multimask_outputs + 1
|
104
|
+
self.mask_tokens = keras.layers.Embedding(
|
105
|
+
self.num_mask_tokens, embedding_dim, dtype=self.dtype_policy
|
106
|
+
)
|
107
|
+
|
108
|
+
self.output_upscaling = keras.models.Sequential(
|
109
|
+
[
|
110
|
+
keras.layers.Conv2DTranspose(
|
111
|
+
embedding_dim // 4,
|
112
|
+
kernel_size=2,
|
113
|
+
strides=2,
|
114
|
+
dtype=self.dtype_policy,
|
115
|
+
),
|
116
|
+
keras.layers.LayerNormalization(
|
117
|
+
epsilon=1e-6, dtype=self.dtype_policy
|
118
|
+
),
|
119
|
+
keras.layers.Activation(activation, dtype=self.dtype_policy),
|
120
|
+
keras.layers.Conv2DTranspose(
|
121
|
+
embedding_dim // 8,
|
122
|
+
kernel_size=2,
|
123
|
+
strides=2,
|
124
|
+
dtype=self.dtype_policy,
|
125
|
+
),
|
126
|
+
keras.layers.Activation(activation, dtype=self.dtype_policy),
|
127
|
+
]
|
128
|
+
)
|
129
|
+
|
130
|
+
self.output_hypernetworks_mlps = [
|
131
|
+
MLP(embedding_dim, embedding_dim // 8, 3, dtype=self.dtype_policy)
|
132
|
+
for _ in range(self.num_mask_tokens)
|
133
|
+
]
|
134
|
+
|
135
|
+
self.iou_prediction_head = MLP(
|
136
|
+
iou_head_hidden_dim,
|
137
|
+
self.num_mask_tokens,
|
138
|
+
iou_head_depth,
|
139
|
+
dtype=self.dtype_policy,
|
140
|
+
)
|
141
|
+
|
142
|
+
def build(self, input_shape=None, **kwargs):
|
143
|
+
self.transformer.build()
|
144
|
+
self.iou_token.build([None])
|
145
|
+
self.mask_tokens.build([None])
|
146
|
+
self.output_upscaling.build([None, None, None, self.embedding_dim])
|
147
|
+
for mlp in self.output_hypernetworks_mlps:
|
148
|
+
mlp.build([None, self.embedding_dim])
|
149
|
+
self.iou_prediction_head.build([None, self.embedding_dim])
|
150
|
+
self.built = True
|
151
|
+
|
152
|
+
def call(
|
153
|
+
self,
|
154
|
+
image_embeddings,
|
155
|
+
prompt_dense_positional_embeddings,
|
156
|
+
prompt_sparse_embeddings,
|
157
|
+
prompt_dense_embeddings,
|
158
|
+
):
|
159
|
+
masks, iou_pred = self._predict_masks(
|
160
|
+
image_embeddings=image_embeddings,
|
161
|
+
image_positional_embeddings=prompt_dense_positional_embeddings,
|
162
|
+
prompt_sparse_embeddings=prompt_sparse_embeddings,
|
163
|
+
prompt_dense_embeddings=prompt_dense_embeddings,
|
164
|
+
)
|
165
|
+
|
166
|
+
return {"masks": masks, "iou_pred": iou_pred}
|
167
|
+
|
168
|
+
def _predict_masks(
|
169
|
+
self,
|
170
|
+
image_embeddings,
|
171
|
+
image_positional_embeddings,
|
172
|
+
prompt_sparse_embeddings,
|
173
|
+
prompt_dense_embeddings,
|
174
|
+
):
|
175
|
+
indices_iou = ops.arange(1, dtype="int32")
|
176
|
+
indices_mask = ops.arange(self.num_mask_tokens, dtype="int32")
|
177
|
+
|
178
|
+
output_tokens = ops.concatenate(
|
179
|
+
[self.iou_token(indices_iou), self.mask_tokens(indices_mask)],
|
180
|
+
axis=0,
|
181
|
+
)
|
182
|
+
output_tokens = ops.broadcast_to(
|
183
|
+
output_tokens[None, ...],
|
184
|
+
shape=(
|
185
|
+
ops.shape(prompt_sparse_embeddings)[0],
|
186
|
+
ops.shape(output_tokens)[0],
|
187
|
+
ops.shape(output_tokens)[1],
|
188
|
+
),
|
189
|
+
)
|
190
|
+
tokens = ops.concatenate(
|
191
|
+
[output_tokens, prompt_sparse_embeddings], axis=1
|
192
|
+
)
|
193
|
+
|
194
|
+
source = ops.broadcast_to(
|
195
|
+
image_embeddings,
|
196
|
+
shape=(
|
197
|
+
ops.shape(tokens)[0],
|
198
|
+
ops.shape(image_embeddings)[1],
|
199
|
+
ops.shape(image_embeddings)[2],
|
200
|
+
ops.shape(image_embeddings)[3],
|
201
|
+
),
|
202
|
+
)
|
203
|
+
source = source + prompt_dense_embeddings
|
204
|
+
positional_source = ops.broadcast_to(
|
205
|
+
image_positional_embeddings,
|
206
|
+
shape=(
|
207
|
+
ops.shape(tokens)[0],
|
208
|
+
ops.shape(image_embeddings)[1],
|
209
|
+
ops.shape(image_embeddings)[2],
|
210
|
+
ops.shape(image_embeddings)[3],
|
211
|
+
),
|
212
|
+
)
|
213
|
+
shape = ops.shape(source)
|
214
|
+
batch_dim, height, width, channels = (
|
215
|
+
shape[0],
|
216
|
+
shape[1],
|
217
|
+
shape[2],
|
218
|
+
shape[3],
|
219
|
+
)
|
220
|
+
|
221
|
+
hidden_state, source = self.transformer(
|
222
|
+
source, positional_source, tokens
|
223
|
+
)
|
224
|
+
iou_token_out = hidden_state[:, 0, :]
|
225
|
+
mask_tokens_out = hidden_state[:, 1 : (1 + self.num_mask_tokens), :]
|
226
|
+
|
227
|
+
source = ops.reshape(source, (batch_dim, height, width, channels))
|
228
|
+
upscaled_embeddings = self.output_upscaling(source)
|
229
|
+
hyper_in_list = []
|
230
|
+
for i in range(self.num_mask_tokens):
|
231
|
+
hyper_in_list.append(
|
232
|
+
self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
|
233
|
+
)
|
234
|
+
hyper_in = ops.stack(hyper_in_list, axis=1)
|
235
|
+
shape = ops.shape(upscaled_embeddings)
|
236
|
+
batch_dim, height, width, channels = (
|
237
|
+
shape[0],
|
238
|
+
shape[1],
|
239
|
+
shape[2],
|
240
|
+
shape[3],
|
241
|
+
)
|
242
|
+
upscaled_embeddings = ops.reshape(
|
243
|
+
ops.transpose(upscaled_embeddings, axes=(0, 3, 1, 2)),
|
244
|
+
(batch_dim, channels, height * width),
|
245
|
+
)
|
246
|
+
masks = ops.reshape(
|
247
|
+
hyper_in @ upscaled_embeddings,
|
248
|
+
(batch_dim, self.num_mask_tokens, height, width),
|
249
|
+
)
|
250
|
+
|
251
|
+
iou_pred = self.iou_prediction_head(iou_token_out)
|
252
|
+
|
253
|
+
return masks, iou_pred
|
254
|
+
|
255
|
+
def get_config(self):
|
256
|
+
config = super().get_config()
|
257
|
+
config.update(
|
258
|
+
{
|
259
|
+
"hidden_size": self.hidden_size,
|
260
|
+
"num_layers": self.num_layers,
|
261
|
+
"intermediate_dim": self.intermediate_dim,
|
262
|
+
"num_heads": self.num_heads,
|
263
|
+
"embedding_dim": self.embedding_dim,
|
264
|
+
"num_multimask_outputs": self.num_multimask_outputs,
|
265
|
+
"iou_head_depth": self.iou_head_depth,
|
266
|
+
"iou_head_hidden_dim": self.iou_head_hidden_dim,
|
267
|
+
"activation": self.activation,
|
268
|
+
}
|
269
|
+
)
|
270
|
+
return config
|