keras-hub-nightly 0.16.1.dev202409240339__py3-none-any.whl → 0.16.1.dev202409260340__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. keras_hub/api/layers/__init__.py +5 -0
  2. keras_hub/api/models/__init__.py +19 -0
  3. keras_hub/api/tokenizers/__init__.py +1 -0
  4. keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_encoder_block.py +8 -2
  5. keras_hub/src/models/clip/clip_preprocessor.py +147 -0
  6. keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_text_encoder.py +60 -57
  7. keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_tokenizer.py +69 -30
  8. keras_hub/src/models/densenet/__init__.py +6 -0
  9. keras_hub/src/models/densenet/densenet_backbone.py +11 -8
  10. keras_hub/src/models/densenet/densenet_image_classifier.py +27 -4
  11. keras_hub/src/models/densenet/densenet_image_classifier_preprocessor.py +27 -0
  12. keras_hub/src/models/densenet/densenet_image_converter.py +23 -0
  13. keras_hub/src/models/densenet/densenet_presets.py +56 -0
  14. keras_hub/src/models/image_segmenter.py +86 -0
  15. keras_hub/src/models/sam/__init__.py +13 -0
  16. keras_hub/src/models/sam/sam_backbone.py +153 -0
  17. keras_hub/src/models/sam/sam_image_segmenter.py +237 -0
  18. keras_hub/src/models/sam/sam_layers.py +402 -0
  19. keras_hub/src/models/sam/sam_mask_decoder.py +270 -0
  20. keras_hub/src/models/sam/sam_prompt_encoder.py +336 -0
  21. keras_hub/src/models/sam/sam_transformer.py +159 -0
  22. keras_hub/src/models/stable_diffusion_3/__init__.py +13 -0
  23. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +93 -0
  24. keras_hub/src/models/{stable_diffusion_v3 → stable_diffusion_3}/mmdit.py +351 -26
  25. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +630 -0
  26. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +151 -0
  27. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +77 -0
  28. keras_hub/src/models/{stable_diffusion_v3/t5_xxl_text_encoder.py → stable_diffusion_3/t5_encoder.py} +7 -7
  29. keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +333 -0
  30. keras_hub/src/models/{stable_diffusion_v3/t5_xxl_preprocessor.py → t5/t5_preprocessor.py} +12 -3
  31. keras_hub/src/models/text_to_image.py +295 -0
  32. keras_hub/src/models/vit_det/vit_det_backbone.py +17 -12
  33. keras_hub/src/utils/timm/convert_densenet.py +107 -0
  34. keras_hub/src/utils/timm/preset_loader.py +3 -0
  35. keras_hub/src/version_utils.py +1 -1
  36. {keras_hub_nightly-0.16.1.dev202409240339.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/METADATA +1 -1
  37. {keras_hub_nightly-0.16.1.dev202409240339.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/RECORD +40 -24
  38. keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +0 -93
  39. keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +0 -317
  40. keras_hub/src/models/stable_diffusion_v3/vae_attention.py +0 -126
  41. keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +0 -186
  42. /keras_hub/src/models/{stable_diffusion_v3 → clip}/__init__.py +0 -0
  43. {keras_hub_nightly-0.16.1.dev202409240339.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/WHEEL +0 -0
  44. {keras_hub_nightly-0.16.1.dev202409240339.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,237 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import numpy as np
16
+ from keras import ops
17
+
18
+ from keras_hub.src.api_export import keras_hub_export
19
+ from keras_hub.src.models.image_segmenter import ImageSegmenter
20
+ from keras_hub.src.models.sam.sam_backbone import SAMBackbone
21
+
22
+
23
+ @keras_hub_export("keras_hub.models.SAMImageSegmenter")
24
+ class SAMImageSegmenter(ImageSegmenter):
25
+ """The Segment Anything (SAM) image segmenter Model.
26
+
27
+ SAM works by prompting the input images. There are three ways to prompt:
28
+ (1) Labelled Points: Foreground points (points with label 1) are encoded
29
+ such that the output masks generated by the mask decoder contain them
30
+ and background points (points with label 0) are encoded such that the
31
+ generated masks don't contain them.
32
+ (2) Box: A box tells the model which part/crop of the image to segment.
33
+ (3) Mask: An input mask can be used to refine the output of the mask
34
+ decoder.
35
+ These prompts can be mixed and matched but at least one of the prompts
36
+ must be present. To turn off a particular prompt, simply exclude it from
37
+ the inputs to the model.
38
+ (1) For points prompts, the expected shape is `(batch, num_points, 2)`.
39
+ The labels must have a corresponding shape of `(batch, num_points)`.
40
+ (2) For box prompt, the expected shape is `(batch, 1, 2, 2)`.
41
+ (3) Similarly, mask prompts have shape `(batch, 1, H, W, 1)`.
42
+
43
+
44
+ Args:
45
+ backbone: A `keras_hub.models.VGGBackbone` instance.
46
+
47
+ Example:
48
+ Load pretrained model using `from_preset`.
49
+
50
+ ```python
51
+ image_size=128
52
+ batch_size=2
53
+ input_data = {
54
+ "images": np.ones(
55
+ (batch_size, image_size, image_size, 3),
56
+ dtype="float32",
57
+ ),
58
+ "points": np.ones((batch_size, 1, 2), dtype="float32"),
59
+ "labels": np.ones((batch_size, 1), dtype="float32"),
60
+ "boxes": np.ones((batch_size, 1, 2, 2), dtype="float32"),
61
+ "masks": np.zeros(
62
+ (batch_size, 0, image_size, image_size, 1)
63
+ ),
64
+ }
65
+ # todo: update preset name
66
+ sam = keras_hub.models.SAMImageSegmenter.from_preset(`sam_base`)
67
+ sam(input_data)
68
+ ```
69
+
70
+ Load segment anything image segmenter with custom backbone
71
+
72
+ ```python
73
+ image_size = 128
74
+ batch_size = 2
75
+ images = np.ones(
76
+ (batch_size, image_size, image_size, 3),
77
+ dtype="float32",
78
+ )
79
+ image_encoder = ViTDetBackbone(
80
+ hidden_size=16,
81
+ num_layers=16,
82
+ intermediate_dim=16 * 4,
83
+ num_heads=16,
84
+ global_attention_layer_indices=[2, 5, 8, 11],
85
+ patch_size=16,
86
+ num_output_channels=8,
87
+ window_size=2,
88
+ image_shape=(image_size, image_size, 3),
89
+ )
90
+ prompt_encoder = SAMPromptEncoder(
91
+ hidden_size=8,
92
+ image_embedding_size=(8, 8),
93
+ input_image_size=(
94
+ image_size,
95
+ image_size,
96
+ ),
97
+ mask_in_channels=16,
98
+ )
99
+ mask_decoder = SAMMaskDecoder(
100
+ num_layers=2,
101
+ hidden_size=8,
102
+ intermediate_dim=32,
103
+ num_heads=8,
104
+ embedding_dim=8,
105
+ num_multimask_outputs=3,
106
+ iou_head_depth=3,
107
+ iou_head_hidden_dim=8,
108
+ )
109
+ backbone = SAMBackbone(
110
+ image_encoder=image_encoder,
111
+ prompt_encoder=prompt_encoder,
112
+ mask_decoder=mask_decoder,
113
+ image_shape=(image_size, image_size, 3),
114
+ )
115
+ sam = SAMImageSegmenter(
116
+ backbone=backbone
117
+ )
118
+ ```
119
+
120
+ For example, to pass in all the prompts, do:
121
+
122
+ ```python
123
+
124
+ points = np.array([[[512., 512.], [100., 100.]]])
125
+ # For labels: 1 means foreground point, 0 means background
126
+ labels = np.array([[1., 0.]])
127
+ box = np.array([[[[384., 384.], [640., 640.]]]])
128
+ input_mask = np.ones((1, 1, 256, 256, 1))
129
+ Prepare an input dictionary:
130
+ inputs = {
131
+ "images": image,
132
+ "points": points,
133
+ "labels": labels,
134
+ "boxes": box,
135
+ "masks": input_mask
136
+ }
137
+ outputs = sam.predict(inputs)
138
+ masks, iou_pred = outputs["masks"], outputs["iou_pred"]
139
+ ```
140
+
141
+ The first mask in the output `masks` (i.e. `masks[:, 0, ...]`) is the best
142
+ mask predicted by the model based on the prompts. Other `masks`
143
+ (i.e. `masks[:, 1:, ...]`) are alternate predictions that can be used if
144
+ they are desired over the first one.
145
+ Now, in case of only points and box prompts, simply exclude the masks:
146
+
147
+ ```python
148
+ inputs = {
149
+ "images": image,
150
+ "points": points,
151
+ "labels": labels,
152
+ "boxes": box,
153
+ }
154
+
155
+ outputs = sam.predict(inputs)
156
+ masks, iou_pred = outputs["masks"], outputs["iou_pred"]
157
+ ```
158
+
159
+ Another example is that only points prompts are present.
160
+ Note that if point prompts are present but no box prompt is present, the
161
+ points must be padded using a zero point and -1 label:
162
+
163
+ ```python
164
+ padded_points = np.concatenate(
165
+ [points, np.zeros((1, 1, 2))], axis=1
166
+ )
167
+
168
+ padded_labels = np.concatenate(
169
+ [labels, -np.ones((1, 1))], axis=1
170
+ )
171
+ inputs = {
172
+ "images": image,
173
+ "points": padded_points,
174
+ "labels": padded_labels,
175
+ }
176
+ outputs = sam.predict(inputs)
177
+ masks, iou_pred = outputs["masks"], outputs["iou_pred"]
178
+ ```
179
+ """
180
+
181
+ backbone_cls = SAMBackbone
182
+ preprocessor_cls = None
183
+
184
+ def __init__(self, backbone, preprocessor=None, **kwargs):
185
+ # The implementation has been adapted form [Segment Anything
186
+ # paper](https://arxiv.org/abs/2304.02643) and [Segment Anything
187
+ # GitHub](https://github.com/facebookresearch/segment-anything) and
188
+ # [Detectron2](https://github.com/facebookresearch/detectron2).
189
+ # === Layers ===
190
+ self.backbone = backbone
191
+ # === Functional Model ===
192
+ inputs = self.backbone.input
193
+ x = self.backbone(inputs)
194
+ outputs = self.backbone.mask_decoder(**x)
195
+ super().__init__(inputs=inputs, outputs=outputs, **kwargs)
196
+
197
+ def predict_step(self, *args, **kwargs):
198
+ if len(args) == 2:
199
+ args = (args[0], self._add_placeholder_prompts(args[-1]))
200
+ else:
201
+ args = (self._add_placeholder_prompts(args[0]),)
202
+
203
+ return super().predict_step(*args, **kwargs)
204
+
205
+ def fit(self, *args, **kwargs):
206
+ raise NotImplementedError(
207
+ "Segment Anything Model only supports inference for now. Training"
208
+ " the model isn't supported yet."
209
+ )
210
+
211
+ def _add_placeholder_prompts(self, inputs):
212
+ """Adds placeholder prompt inputs for a call to SAM.
213
+
214
+ Because SAM is a functional subclass model, all inputs must be specified in
215
+ calls to the model. However, prompt inputs are all optional, so we have to
216
+ add placeholders when they're not specified by the user.
217
+ """
218
+ inputs = inputs.copy()
219
+
220
+ # Get the batch shape based on the image input
221
+ batch_size = ops.shape(inputs["images"])[0]
222
+
223
+ # The type of the placeholders must match the existing inputs with respect
224
+ # to whether or not they are tensors (as opposed to Numpy arrays).
225
+ zeros = ops.zeros if ops.is_tensor(inputs["images"]) else np.zeros
226
+
227
+ # Fill in missing inputs.
228
+ if "points" not in inputs:
229
+ inputs["points"] = zeros((batch_size, 0, 2))
230
+ if "labels" not in inputs:
231
+ inputs["labels"] = zeros((batch_size, 0))
232
+ if "boxes" not in inputs:
233
+ inputs["boxes"] = zeros((batch_size, 0, 2, 2))
234
+ if "masks" not in inputs:
235
+ inputs["masks"] = zeros((batch_size, 0, 256, 256, 1))
236
+
237
+ return inputs
@@ -0,0 +1,402 @@
1
+ # Copyright 2024 The KerasHub Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+
17
+ import keras
18
+ from keras import ops
19
+
20
+
21
+ class MLP(keras.layers.Layer):
22
+ """A MLP block with architecture.
23
+
24
+ `input_dim -> [hidden_dim] * (num_layers - 1) -> output_dim`.
25
+
26
+ Args:
27
+ hidden_dim: int. The number of units in the hidden layers.
28
+ output_dim: int. The number of units in the output layer.
29
+ num_layers: int. The total number of dense layers to use.
30
+ activation: str. Activation to use in the hidden layers.
31
+ Default is `"relu"`.
32
+ """
33
+
34
+ def __init__(
35
+ self, hidden_dim, output_dim, num_layers, activation="relu", **kwargs
36
+ ):
37
+ super().__init__(**kwargs)
38
+ self.hidden_dim = hidden_dim
39
+ self.output_dim = output_dim
40
+ self.num_layers = num_layers
41
+ self.activation = activation
42
+ h = [hidden_dim] * (num_layers - 1)
43
+ self.mlp_block = []
44
+ for hidden_dim in h:
45
+ self.mlp_block.append(
46
+ keras.layers.Dense(hidden_dim, dtype=self.dtype_policy)
47
+ )
48
+ self.mlp_block.append(
49
+ keras.layers.Activation(activation, dtype=self.dtype_policy)
50
+ )
51
+ self.mlp_block.append(
52
+ keras.layers.Dense(output_dim, dtype=self.dtype_policy)
53
+ )
54
+ self.mlp_block = keras.models.Sequential(self.mlp_block)
55
+
56
+ def build(self, input_shape):
57
+ self.mlp_block.build(input_shape)
58
+ self.built = True
59
+
60
+ def call(self, x):
61
+ return self.mlp_block(x)
62
+
63
+ def get_config(self):
64
+ config = super().get_config()
65
+ config.update(
66
+ {
67
+ "hidden_dim": self.hidden_dim,
68
+ "output_dim": self.output_dim,
69
+ "num_layers": self.num_layers,
70
+ "activation": self.activation,
71
+ }
72
+ )
73
+ return config
74
+
75
+
76
+ class MultiHeadAttentionWithDownsampling(keras.layers.Layer):
77
+ """Multi-Head Attention with downsampling.
78
+
79
+ An attention layer that allows for downscaling the size of the embedding
80
+ after projection to queries, keys, and values.
81
+ This layer first downscales the features of input queries, keys, and
82
+ values using a dense layer. Multi-head attention is then performed
83
+ and the attention map is projected back (upscaled) to the number of
84
+ input features.
85
+
86
+ Args:
87
+ num_heads: int. Number of attention heads.
88
+ key_dim: int. Size of each attention head for query, key, and
89
+ value.
90
+ downsample_rate: int, optional. The factor by which to downscale the
91
+ input features i.e. the input features of size `key_dim` are
92
+ projected down to `key_dim // downsample_rate`.
93
+ """
94
+
95
+ def __init__(self, num_heads, key_dim, downsample_rate=1, **kwargs):
96
+ super().__init__(**kwargs)
97
+ self.num_heads = num_heads
98
+ self.key_dim = key_dim
99
+ self.downsample_rate = downsample_rate
100
+ self.internal_dims = key_dim // downsample_rate
101
+
102
+ # Downsample
103
+ self.query_proj = keras.layers.Dense(
104
+ self.internal_dims * self.num_heads, dtype=self.dtype_policy
105
+ )
106
+ self.key_proj = keras.layers.Dense(
107
+ self.internal_dims * self.num_heads, dtype=self.dtype_policy
108
+ )
109
+ self.value_proj = keras.layers.Dense(
110
+ self.internal_dims * self.num_heads, dtype=self.dtype_policy
111
+ )
112
+
113
+ # Upsample
114
+ self.out_proj = keras.layers.Dense(
115
+ self.key_dim * self.num_heads, dtype=self.dtype_policy
116
+ )
117
+
118
+ def build(self, input_shape=None):
119
+ self.query_proj.build([None, None, self.num_heads * self.key_dim])
120
+ self.key_proj.build([None, None, self.num_heads * self.key_dim])
121
+ self.value_proj.build([None, None, self.num_heads * self.key_dim])
122
+ self.out_proj.build([None, None, self.internal_dims * self.num_heads])
123
+ self.built = True
124
+
125
+ def _separate_heads(self, x):
126
+ shape = ops.shape(x)
127
+ batch_size, N, channels = shape[0], shape[1], shape[2]
128
+ x = ops.reshape(
129
+ x, (batch_size, N, self.num_heads, channels // self.num_heads)
130
+ )
131
+ return ops.transpose(x, axes=(0, 2, 1, 3))
132
+
133
+ def _recombine_heads(self, x):
134
+ shape = ops.shape(x)
135
+ batch_size, num_heads, N_T, channels_per_head = (
136
+ shape[0],
137
+ shape[1],
138
+ shape[2],
139
+ shape[3],
140
+ )
141
+ x = ops.transpose(x, axes=(0, 2, 1, 3))
142
+ return ops.reshape(x, (batch_size, N_T, num_heads * channels_per_head))
143
+
144
+ def call(self, query, value, key):
145
+ query = self.query_proj(query)
146
+ key = self.key_proj(key)
147
+ value = self.value_proj(value)
148
+
149
+ # Separate into heads
150
+ query = self._separate_heads(query)
151
+ key = self._separate_heads(key)
152
+ value = self._separate_heads(value)
153
+
154
+ # Attention
155
+ channels_per_head = ops.shape(query)[-1]
156
+ out = ops.matmul(query, ops.transpose(key, (0, 1, 3, 2)))
157
+ out = out / ops.sqrt(
158
+ ops.cast(channels_per_head, dtype=self.compute_dtype)
159
+ )
160
+ out = ops.softmax(out, axis=-1)
161
+
162
+ # Get output
163
+ attention_map = out @ value
164
+ attention_map = self._recombine_heads(attention_map)
165
+ return self.out_proj(attention_map)
166
+
167
+ def get_config(self):
168
+ config = super().get_config()
169
+ config.update(
170
+ {
171
+ "num_heads": self.num_heads,
172
+ "key_dim": self.key_dim,
173
+ "downsample_rate": self.downsample_rate,
174
+ }
175
+ )
176
+ return config
177
+
178
+
179
+ class TwoWayMultiHeadAttention(keras.layers.Layer):
180
+ """Two-way multi-head attention layer.
181
+
182
+ Args:
183
+ num_heads: int. Number of attention heads.
184
+ key_dim: int. Size of each attention head for query, key, and
185
+ value.
186
+ intermediate_dim: int. Number of hidden dims to use in the mlp block.
187
+ skip_first_layer_pos_embedding: bool. A boolean indicating whether to skip the
188
+ first layer positional embeddings.
189
+ attention_downsample_rate: int, optional. The downsample rate to use
190
+ in the attention layers. Defaults to 2.
191
+ activation: str, optional. The activation for the mlp block's output
192
+ layer. Defaults to "relu".
193
+ """
194
+
195
+ def __init__(
196
+ self,
197
+ num_heads,
198
+ key_dim,
199
+ intermediate_dim,
200
+ skip_first_layer_pos_embedding,
201
+ attention_downsample_rate=2,
202
+ activation="relu",
203
+ **kwargs,
204
+ ):
205
+ super().__init__(**kwargs)
206
+ self.num_heads = num_heads
207
+ self.key_dim = key_dim
208
+ self.intermediate_dim = intermediate_dim
209
+ self.skip_first_layer_pos_embedding = skip_first_layer_pos_embedding
210
+ self.attention_downsample_rate = attention_downsample_rate
211
+ self.activation = activation
212
+
213
+ self.self_attention = MultiHeadAttentionWithDownsampling(
214
+ num_heads=num_heads, key_dim=key_dim, dtype=self.dtype_policy
215
+ )
216
+ self.layer_norm1 = keras.layers.LayerNormalization(
217
+ epsilon=1e-5, dtype=self.dtype_policy
218
+ )
219
+ self.cross_attention_token_to_image = (
220
+ MultiHeadAttentionWithDownsampling(
221
+ num_heads=num_heads,
222
+ key_dim=key_dim,
223
+ downsample_rate=attention_downsample_rate,
224
+ dtype=self.dtype_policy,
225
+ )
226
+ )
227
+ self.layer_norm2 = keras.layers.LayerNormalization(
228
+ epsilon=1e-5, dtype=self.dtype_policy
229
+ )
230
+
231
+ self.mlp_block = MLP(
232
+ intermediate_dim,
233
+ key_dim * num_heads,
234
+ num_layers=2,
235
+ activation=activation,
236
+ dtype=self.dtype_policy,
237
+ )
238
+
239
+ self.layer_norm3 = keras.layers.LayerNormalization(
240
+ epsilon=1e-5, dtype=self.dtype_policy
241
+ )
242
+ self.cross_attention_image_to_token = (
243
+ MultiHeadAttentionWithDownsampling(
244
+ num_heads=num_heads,
245
+ key_dim=key_dim,
246
+ downsample_rate=attention_downsample_rate,
247
+ dtype=self.dtype_policy,
248
+ )
249
+ )
250
+ self.layer_norm4 = keras.layers.LayerNormalization(
251
+ epsilon=1e-5, dtype=self.dtype_policy
252
+ )
253
+
254
+ def build(self, input_shape=None):
255
+ self.self_attention.build()
256
+ self.layer_norm1.build([None, None, self.num_heads * self.key_dim])
257
+ self.cross_attention_token_to_image.build()
258
+ self.layer_norm2.build([None, None, self.num_heads * self.key_dim])
259
+ self.mlp_block.build([None, None, self.num_heads * self.key_dim])
260
+ self.layer_norm3.build([None, None, self.num_heads * self.key_dim])
261
+ self.cross_attention_image_to_token.build()
262
+ self.layer_norm4.build([None, None, self.num_heads * self.key_dim])
263
+ self.built = True
264
+
265
+ def call(self, queries, keys, query_pos_embedding, key_pos_embedding):
266
+ if self.skip_first_layer_pos_embedding:
267
+ queries = self.self_attention(
268
+ query=queries, value=queries, key=queries
269
+ )
270
+ else:
271
+ queries_with_pos_embedding = queries + query_pos_embedding
272
+ attention_map = self.self_attention(
273
+ query=queries_with_pos_embedding,
274
+ key=queries_with_pos_embedding,
275
+ value=queries,
276
+ )
277
+ queries = queries + attention_map
278
+ queries = self.layer_norm1(queries)
279
+
280
+ queries_with_pos_embedding = queries + query_pos_embedding
281
+ keys_with_pos_embedding = keys + key_pos_embedding
282
+ attention_map = self.cross_attention_token_to_image(
283
+ query=queries_with_pos_embedding,
284
+ key=keys_with_pos_embedding,
285
+ value=keys,
286
+ )
287
+ queries = queries + attention_map
288
+ queries = self.layer_norm2(queries)
289
+
290
+ mlp_out = self.mlp_block(queries)
291
+ queries = queries + mlp_out
292
+ queries = self.layer_norm3(queries)
293
+
294
+ queries_with_pos_embedding = queries + query_pos_embedding
295
+ keys_with_pos_embedding = keys + key_pos_embedding
296
+ attention_map = self.cross_attention_image_to_token(
297
+ query=keys_with_pos_embedding,
298
+ key=queries_with_pos_embedding,
299
+ value=queries,
300
+ )
301
+ keys = keys + attention_map
302
+ keys = self.layer_norm4(keys)
303
+
304
+ return queries, keys
305
+
306
+ def get_config(self):
307
+ config = super().get_config()
308
+ config.update(
309
+ {
310
+ "num_heads": self.num_heads,
311
+ "key_dim": self.key_dim,
312
+ "intermediate_dim": self.intermediate_dim,
313
+ "skip_first_layer_pos_embedding": self.skip_first_layer_pos_embedding,
314
+ "attention_downsample_rate": self.attention_downsample_rate,
315
+ "activation": self.activation,
316
+ }
317
+ )
318
+ return config
319
+
320
+
321
+ class RandomFrequencyPositionalEmbeddings(keras.layers.Layer):
322
+ """Positional encoding using random spatial frequencies.
323
+
324
+ This layer maps coordinates/points in 2D space to positional
325
+ encodings using random spatial frequencies.
326
+
327
+ Args:
328
+ num_positional_features: int. Number of positional features
329
+ in the output.
330
+ scale: float. The standard deviation of the random frequencies.
331
+ """
332
+
333
+ def __init__(self, num_positional_features, scale, **kwargs):
334
+ super().__init__(**kwargs)
335
+ self.num_positional_features = num_positional_features
336
+ self.scale = scale
337
+ self.positional_encoding_gaussian_matrix = self.add_weight(
338
+ name="positional_encoding_gaussian_matrix",
339
+ shape=(2, self.num_positional_features),
340
+ dtype=self.variable_dtype,
341
+ trainable=False,
342
+ initializer=keras.initializers.get("normal"),
343
+ )
344
+
345
+ def build(self, input_shape=None):
346
+ self.built = True
347
+
348
+ def _positional_encodings(self, coords):
349
+ coords = coords * 2 - 1
350
+ coords = coords @ ops.cast(
351
+ self.positional_encoding_gaussian_matrix, dtype=self.compute_dtype
352
+ )
353
+ coords = coords * (2 * math.pi)
354
+ return ops.concatenate([ops.sin(coords), ops.cos(coords)], axis=-1)
355
+
356
+ def call(self, size):
357
+ return self.encode_image(size)
358
+
359
+ def encode_image(self, size):
360
+ """Generate a positional encoding for an image of any given size.
361
+ Args:
362
+ size: tuple[int, int]. The size of the image.
363
+ Returns:
364
+ tensor: Positional encoding of the image.
365
+ """
366
+ height, width = size
367
+ grid = ops.ones(shape=(height, width), dtype=self.compute_dtype)
368
+ y_embed = ops.cumsum(grid, axis=0) - 0.5
369
+ x_embed = ops.cumsum(grid, axis=1) - 0.5
370
+ y_embed = y_embed / ops.cast(height, self.compute_dtype)
371
+ x_embed = x_embed / ops.cast(width, self.compute_dtype)
372
+ return self._positional_encodings(
373
+ ops.stack([x_embed, y_embed], axis=-1)
374
+ )
375
+
376
+ def encode_coordinates(self, coords_input, image_size):
377
+ """Positionally encode points that are not normalized to `[0, 1]`.
378
+ Args:
379
+ coords_input: tensor. 2D coordinates/points to map.
380
+ image_size: tuple[int, int]. Height and width of the image
381
+ being prompted.
382
+ Returns:
383
+ tensor: Positional encodings of the normalized coordinates.
384
+ """
385
+ coords_normalized = ops.stack(
386
+ [
387
+ coords_input[..., 0] / image_size[1],
388
+ coords_input[..., 1] / image_size[0],
389
+ ],
390
+ axis=-1,
391
+ )
392
+ return self._positional_encodings(coords_normalized)
393
+
394
+ def get_config(self):
395
+ config = super().get_config()
396
+ config.update(
397
+ {
398
+ "num_positional_features": self.num_positional_features,
399
+ "scale": self.scale,
400
+ }
401
+ )
402
+ return config