keras-hub-nightly 0.16.1.dev202409240339__py3-none-any.whl → 0.16.1.dev202409260340__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. keras_hub/api/layers/__init__.py +5 -0
  2. keras_hub/api/models/__init__.py +19 -0
  3. keras_hub/api/tokenizers/__init__.py +1 -0
  4. keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_encoder_block.py +8 -2
  5. keras_hub/src/models/clip/clip_preprocessor.py +147 -0
  6. keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_text_encoder.py +60 -57
  7. keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_tokenizer.py +69 -30
  8. keras_hub/src/models/densenet/__init__.py +6 -0
  9. keras_hub/src/models/densenet/densenet_backbone.py +11 -8
  10. keras_hub/src/models/densenet/densenet_image_classifier.py +27 -4
  11. keras_hub/src/models/densenet/densenet_image_classifier_preprocessor.py +27 -0
  12. keras_hub/src/models/densenet/densenet_image_converter.py +23 -0
  13. keras_hub/src/models/densenet/densenet_presets.py +56 -0
  14. keras_hub/src/models/image_segmenter.py +86 -0
  15. keras_hub/src/models/sam/__init__.py +13 -0
  16. keras_hub/src/models/sam/sam_backbone.py +153 -0
  17. keras_hub/src/models/sam/sam_image_segmenter.py +237 -0
  18. keras_hub/src/models/sam/sam_layers.py +402 -0
  19. keras_hub/src/models/sam/sam_mask_decoder.py +270 -0
  20. keras_hub/src/models/sam/sam_prompt_encoder.py +336 -0
  21. keras_hub/src/models/sam/sam_transformer.py +159 -0
  22. keras_hub/src/models/stable_diffusion_3/__init__.py +13 -0
  23. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +93 -0
  24. keras_hub/src/models/{stable_diffusion_v3 → stable_diffusion_3}/mmdit.py +351 -26
  25. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +630 -0
  26. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +151 -0
  27. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +77 -0
  28. keras_hub/src/models/{stable_diffusion_v3/t5_xxl_text_encoder.py → stable_diffusion_3/t5_encoder.py} +7 -7
  29. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +333 -0
  30. keras_hub/src/models/{stable_diffusion_v3/t5_xxl_preprocessor.py → t5/t5_preprocessor.py} +12 -3
  31. keras_hub/src/models/text_to_image.py +295 -0
  32. keras_hub/src/models/vit_det/vit_det_backbone.py +17 -12
  33. keras_hub/src/utils/timm/convert_densenet.py +107 -0
  34. keras_hub/src/utils/timm/preset_loader.py +3 -0
  35. keras_hub/src/version_utils.py +1 -1
  36. {keras_hub_nightly-0.16.1.dev202409240339.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/METADATA +1 -1
  37. {keras_hub_nightly-0.16.1.dev202409240339.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/RECORD +40 -24
  38. keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +0 -93
  39. keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +0 -317
  40. keras_hub/src/models/stable_diffusion_v3/vae_attention.py +0 -126
  41. keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +0 -186
  42. /keras_hub/src/models/{stable_diffusion_v3 → clip}/__init__.py +0 -0
  43. {keras_hub_nightly-0.16.1.dev202409240339.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/WHEEL +0 -0
  44. {keras_hub_nightly-0.16.1.dev202409240339.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,270 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import keras
16
+ from keras import ops
17
+
18
+ from keras_hub.src.api_export import keras_hub_export
19
+ from keras_hub.src.models.sam.sam_layers import MLP
20
+ from keras_hub.src.models.sam.sam_transformer import TwoWayTransformer
21
+
22
+
23
+ @keras_hub_export("keras_hub.layers.SAMMaskDecoder")
24
+ class SAMMaskDecoder(keras.layers.Layer):
25
+ """Mask decoder for the Segment Anything Model (SAM).
26
+
27
+ This lightweight module efficiently maps the image embedding and a set of
28
+ prompt embeddings to an output mask. Before applying the transformer
29
+ decoder, the layer first inserts into the set of prompt embeddings a
30
+ learned output token embedding that will be used at the decoder's output.
31
+ For simplicity, these embeddings (not including the image embedding) are
32
+ collectively called "tokens".
33
+
34
+ The image embeddings, positional image embeddings, and tokens are passed
35
+ through a transformer decoder. After running the decoder, the layer
36
+ upsamples the updated image embedding by 4x with two transposed
37
+ convolutional layers (now it's downscaled 4x relative to the input
38
+ image). Then, the tokens attend once more to the image embedding and
39
+ the updated output token embedding are passed to a small 3-layer MLP that
40
+ outputs a vector matching the channel dimension of the upscaled image
41
+ embedding.
42
+
43
+ Finally, a mask is predicted with a spatially point-wise
44
+ product between the upscaled image embedding and the MLP's output.
45
+
46
+ Args:
47
+ hidden_size: int. The hidden size of the TwoWayTransformer.
48
+ num_layers: int. The number of layers in the TwoWayTransformer.
49
+ intermediate_dim: int. The intermediate dimension of the
50
+ TwoWayTransformer.
51
+ num_heads: int. The number of heads in the TwoWayTransformer.
52
+ embedding_dim: int, optional. The number of input features to the
53
+ transformer decoder. Defaults to `256`.
54
+ num_multimask_outputs: int, optional. Number of multimask outputs.
55
+ The model would generate these many extra masks. The total masks
56
+ generated by the model are `1 + num_multimask_outputs`. Defaults
57
+ to `3`.
58
+ iou_head_depth: int, optional. The depth of the dense net used to
59
+ predict the IoU confidence score. Defaults to `3`.
60
+ iou_head_hidden_dim: int, optional. The number of units in the hidden
61
+ layers used in the dense net to predict the IoU confidence score.
62
+ Defaults to `256`.
63
+ activation: str, optional. Activation to use in the mask upscaler
64
+ network. Defaults to `"gelu"`.
65
+ """
66
+
67
+ def __init__(
68
+ self,
69
+ *,
70
+ hidden_size,
71
+ num_layers,
72
+ intermediate_dim,
73
+ num_heads,
74
+ embedding_dim=256,
75
+ num_multimask_outputs=3,
76
+ iou_head_depth=3,
77
+ iou_head_hidden_dim=256,
78
+ activation="gelu",
79
+ **kwargs,
80
+ ):
81
+ super().__init__(**kwargs)
82
+ self.hidden_size = hidden_size
83
+ self.num_layers = num_layers
84
+ self.intermediate_dim = intermediate_dim
85
+ self.num_heads = num_heads
86
+ self.embedding_dim = embedding_dim
87
+ transformer = TwoWayTransformer(
88
+ num_layers=num_layers,
89
+ hidden_size=hidden_size,
90
+ intermediate_dim=intermediate_dim,
91
+ num_heads=num_heads,
92
+ dtype=self.dtype_policy,
93
+ )
94
+ self.transformer = transformer
95
+ self.num_multimask_outputs = num_multimask_outputs
96
+ self.iou_head_depth = iou_head_depth
97
+ self.iou_head_hidden_dim = iou_head_hidden_dim
98
+ self.activation = activation
99
+
100
+ self.iou_token = keras.layers.Embedding(
101
+ 1, embedding_dim, dtype=self.dtype_policy
102
+ )
103
+ self.num_mask_tokens = num_multimask_outputs + 1
104
+ self.mask_tokens = keras.layers.Embedding(
105
+ self.num_mask_tokens, embedding_dim, dtype=self.dtype_policy
106
+ )
107
+
108
+ self.output_upscaling = keras.models.Sequential(
109
+ [
110
+ keras.layers.Conv2DTranspose(
111
+ embedding_dim // 4,
112
+ kernel_size=2,
113
+ strides=2,
114
+ dtype=self.dtype_policy,
115
+ ),
116
+ keras.layers.LayerNormalization(
117
+ epsilon=1e-6, dtype=self.dtype_policy
118
+ ),
119
+ keras.layers.Activation(activation, dtype=self.dtype_policy),
120
+ keras.layers.Conv2DTranspose(
121
+ embedding_dim // 8,
122
+ kernel_size=2,
123
+ strides=2,
124
+ dtype=self.dtype_policy,
125
+ ),
126
+ keras.layers.Activation(activation, dtype=self.dtype_policy),
127
+ ]
128
+ )
129
+
130
+ self.output_hypernetworks_mlps = [
131
+ MLP(embedding_dim, embedding_dim // 8, 3, dtype=self.dtype_policy)
132
+ for _ in range(self.num_mask_tokens)
133
+ ]
134
+
135
+ self.iou_prediction_head = MLP(
136
+ iou_head_hidden_dim,
137
+ self.num_mask_tokens,
138
+ iou_head_depth,
139
+ dtype=self.dtype_policy,
140
+ )
141
+
142
+ def build(self, input_shape=None, **kwargs):
143
+ self.transformer.build()
144
+ self.iou_token.build([None])
145
+ self.mask_tokens.build([None])
146
+ self.output_upscaling.build([None, None, None, self.embedding_dim])
147
+ for mlp in self.output_hypernetworks_mlps:
148
+ mlp.build([None, self.embedding_dim])
149
+ self.iou_prediction_head.build([None, self.embedding_dim])
150
+ self.built = True
151
+
152
+ def call(
153
+ self,
154
+ image_embeddings,
155
+ prompt_dense_positional_embeddings,
156
+ prompt_sparse_embeddings,
157
+ prompt_dense_embeddings,
158
+ ):
159
+ masks, iou_pred = self._predict_masks(
160
+ image_embeddings=image_embeddings,
161
+ image_positional_embeddings=prompt_dense_positional_embeddings,
162
+ prompt_sparse_embeddings=prompt_sparse_embeddings,
163
+ prompt_dense_embeddings=prompt_dense_embeddings,
164
+ )
165
+
166
+ return {"masks": masks, "iou_pred": iou_pred}
167
+
168
+ def _predict_masks(
169
+ self,
170
+ image_embeddings,
171
+ image_positional_embeddings,
172
+ prompt_sparse_embeddings,
173
+ prompt_dense_embeddings,
174
+ ):
175
+ indices_iou = ops.arange(1, dtype="int32")
176
+ indices_mask = ops.arange(self.num_mask_tokens, dtype="int32")
177
+
178
+ output_tokens = ops.concatenate(
179
+ [self.iou_token(indices_iou), self.mask_tokens(indices_mask)],
180
+ axis=0,
181
+ )
182
+ output_tokens = ops.broadcast_to(
183
+ output_tokens[None, ...],
184
+ shape=(
185
+ ops.shape(prompt_sparse_embeddings)[0],
186
+ ops.shape(output_tokens)[0],
187
+ ops.shape(output_tokens)[1],
188
+ ),
189
+ )
190
+ tokens = ops.concatenate(
191
+ [output_tokens, prompt_sparse_embeddings], axis=1
192
+ )
193
+
194
+ source = ops.broadcast_to(
195
+ image_embeddings,
196
+ shape=(
197
+ ops.shape(tokens)[0],
198
+ ops.shape(image_embeddings)[1],
199
+ ops.shape(image_embeddings)[2],
200
+ ops.shape(image_embeddings)[3],
201
+ ),
202
+ )
203
+ source = source + prompt_dense_embeddings
204
+ positional_source = ops.broadcast_to(
205
+ image_positional_embeddings,
206
+ shape=(
207
+ ops.shape(tokens)[0],
208
+ ops.shape(image_embeddings)[1],
209
+ ops.shape(image_embeddings)[2],
210
+ ops.shape(image_embeddings)[3],
211
+ ),
212
+ )
213
+ shape = ops.shape(source)
214
+ batch_dim, height, width, channels = (
215
+ shape[0],
216
+ shape[1],
217
+ shape[2],
218
+ shape[3],
219
+ )
220
+
221
+ hidden_state, source = self.transformer(
222
+ source, positional_source, tokens
223
+ )
224
+ iou_token_out = hidden_state[:, 0, :]
225
+ mask_tokens_out = hidden_state[:, 1 : (1 + self.num_mask_tokens), :]
226
+
227
+ source = ops.reshape(source, (batch_dim, height, width, channels))
228
+ upscaled_embeddings = self.output_upscaling(source)
229
+ hyper_in_list = []
230
+ for i in range(self.num_mask_tokens):
231
+ hyper_in_list.append(
232
+ self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
233
+ )
234
+ hyper_in = ops.stack(hyper_in_list, axis=1)
235
+ shape = ops.shape(upscaled_embeddings)
236
+ batch_dim, height, width, channels = (
237
+ shape[0],
238
+ shape[1],
239
+ shape[2],
240
+ shape[3],
241
+ )
242
+ upscaled_embeddings = ops.reshape(
243
+ ops.transpose(upscaled_embeddings, axes=(0, 3, 1, 2)),
244
+ (batch_dim, channels, height * width),
245
+ )
246
+ masks = ops.reshape(
247
+ hyper_in @ upscaled_embeddings,
248
+ (batch_dim, self.num_mask_tokens, height, width),
249
+ )
250
+
251
+ iou_pred = self.iou_prediction_head(iou_token_out)
252
+
253
+ return masks, iou_pred
254
+
255
+ def get_config(self):
256
+ config = super().get_config()
257
+ config.update(
258
+ {
259
+ "hidden_size": self.hidden_size,
260
+ "num_layers": self.num_layers,
261
+ "intermediate_dim": self.intermediate_dim,
262
+ "num_heads": self.num_heads,
263
+ "embedding_dim": self.embedding_dim,
264
+ "num_multimask_outputs": self.num_multimask_outputs,
265
+ "iou_head_depth": self.iou_head_depth,
266
+ "iou_head_hidden_dim": self.iou_head_hidden_dim,
267
+ "activation": self.activation,
268
+ }
269
+ )
270
+ return config
@@ -0,0 +1,336 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import keras
16
+ from keras import ops
17
+
18
+ from keras_hub.src.api_export import keras_hub_export
19
+ from keras_hub.src.models.sam.sam_layers import (
20
+ RandomFrequencyPositionalEmbeddings,
21
+ )
22
+
23
+
24
+ @keras_hub_export("keras_hub.layers.SAMPromptEncoder")
25
+ class SAMPromptEncoder(keras.layers.Layer):
26
+ """Prompt Encoder for the Segment Anything Model (SAM).
27
+
28
+ The prompt encoder generates encodings for three types of prompts:
29
+ - Point prompts: Points on the image along with a label indicating whether
30
+ the point is in the foreground (part of the mask) or in the background
31
+ (not a part of the mask).
32
+ - Box prompts: A batch of bounding boxes with format [(x1, y1), (x2, y2)]
33
+ used to determine the location of the masks in the image.
34
+ - Masks: An input mask can be passed to refine the positional embeddings
35
+ for the output mask.
36
+
37
+ First, the point prompts and box prompts are concatenated and positional
38
+ encodings are generated using random spatial frequencies. A point is
39
+ represented as the sum of a positional encoding of the point's location
40
+ and one of two learned embeddings that indicate if the point is either in
41
+ the foreground or background. A box is represented by an embedding pair:
42
+ (1) the positional encoding of its top-left corner summed with a learned
43
+ embedding representing "top-left corner" and
44
+ (2) the same structure but using a learned embedding indicating
45
+ "bottom-right corner".
46
+ The box and point encodings are referred to as "prompt_sparse encodings"
47
+ If a mask prompt is passed, a convolutional neural net is used to
48
+ downscale it to generate "dense encodings". If no mask prompt is passed,
49
+ an embedding layer is used instead to generate a "no mask" embedding.
50
+
51
+
52
+ Args:
53
+ hidden_size: int, optional. The number of features in the output
54
+ embeddings. Defaults to `256`.
55
+ image_embedding_size: int, optional. The number of features in the
56
+ image embeddings generated by an image encoder. Defaults to
57
+ `(64, 64)`.
58
+ input_image_size: tuple[int], optional. A tuple of the height and
59
+ width of the image being prompted. Defaults to `(1024, 1024)`.
60
+ mask_in_channels: int, optional. The number of channels of the mask
61
+ prompt. Defaults to `16`.
62
+ activation: str, optional. The activation to use in the mask
63
+ downscaler neural net. Defaults to `"gelu"`.
64
+ """
65
+
66
+ def __init__(
67
+ self,
68
+ *,
69
+ hidden_size=256,
70
+ image_embedding_size=(64, 64),
71
+ input_image_size=(1024, 1024),
72
+ mask_in_channels=16,
73
+ activation="gelu",
74
+ **kwargs
75
+ ):
76
+ super().__init__(**kwargs)
77
+ self.hidden_size = hidden_size
78
+ self.image_embedding_size = image_embedding_size
79
+ self.input_image_size = input_image_size
80
+ self.mask_in_channels = mask_in_channels
81
+ self.activation = activation
82
+
83
+ self.positional_embedding_layer = RandomFrequencyPositionalEmbeddings(
84
+ num_positional_features=self.hidden_size // 2, scale=1
85
+ )
86
+
87
+ self.foreground_point_embed = keras.layers.Embedding(
88
+ 1, hidden_size, name="foreground_point_embed"
89
+ )
90
+ self.background_point_embed = keras.layers.Embedding(
91
+ 1, hidden_size, name="background_point_embed"
92
+ )
93
+ self.top_left_corner_embed = keras.layers.Embedding(
94
+ 1, hidden_size, name="top_left_corner_embed"
95
+ )
96
+ self.bottom_right_corner_embed = keras.layers.Embedding(
97
+ 1, hidden_size, name="bottom_right_corner_embed"
98
+ )
99
+ self.not_a_point_embed = keras.layers.Embedding(
100
+ 1, hidden_size, name="not_a_point_embed"
101
+ )
102
+
103
+ self.mask_downscaler = keras.models.Sequential(
104
+ [
105
+ keras.layers.Conv2D(
106
+ mask_in_channels // 4, kernel_size=2, strides=2
107
+ ),
108
+ keras.layers.LayerNormalization(epsilon=1e-6),
109
+ keras.layers.Activation(activation),
110
+ keras.layers.Conv2D(mask_in_channels, kernel_size=2, strides=2),
111
+ keras.layers.LayerNormalization(epsilon=1e-6),
112
+ keras.layers.Activation(activation),
113
+ keras.layers.Conv2D(hidden_size, kernel_size=1),
114
+ ],
115
+ name="mask_downscaler",
116
+ )
117
+ self.no_mask_embed = keras.layers.Embedding(
118
+ 1, hidden_size, name="no_mask_embed"
119
+ )
120
+
121
+ def build(
122
+ self,
123
+ points_shape=None,
124
+ labels_shape=None,
125
+ boxes_shape=None,
126
+ masks_shape=None,
127
+ ):
128
+ self.positional_embedding_layer.build()
129
+ for layer in [
130
+ self.foreground_point_embed,
131
+ self.background_point_embed,
132
+ self.top_left_corner_embed,
133
+ self.bottom_right_corner_embed,
134
+ self.not_a_point_embed,
135
+ self.no_mask_embed,
136
+ ]:
137
+ layer.build([None])
138
+ self.mask_downscaler.build(
139
+ [
140
+ None,
141
+ 4 * self.image_embedding_size[0],
142
+ 4 * self.image_embedding_size[1],
143
+ 1,
144
+ ]
145
+ )
146
+ self.built = True
147
+
148
+ def compute_output_shape(
149
+ self,
150
+ points_shape=None,
151
+ labels_shape=None,
152
+ boxes_shape=None,
153
+ masks_shape=None,
154
+ ):
155
+ batch_size = None
156
+ for shape in (points_shape, labels_shape, boxes_shape, masks_shape):
157
+ if shape is not None:
158
+ batch_size = shape[0]
159
+ break
160
+ return {
161
+ "prompt_sparse_embeddings": (
162
+ batch_size,
163
+ None,
164
+ self.hidden_size,
165
+ ),
166
+ "prompt_dense_embeddings": (
167
+ batch_size,
168
+ self.image_embedding_size[0],
169
+ self.image_embedding_size[1],
170
+ self.hidden_size,
171
+ ),
172
+ "prompt_dense_positional_embeddings": (
173
+ batch_size,
174
+ self.image_embedding_size[0],
175
+ self.image_embedding_size[1],
176
+ self.hidden_size,
177
+ ),
178
+ }
179
+
180
+ def _embed_points(self, points, labels):
181
+ points = points + 0.5
182
+ indices = ops.arange(1, dtype="int32")
183
+
184
+ point_embeddings = self.positional_embedding_layer.encode_coordinates(
185
+ points, self.input_image_size
186
+ )
187
+ labels = ops.broadcast_to(
188
+ labels[..., None], ops.shape(point_embeddings)
189
+ )
190
+ point_embeddings = ops.where(
191
+ labels == 0,
192
+ point_embeddings + self.background_point_embed(indices),
193
+ point_embeddings + self.foreground_point_embed(indices),
194
+ )
195
+ point_embeddings = ops.where(
196
+ labels == -1,
197
+ self.not_a_point_embed(indices),
198
+ point_embeddings,
199
+ )
200
+ return point_embeddings
201
+
202
+ def _embed_box(self, box):
203
+ shape = ops.shape(box)
204
+ batch_size, N = shape[0], shape[1]
205
+ box = box + 0.5
206
+ indices = ops.arange(1, dtype="int32")
207
+ corner_embedding = self.positional_embedding_layer.encode_coordinates(
208
+ box, self.input_image_size
209
+ )
210
+ top_left_embedding = corner_embedding[
211
+ :, :, 0, :
212
+ ] + self.top_left_corner_embed(indices)
213
+ bottom_right_embedding = corner_embedding[
214
+ :, :, 1, :
215
+ ] + self.bottom_right_corner_embed(indices)
216
+ corner_embedding = ops.stack(
217
+ [top_left_embedding, bottom_right_embedding], axis=2
218
+ )
219
+ return ops.reshape(
220
+ corner_embedding, (batch_size, N * 2, self.hidden_size)
221
+ )
222
+
223
+ def _embed_mask(self, mask):
224
+ mask_embedding = self.mask_downscaler(mask)
225
+ return mask_embedding
226
+
227
+ def call(
228
+ self, images=None, points=None, labels=None, boxes=None, masks=None
229
+ ):
230
+ # Get the batch shape based on any arbitrary input, because batch
231
+ # shapes must all match.
232
+ valid_inputs = [
233
+ x for x in (points, labels, boxes, masks) if x is not None
234
+ ]
235
+
236
+ batch_size = ops.shape(valid_inputs[0])[0]
237
+ if points is None:
238
+ points = ops.zeros((batch_size, 0, 2))
239
+ if labels is None:
240
+ labels = ops.zeros((batch_size, 0))
241
+ if boxes is None:
242
+ boxes = ops.zeros((batch_size, 0, 2, 2))
243
+ if masks is None:
244
+ masks = ops.zeros((batch_size, 0, 256, 256, 1))
245
+
246
+ # Compute point embeddings
247
+ point_embeddings = self._embed_points(points, labels)
248
+
249
+ # Compute box embeddings
250
+ box_embeddings = self._embed_box(boxes)
251
+
252
+ # Concatenate both into a sparse embeddings tensor
253
+ sparse_embeddings = ops.concatenate(
254
+ [point_embeddings, box_embeddings], axis=1
255
+ )
256
+
257
+ # Compute the mask embeddings
258
+ def _no_mask_embed():
259
+ reshaped_embed = ops.reshape(
260
+ self.no_mask_embed(ops.arange(1, dtype="int32")),
261
+ (1, 1, 1, self.hidden_size),
262
+ )
263
+ broadcasted_embed = ops.broadcast_to(
264
+ reshaped_embed,
265
+ shape=(
266
+ batch_size,
267
+ self.image_embedding_size[0],
268
+ self.image_embedding_size[1],
269
+ self.hidden_size,
270
+ ),
271
+ )
272
+ return broadcasted_embed
273
+
274
+ def _maybe_input_mask_embed():
275
+ # Keras passes the masks as concrete tensors for both the
276
+ # true and false functions to build the output shape. So, we
277
+ # need to handle the case when 0 size masks is passed and
278
+ # dispatch the call to `_no_mask_embed`. Note that we can't call
279
+ # the lambda directly since the inputs are bound to different
280
+ # values when called with concrete values.
281
+ if masks.shape[1] == 0:
282
+ return ops.broadcast_to(
283
+ ops.reshape(
284
+ self.no_mask_embed(ops.arange(1, dtype="int32")),
285
+ (1, 1, 1, self.hidden_size),
286
+ ),
287
+ shape=(
288
+ batch_size,
289
+ self.image_embedding_size[0],
290
+ self.image_embedding_size[1],
291
+ self.hidden_size,
292
+ ),
293
+ )
294
+ shape = ops.shape(masks)
295
+ BM, N, height, width, channels = (
296
+ shape[0],
297
+ shape[1],
298
+ shape[2],
299
+ shape[3],
300
+ shape[4],
301
+ )
302
+ return self._embed_mask(
303
+ ops.reshape(masks, (BM * N, height, width, channels))
304
+ )
305
+
306
+ dense_embeddings = ops.cond(
307
+ ops.equal(ops.size(masks), 0),
308
+ _no_mask_embed,
309
+ _maybe_input_mask_embed,
310
+ )
311
+
312
+ # Compute the dense positional embeddings
313
+ prompt_dense_positional_embeddings = (
314
+ self.positional_embedding_layer.encode_image(
315
+ self.image_embedding_size
316
+ )[None, ...]
317
+ )
318
+
319
+ return {
320
+ "prompt_sparse_embeddings": sparse_embeddings,
321
+ "prompt_dense_embeddings": dense_embeddings,
322
+ "prompt_dense_positional_embeddings": prompt_dense_positional_embeddings,
323
+ }
324
+
325
+ def get_config(self):
326
+ config = super().get_config()
327
+ config.update(
328
+ {
329
+ "hidden_size": self.hidden_size,
330
+ "image_embedding_size": self.image_embedding_size,
331
+ "input_image_size": self.input_image_size,
332
+ "mask_in_channels": self.mask_in_channels,
333
+ "activation": self.activation,
334
+ }
335
+ )
336
+ return config