keras-hub-nightly 0.23.0.dev202510080414__py3-none-any.whl → 0.24.0.dev202511080419__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/layers/__init__.py +6 -0
- keras_hub/models/__init__.py +36 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +6 -0
- keras_hub/src/models/causal_lm.py +5 -0
- keras_hub/src/models/depth_anything/depth_anything_presets.py +38 -1
- keras_hub/src/models/dinov2/dinov2_layers.py +3 -1
- keras_hub/src/models/dinov3/__init__.py +5 -0
- keras_hub/src/models/dinov3/dinov3_backbone.py +263 -0
- keras_hub/src/models/dinov3/dinov3_image_converter.py +8 -0
- keras_hub/src/models/dinov3/dinov3_layers.py +1013 -0
- keras_hub/src/models/dinov3/dinov3_presets.py +4 -0
- keras_hub/src/models/gemma/gemma_presets.py +22 -0
- keras_hub/src/models/gemma3/gemma3_presets.py +39 -0
- keras_hub/src/models/image_to_image.py +5 -0
- keras_hub/src/models/inpaint.py +5 -0
- keras_hub/src/models/mobilenetv5/__init__.py +9 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_attention.py +699 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_backbone.py +396 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_blocks.py +890 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_builder.py +436 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier.py +157 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier_preprocessor.py +16 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_image_converter.py +10 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_layers.py +462 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_presets.py +15 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_utils.py +146 -0
- keras_hub/src/models/parseq/__init__.py +5 -0
- keras_hub/src/models/parseq/parseq_presets.py +15 -0
- keras_hub/src/models/qwen3_moe/__init__.py +5 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +30 -0
- keras_hub/src/models/siglip/siglip_presets.py +15 -0
- keras_hub/src/models/smollm3/smollm3_backbone.py +211 -0
- keras_hub/src/models/smollm3/smollm3_causal_lm.py +310 -0
- keras_hub/src/models/smollm3/smollm3_causal_lm_preprocessor.py +84 -0
- keras_hub/src/models/smollm3/smollm3_layers.py +757 -0
- keras_hub/src/models/smollm3/smollm3_tokenizer.py +60 -0
- keras_hub/src/models/smollm3/smollm3_utils.py +56 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +3 -3
- keras_hub/src/models/text_to_image.py +5 -0
- keras_hub/src/utils/preset_utils.py +9 -2
- keras_hub/src/utils/tensor_utils.py +3 -1
- keras_hub/src/utils/timm/convert_mobilenetv5.py +321 -0
- keras_hub/src/utils/timm/preset_loader.py +8 -4
- keras_hub/src/utils/transformers/convert_dinov3.py +106 -0
- keras_hub/src/utils/transformers/convert_smollm3.py +139 -0
- keras_hub/src/utils/transformers/preset_loader.py +6 -0
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +6 -0
- {keras_hub_nightly-0.23.0.dev202510080414.dist-info → keras_hub_nightly-0.24.0.dev202511080419.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.23.0.dev202510080414.dist-info → keras_hub_nightly-0.24.0.dev202511080419.dist-info}/RECORD +52 -24
- {keras_hub_nightly-0.23.0.dev202510080414.dist-info → keras_hub_nightly-0.24.0.dev202511080419.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.23.0.dev202510080414.dist-info → keras_hub_nightly-0.24.0.dev202511080419.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1013 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
from keras import initializers
|
|
4
|
+
from keras import layers
|
|
5
|
+
from keras import ops
|
|
6
|
+
from keras import random
|
|
7
|
+
|
|
8
|
+
from keras_hub.src.utils.keras_utils import standardize_data_format
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class DINOV3PatchEmbedding(layers.Layer):
|
|
12
|
+
"""A layer that converts images into patches.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
hidden_dim: int. The number of units in the hidden layers.
|
|
16
|
+
patch_size: int. The size of one side of each patch.
|
|
17
|
+
data_format: `None` or str. If specified, either `"channels_last"` or
|
|
18
|
+
`"channels_first"`. The ordering of the dimensions in the
|
|
19
|
+
inputs. `"channels_last"` corresponds to inputs with shape
|
|
20
|
+
`(batch_size, height, width, channels)`
|
|
21
|
+
while `"channels_first"` corresponds to inputs with shape
|
|
22
|
+
`(batch_size, channels, height, width)`. It defaults to the
|
|
23
|
+
`image_data_format` value found in your Keras config file at
|
|
24
|
+
`~/.keras/keras.json`. If you never set it, then it will be
|
|
25
|
+
`"channels_last"`.
|
|
26
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
|
27
|
+
including `name`, `dtype` etc.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(self, hidden_dim, patch_size, data_format=None, **kwargs):
|
|
31
|
+
super().__init__(**kwargs)
|
|
32
|
+
self.hidden_dim = int(hidden_dim)
|
|
33
|
+
self.patch_size = int(patch_size)
|
|
34
|
+
self.data_format = standardize_data_format(data_format)
|
|
35
|
+
|
|
36
|
+
self.projection = layers.Conv2D(
|
|
37
|
+
hidden_dim,
|
|
38
|
+
kernel_size=patch_size,
|
|
39
|
+
strides=patch_size,
|
|
40
|
+
data_format=data_format,
|
|
41
|
+
kernel_initializer=initializers.TruncatedNormal(stddev=0.02),
|
|
42
|
+
dtype=self.dtype_policy,
|
|
43
|
+
name="projection",
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
def build(self, input_shape):
|
|
47
|
+
self.projection.build(input_shape)
|
|
48
|
+
|
|
49
|
+
def call(self, inputs, training=None):
|
|
50
|
+
batch_size = ops.shape(inputs)[0]
|
|
51
|
+
embeddings = self.projection(inputs, training=training)
|
|
52
|
+
if self.data_format == "channels_last":
|
|
53
|
+
embeddings = ops.reshape(
|
|
54
|
+
embeddings, (batch_size, -1, self.hidden_dim)
|
|
55
|
+
)
|
|
56
|
+
else:
|
|
57
|
+
embeddings = ops.reshape(
|
|
58
|
+
embeddings, (batch_size, self.hidden_dim, -1)
|
|
59
|
+
)
|
|
60
|
+
embeddings = ops.transpose(embeddings, (0, 2, 1))
|
|
61
|
+
return embeddings
|
|
62
|
+
|
|
63
|
+
def get_config(self):
|
|
64
|
+
config = super().get_config()
|
|
65
|
+
config.update(
|
|
66
|
+
{
|
|
67
|
+
"hidden_dim": self.hidden_dim,
|
|
68
|
+
"patch_size": self.patch_size,
|
|
69
|
+
}
|
|
70
|
+
)
|
|
71
|
+
return config
|
|
72
|
+
|
|
73
|
+
def compute_output_shape(self, input_shape):
|
|
74
|
+
output_shape = [input_shape[0], None, self.hidden_dim]
|
|
75
|
+
if self.data_format == "channels_last":
|
|
76
|
+
if input_shape[1] is not None and input_shape[2] is not None:
|
|
77
|
+
patch_num = input_shape[1] // self.patch_size
|
|
78
|
+
output_shape[1] = patch_num**2
|
|
79
|
+
else:
|
|
80
|
+
if input_shape[2] is not None and input_shape[3] is not None:
|
|
81
|
+
patch_num = input_shape[2] // self.patch_size
|
|
82
|
+
output_shape[1] = patch_num**2
|
|
83
|
+
return output_shape
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class DINOV3Embedding(layers.Layer):
|
|
87
|
+
"""A layer that converts images into patches.
|
|
88
|
+
|
|
89
|
+
This layer adds all the necessary tokens to the embeddings, inlcuding
|
|
90
|
+
the class token, register tokens and mask token if specified.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
hidden_dim: int. The number of units in the hidden layers.
|
|
94
|
+
patch_size: int. The size of one side of each patch.
|
|
95
|
+
num_register_tokens: int. The number of register tokens to add to the
|
|
96
|
+
embeddings. Defaults to `0`.
|
|
97
|
+
use_mask_token: bool. Whether to use a mask token. Defaults to `True`.
|
|
98
|
+
initializer_range: float. The standard deviation of the truncated
|
|
99
|
+
normal initializer. Defaults to `0.02`.
|
|
100
|
+
data_format: `None` or str. If specified, either `"channels_last"` or
|
|
101
|
+
`"channels_first"`. The ordering of the dimensions in the
|
|
102
|
+
inputs. `"channels_last"` corresponds to inputs with shape
|
|
103
|
+
`(batch_size, height, width, channels)`
|
|
104
|
+
while `"channels_first"` corresponds to inputs with shape
|
|
105
|
+
`(batch_size, channels, height, width)`. It defaults to the
|
|
106
|
+
`image_data_format` value found in your Keras config file at
|
|
107
|
+
`~/.keras/keras.json`. If you never set it, then it will be
|
|
108
|
+
`"channels_last"`.
|
|
109
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
|
110
|
+
including `name`, `dtype` etc.
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
def __init__(
|
|
114
|
+
self,
|
|
115
|
+
hidden_dim,
|
|
116
|
+
patch_size,
|
|
117
|
+
num_register_tokens=0,
|
|
118
|
+
use_mask_token=True,
|
|
119
|
+
initializer_range=0.02,
|
|
120
|
+
data_format=None,
|
|
121
|
+
**kwargs,
|
|
122
|
+
):
|
|
123
|
+
super().__init__(**kwargs)
|
|
124
|
+
self.hidden_dim = int(hidden_dim)
|
|
125
|
+
self.patch_size = int(patch_size)
|
|
126
|
+
self.num_register_tokens = int(num_register_tokens)
|
|
127
|
+
self.use_mask_token = bool(use_mask_token)
|
|
128
|
+
self.initializer_range = float(initializer_range)
|
|
129
|
+
self.data_format = standardize_data_format(data_format)
|
|
130
|
+
|
|
131
|
+
self.patch_embeddings = DINOV3PatchEmbedding(
|
|
132
|
+
hidden_dim,
|
|
133
|
+
patch_size,
|
|
134
|
+
data_format=data_format,
|
|
135
|
+
dtype=self.dtype_policy,
|
|
136
|
+
name="patch_embeddings",
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
def build(self, input_shape):
|
|
140
|
+
self.cls_token = self.add_weight(
|
|
141
|
+
shape=(1, 1, self.hidden_dim),
|
|
142
|
+
initializer=initializers.TruncatedNormal(
|
|
143
|
+
stddev=self.initializer_range
|
|
144
|
+
),
|
|
145
|
+
trainable=True,
|
|
146
|
+
name="cls_token",
|
|
147
|
+
)
|
|
148
|
+
if self.use_mask_token:
|
|
149
|
+
self.mask_token = self.add_weight(
|
|
150
|
+
shape=(1, 1, self.hidden_dim),
|
|
151
|
+
initializer="zeros",
|
|
152
|
+
trainable=True,
|
|
153
|
+
name="mask_token",
|
|
154
|
+
)
|
|
155
|
+
if self.num_register_tokens > 0:
|
|
156
|
+
self.register_tokens = self.add_weight(
|
|
157
|
+
shape=(1, self.num_register_tokens, self.hidden_dim),
|
|
158
|
+
initializer=initializers.TruncatedNormal(
|
|
159
|
+
stddev=self.initializer_range
|
|
160
|
+
),
|
|
161
|
+
trainable=True,
|
|
162
|
+
name="register_tokens",
|
|
163
|
+
)
|
|
164
|
+
self.patch_embeddings.build(input_shape)
|
|
165
|
+
|
|
166
|
+
def call(self, inputs, masks=None, training=None):
|
|
167
|
+
batch_size = ops.shape(inputs)[0]
|
|
168
|
+
embeddings = self.patch_embeddings(inputs, training=training)
|
|
169
|
+
|
|
170
|
+
if masks is not None and self.use_mask_token:
|
|
171
|
+
mask_token = ops.cast(self.mask_token, embeddings.dtype)
|
|
172
|
+
embeddings = ops.where(
|
|
173
|
+
ops.expand_dims(masks, axis=-1),
|
|
174
|
+
mask_token,
|
|
175
|
+
embeddings,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
cls_tokens = ops.tile(self.cls_token, (batch_size, 1, 1))
|
|
179
|
+
embeddings = ops.concatenate((cls_tokens, embeddings), axis=1)
|
|
180
|
+
|
|
181
|
+
if self.num_register_tokens > 0:
|
|
182
|
+
register_tokens = ops.tile(self.register_tokens, (batch_size, 1, 1))
|
|
183
|
+
embeddings = ops.concatenate(
|
|
184
|
+
(
|
|
185
|
+
embeddings[:, :1, ...],
|
|
186
|
+
register_tokens,
|
|
187
|
+
embeddings[:, 1:, ...],
|
|
188
|
+
),
|
|
189
|
+
axis=1,
|
|
190
|
+
)
|
|
191
|
+
return embeddings
|
|
192
|
+
|
|
193
|
+
def get_config(self):
|
|
194
|
+
config = super().get_config()
|
|
195
|
+
config.update(
|
|
196
|
+
{
|
|
197
|
+
"hidden_dim": self.hidden_dim,
|
|
198
|
+
"patch_size": self.patch_size,
|
|
199
|
+
"num_register_tokens": self.num_register_tokens,
|
|
200
|
+
"use_mask_token": self.use_mask_token,
|
|
201
|
+
"initializer_range": self.initializer_range,
|
|
202
|
+
}
|
|
203
|
+
)
|
|
204
|
+
return config
|
|
205
|
+
|
|
206
|
+
def compute_output_shape(self, input_shape):
|
|
207
|
+
output_shape = [input_shape[0], None, self.hidden_dim]
|
|
208
|
+
if self.data_format == "channels_last":
|
|
209
|
+
if input_shape[1] is not None and input_shape[2] is not None:
|
|
210
|
+
patch_num = input_shape[1] // self.patch_size
|
|
211
|
+
output_shape[1] = 1 + self.num_register_tokens + patch_num**2
|
|
212
|
+
else:
|
|
213
|
+
if input_shape[2] is not None and input_shape[3] is not None:
|
|
214
|
+
patch_num = input_shape[2] // self.patch_size
|
|
215
|
+
output_shape[1] = 1 + self.num_register_tokens + patch_num**2
|
|
216
|
+
return output_shape
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
class DINOV3RopePositionEmbedding(layers.Layer):
|
|
220
|
+
"""A layer that implements Rotary Position Embedding.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
hidden_dim: int. The number of units in the hidden layers.
|
|
224
|
+
num_heads: int. Number of attention heads.
|
|
225
|
+
rope_theta: float. The base period of the rotary position embeddings.
|
|
226
|
+
patch_size: int. The size of one side of each patch.
|
|
227
|
+
data_format: `None` or str. If specified, either `"channels_last"` or
|
|
228
|
+
`"channels_first"`. The ordering of the dimensions in the
|
|
229
|
+
inputs. `"channels_last"` corresponds to inputs with shape
|
|
230
|
+
`(batch_size, height, width, channels)`
|
|
231
|
+
while `"channels_first"` corresponds to inputs with shape
|
|
232
|
+
`(batch_size, channels, height, width)`. It defaults to the
|
|
233
|
+
`image_data_format` value found in your Keras config file at
|
|
234
|
+
`~/.keras/keras.json`. If you never set it, then it will be
|
|
235
|
+
`"channels_last"`.
|
|
236
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
|
237
|
+
including `name`, `dtype` etc.
|
|
238
|
+
"""
|
|
239
|
+
|
|
240
|
+
def __init__(
|
|
241
|
+
self,
|
|
242
|
+
hidden_dim,
|
|
243
|
+
num_heads,
|
|
244
|
+
rope_theta,
|
|
245
|
+
patch_size,
|
|
246
|
+
data_format=None,
|
|
247
|
+
**kwargs,
|
|
248
|
+
):
|
|
249
|
+
super().__init__(**kwargs)
|
|
250
|
+
self.hidden_dim = int(hidden_dim)
|
|
251
|
+
self.num_heads = int(num_heads)
|
|
252
|
+
self.rope_theta = float(rope_theta)
|
|
253
|
+
self.patch_size = int(patch_size)
|
|
254
|
+
self.data_format = standardize_data_format(data_format)
|
|
255
|
+
self.head_dim = hidden_dim // num_heads
|
|
256
|
+
self.inv_freq = 1.0 / (
|
|
257
|
+
rope_theta ** (ops.arange(0, 1, 4 / self.head_dim, dtype="float32"))
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
def _get_patches_center_coordinates(
|
|
261
|
+
self, num_patches_h, num_patches_w, dtype="float32"
|
|
262
|
+
):
|
|
263
|
+
"""A helper function to get the center coordinates of the patches."""
|
|
264
|
+
coords_h = ops.arange(0.5, num_patches_h, dtype=dtype)
|
|
265
|
+
coords_w = ops.arange(0.5, num_patches_w, dtype=dtype)
|
|
266
|
+
|
|
267
|
+
coords_h = coords_h / num_patches_h
|
|
268
|
+
coords_w = coords_w / num_patches_w
|
|
269
|
+
|
|
270
|
+
coords_h = ops.expand_dims(coords_h, axis=1)
|
|
271
|
+
coords_w = ops.expand_dims(coords_w, axis=0)
|
|
272
|
+
|
|
273
|
+
coords_h = ops.repeat(coords_h, num_patches_w, axis=1)
|
|
274
|
+
coords_w = ops.repeat(coords_w, num_patches_h, axis=0)
|
|
275
|
+
|
|
276
|
+
coords = ops.stack([coords_h, coords_w], axis=-1)
|
|
277
|
+
coords = ops.reshape(coords, (-1, 2))
|
|
278
|
+
coords = 2.0 * coords - 1.0
|
|
279
|
+
return coords
|
|
280
|
+
|
|
281
|
+
def call(self, inputs):
|
|
282
|
+
shape = ops.shape(inputs)
|
|
283
|
+
if self.data_format == "channels_last":
|
|
284
|
+
height, width = shape[1], shape[2]
|
|
285
|
+
else:
|
|
286
|
+
height, width = shape[2], shape[3]
|
|
287
|
+
num_patches_h = height // self.patch_size
|
|
288
|
+
num_patches_w = width // self.patch_size
|
|
289
|
+
|
|
290
|
+
patch_coords = self._get_patches_center_coordinates(
|
|
291
|
+
num_patches_h, num_patches_w, dtype="float32"
|
|
292
|
+
)
|
|
293
|
+
angles = (
|
|
294
|
+
2
|
|
295
|
+
* math.pi
|
|
296
|
+
* ops.expand_dims(patch_coords, axis=-1)
|
|
297
|
+
* ops.expand_dims(ops.expand_dims(self.inv_freq, axis=0), axis=0)
|
|
298
|
+
)
|
|
299
|
+
angles = ops.reshape(angles, (ops.shape(angles)[0], -1))
|
|
300
|
+
angles = ops.tile(angles, (1, 2))
|
|
301
|
+
|
|
302
|
+
cos = ops.cast(ops.cos(angles), inputs.dtype)
|
|
303
|
+
sin = ops.cast(ops.sin(angles), inputs.dtype)
|
|
304
|
+
return cos, sin
|
|
305
|
+
|
|
306
|
+
def get_config(self):
|
|
307
|
+
config = super().get_config()
|
|
308
|
+
config.update(
|
|
309
|
+
{
|
|
310
|
+
"hidden_dim": self.hidden_dim,
|
|
311
|
+
"num_heads": self.num_heads,
|
|
312
|
+
"rope_theta": self.rope_theta,
|
|
313
|
+
"patch_size": self.patch_size,
|
|
314
|
+
}
|
|
315
|
+
)
|
|
316
|
+
return config
|
|
317
|
+
|
|
318
|
+
def compute_output_shape(self, input_shape):
|
|
319
|
+
output_shape = input_shape
|
|
320
|
+
if self.data_format == "channels_last":
|
|
321
|
+
height, width = input_shape[1], input_shape[2]
|
|
322
|
+
else:
|
|
323
|
+
height, width = input_shape[2], input_shape[3]
|
|
324
|
+
num_patches_h = height // self.patch_size
|
|
325
|
+
num_patches_w = width // self.patch_size
|
|
326
|
+
seq_len = num_patches_h * num_patches_w
|
|
327
|
+
output_shape = (seq_len, self.head_dim)
|
|
328
|
+
return output_shape, output_shape
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
class DINOV3Attention(layers.Layer):
|
|
332
|
+
"""A multi-head attention layer with dropout.
|
|
333
|
+
|
|
334
|
+
Args:
|
|
335
|
+
hidden_dim: int. The number of units in the hidden layers.
|
|
336
|
+
num_heads: int. Number of attention heads.
|
|
337
|
+
dropout_rate: float. The dropout rate to use. Defaults to `0.0`.
|
|
338
|
+
use_query_bias: bool. Whether to use a bias for the query projection.
|
|
339
|
+
use_key_bias: bool. Whether to use a bias for the key projection.
|
|
340
|
+
use_value_bias: bool. Whether to use a bias for the value projection.
|
|
341
|
+
use_proj_bias: bool. Whether to use a bias for the output projection.
|
|
342
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
|
343
|
+
including `name`, `dtype` etc.
|
|
344
|
+
"""
|
|
345
|
+
|
|
346
|
+
def __init__(
|
|
347
|
+
self,
|
|
348
|
+
hidden_dim,
|
|
349
|
+
num_heads,
|
|
350
|
+
dropout_rate=0.0,
|
|
351
|
+
use_query_bias=True,
|
|
352
|
+
use_key_bias=True,
|
|
353
|
+
use_value_bias=True,
|
|
354
|
+
use_proj_bias=True,
|
|
355
|
+
**kwargs,
|
|
356
|
+
):
|
|
357
|
+
super().__init__(**kwargs)
|
|
358
|
+
self.hidden_dim = int(hidden_dim)
|
|
359
|
+
self.num_heads = int(num_heads)
|
|
360
|
+
self.dropout_rate = float(dropout_rate)
|
|
361
|
+
self.use_query_bias = bool(use_query_bias)
|
|
362
|
+
self.use_key_bias = bool(use_key_bias)
|
|
363
|
+
self.use_value_bias = bool(use_value_bias)
|
|
364
|
+
self.use_proj_bias = bool(use_proj_bias)
|
|
365
|
+
self.head_dim = hidden_dim // num_heads
|
|
366
|
+
self.scale = self.head_dim**-0.5
|
|
367
|
+
|
|
368
|
+
self.query_dense = layers.Dense(
|
|
369
|
+
hidden_dim,
|
|
370
|
+
use_bias=use_query_bias,
|
|
371
|
+
dtype=self.dtype_policy,
|
|
372
|
+
name="q_proj",
|
|
373
|
+
)
|
|
374
|
+
self.key_dense = layers.Dense(
|
|
375
|
+
hidden_dim,
|
|
376
|
+
use_bias=use_key_bias,
|
|
377
|
+
dtype=self.dtype_policy,
|
|
378
|
+
name="k_proj",
|
|
379
|
+
)
|
|
380
|
+
self.value_dense = layers.Dense(
|
|
381
|
+
hidden_dim,
|
|
382
|
+
use_bias=use_value_bias,
|
|
383
|
+
dtype=self.dtype_policy,
|
|
384
|
+
name="v_proj",
|
|
385
|
+
)
|
|
386
|
+
self.output_dense = layers.Dense(
|
|
387
|
+
hidden_dim,
|
|
388
|
+
use_bias=use_proj_bias,
|
|
389
|
+
dtype=self.dtype_policy,
|
|
390
|
+
name="o_proj",
|
|
391
|
+
)
|
|
392
|
+
self.dropout = layers.Dropout(
|
|
393
|
+
dropout_rate, dtype=self.dtype_policy, name="dropout"
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
def build(self, input_shape):
|
|
397
|
+
self.query_dense.build(input_shape)
|
|
398
|
+
self.key_dense.build(input_shape)
|
|
399
|
+
self.value_dense.build(input_shape)
|
|
400
|
+
self.output_dense.build(input_shape)
|
|
401
|
+
|
|
402
|
+
def _apply_rotary(self, q, k, cos, sin, num_prefix_tokens):
|
|
403
|
+
"""Apply rotary position embedding to query and key."""
|
|
404
|
+
|
|
405
|
+
def _rotate_half(x):
|
|
406
|
+
"""A helper function to rotate half of the features."""
|
|
407
|
+
x1 = x[..., : ops.shape(x)[-1] // 2]
|
|
408
|
+
x2 = x[..., ops.shape(x)[-1] // 2 :]
|
|
409
|
+
return ops.concatenate([-x2, x1], axis=-1)
|
|
410
|
+
|
|
411
|
+
q_prefix_tokens = q[:, :num_prefix_tokens, :, :]
|
|
412
|
+
q_patches = q[:, num_prefix_tokens:, :, :]
|
|
413
|
+
k_prefix_tokens = k[:, :num_prefix_tokens, :, :]
|
|
414
|
+
k_patches = k[:, num_prefix_tokens:, :, :]
|
|
415
|
+
cos = ops.expand_dims(ops.expand_dims(cos, axis=0), axis=2)
|
|
416
|
+
sin = ops.expand_dims(ops.expand_dims(sin, axis=0), axis=2)
|
|
417
|
+
|
|
418
|
+
q_patches = (q_patches * cos) + (_rotate_half(q_patches) * sin)
|
|
419
|
+
k_patches = (k_patches * cos) + (_rotate_half(k_patches) * sin)
|
|
420
|
+
q = ops.concatenate([q_prefix_tokens, q_patches], axis=-3)
|
|
421
|
+
k = ops.concatenate([k_prefix_tokens, k_patches], axis=-3)
|
|
422
|
+
return q, k
|
|
423
|
+
|
|
424
|
+
def call(
|
|
425
|
+
self,
|
|
426
|
+
inputs,
|
|
427
|
+
attention_mask=None,
|
|
428
|
+
position_embeddings=None,
|
|
429
|
+
num_prefix_tokens=0,
|
|
430
|
+
training=None,
|
|
431
|
+
):
|
|
432
|
+
batch_size, seq_len, _ = ops.shape(inputs)
|
|
433
|
+
q = self.query_dense(inputs, training=training)
|
|
434
|
+
k = self.key_dense(inputs, training=training)
|
|
435
|
+
v = self.value_dense(inputs, training=training)
|
|
436
|
+
q = ops.reshape(q, (batch_size, seq_len, self.num_heads, self.head_dim))
|
|
437
|
+
k = ops.reshape(k, (batch_size, seq_len, self.num_heads, self.head_dim))
|
|
438
|
+
v = ops.reshape(v, (batch_size, seq_len, self.num_heads, self.head_dim))
|
|
439
|
+
if position_embeddings is not None:
|
|
440
|
+
cos, sin = position_embeddings
|
|
441
|
+
q, k = self._apply_rotary(q, k, cos, sin, num_prefix_tokens)
|
|
442
|
+
|
|
443
|
+
attn_output = ops.nn.dot_product_attention(
|
|
444
|
+
q,
|
|
445
|
+
k,
|
|
446
|
+
v,
|
|
447
|
+
mask=attention_mask,
|
|
448
|
+
scale=self.scale,
|
|
449
|
+
is_causal=False,
|
|
450
|
+
)
|
|
451
|
+
attn_output = ops.reshape(attn_output, (batch_size, seq_len, -1))
|
|
452
|
+
attn_output = self.dropout(attn_output, training=training)
|
|
453
|
+
return self.output_dense(attn_output, training=training)
|
|
454
|
+
|
|
455
|
+
def get_config(self):
|
|
456
|
+
config = super().get_config()
|
|
457
|
+
config.update(
|
|
458
|
+
{
|
|
459
|
+
"hidden_dim": self.hidden_dim,
|
|
460
|
+
"num_heads": self.num_heads,
|
|
461
|
+
"dropout_rate": self.dropout_rate,
|
|
462
|
+
"query_bias": self.use_query_bias,
|
|
463
|
+
"key_bias": self.use_key_bias,
|
|
464
|
+
"value_bias": self.use_value_bias,
|
|
465
|
+
"proj_bias": self.use_proj_bias,
|
|
466
|
+
}
|
|
467
|
+
)
|
|
468
|
+
return config
|
|
469
|
+
|
|
470
|
+
def compute_output_shape(self, input_shape):
|
|
471
|
+
return input_shape
|
|
472
|
+
|
|
473
|
+
|
|
474
|
+
class DINOV3LayerScale(layers.Layer):
|
|
475
|
+
"""A layer scale.
|
|
476
|
+
|
|
477
|
+
Args:
|
|
478
|
+
hidden_dim: int. The number of units in the hidden layers.
|
|
479
|
+
init_values: float. The initial value for the scale. Defaults to `1.0`.
|
|
480
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
|
481
|
+
including `name`, `dtype` etc.
|
|
482
|
+
"""
|
|
483
|
+
|
|
484
|
+
def __init__(self, hidden_dim, init_values=1.0, **kwargs):
|
|
485
|
+
super().__init__(**kwargs)
|
|
486
|
+
self.hidden_dim = int(hidden_dim)
|
|
487
|
+
self.init_values = float(init_values)
|
|
488
|
+
|
|
489
|
+
def build(self, input_shape):
|
|
490
|
+
self.lambda1 = self.add_weight(
|
|
491
|
+
shape=(self.hidden_dim,),
|
|
492
|
+
initializer=initializers.Constant(self.init_values),
|
|
493
|
+
trainable=True,
|
|
494
|
+
name="lambda1",
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
def call(self, inputs, training=None):
|
|
498
|
+
return ops.multiply(inputs, self.lambda1)
|
|
499
|
+
|
|
500
|
+
def get_config(self):
|
|
501
|
+
config = super().get_config()
|
|
502
|
+
config.update(
|
|
503
|
+
{"hidden_dim": self.hidden_dim, "init_values": self.init_values}
|
|
504
|
+
)
|
|
505
|
+
return config
|
|
506
|
+
|
|
507
|
+
|
|
508
|
+
class DINOV3DropPath(layers.Layer):
|
|
509
|
+
"""A drop path layer.
|
|
510
|
+
|
|
511
|
+
Args:
|
|
512
|
+
rate: float. The drop path rate to use. Defaults to `0.0`.
|
|
513
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
|
514
|
+
including `name`, `dtype` etc.
|
|
515
|
+
"""
|
|
516
|
+
|
|
517
|
+
def __init__(self, rate=0.0, **kwargs):
|
|
518
|
+
super().__init__(**kwargs)
|
|
519
|
+
self.rate = float(rate)
|
|
520
|
+
|
|
521
|
+
def build(self, input_shape):
|
|
522
|
+
self.noise_shape = (input_shape[0],) + (1,) * (len(input_shape) - 1)
|
|
523
|
+
|
|
524
|
+
def call(self, inputs, training=None):
|
|
525
|
+
if not training or self.rate == 0.0:
|
|
526
|
+
return inputs
|
|
527
|
+
|
|
528
|
+
keep_prob = 1.0 - self.rate
|
|
529
|
+
random_tensor = random.uniform(self.noise_shape, dtype=inputs.dtype)
|
|
530
|
+
random_tensor = ops.add(random_tensor, keep_prob)
|
|
531
|
+
return ops.multiply(ops.divide(inputs, keep_prob), random_tensor)
|
|
532
|
+
|
|
533
|
+
def get_config(self):
|
|
534
|
+
config = super().get_config()
|
|
535
|
+
config.update({"rate": self.rate})
|
|
536
|
+
return config
|
|
537
|
+
|
|
538
|
+
def compute_output_shape(self, input_shape):
|
|
539
|
+
return input_shape
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
class DINOV3MLP(layers.Layer):
|
|
543
|
+
"""A DINOV3 MLP block.
|
|
544
|
+
|
|
545
|
+
Args:
|
|
546
|
+
hidden_dim: int. The number of units in the output layer.
|
|
547
|
+
intermediate_dim: int. The output dimension of the first Dense layer.
|
|
548
|
+
activation: str of callable. Activation to use in the intermediate
|
|
549
|
+
layer. Defaults to `"gelu"`.
|
|
550
|
+
use_bias: bool. Whether to use a bias for the dense layers.
|
|
551
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
|
552
|
+
including `name`, `dtype` etc.
|
|
553
|
+
"""
|
|
554
|
+
|
|
555
|
+
def __init__(
|
|
556
|
+
self,
|
|
557
|
+
hidden_dim,
|
|
558
|
+
intermediate_dim,
|
|
559
|
+
activation="gelu",
|
|
560
|
+
use_bias=True,
|
|
561
|
+
**kwargs,
|
|
562
|
+
):
|
|
563
|
+
super().__init__(**kwargs)
|
|
564
|
+
self.hidden_dim = int(hidden_dim)
|
|
565
|
+
self.intermediate_dim = int(intermediate_dim)
|
|
566
|
+
self.activation = activation
|
|
567
|
+
self.use_bias = bool(use_bias)
|
|
568
|
+
|
|
569
|
+
self.up_proj = layers.Dense(
|
|
570
|
+
intermediate_dim,
|
|
571
|
+
activation=activation,
|
|
572
|
+
use_bias=use_bias,
|
|
573
|
+
dtype=self.dtype_policy,
|
|
574
|
+
name="up_proj",
|
|
575
|
+
)
|
|
576
|
+
self.down_proj = layers.Dense(
|
|
577
|
+
hidden_dim,
|
|
578
|
+
use_bias=use_bias,
|
|
579
|
+
dtype=self.dtype_policy,
|
|
580
|
+
name="down_proj",
|
|
581
|
+
)
|
|
582
|
+
|
|
583
|
+
def build(self, input_shape):
|
|
584
|
+
self.up_proj.build(input_shape)
|
|
585
|
+
input_shape = self.up_proj.compute_output_shape(input_shape)
|
|
586
|
+
self.down_proj.build(input_shape)
|
|
587
|
+
|
|
588
|
+
def call(self, inputs, training=None):
|
|
589
|
+
x = self.up_proj(inputs, training=training)
|
|
590
|
+
return self.down_proj(x, training=training)
|
|
591
|
+
|
|
592
|
+
def get_config(self):
|
|
593
|
+
config = super().get_config()
|
|
594
|
+
config.update(
|
|
595
|
+
{
|
|
596
|
+
"hidden_dim": self.hidden_dim,
|
|
597
|
+
"intermediate_dim": self.intermediate_dim,
|
|
598
|
+
"activation": self.activation,
|
|
599
|
+
"use_bias": self.use_bias,
|
|
600
|
+
}
|
|
601
|
+
)
|
|
602
|
+
return config
|
|
603
|
+
|
|
604
|
+
def compute_output_shape(self, input_shape):
|
|
605
|
+
output_shape = list(input_shape)
|
|
606
|
+
output_shape[-1] = self.hidden_dim
|
|
607
|
+
return output_shape
|
|
608
|
+
|
|
609
|
+
|
|
610
|
+
class DINOV3GatedMLP(layers.Layer):
|
|
611
|
+
"""A DINOV3 Gated MLP block.
|
|
612
|
+
|
|
613
|
+
Args:
|
|
614
|
+
hidden_dim: int. The number of units in the output layer.
|
|
615
|
+
intermediate_dim: int. The output dimension of the first Dense layer.
|
|
616
|
+
activation: str of callable. Activation to use in the intermediate
|
|
617
|
+
layer. Defaults to `"gelu"`.
|
|
618
|
+
use_bias: bool. Whether to use a bias for the dense layers.
|
|
619
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
|
620
|
+
including `name`, `dtype` etc.
|
|
621
|
+
"""
|
|
622
|
+
|
|
623
|
+
def __init__(
|
|
624
|
+
self,
|
|
625
|
+
hidden_dim,
|
|
626
|
+
intermediate_dim,
|
|
627
|
+
activation="gelu",
|
|
628
|
+
use_bias=True,
|
|
629
|
+
**kwargs,
|
|
630
|
+
):
|
|
631
|
+
super().__init__(**kwargs)
|
|
632
|
+
self.hidden_dim = int(hidden_dim)
|
|
633
|
+
self.intermediate_dim = int(intermediate_dim)
|
|
634
|
+
self.activation = activation
|
|
635
|
+
self.use_bias = bool(use_bias)
|
|
636
|
+
|
|
637
|
+
self.gate_proj = layers.Dense(
|
|
638
|
+
intermediate_dim,
|
|
639
|
+
activation=activation,
|
|
640
|
+
use_bias=use_bias,
|
|
641
|
+
dtype=self.dtype_policy,
|
|
642
|
+
name="gate_proj",
|
|
643
|
+
)
|
|
644
|
+
self.up_proj = layers.Dense(
|
|
645
|
+
intermediate_dim,
|
|
646
|
+
use_bias=use_bias,
|
|
647
|
+
dtype=self.dtype_policy,
|
|
648
|
+
name="up_proj",
|
|
649
|
+
)
|
|
650
|
+
self.down_proj = layers.Dense(
|
|
651
|
+
hidden_dim,
|
|
652
|
+
use_bias=use_bias,
|
|
653
|
+
dtype=self.dtype_policy,
|
|
654
|
+
name="down_proj",
|
|
655
|
+
)
|
|
656
|
+
|
|
657
|
+
def build(self, input_shape):
|
|
658
|
+
self.gate_proj.build(input_shape)
|
|
659
|
+
self.up_proj.build(input_shape)
|
|
660
|
+
input_shape = self.up_proj.compute_output_shape(input_shape)
|
|
661
|
+
self.down_proj.build(input_shape)
|
|
662
|
+
|
|
663
|
+
def call(self, inputs, training=None):
|
|
664
|
+
x = ops.multiply(
|
|
665
|
+
self.gate_proj(inputs, training=training),
|
|
666
|
+
self.up_proj(inputs, training=training),
|
|
667
|
+
)
|
|
668
|
+
return self.down_proj(x, training=training)
|
|
669
|
+
|
|
670
|
+
def get_config(self):
|
|
671
|
+
config = super().get_config()
|
|
672
|
+
config.update(
|
|
673
|
+
{
|
|
674
|
+
"hidden_dim": self.hidden_dim,
|
|
675
|
+
"intermediate_dim": self.intermediate_dim,
|
|
676
|
+
"activation": self.activation,
|
|
677
|
+
"use_bias": self.use_bias,
|
|
678
|
+
}
|
|
679
|
+
)
|
|
680
|
+
return config
|
|
681
|
+
|
|
682
|
+
def compute_output_shape(self, input_shape):
|
|
683
|
+
output_shape = list(input_shape)
|
|
684
|
+
output_shape[-1] = self.hidden_dim
|
|
685
|
+
return output_shape
|
|
686
|
+
|
|
687
|
+
|
|
688
|
+
class DINOV3Layer(layers.Layer):
|
|
689
|
+
"""A DINOV3 encoder layer.
|
|
690
|
+
|
|
691
|
+
Args:
|
|
692
|
+
hidden_dim: int. The number of units in the hidden layers.
|
|
693
|
+
num_heads: int. Number of attention heads.
|
|
694
|
+
intermediate_dim: int. The output dimension of the first Dense layer in
|
|
695
|
+
a two-layer feedforward network for each transformer.
|
|
696
|
+
layer_scale_init_value: float. The initial value for the scale.
|
|
697
|
+
Defaults to `1.0`.
|
|
698
|
+
hidden_activation: str or callable. Activation to use in the MLP.
|
|
699
|
+
Defaults to `"gelu"`.
|
|
700
|
+
use_gated_mlp: bool. Whether to use Gated MLP layers. Defaults to
|
|
701
|
+
`False`.
|
|
702
|
+
use_query_bias: bool. Whether to use a bias for the query projection.
|
|
703
|
+
use_key_bias: bool. Whether to use a bias for the key projection.
|
|
704
|
+
use_value_bias: bool. Whether to use a bias for the value projection.
|
|
705
|
+
use_proj_bias: bool. Whether to use a bias for the output projection.
|
|
706
|
+
use_mlp_bias: bool. Whether to use a bias for the MLP layers.
|
|
707
|
+
attention_dropout: float. The dropout rate for the attention
|
|
708
|
+
probabilities. Defaults to `0.0`.
|
|
709
|
+
drop_path_rate: float. The drop path rate to use. Defaults to `0.0`.
|
|
710
|
+
layer_norm_eps: float. The epsilon for layer normalization.
|
|
711
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
|
712
|
+
including `name`, `dtype` etc.
|
|
713
|
+
"""
|
|
714
|
+
|
|
715
|
+
def __init__(
|
|
716
|
+
self,
|
|
717
|
+
hidden_dim,
|
|
718
|
+
num_heads,
|
|
719
|
+
intermediate_dim,
|
|
720
|
+
layer_scale_init_value=1.0,
|
|
721
|
+
hidden_activation="gelu",
|
|
722
|
+
use_gated_mlp=False,
|
|
723
|
+
use_query_bias=True,
|
|
724
|
+
use_key_bias=True,
|
|
725
|
+
use_value_bias=True,
|
|
726
|
+
use_proj_bias=True,
|
|
727
|
+
use_mlp_bias=True,
|
|
728
|
+
attention_dropout=0.0,
|
|
729
|
+
drop_path_rate=0.0,
|
|
730
|
+
layer_norm_eps=1e-6,
|
|
731
|
+
**kwargs,
|
|
732
|
+
):
|
|
733
|
+
super().__init__(**kwargs)
|
|
734
|
+
self.hidden_dim = int(hidden_dim)
|
|
735
|
+
self.num_heads = int(num_heads)
|
|
736
|
+
self.intermediate_dim = int(intermediate_dim)
|
|
737
|
+
self.layer_scale_init_value = float(layer_scale_init_value)
|
|
738
|
+
self.hidden_activation = hidden_activation
|
|
739
|
+
self.use_gated_mlp = bool(use_gated_mlp)
|
|
740
|
+
self.use_query_bias = bool(use_query_bias)
|
|
741
|
+
self.use_key_bias = bool(use_key_bias)
|
|
742
|
+
self.use_value_bias = bool(use_value_bias)
|
|
743
|
+
self.use_proj_bias = bool(use_proj_bias)
|
|
744
|
+
self.use_mlp_bias = bool(use_mlp_bias)
|
|
745
|
+
self.attention_dropout = float(attention_dropout)
|
|
746
|
+
self.drop_path_rate = float(drop_path_rate)
|
|
747
|
+
self.layer_norm_eps = float(layer_norm_eps)
|
|
748
|
+
|
|
749
|
+
self.norm1 = layers.LayerNormalization(
|
|
750
|
+
epsilon=layer_norm_eps, dtype=self.dtype_policy, name="norm1"
|
|
751
|
+
)
|
|
752
|
+
self.attention = DINOV3Attention(
|
|
753
|
+
hidden_dim=hidden_dim,
|
|
754
|
+
num_heads=num_heads,
|
|
755
|
+
dropout_rate=attention_dropout,
|
|
756
|
+
use_query_bias=use_query_bias,
|
|
757
|
+
use_key_bias=use_key_bias,
|
|
758
|
+
use_value_bias=use_value_bias,
|
|
759
|
+
use_proj_bias=use_proj_bias,
|
|
760
|
+
dtype=self.dtype_policy,
|
|
761
|
+
name="attention",
|
|
762
|
+
)
|
|
763
|
+
self.layer_scale1 = DINOV3LayerScale(
|
|
764
|
+
hidden_dim,
|
|
765
|
+
init_values=layer_scale_init_value,
|
|
766
|
+
dtype=self.dtype_policy,
|
|
767
|
+
name="layer_scale1",
|
|
768
|
+
)
|
|
769
|
+
self.drop_path = (
|
|
770
|
+
DINOV3DropPath(drop_path_rate, dtype=self.dtype_policy)
|
|
771
|
+
if drop_path_rate > 0.0
|
|
772
|
+
else layers.Identity(dtype=self.dtype_policy)
|
|
773
|
+
)
|
|
774
|
+
self.norm2 = layers.LayerNormalization(
|
|
775
|
+
epsilon=layer_norm_eps, dtype=self.dtype_policy, name="norm2"
|
|
776
|
+
)
|
|
777
|
+
if use_gated_mlp:
|
|
778
|
+
self.mlp = DINOV3GatedMLP(
|
|
779
|
+
hidden_dim,
|
|
780
|
+
intermediate_dim,
|
|
781
|
+
activation=hidden_activation,
|
|
782
|
+
use_bias=use_mlp_bias,
|
|
783
|
+
dtype=self.dtype_policy,
|
|
784
|
+
name="mlp",
|
|
785
|
+
)
|
|
786
|
+
else:
|
|
787
|
+
self.mlp = DINOV3MLP(
|
|
788
|
+
hidden_dim,
|
|
789
|
+
intermediate_dim,
|
|
790
|
+
activation=hidden_activation,
|
|
791
|
+
use_bias=use_mlp_bias,
|
|
792
|
+
dtype=self.dtype_policy,
|
|
793
|
+
name="mlp",
|
|
794
|
+
)
|
|
795
|
+
self.layer_scale2 = DINOV3LayerScale(
|
|
796
|
+
hidden_dim,
|
|
797
|
+
init_values=layer_scale_init_value,
|
|
798
|
+
dtype=self.dtype_policy,
|
|
799
|
+
name="layer_scale2",
|
|
800
|
+
)
|
|
801
|
+
|
|
802
|
+
def build(self, input_shape):
|
|
803
|
+
self.norm1.build(input_shape)
|
|
804
|
+
self.attention.build(input_shape)
|
|
805
|
+
input_shape = self.attention.compute_output_shape(input_shape)
|
|
806
|
+
self.layer_scale1.build(input_shape)
|
|
807
|
+
self.drop_path.build(input_shape)
|
|
808
|
+
self.norm2.build(input_shape)
|
|
809
|
+
self.mlp.build(input_shape)
|
|
810
|
+
input_shape = self.mlp.compute_output_shape(input_shape)
|
|
811
|
+
self.layer_scale2.build(input_shape)
|
|
812
|
+
|
|
813
|
+
def call(
|
|
814
|
+
self,
|
|
815
|
+
inputs,
|
|
816
|
+
attention_mask=None,
|
|
817
|
+
position_embeddings=None,
|
|
818
|
+
num_prefix_tokens=0,
|
|
819
|
+
training=None,
|
|
820
|
+
):
|
|
821
|
+
residual = inputs
|
|
822
|
+
hidden_states = self.norm1(inputs)
|
|
823
|
+
hidden_states = self.attention(
|
|
824
|
+
hidden_states,
|
|
825
|
+
attention_mask=attention_mask,
|
|
826
|
+
position_embeddings=position_embeddings,
|
|
827
|
+
num_prefix_tokens=num_prefix_tokens,
|
|
828
|
+
training=training,
|
|
829
|
+
)
|
|
830
|
+
hidden_states = self.layer_scale1(hidden_states, training=training)
|
|
831
|
+
hidden_states = (
|
|
832
|
+
self.drop_path(hidden_states, training=training) + residual
|
|
833
|
+
)
|
|
834
|
+
|
|
835
|
+
residual = hidden_states
|
|
836
|
+
hidden_states = self.norm2(hidden_states, training=training)
|
|
837
|
+
hidden_states = self.mlp(hidden_states, training=training)
|
|
838
|
+
hidden_states = self.layer_scale2(hidden_states, training=training)
|
|
839
|
+
return self.drop_path(hidden_states, training=training) + residual
|
|
840
|
+
|
|
841
|
+
def get_config(self):
|
|
842
|
+
config = super().get_config()
|
|
843
|
+
config.update(
|
|
844
|
+
{
|
|
845
|
+
"hidden_dim": self.hidden_dim,
|
|
846
|
+
"num_heads": self.num_heads,
|
|
847
|
+
"intermediate_dim": self.intermediate_dim,
|
|
848
|
+
"layer_scale_init_value": self.layer_scale_init_value,
|
|
849
|
+
"hidden_activation": self.hidden_activation,
|
|
850
|
+
"use_gated_mlp": self.use_gated_mlp,
|
|
851
|
+
"use_query_bias": self.use_query_bias,
|
|
852
|
+
"use_key_bias": self.use_key_bias,
|
|
853
|
+
"use_value_bias": self.use_value_bias,
|
|
854
|
+
"use_proj_bias": self.use_proj_bias,
|
|
855
|
+
"use_mlp_bias": self.use_mlp_bias,
|
|
856
|
+
"attention_dropout": self.attention_dropout,
|
|
857
|
+
"drop_path_rate": self.drop_path_rate,
|
|
858
|
+
"layer_norm_eps": self.layer_norm_eps,
|
|
859
|
+
}
|
|
860
|
+
)
|
|
861
|
+
return config
|
|
862
|
+
|
|
863
|
+
def compute_output_shape(self, input_shape):
|
|
864
|
+
return input_shape
|
|
865
|
+
|
|
866
|
+
|
|
867
|
+
class DINOV3Encoder(layers.Layer):
|
|
868
|
+
"""A DINOV3 encoder.
|
|
869
|
+
|
|
870
|
+
Args:
|
|
871
|
+
num_layers: int. The number of transformer layers.
|
|
872
|
+
hidden_dim: int. The number of units in the hidden layers.
|
|
873
|
+
num_heads: int. Number of attention heads.
|
|
874
|
+
intermediate_dim: int. The output dimension of the first Dense layer in
|
|
875
|
+
a two-layer feedforward network for each transformer.
|
|
876
|
+
layer_scale_init_value: float. The initial value for the scale.
|
|
877
|
+
Defaults to `1.0`.
|
|
878
|
+
hidden_activation: str or callable. Activation to use in the MLP.
|
|
879
|
+
Defaults to `"gelu"`.
|
|
880
|
+
use_gated_mlp: bool. Whether to use Gated MLP layers. Defaults to
|
|
881
|
+
`False`.
|
|
882
|
+
use_query_bias: bool. Whether to use a bias for the query projection.
|
|
883
|
+
Defaults to `True`.
|
|
884
|
+
use_key_bias: bool. Whether to use a bias for the key projection.
|
|
885
|
+
Defaults to `True`.
|
|
886
|
+
use_value_bias: bool. Whether to use a bias for the value projection.
|
|
887
|
+
Defaults to `True`.
|
|
888
|
+
use_proj_bias: bool. Whether to use a bias for the output projection.
|
|
889
|
+
Defaults to `True`.
|
|
890
|
+
use_mlp_bias: bool. Whether to use a bias for the dense layers in MLP.
|
|
891
|
+
Defaults to `True`.
|
|
892
|
+
attention_dropout: float. The dropout rate for the attention
|
|
893
|
+
probabilities. Defaults to `0.0`.
|
|
894
|
+
drop_path_rate: float. The drop path rate to use. Defaults to `0.0`.
|
|
895
|
+
layer_norm_eps: float. The epsilon for layer normalization. Defaults to
|
|
896
|
+
`1e-5`.
|
|
897
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
|
898
|
+
including `name`, `dtype` etc.
|
|
899
|
+
"""
|
|
900
|
+
|
|
901
|
+
def __init__(
|
|
902
|
+
self,
|
|
903
|
+
num_layers,
|
|
904
|
+
hidden_dim,
|
|
905
|
+
num_heads,
|
|
906
|
+
intermediate_dim,
|
|
907
|
+
layer_scale_init_value=1.0,
|
|
908
|
+
hidden_activation="gelu",
|
|
909
|
+
use_gated_mlp=False,
|
|
910
|
+
use_query_bias=True,
|
|
911
|
+
use_key_bias=True,
|
|
912
|
+
use_value_bias=True,
|
|
913
|
+
use_proj_bias=True,
|
|
914
|
+
use_mlp_bias=True,
|
|
915
|
+
attention_dropout=0.0,
|
|
916
|
+
drop_path_rate=0.0,
|
|
917
|
+
layer_norm_eps=1e-5,
|
|
918
|
+
**kwargs,
|
|
919
|
+
):
|
|
920
|
+
super().__init__(**kwargs)
|
|
921
|
+
self.num_layers = int(num_layers)
|
|
922
|
+
self.hidden_dim = int(hidden_dim)
|
|
923
|
+
self.num_heads = int(num_heads)
|
|
924
|
+
self.intermediate_dim = int(intermediate_dim)
|
|
925
|
+
self.layer_scale_init_value = float(layer_scale_init_value)
|
|
926
|
+
self.hidden_activation = hidden_activation
|
|
927
|
+
self.use_gated_mlp = bool(use_gated_mlp)
|
|
928
|
+
self.use_query_bias = bool(use_query_bias)
|
|
929
|
+
self.use_key_bias = bool(use_key_bias)
|
|
930
|
+
self.use_value_bias = bool(use_value_bias)
|
|
931
|
+
self.use_proj_bias = bool(use_proj_bias)
|
|
932
|
+
self.use_mlp_bias = bool(use_mlp_bias)
|
|
933
|
+
self.attention_dropout = float(attention_dropout)
|
|
934
|
+
self.drop_path_rate = float(drop_path_rate)
|
|
935
|
+
self.layer_norm_eps = float(layer_norm_eps)
|
|
936
|
+
|
|
937
|
+
dpr = [x for x in ops.linspace(0.0, drop_path_rate, num_layers)]
|
|
938
|
+
self.layers = [
|
|
939
|
+
DINOV3Layer(
|
|
940
|
+
hidden_dim=hidden_dim,
|
|
941
|
+
num_heads=num_heads,
|
|
942
|
+
intermediate_dim=intermediate_dim,
|
|
943
|
+
layer_scale_init_value=layer_scale_init_value,
|
|
944
|
+
hidden_activation=hidden_activation,
|
|
945
|
+
use_gated_mlp=use_gated_mlp,
|
|
946
|
+
use_query_bias=use_query_bias,
|
|
947
|
+
use_key_bias=use_key_bias,
|
|
948
|
+
use_value_bias=use_value_bias,
|
|
949
|
+
use_proj_bias=use_proj_bias,
|
|
950
|
+
use_mlp_bias=use_mlp_bias,
|
|
951
|
+
attention_dropout=attention_dropout,
|
|
952
|
+
drop_path_rate=dpr[i],
|
|
953
|
+
layer_norm_eps=layer_norm_eps,
|
|
954
|
+
dtype=self.dtype_policy,
|
|
955
|
+
name=f"layers_{i}",
|
|
956
|
+
)
|
|
957
|
+
for i in range(num_layers)
|
|
958
|
+
]
|
|
959
|
+
|
|
960
|
+
def build(self, input_shape):
|
|
961
|
+
for layer in self.layers:
|
|
962
|
+
layer.build(input_shape)
|
|
963
|
+
input_shape = layer.compute_output_shape(input_shape)
|
|
964
|
+
|
|
965
|
+
def call(
|
|
966
|
+
self,
|
|
967
|
+
inputs,
|
|
968
|
+
attention_mask=None,
|
|
969
|
+
position_embeddings=None,
|
|
970
|
+
num_prefix_tokens=0,
|
|
971
|
+
training=None,
|
|
972
|
+
):
|
|
973
|
+
pyramid_outputs = {}
|
|
974
|
+
x = inputs
|
|
975
|
+
for layer_index, layer in enumerate(self.layers, start=1):
|
|
976
|
+
x = layer(
|
|
977
|
+
x,
|
|
978
|
+
attention_mask=attention_mask,
|
|
979
|
+
position_embeddings=position_embeddings,
|
|
980
|
+
num_prefix_tokens=num_prefix_tokens,
|
|
981
|
+
training=training,
|
|
982
|
+
)
|
|
983
|
+
pyramid_outputs[f"stage{str(layer_index)}"] = x
|
|
984
|
+
return x, pyramid_outputs
|
|
985
|
+
|
|
986
|
+
def get_config(self):
|
|
987
|
+
config = super().get_config()
|
|
988
|
+
config.update(
|
|
989
|
+
{
|
|
990
|
+
"num_layers": self.num_layers,
|
|
991
|
+
"hidden_dim": self.hidden_dim,
|
|
992
|
+
"num_heads": self.num_heads,
|
|
993
|
+
"intermediate_dim": self.intermediate_dim,
|
|
994
|
+
"layer_scale_init_value": self.layer_scale_init_value,
|
|
995
|
+
"hidden_activation": self.hidden_activation,
|
|
996
|
+
"use_gated_mlp": self.use_gated_mlp,
|
|
997
|
+
"use_query_bias": self.use_query_bias,
|
|
998
|
+
"use_key_bias": self.use_key_bias,
|
|
999
|
+
"use_value_bias": self.use_value_bias,
|
|
1000
|
+
"use_proj_bias": self.use_proj_bias,
|
|
1001
|
+
"use_mlp_bias": self.use_mlp_bias,
|
|
1002
|
+
"attention_dropout": self.attention_dropout,
|
|
1003
|
+
"drop_path_rate": self.drop_path_rate,
|
|
1004
|
+
"layer_norm_eps": self.layer_norm_eps,
|
|
1005
|
+
}
|
|
1006
|
+
)
|
|
1007
|
+
return config
|
|
1008
|
+
|
|
1009
|
+
def compute_output_shape(self, input_shape):
|
|
1010
|
+
pyramid_outputs = {}
|
|
1011
|
+
for layer_index in range(1, len(self.layers) + 1):
|
|
1012
|
+
pyramid_outputs[f"stage{str(layer_index)}"] = input_shape
|
|
1013
|
+
return input_shape, pyramid_outputs
|