keras-hub-nightly 0.22.0.dev202505300409__py3-none-any.whl → 0.22.0.dev202505310408__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/layers/__init__.py +3 -0
- keras_hub/models/__init__.py +7 -0
- keras_hub/src/models/deit/__init__.py +0 -0
- keras_hub/src/models/deit/deit_backbone.py +154 -0
- keras_hub/src/models/deit/deit_image_classifier.py +171 -0
- keras_hub/src/models/deit/deit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/deit/deit_image_converter.py +8 -0
- keras_hub/src/models/deit/deit_layers.py +519 -0
- keras_hub/src/models/deit/deit_presets.py +49 -0
- keras_hub/src/utils/transformers/convert_deit.py +155 -0
- keras_hub/src/utils/transformers/preset_loader.py +4 -1
- keras_hub/src/version.py +1 -1
- {keras_hub_nightly-0.22.0.dev202505300409.dist-info → keras_hub_nightly-0.22.0.dev202505310408.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.22.0.dev202505300409.dist-info → keras_hub_nightly-0.22.0.dev202505310408.dist-info}/RECORD +16 -8
- {keras_hub_nightly-0.22.0.dev202505300409.dist-info → keras_hub_nightly-0.22.0.dev202505310408.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.22.0.dev202505300409.dist-info → keras_hub_nightly-0.22.0.dev202505310408.dist-info}/top_level.txt +0 -0
keras_hub/layers/__init__.py
CHANGED
@@ -78,6 +78,9 @@ from keras_hub.src.models.cspnet.cspnet_image_converter import (
|
|
78
78
|
from keras_hub.src.models.deeplab_v3.deeplab_v3_image_converter import (
|
79
79
|
DeepLabV3ImageConverter as DeepLabV3ImageConverter,
|
80
80
|
)
|
81
|
+
from keras_hub.src.models.deit.deit_image_converter import (
|
82
|
+
DeiTImageConverter as DeiTImageConverter,
|
83
|
+
)
|
81
84
|
from keras_hub.src.models.densenet.densenet_image_converter import (
|
82
85
|
DenseNetImageConverter as DenseNetImageConverter,
|
83
86
|
)
|
keras_hub/models/__init__.py
CHANGED
@@ -141,6 +141,13 @@ from keras_hub.src.models.deeplab_v3.deeplab_v3_image_segmeter_preprocessor impo
|
|
141
141
|
from keras_hub.src.models.deeplab_v3.deeplab_v3_segmenter import (
|
142
142
|
DeepLabV3ImageSegmenter as DeepLabV3ImageSegmenter,
|
143
143
|
)
|
144
|
+
from keras_hub.src.models.deit.deit_backbone import DeiTBackbone as DeiTBackbone
|
145
|
+
from keras_hub.src.models.deit.deit_image_classifier import (
|
146
|
+
DeiTImageClassifier as DeiTImageClassifier,
|
147
|
+
)
|
148
|
+
from keras_hub.src.models.deit.deit_image_classifier_preprocessor import (
|
149
|
+
DeiTImageClassifierPreprocessor as DeiTImageClassifierPreprocessor,
|
150
|
+
)
|
144
151
|
from keras_hub.src.models.densenet.densenet_backbone import (
|
145
152
|
DenseNetBackbone as DenseNetBackbone,
|
146
153
|
)
|
File without changes
|
@@ -0,0 +1,154 @@
|
|
1
|
+
import keras
|
2
|
+
|
3
|
+
from keras_hub.src.api_export import keras_hub_export
|
4
|
+
from keras_hub.src.models.backbone import Backbone
|
5
|
+
from keras_hub.src.models.deit.deit_layers import DeiTEmbeddings
|
6
|
+
from keras_hub.src.models.deit.deit_layers import DeiTEncoder
|
7
|
+
from keras_hub.src.utils.keras_utils import standardize_data_format
|
8
|
+
|
9
|
+
|
10
|
+
@keras_hub_export("keras_hub.models.DeiTBackbone")
|
11
|
+
class DeiTBackbone(Backbone):
|
12
|
+
"""DeiT backbone.
|
13
|
+
|
14
|
+
This backbone implements the Data-efficient Image Transformer (DeiT)
|
15
|
+
architecture as described in [Training data-efficient image
|
16
|
+
transformers & distillation through attention]
|
17
|
+
(https://arxiv.org/abs/2012.12877).
|
18
|
+
|
19
|
+
Args:
|
20
|
+
image_shape: A tuple or list of 3 integers representing the shape of the
|
21
|
+
input image `(height, width, channels)`.
|
22
|
+
patch_size: tuple or int. The size of each image patch. If an int is
|
23
|
+
provided, it will be used for both height and width. The input image
|
24
|
+
will be split into patches of shape `(patch_size_h, patch_size_w)`.
|
25
|
+
num_layers: int. The number of transformer encoder layers.
|
26
|
+
num_heads: int. The number of attention heads in each Transformer
|
27
|
+
encoder layer.
|
28
|
+
hidden_dim: int. The dimensionality of the hidden representations.
|
29
|
+
intermediate_dim: int. The dimensionality of the intermediate MLP layer
|
30
|
+
in each Transformer encoder layer.
|
31
|
+
dropout_rate: float. The dropout rate for the Transformer encoder
|
32
|
+
layers.
|
33
|
+
attention_dropout: float. The dropout rate for the attention mechanism
|
34
|
+
in each Transformer encoder layer.
|
35
|
+
layer_norm_epsilon: float. Value used for numerical stability in layer
|
36
|
+
normalization.
|
37
|
+
use_mha_bias: bool. Whether to use bias in the multi-head attention
|
38
|
+
layers.
|
39
|
+
data_format: str. `"channels_last"` or `"channels_first"`, specifying
|
40
|
+
the data format for the input image. If `None`, defaults to
|
41
|
+
`"channels_last"`.
|
42
|
+
dtype: The dtype of the layer weights. Defaults to None.
|
43
|
+
**kwargs: Additional keyword arguments to be passed to the parent
|
44
|
+
`Backbone` class.
|
45
|
+
"""
|
46
|
+
|
47
|
+
def __init__(
|
48
|
+
self,
|
49
|
+
image_shape,
|
50
|
+
patch_size,
|
51
|
+
num_layers,
|
52
|
+
num_heads,
|
53
|
+
hidden_dim,
|
54
|
+
intermediate_dim,
|
55
|
+
dropout_rate=0.0,
|
56
|
+
attention_dropout=0.0,
|
57
|
+
layer_norm_epsilon=1e-6,
|
58
|
+
use_mha_bias=True,
|
59
|
+
data_format=None,
|
60
|
+
dtype=None,
|
61
|
+
**kwargs,
|
62
|
+
):
|
63
|
+
# === Laters ===
|
64
|
+
data_format = standardize_data_format(data_format)
|
65
|
+
if isinstance(patch_size, int):
|
66
|
+
patch_size = (patch_size, patch_size)
|
67
|
+
h_axis, w_axis, channels_axis = (
|
68
|
+
(-3, -2, -1) if data_format == "channels_last" else (-2, -1, -3)
|
69
|
+
)
|
70
|
+
# Check that the input image is well specified.
|
71
|
+
if image_shape[h_axis] is None or image_shape[w_axis] is None:
|
72
|
+
raise ValueError(
|
73
|
+
f"Image shape must have defined height and width. Found `None` "
|
74
|
+
f"at index {h_axis} (height) or {w_axis} (width). "
|
75
|
+
f"Image shape: {image_shape}"
|
76
|
+
)
|
77
|
+
# Check that image dimensions be divisible by patch size
|
78
|
+
if image_shape[h_axis] % patch_size[0] != 0:
|
79
|
+
raise ValueError(
|
80
|
+
f"Input height {image_shape[h_axis]} should be divisible by "
|
81
|
+
f"patch size {patch_size}."
|
82
|
+
)
|
83
|
+
if image_shape[w_axis] % patch_size[1] != 0:
|
84
|
+
raise ValueError(
|
85
|
+
f"Input height {image_shape[w_axis]} should be divisible by "
|
86
|
+
f"patch size {patch_size}."
|
87
|
+
)
|
88
|
+
|
89
|
+
num_channels = image_shape[channels_axis]
|
90
|
+
|
91
|
+
# === Functional Model ===
|
92
|
+
inputs = keras.layers.Input(shape=image_shape)
|
93
|
+
|
94
|
+
x = DeiTEmbeddings(
|
95
|
+
image_size=(image_shape[h_axis], image_shape[w_axis]),
|
96
|
+
patch_size=patch_size,
|
97
|
+
hidden_dim=hidden_dim,
|
98
|
+
num_channels=num_channels,
|
99
|
+
data_format=data_format,
|
100
|
+
dropout_rate=dropout_rate,
|
101
|
+
dtype=dtype,
|
102
|
+
name="deit_patching_and_embedding",
|
103
|
+
)(inputs)
|
104
|
+
|
105
|
+
output, _, _ = DeiTEncoder(
|
106
|
+
num_layers=num_layers,
|
107
|
+
num_heads=num_heads,
|
108
|
+
hidden_dim=hidden_dim,
|
109
|
+
intermediate_dim=intermediate_dim,
|
110
|
+
use_mha_bias=use_mha_bias,
|
111
|
+
dropout_rate=dropout_rate,
|
112
|
+
attention_dropout=attention_dropout,
|
113
|
+
layer_norm_epsilon=layer_norm_epsilon,
|
114
|
+
dtype=dtype,
|
115
|
+
name="deit_encoder",
|
116
|
+
)(x)
|
117
|
+
|
118
|
+
super().__init__(
|
119
|
+
inputs=inputs,
|
120
|
+
outputs=output,
|
121
|
+
dtype=dtype,
|
122
|
+
**kwargs,
|
123
|
+
)
|
124
|
+
|
125
|
+
# === Config ===
|
126
|
+
self.image_shape = image_shape
|
127
|
+
self.patch_size = patch_size
|
128
|
+
self.num_layers = num_layers
|
129
|
+
self.num_heads = num_heads
|
130
|
+
self.hidden_dim = hidden_dim
|
131
|
+
self.intermediate_dim = intermediate_dim
|
132
|
+
self.dropout_rate = dropout_rate
|
133
|
+
self.attention_dropout = attention_dropout
|
134
|
+
self.layer_norm_epsilon = layer_norm_epsilon
|
135
|
+
self.use_mha_bias = use_mha_bias
|
136
|
+
self.data_format = data_format
|
137
|
+
|
138
|
+
def get_config(self):
|
139
|
+
config = super().get_config()
|
140
|
+
config.update(
|
141
|
+
{
|
142
|
+
"image_shape": self.image_shape,
|
143
|
+
"patch_size": self.patch_size,
|
144
|
+
"num_layers": self.num_layers,
|
145
|
+
"num_heads": self.num_heads,
|
146
|
+
"hidden_dim": self.hidden_dim,
|
147
|
+
"intermediate_dim": self.intermediate_dim,
|
148
|
+
"dropout_rate": self.dropout_rate,
|
149
|
+
"attention_dropout": self.attention_dropout,
|
150
|
+
"layer_norm_epsilon": self.layer_norm_epsilon,
|
151
|
+
"use_mha_bias": self.use_mha_bias,
|
152
|
+
}
|
153
|
+
)
|
154
|
+
return config
|
@@ -0,0 +1,171 @@
|
|
1
|
+
import keras
|
2
|
+
from keras import ops
|
3
|
+
|
4
|
+
from keras_hub.src.api_export import keras_hub_export
|
5
|
+
from keras_hub.src.models.deit.deit_backbone import DeiTBackbone
|
6
|
+
from keras_hub.src.models.deit.deit_image_classifier_preprocessor import (
|
7
|
+
DeiTImageClassifierPreprocessor,
|
8
|
+
)
|
9
|
+
from keras_hub.src.models.image_classifier import ImageClassifier
|
10
|
+
from keras_hub.src.models.task import Task
|
11
|
+
|
12
|
+
|
13
|
+
@keras_hub_export("keras_hub.models.DeiTImageClassifier")
|
14
|
+
class DeiTImageClassifier(ImageClassifier):
|
15
|
+
"""DeiT image classification task.
|
16
|
+
|
17
|
+
`DeiTImageClassifier` tasks wrap a `keras_hub.models.DeiTBackbone` and
|
18
|
+
a `keras_hub.models.Preprocessor` to create a model that can be used for
|
19
|
+
image classification. `DeiTImageClassifier` tasks take an additional
|
20
|
+
`num_classes` argument, controlling the number of predicted output classes.
|
21
|
+
|
22
|
+
To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)`
|
23
|
+
labels where `x` is a string and `y` is a integer from `[0, num_classes)`.
|
24
|
+
|
25
|
+
Not that unlike `keras_hub.model.ImageClassifier`, the `DeiTImageClassifier`
|
26
|
+
we pluck out `cls_token` which is first seqence from the backbone.
|
27
|
+
|
28
|
+
Args:
|
29
|
+
backbone: A `keras_hub.models.DeiTBackbone` instance or a `keras.Model`.
|
30
|
+
num_classes: int. The number of classes to predict.
|
31
|
+
preprocessor: `None`, a `keras_hub.models.Preprocessor` instance,
|
32
|
+
a `keras.Layer` instance, or a callable. If `None` no preprocessing
|
33
|
+
will be applied to the inputs.
|
34
|
+
pooling: String specifying the classification strategy. The choice
|
35
|
+
impacts the dimensionality and nature of the feature vector used for
|
36
|
+
classification.
|
37
|
+
`"token"`: A single vector (class token) representing the
|
38
|
+
overall image features.
|
39
|
+
`"gap"`: A single vector representing the average features
|
40
|
+
across the spatial dimensions.
|
41
|
+
activation: `None`, str, or callable. The activation function to use on
|
42
|
+
the `Dense` layer. Set `activation=None` to return the output
|
43
|
+
logits. Defaults to `None`.
|
44
|
+
head_dtype: `None`, str, or `keras.mixed_precision.DTypePolicy`. The
|
45
|
+
dtype to use for the classification head's computations and weights.
|
46
|
+
|
47
|
+
Examples:
|
48
|
+
|
49
|
+
Call `predict()` to run inference.
|
50
|
+
```python
|
51
|
+
# Load preset and train
|
52
|
+
images = np.random.randint(0, 256, size=(2, 384, 384, 3))
|
53
|
+
classifier = keras_hub.models.DeiTImageClassifier.from_preset(
|
54
|
+
"hf://facebook/deit-base-distilled-patch16-384"
|
55
|
+
)
|
56
|
+
classifier.predict(images)
|
57
|
+
```
|
58
|
+
|
59
|
+
Call `fit()` on a single batch.
|
60
|
+
```python
|
61
|
+
# Load preset and train
|
62
|
+
images = np.random.randint(0, 256, size=(2, 384, 384, 3))
|
63
|
+
labels = [0, 3]
|
64
|
+
classifier = keras_hub.models.DeiTImageClassifier.from_preset(
|
65
|
+
"hf://facebook/deit-base-distilled-patch16-384"
|
66
|
+
)
|
67
|
+
classifier.fit(x=images, y=labels, batch_size=2)
|
68
|
+
```
|
69
|
+
|
70
|
+
Call `fit()` with custom loss, optimizer and backbone.
|
71
|
+
```python
|
72
|
+
classifier = keras_hub.models.DeiTImageClassifier.from_preset(
|
73
|
+
"hf://facebook/deit-base-distilled-patch16-384"
|
74
|
+
)
|
75
|
+
classifier.compile(
|
76
|
+
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
|
77
|
+
optimizer=keras.optimizers.Adam(5e-5),
|
78
|
+
)
|
79
|
+
classifier.backbone.trainable = False
|
80
|
+
classifier.fit(x=images, y=labels, batch_size=2)
|
81
|
+
```
|
82
|
+
|
83
|
+
Custom backbone.
|
84
|
+
```python
|
85
|
+
images = np.random.randint(0, 256, size=(2, 384, 384, 3))
|
86
|
+
labels = [0, 3]
|
87
|
+
backbone = keras_hub.models.DeiTBackbone(
|
88
|
+
image_shape = (384, 384, 3),
|
89
|
+
patch_size=16,
|
90
|
+
num_layers=6,
|
91
|
+
num_heads=3,
|
92
|
+
hidden_dim=768,
|
93
|
+
intermediate_dim=2048
|
94
|
+
)
|
95
|
+
classifier = keras_hub.models.DeiTImageClassifier(
|
96
|
+
backbone=backbone,
|
97
|
+
num_classes=4,
|
98
|
+
)
|
99
|
+
classifier.fit(x=images, y=labels, batch_size=2)
|
100
|
+
```
|
101
|
+
"""
|
102
|
+
|
103
|
+
backbone_cls = DeiTBackbone
|
104
|
+
preprocessor_cls = DeiTImageClassifierPreprocessor
|
105
|
+
|
106
|
+
def __init__(
|
107
|
+
self,
|
108
|
+
backbone,
|
109
|
+
num_classes,
|
110
|
+
preprocessor=None,
|
111
|
+
pooling="token",
|
112
|
+
activation=None,
|
113
|
+
dropout=0.0,
|
114
|
+
head_dtype=None,
|
115
|
+
**kwargs,
|
116
|
+
):
|
117
|
+
head_dtype = head_dtype or backbone.dtype_policy
|
118
|
+
|
119
|
+
# === Layers ===
|
120
|
+
self.backbone = backbone
|
121
|
+
self.preprocessor = preprocessor
|
122
|
+
self.dropout = keras.layers.Dropout(
|
123
|
+
rate=dropout,
|
124
|
+
dtype=head_dtype,
|
125
|
+
name="output_dropout",
|
126
|
+
)
|
127
|
+
|
128
|
+
self.output_dense = keras.layers.Dense(
|
129
|
+
num_classes,
|
130
|
+
activation=activation,
|
131
|
+
dtype=head_dtype,
|
132
|
+
name="predictions",
|
133
|
+
)
|
134
|
+
|
135
|
+
# === Functional Model ===
|
136
|
+
inputs = self.backbone.input
|
137
|
+
x = self.backbone(inputs)
|
138
|
+
if pooling == "token":
|
139
|
+
x = x[:, 0]
|
140
|
+
elif pooling == "gap":
|
141
|
+
ndim = len(ops.shape(x))
|
142
|
+
x = ops.mean(x, axis=list(range(1, ndim - 1))) # (1,) or (1,2)
|
143
|
+
|
144
|
+
outputs = self.output_dense(x)
|
145
|
+
|
146
|
+
# Skip the parent class functional model.
|
147
|
+
Task.__init__(
|
148
|
+
self,
|
149
|
+
inputs=inputs,
|
150
|
+
outputs=outputs,
|
151
|
+
**kwargs,
|
152
|
+
)
|
153
|
+
|
154
|
+
# === config ===
|
155
|
+
self.num_classes = num_classes
|
156
|
+
self.pooling = pooling
|
157
|
+
self.activation = activation
|
158
|
+
self.dropout = dropout
|
159
|
+
|
160
|
+
def get_config(self):
|
161
|
+
# Backbone serialized in `super`
|
162
|
+
config = super().get_config()
|
163
|
+
config.update(
|
164
|
+
{
|
165
|
+
"num_classes": self.num_classes,
|
166
|
+
"pooling": self.pooling,
|
167
|
+
"activation": self.activation,
|
168
|
+
"dropout": self.dropout,
|
169
|
+
}
|
170
|
+
)
|
171
|
+
return config
|
@@ -0,0 +1,12 @@
|
|
1
|
+
from keras_hub.src.api_export import keras_hub_export
|
2
|
+
from keras_hub.src.models.deit.deit_backbone import DeiTBackbone
|
3
|
+
from keras_hub.src.models.deit.deit_image_converter import DeiTImageConverter
|
4
|
+
from keras_hub.src.models.image_classifier_preprocessor import (
|
5
|
+
ImageClassifierPreprocessor,
|
6
|
+
)
|
7
|
+
|
8
|
+
|
9
|
+
@keras_hub_export("keras_hub.models.DeiTImageClassifierPreprocessor")
|
10
|
+
class DeiTImageClassifierPreprocessor(ImageClassifierPreprocessor):
|
11
|
+
backbone_cls = DeiTBackbone
|
12
|
+
image_converter_cls = DeiTImageConverter
|
@@ -0,0 +1,8 @@
|
|
1
|
+
from keras_hub.src.api_export import keras_hub_export
|
2
|
+
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
|
3
|
+
from keras_hub.src.models.deit.deit_backbone import DeiTBackbone
|
4
|
+
|
5
|
+
|
6
|
+
@keras_hub_export("keras_hub.layers.DeiTImageConverter")
|
7
|
+
class DeiTImageConverter(ImageConverter):
|
8
|
+
backbone_cls = DeiTBackbone
|
@@ -0,0 +1,519 @@
|
|
1
|
+
import keras
|
2
|
+
from keras import ops
|
3
|
+
|
4
|
+
from keras_hub.src.utils.keras_utils import standardize_data_format
|
5
|
+
|
6
|
+
|
7
|
+
class DeiTEmbeddings(keras.layers.Layer):
|
8
|
+
"""Patches the image and embeds the patches.
|
9
|
+
|
10
|
+
Args:
|
11
|
+
image_size: tuple. Size of the input image (height, width).
|
12
|
+
patch_size: tuple. patch_size: tuple. The size of each image
|
13
|
+
patch as (patch_height, patch_width).
|
14
|
+
hidden_dim: int. Dimensionality of the patch embeddings.
|
15
|
+
num_channels: int. Number of channels in the input image. Defaults to
|
16
|
+
`3`.
|
17
|
+
data_format: str. `"channels_last"` or `"channels_first"`. Defaults to
|
18
|
+
`None` (which uses `"channels_last"`).
|
19
|
+
use_mask_token: bool. Whether to use a mask token. Defaults to `False`.
|
20
|
+
dropout_rate: float. Dropout rate. Between 0 and 1. Defaults to
|
21
|
+
`0.0`.
|
22
|
+
**kwargs: Additional keyword arguments passed to `keras.layers.Layer`
|
23
|
+
"""
|
24
|
+
|
25
|
+
def __init__(
|
26
|
+
self,
|
27
|
+
image_size,
|
28
|
+
patch_size,
|
29
|
+
hidden_dim,
|
30
|
+
num_channels=3,
|
31
|
+
data_format=None,
|
32
|
+
use_mask_token=False,
|
33
|
+
dropout_rate=0.0,
|
34
|
+
**kwargs,
|
35
|
+
):
|
36
|
+
super().__init__(**kwargs)
|
37
|
+
num_patches = (image_size[0] // patch_size[0]) * (
|
38
|
+
image_size[1] // patch_size[1]
|
39
|
+
)
|
40
|
+
num_positions = num_patches + 2
|
41
|
+
|
42
|
+
# === Config ===
|
43
|
+
self.image_size = image_size
|
44
|
+
self.patch_size = patch_size
|
45
|
+
self.hidden_dim = hidden_dim
|
46
|
+
self.num_channels = num_channels
|
47
|
+
self.num_patches = num_patches
|
48
|
+
self.num_positions = num_positions
|
49
|
+
self.data_format = standardize_data_format(data_format)
|
50
|
+
self.use_mask_token = use_mask_token
|
51
|
+
self.dropout_rate = dropout_rate
|
52
|
+
|
53
|
+
def build(self, input_shape):
|
54
|
+
if self.use_mask_token:
|
55
|
+
self.mask_token = self.add_weight(
|
56
|
+
shape=(1, 1, self.hidden_dim),
|
57
|
+
initializer="zeros",
|
58
|
+
dtype=self.variable_dtype,
|
59
|
+
name="mask_token",
|
60
|
+
)
|
61
|
+
self.class_token = self.add_weight(
|
62
|
+
shape=(
|
63
|
+
1,
|
64
|
+
1,
|
65
|
+
self.hidden_dim,
|
66
|
+
),
|
67
|
+
initializer="zeros",
|
68
|
+
dtype=self.variable_dtype,
|
69
|
+
name="class_token",
|
70
|
+
)
|
71
|
+
self.distillation_token = self.add_weight(
|
72
|
+
shape=(
|
73
|
+
1,
|
74
|
+
1,
|
75
|
+
self.hidden_dim,
|
76
|
+
),
|
77
|
+
initializer="zeros",
|
78
|
+
dtype=self.variable_dtype,
|
79
|
+
name="distillation_token",
|
80
|
+
)
|
81
|
+
self.patch_embedding = keras.layers.Conv2D(
|
82
|
+
filters=self.hidden_dim,
|
83
|
+
kernel_size=self.patch_size,
|
84
|
+
strides=self.patch_size,
|
85
|
+
padding="valid",
|
86
|
+
activation=None,
|
87
|
+
dtype=self.dtype_policy,
|
88
|
+
data_format=self.data_format,
|
89
|
+
name="patch_embedding",
|
90
|
+
)
|
91
|
+
self.patch_embedding.build(input_shape)
|
92
|
+
self.position_embedding = self.add_weight(
|
93
|
+
shape=(
|
94
|
+
1,
|
95
|
+
self.num_positions,
|
96
|
+
self.hidden_dim,
|
97
|
+
), # Matches the shape in PyTorch
|
98
|
+
initializer=keras.initializers.RandomNormal(
|
99
|
+
stddev=0.02
|
100
|
+
), # Equivalent to torch.randn()
|
101
|
+
dtype=self.variable_dtype,
|
102
|
+
trainable=True,
|
103
|
+
name="position_embedding",
|
104
|
+
)
|
105
|
+
self.dropout = keras.layers.Dropout(
|
106
|
+
self.dropout_rate, dtype=self.dtype_policy, name="dropout"
|
107
|
+
)
|
108
|
+
|
109
|
+
self.built = True
|
110
|
+
|
111
|
+
def call(self, inputs, bool_masked_pos=None):
|
112
|
+
patch_embeddings = self.patch_embedding(inputs)
|
113
|
+
if self.data_format == "channels_first":
|
114
|
+
patch_embeddings = ops.transpose(
|
115
|
+
patch_embeddings, axes=(0, 2, 3, 1)
|
116
|
+
)
|
117
|
+
embeddings_shape = ops.shape(patch_embeddings)
|
118
|
+
patch_embeddings = ops.reshape(
|
119
|
+
patch_embeddings, [embeddings_shape[0], -1, embeddings_shape[-1]]
|
120
|
+
)
|
121
|
+
|
122
|
+
if bool_masked_pos is not None and self.use_mask_token:
|
123
|
+
# Expand dimensions to match the embeddings
|
124
|
+
bool_masked_pos_expanded = ops.expand_dims(
|
125
|
+
bool_masked_pos, axis=-1
|
126
|
+
) # (batch_size, num_patches, 1)
|
127
|
+
mask_token_expanded = ops.expand_dims(
|
128
|
+
self.mask_token, axis=0
|
129
|
+
) # (1, 1, hidden_size)
|
130
|
+
# Apply masking
|
131
|
+
embeddings = ops.where(
|
132
|
+
bool_masked_pos_expanded, mask_token_expanded, patch_embeddings
|
133
|
+
)
|
134
|
+
|
135
|
+
class_token = ops.tile(self.class_token, (embeddings_shape[0], 1, 1))
|
136
|
+
distillation_token = ops.tile(
|
137
|
+
self.distillation_token, (embeddings_shape[0], 1, 1)
|
138
|
+
)
|
139
|
+
embeddings = ops.concatenate(
|
140
|
+
[class_token, distillation_token, patch_embeddings], axis=1
|
141
|
+
)
|
142
|
+
position_embedding = self.position_embedding
|
143
|
+
embeddings = ops.add(embeddings, position_embedding)
|
144
|
+
embeddings = self.dropout(embeddings)
|
145
|
+
return embeddings
|
146
|
+
|
147
|
+
def compute_output_shape(self, input_shape):
|
148
|
+
return (
|
149
|
+
input_shape[0],
|
150
|
+
self.num_positions,
|
151
|
+
self.hidden_dim,
|
152
|
+
)
|
153
|
+
|
154
|
+
def get_config(self):
|
155
|
+
config = super().get_config()
|
156
|
+
config.update(
|
157
|
+
{
|
158
|
+
"image_size": self.image_size,
|
159
|
+
"patch_size": self.patch_size,
|
160
|
+
"hidden_dim": self.hidden_dim,
|
161
|
+
"num_channels": self.num_channels,
|
162
|
+
"num_patches": self.num_patches,
|
163
|
+
"num_positions": self.num_positions,
|
164
|
+
"use_mask_token": self.use_mask_token,
|
165
|
+
"dropout_rate": self.dropout_rate,
|
166
|
+
}
|
167
|
+
)
|
168
|
+
return config
|
169
|
+
|
170
|
+
|
171
|
+
class DeiTIntermediate(keras.layers.Layer):
|
172
|
+
"""DeiTIntermediate block.
|
173
|
+
Args:
|
174
|
+
intermediate_dim: int. Dimensionality of the intermediate MLP layer.
|
175
|
+
**kwargs: Additional keyword arguments passed to `keras.layers.Layer`
|
176
|
+
"""
|
177
|
+
|
178
|
+
def __init__(
|
179
|
+
self,
|
180
|
+
intermediate_dim,
|
181
|
+
**kwargs,
|
182
|
+
):
|
183
|
+
super().__init__(**kwargs)
|
184
|
+
|
185
|
+
# === Config ===
|
186
|
+
self.intermediate_dim = intermediate_dim
|
187
|
+
|
188
|
+
def build(self, input_shape):
|
189
|
+
self.dense = keras.layers.Dense(
|
190
|
+
units=self.intermediate_dim,
|
191
|
+
activation="gelu",
|
192
|
+
dtype=self.dtype_policy,
|
193
|
+
name="dense",
|
194
|
+
)
|
195
|
+
self.dense.build(input_shape)
|
196
|
+
self.built = True
|
197
|
+
|
198
|
+
def call(self, inputs):
|
199
|
+
out = self.dense(inputs)
|
200
|
+
return out
|
201
|
+
|
202
|
+
def get_config(self):
|
203
|
+
config = super().get_config()
|
204
|
+
config.update(
|
205
|
+
{
|
206
|
+
"intermediate_dim": self.intermediate_dim,
|
207
|
+
}
|
208
|
+
)
|
209
|
+
return config
|
210
|
+
|
211
|
+
|
212
|
+
class DeiTOutput(keras.layers.Layer):
|
213
|
+
"""DeiT Output layer implementation.
|
214
|
+
Args:
|
215
|
+
hidden_dim: int. Dimensionality of the patch embeddings.
|
216
|
+
dropout_rate: float. Dropout rate. Between 0 and 1. Defaults to
|
217
|
+
`0.0`.
|
218
|
+
**kwargs: Additional keyword arguments passed to `keras.layers.Layer`
|
219
|
+
"""
|
220
|
+
|
221
|
+
def __init__(self, hidden_dim, dropout_rate=0.1, **kwargs):
|
222
|
+
super().__init__(**kwargs)
|
223
|
+
self.hidden_dim = hidden_dim
|
224
|
+
self.dropout_rate = dropout_rate
|
225
|
+
|
226
|
+
def build(self, input_shape):
|
227
|
+
self.dense = keras.layers.Dense(
|
228
|
+
self.hidden_dim, dtype=self.dtype_policy, name="output"
|
229
|
+
)
|
230
|
+
self.dense.build(input_shape)
|
231
|
+
|
232
|
+
self.dropout = keras.layers.Dropout(
|
233
|
+
self.dropout_rate, dtype=self.dtype_policy, name="dropout"
|
234
|
+
)
|
235
|
+
# Mark this layer as built
|
236
|
+
self.built = True
|
237
|
+
|
238
|
+
def call(self, hidden_states, input_tensor):
|
239
|
+
hidden_states = self.dense(hidden_states) # Linear transformation
|
240
|
+
hidden_states = self.dropout(hidden_states) # Apply dropout
|
241
|
+
hidden_states = hidden_states + input_tensor # Residual connection
|
242
|
+
return hidden_states
|
243
|
+
|
244
|
+
def get_config(self):
|
245
|
+
config = super().get_config()
|
246
|
+
config.update(
|
247
|
+
{
|
248
|
+
"hidden_dim": self.hidden_dim,
|
249
|
+
"dropout_rate": self.dropout_rate,
|
250
|
+
}
|
251
|
+
)
|
252
|
+
return config
|
253
|
+
|
254
|
+
|
255
|
+
class DeiTEncoderBlock(keras.layers.Layer):
|
256
|
+
"""DeiT encoder block.
|
257
|
+
Args:
|
258
|
+
num_heads: int. Number of attention heads.
|
259
|
+
hidden_dim: int. Dimensionality of the hidden representations.
|
260
|
+
intermediate_dim: int. Dimensionality of the intermediate MLP layer.
|
261
|
+
use_mha_bias: bool. Whether to use bias in the multi-head attention
|
262
|
+
layer. Defaults to `True`.
|
263
|
+
dropout_rate: float. Dropout rate. Between 0 and 1. Defaults to
|
264
|
+
`0.0`.
|
265
|
+
attention_dropout: float. Dropout rate for the attention mechanism.
|
266
|
+
Between 0 and 1. Defaults to `0.0`.
|
267
|
+
layer_norm_epsilon: float. Small float value for layer normalization
|
268
|
+
stability. Defaults to `1e-6`.
|
269
|
+
**kwargs: Additional keyword arguments passed to `keras.layers.Layer`
|
270
|
+
"""
|
271
|
+
|
272
|
+
def __init__(
|
273
|
+
self,
|
274
|
+
num_heads,
|
275
|
+
hidden_dim,
|
276
|
+
intermediate_dim,
|
277
|
+
use_mha_bias=True,
|
278
|
+
dropout_rate=0.0,
|
279
|
+
attention_dropout=0.0,
|
280
|
+
layer_norm_epsilon=1e-6,
|
281
|
+
**kwargs,
|
282
|
+
):
|
283
|
+
super().__init__(**kwargs)
|
284
|
+
key_dim = hidden_dim // num_heads
|
285
|
+
|
286
|
+
# === Config ===
|
287
|
+
self.num_heads = num_heads
|
288
|
+
self.hidden_dim = hidden_dim
|
289
|
+
self.intermediate_dim = intermediate_dim
|
290
|
+
self.key_dim = key_dim
|
291
|
+
self.use_mha_bias = use_mha_bias
|
292
|
+
self.dropout_rate = dropout_rate
|
293
|
+
self.attention_dropout = attention_dropout
|
294
|
+
self.layer_norm_epsilon = layer_norm_epsilon
|
295
|
+
|
296
|
+
def build(self, input_shape):
|
297
|
+
# Attention block
|
298
|
+
self.layer_norm_1 = keras.layers.LayerNormalization(
|
299
|
+
epsilon=self.layer_norm_epsilon,
|
300
|
+
name="ln_1",
|
301
|
+
dtype=self.dtype_policy,
|
302
|
+
)
|
303
|
+
self.layer_norm_1.build(input_shape)
|
304
|
+
self.mha = keras.layers.MultiHeadAttention(
|
305
|
+
num_heads=self.num_heads,
|
306
|
+
key_dim=self.key_dim,
|
307
|
+
use_bias=self.use_mha_bias,
|
308
|
+
dropout=self.attention_dropout,
|
309
|
+
name="mha",
|
310
|
+
dtype=self.dtype_policy,
|
311
|
+
)
|
312
|
+
self.mha.build(input_shape, input_shape)
|
313
|
+
|
314
|
+
# MLP block
|
315
|
+
self.layer_norm_2 = keras.layers.LayerNormalization(
|
316
|
+
epsilon=self.layer_norm_epsilon,
|
317
|
+
name="ln_2",
|
318
|
+
dtype=self.dtype_policy,
|
319
|
+
)
|
320
|
+
self.layer_norm_2.build((None, None, self.hidden_dim))
|
321
|
+
|
322
|
+
# Intermediate Layer
|
323
|
+
self.mlp = DeiTIntermediate(
|
324
|
+
self.intermediate_dim, dtype=self.dtype_policy, name="mlp"
|
325
|
+
)
|
326
|
+
self.mlp.build((None, None, self.hidden_dim))
|
327
|
+
|
328
|
+
# Output Layer
|
329
|
+
self.output_layer = DeiTOutput(
|
330
|
+
self.hidden_dim,
|
331
|
+
self.dropout_rate,
|
332
|
+
dtype=self.dtype_policy,
|
333
|
+
name="output_layer",
|
334
|
+
)
|
335
|
+
|
336
|
+
self.output_layer.build((None, None, self.intermediate_dim))
|
337
|
+
|
338
|
+
self.built = True
|
339
|
+
|
340
|
+
def call(
|
341
|
+
self,
|
342
|
+
hidden_states,
|
343
|
+
attention_mask=None,
|
344
|
+
return_attention_scores=False,
|
345
|
+
):
|
346
|
+
attention_scores = None
|
347
|
+
x = self.layer_norm_1(hidden_states)
|
348
|
+
if return_attention_scores:
|
349
|
+
x, attention_scores = self.mha(
|
350
|
+
x,
|
351
|
+
x,
|
352
|
+
attention_mask=attention_mask,
|
353
|
+
return_attention_scores=return_attention_scores,
|
354
|
+
)
|
355
|
+
else:
|
356
|
+
x = self.mha(
|
357
|
+
x,
|
358
|
+
x,
|
359
|
+
attention_mask=attention_mask,
|
360
|
+
)
|
361
|
+
|
362
|
+
x = x + hidden_states
|
363
|
+
y = self.layer_norm_2(x)
|
364
|
+
y = self.mlp(y)
|
365
|
+
y = self.output_layer(y, x)
|
366
|
+
|
367
|
+
return y, attention_scores
|
368
|
+
|
369
|
+
def get_config(self):
|
370
|
+
config = super().get_config()
|
371
|
+
config.update(
|
372
|
+
{
|
373
|
+
"num_heads": self.num_heads,
|
374
|
+
"hidden_dim": self.hidden_dim,
|
375
|
+
"intermediate_dim": self.intermediate_dim,
|
376
|
+
"key_dim": self.key_dim,
|
377
|
+
"use_mha_bias": self.use_mha_bias,
|
378
|
+
"dropout_rate": self.dropout_rate,
|
379
|
+
"attention_dropout": self.attention_dropout,
|
380
|
+
"layer_norm_epsilon": self.layer_norm_epsilon,
|
381
|
+
}
|
382
|
+
)
|
383
|
+
return config
|
384
|
+
|
385
|
+
|
386
|
+
class DeiTEncoder(keras.layers.Layer):
|
387
|
+
"""DeiT Encoder class.
|
388
|
+
Args:
|
389
|
+
num_layers: int. Number of Transformer encoder blocks.
|
390
|
+
num_heads: int. Number of attention heads.
|
391
|
+
hidden_dim: int. Dimensionality of the hidden representations.
|
392
|
+
intermediate_dim: int. Dimensionality of the intermediate MLP layer.
|
393
|
+
use_mha_bias: bool. Whether to use bias in the multi-head attention
|
394
|
+
layer. Defaults to `True`.
|
395
|
+
dropout_rate: float. Dropout rate. Between 0 and 1. Defaults to
|
396
|
+
`0.0`.
|
397
|
+
attention_dropout: float. Dropout rate for the attention mechanism.
|
398
|
+
Between 0 and 1. Defaults to `0.0`.
|
399
|
+
layer_norm_epsilon: float. Small float value for layer normalization
|
400
|
+
stability. Defaults to `1e-6`.
|
401
|
+
**kwargs: Additional keyword arguments passed to `keras.layers.Layer`
|
402
|
+
"""
|
403
|
+
|
404
|
+
def __init__(
|
405
|
+
self,
|
406
|
+
num_layers,
|
407
|
+
num_heads,
|
408
|
+
hidden_dim,
|
409
|
+
intermediate_dim,
|
410
|
+
use_mha_bias=True,
|
411
|
+
dropout_rate=0.0,
|
412
|
+
attention_dropout=0.0,
|
413
|
+
layer_norm_epsilon=1e-6,
|
414
|
+
**kwargs,
|
415
|
+
):
|
416
|
+
super().__init__(**kwargs)
|
417
|
+
|
418
|
+
# === Config ===
|
419
|
+
self.num_layers = num_layers
|
420
|
+
self.num_heads = num_heads
|
421
|
+
self.hidden_dim = hidden_dim
|
422
|
+
self.intermediate_dim = intermediate_dim
|
423
|
+
self.use_mha_bias = use_mha_bias
|
424
|
+
self.dropout_rate = dropout_rate
|
425
|
+
self.attention_dropout = attention_dropout
|
426
|
+
self.layer_norm_epsilon = layer_norm_epsilon
|
427
|
+
|
428
|
+
def build(self, input_shape):
|
429
|
+
self.encoder_layers = []
|
430
|
+
for i in range(self.num_layers):
|
431
|
+
encoder_block = DeiTEncoderBlock(
|
432
|
+
num_heads=self.num_heads,
|
433
|
+
hidden_dim=self.hidden_dim,
|
434
|
+
intermediate_dim=self.intermediate_dim,
|
435
|
+
use_mha_bias=self.use_mha_bias,
|
436
|
+
dropout_rate=self.dropout_rate,
|
437
|
+
attention_dropout=self.attention_dropout,
|
438
|
+
layer_norm_epsilon=self.layer_norm_epsilon,
|
439
|
+
dtype=self.dtype_policy,
|
440
|
+
name=f"transformer_block_{i + 1}",
|
441
|
+
)
|
442
|
+
encoder_block.build((None, None, self.hidden_dim))
|
443
|
+
self.encoder_layers.append(encoder_block)
|
444
|
+
|
445
|
+
self.layer_norm = keras.layers.LayerNormalization(
|
446
|
+
epsilon=self.layer_norm_epsilon,
|
447
|
+
dtype=self.dtype_policy,
|
448
|
+
name="ln",
|
449
|
+
)
|
450
|
+
self.layer_norm.build((None, None, self.hidden_dim))
|
451
|
+
|
452
|
+
self.built = True
|
453
|
+
|
454
|
+
def call(
|
455
|
+
self,
|
456
|
+
hidden_states,
|
457
|
+
attention_masks=None,
|
458
|
+
output_hidden_states=False,
|
459
|
+
return_attention_scores=False,
|
460
|
+
):
|
461
|
+
seq_len = ops.shape(hidden_states)[1] # Sequence length
|
462
|
+
hidden_dim = ops.shape(hidden_states)[2] # Hidden size
|
463
|
+
|
464
|
+
# Ensure valid tensor output even if disabled
|
465
|
+
all_hidden_states = (
|
466
|
+
ops.empty(shape=(0, seq_len, hidden_dim), dtype=hidden_states.dtype)
|
467
|
+
if not output_hidden_states
|
468
|
+
else ()
|
469
|
+
)
|
470
|
+
|
471
|
+
all_self_attentions_scores = (
|
472
|
+
ops.empty(
|
473
|
+
shape=(0, self.num_heads, seq_len, seq_len),
|
474
|
+
dtype=hidden_states.dtype,
|
475
|
+
)
|
476
|
+
if not return_attention_scores
|
477
|
+
else ()
|
478
|
+
)
|
479
|
+
|
480
|
+
for i in range(self.num_layers):
|
481
|
+
attention_mask = (
|
482
|
+
attention_masks[i] if attention_masks is not None else None
|
483
|
+
)
|
484
|
+
if output_hidden_states:
|
485
|
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
486
|
+
|
487
|
+
hidden_states, scores = self.encoder_layers[i](
|
488
|
+
hidden_states,
|
489
|
+
attention_mask=attention_mask,
|
490
|
+
return_attention_scores=return_attention_scores,
|
491
|
+
)
|
492
|
+
if return_attention_scores:
|
493
|
+
all_self_attentions_scores = all_self_attentions_scores + (
|
494
|
+
scores,
|
495
|
+
)
|
496
|
+
|
497
|
+
if output_hidden_states:
|
498
|
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
499
|
+
|
500
|
+
hidden_states = self.layer_norm(hidden_states)
|
501
|
+
|
502
|
+
return hidden_states, all_hidden_states, all_self_attentions_scores
|
503
|
+
|
504
|
+
def get_config(self):
|
505
|
+
config = super().get_config()
|
506
|
+
config.update(
|
507
|
+
{
|
508
|
+
"num_layers": self.num_layers,
|
509
|
+
"num_heads": self.num_heads,
|
510
|
+
"hidden_dim": self.hidden_dim,
|
511
|
+
"intermediate_dim": self.intermediate_dim,
|
512
|
+
"key_dim": self.key_dim,
|
513
|
+
"use_mha_bias": self.use_mha_bias,
|
514
|
+
"dropout_rate": self.dropout_rate,
|
515
|
+
"attention_dropout": self.attention_dropout,
|
516
|
+
"layer_norm_epsilon": self.layer_norm_epsilon,
|
517
|
+
}
|
518
|
+
)
|
519
|
+
return config
|
@@ -0,0 +1,49 @@
|
|
1
|
+
"""DeiT model preset configurations."""
|
2
|
+
|
3
|
+
# Metadata for loading pretrained model weights.
|
4
|
+
backbone_presets = {
|
5
|
+
"deit-base-distilled-patch16-384_imagenet": {
|
6
|
+
"metadata": {
|
7
|
+
"description": (
|
8
|
+
"DeiT-B16 model pre-trained on the ImageNet 1k dataset with "
|
9
|
+
"image resolution of 384x384 "
|
10
|
+
),
|
11
|
+
"params": 86092032,
|
12
|
+
"path": "deit",
|
13
|
+
},
|
14
|
+
"kaggle_handle": "kaggle://keras/deit/keras/deit_base_distilled_patch16_384_imagenet/1",
|
15
|
+
},
|
16
|
+
"deit-base-distilled-patch16-224_imagenet": {
|
17
|
+
"metadata": {
|
18
|
+
"description": (
|
19
|
+
"DeiT-B16 model pre-trained on the ImageNet 1k dataset with "
|
20
|
+
"image resolution of 224x224 "
|
21
|
+
),
|
22
|
+
"params": 85800192,
|
23
|
+
"path": "deit",
|
24
|
+
},
|
25
|
+
"kaggle_handle": "kaggle://keras/deit/keras/deit_base_distilled_patch16_224_imagenet/1",
|
26
|
+
},
|
27
|
+
"deit-tiny-distilled-patch16-224_imagenet": {
|
28
|
+
"metadata": {
|
29
|
+
"description": (
|
30
|
+
"DeiT-T16 model pre-trained on the ImageNet 1k dataset with "
|
31
|
+
"image resolution of 224x224 "
|
32
|
+
),
|
33
|
+
"params": 5524800,
|
34
|
+
"path": "deit",
|
35
|
+
},
|
36
|
+
"kaggle_handle": "kaggle://keras/deit/keras/deit_tiny_distilled_patch16_224_imagenet/1",
|
37
|
+
},
|
38
|
+
"deit-small-distilled-patch16-224_imagenet": {
|
39
|
+
"metadata": {
|
40
|
+
"description": (
|
41
|
+
"DeiT-S16 model pre-trained on the ImageNet 1k dataset with "
|
42
|
+
"image resolution of 224x224 "
|
43
|
+
),
|
44
|
+
"params": 21666432,
|
45
|
+
"path": "deit",
|
46
|
+
},
|
47
|
+
"kaggle_handle": "kaggle://keras/deit/keras/deit_small_distilled_patch16_224_imagenet/1",
|
48
|
+
},
|
49
|
+
}
|
@@ -0,0 +1,155 @@
|
|
1
|
+
import numpy as np
|
2
|
+
|
3
|
+
from keras_hub.src.models.deit.deit_backbone import DeiTBackbone
|
4
|
+
|
5
|
+
backbone_cls = DeiTBackbone
|
6
|
+
|
7
|
+
|
8
|
+
def convert_backbone_config(transformers_config):
|
9
|
+
image_size = transformers_config["image_size"]
|
10
|
+
return {
|
11
|
+
"image_shape": (image_size, image_size, 3),
|
12
|
+
"patch_size": transformers_config["patch_size"],
|
13
|
+
"num_layers": transformers_config["num_hidden_layers"],
|
14
|
+
"num_heads": transformers_config["num_attention_heads"],
|
15
|
+
"hidden_dim": transformers_config["hidden_size"],
|
16
|
+
"intermediate_dim": transformers_config["intermediate_size"],
|
17
|
+
"dropout_rate": transformers_config["hidden_dropout_prob"],
|
18
|
+
"attention_dropout": transformers_config[
|
19
|
+
"attention_probs_dropout_prob"
|
20
|
+
],
|
21
|
+
"layer_norm_epsilon": transformers_config["layer_norm_eps"],
|
22
|
+
}
|
23
|
+
|
24
|
+
|
25
|
+
def convert_weights(backbone, loader, transformers_config):
|
26
|
+
def port_ln(keras_variable, weight_key):
|
27
|
+
loader.port_weight(keras_variable.gamma, f"{weight_key}.weight")
|
28
|
+
loader.port_weight(keras_variable.beta, f"{weight_key}.bias")
|
29
|
+
|
30
|
+
def port_dense(keras_variable, weight_key):
|
31
|
+
loader.port_weight(
|
32
|
+
keras_variable.kernel,
|
33
|
+
f"{weight_key}.weight",
|
34
|
+
hook_fn=lambda x, _: x.T,
|
35
|
+
)
|
36
|
+
if keras_variable.bias is not None:
|
37
|
+
loader.port_weight(keras_variable.bias, f"{weight_key}.bias")
|
38
|
+
|
39
|
+
def port_mha(keras_variable, weight_key, num_heads, hidden_dim):
|
40
|
+
# query
|
41
|
+
loader.port_weight(
|
42
|
+
keras_variable.query_dense.kernel,
|
43
|
+
f"{weight_key}.attention.query.weight",
|
44
|
+
hook_fn=lambda x, _: np.reshape(
|
45
|
+
x.T, (hidden_dim, num_heads, hidden_dim // num_heads)
|
46
|
+
),
|
47
|
+
)
|
48
|
+
loader.port_weight(
|
49
|
+
keras_variable.query_dense.bias,
|
50
|
+
f"{weight_key}.attention.query.bias",
|
51
|
+
hook_fn=lambda x, _: np.reshape(
|
52
|
+
x, (num_heads, hidden_dim // num_heads)
|
53
|
+
),
|
54
|
+
)
|
55
|
+
# key
|
56
|
+
loader.port_weight(
|
57
|
+
keras_variable.key_dense.kernel,
|
58
|
+
f"{weight_key}.attention.key.weight",
|
59
|
+
hook_fn=lambda x, _: np.reshape(
|
60
|
+
x.T, (hidden_dim, num_heads, hidden_dim // num_heads)
|
61
|
+
),
|
62
|
+
)
|
63
|
+
loader.port_weight(
|
64
|
+
keras_variable.key_dense.bias,
|
65
|
+
f"{weight_key}.attention.key.bias",
|
66
|
+
hook_fn=lambda x, _: np.reshape(
|
67
|
+
x, (num_heads, hidden_dim // num_heads)
|
68
|
+
),
|
69
|
+
)
|
70
|
+
# value
|
71
|
+
loader.port_weight(
|
72
|
+
keras_variable.value_dense.kernel,
|
73
|
+
f"{weight_key}.attention.value.weight",
|
74
|
+
hook_fn=lambda x, _: np.reshape(
|
75
|
+
x.T, (hidden_dim, num_heads, hidden_dim // num_heads)
|
76
|
+
),
|
77
|
+
)
|
78
|
+
loader.port_weight(
|
79
|
+
keras_variable.value_dense.bias,
|
80
|
+
f"{weight_key}.attention.value.bias",
|
81
|
+
hook_fn=lambda x, _: np.reshape(
|
82
|
+
x, (num_heads, hidden_dim // num_heads)
|
83
|
+
),
|
84
|
+
)
|
85
|
+
# output
|
86
|
+
loader.port_weight(
|
87
|
+
keras_variable.output_dense.kernel,
|
88
|
+
f"{weight_key}.output.dense.weight",
|
89
|
+
hook_fn=lambda x, _: np.reshape(
|
90
|
+
x.T, (num_heads, hidden_dim // num_heads, hidden_dim)
|
91
|
+
),
|
92
|
+
)
|
93
|
+
loader.port_weight(
|
94
|
+
keras_variable.output_dense.bias, f"{weight_key}.output.dense.bias"
|
95
|
+
)
|
96
|
+
|
97
|
+
loader.port_weight(
|
98
|
+
keras_variable=backbone.layers[1].patch_embedding.kernel,
|
99
|
+
hf_weight_key="deit.embeddings.patch_embeddings.projection.weight",
|
100
|
+
hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)),
|
101
|
+
)
|
102
|
+
|
103
|
+
loader.port_weight(
|
104
|
+
backbone.layers[1].patch_embedding.bias,
|
105
|
+
"deit.embeddings.patch_embeddings.projection.bias",
|
106
|
+
)
|
107
|
+
|
108
|
+
loader.port_weight(
|
109
|
+
backbone.layers[1].class_token,
|
110
|
+
"deit.embeddings.cls_token",
|
111
|
+
)
|
112
|
+
|
113
|
+
loader.port_weight(
|
114
|
+
backbone.layers[1].distillation_token,
|
115
|
+
"deit.embeddings.distillation_token",
|
116
|
+
)
|
117
|
+
|
118
|
+
loader.port_weight(
|
119
|
+
backbone.layers[1].position_embedding,
|
120
|
+
"deit.embeddings.position_embeddings",
|
121
|
+
)
|
122
|
+
|
123
|
+
encoder_layers = backbone.layers[2].encoder_layers
|
124
|
+
for i, encoder_block in enumerate(encoder_layers):
|
125
|
+
prefix = "deit.encoder.layer"
|
126
|
+
num_heads = encoder_block.num_heads
|
127
|
+
hidden_dim = encoder_block.hidden_dim
|
128
|
+
|
129
|
+
port_mha(
|
130
|
+
encoder_block.mha,
|
131
|
+
f"{prefix}.{i}.attention",
|
132
|
+
num_heads,
|
133
|
+
hidden_dim,
|
134
|
+
)
|
135
|
+
port_ln(encoder_block.layer_norm_1, f"{prefix}.{i}.layernorm_before")
|
136
|
+
port_ln(encoder_block.layer_norm_2, f"{prefix}.{i}.layernorm_after")
|
137
|
+
|
138
|
+
port_dense(encoder_block.mlp.dense, f"{prefix}.{i}.intermediate.dense")
|
139
|
+
port_dense(
|
140
|
+
encoder_block.output_layer.dense, f"{prefix}.{i}.output.dense"
|
141
|
+
)
|
142
|
+
port_ln(backbone.layers[2].layer_norm, "deit.layernorm")
|
143
|
+
|
144
|
+
|
145
|
+
def convert_head(task, loader, transformers_config):
|
146
|
+
prefix = "cls_classifier."
|
147
|
+
loader.port_weight(
|
148
|
+
task.output_dense.kernel,
|
149
|
+
hf_weight_key=prefix + "weight",
|
150
|
+
hook_fn=lambda x, _: x.T,
|
151
|
+
)
|
152
|
+
loader.port_weight(
|
153
|
+
task.output_dense.bias,
|
154
|
+
hf_weight_key=prefix + "bias",
|
155
|
+
)
|
@@ -6,6 +6,7 @@ from keras_hub.src.utils.preset_utils import jax_memory_cleanup
|
|
6
6
|
from keras_hub.src.utils.transformers import convert_albert
|
7
7
|
from keras_hub.src.utils.transformers import convert_bart
|
8
8
|
from keras_hub.src.utils.transformers import convert_bert
|
9
|
+
from keras_hub.src.utils.transformers import convert_deit
|
9
10
|
from keras_hub.src.utils.transformers import convert_distilbert
|
10
11
|
from keras_hub.src.utils.transformers import convert_gemma
|
11
12
|
from keras_hub.src.utils.transformers import convert_gpt2
|
@@ -30,6 +31,8 @@ class TransformersPresetLoader(PresetLoader):
|
|
30
31
|
self.converter = convert_bart
|
31
32
|
elif model_type == "bert":
|
32
33
|
self.converter = convert_bert
|
34
|
+
elif model_type == "deit":
|
35
|
+
self.converter = convert_deit
|
33
36
|
elif model_type == "distilbert":
|
34
37
|
self.converter = convert_distilbert
|
35
38
|
elif model_type == "gemma" or model_type == "gemma2":
|
@@ -82,7 +85,7 @@ class TransformersPresetLoader(PresetLoader):
|
|
82
85
|
cls, load_weights, load_task_weights, **kwargs
|
83
86
|
)
|
84
87
|
# Support loading the classification head for classifier models.
|
85
|
-
if
|
88
|
+
if "ForImageClassification" in architecture:
|
86
89
|
kwargs["num_classes"] = len(self.config["id2label"])
|
87
90
|
task = super().load_task(cls, load_weights, load_task_weights, **kwargs)
|
88
91
|
if load_task_weights:
|
keras_hub/src/version.py
CHANGED
@@ -1,11 +1,11 @@
|
|
1
1
|
keras_hub/__init__.py,sha256=bJbUZkqwhZvTb1Tqx1fbkq6mzBYiEyq-Hin3oQIkhdE,558
|
2
|
-
keras_hub/layers/__init__.py,sha256=
|
2
|
+
keras_hub/layers/__init__.py,sha256=YQ4bW0_mI39Jqj2yoc8xcnynqoaXV2FBjHJviA9Ffas,5190
|
3
3
|
keras_hub/metrics/__init__.py,sha256=KYalsMPBnfwim9BdGHFfJ5WxUKFXOQ1QoKIMT_0lwlM,439
|
4
|
-
keras_hub/models/__init__.py,sha256=
|
4
|
+
keras_hub/models/__init__.py,sha256=7MhCw7S-uIPcko-R6g5a-Jy1idKe7BwlI836PfekhHc,27076
|
5
5
|
keras_hub/samplers/__init__.py,sha256=aFQIkiqbZpi8vjrPp2MVII4QUfE-eQjra5fMeHsoy7k,886
|
6
6
|
keras_hub/src/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
7
7
|
keras_hub/src/api_export.py,sha256=9pQZK27JObxWZ96QPLBp1OBsjWigh1iuV6RglPGMRk0,1499
|
8
|
-
keras_hub/src/version.py,sha256=
|
8
|
+
keras_hub/src/version.py,sha256=A_oYO8DhCB-uOrecxZt2B7NMyEpt94fhLGZT7-dbdBg,222
|
9
9
|
keras_hub/src/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
10
10
|
keras_hub/src/layers/modeling/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
11
|
keras_hub/src/layers/modeling/alibi_bias.py,sha256=1XBTHI52L_iJDhN_w5ydu_iMhCuTgQAxEPwcLA6BPuk,4411
|
@@ -135,6 +135,13 @@ keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py,sha256
|
|
135
135
|
keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py,sha256=mz9nG55gdXSTDE96AXgeTCwUFB95DIpTuqrvWIt5Lco,7840
|
136
136
|
keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py,sha256=ZKYY8A7mV2QvwXwjDUd9xAbVHo58-Hgj_IqNUbuyCIU,625
|
137
137
|
keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py,sha256=pubi30sPJKLOpz9fRQff2FZt_53KBvwf2uyaJ5YL7J8,3726
|
138
|
+
keras_hub/src/models/deit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
139
|
+
keras_hub/src/models/deit/deit_backbone.py,sha256=R5pBOqe8vcvD8VaRnsy_zIRIz6BLnUbkTeKUOoGNHPA,5942
|
140
|
+
keras_hub/src/models/deit/deit_image_classifier.py,sha256=pUS2638yBAxEBxcJoHyLABsgjCWv_Y0Mj_8u0YgDPdI,5758
|
141
|
+
keras_hub/src/models/deit/deit_image_classifier_preprocessor.py,sha256=s5pTcsUjlt1oIXFWIu-9gf2-sBesAyrjJIYmFOB96Xs,514
|
142
|
+
keras_hub/src/models/deit/deit_image_converter.py,sha256=wEGCLHS_i4wF9WA4m7uUXcHNbwf6TYgvPoM6C_t0rpM,330
|
143
|
+
keras_hub/src/models/deit/deit_layers.py,sha256=A80-UTHEUV8g5rEG-fr8OQpGe3HeoYlYwpoDCtq71ZU,17278
|
144
|
+
keras_hub/src/models/deit/deit_presets.py,sha256=0c2jm2DIznOr6ciQoLM6QYopQTLiMx4jONGLaXvtt6g,1778
|
138
145
|
keras_hub/src/models/densenet/__init__.py,sha256=r7StyamnWeeZxOk9r4ZYNbS_YVhu9YGPyXhNxljvdPg,269
|
139
146
|
keras_hub/src/models/densenet/densenet_backbone.py,sha256=f2nfsXyXQert2aYHq-L-JZtp8inq1fs1K47rzZQ9nTI,6744
|
140
147
|
keras_hub/src/models/densenet/densenet_image_classifier.py,sha256=ye-Ix3oU42pfsDoh-h1PG4di1kzldO0ZO7Nj304p_X4,544
|
@@ -494,6 +501,7 @@ keras_hub/src/utils/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRk
|
|
494
501
|
keras_hub/src/utils/transformers/convert_albert.py,sha256=VdKclZpCxtDWq3UbUUQZf4fR9DJK_JYZ73B4O_G9skg,7695
|
495
502
|
keras_hub/src/utils/transformers/convert_bart.py,sha256=Tk4h9Md9rwN5wjQbGIVrC7qzDpF8kI8qm-FKL8HlUok,14411
|
496
503
|
keras_hub/src/utils/transformers/convert_bert.py,sha256=4gQqXCJzC9QWdLPDUAq741K8t_kjPIET050YjUnLeDA,5977
|
504
|
+
keras_hub/src/utils/transformers/convert_deit.py,sha256=ubcqYzMlhWTCE2S_TsXICCMmqjN9RsQPaw_70vArnjo,5306
|
497
505
|
keras_hub/src/utils/transformers/convert_distilbert.py,sha256=SlfIRhSRk5c1ir2HGiDPiXa5XdOId_DbcnZO9lbwyZ8,6498
|
498
506
|
keras_hub/src/utils/transformers/convert_gemma.py,sha256=ElCgwBpSN5Q7rV5PJawTsoytPzs5ZjuwoY60YAe8y_A,6533
|
499
507
|
keras_hub/src/utils/transformers/convert_gpt2.py,sha256=HCeHN_-GiQJRxLCM9OCJJ1watPVpIBF8ujS8pGbBOWc,5703
|
@@ -505,11 +513,11 @@ keras_hub/src/utils/transformers/convert_qwen.py,sha256=WUxMAEFVqRs7TRw7QU5TH3_e
|
|
505
513
|
keras_hub/src/utils/transformers/convert_qwen3.py,sha256=LIormvCMWPq6X9Wo2eNbADjtFZ0nI7tFGZFBxmo4GKw,5700
|
506
514
|
keras_hub/src/utils/transformers/convert_qwen_moe.py,sha256=a7R28aln-PdAcNuKAXdrtzvslho2Co6GypChxLMKPpc,10618
|
507
515
|
keras_hub/src/utils/transformers/convert_vit.py,sha256=9SUZ9utNJhW_5cj3acMn9cRy47u2eIcDsrhmzj77o9k,5187
|
508
|
-
keras_hub/src/utils/transformers/preset_loader.py,sha256=
|
516
|
+
keras_hub/src/utils/transformers/preset_loader.py,sha256=K5FzDAtCuXS9rmZc0Zj7UCwbz5J9_pf7ozWov1qRAfg,4495
|
509
517
|
keras_hub/src/utils/transformers/safetensor_utils.py,sha256=CYUHyA4y-B61r7NDnCsFb4t_UmSwZ1k9L-8gzEd6KRg,3339
|
510
518
|
keras_hub/tokenizers/__init__.py,sha256=uMjjm0mzUkRb0e4Ac_JK8aJ9cKGUi5UqmzWoWAFJprE,4164
|
511
519
|
keras_hub/utils/__init__.py,sha256=jXPqVGBpJr_PpYmqD8aDG-fRMlxH-ulqCR2SZMn288Y,646
|
512
|
-
keras_hub_nightly-0.22.0.
|
513
|
-
keras_hub_nightly-0.22.0.
|
514
|
-
keras_hub_nightly-0.22.0.
|
515
|
-
keras_hub_nightly-0.22.0.
|
520
|
+
keras_hub_nightly-0.22.0.dev202505310408.dist-info/METADATA,sha256=v4Rvzln90tKecsbiwiU29ZFrct9xpLCV10RQDme4-DI,7393
|
521
|
+
keras_hub_nightly-0.22.0.dev202505310408.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
522
|
+
keras_hub_nightly-0.22.0.dev202505310408.dist-info/top_level.txt,sha256=N4J6piIWBKa38A4uV-CnIopnOEf8mHAbkNXafXm_CuA,10
|
523
|
+
keras_hub_nightly-0.22.0.dev202505310408.dist-info/RECORD,,
|
File without changes
|