keras-hub-nightly 0.16.1.dev202409240339__py3-none-any.whl → 0.16.1.dev202409260340__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +5 -0
- keras_hub/api/models/__init__.py +19 -0
- keras_hub/api/tokenizers/__init__.py +1 -0
- keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_encoder_block.py +8 -2
- keras_hub/src/models/clip/clip_preprocessor.py +147 -0
- keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_text_encoder.py +60 -57
- keras_hub/src/models/{stable_diffusion_v3 → clip}/clip_tokenizer.py +69 -30
- keras_hub/src/models/densenet/__init__.py +6 -0
- keras_hub/src/models/densenet/densenet_backbone.py +11 -8
- keras_hub/src/models/densenet/densenet_image_classifier.py +27 -4
- keras_hub/src/models/densenet/densenet_image_classifier_preprocessor.py +27 -0
- keras_hub/src/models/densenet/densenet_image_converter.py +23 -0
- keras_hub/src/models/densenet/densenet_presets.py +56 -0
- keras_hub/src/models/image_segmenter.py +86 -0
- keras_hub/src/models/sam/__init__.py +13 -0
- keras_hub/src/models/sam/sam_backbone.py +153 -0
- keras_hub/src/models/sam/sam_image_segmenter.py +237 -0
- keras_hub/src/models/sam/sam_layers.py +402 -0
- keras_hub/src/models/sam/sam_mask_decoder.py +270 -0
- keras_hub/src/models/sam/sam_prompt_encoder.py +336 -0
- keras_hub/src/models/sam/sam_transformer.py +159 -0
- keras_hub/src/models/stable_diffusion_3/__init__.py +13 -0
- keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +93 -0
- keras_hub/src/models/{stable_diffusion_v3 → stable_diffusion_3}/mmdit.py +351 -26
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +630 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +151 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +77 -0
- keras_hub/src/models/{stable_diffusion_v3/t5_xxl_text_encoder.py → stable_diffusion_3/t5_encoder.py} +7 -7
- keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +333 -0
- keras_hub/src/models/{stable_diffusion_v3/t5_xxl_preprocessor.py → t5/t5_preprocessor.py} +12 -3
- keras_hub/src/models/text_to_image.py +295 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +17 -12
- keras_hub/src/utils/timm/convert_densenet.py +107 -0
- keras_hub/src/utils/timm/preset_loader.py +3 -0
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.16.1.dev202409240339.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.16.1.dev202409240339.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/RECORD +40 -24
- keras_hub/src/models/stable_diffusion_v3/clip_preprocessor.py +0 -93
- keras_hub/src/models/stable_diffusion_v3/mmdit_block.py +0 -317
- keras_hub/src/models/stable_diffusion_v3/vae_attention.py +0 -126
- keras_hub/src/models/stable_diffusion_v3/vae_image_decoder.py +0 -186
- /keras_hub/src/models/{stable_diffusion_v3 → clip}/__init__.py +0 -0
- {keras_hub_nightly-0.16.1.dev202409240339.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.16.1.dev202409240339.dist-info → keras_hub_nightly-0.16.1.dev202409260340.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,237 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import numpy as np
|
16
|
+
from keras import ops
|
17
|
+
|
18
|
+
from keras_hub.src.api_export import keras_hub_export
|
19
|
+
from keras_hub.src.models.image_segmenter import ImageSegmenter
|
20
|
+
from keras_hub.src.models.sam.sam_backbone import SAMBackbone
|
21
|
+
|
22
|
+
|
23
|
+
@keras_hub_export("keras_hub.models.SAMImageSegmenter")
|
24
|
+
class SAMImageSegmenter(ImageSegmenter):
|
25
|
+
"""The Segment Anything (SAM) image segmenter Model.
|
26
|
+
|
27
|
+
SAM works by prompting the input images. There are three ways to prompt:
|
28
|
+
(1) Labelled Points: Foreground points (points with label 1) are encoded
|
29
|
+
such that the output masks generated by the mask decoder contain them
|
30
|
+
and background points (points with label 0) are encoded such that the
|
31
|
+
generated masks don't contain them.
|
32
|
+
(2) Box: A box tells the model which part/crop of the image to segment.
|
33
|
+
(3) Mask: An input mask can be used to refine the output of the mask
|
34
|
+
decoder.
|
35
|
+
These prompts can be mixed and matched but at least one of the prompts
|
36
|
+
must be present. To turn off a particular prompt, simply exclude it from
|
37
|
+
the inputs to the model.
|
38
|
+
(1) For points prompts, the expected shape is `(batch, num_points, 2)`.
|
39
|
+
The labels must have a corresponding shape of `(batch, num_points)`.
|
40
|
+
(2) For box prompt, the expected shape is `(batch, 1, 2, 2)`.
|
41
|
+
(3) Similarly, mask prompts have shape `(batch, 1, H, W, 1)`.
|
42
|
+
|
43
|
+
|
44
|
+
Args:
|
45
|
+
backbone: A `keras_hub.models.VGGBackbone` instance.
|
46
|
+
|
47
|
+
Example:
|
48
|
+
Load pretrained model using `from_preset`.
|
49
|
+
|
50
|
+
```python
|
51
|
+
image_size=128
|
52
|
+
batch_size=2
|
53
|
+
input_data = {
|
54
|
+
"images": np.ones(
|
55
|
+
(batch_size, image_size, image_size, 3),
|
56
|
+
dtype="float32",
|
57
|
+
),
|
58
|
+
"points": np.ones((batch_size, 1, 2), dtype="float32"),
|
59
|
+
"labels": np.ones((batch_size, 1), dtype="float32"),
|
60
|
+
"boxes": np.ones((batch_size, 1, 2, 2), dtype="float32"),
|
61
|
+
"masks": np.zeros(
|
62
|
+
(batch_size, 0, image_size, image_size, 1)
|
63
|
+
),
|
64
|
+
}
|
65
|
+
# todo: update preset name
|
66
|
+
sam = keras_hub.models.SAMImageSegmenter.from_preset(`sam_base`)
|
67
|
+
sam(input_data)
|
68
|
+
```
|
69
|
+
|
70
|
+
Load segment anything image segmenter with custom backbone
|
71
|
+
|
72
|
+
```python
|
73
|
+
image_size = 128
|
74
|
+
batch_size = 2
|
75
|
+
images = np.ones(
|
76
|
+
(batch_size, image_size, image_size, 3),
|
77
|
+
dtype="float32",
|
78
|
+
)
|
79
|
+
image_encoder = ViTDetBackbone(
|
80
|
+
hidden_size=16,
|
81
|
+
num_layers=16,
|
82
|
+
intermediate_dim=16 * 4,
|
83
|
+
num_heads=16,
|
84
|
+
global_attention_layer_indices=[2, 5, 8, 11],
|
85
|
+
patch_size=16,
|
86
|
+
num_output_channels=8,
|
87
|
+
window_size=2,
|
88
|
+
image_shape=(image_size, image_size, 3),
|
89
|
+
)
|
90
|
+
prompt_encoder = SAMPromptEncoder(
|
91
|
+
hidden_size=8,
|
92
|
+
image_embedding_size=(8, 8),
|
93
|
+
input_image_size=(
|
94
|
+
image_size,
|
95
|
+
image_size,
|
96
|
+
),
|
97
|
+
mask_in_channels=16,
|
98
|
+
)
|
99
|
+
mask_decoder = SAMMaskDecoder(
|
100
|
+
num_layers=2,
|
101
|
+
hidden_size=8,
|
102
|
+
intermediate_dim=32,
|
103
|
+
num_heads=8,
|
104
|
+
embedding_dim=8,
|
105
|
+
num_multimask_outputs=3,
|
106
|
+
iou_head_depth=3,
|
107
|
+
iou_head_hidden_dim=8,
|
108
|
+
)
|
109
|
+
backbone = SAMBackbone(
|
110
|
+
image_encoder=image_encoder,
|
111
|
+
prompt_encoder=prompt_encoder,
|
112
|
+
mask_decoder=mask_decoder,
|
113
|
+
image_shape=(image_size, image_size, 3),
|
114
|
+
)
|
115
|
+
sam = SAMImageSegmenter(
|
116
|
+
backbone=backbone
|
117
|
+
)
|
118
|
+
```
|
119
|
+
|
120
|
+
For example, to pass in all the prompts, do:
|
121
|
+
|
122
|
+
```python
|
123
|
+
|
124
|
+
points = np.array([[[512., 512.], [100., 100.]]])
|
125
|
+
# For labels: 1 means foreground point, 0 means background
|
126
|
+
labels = np.array([[1., 0.]])
|
127
|
+
box = np.array([[[[384., 384.], [640., 640.]]]])
|
128
|
+
input_mask = np.ones((1, 1, 256, 256, 1))
|
129
|
+
Prepare an input dictionary:
|
130
|
+
inputs = {
|
131
|
+
"images": image,
|
132
|
+
"points": points,
|
133
|
+
"labels": labels,
|
134
|
+
"boxes": box,
|
135
|
+
"masks": input_mask
|
136
|
+
}
|
137
|
+
outputs = sam.predict(inputs)
|
138
|
+
masks, iou_pred = outputs["masks"], outputs["iou_pred"]
|
139
|
+
```
|
140
|
+
|
141
|
+
The first mask in the output `masks` (i.e. `masks[:, 0, ...]`) is the best
|
142
|
+
mask predicted by the model based on the prompts. Other `masks`
|
143
|
+
(i.e. `masks[:, 1:, ...]`) are alternate predictions that can be used if
|
144
|
+
they are desired over the first one.
|
145
|
+
Now, in case of only points and box prompts, simply exclude the masks:
|
146
|
+
|
147
|
+
```python
|
148
|
+
inputs = {
|
149
|
+
"images": image,
|
150
|
+
"points": points,
|
151
|
+
"labels": labels,
|
152
|
+
"boxes": box,
|
153
|
+
}
|
154
|
+
|
155
|
+
outputs = sam.predict(inputs)
|
156
|
+
masks, iou_pred = outputs["masks"], outputs["iou_pred"]
|
157
|
+
```
|
158
|
+
|
159
|
+
Another example is that only points prompts are present.
|
160
|
+
Note that if point prompts are present but no box prompt is present, the
|
161
|
+
points must be padded using a zero point and -1 label:
|
162
|
+
|
163
|
+
```python
|
164
|
+
padded_points = np.concatenate(
|
165
|
+
[points, np.zeros((1, 1, 2))], axis=1
|
166
|
+
)
|
167
|
+
|
168
|
+
padded_labels = np.concatenate(
|
169
|
+
[labels, -np.ones((1, 1))], axis=1
|
170
|
+
)
|
171
|
+
inputs = {
|
172
|
+
"images": image,
|
173
|
+
"points": padded_points,
|
174
|
+
"labels": padded_labels,
|
175
|
+
}
|
176
|
+
outputs = sam.predict(inputs)
|
177
|
+
masks, iou_pred = outputs["masks"], outputs["iou_pred"]
|
178
|
+
```
|
179
|
+
"""
|
180
|
+
|
181
|
+
backbone_cls = SAMBackbone
|
182
|
+
preprocessor_cls = None
|
183
|
+
|
184
|
+
def __init__(self, backbone, preprocessor=None, **kwargs):
|
185
|
+
# The implementation has been adapted form [Segment Anything
|
186
|
+
# paper](https://arxiv.org/abs/2304.02643) and [Segment Anything
|
187
|
+
# GitHub](https://github.com/facebookresearch/segment-anything) and
|
188
|
+
# [Detectron2](https://github.com/facebookresearch/detectron2).
|
189
|
+
# === Layers ===
|
190
|
+
self.backbone = backbone
|
191
|
+
# === Functional Model ===
|
192
|
+
inputs = self.backbone.input
|
193
|
+
x = self.backbone(inputs)
|
194
|
+
outputs = self.backbone.mask_decoder(**x)
|
195
|
+
super().__init__(inputs=inputs, outputs=outputs, **kwargs)
|
196
|
+
|
197
|
+
def predict_step(self, *args, **kwargs):
|
198
|
+
if len(args) == 2:
|
199
|
+
args = (args[0], self._add_placeholder_prompts(args[-1]))
|
200
|
+
else:
|
201
|
+
args = (self._add_placeholder_prompts(args[0]),)
|
202
|
+
|
203
|
+
return super().predict_step(*args, **kwargs)
|
204
|
+
|
205
|
+
def fit(self, *args, **kwargs):
|
206
|
+
raise NotImplementedError(
|
207
|
+
"Segment Anything Model only supports inference for now. Training"
|
208
|
+
" the model isn't supported yet."
|
209
|
+
)
|
210
|
+
|
211
|
+
def _add_placeholder_prompts(self, inputs):
|
212
|
+
"""Adds placeholder prompt inputs for a call to SAM.
|
213
|
+
|
214
|
+
Because SAM is a functional subclass model, all inputs must be specified in
|
215
|
+
calls to the model. However, prompt inputs are all optional, so we have to
|
216
|
+
add placeholders when they're not specified by the user.
|
217
|
+
"""
|
218
|
+
inputs = inputs.copy()
|
219
|
+
|
220
|
+
# Get the batch shape based on the image input
|
221
|
+
batch_size = ops.shape(inputs["images"])[0]
|
222
|
+
|
223
|
+
# The type of the placeholders must match the existing inputs with respect
|
224
|
+
# to whether or not they are tensors (as opposed to Numpy arrays).
|
225
|
+
zeros = ops.zeros if ops.is_tensor(inputs["images"]) else np.zeros
|
226
|
+
|
227
|
+
# Fill in missing inputs.
|
228
|
+
if "points" not in inputs:
|
229
|
+
inputs["points"] = zeros((batch_size, 0, 2))
|
230
|
+
if "labels" not in inputs:
|
231
|
+
inputs["labels"] = zeros((batch_size, 0))
|
232
|
+
if "boxes" not in inputs:
|
233
|
+
inputs["boxes"] = zeros((batch_size, 0, 2, 2))
|
234
|
+
if "masks" not in inputs:
|
235
|
+
inputs["masks"] = zeros((batch_size, 0, 256, 256, 1))
|
236
|
+
|
237
|
+
return inputs
|
@@ -0,0 +1,402 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import math
|
16
|
+
|
17
|
+
import keras
|
18
|
+
from keras import ops
|
19
|
+
|
20
|
+
|
21
|
+
class MLP(keras.layers.Layer):
|
22
|
+
"""A MLP block with architecture.
|
23
|
+
|
24
|
+
`input_dim -> [hidden_dim] * (num_layers - 1) -> output_dim`.
|
25
|
+
|
26
|
+
Args:
|
27
|
+
hidden_dim: int. The number of units in the hidden layers.
|
28
|
+
output_dim: int. The number of units in the output layer.
|
29
|
+
num_layers: int. The total number of dense layers to use.
|
30
|
+
activation: str. Activation to use in the hidden layers.
|
31
|
+
Default is `"relu"`.
|
32
|
+
"""
|
33
|
+
|
34
|
+
def __init__(
|
35
|
+
self, hidden_dim, output_dim, num_layers, activation="relu", **kwargs
|
36
|
+
):
|
37
|
+
super().__init__(**kwargs)
|
38
|
+
self.hidden_dim = hidden_dim
|
39
|
+
self.output_dim = output_dim
|
40
|
+
self.num_layers = num_layers
|
41
|
+
self.activation = activation
|
42
|
+
h = [hidden_dim] * (num_layers - 1)
|
43
|
+
self.mlp_block = []
|
44
|
+
for hidden_dim in h:
|
45
|
+
self.mlp_block.append(
|
46
|
+
keras.layers.Dense(hidden_dim, dtype=self.dtype_policy)
|
47
|
+
)
|
48
|
+
self.mlp_block.append(
|
49
|
+
keras.layers.Activation(activation, dtype=self.dtype_policy)
|
50
|
+
)
|
51
|
+
self.mlp_block.append(
|
52
|
+
keras.layers.Dense(output_dim, dtype=self.dtype_policy)
|
53
|
+
)
|
54
|
+
self.mlp_block = keras.models.Sequential(self.mlp_block)
|
55
|
+
|
56
|
+
def build(self, input_shape):
|
57
|
+
self.mlp_block.build(input_shape)
|
58
|
+
self.built = True
|
59
|
+
|
60
|
+
def call(self, x):
|
61
|
+
return self.mlp_block(x)
|
62
|
+
|
63
|
+
def get_config(self):
|
64
|
+
config = super().get_config()
|
65
|
+
config.update(
|
66
|
+
{
|
67
|
+
"hidden_dim": self.hidden_dim,
|
68
|
+
"output_dim": self.output_dim,
|
69
|
+
"num_layers": self.num_layers,
|
70
|
+
"activation": self.activation,
|
71
|
+
}
|
72
|
+
)
|
73
|
+
return config
|
74
|
+
|
75
|
+
|
76
|
+
class MultiHeadAttentionWithDownsampling(keras.layers.Layer):
|
77
|
+
"""Multi-Head Attention with downsampling.
|
78
|
+
|
79
|
+
An attention layer that allows for downscaling the size of the embedding
|
80
|
+
after projection to queries, keys, and values.
|
81
|
+
This layer first downscales the features of input queries, keys, and
|
82
|
+
values using a dense layer. Multi-head attention is then performed
|
83
|
+
and the attention map is projected back (upscaled) to the number of
|
84
|
+
input features.
|
85
|
+
|
86
|
+
Args:
|
87
|
+
num_heads: int. Number of attention heads.
|
88
|
+
key_dim: int. Size of each attention head for query, key, and
|
89
|
+
value.
|
90
|
+
downsample_rate: int, optional. The factor by which to downscale the
|
91
|
+
input features i.e. the input features of size `key_dim` are
|
92
|
+
projected down to `key_dim // downsample_rate`.
|
93
|
+
"""
|
94
|
+
|
95
|
+
def __init__(self, num_heads, key_dim, downsample_rate=1, **kwargs):
|
96
|
+
super().__init__(**kwargs)
|
97
|
+
self.num_heads = num_heads
|
98
|
+
self.key_dim = key_dim
|
99
|
+
self.downsample_rate = downsample_rate
|
100
|
+
self.internal_dims = key_dim // downsample_rate
|
101
|
+
|
102
|
+
# Downsample
|
103
|
+
self.query_proj = keras.layers.Dense(
|
104
|
+
self.internal_dims * self.num_heads, dtype=self.dtype_policy
|
105
|
+
)
|
106
|
+
self.key_proj = keras.layers.Dense(
|
107
|
+
self.internal_dims * self.num_heads, dtype=self.dtype_policy
|
108
|
+
)
|
109
|
+
self.value_proj = keras.layers.Dense(
|
110
|
+
self.internal_dims * self.num_heads, dtype=self.dtype_policy
|
111
|
+
)
|
112
|
+
|
113
|
+
# Upsample
|
114
|
+
self.out_proj = keras.layers.Dense(
|
115
|
+
self.key_dim * self.num_heads, dtype=self.dtype_policy
|
116
|
+
)
|
117
|
+
|
118
|
+
def build(self, input_shape=None):
|
119
|
+
self.query_proj.build([None, None, self.num_heads * self.key_dim])
|
120
|
+
self.key_proj.build([None, None, self.num_heads * self.key_dim])
|
121
|
+
self.value_proj.build([None, None, self.num_heads * self.key_dim])
|
122
|
+
self.out_proj.build([None, None, self.internal_dims * self.num_heads])
|
123
|
+
self.built = True
|
124
|
+
|
125
|
+
def _separate_heads(self, x):
|
126
|
+
shape = ops.shape(x)
|
127
|
+
batch_size, N, channels = shape[0], shape[1], shape[2]
|
128
|
+
x = ops.reshape(
|
129
|
+
x, (batch_size, N, self.num_heads, channels // self.num_heads)
|
130
|
+
)
|
131
|
+
return ops.transpose(x, axes=(0, 2, 1, 3))
|
132
|
+
|
133
|
+
def _recombine_heads(self, x):
|
134
|
+
shape = ops.shape(x)
|
135
|
+
batch_size, num_heads, N_T, channels_per_head = (
|
136
|
+
shape[0],
|
137
|
+
shape[1],
|
138
|
+
shape[2],
|
139
|
+
shape[3],
|
140
|
+
)
|
141
|
+
x = ops.transpose(x, axes=(0, 2, 1, 3))
|
142
|
+
return ops.reshape(x, (batch_size, N_T, num_heads * channels_per_head))
|
143
|
+
|
144
|
+
def call(self, query, value, key):
|
145
|
+
query = self.query_proj(query)
|
146
|
+
key = self.key_proj(key)
|
147
|
+
value = self.value_proj(value)
|
148
|
+
|
149
|
+
# Separate into heads
|
150
|
+
query = self._separate_heads(query)
|
151
|
+
key = self._separate_heads(key)
|
152
|
+
value = self._separate_heads(value)
|
153
|
+
|
154
|
+
# Attention
|
155
|
+
channels_per_head = ops.shape(query)[-1]
|
156
|
+
out = ops.matmul(query, ops.transpose(key, (0, 1, 3, 2)))
|
157
|
+
out = out / ops.sqrt(
|
158
|
+
ops.cast(channels_per_head, dtype=self.compute_dtype)
|
159
|
+
)
|
160
|
+
out = ops.softmax(out, axis=-1)
|
161
|
+
|
162
|
+
# Get output
|
163
|
+
attention_map = out @ value
|
164
|
+
attention_map = self._recombine_heads(attention_map)
|
165
|
+
return self.out_proj(attention_map)
|
166
|
+
|
167
|
+
def get_config(self):
|
168
|
+
config = super().get_config()
|
169
|
+
config.update(
|
170
|
+
{
|
171
|
+
"num_heads": self.num_heads,
|
172
|
+
"key_dim": self.key_dim,
|
173
|
+
"downsample_rate": self.downsample_rate,
|
174
|
+
}
|
175
|
+
)
|
176
|
+
return config
|
177
|
+
|
178
|
+
|
179
|
+
class TwoWayMultiHeadAttention(keras.layers.Layer):
|
180
|
+
"""Two-way multi-head attention layer.
|
181
|
+
|
182
|
+
Args:
|
183
|
+
num_heads: int. Number of attention heads.
|
184
|
+
key_dim: int. Size of each attention head for query, key, and
|
185
|
+
value.
|
186
|
+
intermediate_dim: int. Number of hidden dims to use in the mlp block.
|
187
|
+
skip_first_layer_pos_embedding: bool. A boolean indicating whether to skip the
|
188
|
+
first layer positional embeddings.
|
189
|
+
attention_downsample_rate: int, optional. The downsample rate to use
|
190
|
+
in the attention layers. Defaults to 2.
|
191
|
+
activation: str, optional. The activation for the mlp block's output
|
192
|
+
layer. Defaults to "relu".
|
193
|
+
"""
|
194
|
+
|
195
|
+
def __init__(
|
196
|
+
self,
|
197
|
+
num_heads,
|
198
|
+
key_dim,
|
199
|
+
intermediate_dim,
|
200
|
+
skip_first_layer_pos_embedding,
|
201
|
+
attention_downsample_rate=2,
|
202
|
+
activation="relu",
|
203
|
+
**kwargs,
|
204
|
+
):
|
205
|
+
super().__init__(**kwargs)
|
206
|
+
self.num_heads = num_heads
|
207
|
+
self.key_dim = key_dim
|
208
|
+
self.intermediate_dim = intermediate_dim
|
209
|
+
self.skip_first_layer_pos_embedding = skip_first_layer_pos_embedding
|
210
|
+
self.attention_downsample_rate = attention_downsample_rate
|
211
|
+
self.activation = activation
|
212
|
+
|
213
|
+
self.self_attention = MultiHeadAttentionWithDownsampling(
|
214
|
+
num_heads=num_heads, key_dim=key_dim, dtype=self.dtype_policy
|
215
|
+
)
|
216
|
+
self.layer_norm1 = keras.layers.LayerNormalization(
|
217
|
+
epsilon=1e-5, dtype=self.dtype_policy
|
218
|
+
)
|
219
|
+
self.cross_attention_token_to_image = (
|
220
|
+
MultiHeadAttentionWithDownsampling(
|
221
|
+
num_heads=num_heads,
|
222
|
+
key_dim=key_dim,
|
223
|
+
downsample_rate=attention_downsample_rate,
|
224
|
+
dtype=self.dtype_policy,
|
225
|
+
)
|
226
|
+
)
|
227
|
+
self.layer_norm2 = keras.layers.LayerNormalization(
|
228
|
+
epsilon=1e-5, dtype=self.dtype_policy
|
229
|
+
)
|
230
|
+
|
231
|
+
self.mlp_block = MLP(
|
232
|
+
intermediate_dim,
|
233
|
+
key_dim * num_heads,
|
234
|
+
num_layers=2,
|
235
|
+
activation=activation,
|
236
|
+
dtype=self.dtype_policy,
|
237
|
+
)
|
238
|
+
|
239
|
+
self.layer_norm3 = keras.layers.LayerNormalization(
|
240
|
+
epsilon=1e-5, dtype=self.dtype_policy
|
241
|
+
)
|
242
|
+
self.cross_attention_image_to_token = (
|
243
|
+
MultiHeadAttentionWithDownsampling(
|
244
|
+
num_heads=num_heads,
|
245
|
+
key_dim=key_dim,
|
246
|
+
downsample_rate=attention_downsample_rate,
|
247
|
+
dtype=self.dtype_policy,
|
248
|
+
)
|
249
|
+
)
|
250
|
+
self.layer_norm4 = keras.layers.LayerNormalization(
|
251
|
+
epsilon=1e-5, dtype=self.dtype_policy
|
252
|
+
)
|
253
|
+
|
254
|
+
def build(self, input_shape=None):
|
255
|
+
self.self_attention.build()
|
256
|
+
self.layer_norm1.build([None, None, self.num_heads * self.key_dim])
|
257
|
+
self.cross_attention_token_to_image.build()
|
258
|
+
self.layer_norm2.build([None, None, self.num_heads * self.key_dim])
|
259
|
+
self.mlp_block.build([None, None, self.num_heads * self.key_dim])
|
260
|
+
self.layer_norm3.build([None, None, self.num_heads * self.key_dim])
|
261
|
+
self.cross_attention_image_to_token.build()
|
262
|
+
self.layer_norm4.build([None, None, self.num_heads * self.key_dim])
|
263
|
+
self.built = True
|
264
|
+
|
265
|
+
def call(self, queries, keys, query_pos_embedding, key_pos_embedding):
|
266
|
+
if self.skip_first_layer_pos_embedding:
|
267
|
+
queries = self.self_attention(
|
268
|
+
query=queries, value=queries, key=queries
|
269
|
+
)
|
270
|
+
else:
|
271
|
+
queries_with_pos_embedding = queries + query_pos_embedding
|
272
|
+
attention_map = self.self_attention(
|
273
|
+
query=queries_with_pos_embedding,
|
274
|
+
key=queries_with_pos_embedding,
|
275
|
+
value=queries,
|
276
|
+
)
|
277
|
+
queries = queries + attention_map
|
278
|
+
queries = self.layer_norm1(queries)
|
279
|
+
|
280
|
+
queries_with_pos_embedding = queries + query_pos_embedding
|
281
|
+
keys_with_pos_embedding = keys + key_pos_embedding
|
282
|
+
attention_map = self.cross_attention_token_to_image(
|
283
|
+
query=queries_with_pos_embedding,
|
284
|
+
key=keys_with_pos_embedding,
|
285
|
+
value=keys,
|
286
|
+
)
|
287
|
+
queries = queries + attention_map
|
288
|
+
queries = self.layer_norm2(queries)
|
289
|
+
|
290
|
+
mlp_out = self.mlp_block(queries)
|
291
|
+
queries = queries + mlp_out
|
292
|
+
queries = self.layer_norm3(queries)
|
293
|
+
|
294
|
+
queries_with_pos_embedding = queries + query_pos_embedding
|
295
|
+
keys_with_pos_embedding = keys + key_pos_embedding
|
296
|
+
attention_map = self.cross_attention_image_to_token(
|
297
|
+
query=keys_with_pos_embedding,
|
298
|
+
key=queries_with_pos_embedding,
|
299
|
+
value=queries,
|
300
|
+
)
|
301
|
+
keys = keys + attention_map
|
302
|
+
keys = self.layer_norm4(keys)
|
303
|
+
|
304
|
+
return queries, keys
|
305
|
+
|
306
|
+
def get_config(self):
|
307
|
+
config = super().get_config()
|
308
|
+
config.update(
|
309
|
+
{
|
310
|
+
"num_heads": self.num_heads,
|
311
|
+
"key_dim": self.key_dim,
|
312
|
+
"intermediate_dim": self.intermediate_dim,
|
313
|
+
"skip_first_layer_pos_embedding": self.skip_first_layer_pos_embedding,
|
314
|
+
"attention_downsample_rate": self.attention_downsample_rate,
|
315
|
+
"activation": self.activation,
|
316
|
+
}
|
317
|
+
)
|
318
|
+
return config
|
319
|
+
|
320
|
+
|
321
|
+
class RandomFrequencyPositionalEmbeddings(keras.layers.Layer):
|
322
|
+
"""Positional encoding using random spatial frequencies.
|
323
|
+
|
324
|
+
This layer maps coordinates/points in 2D space to positional
|
325
|
+
encodings using random spatial frequencies.
|
326
|
+
|
327
|
+
Args:
|
328
|
+
num_positional_features: int. Number of positional features
|
329
|
+
in the output.
|
330
|
+
scale: float. The standard deviation of the random frequencies.
|
331
|
+
"""
|
332
|
+
|
333
|
+
def __init__(self, num_positional_features, scale, **kwargs):
|
334
|
+
super().__init__(**kwargs)
|
335
|
+
self.num_positional_features = num_positional_features
|
336
|
+
self.scale = scale
|
337
|
+
self.positional_encoding_gaussian_matrix = self.add_weight(
|
338
|
+
name="positional_encoding_gaussian_matrix",
|
339
|
+
shape=(2, self.num_positional_features),
|
340
|
+
dtype=self.variable_dtype,
|
341
|
+
trainable=False,
|
342
|
+
initializer=keras.initializers.get("normal"),
|
343
|
+
)
|
344
|
+
|
345
|
+
def build(self, input_shape=None):
|
346
|
+
self.built = True
|
347
|
+
|
348
|
+
def _positional_encodings(self, coords):
|
349
|
+
coords = coords * 2 - 1
|
350
|
+
coords = coords @ ops.cast(
|
351
|
+
self.positional_encoding_gaussian_matrix, dtype=self.compute_dtype
|
352
|
+
)
|
353
|
+
coords = coords * (2 * math.pi)
|
354
|
+
return ops.concatenate([ops.sin(coords), ops.cos(coords)], axis=-1)
|
355
|
+
|
356
|
+
def call(self, size):
|
357
|
+
return self.encode_image(size)
|
358
|
+
|
359
|
+
def encode_image(self, size):
|
360
|
+
"""Generate a positional encoding for an image of any given size.
|
361
|
+
Args:
|
362
|
+
size: tuple[int, int]. The size of the image.
|
363
|
+
Returns:
|
364
|
+
tensor: Positional encoding of the image.
|
365
|
+
"""
|
366
|
+
height, width = size
|
367
|
+
grid = ops.ones(shape=(height, width), dtype=self.compute_dtype)
|
368
|
+
y_embed = ops.cumsum(grid, axis=0) - 0.5
|
369
|
+
x_embed = ops.cumsum(grid, axis=1) - 0.5
|
370
|
+
y_embed = y_embed / ops.cast(height, self.compute_dtype)
|
371
|
+
x_embed = x_embed / ops.cast(width, self.compute_dtype)
|
372
|
+
return self._positional_encodings(
|
373
|
+
ops.stack([x_embed, y_embed], axis=-1)
|
374
|
+
)
|
375
|
+
|
376
|
+
def encode_coordinates(self, coords_input, image_size):
|
377
|
+
"""Positionally encode points that are not normalized to `[0, 1]`.
|
378
|
+
Args:
|
379
|
+
coords_input: tensor. 2D coordinates/points to map.
|
380
|
+
image_size: tuple[int, int]. Height and width of the image
|
381
|
+
being prompted.
|
382
|
+
Returns:
|
383
|
+
tensor: Positional encodings of the normalized coordinates.
|
384
|
+
"""
|
385
|
+
coords_normalized = ops.stack(
|
386
|
+
[
|
387
|
+
coords_input[..., 0] / image_size[1],
|
388
|
+
coords_input[..., 1] / image_size[0],
|
389
|
+
],
|
390
|
+
axis=-1,
|
391
|
+
)
|
392
|
+
return self._positional_encodings(coords_normalized)
|
393
|
+
|
394
|
+
def get_config(self):
|
395
|
+
config = super().get_config()
|
396
|
+
config.update(
|
397
|
+
{
|
398
|
+
"num_positional_features": self.num_positional_features,
|
399
|
+
"scale": self.scale,
|
400
|
+
}
|
401
|
+
)
|
402
|
+
return config
|