keras-hub-nightly 0.22.0.dev202507160421__py3-none-any.whl → 0.22.0.dev202507170424__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. keras_hub/layers/__init__.py +3 -0
  2. keras_hub/models/__init__.py +3 -0
  3. keras_hub/src/models/clip/clip_backbone.py +3 -102
  4. keras_hub/src/models/clip/clip_layers.py +295 -0
  5. keras_hub/src/models/clip/clip_preprocessor.py +57 -48
  6. keras_hub/src/models/clip/clip_text_encoder.py +2 -2
  7. keras_hub/src/models/clip/clip_vision_encoder.py +3 -3
  8. keras_hub/src/models/dinov2/__init__.py +5 -0
  9. keras_hub/src/models/dinov2/dinov2_backbone.py +228 -0
  10. keras_hub/src/models/dinov2/dinov2_image_converter.py +8 -0
  11. keras_hub/src/models/dinov2/dinov2_layers.py +886 -0
  12. keras_hub/src/models/dinov2/dinov2_presets.py +4 -0
  13. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +6 -2
  14. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +16 -7
  15. keras_hub/src/models/stable_diffusion_3/mmdit.py +61 -4
  16. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +23 -32
  17. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +1 -0
  18. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +1 -0
  19. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +1 -0
  20. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +6 -2
  21. keras_hub/src/utils/preset_utils.py +4 -1
  22. keras_hub/src/utils/transformers/convert_dinov2.py +180 -0
  23. keras_hub/src/utils/transformers/export/gemma.py +89 -0
  24. keras_hub/src/utils/transformers/export/hf_exporter.py +98 -0
  25. keras_hub/src/utils/transformers/preset_loader.py +4 -1
  26. keras_hub/src/version.py +1 -1
  27. {keras_hub_nightly-0.22.0.dev202507160421.dist-info → keras_hub_nightly-0.22.0.dev202507170424.dist-info}/METADATA +1 -1
  28. {keras_hub_nightly-0.22.0.dev202507160421.dist-info → keras_hub_nightly-0.22.0.dev202507170424.dist-info}/RECORD +30 -23
  29. keras_hub/src/models/clip/clip_encoder_block.py +0 -111
  30. keras_hub/src/models/clip/clip_vision_embedding.py +0 -101
  31. {keras_hub_nightly-0.22.0.dev202507160421.dist-info → keras_hub_nightly-0.22.0.dev202507170424.dist-info}/WHEEL +0 -0
  32. {keras_hub_nightly-0.22.0.dev202507160421.dist-info → keras_hub_nightly-0.22.0.dev202507170424.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,886 @@
1
+ from keras import backend
2
+ from keras import config
3
+ from keras import initializers
4
+ from keras import layers
5
+ from keras import ops
6
+ from keras import random
7
+
8
+ from keras_hub.src.utils.keras_utils import standardize_data_format
9
+
10
+
11
+ class DINOV2PatchEmbedding(layers.Layer):
12
+ """A layer that converts images into patches.
13
+
14
+ Args:
15
+ hidden_dim: int. The number of units in the hidden layers.
16
+ patch_size: int. The size of one side of each patch.
17
+ data_format: `None` or str. If specified, either `"channels_last"` or
18
+ `"channels_first"`. The ordering of the dimensions in the
19
+ inputs. `"channels_last"` corresponds to inputs with shape
20
+ `(batch_size, height, width, channels)`
21
+ while `"channels_first"` corresponds to inputs with shape
22
+ `(batch_size, channels, height, width)`. It defaults to the
23
+ `image_data_format` value found in your Keras config file at
24
+ `~/.keras/keras.json`. If you never set it, then it will be
25
+ `"channels_last"`.
26
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
27
+ including `name`, `dtype` etc.
28
+ """
29
+
30
+ def __init__(self, hidden_dim, patch_size, data_format=None, **kwargs):
31
+ super().__init__(**kwargs)
32
+ self.hidden_dim = int(hidden_dim)
33
+ self.patch_size = int(patch_size)
34
+ self.data_format = standardize_data_format(data_format)
35
+
36
+ self.projection = layers.Conv2D(
37
+ hidden_dim,
38
+ kernel_size=patch_size,
39
+ strides=patch_size,
40
+ data_format=data_format,
41
+ kernel_initializer=initializers.TruncatedNormal(stddev=0.02),
42
+ dtype=self.dtype_policy,
43
+ name="projection",
44
+ )
45
+
46
+ def build(self, input_shape):
47
+ self.projection.build(input_shape)
48
+
49
+ def call(self, inputs, training=None):
50
+ batch_size = ops.shape(inputs)[0]
51
+ embeddings = self.projection(inputs, training=training)
52
+ if self.data_format == "channels_last":
53
+ embeddings = ops.reshape(
54
+ embeddings, (batch_size, -1, self.hidden_dim)
55
+ )
56
+ else:
57
+ embeddings = ops.reshape(
58
+ embeddings, (batch_size, self.hidden_dim, -1)
59
+ )
60
+ embeddings = ops.transpose(embeddings, (0, 2, 1))
61
+ return embeddings
62
+
63
+ def get_config(self):
64
+ config = super().get_config()
65
+ config.update(
66
+ {
67
+ "hidden_dim": self.hidden_dim,
68
+ "patch_size": self.patch_size,
69
+ }
70
+ )
71
+ return config
72
+
73
+ def compute_output_shape(self, input_shape):
74
+ output_shape = [input_shape[0], None, self.hidden_dim]
75
+ if self.data_format == "channels_last":
76
+ if input_shape[1] is not None and input_shape[2] is not None:
77
+ patch_num = input_shape[1] // self.patch_size
78
+ output_shape[1] = patch_num**2
79
+ else:
80
+ if input_shape[2] is not None and input_shape[3] is not None:
81
+ patch_num = input_shape[2] // self.patch_size
82
+ output_shape[1] = patch_num**2
83
+ return output_shape
84
+
85
+
86
+ class DINOV2Embedding(layers.Layer):
87
+ """A layer that converts images into patches.
88
+
89
+ This layer adds all the necessary tokens to the embeddings, inlcuding
90
+ the class token, register tokens and mask token if specified. Finally, a
91
+ position embedding will be added.
92
+
93
+ This layer supports the interpolation of the position embeddings to enable
94
+ the model to work with images of different sizes. Please refer to
95
+ `_interpolate_position_embeddings` for more details.
96
+
97
+ The saving and loading of this layer will automatically handle the position
98
+ embeddings interpolation. Please refer to `save_own_variables` and
99
+ `load_own_variables` for more details.
100
+
101
+ Args:
102
+ hidden_dim: int. The number of units in the hidden layers.
103
+ patch_size: int. The size of one side of each patch.
104
+ image_size: tuple of ints. The (height, width) of the input images.
105
+ num_register_tokens: int. The number of register tokens to add to the
106
+ embeddings. Defaults to `0`.
107
+ use_mask_token: bool. Whether to use a mask token. Defaults to `True`.
108
+ dropout_rate: float. The dropout rate to use. Defaults to `0.0`.
109
+ position_embedding_shape: tuple. The original input shape used to
110
+ train the position embeddings. This is used to interpolate the
111
+ position embeddings to the actual input shape. Defaults to
112
+ `(518, 518)`.
113
+ antialias_in_interpolation: bool. Whether to use antialiasing in the
114
+ interpolation of the position embeddings. Defaults to `False`.
115
+ data_format: `None` or str. If specified, either `"channels_last"` or
116
+ `"channels_first"`. The ordering of the dimensions in the
117
+ inputs. `"channels_last"` corresponds to inputs with shape
118
+ `(batch_size, height, width, channels)`
119
+ while `"channels_first"` corresponds to inputs with shape
120
+ `(batch_size, channels, height, width)`. It defaults to the
121
+ `image_data_format` value found in your Keras config file at
122
+ `~/.keras/keras.json`. If you never set it, then it will be
123
+ `"channels_last"`.
124
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
125
+ including `name`, `dtype` etc.
126
+ """
127
+
128
+ def __init__(
129
+ self,
130
+ hidden_dim,
131
+ patch_size,
132
+ image_shape,
133
+ num_register_tokens=0,
134
+ use_mask_token=True,
135
+ dropout_rate=0.0,
136
+ position_embedding_shape=(518, 518),
137
+ antialias_in_interpolation=False,
138
+ data_format=None,
139
+ **kwargs,
140
+ ):
141
+ super().__init__(**kwargs)
142
+ self.hidden_dim = int(hidden_dim)
143
+ self.patch_size = int(patch_size)
144
+ self.image_shape = (int(image_shape[0]), int(image_shape[1]))
145
+ self.position_embedding_shape = (
146
+ int(position_embedding_shape[0]),
147
+ int(position_embedding_shape[1]),
148
+ )
149
+ self.num_register_tokens = int(num_register_tokens)
150
+ self.use_mask_token = bool(use_mask_token)
151
+ self.dropout_rate = float(dropout_rate)
152
+ self.antialias_in_interpolation = bool(antialias_in_interpolation)
153
+ self.data_format = standardize_data_format(data_format)
154
+ self.interpolated_num_patches = (
155
+ self.image_shape[0] // self.patch_size
156
+ ) * (self.image_shape[1] // self.patch_size)
157
+ self.num_patches = (
158
+ self.position_embedding_shape[0] // self.patch_size
159
+ ) * (self.position_embedding_shape[1] // self.patch_size)
160
+
161
+ self.patch_embeddings = DINOV2PatchEmbedding(
162
+ hidden_dim,
163
+ patch_size,
164
+ data_format=data_format,
165
+ dtype=self.dtype_policy,
166
+ name="patch_embeddings",
167
+ )
168
+ self.dropout = layers.Dropout(
169
+ rate=self.dropout_rate,
170
+ dtype=self.dtype_policy,
171
+ name="dropout",
172
+ )
173
+
174
+ def build(self, input_shape):
175
+ self.cls_token = self.add_weight(
176
+ shape=(1, 1, self.hidden_dim),
177
+ initializer=initializers.TruncatedNormal(stddev=0.02),
178
+ trainable=True,
179
+ name="cls_token",
180
+ )
181
+ if self.use_mask_token:
182
+ self.mask_token = self.add_weight(
183
+ shape=(1, self.hidden_dim),
184
+ initializer="zeros",
185
+ trainable=True,
186
+ name="mask_token",
187
+ )
188
+ if self.num_register_tokens > 0:
189
+ self.register_tokens = self.add_weight(
190
+ shape=(1, self.num_register_tokens, self.hidden_dim),
191
+ initializer="zeros",
192
+ trainable=True,
193
+ name="register_tokens",
194
+ )
195
+ self.patch_embeddings.build(input_shape)
196
+
197
+ # Note that there are two position embeddings:
198
+ # `self.interpolated_position_embeddings` is used for the image inputs
199
+ # during both training and inference.
200
+ # `self.position_embeddings` is used to load pretrained weights and
201
+ # remains unchanged during training and inference. It will be updated
202
+ # during saving once `self.interpolated_position_embeddings` is
203
+ # modified.
204
+ self.position_embeddings = self.add_weight(
205
+ shape=(1, self.num_patches + 1, self.hidden_dim),
206
+ initializer=initializers.TruncatedNormal(stddev=0.02),
207
+ trainable=False,
208
+ name="position_embeddings",
209
+ )
210
+ self.interpolated_position_embeddings = self.add_weight(
211
+ shape=(1, self.interpolated_num_patches + 1, self.hidden_dim),
212
+ initializer="zeros", # Will be initialized by interpolation.
213
+ trainable=True,
214
+ name="interpolated_position_embeddings",
215
+ )
216
+
217
+ # Initialize the interpolated position embeddings.
218
+ self.interpolated_position_embeddings.assign(
219
+ self._interpolate_position_embeddings(
220
+ self.position_embeddings,
221
+ patch_size=self.patch_size,
222
+ source_shape=self.position_embedding_shape,
223
+ target_shape=self.image_shape,
224
+ antialias=self.antialias_in_interpolation,
225
+ )
226
+ )
227
+
228
+ def call(self, inputs, masks=None, training=None):
229
+ batch_size = ops.shape(inputs)[0]
230
+ embeddings = self.patch_embeddings(inputs, training=training)
231
+
232
+ # Repalce the embeddings with the mask tokens if specified.
233
+ # Basically, this is only used during training.
234
+ if masks is not None and self.use_mask_token:
235
+ masks = ops.expand_dims(masks, axis=-1)
236
+ mask_token = ops.cast(
237
+ ops.expand_dims(self.mask_token, axis=0), embeddings.dtype
238
+ )
239
+ embeddings = ops.where(masks, mask_token, embeddings)
240
+
241
+ # Add the [CLS] token to the embedded patch tokens.
242
+ cls_tokens = ops.tile(self.cls_token, (batch_size, 1, 1))
243
+ embeddings = ops.concatenate((cls_tokens, embeddings), axis=1)
244
+
245
+ # Add positional encoding to each token.
246
+ embeddings = ops.add(embeddings, self.interpolated_position_embeddings)
247
+
248
+ # Add register tokens if specified.
249
+ if self.num_register_tokens > 0:
250
+ register_tokens = ops.tile(self.register_tokens, (batch_size, 1, 1))
251
+ embeddings = ops.concatenate(
252
+ (
253
+ embeddings[:, :1, ...],
254
+ register_tokens,
255
+ embeddings[:, 1:, ...],
256
+ ),
257
+ axis=1,
258
+ )
259
+
260
+ embeddings = self.dropout(embeddings)
261
+ return embeddings
262
+
263
+ def get_config(self):
264
+ config = super().get_config()
265
+ config.update(
266
+ {
267
+ "hidden_dim": self.hidden_dim,
268
+ "patch_size": self.patch_size,
269
+ "image_shape": self.image_shape,
270
+ "num_register_tokens": self.num_register_tokens,
271
+ "use_mask_token": self.use_mask_token,
272
+ "dropout_rate": self.dropout_rate,
273
+ "position_embedding_shape": self.position_embedding_shape,
274
+ "antialias_in_interpolation": self.antialias_in_interpolation,
275
+ }
276
+ )
277
+ return config
278
+
279
+ def compute_output_shape(self, input_shape):
280
+ output_shape = [input_shape[0], None, self.hidden_dim]
281
+ if self.data_format == "channels_last":
282
+ if input_shape[1] is not None and input_shape[2] is not None:
283
+ patch_num = input_shape[1] // self.patch_size
284
+ # 1 is for cls token.
285
+ output_shape[1] = 1 + self.num_register_tokens + patch_num**2
286
+ else:
287
+ if input_shape[2] is not None and input_shape[3] is not None:
288
+ patch_num = input_shape[2] // self.patch_size
289
+ # 1 is for cls token.
290
+ output_shape[1] = 1 + self.num_register_tokens + patch_num**2
291
+ return output_shape
292
+
293
+ @staticmethod
294
+ def _interpolate_position_embeddings(
295
+ position_embeddings,
296
+ patch_size,
297
+ source_shape,
298
+ target_shape,
299
+ antialias=False,
300
+ ):
301
+ """Interpolate position embeddings to match the target image shape.
302
+
303
+ Reference:
304
+ - https://github.com/huggingface/transformers/blob/main/src/transformers/models/dinov2/modeling_dinov2.py
305
+ """
306
+ position_embeddings = ops.convert_to_tensor(position_embeddings)
307
+ patch_size = int(patch_size)
308
+ source_shape = (int(source_shape[0]), int(source_shape[1]))
309
+ target_shape = (int(target_shape[0]), int(target_shape[1]))
310
+ hidden_dim = int(position_embeddings.shape[-1])
311
+
312
+ if (
313
+ source_shape[0] == target_shape[0]
314
+ and source_shape[1] == target_shape[1]
315
+ ):
316
+ # No need to interpolate if the image size is the same as the
317
+ # position embedding image size.
318
+ return ops.copy(position_embeddings)
319
+
320
+ num_positions = int(position_embeddings.shape[1]) - 1
321
+
322
+ # Handle class token and patch embeddings separately.
323
+ class_position_embeddings = position_embeddings[:, :1, ...]
324
+ patch_position_embeddings = position_embeddings[:, 1:, ...]
325
+
326
+ # Calculate new dimensions
327
+ new_height = target_shape[0] // patch_size
328
+ new_width = target_shape[1] // patch_size
329
+
330
+ # Reshape for interpolation
331
+ sqrt_num_positions = int(num_positions**0.5)
332
+ patch_position_embeddings = ops.reshape(
333
+ patch_position_embeddings,
334
+ (1, sqrt_num_positions, sqrt_num_positions, hidden_dim),
335
+ )
336
+
337
+ # Interpolate at float32 precision.
338
+ original_dtype = backend.standardize_dtype(
339
+ patch_position_embeddings.dtype
340
+ )
341
+ interpolated_patch_position_embeddings = ops.image.resize(
342
+ ops.cast(patch_position_embeddings, "float32"),
343
+ size=(new_height, new_width),
344
+ interpolation="bicubic",
345
+ antialias=antialias,
346
+ data_format="channels_last",
347
+ )
348
+ interpolated_patch_position_embeddings = ops.cast(
349
+ interpolated_patch_position_embeddings, original_dtype
350
+ )
351
+
352
+ # Reshape back to the original format
353
+ interpolated_patch_position_embeddings = ops.reshape(
354
+ interpolated_patch_position_embeddings, (1, -1, hidden_dim)
355
+ )
356
+ interpolated_position_embeddings = ops.concatenate(
357
+ (class_position_embeddings, interpolated_patch_position_embeddings),
358
+ axis=1,
359
+ )
360
+ return interpolated_position_embeddings
361
+
362
+ def _is_interpolated_position_embeddings_updated(self):
363
+ """Check if the interpolated position embeddings are updated."""
364
+ original_interpolated_position_embeddings = (
365
+ self._interpolate_position_embeddings(
366
+ self.position_embeddings,
367
+ patch_size=self.patch_size,
368
+ source_shape=self.position_embedding_shape,
369
+ target_shape=self.image_shape,
370
+ antialias=self.antialias_in_interpolation,
371
+ )
372
+ )
373
+ diff = ops.sum(
374
+ ops.subtract(
375
+ original_interpolated_position_embeddings,
376
+ self.interpolated_position_embeddings,
377
+ )
378
+ )
379
+ return ops.cond(
380
+ ops.greater(diff, config.epsilon()), lambda: True, lambda: False
381
+ )
382
+
383
+ def save_own_variables(self, store):
384
+ if self._is_interpolated_position_embeddings_updated():
385
+ self.position_embeddings.assign(
386
+ self._interpolate_position_embeddings(
387
+ self.interpolated_position_embeddings,
388
+ patch_size=self.patch_size,
389
+ source_shape=self.image_shape,
390
+ target_shape=self.position_embedding_shape,
391
+ antialias=self.antialias_in_interpolation,
392
+ )
393
+ )
394
+ super().save_own_variables(store)
395
+
396
+ def load_own_variables(self, store):
397
+ all_vars = self._trainable_variables + self._non_trainable_variables
398
+ for i, v in enumerate(all_vars):
399
+ if v is self.interpolated_position_embeddings:
400
+ continue
401
+ v.assign(store[f"{i}"])
402
+ self.interpolated_position_embeddings.assign(
403
+ self._interpolate_position_embeddings(
404
+ self.position_embeddings,
405
+ patch_size=self.patch_size,
406
+ source_shape=self.position_embedding_shape,
407
+ target_shape=self.image_shape,
408
+ antialias=self.antialias_in_interpolation,
409
+ )
410
+ )
411
+
412
+
413
+ class DINOV2Attention(layers.Layer):
414
+ """A multi-head attention layer with dropout.
415
+
416
+ Args:
417
+ hidden_dim: int. The number of units in the hidden layers.
418
+ num_heads: int. Number of attention heads.
419
+ dropout_rate: float. The dropout rate to use. Defaults to `0.0`.
420
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
421
+ including `name`, `dtype` etc.
422
+ """
423
+
424
+ def __init__(self, hidden_dim, num_heads, dropout_rate=0.0, **kwargs):
425
+ super().__init__(**kwargs)
426
+ self.hidden_dim = int(hidden_dim)
427
+ self.num_heads = int(num_heads)
428
+ self.dropout_rate = float(dropout_rate)
429
+
430
+ self.attention = layers.MultiHeadAttention(
431
+ num_heads=self.num_heads,
432
+ key_dim=self.hidden_dim // self.num_heads,
433
+ dropout=self.dropout_rate,
434
+ dtype=self.dtype_policy,
435
+ name="attention",
436
+ )
437
+ self.dropout = layers.Dropout(
438
+ rate=self.dropout_rate,
439
+ dtype=self.dtype_policy,
440
+ name="dropout",
441
+ )
442
+
443
+ def build(self, input_shape):
444
+ self.attention.build(input_shape, input_shape)
445
+
446
+ def call(self, inputs, training=None):
447
+ attention_output = self.attention(
448
+ query=inputs,
449
+ value=inputs,
450
+ key=inputs,
451
+ training=training,
452
+ use_causal_mask=False,
453
+ )
454
+ outputs = self.dropout(attention_output, training=training)
455
+ return outputs
456
+
457
+ def get_config(self):
458
+ config = super().get_config()
459
+ config.update(
460
+ {
461
+ "hidden_dim": self.hidden_dim,
462
+ "num_heads": self.num_heads,
463
+ "dropout_rate": self.dropout_rate,
464
+ }
465
+ )
466
+ return config
467
+
468
+ def compute_output_shape(self, input_shape):
469
+ return input_shape
470
+
471
+
472
+ class DINOV2LayerScale(layers.Layer):
473
+ """A layer scale.
474
+
475
+ Args:
476
+ hidden_dim: int. The number of units in the hidden layers.
477
+ init_values: float. The initial value for the scale. Defaults to `1.0`.
478
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
479
+ including `name`, `dtype` etc.
480
+ """
481
+
482
+ def __init__(self, hidden_dim, init_values=1.0, **kwargs):
483
+ super().__init__(**kwargs)
484
+ self.hidden_dim = int(hidden_dim)
485
+ self.init_values = float(init_values)
486
+
487
+ def build(self, input_shape):
488
+ self.lambda1 = self.add_weight(
489
+ shape=(self.hidden_dim,),
490
+ initializer=initializers.Constant(self.init_values),
491
+ trainable=True,
492
+ name="lambda1",
493
+ )
494
+
495
+ def call(self, inputs, training=None):
496
+ return ops.multiply(inputs, self.lambda1)
497
+
498
+ def get_config(self):
499
+ config = super().get_config()
500
+ config.update({"hidden_dim": self.hidden_dim})
501
+ return config
502
+
503
+ def compute_output_shape(self, input_shape):
504
+ return input_shape
505
+
506
+
507
+ class DINOV2DropPath(layers.Layer):
508
+ """A drop path layer.
509
+
510
+ Args:
511
+ rate: float. The drop path rate to use. Defaults to `0.0`.
512
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
513
+ including `name`, `dtype` etc.
514
+ """
515
+
516
+ def __init__(self, rate=0.0, **kwargs):
517
+ super().__init__(**kwargs)
518
+ self.rate = float(rate)
519
+
520
+ def build(self, input_shape):
521
+ self.noise_shape = (input_shape[0],) + (1,) * (len(input_shape) - 1)
522
+
523
+ def call(self, inputs, training=None):
524
+ if not training or self.rate == 0.0:
525
+ return inputs
526
+
527
+ keep_prob = 1.0 - self.rate
528
+ random_tensor = random.uniform(self.noise_shape, dtype=inputs.dtype)
529
+ random_tensor = ops.add(random_tensor, keep_prob)
530
+ return ops.multiply(ops.divide(inputs, keep_prob), random_tensor)
531
+
532
+ def get_config(self):
533
+ config = super().get_config()
534
+ config.update({"rate": self.rate})
535
+ return config
536
+
537
+ def compute_output_shape(self, input_shape):
538
+ return input_shape
539
+
540
+
541
+ class DINOV2MLP(layers.Layer):
542
+ """A DINOV2 MLP block.
543
+
544
+ Args:
545
+ hidden_dim: int. The number of units in the output layer.
546
+ intermediate_dim: int. The output dimension of the first Dense layer.
547
+ activation: str of callable. Activation to use in the intermediate
548
+ layer. Defaults to `"gelu"`.
549
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
550
+ including `name`, `dtype` etc.
551
+ """
552
+
553
+ def __init__(
554
+ self, hidden_dim, intermediate_dim, activation="gelu", **kwargs
555
+ ):
556
+ super().__init__(**kwargs)
557
+ self.hidden_dim = int(hidden_dim)
558
+ self.intermediate_dim = int(intermediate_dim)
559
+ self.activation = activation
560
+
561
+ self.fc1 = layers.Dense(
562
+ self.intermediate_dim,
563
+ activation=activation,
564
+ kernel_initializer=initializers.TruncatedNormal(stddev=0.02),
565
+ dtype=self.dtype_policy,
566
+ name="fc1",
567
+ )
568
+ self.fc2 = layers.Dense(
569
+ self.hidden_dim,
570
+ kernel_initializer=initializers.TruncatedNormal(stddev=0.02),
571
+ dtype=self.dtype_policy,
572
+ name="fc2",
573
+ )
574
+
575
+ def build(self, input_shape):
576
+ self.fc1.build(input_shape)
577
+ input_shape = self.fc1.compute_output_shape(input_shape)
578
+ self.fc2.build(input_shape)
579
+
580
+ def call(self, inputs, training=None):
581
+ x = self.fc1(inputs, training=training)
582
+ x = self.fc2(x, training=training)
583
+ return x
584
+
585
+ def get_config(self):
586
+ config = super().get_config()
587
+ config.update(
588
+ {
589
+ "hidden_dim": self.hidden_dim,
590
+ "intermediate_dim": self.intermediate_dim,
591
+ "activation": self.activation,
592
+ }
593
+ )
594
+ return config
595
+
596
+ def compute_output_shape(self, input_shape):
597
+ output_shape = list(input_shape)
598
+ output_shape[-1] = self.hidden_dim
599
+ return output_shape
600
+
601
+
602
+ class DINOV2SwiGLUFFN(layers.Layer):
603
+ """A DINOV2 SwiGLU Feed-Forward Network layer.
604
+
605
+ Please refer to [GLU Variants Improve Transformer](
606
+ https://arxiv.org/abs/2002.05202) for more details on SwiGLU.
607
+
608
+ Args:
609
+ hidden_dim: int. The number of units in the output layer.
610
+ intermediate_dim: int. The output dimension of the first Dense layer.
611
+ Note that this value will be multiplied by `2 / 3` and rounded up to
612
+ the nearest multiple of `8`. The reason for this is that SwiGLUFFN
613
+ achieves similar or better performance with fewer parameters
614
+ compared to the original FFN implementation.
615
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
616
+ including `name`, `dtype` etc.
617
+ """
618
+
619
+ def __init__(self, hidden_dim, intermediate_dim, **kwargs):
620
+ super().__init__(**kwargs)
621
+ self.hidden_dim = int(hidden_dim)
622
+ self.intermediate_dim = int(intermediate_dim)
623
+ self.actual_intermediate_dim = (
624
+ (int(intermediate_dim * 2 / 3) + 7) // 8 * 8
625
+ )
626
+
627
+ self.weights_in = layers.Dense(
628
+ 2 * self.actual_intermediate_dim,
629
+ kernel_initializer=initializers.TruncatedNormal(stddev=0.02),
630
+ dtype=self.dtype_policy,
631
+ name="weights_in",
632
+ )
633
+ self.weights_out = layers.Dense(
634
+ self.hidden_dim,
635
+ kernel_initializer=initializers.TruncatedNormal(stddev=0.02),
636
+ dtype=self.dtype_policy,
637
+ name="weights_out",
638
+ )
639
+
640
+ def build(self, input_shape):
641
+ self.weights_in.build(input_shape)
642
+ input_shape = list(input_shape)
643
+ input_shape[-1] = self.actual_intermediate_dim
644
+ self.weights_out.build(input_shape)
645
+
646
+ def call(self, inputs, training=None):
647
+ x = self.weights_in(inputs, training=training)
648
+ x1, x2 = ops.split(x, 2, axis=-1)
649
+ x = ops.multiply(ops.silu(x1), x2)
650
+ x = self.weights_out(x, training=training)
651
+ return x
652
+
653
+ def get_config(self):
654
+ config = super().get_config()
655
+ config.update(
656
+ {
657
+ "hidden_dim": self.hidden_dim,
658
+ "intermediate_dim": self.intermediate_dim,
659
+ }
660
+ )
661
+ return config
662
+
663
+ def compute_output_shape(self, input_shape):
664
+ output_shape = list(input_shape)
665
+ output_shape[-1] = self.hidden_dim
666
+ return output_shape
667
+
668
+
669
+ class DINOV2Layer(layers.Layer):
670
+ """A DINOV2 encoder layer.
671
+
672
+ Args:
673
+ hidden_dim: int. The number of units in the hidden layers.
674
+ num_heads: int. Number of attention heads.
675
+ layer_scale_init_value: float. The initial value for the scale.
676
+ Defaults to `1.0`.
677
+ intermediate_dim: int. The output dimension of the first Dense layer in
678
+ a two-layer feedforward network for each transformer.
679
+ use_swiglu_ffn: bool. Whether to use SwigLUFFN instead of MLP.
680
+ Defaults to `False`.
681
+ dropout_rate: float. The dropout rate to use. Defaults to `0.0`.
682
+ drop_path_rate: float. The drop path rate to use. Defaults to `0.0`.
683
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
684
+ including `name`, `dtype` etc.
685
+ """
686
+
687
+ def __init__(
688
+ self,
689
+ hidden_dim,
690
+ num_heads,
691
+ intermediate_dim,
692
+ layer_scale_init_value=1.0,
693
+ use_swiglu_ffn=False,
694
+ dropout_rate=0.0,
695
+ drop_path_rate=0.0,
696
+ **kwargs,
697
+ ):
698
+ super().__init__(**kwargs)
699
+ self.hidden_dim = int(hidden_dim)
700
+ self.num_heads = int(num_heads)
701
+ self.intermediate_dim = int(intermediate_dim)
702
+ self.layer_scale_init_value = float(layer_scale_init_value)
703
+ self.use_swiglu_ffn = bool(use_swiglu_ffn)
704
+ self.dropout_rate = float(dropout_rate)
705
+ self.drop_path_rate = float(drop_path_rate)
706
+
707
+ self.norm1 = layers.LayerNormalization(
708
+ epsilon=1e-6, dtype=self.dtype_policy, name="norm1"
709
+ )
710
+ self.attention = DINOV2Attention(
711
+ hidden_dim=self.hidden_dim,
712
+ num_heads=self.num_heads,
713
+ dropout_rate=self.dropout_rate,
714
+ dtype=self.dtype_policy,
715
+ name="attention",
716
+ )
717
+ self.layer_scale1 = DINOV2LayerScale(
718
+ hidden_dim=self.hidden_dim,
719
+ init_values=self.layer_scale_init_value,
720
+ dtype=self.dtype_policy,
721
+ name="layer_scale1",
722
+ )
723
+ self.drop_path = (
724
+ DINOV2DropPath(
725
+ rate=self.drop_path_rate,
726
+ dtype=self.dtype_policy,
727
+ name="drop_path",
728
+ )
729
+ if self.drop_path_rate > 0
730
+ else layers.Identity(dtype=self.dtype_policy, name="drop_path")
731
+ )
732
+ self.norm2 = layers.LayerNormalization(
733
+ epsilon=1e-6, dtype=self.dtype_policy, name="norm2"
734
+ )
735
+ if self.use_swiglu_ffn:
736
+ self.mlp = DINOV2SwiGLUFFN(
737
+ hidden_dim=self.hidden_dim,
738
+ intermediate_dim=self.intermediate_dim,
739
+ dtype=self.dtype_policy,
740
+ name="mlp",
741
+ )
742
+ else:
743
+ self.mlp = DINOV2MLP(
744
+ hidden_dim=self.hidden_dim,
745
+ intermediate_dim=self.intermediate_dim,
746
+ activation="gelu",
747
+ dtype=self.dtype_policy,
748
+ name="mlp",
749
+ )
750
+ self.layer_scale2 = DINOV2LayerScale(
751
+ hidden_dim=self.hidden_dim,
752
+ init_values=self.layer_scale_init_value,
753
+ dtype=self.dtype_policy,
754
+ name="layer_scale2",
755
+ )
756
+
757
+ def build(self, input_shape):
758
+ self.norm1.build(input_shape)
759
+ self.attention.build(input_shape)
760
+ input_shape = self.attention.compute_output_shape(input_shape)
761
+ self.layer_scale1.build(input_shape)
762
+ self.drop_path.build(input_shape)
763
+ self.norm2.build(input_shape)
764
+ self.mlp.build(input_shape)
765
+ input_shape = self.mlp.compute_output_shape(input_shape)
766
+ self.layer_scale2.build(input_shape)
767
+
768
+ def call(self, inputs, training=None):
769
+ x = inputs
770
+ x = self.norm1(x, training=training)
771
+ x = self.attention(x, training=training)
772
+ x = self.layer_scale1(x, training=training)
773
+
774
+ # First residual connection.
775
+ hidden_states = ops.add(self.drop_path(x, training=training), inputs)
776
+ x = self.norm2(hidden_states, training=training)
777
+ x = self.mlp(x, training=training)
778
+ x = self.layer_scale2(x, training=training)
779
+
780
+ # Second residual connection.
781
+ return ops.add(self.drop_path(x, training=training), hidden_states)
782
+
783
+ def get_config(self):
784
+ config = super().get_config()
785
+ config.update(
786
+ {
787
+ "hidden_dim": self.hidden_dim,
788
+ "num_heads": self.num_heads,
789
+ "intermediate_dim": self.intermediate_dim,
790
+ "layer_scale_init_value": self.layer_scale_init_value,
791
+ "use_swiglu_ffn": self.use_swiglu_ffn,
792
+ "dropout_rate": self.dropout_rate,
793
+ "drop_path_rate": self.drop_path_rate,
794
+ }
795
+ )
796
+ return config
797
+
798
+ def compute_output_shape(self, input_shape):
799
+ return input_shape
800
+
801
+
802
+ class DINOV2Encoder(layers.Layer):
803
+ """A DINOV2 encoder.
804
+
805
+ Args:
806
+ num_layers: int. The number of transformer layers.
807
+ hidden_dim: int. The number of units in the hidden layers.
808
+ num_heads: int. Number of attention heads.
809
+ intermediate_dim: int. The output dimension of the first Dense layer in
810
+ a two-layer feedforward network for each transformer.
811
+ layer_scale_init_value: float. The initial value for the scale.
812
+ Defaults to `1.0`.
813
+ use_swiglu_ffn: bool. Whether to use SwigLUFFN instead of MLP.
814
+ Defaults to `False`.
815
+ dropout_rate: float. The dropout rate to use. Defaults to `0.0`.
816
+ drop_path_rate: float. The drop path rate to use. Defaults to `0.0`.
817
+ **kwargs: other keyword arguments passed to `keras.layers.Layer`,
818
+ including `name`, `dtype` etc.
819
+ """
820
+
821
+ def __init__(
822
+ self,
823
+ num_layers,
824
+ hidden_dim,
825
+ num_heads,
826
+ intermediate_dim,
827
+ layer_scale_init_value=1.0,
828
+ use_swiglu_ffn=False,
829
+ dropout_rate=0.0,
830
+ drop_path_rate=0.0,
831
+ **kwargs,
832
+ ):
833
+ super().__init__(**kwargs)
834
+ self.num_layers = int(num_layers)
835
+ self.hidden_dim = int(hidden_dim)
836
+ self.num_heads = int(num_heads)
837
+ self.intermediate_dim = int(intermediate_dim)
838
+ self.layer_scale_init_value = float(layer_scale_init_value)
839
+ self.use_swiglu_ffn = bool(use_swiglu_ffn)
840
+ self.dropout_rate = float(dropout_rate)
841
+ self.drop_path_rate = float(drop_path_rate)
842
+
843
+ self.layers = [
844
+ DINOV2Layer(
845
+ hidden_dim=self.hidden_dim,
846
+ num_heads=self.num_heads,
847
+ intermediate_dim=self.intermediate_dim,
848
+ layer_scale_init_value=self.layer_scale_init_value,
849
+ use_swiglu_ffn=self.use_swiglu_ffn,
850
+ dropout_rate=self.dropout_rate,
851
+ drop_path_rate=self.drop_path_rate,
852
+ dtype=self.dtype_policy,
853
+ name=f"layers_{i}",
854
+ )
855
+ for i in range(self.num_layers)
856
+ ]
857
+
858
+ def build(self, input_shape):
859
+ for layer in self.layers:
860
+ layer.build(input_shape)
861
+ input_shape = layer.compute_output_shape(input_shape)
862
+
863
+ def call(self, inputs, training=None):
864
+ x = inputs
865
+ for layer in self.layers:
866
+ x = layer(x, training=training)
867
+ return x
868
+
869
+ def get_config(self):
870
+ config = super().get_config()
871
+ config.update(
872
+ {
873
+ "num_layers": self.num_layers,
874
+ "hidden_dim": self.hidden_dim,
875
+ "num_heads": self.num_heads,
876
+ "intermediate_dim": self.intermediate_dim,
877
+ "layer_scale_init_value": self.layer_scale_init_value,
878
+ "use_swiglu_ffn": self.use_swiglu_ffn,
879
+ "dropout_rate": self.dropout_rate,
880
+ "drop_path_rate": self.drop_path_rate,
881
+ }
882
+ )
883
+ return config
884
+
885
+ def compute_output_shape(self, input_shape):
886
+ return input_shape