keras-hub-nightly 0.16.1.dev202409240339__py3-none-any.whl → 0.16.1.dev202409250340__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,402 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+
17
+ import keras
18
+ from keras import ops
19
+
20
+
21
+ class MLP(keras.layers.Layer):
22
+ """A MLP block with architecture.
23
+
24
+ `input_dim -> [hidden_dim] * (num_layers - 1) -> output_dim`.
25
+
26
+ Args:
27
+ hidden_dim: int. The number of units in the hidden layers.
28
+ output_dim: int. The number of units in the output layer.
29
+ num_layers: int. The total number of dense layers to use.
30
+ activation: str. Activation to use in the hidden layers.
31
+ Default is `"relu"`.
32
+ """
33
+
34
+ def __init__(
35
+ self, hidden_dim, output_dim, num_layers, activation="relu", **kwargs
36
+ ):
37
+ super().__init__(**kwargs)
38
+ self.hidden_dim = hidden_dim
39
+ self.output_dim = output_dim
40
+ self.num_layers = num_layers
41
+ self.activation = activation
42
+ h = [hidden_dim] * (num_layers - 1)
43
+ self.mlp_block = []
44
+ for hidden_dim in h:
45
+ self.mlp_block.append(
46
+ keras.layers.Dense(hidden_dim, dtype=self.dtype_policy)
47
+ )
48
+ self.mlp_block.append(
49
+ keras.layers.Activation(activation, dtype=self.dtype_policy)
50
+ )
51
+ self.mlp_block.append(
52
+ keras.layers.Dense(output_dim, dtype=self.dtype_policy)
53
+ )
54
+ self.mlp_block = keras.models.Sequential(self.mlp_block)
55
+
56
+ def build(self, input_shape):
57
+ self.mlp_block.build(input_shape)
58
+ self.built = True
59
+
60
+ def call(self, x):
61
+ return self.mlp_block(x)
62
+
63
+ def get_config(self):
64
+ config = super().get_config()
65
+ config.update(
66
+ {
67
+ "hidden_dim": self.hidden_dim,
68
+ "output_dim": self.output_dim,
69
+ "num_layers": self.num_layers,
70
+ "activation": self.activation,
71
+ }
72
+ )
73
+ return config
74
+
75
+
76
+ class MultiHeadAttentionWithDownsampling(keras.layers.Layer):
77
+ """Multi-Head Attention with downsampling.
78
+
79
+ An attention layer that allows for downscaling the size of the embedding
80
+ after projection to queries, keys, and values.
81
+ This layer first downscales the features of input queries, keys, and
82
+ values using a dense layer. Multi-head attention is then performed
83
+ and the attention map is projected back (upscaled) to the number of
84
+ input features.
85
+
86
+ Args:
87
+ num_heads: int. Number of attention heads.
88
+ key_dim: int. Size of each attention head for query, key, and
89
+ value.
90
+ downsample_rate: int, optional. The factor by which to downscale the
91
+ input features i.e. the input features of size `key_dim` are
92
+ projected down to `key_dim // downsample_rate`.
93
+ """
94
+
95
+ def __init__(self, num_heads, key_dim, downsample_rate=1, **kwargs):
96
+ super().__init__(**kwargs)
97
+ self.num_heads = num_heads
98
+ self.key_dim = key_dim
99
+ self.downsample_rate = downsample_rate
100
+ self.internal_dims = key_dim // downsample_rate
101
+
102
+ # Downsample
103
+ self.query_proj = keras.layers.Dense(
104
+ self.internal_dims * self.num_heads, dtype=self.dtype_policy
105
+ )
106
+ self.key_proj = keras.layers.Dense(
107
+ self.internal_dims * self.num_heads, dtype=self.dtype_policy
108
+ )
109
+ self.value_proj = keras.layers.Dense(
110
+ self.internal_dims * self.num_heads, dtype=self.dtype_policy
111
+ )
112
+
113
+ # Upsample
114
+ self.out_proj = keras.layers.Dense(
115
+ self.key_dim * self.num_heads, dtype=self.dtype_policy
116
+ )
117
+
118
+ def build(self, input_shape=None):
119
+ self.query_proj.build([None, None, self.num_heads * self.key_dim])
120
+ self.key_proj.build([None, None, self.num_heads * self.key_dim])
121
+ self.value_proj.build([None, None, self.num_heads * self.key_dim])
122
+ self.out_proj.build([None, None, self.internal_dims * self.num_heads])
123
+ self.built = True
124
+
125
+ def _separate_heads(self, x):
126
+ shape = ops.shape(x)
127
+ batch_size, N, channels = shape[0], shape[1], shape[2]
128
+ x = ops.reshape(
129
+ x, (batch_size, N, self.num_heads, channels // self.num_heads)
130
+ )
131
+ return ops.transpose(x, axes=(0, 2, 1, 3))
132
+
133
+ def _recombine_heads(self, x):
134
+ shape = ops.shape(x)
135
+ batch_size, num_heads, N_T, channels_per_head = (
136
+ shape[0],
137
+ shape[1],
138
+ shape[2],
139
+ shape[3],
140
+ )
141
+ x = ops.transpose(x, axes=(0, 2, 1, 3))
142
+ return ops.reshape(x, (batch_size, N_T, num_heads * channels_per_head))
143
+
144
+ def call(self, query, value, key):
145
+ query = self.query_proj(query)
146
+ key = self.key_proj(key)
147
+ value = self.value_proj(value)
148
+
149
+ # Separate into heads
150
+ query = self._separate_heads(query)
151
+ key = self._separate_heads(key)
152
+ value = self._separate_heads(value)
153
+
154
+ # Attention
155
+ channels_per_head = ops.shape(query)[-1]
156
+ out = ops.matmul(query, ops.transpose(key, (0, 1, 3, 2)))
157
+ out = out / ops.sqrt(
158
+ ops.cast(channels_per_head, dtype=self.compute_dtype)
159
+ )
160
+ out = ops.softmax(out, axis=-1)
161
+
162
+ # Get output
163
+ attention_map = out @ value
164
+ attention_map = self._recombine_heads(attention_map)
165
+ return self.out_proj(attention_map)
166
+
167
+ def get_config(self):
168
+ config = super().get_config()
169
+ config.update(
170
+ {
171
+ "num_heads": self.num_heads,
172
+ "key_dim": self.key_dim,
173
+ "downsample_rate": self.downsample_rate,
174
+ }
175
+ )
176
+ return config
177
+
178
+
179
+ class TwoWayMultiHeadAttention(keras.layers.Layer):
180
+ """Two-way multi-head attention layer.
181
+
182
+ Args:
183
+ num_heads: int. Number of attention heads.
184
+ key_dim: int. Size of each attention head for query, key, and
185
+ value.
186
+ intermediate_dim: int. Number of hidden dims to use in the mlp block.
187
+ skip_first_layer_pos_embedding: bool. A boolean indicating whether to skip the
188
+ first layer positional embeddings.
189
+ attention_downsample_rate: int, optional. The downsample rate to use
190
+ in the attention layers. Defaults to 2.
191
+ activation: str, optional. The activation for the mlp block's output
192
+ layer. Defaults to "relu".
193
+ """
194
+
195
+ def __init__(
196
+ self,
197
+ num_heads,
198
+ key_dim,
199
+ intermediate_dim,
200
+ skip_first_layer_pos_embedding,
201
+ attention_downsample_rate=2,
202
+ activation="relu",
203
+ **kwargs,
204
+ ):
205
+ super().__init__(**kwargs)
206
+ self.num_heads = num_heads
207
+ self.key_dim = key_dim
208
+ self.intermediate_dim = intermediate_dim
209
+ self.skip_first_layer_pos_embedding = skip_first_layer_pos_embedding
210
+ self.attention_downsample_rate = attention_downsample_rate
211
+ self.activation = activation
212
+
213
+ self.self_attention = MultiHeadAttentionWithDownsampling(
214
+ num_heads=num_heads, key_dim=key_dim, dtype=self.dtype_policy
215
+ )
216
+ self.layer_norm1 = keras.layers.LayerNormalization(
217
+ epsilon=1e-5, dtype=self.dtype_policy
218
+ )
219
+ self.cross_attention_token_to_image = (
220
+ MultiHeadAttentionWithDownsampling(
221
+ num_heads=num_heads,
222
+ key_dim=key_dim,
223
+ downsample_rate=attention_downsample_rate,
224
+ dtype=self.dtype_policy,
225
+ )
226
+ )
227
+ self.layer_norm2 = keras.layers.LayerNormalization(
228
+ epsilon=1e-5, dtype=self.dtype_policy
229
+ )
230
+
231
+ self.mlp_block = MLP(
232
+ intermediate_dim,
233
+ key_dim * num_heads,
234
+ num_layers=2,
235
+ activation=activation,
236
+ dtype=self.dtype_policy,
237
+ )
238
+
239
+ self.layer_norm3 = keras.layers.LayerNormalization(
240
+ epsilon=1e-5, dtype=self.dtype_policy
241
+ )
242
+ self.cross_attention_image_to_token = (
243
+ MultiHeadAttentionWithDownsampling(
244
+ num_heads=num_heads,
245
+ key_dim=key_dim,
246
+ downsample_rate=attention_downsample_rate,
247
+ dtype=self.dtype_policy,
248
+ )
249
+ )
250
+ self.layer_norm4 = keras.layers.LayerNormalization(
251
+ epsilon=1e-5, dtype=self.dtype_policy
252
+ )
253
+
254
+ def build(self, input_shape=None):
255
+ self.self_attention.build()
256
+ self.layer_norm1.build([None, None, self.num_heads * self.key_dim])
257
+ self.cross_attention_token_to_image.build()
258
+ self.layer_norm2.build([None, None, self.num_heads * self.key_dim])
259
+ self.mlp_block.build([None, None, self.num_heads * self.key_dim])
260
+ self.layer_norm3.build([None, None, self.num_heads * self.key_dim])
261
+ self.cross_attention_image_to_token.build()
262
+ self.layer_norm4.build([None, None, self.num_heads * self.key_dim])
263
+ self.built = True
264
+
265
+ def call(self, queries, keys, query_pos_embedding, key_pos_embedding):
266
+ if self.skip_first_layer_pos_embedding:
267
+ queries = self.self_attention(
268
+ query=queries, value=queries, key=queries
269
+ )
270
+ else:
271
+ queries_with_pos_embedding = queries + query_pos_embedding
272
+ attention_map = self.self_attention(
273
+ query=queries_with_pos_embedding,
274
+ key=queries_with_pos_embedding,
275
+ value=queries,
276
+ )
277
+ queries = queries + attention_map
278
+ queries = self.layer_norm1(queries)
279
+
280
+ queries_with_pos_embedding = queries + query_pos_embedding
281
+ keys_with_pos_embedding = keys + key_pos_embedding
282
+ attention_map = self.cross_attention_token_to_image(
283
+ query=queries_with_pos_embedding,
284
+ key=keys_with_pos_embedding,
285
+ value=keys,
286
+ )
287
+ queries = queries + attention_map
288
+ queries = self.layer_norm2(queries)
289
+
290
+ mlp_out = self.mlp_block(queries)
291
+ queries = queries + mlp_out
292
+ queries = self.layer_norm3(queries)
293
+
294
+ queries_with_pos_embedding = queries + query_pos_embedding
295
+ keys_with_pos_embedding = keys + key_pos_embedding
296
+ attention_map = self.cross_attention_image_to_token(
297
+ query=keys_with_pos_embedding,
298
+ key=queries_with_pos_embedding,
299
+ value=queries,
300
+ )
301
+ keys = keys + attention_map
302
+ keys = self.layer_norm4(keys)
303
+
304
+ return queries, keys
305
+
306
+ def get_config(self):
307
+ config = super().get_config()
308
+ config.update(
309
+ {
310
+ "num_heads": self.num_heads,
311
+ "key_dim": self.key_dim,
312
+ "intermediate_dim": self.intermediate_dim,
313
+ "skip_first_layer_pos_embedding": self.skip_first_layer_pos_embedding,
314
+ "attention_downsample_rate": self.attention_downsample_rate,
315
+ "activation": self.activation,
316
+ }
317
+ )
318
+ return config
319
+
320
+
321
+ class RandomFrequencyPositionalEmbeddings(keras.layers.Layer):
322
+ """Positional encoding using random spatial frequencies.
323
+
324
+ This layer maps coordinates/points in 2D space to positional
325
+ encodings using random spatial frequencies.
326
+
327
+ Args:
328
+ num_positional_features: int. Number of positional features
329
+ in the output.
330
+ scale: float. The standard deviation of the random frequencies.
331
+ """
332
+
333
+ def __init__(self, num_positional_features, scale, **kwargs):
334
+ super().__init__(**kwargs)
335
+ self.num_positional_features = num_positional_features
336
+ self.scale = scale
337
+ self.positional_encoding_gaussian_matrix = self.add_weight(
338
+ name="positional_encoding_gaussian_matrix",
339
+ shape=(2, self.num_positional_features),
340
+ dtype=self.variable_dtype,
341
+ trainable=False,
342
+ initializer=keras.initializers.get("normal"),
343
+ )
344
+
345
+ def build(self, input_shape=None):
346
+ self.built = True
347
+
348
+ def _positional_encodings(self, coords):
349
+ coords = coords * 2 - 1
350
+ coords = coords @ ops.cast(
351
+ self.positional_encoding_gaussian_matrix, dtype=self.compute_dtype
352
+ )
353
+ coords = coords * (2 * math.pi)
354
+ return ops.concatenate([ops.sin(coords), ops.cos(coords)], axis=-1)
355
+
356
+ def call(self, size):
357
+ return self.encode_image(size)
358
+
359
+ def encode_image(self, size):
360
+ """Generate a positional encoding for an image of any given size.
361
+ Args:
362
+ size: tuple[int, int]. The size of the image.
363
+ Returns:
364
+ tensor: Positional encoding of the image.
365
+ """
366
+ height, width = size
367
+ grid = ops.ones(shape=(height, width), dtype=self.compute_dtype)
368
+ y_embed = ops.cumsum(grid, axis=0) - 0.5
369
+ x_embed = ops.cumsum(grid, axis=1) - 0.5
370
+ y_embed = y_embed / ops.cast(height, self.compute_dtype)
371
+ x_embed = x_embed / ops.cast(width, self.compute_dtype)
372
+ return self._positional_encodings(
373
+ ops.stack([x_embed, y_embed], axis=-1)
374
+ )
375
+
376
+ def encode_coordinates(self, coords_input, image_size):
377
+ """Positionally encode points that are not normalized to `[0, 1]`.
378
+ Args:
379
+ coords_input: tensor. 2D coordinates/points to map.
380
+ image_size: tuple[int, int]. Height and width of the image
381
+ being prompted.
382
+ Returns:
383
+ tensor: Positional encodings of the normalized coordinates.
384
+ """
385
+ coords_normalized = ops.stack(
386
+ [
387
+ coords_input[..., 0] / image_size[1],
388
+ coords_input[..., 1] / image_size[0],
389
+ ],
390
+ axis=-1,
391
+ )
392
+ return self._positional_encodings(coords_normalized)
393
+
394
+ def get_config(self):
395
+ config = super().get_config()
396
+ config.update(
397
+ {
398
+ "num_positional_features": self.num_positional_features,
399
+ "scale": self.scale,
400
+ }
401
+ )
402
+ return config
@@ -0,0 +1,270 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import keras
16
+ from keras import ops
17
+
18
+ from keras_hub.src.api_export import keras_hub_export
19
+ from keras_hub.src.models.sam.sam_layers import MLP
20
+ from keras_hub.src.models.sam.sam_transformer import TwoWayTransformer
21
+
22
+
23
+ @keras_hub_export("keras_hub.layers.SAMMaskDecoder")
24
+ class SAMMaskDecoder(keras.layers.Layer):
25
+ """Mask decoder for the Segment Anything Model (SAM).
26
+
27
+ This lightweight module efficiently maps the image embedding and a set of
28
+ prompt embeddings to an output mask. Before applying the transformer
29
+ decoder, the layer first inserts into the set of prompt embeddings a
30
+ learned output token embedding that will be used at the decoder's output.
31
+ For simplicity, these embeddings (not including the image embedding) are
32
+ collectively called "tokens".
33
+
34
+ The image embeddings, positional image embeddings, and tokens are passed
35
+ through a transformer decoder. After running the decoder, the layer
36
+ upsamples the updated image embedding by 4x with two transposed
37
+ convolutional layers (now it's downscaled 4x relative to the input
38
+ image). Then, the tokens attend once more to the image embedding and
39
+ the updated output token embedding are passed to a small 3-layer MLP that
40
+ outputs a vector matching the channel dimension of the upscaled image
41
+ embedding.
42
+
43
+ Finally, a mask is predicted with a spatially point-wise
44
+ product between the upscaled image embedding and the MLP's output.
45
+
46
+ Args:
47
+ hidden_size: int. The hidden size of the TwoWayTransformer.
48
+ num_layers: int. The number of layers in the TwoWayTransformer.
49
+ intermediate_dim: int. The intermediate dimension of the
50
+ TwoWayTransformer.
51
+ num_heads: int. The number of heads in the TwoWayTransformer.
52
+ embedding_dim: int, optional. The number of input features to the
53
+ transformer decoder. Defaults to `256`.
54
+ num_multimask_outputs: int, optional. Number of multimask outputs.
55
+ The model would generate these many extra masks. The total masks
56
+ generated by the model are `1 + num_multimask_outputs`. Defaults
57
+ to `3`.
58
+ iou_head_depth: int, optional. The depth of the dense net used to
59
+ predict the IoU confidence score. Defaults to `3`.
60
+ iou_head_hidden_dim: int, optional. The number of units in the hidden
61
+ layers used in the dense net to predict the IoU confidence score.
62
+ Defaults to `256`.
63
+ activation: str, optional. Activation to use in the mask upscaler
64
+ network. Defaults to `"gelu"`.
65
+ """
66
+
67
+ def __init__(
68
+ self,
69
+ *,
70
+ hidden_size,
71
+ num_layers,
72
+ intermediate_dim,
73
+ num_heads,
74
+ embedding_dim=256,
75
+ num_multimask_outputs=3,
76
+ iou_head_depth=3,
77
+ iou_head_hidden_dim=256,
78
+ activation="gelu",
79
+ **kwargs,
80
+ ):
81
+ super().__init__(**kwargs)
82
+ self.hidden_size = hidden_size
83
+ self.num_layers = num_layers
84
+ self.intermediate_dim = intermediate_dim
85
+ self.num_heads = num_heads
86
+ self.embedding_dim = embedding_dim
87
+ transformer = TwoWayTransformer(
88
+ num_layers=num_layers,
89
+ hidden_size=hidden_size,
90
+ intermediate_dim=intermediate_dim,
91
+ num_heads=num_heads,
92
+ dtype=self.dtype_policy,
93
+ )
94
+ self.transformer = transformer
95
+ self.num_multimask_outputs = num_multimask_outputs
96
+ self.iou_head_depth = iou_head_depth
97
+ self.iou_head_hidden_dim = iou_head_hidden_dim
98
+ self.activation = activation
99
+
100
+ self.iou_token = keras.layers.Embedding(
101
+ 1, embedding_dim, dtype=self.dtype_policy
102
+ )
103
+ self.num_mask_tokens = num_multimask_outputs + 1
104
+ self.mask_tokens = keras.layers.Embedding(
105
+ self.num_mask_tokens, embedding_dim, dtype=self.dtype_policy
106
+ )
107
+
108
+ self.output_upscaling = keras.models.Sequential(
109
+ [
110
+ keras.layers.Conv2DTranspose(
111
+ embedding_dim // 4,
112
+ kernel_size=2,
113
+ strides=2,
114
+ dtype=self.dtype_policy,
115
+ ),
116
+ keras.layers.LayerNormalization(
117
+ epsilon=1e-6, dtype=self.dtype_policy
118
+ ),
119
+ keras.layers.Activation(activation, dtype=self.dtype_policy),
120
+ keras.layers.Conv2DTranspose(
121
+ embedding_dim // 8,
122
+ kernel_size=2,
123
+ strides=2,
124
+ dtype=self.dtype_policy,
125
+ ),
126
+ keras.layers.Activation(activation, dtype=self.dtype_policy),
127
+ ]
128
+ )
129
+
130
+ self.output_hypernetworks_mlps = [
131
+ MLP(embedding_dim, embedding_dim // 8, 3, dtype=self.dtype_policy)
132
+ for _ in range(self.num_mask_tokens)
133
+ ]
134
+
135
+ self.iou_prediction_head = MLP(
136
+ iou_head_hidden_dim,
137
+ self.num_mask_tokens,
138
+ iou_head_depth,
139
+ dtype=self.dtype_policy,
140
+ )
141
+
142
+ def build(self, input_shape=None, **kwargs):
143
+ self.transformer.build()
144
+ self.iou_token.build([None])
145
+ self.mask_tokens.build([None])
146
+ self.output_upscaling.build([None, None, None, self.embedding_dim])
147
+ for mlp in self.output_hypernetworks_mlps:
148
+ mlp.build([None, self.embedding_dim])
149
+ self.iou_prediction_head.build([None, self.embedding_dim])
150
+ self.built = True
151
+
152
+ def call(
153
+ self,
154
+ image_embeddings,
155
+ prompt_dense_positional_embeddings,
156
+ prompt_sparse_embeddings,
157
+ prompt_dense_embeddings,
158
+ ):
159
+ masks, iou_pred = self._predict_masks(
160
+ image_embeddings=image_embeddings,
161
+ image_positional_embeddings=prompt_dense_positional_embeddings,
162
+ prompt_sparse_embeddings=prompt_sparse_embeddings,
163
+ prompt_dense_embeddings=prompt_dense_embeddings,
164
+ )
165
+
166
+ return {"masks": masks, "iou_pred": iou_pred}
167
+
168
+ def _predict_masks(
169
+ self,
170
+ image_embeddings,
171
+ image_positional_embeddings,
172
+ prompt_sparse_embeddings,
173
+ prompt_dense_embeddings,
174
+ ):
175
+ indices_iou = ops.arange(1, dtype="int32")
176
+ indices_mask = ops.arange(self.num_mask_tokens, dtype="int32")
177
+
178
+ output_tokens = ops.concatenate(
179
+ [self.iou_token(indices_iou), self.mask_tokens(indices_mask)],
180
+ axis=0,
181
+ )
182
+ output_tokens = ops.broadcast_to(
183
+ output_tokens[None, ...],
184
+ shape=(
185
+ ops.shape(prompt_sparse_embeddings)[0],
186
+ ops.shape(output_tokens)[0],
187
+ ops.shape(output_tokens)[1],
188
+ ),
189
+ )
190
+ tokens = ops.concatenate(
191
+ [output_tokens, prompt_sparse_embeddings], axis=1
192
+ )
193
+
194
+ source = ops.broadcast_to(
195
+ image_embeddings,
196
+ shape=(
197
+ ops.shape(tokens)[0],
198
+ ops.shape(image_embeddings)[1],
199
+ ops.shape(image_embeddings)[2],
200
+ ops.shape(image_embeddings)[3],
201
+ ),
202
+ )
203
+ source = source + prompt_dense_embeddings
204
+ positional_source = ops.broadcast_to(
205
+ image_positional_embeddings,
206
+ shape=(
207
+ ops.shape(tokens)[0],
208
+ ops.shape(image_embeddings)[1],
209
+ ops.shape(image_embeddings)[2],
210
+ ops.shape(image_embeddings)[3],
211
+ ),
212
+ )
213
+ shape = ops.shape(source)
214
+ batch_dim, height, width, channels = (
215
+ shape[0],
216
+ shape[1],
217
+ shape[2],
218
+ shape[3],
219
+ )
220
+
221
+ hidden_state, source = self.transformer(
222
+ source, positional_source, tokens
223
+ )
224
+ iou_token_out = hidden_state[:, 0, :]
225
+ mask_tokens_out = hidden_state[:, 1 : (1 + self.num_mask_tokens), :]
226
+
227
+ source = ops.reshape(source, (batch_dim, height, width, channels))
228
+ upscaled_embeddings = self.output_upscaling(source)
229
+ hyper_in_list = []
230
+ for i in range(self.num_mask_tokens):
231
+ hyper_in_list.append(
232
+ self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
233
+ )
234
+ hyper_in = ops.stack(hyper_in_list, axis=1)
235
+ shape = ops.shape(upscaled_embeddings)
236
+ batch_dim, height, width, channels = (
237
+ shape[0],
238
+ shape[1],
239
+ shape[2],
240
+ shape[3],
241
+ )
242
+ upscaled_embeddings = ops.reshape(
243
+ ops.transpose(upscaled_embeddings, axes=(0, 3, 1, 2)),
244
+ (batch_dim, channels, height * width),
245
+ )
246
+ masks = ops.reshape(
247
+ hyper_in @ upscaled_embeddings,
248
+ (batch_dim, self.num_mask_tokens, height, width),
249
+ )
250
+
251
+ iou_pred = self.iou_prediction_head(iou_token_out)
252
+
253
+ return masks, iou_pred
254
+
255
+ def get_config(self):
256
+ config = super().get_config()
257
+ config.update(
258
+ {
259
+ "hidden_size": self.hidden_size,
260
+ "num_layers": self.num_layers,
261
+ "intermediate_dim": self.intermediate_dim,
262
+ "num_heads": self.num_heads,
263
+ "embedding_dim": self.embedding_dim,
264
+ "num_multimask_outputs": self.num_multimask_outputs,
265
+ "iou_head_depth": self.iou_head_depth,
266
+ "iou_head_hidden_dim": self.iou_head_hidden_dim,
267
+ "activation": self.activation,
268
+ }
269
+ )
270
+ return config