keras-hub-nightly 0.22.0.dev202507150421__py3-none-any.whl → 0.22.0.dev202507170424__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/layers/__init__.py +3 -0
- keras_hub/models/__init__.py +3 -0
- keras_hub/src/models/clip/clip_backbone.py +3 -102
- keras_hub/src/models/clip/clip_layers.py +295 -0
- keras_hub/src/models/clip/clip_preprocessor.py +57 -48
- keras_hub/src/models/clip/clip_text_encoder.py +2 -2
- keras_hub/src/models/clip/clip_vision_encoder.py +3 -3
- keras_hub/src/models/dinov2/__init__.py +5 -0
- keras_hub/src/models/dinov2/dinov2_backbone.py +228 -0
- keras_hub/src/models/dinov2/dinov2_image_converter.py +8 -0
- keras_hub/src/models/dinov2/dinov2_layers.py +886 -0
- keras_hub/src/models/dinov2/dinov2_presets.py +4 -0
- keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +6 -2
- keras_hub/src/models/hgnetv2/__init__.py +5 -0
- keras_hub/src/models/hgnetv2/hgnetv2_presets.py +5 -5
- keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +16 -7
- keras_hub/src/models/stable_diffusion_3/mmdit.py +61 -4
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +23 -32
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +1 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +1 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +1 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +6 -2
- keras_hub/src/utils/preset_utils.py +4 -1
- keras_hub/src/utils/transformers/convert_dinov2.py +180 -0
- keras_hub/src/utils/transformers/export/gemma.py +89 -0
- keras_hub/src/utils/transformers/export/hf_exporter.py +98 -0
- keras_hub/src/utils/transformers/preset_loader.py +4 -1
- keras_hub/src/version.py +1 -1
- {keras_hub_nightly-0.22.0.dev202507150421.dist-info → keras_hub_nightly-0.22.0.dev202507170424.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.22.0.dev202507150421.dist-info → keras_hub_nightly-0.22.0.dev202507170424.dist-info}/RECORD +32 -25
- keras_hub/src/models/clip/clip_encoder_block.py +0 -111
- keras_hub/src/models/clip/clip_vision_embedding.py +0 -101
- {keras_hub_nightly-0.22.0.dev202507150421.dist-info → keras_hub_nightly-0.22.0.dev202507170424.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.22.0.dev202507150421.dist-info → keras_hub_nightly-0.22.0.dev202507170424.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,886 @@
|
|
1
|
+
from keras import backend
|
2
|
+
from keras import config
|
3
|
+
from keras import initializers
|
4
|
+
from keras import layers
|
5
|
+
from keras import ops
|
6
|
+
from keras import random
|
7
|
+
|
8
|
+
from keras_hub.src.utils.keras_utils import standardize_data_format
|
9
|
+
|
10
|
+
|
11
|
+
class DINOV2PatchEmbedding(layers.Layer):
|
12
|
+
"""A layer that converts images into patches.
|
13
|
+
|
14
|
+
Args:
|
15
|
+
hidden_dim: int. The number of units in the hidden layers.
|
16
|
+
patch_size: int. The size of one side of each patch.
|
17
|
+
data_format: `None` or str. If specified, either `"channels_last"` or
|
18
|
+
`"channels_first"`. The ordering of the dimensions in the
|
19
|
+
inputs. `"channels_last"` corresponds to inputs with shape
|
20
|
+
`(batch_size, height, width, channels)`
|
21
|
+
while `"channels_first"` corresponds to inputs with shape
|
22
|
+
`(batch_size, channels, height, width)`. It defaults to the
|
23
|
+
`image_data_format` value found in your Keras config file at
|
24
|
+
`~/.keras/keras.json`. If you never set it, then it will be
|
25
|
+
`"channels_last"`.
|
26
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
27
|
+
including `name`, `dtype` etc.
|
28
|
+
"""
|
29
|
+
|
30
|
+
def __init__(self, hidden_dim, patch_size, data_format=None, **kwargs):
|
31
|
+
super().__init__(**kwargs)
|
32
|
+
self.hidden_dim = int(hidden_dim)
|
33
|
+
self.patch_size = int(patch_size)
|
34
|
+
self.data_format = standardize_data_format(data_format)
|
35
|
+
|
36
|
+
self.projection = layers.Conv2D(
|
37
|
+
hidden_dim,
|
38
|
+
kernel_size=patch_size,
|
39
|
+
strides=patch_size,
|
40
|
+
data_format=data_format,
|
41
|
+
kernel_initializer=initializers.TruncatedNormal(stddev=0.02),
|
42
|
+
dtype=self.dtype_policy,
|
43
|
+
name="projection",
|
44
|
+
)
|
45
|
+
|
46
|
+
def build(self, input_shape):
|
47
|
+
self.projection.build(input_shape)
|
48
|
+
|
49
|
+
def call(self, inputs, training=None):
|
50
|
+
batch_size = ops.shape(inputs)[0]
|
51
|
+
embeddings = self.projection(inputs, training=training)
|
52
|
+
if self.data_format == "channels_last":
|
53
|
+
embeddings = ops.reshape(
|
54
|
+
embeddings, (batch_size, -1, self.hidden_dim)
|
55
|
+
)
|
56
|
+
else:
|
57
|
+
embeddings = ops.reshape(
|
58
|
+
embeddings, (batch_size, self.hidden_dim, -1)
|
59
|
+
)
|
60
|
+
embeddings = ops.transpose(embeddings, (0, 2, 1))
|
61
|
+
return embeddings
|
62
|
+
|
63
|
+
def get_config(self):
|
64
|
+
config = super().get_config()
|
65
|
+
config.update(
|
66
|
+
{
|
67
|
+
"hidden_dim": self.hidden_dim,
|
68
|
+
"patch_size": self.patch_size,
|
69
|
+
}
|
70
|
+
)
|
71
|
+
return config
|
72
|
+
|
73
|
+
def compute_output_shape(self, input_shape):
|
74
|
+
output_shape = [input_shape[0], None, self.hidden_dim]
|
75
|
+
if self.data_format == "channels_last":
|
76
|
+
if input_shape[1] is not None and input_shape[2] is not None:
|
77
|
+
patch_num = input_shape[1] // self.patch_size
|
78
|
+
output_shape[1] = patch_num**2
|
79
|
+
else:
|
80
|
+
if input_shape[2] is not None and input_shape[3] is not None:
|
81
|
+
patch_num = input_shape[2] // self.patch_size
|
82
|
+
output_shape[1] = patch_num**2
|
83
|
+
return output_shape
|
84
|
+
|
85
|
+
|
86
|
+
class DINOV2Embedding(layers.Layer):
|
87
|
+
"""A layer that converts images into patches.
|
88
|
+
|
89
|
+
This layer adds all the necessary tokens to the embeddings, inlcuding
|
90
|
+
the class token, register tokens and mask token if specified. Finally, a
|
91
|
+
position embedding will be added.
|
92
|
+
|
93
|
+
This layer supports the interpolation of the position embeddings to enable
|
94
|
+
the model to work with images of different sizes. Please refer to
|
95
|
+
`_interpolate_position_embeddings` for more details.
|
96
|
+
|
97
|
+
The saving and loading of this layer will automatically handle the position
|
98
|
+
embeddings interpolation. Please refer to `save_own_variables` and
|
99
|
+
`load_own_variables` for more details.
|
100
|
+
|
101
|
+
Args:
|
102
|
+
hidden_dim: int. The number of units in the hidden layers.
|
103
|
+
patch_size: int. The size of one side of each patch.
|
104
|
+
image_size: tuple of ints. The (height, width) of the input images.
|
105
|
+
num_register_tokens: int. The number of register tokens to add to the
|
106
|
+
embeddings. Defaults to `0`.
|
107
|
+
use_mask_token: bool. Whether to use a mask token. Defaults to `True`.
|
108
|
+
dropout_rate: float. The dropout rate to use. Defaults to `0.0`.
|
109
|
+
position_embedding_shape: tuple. The original input shape used to
|
110
|
+
train the position embeddings. This is used to interpolate the
|
111
|
+
position embeddings to the actual input shape. Defaults to
|
112
|
+
`(518, 518)`.
|
113
|
+
antialias_in_interpolation: bool. Whether to use antialiasing in the
|
114
|
+
interpolation of the position embeddings. Defaults to `False`.
|
115
|
+
data_format: `None` or str. If specified, either `"channels_last"` or
|
116
|
+
`"channels_first"`. The ordering of the dimensions in the
|
117
|
+
inputs. `"channels_last"` corresponds to inputs with shape
|
118
|
+
`(batch_size, height, width, channels)`
|
119
|
+
while `"channels_first"` corresponds to inputs with shape
|
120
|
+
`(batch_size, channels, height, width)`. It defaults to the
|
121
|
+
`image_data_format` value found in your Keras config file at
|
122
|
+
`~/.keras/keras.json`. If you never set it, then it will be
|
123
|
+
`"channels_last"`.
|
124
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
125
|
+
including `name`, `dtype` etc.
|
126
|
+
"""
|
127
|
+
|
128
|
+
def __init__(
|
129
|
+
self,
|
130
|
+
hidden_dim,
|
131
|
+
patch_size,
|
132
|
+
image_shape,
|
133
|
+
num_register_tokens=0,
|
134
|
+
use_mask_token=True,
|
135
|
+
dropout_rate=0.0,
|
136
|
+
position_embedding_shape=(518, 518),
|
137
|
+
antialias_in_interpolation=False,
|
138
|
+
data_format=None,
|
139
|
+
**kwargs,
|
140
|
+
):
|
141
|
+
super().__init__(**kwargs)
|
142
|
+
self.hidden_dim = int(hidden_dim)
|
143
|
+
self.patch_size = int(patch_size)
|
144
|
+
self.image_shape = (int(image_shape[0]), int(image_shape[1]))
|
145
|
+
self.position_embedding_shape = (
|
146
|
+
int(position_embedding_shape[0]),
|
147
|
+
int(position_embedding_shape[1]),
|
148
|
+
)
|
149
|
+
self.num_register_tokens = int(num_register_tokens)
|
150
|
+
self.use_mask_token = bool(use_mask_token)
|
151
|
+
self.dropout_rate = float(dropout_rate)
|
152
|
+
self.antialias_in_interpolation = bool(antialias_in_interpolation)
|
153
|
+
self.data_format = standardize_data_format(data_format)
|
154
|
+
self.interpolated_num_patches = (
|
155
|
+
self.image_shape[0] // self.patch_size
|
156
|
+
) * (self.image_shape[1] // self.patch_size)
|
157
|
+
self.num_patches = (
|
158
|
+
self.position_embedding_shape[0] // self.patch_size
|
159
|
+
) * (self.position_embedding_shape[1] // self.patch_size)
|
160
|
+
|
161
|
+
self.patch_embeddings = DINOV2PatchEmbedding(
|
162
|
+
hidden_dim,
|
163
|
+
patch_size,
|
164
|
+
data_format=data_format,
|
165
|
+
dtype=self.dtype_policy,
|
166
|
+
name="patch_embeddings",
|
167
|
+
)
|
168
|
+
self.dropout = layers.Dropout(
|
169
|
+
rate=self.dropout_rate,
|
170
|
+
dtype=self.dtype_policy,
|
171
|
+
name="dropout",
|
172
|
+
)
|
173
|
+
|
174
|
+
def build(self, input_shape):
|
175
|
+
self.cls_token = self.add_weight(
|
176
|
+
shape=(1, 1, self.hidden_dim),
|
177
|
+
initializer=initializers.TruncatedNormal(stddev=0.02),
|
178
|
+
trainable=True,
|
179
|
+
name="cls_token",
|
180
|
+
)
|
181
|
+
if self.use_mask_token:
|
182
|
+
self.mask_token = self.add_weight(
|
183
|
+
shape=(1, self.hidden_dim),
|
184
|
+
initializer="zeros",
|
185
|
+
trainable=True,
|
186
|
+
name="mask_token",
|
187
|
+
)
|
188
|
+
if self.num_register_tokens > 0:
|
189
|
+
self.register_tokens = self.add_weight(
|
190
|
+
shape=(1, self.num_register_tokens, self.hidden_dim),
|
191
|
+
initializer="zeros",
|
192
|
+
trainable=True,
|
193
|
+
name="register_tokens",
|
194
|
+
)
|
195
|
+
self.patch_embeddings.build(input_shape)
|
196
|
+
|
197
|
+
# Note that there are two position embeddings:
|
198
|
+
# `self.interpolated_position_embeddings` is used for the image inputs
|
199
|
+
# during both training and inference.
|
200
|
+
# `self.position_embeddings` is used to load pretrained weights and
|
201
|
+
# remains unchanged during training and inference. It will be updated
|
202
|
+
# during saving once `self.interpolated_position_embeddings` is
|
203
|
+
# modified.
|
204
|
+
self.position_embeddings = self.add_weight(
|
205
|
+
shape=(1, self.num_patches + 1, self.hidden_dim),
|
206
|
+
initializer=initializers.TruncatedNormal(stddev=0.02),
|
207
|
+
trainable=False,
|
208
|
+
name="position_embeddings",
|
209
|
+
)
|
210
|
+
self.interpolated_position_embeddings = self.add_weight(
|
211
|
+
shape=(1, self.interpolated_num_patches + 1, self.hidden_dim),
|
212
|
+
initializer="zeros", # Will be initialized by interpolation.
|
213
|
+
trainable=True,
|
214
|
+
name="interpolated_position_embeddings",
|
215
|
+
)
|
216
|
+
|
217
|
+
# Initialize the interpolated position embeddings.
|
218
|
+
self.interpolated_position_embeddings.assign(
|
219
|
+
self._interpolate_position_embeddings(
|
220
|
+
self.position_embeddings,
|
221
|
+
patch_size=self.patch_size,
|
222
|
+
source_shape=self.position_embedding_shape,
|
223
|
+
target_shape=self.image_shape,
|
224
|
+
antialias=self.antialias_in_interpolation,
|
225
|
+
)
|
226
|
+
)
|
227
|
+
|
228
|
+
def call(self, inputs, masks=None, training=None):
|
229
|
+
batch_size = ops.shape(inputs)[0]
|
230
|
+
embeddings = self.patch_embeddings(inputs, training=training)
|
231
|
+
|
232
|
+
# Repalce the embeddings with the mask tokens if specified.
|
233
|
+
# Basically, this is only used during training.
|
234
|
+
if masks is not None and self.use_mask_token:
|
235
|
+
masks = ops.expand_dims(masks, axis=-1)
|
236
|
+
mask_token = ops.cast(
|
237
|
+
ops.expand_dims(self.mask_token, axis=0), embeddings.dtype
|
238
|
+
)
|
239
|
+
embeddings = ops.where(masks, mask_token, embeddings)
|
240
|
+
|
241
|
+
# Add the [CLS] token to the embedded patch tokens.
|
242
|
+
cls_tokens = ops.tile(self.cls_token, (batch_size, 1, 1))
|
243
|
+
embeddings = ops.concatenate((cls_tokens, embeddings), axis=1)
|
244
|
+
|
245
|
+
# Add positional encoding to each token.
|
246
|
+
embeddings = ops.add(embeddings, self.interpolated_position_embeddings)
|
247
|
+
|
248
|
+
# Add register tokens if specified.
|
249
|
+
if self.num_register_tokens > 0:
|
250
|
+
register_tokens = ops.tile(self.register_tokens, (batch_size, 1, 1))
|
251
|
+
embeddings = ops.concatenate(
|
252
|
+
(
|
253
|
+
embeddings[:, :1, ...],
|
254
|
+
register_tokens,
|
255
|
+
embeddings[:, 1:, ...],
|
256
|
+
),
|
257
|
+
axis=1,
|
258
|
+
)
|
259
|
+
|
260
|
+
embeddings = self.dropout(embeddings)
|
261
|
+
return embeddings
|
262
|
+
|
263
|
+
def get_config(self):
|
264
|
+
config = super().get_config()
|
265
|
+
config.update(
|
266
|
+
{
|
267
|
+
"hidden_dim": self.hidden_dim,
|
268
|
+
"patch_size": self.patch_size,
|
269
|
+
"image_shape": self.image_shape,
|
270
|
+
"num_register_tokens": self.num_register_tokens,
|
271
|
+
"use_mask_token": self.use_mask_token,
|
272
|
+
"dropout_rate": self.dropout_rate,
|
273
|
+
"position_embedding_shape": self.position_embedding_shape,
|
274
|
+
"antialias_in_interpolation": self.antialias_in_interpolation,
|
275
|
+
}
|
276
|
+
)
|
277
|
+
return config
|
278
|
+
|
279
|
+
def compute_output_shape(self, input_shape):
|
280
|
+
output_shape = [input_shape[0], None, self.hidden_dim]
|
281
|
+
if self.data_format == "channels_last":
|
282
|
+
if input_shape[1] is not None and input_shape[2] is not None:
|
283
|
+
patch_num = input_shape[1] // self.patch_size
|
284
|
+
# 1 is for cls token.
|
285
|
+
output_shape[1] = 1 + self.num_register_tokens + patch_num**2
|
286
|
+
else:
|
287
|
+
if input_shape[2] is not None and input_shape[3] is not None:
|
288
|
+
patch_num = input_shape[2] // self.patch_size
|
289
|
+
# 1 is for cls token.
|
290
|
+
output_shape[1] = 1 + self.num_register_tokens + patch_num**2
|
291
|
+
return output_shape
|
292
|
+
|
293
|
+
@staticmethod
|
294
|
+
def _interpolate_position_embeddings(
|
295
|
+
position_embeddings,
|
296
|
+
patch_size,
|
297
|
+
source_shape,
|
298
|
+
target_shape,
|
299
|
+
antialias=False,
|
300
|
+
):
|
301
|
+
"""Interpolate position embeddings to match the target image shape.
|
302
|
+
|
303
|
+
Reference:
|
304
|
+
- https://github.com/huggingface/transformers/blob/main/src/transformers/models/dinov2/modeling_dinov2.py
|
305
|
+
"""
|
306
|
+
position_embeddings = ops.convert_to_tensor(position_embeddings)
|
307
|
+
patch_size = int(patch_size)
|
308
|
+
source_shape = (int(source_shape[0]), int(source_shape[1]))
|
309
|
+
target_shape = (int(target_shape[0]), int(target_shape[1]))
|
310
|
+
hidden_dim = int(position_embeddings.shape[-1])
|
311
|
+
|
312
|
+
if (
|
313
|
+
source_shape[0] == target_shape[0]
|
314
|
+
and source_shape[1] == target_shape[1]
|
315
|
+
):
|
316
|
+
# No need to interpolate if the image size is the same as the
|
317
|
+
# position embedding image size.
|
318
|
+
return ops.copy(position_embeddings)
|
319
|
+
|
320
|
+
num_positions = int(position_embeddings.shape[1]) - 1
|
321
|
+
|
322
|
+
# Handle class token and patch embeddings separately.
|
323
|
+
class_position_embeddings = position_embeddings[:, :1, ...]
|
324
|
+
patch_position_embeddings = position_embeddings[:, 1:, ...]
|
325
|
+
|
326
|
+
# Calculate new dimensions
|
327
|
+
new_height = target_shape[0] // patch_size
|
328
|
+
new_width = target_shape[1] // patch_size
|
329
|
+
|
330
|
+
# Reshape for interpolation
|
331
|
+
sqrt_num_positions = int(num_positions**0.5)
|
332
|
+
patch_position_embeddings = ops.reshape(
|
333
|
+
patch_position_embeddings,
|
334
|
+
(1, sqrt_num_positions, sqrt_num_positions, hidden_dim),
|
335
|
+
)
|
336
|
+
|
337
|
+
# Interpolate at float32 precision.
|
338
|
+
original_dtype = backend.standardize_dtype(
|
339
|
+
patch_position_embeddings.dtype
|
340
|
+
)
|
341
|
+
interpolated_patch_position_embeddings = ops.image.resize(
|
342
|
+
ops.cast(patch_position_embeddings, "float32"),
|
343
|
+
size=(new_height, new_width),
|
344
|
+
interpolation="bicubic",
|
345
|
+
antialias=antialias,
|
346
|
+
data_format="channels_last",
|
347
|
+
)
|
348
|
+
interpolated_patch_position_embeddings = ops.cast(
|
349
|
+
interpolated_patch_position_embeddings, original_dtype
|
350
|
+
)
|
351
|
+
|
352
|
+
# Reshape back to the original format
|
353
|
+
interpolated_patch_position_embeddings = ops.reshape(
|
354
|
+
interpolated_patch_position_embeddings, (1, -1, hidden_dim)
|
355
|
+
)
|
356
|
+
interpolated_position_embeddings = ops.concatenate(
|
357
|
+
(class_position_embeddings, interpolated_patch_position_embeddings),
|
358
|
+
axis=1,
|
359
|
+
)
|
360
|
+
return interpolated_position_embeddings
|
361
|
+
|
362
|
+
def _is_interpolated_position_embeddings_updated(self):
|
363
|
+
"""Check if the interpolated position embeddings are updated."""
|
364
|
+
original_interpolated_position_embeddings = (
|
365
|
+
self._interpolate_position_embeddings(
|
366
|
+
self.position_embeddings,
|
367
|
+
patch_size=self.patch_size,
|
368
|
+
source_shape=self.position_embedding_shape,
|
369
|
+
target_shape=self.image_shape,
|
370
|
+
antialias=self.antialias_in_interpolation,
|
371
|
+
)
|
372
|
+
)
|
373
|
+
diff = ops.sum(
|
374
|
+
ops.subtract(
|
375
|
+
original_interpolated_position_embeddings,
|
376
|
+
self.interpolated_position_embeddings,
|
377
|
+
)
|
378
|
+
)
|
379
|
+
return ops.cond(
|
380
|
+
ops.greater(diff, config.epsilon()), lambda: True, lambda: False
|
381
|
+
)
|
382
|
+
|
383
|
+
def save_own_variables(self, store):
|
384
|
+
if self._is_interpolated_position_embeddings_updated():
|
385
|
+
self.position_embeddings.assign(
|
386
|
+
self._interpolate_position_embeddings(
|
387
|
+
self.interpolated_position_embeddings,
|
388
|
+
patch_size=self.patch_size,
|
389
|
+
source_shape=self.image_shape,
|
390
|
+
target_shape=self.position_embedding_shape,
|
391
|
+
antialias=self.antialias_in_interpolation,
|
392
|
+
)
|
393
|
+
)
|
394
|
+
super().save_own_variables(store)
|
395
|
+
|
396
|
+
def load_own_variables(self, store):
|
397
|
+
all_vars = self._trainable_variables + self._non_trainable_variables
|
398
|
+
for i, v in enumerate(all_vars):
|
399
|
+
if v is self.interpolated_position_embeddings:
|
400
|
+
continue
|
401
|
+
v.assign(store[f"{i}"])
|
402
|
+
self.interpolated_position_embeddings.assign(
|
403
|
+
self._interpolate_position_embeddings(
|
404
|
+
self.position_embeddings,
|
405
|
+
patch_size=self.patch_size,
|
406
|
+
source_shape=self.position_embedding_shape,
|
407
|
+
target_shape=self.image_shape,
|
408
|
+
antialias=self.antialias_in_interpolation,
|
409
|
+
)
|
410
|
+
)
|
411
|
+
|
412
|
+
|
413
|
+
class DINOV2Attention(layers.Layer):
|
414
|
+
"""A multi-head attention layer with dropout.
|
415
|
+
|
416
|
+
Args:
|
417
|
+
hidden_dim: int. The number of units in the hidden layers.
|
418
|
+
num_heads: int. Number of attention heads.
|
419
|
+
dropout_rate: float. The dropout rate to use. Defaults to `0.0`.
|
420
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
421
|
+
including `name`, `dtype` etc.
|
422
|
+
"""
|
423
|
+
|
424
|
+
def __init__(self, hidden_dim, num_heads, dropout_rate=0.0, **kwargs):
|
425
|
+
super().__init__(**kwargs)
|
426
|
+
self.hidden_dim = int(hidden_dim)
|
427
|
+
self.num_heads = int(num_heads)
|
428
|
+
self.dropout_rate = float(dropout_rate)
|
429
|
+
|
430
|
+
self.attention = layers.MultiHeadAttention(
|
431
|
+
num_heads=self.num_heads,
|
432
|
+
key_dim=self.hidden_dim // self.num_heads,
|
433
|
+
dropout=self.dropout_rate,
|
434
|
+
dtype=self.dtype_policy,
|
435
|
+
name="attention",
|
436
|
+
)
|
437
|
+
self.dropout = layers.Dropout(
|
438
|
+
rate=self.dropout_rate,
|
439
|
+
dtype=self.dtype_policy,
|
440
|
+
name="dropout",
|
441
|
+
)
|
442
|
+
|
443
|
+
def build(self, input_shape):
|
444
|
+
self.attention.build(input_shape, input_shape)
|
445
|
+
|
446
|
+
def call(self, inputs, training=None):
|
447
|
+
attention_output = self.attention(
|
448
|
+
query=inputs,
|
449
|
+
value=inputs,
|
450
|
+
key=inputs,
|
451
|
+
training=training,
|
452
|
+
use_causal_mask=False,
|
453
|
+
)
|
454
|
+
outputs = self.dropout(attention_output, training=training)
|
455
|
+
return outputs
|
456
|
+
|
457
|
+
def get_config(self):
|
458
|
+
config = super().get_config()
|
459
|
+
config.update(
|
460
|
+
{
|
461
|
+
"hidden_dim": self.hidden_dim,
|
462
|
+
"num_heads": self.num_heads,
|
463
|
+
"dropout_rate": self.dropout_rate,
|
464
|
+
}
|
465
|
+
)
|
466
|
+
return config
|
467
|
+
|
468
|
+
def compute_output_shape(self, input_shape):
|
469
|
+
return input_shape
|
470
|
+
|
471
|
+
|
472
|
+
class DINOV2LayerScale(layers.Layer):
|
473
|
+
"""A layer scale.
|
474
|
+
|
475
|
+
Args:
|
476
|
+
hidden_dim: int. The number of units in the hidden layers.
|
477
|
+
init_values: float. The initial value for the scale. Defaults to `1.0`.
|
478
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
479
|
+
including `name`, `dtype` etc.
|
480
|
+
"""
|
481
|
+
|
482
|
+
def __init__(self, hidden_dim, init_values=1.0, **kwargs):
|
483
|
+
super().__init__(**kwargs)
|
484
|
+
self.hidden_dim = int(hidden_dim)
|
485
|
+
self.init_values = float(init_values)
|
486
|
+
|
487
|
+
def build(self, input_shape):
|
488
|
+
self.lambda1 = self.add_weight(
|
489
|
+
shape=(self.hidden_dim,),
|
490
|
+
initializer=initializers.Constant(self.init_values),
|
491
|
+
trainable=True,
|
492
|
+
name="lambda1",
|
493
|
+
)
|
494
|
+
|
495
|
+
def call(self, inputs, training=None):
|
496
|
+
return ops.multiply(inputs, self.lambda1)
|
497
|
+
|
498
|
+
def get_config(self):
|
499
|
+
config = super().get_config()
|
500
|
+
config.update({"hidden_dim": self.hidden_dim})
|
501
|
+
return config
|
502
|
+
|
503
|
+
def compute_output_shape(self, input_shape):
|
504
|
+
return input_shape
|
505
|
+
|
506
|
+
|
507
|
+
class DINOV2DropPath(layers.Layer):
|
508
|
+
"""A drop path layer.
|
509
|
+
|
510
|
+
Args:
|
511
|
+
rate: float. The drop path rate to use. Defaults to `0.0`.
|
512
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
513
|
+
including `name`, `dtype` etc.
|
514
|
+
"""
|
515
|
+
|
516
|
+
def __init__(self, rate=0.0, **kwargs):
|
517
|
+
super().__init__(**kwargs)
|
518
|
+
self.rate = float(rate)
|
519
|
+
|
520
|
+
def build(self, input_shape):
|
521
|
+
self.noise_shape = (input_shape[0],) + (1,) * (len(input_shape) - 1)
|
522
|
+
|
523
|
+
def call(self, inputs, training=None):
|
524
|
+
if not training or self.rate == 0.0:
|
525
|
+
return inputs
|
526
|
+
|
527
|
+
keep_prob = 1.0 - self.rate
|
528
|
+
random_tensor = random.uniform(self.noise_shape, dtype=inputs.dtype)
|
529
|
+
random_tensor = ops.add(random_tensor, keep_prob)
|
530
|
+
return ops.multiply(ops.divide(inputs, keep_prob), random_tensor)
|
531
|
+
|
532
|
+
def get_config(self):
|
533
|
+
config = super().get_config()
|
534
|
+
config.update({"rate": self.rate})
|
535
|
+
return config
|
536
|
+
|
537
|
+
def compute_output_shape(self, input_shape):
|
538
|
+
return input_shape
|
539
|
+
|
540
|
+
|
541
|
+
class DINOV2MLP(layers.Layer):
|
542
|
+
"""A DINOV2 MLP block.
|
543
|
+
|
544
|
+
Args:
|
545
|
+
hidden_dim: int. The number of units in the output layer.
|
546
|
+
intermediate_dim: int. The output dimension of the first Dense layer.
|
547
|
+
activation: str of callable. Activation to use in the intermediate
|
548
|
+
layer. Defaults to `"gelu"`.
|
549
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
550
|
+
including `name`, `dtype` etc.
|
551
|
+
"""
|
552
|
+
|
553
|
+
def __init__(
|
554
|
+
self, hidden_dim, intermediate_dim, activation="gelu", **kwargs
|
555
|
+
):
|
556
|
+
super().__init__(**kwargs)
|
557
|
+
self.hidden_dim = int(hidden_dim)
|
558
|
+
self.intermediate_dim = int(intermediate_dim)
|
559
|
+
self.activation = activation
|
560
|
+
|
561
|
+
self.fc1 = layers.Dense(
|
562
|
+
self.intermediate_dim,
|
563
|
+
activation=activation,
|
564
|
+
kernel_initializer=initializers.TruncatedNormal(stddev=0.02),
|
565
|
+
dtype=self.dtype_policy,
|
566
|
+
name="fc1",
|
567
|
+
)
|
568
|
+
self.fc2 = layers.Dense(
|
569
|
+
self.hidden_dim,
|
570
|
+
kernel_initializer=initializers.TruncatedNormal(stddev=0.02),
|
571
|
+
dtype=self.dtype_policy,
|
572
|
+
name="fc2",
|
573
|
+
)
|
574
|
+
|
575
|
+
def build(self, input_shape):
|
576
|
+
self.fc1.build(input_shape)
|
577
|
+
input_shape = self.fc1.compute_output_shape(input_shape)
|
578
|
+
self.fc2.build(input_shape)
|
579
|
+
|
580
|
+
def call(self, inputs, training=None):
|
581
|
+
x = self.fc1(inputs, training=training)
|
582
|
+
x = self.fc2(x, training=training)
|
583
|
+
return x
|
584
|
+
|
585
|
+
def get_config(self):
|
586
|
+
config = super().get_config()
|
587
|
+
config.update(
|
588
|
+
{
|
589
|
+
"hidden_dim": self.hidden_dim,
|
590
|
+
"intermediate_dim": self.intermediate_dim,
|
591
|
+
"activation": self.activation,
|
592
|
+
}
|
593
|
+
)
|
594
|
+
return config
|
595
|
+
|
596
|
+
def compute_output_shape(self, input_shape):
|
597
|
+
output_shape = list(input_shape)
|
598
|
+
output_shape[-1] = self.hidden_dim
|
599
|
+
return output_shape
|
600
|
+
|
601
|
+
|
602
|
+
class DINOV2SwiGLUFFN(layers.Layer):
|
603
|
+
"""A DINOV2 SwiGLU Feed-Forward Network layer.
|
604
|
+
|
605
|
+
Please refer to [GLU Variants Improve Transformer](
|
606
|
+
https://arxiv.org/abs/2002.05202) for more details on SwiGLU.
|
607
|
+
|
608
|
+
Args:
|
609
|
+
hidden_dim: int. The number of units in the output layer.
|
610
|
+
intermediate_dim: int. The output dimension of the first Dense layer.
|
611
|
+
Note that this value will be multiplied by `2 / 3` and rounded up to
|
612
|
+
the nearest multiple of `8`. The reason for this is that SwiGLUFFN
|
613
|
+
achieves similar or better performance with fewer parameters
|
614
|
+
compared to the original FFN implementation.
|
615
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
616
|
+
including `name`, `dtype` etc.
|
617
|
+
"""
|
618
|
+
|
619
|
+
def __init__(self, hidden_dim, intermediate_dim, **kwargs):
|
620
|
+
super().__init__(**kwargs)
|
621
|
+
self.hidden_dim = int(hidden_dim)
|
622
|
+
self.intermediate_dim = int(intermediate_dim)
|
623
|
+
self.actual_intermediate_dim = (
|
624
|
+
(int(intermediate_dim * 2 / 3) + 7) // 8 * 8
|
625
|
+
)
|
626
|
+
|
627
|
+
self.weights_in = layers.Dense(
|
628
|
+
2 * self.actual_intermediate_dim,
|
629
|
+
kernel_initializer=initializers.TruncatedNormal(stddev=0.02),
|
630
|
+
dtype=self.dtype_policy,
|
631
|
+
name="weights_in",
|
632
|
+
)
|
633
|
+
self.weights_out = layers.Dense(
|
634
|
+
self.hidden_dim,
|
635
|
+
kernel_initializer=initializers.TruncatedNormal(stddev=0.02),
|
636
|
+
dtype=self.dtype_policy,
|
637
|
+
name="weights_out",
|
638
|
+
)
|
639
|
+
|
640
|
+
def build(self, input_shape):
|
641
|
+
self.weights_in.build(input_shape)
|
642
|
+
input_shape = list(input_shape)
|
643
|
+
input_shape[-1] = self.actual_intermediate_dim
|
644
|
+
self.weights_out.build(input_shape)
|
645
|
+
|
646
|
+
def call(self, inputs, training=None):
|
647
|
+
x = self.weights_in(inputs, training=training)
|
648
|
+
x1, x2 = ops.split(x, 2, axis=-1)
|
649
|
+
x = ops.multiply(ops.silu(x1), x2)
|
650
|
+
x = self.weights_out(x, training=training)
|
651
|
+
return x
|
652
|
+
|
653
|
+
def get_config(self):
|
654
|
+
config = super().get_config()
|
655
|
+
config.update(
|
656
|
+
{
|
657
|
+
"hidden_dim": self.hidden_dim,
|
658
|
+
"intermediate_dim": self.intermediate_dim,
|
659
|
+
}
|
660
|
+
)
|
661
|
+
return config
|
662
|
+
|
663
|
+
def compute_output_shape(self, input_shape):
|
664
|
+
output_shape = list(input_shape)
|
665
|
+
output_shape[-1] = self.hidden_dim
|
666
|
+
return output_shape
|
667
|
+
|
668
|
+
|
669
|
+
class DINOV2Layer(layers.Layer):
|
670
|
+
"""A DINOV2 encoder layer.
|
671
|
+
|
672
|
+
Args:
|
673
|
+
hidden_dim: int. The number of units in the hidden layers.
|
674
|
+
num_heads: int. Number of attention heads.
|
675
|
+
layer_scale_init_value: float. The initial value for the scale.
|
676
|
+
Defaults to `1.0`.
|
677
|
+
intermediate_dim: int. The output dimension of the first Dense layer in
|
678
|
+
a two-layer feedforward network for each transformer.
|
679
|
+
use_swiglu_ffn: bool. Whether to use SwigLUFFN instead of MLP.
|
680
|
+
Defaults to `False`.
|
681
|
+
dropout_rate: float. The dropout rate to use. Defaults to `0.0`.
|
682
|
+
drop_path_rate: float. The drop path rate to use. Defaults to `0.0`.
|
683
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
684
|
+
including `name`, `dtype` etc.
|
685
|
+
"""
|
686
|
+
|
687
|
+
def __init__(
|
688
|
+
self,
|
689
|
+
hidden_dim,
|
690
|
+
num_heads,
|
691
|
+
intermediate_dim,
|
692
|
+
layer_scale_init_value=1.0,
|
693
|
+
use_swiglu_ffn=False,
|
694
|
+
dropout_rate=0.0,
|
695
|
+
drop_path_rate=0.0,
|
696
|
+
**kwargs,
|
697
|
+
):
|
698
|
+
super().__init__(**kwargs)
|
699
|
+
self.hidden_dim = int(hidden_dim)
|
700
|
+
self.num_heads = int(num_heads)
|
701
|
+
self.intermediate_dim = int(intermediate_dim)
|
702
|
+
self.layer_scale_init_value = float(layer_scale_init_value)
|
703
|
+
self.use_swiglu_ffn = bool(use_swiglu_ffn)
|
704
|
+
self.dropout_rate = float(dropout_rate)
|
705
|
+
self.drop_path_rate = float(drop_path_rate)
|
706
|
+
|
707
|
+
self.norm1 = layers.LayerNormalization(
|
708
|
+
epsilon=1e-6, dtype=self.dtype_policy, name="norm1"
|
709
|
+
)
|
710
|
+
self.attention = DINOV2Attention(
|
711
|
+
hidden_dim=self.hidden_dim,
|
712
|
+
num_heads=self.num_heads,
|
713
|
+
dropout_rate=self.dropout_rate,
|
714
|
+
dtype=self.dtype_policy,
|
715
|
+
name="attention",
|
716
|
+
)
|
717
|
+
self.layer_scale1 = DINOV2LayerScale(
|
718
|
+
hidden_dim=self.hidden_dim,
|
719
|
+
init_values=self.layer_scale_init_value,
|
720
|
+
dtype=self.dtype_policy,
|
721
|
+
name="layer_scale1",
|
722
|
+
)
|
723
|
+
self.drop_path = (
|
724
|
+
DINOV2DropPath(
|
725
|
+
rate=self.drop_path_rate,
|
726
|
+
dtype=self.dtype_policy,
|
727
|
+
name="drop_path",
|
728
|
+
)
|
729
|
+
if self.drop_path_rate > 0
|
730
|
+
else layers.Identity(dtype=self.dtype_policy, name="drop_path")
|
731
|
+
)
|
732
|
+
self.norm2 = layers.LayerNormalization(
|
733
|
+
epsilon=1e-6, dtype=self.dtype_policy, name="norm2"
|
734
|
+
)
|
735
|
+
if self.use_swiglu_ffn:
|
736
|
+
self.mlp = DINOV2SwiGLUFFN(
|
737
|
+
hidden_dim=self.hidden_dim,
|
738
|
+
intermediate_dim=self.intermediate_dim,
|
739
|
+
dtype=self.dtype_policy,
|
740
|
+
name="mlp",
|
741
|
+
)
|
742
|
+
else:
|
743
|
+
self.mlp = DINOV2MLP(
|
744
|
+
hidden_dim=self.hidden_dim,
|
745
|
+
intermediate_dim=self.intermediate_dim,
|
746
|
+
activation="gelu",
|
747
|
+
dtype=self.dtype_policy,
|
748
|
+
name="mlp",
|
749
|
+
)
|
750
|
+
self.layer_scale2 = DINOV2LayerScale(
|
751
|
+
hidden_dim=self.hidden_dim,
|
752
|
+
init_values=self.layer_scale_init_value,
|
753
|
+
dtype=self.dtype_policy,
|
754
|
+
name="layer_scale2",
|
755
|
+
)
|
756
|
+
|
757
|
+
def build(self, input_shape):
|
758
|
+
self.norm1.build(input_shape)
|
759
|
+
self.attention.build(input_shape)
|
760
|
+
input_shape = self.attention.compute_output_shape(input_shape)
|
761
|
+
self.layer_scale1.build(input_shape)
|
762
|
+
self.drop_path.build(input_shape)
|
763
|
+
self.norm2.build(input_shape)
|
764
|
+
self.mlp.build(input_shape)
|
765
|
+
input_shape = self.mlp.compute_output_shape(input_shape)
|
766
|
+
self.layer_scale2.build(input_shape)
|
767
|
+
|
768
|
+
def call(self, inputs, training=None):
|
769
|
+
x = inputs
|
770
|
+
x = self.norm1(x, training=training)
|
771
|
+
x = self.attention(x, training=training)
|
772
|
+
x = self.layer_scale1(x, training=training)
|
773
|
+
|
774
|
+
# First residual connection.
|
775
|
+
hidden_states = ops.add(self.drop_path(x, training=training), inputs)
|
776
|
+
x = self.norm2(hidden_states, training=training)
|
777
|
+
x = self.mlp(x, training=training)
|
778
|
+
x = self.layer_scale2(x, training=training)
|
779
|
+
|
780
|
+
# Second residual connection.
|
781
|
+
return ops.add(self.drop_path(x, training=training), hidden_states)
|
782
|
+
|
783
|
+
def get_config(self):
|
784
|
+
config = super().get_config()
|
785
|
+
config.update(
|
786
|
+
{
|
787
|
+
"hidden_dim": self.hidden_dim,
|
788
|
+
"num_heads": self.num_heads,
|
789
|
+
"intermediate_dim": self.intermediate_dim,
|
790
|
+
"layer_scale_init_value": self.layer_scale_init_value,
|
791
|
+
"use_swiglu_ffn": self.use_swiglu_ffn,
|
792
|
+
"dropout_rate": self.dropout_rate,
|
793
|
+
"drop_path_rate": self.drop_path_rate,
|
794
|
+
}
|
795
|
+
)
|
796
|
+
return config
|
797
|
+
|
798
|
+
def compute_output_shape(self, input_shape):
|
799
|
+
return input_shape
|
800
|
+
|
801
|
+
|
802
|
+
class DINOV2Encoder(layers.Layer):
|
803
|
+
"""A DINOV2 encoder.
|
804
|
+
|
805
|
+
Args:
|
806
|
+
num_layers: int. The number of transformer layers.
|
807
|
+
hidden_dim: int. The number of units in the hidden layers.
|
808
|
+
num_heads: int. Number of attention heads.
|
809
|
+
intermediate_dim: int. The output dimension of the first Dense layer in
|
810
|
+
a two-layer feedforward network for each transformer.
|
811
|
+
layer_scale_init_value: float. The initial value for the scale.
|
812
|
+
Defaults to `1.0`.
|
813
|
+
use_swiglu_ffn: bool. Whether to use SwigLUFFN instead of MLP.
|
814
|
+
Defaults to `False`.
|
815
|
+
dropout_rate: float. The dropout rate to use. Defaults to `0.0`.
|
816
|
+
drop_path_rate: float. The drop path rate to use. Defaults to `0.0`.
|
817
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
818
|
+
including `name`, `dtype` etc.
|
819
|
+
"""
|
820
|
+
|
821
|
+
def __init__(
|
822
|
+
self,
|
823
|
+
num_layers,
|
824
|
+
hidden_dim,
|
825
|
+
num_heads,
|
826
|
+
intermediate_dim,
|
827
|
+
layer_scale_init_value=1.0,
|
828
|
+
use_swiglu_ffn=False,
|
829
|
+
dropout_rate=0.0,
|
830
|
+
drop_path_rate=0.0,
|
831
|
+
**kwargs,
|
832
|
+
):
|
833
|
+
super().__init__(**kwargs)
|
834
|
+
self.num_layers = int(num_layers)
|
835
|
+
self.hidden_dim = int(hidden_dim)
|
836
|
+
self.num_heads = int(num_heads)
|
837
|
+
self.intermediate_dim = int(intermediate_dim)
|
838
|
+
self.layer_scale_init_value = float(layer_scale_init_value)
|
839
|
+
self.use_swiglu_ffn = bool(use_swiglu_ffn)
|
840
|
+
self.dropout_rate = float(dropout_rate)
|
841
|
+
self.drop_path_rate = float(drop_path_rate)
|
842
|
+
|
843
|
+
self.layers = [
|
844
|
+
DINOV2Layer(
|
845
|
+
hidden_dim=self.hidden_dim,
|
846
|
+
num_heads=self.num_heads,
|
847
|
+
intermediate_dim=self.intermediate_dim,
|
848
|
+
layer_scale_init_value=self.layer_scale_init_value,
|
849
|
+
use_swiglu_ffn=self.use_swiglu_ffn,
|
850
|
+
dropout_rate=self.dropout_rate,
|
851
|
+
drop_path_rate=self.drop_path_rate,
|
852
|
+
dtype=self.dtype_policy,
|
853
|
+
name=f"layers_{i}",
|
854
|
+
)
|
855
|
+
for i in range(self.num_layers)
|
856
|
+
]
|
857
|
+
|
858
|
+
def build(self, input_shape):
|
859
|
+
for layer in self.layers:
|
860
|
+
layer.build(input_shape)
|
861
|
+
input_shape = layer.compute_output_shape(input_shape)
|
862
|
+
|
863
|
+
def call(self, inputs, training=None):
|
864
|
+
x = inputs
|
865
|
+
for layer in self.layers:
|
866
|
+
x = layer(x, training=training)
|
867
|
+
return x
|
868
|
+
|
869
|
+
def get_config(self):
|
870
|
+
config = super().get_config()
|
871
|
+
config.update(
|
872
|
+
{
|
873
|
+
"num_layers": self.num_layers,
|
874
|
+
"hidden_dim": self.hidden_dim,
|
875
|
+
"num_heads": self.num_heads,
|
876
|
+
"intermediate_dim": self.intermediate_dim,
|
877
|
+
"layer_scale_init_value": self.layer_scale_init_value,
|
878
|
+
"use_swiglu_ffn": self.use_swiglu_ffn,
|
879
|
+
"dropout_rate": self.dropout_rate,
|
880
|
+
"drop_path_rate": self.drop_path_rate,
|
881
|
+
}
|
882
|
+
)
|
883
|
+
return config
|
884
|
+
|
885
|
+
def compute_output_shape(self, input_shape):
|
886
|
+
return input_shape
|