keras-hub-nightly 0.21.0.dev202505280410__py3-none-any.whl → 0.22.0.dev202505300409__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/models/__init__.py +9 -0
- keras_hub/src/models/mixtral/mixtral_presets.py +4 -4
- keras_hub/src/models/qwen/qwen_presets.py +6 -6
- keras_hub/src/models/qwen3/qwen3_attention.py +369 -0
- keras_hub/src/models/qwen3/qwen3_backbone.py +191 -0
- keras_hub/src/models/qwen3/qwen3_causal_lm_preprocessor.py +10 -0
- keras_hub/src/models/qwen3/qwen3_decoder.py +309 -0
- keras_hub/src/models/qwen3/qwen3_layernorm.py +38 -0
- keras_hub/src/models/qwen3/qwen3_tokenizer.py +48 -0
- keras_hub/src/models/qwen_moe/qwen_moe_presets.py +2 -2
- keras_hub/src/models/vit/vit_backbone.py +31 -11
- keras_hub/src/models/vit/vit_image_converter.py +0 -70
- keras_hub/src/models/vit/vit_layers.py +33 -18
- keras_hub/src/models/vit/vit_presets.py +11 -11
- keras_hub/src/utils/transformers/convert_qwen3.py +145 -0
- keras_hub/src/utils/transformers/preset_loader.py +3 -0
- keras_hub/src/version.py +1 -1
- {keras_hub_nightly-0.21.0.dev202505280410.dist-info → keras_hub_nightly-0.22.0.dev202505300409.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.21.0.dev202505280410.dist-info → keras_hub_nightly-0.22.0.dev202505300409.dist-info}/RECORD +21 -14
- {keras_hub_nightly-0.21.0.dev202505280410.dist-info → keras_hub_nightly-0.22.0.dev202505300409.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.21.0.dev202505280410.dist-info → keras_hub_nightly-0.22.0.dev202505300409.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,309 @@
|
|
1
|
+
import keras
|
2
|
+
from keras import ops
|
3
|
+
|
4
|
+
from keras_hub.src.layers.modeling.transformer_layer_utils import (
|
5
|
+
compute_causal_mask,
|
6
|
+
)
|
7
|
+
from keras_hub.src.layers.modeling.transformer_layer_utils import (
|
8
|
+
merge_padding_and_attention_mask,
|
9
|
+
)
|
10
|
+
from keras_hub.src.models.qwen3.qwen3_attention import Qwen3Attention
|
11
|
+
from keras_hub.src.models.qwen3.qwen3_layernorm import Qwen3LayerNorm
|
12
|
+
from keras_hub.src.utils.keras_utils import clone_initializer
|
13
|
+
|
14
|
+
|
15
|
+
class Qwen3TransformerDecoder(keras.layers.Layer):
|
16
|
+
"""A Transformer decoder layer for the Qwen3 backbone.
|
17
|
+
|
18
|
+
This layer implements a Transformer decoder block that includes
|
19
|
+
self-attention with optional sliding window attention and a feed-forward
|
20
|
+
network.
|
21
|
+
|
22
|
+
Args:
|
23
|
+
intermediate_dim: Output dimension of the first dense layer in the
|
24
|
+
feed-forward network.
|
25
|
+
num_query_heads: Number of query attention heads.
|
26
|
+
num_key_value_heads: Number of key/value attention heads (for GQA).
|
27
|
+
rope_max_wavelength: Maximum wavelength for RoPE (Rotary Position
|
28
|
+
Embedding).
|
29
|
+
rope_scaling_factor: Scaling factor for RoPE, used for extending
|
30
|
+
context length.
|
31
|
+
activation: Activation function to use in the feed-forward network.
|
32
|
+
layer_norm_epsilon: Small float added to variance to avoid dividing
|
33
|
+
by zero in layer norm.
|
34
|
+
kernel_initializer: Initializer for the kernel weights.
|
35
|
+
dropout: Dropout rate for attention and hidden layers.
|
36
|
+
sliding_window_size: Size of the sliding window for attention when
|
37
|
+
enabled.
|
38
|
+
**kwargs: Additional keyword arguments to pass to the Layer.
|
39
|
+
"""
|
40
|
+
|
41
|
+
def __init__(
|
42
|
+
self,
|
43
|
+
intermediate_dim,
|
44
|
+
num_query_heads,
|
45
|
+
num_key_value_heads,
|
46
|
+
head_dim,
|
47
|
+
rope_max_wavelength=10000,
|
48
|
+
rope_scaling_factor=1.0,
|
49
|
+
activation="silu",
|
50
|
+
layer_norm_epsilon=1e-5,
|
51
|
+
kernel_initializer="glorot_uniform",
|
52
|
+
dropout=0.0,
|
53
|
+
sliding_window_size=None,
|
54
|
+
**kwargs,
|
55
|
+
):
|
56
|
+
super().__init__(**kwargs)
|
57
|
+
self.intermediate_dim = intermediate_dim
|
58
|
+
self.num_query_heads = num_query_heads
|
59
|
+
self.num_key_value_heads = num_key_value_heads
|
60
|
+
self.head_dim = head_dim
|
61
|
+
|
62
|
+
self.rope_max_wavelength = rope_max_wavelength
|
63
|
+
self.rope_scaling_factor = rope_scaling_factor
|
64
|
+
|
65
|
+
self.dropout = dropout
|
66
|
+
|
67
|
+
self.sliding_window_size = sliding_window_size
|
68
|
+
|
69
|
+
self.activation = keras.activations.get(activation)
|
70
|
+
self.layer_norm_epsilon = layer_norm_epsilon
|
71
|
+
self.kernel_initializer = keras.initializers.get(kernel_initializer)
|
72
|
+
|
73
|
+
self.supports_masking = True
|
74
|
+
|
75
|
+
def build(self, decoder_sequence_shape):
|
76
|
+
self._decoder_sequence_shape = decoder_sequence_shape
|
77
|
+
self.hidden_dim = decoder_sequence_shape[-1]
|
78
|
+
|
79
|
+
# Self attention layer.
|
80
|
+
self._self_attention_layer = Qwen3Attention(
|
81
|
+
num_query_heads=self.num_query_heads,
|
82
|
+
num_key_value_heads=self.num_key_value_heads,
|
83
|
+
rope_max_wavelength=self.rope_max_wavelength,
|
84
|
+
head_dim=self.head_dim,
|
85
|
+
rope_scaling_factor=self.rope_scaling_factor,
|
86
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
87
|
+
dropout=self.dropout,
|
88
|
+
sliding_window_size=self.sliding_window_size,
|
89
|
+
dtype=self.dtype_policy,
|
90
|
+
name="self_attention",
|
91
|
+
)
|
92
|
+
self._self_attention_layer.build(decoder_sequence_shape)
|
93
|
+
|
94
|
+
self._self_attention_layernorm = Qwen3LayerNorm(
|
95
|
+
epsilon=self.layer_norm_epsilon,
|
96
|
+
dtype=self.dtype_policy,
|
97
|
+
name="self_attention_layernorm",
|
98
|
+
)
|
99
|
+
|
100
|
+
self._self_attention_layernorm.build(decoder_sequence_shape)
|
101
|
+
self._self_attention_dropout = keras.layers.Dropout(
|
102
|
+
rate=self.dropout,
|
103
|
+
dtype=self.dtype_policy,
|
104
|
+
name="self_attention_dropout",
|
105
|
+
)
|
106
|
+
|
107
|
+
# Feedforward layers.
|
108
|
+
self._feedforward_intermediate_dense = keras.layers.Dense(
|
109
|
+
self.intermediate_dim,
|
110
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
111
|
+
use_bias=False,
|
112
|
+
dtype=self.dtype_policy,
|
113
|
+
name="feedforward_intermediate_dense",
|
114
|
+
)
|
115
|
+
self._feedforward_intermediate_dense.build(decoder_sequence_shape)
|
116
|
+
|
117
|
+
self._feedforward_gate_dense = keras.layers.Dense(
|
118
|
+
self.intermediate_dim,
|
119
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
120
|
+
use_bias=False,
|
121
|
+
dtype=self.dtype_policy,
|
122
|
+
name="feedforward_gate_dense",
|
123
|
+
)
|
124
|
+
self._feedforward_gate_dense.build(decoder_sequence_shape)
|
125
|
+
|
126
|
+
self._feedforward_output_dense = keras.layers.Dense(
|
127
|
+
self.hidden_dim,
|
128
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
129
|
+
use_bias=False,
|
130
|
+
dtype=self.dtype_policy,
|
131
|
+
name="feedforward_output_dense",
|
132
|
+
)
|
133
|
+
|
134
|
+
self._feedforward_output_dense.build(
|
135
|
+
self._feedforward_gate_dense.compute_output_shape(
|
136
|
+
decoder_sequence_shape
|
137
|
+
)
|
138
|
+
)
|
139
|
+
|
140
|
+
self._feedforward_layernorm = Qwen3LayerNorm(
|
141
|
+
epsilon=self.layer_norm_epsilon,
|
142
|
+
dtype=self.dtype_policy,
|
143
|
+
name="feedforward_layernorm",
|
144
|
+
)
|
145
|
+
self._feedforward_layernorm.build(decoder_sequence_shape)
|
146
|
+
|
147
|
+
self.built = True
|
148
|
+
|
149
|
+
def call(
|
150
|
+
self,
|
151
|
+
decoder_sequence,
|
152
|
+
decoder_padding_mask=None,
|
153
|
+
decoder_attention_mask=None,
|
154
|
+
self_attention_cache=None,
|
155
|
+
self_attention_cache_update_index=None,
|
156
|
+
training=None,
|
157
|
+
):
|
158
|
+
"""Forward pass for the decoder layer.
|
159
|
+
|
160
|
+
Args:
|
161
|
+
decoder_sequence: Input tensor of shape [batch_size, seq_length,
|
162
|
+
hidden_size].
|
163
|
+
decoder_padding_mask: Mask tensor for padding tokens.
|
164
|
+
decoder_attention_mask: Additional attention mask.
|
165
|
+
self_attention_cache: Optional cached key and value tensors for
|
166
|
+
self-attention.
|
167
|
+
self_attention_cache_update_index: Index at which to update the
|
168
|
+
cache.
|
169
|
+
training: Boolean indicating whether in training mode.
|
170
|
+
|
171
|
+
Returns:
|
172
|
+
decoder_output: Output tensor after applying transformer decoder
|
173
|
+
block.
|
174
|
+
self_attention_cache: Updated cache tensors (if cache is provided).
|
175
|
+
"""
|
176
|
+
self_attention_mask = self._compute_self_attention_mask(
|
177
|
+
decoder_sequence=decoder_sequence,
|
178
|
+
decoder_padding_mask=decoder_padding_mask,
|
179
|
+
decoder_attention_mask=decoder_attention_mask,
|
180
|
+
self_attention_cache=self_attention_cache,
|
181
|
+
self_attention_cache_update_index=self_attention_cache_update_index,
|
182
|
+
)
|
183
|
+
residual = decoder_sequence
|
184
|
+
|
185
|
+
x = self._self_attention_layernorm(decoder_sequence)
|
186
|
+
|
187
|
+
# Self attention block.
|
188
|
+
x = self._self_attention_layer(
|
189
|
+
hidden_states=x,
|
190
|
+
attention_mask=self_attention_mask,
|
191
|
+
cache=self_attention_cache,
|
192
|
+
cache_update_index=self_attention_cache_update_index,
|
193
|
+
)
|
194
|
+
|
195
|
+
if self_attention_cache is not None:
|
196
|
+
x, self_attention_cache = x
|
197
|
+
|
198
|
+
x = self._self_attention_dropout(x, training=training)
|
199
|
+
|
200
|
+
x = x + residual
|
201
|
+
residual = x
|
202
|
+
|
203
|
+
x = self._feedforward_layernorm(x)
|
204
|
+
gate_output = self._feedforward_gate_dense(x)
|
205
|
+
|
206
|
+
# Note that we run the activation function in full 32-bit
|
207
|
+
# precision since this is what `torch.nn.functional.silu`
|
208
|
+
# does. Internally, `torch.nn.functional.silu` converts the
|
209
|
+
# inputs to float32, computes SiLU, and converts the outputs
|
210
|
+
# back to compute dtype.
|
211
|
+
# CPU Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cpu/Activation.cpp#L1221-L1235 # noqa: E501
|
212
|
+
# CUDA Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cuda/ActivationSiluKernel.cu # noqa: E501
|
213
|
+
gate_output = ops.cast(gate_output, "float32")
|
214
|
+
gate_output = self.activation(gate_output)
|
215
|
+
gate_output = ops.cast(gate_output, self.compute_dtype)
|
216
|
+
|
217
|
+
x = self._feedforward_intermediate_dense(x)
|
218
|
+
|
219
|
+
x = self._feedforward_output_dense(ops.multiply(x, gate_output))
|
220
|
+
|
221
|
+
decoder_output = x + residual
|
222
|
+
|
223
|
+
if self_attention_cache is not None:
|
224
|
+
return decoder_output, self_attention_cache
|
225
|
+
return decoder_output
|
226
|
+
|
227
|
+
def _compute_self_attention_mask(
|
228
|
+
self,
|
229
|
+
decoder_sequence,
|
230
|
+
decoder_padding_mask,
|
231
|
+
decoder_attention_mask,
|
232
|
+
self_attention_cache,
|
233
|
+
self_attention_cache_update_index,
|
234
|
+
):
|
235
|
+
"""Computes the self-attention mask combining causal, padding and
|
236
|
+
attention masks.
|
237
|
+
|
238
|
+
Args:
|
239
|
+
decoder_sequence: Input tensor.
|
240
|
+
decoder_padding_mask: Mask tensor for padding tokens.
|
241
|
+
decoder_attention_mask: Additional attention mask.
|
242
|
+
self_attention_cache: Optional cached key and value tensors.
|
243
|
+
self_attention_cache_update_index: Index at which to update the
|
244
|
+
cache.
|
245
|
+
|
246
|
+
Returns:
|
247
|
+
Combined attention mask tensor.
|
248
|
+
"""
|
249
|
+
decoder_mask = merge_padding_and_attention_mask(
|
250
|
+
decoder_sequence, decoder_padding_mask, decoder_attention_mask
|
251
|
+
)
|
252
|
+
batch_size = ops.shape(decoder_sequence)[0]
|
253
|
+
input_length = output_length = ops.shape(decoder_sequence)[1]
|
254
|
+
# We need to handle a rectangular causal mask when doing cached
|
255
|
+
# decoding. For generative inference, `decoder_sequence` will
|
256
|
+
# generally be length 1, and `cache` will be the full generation length.
|
257
|
+
if self_attention_cache is not None:
|
258
|
+
input_length = ops.shape(self_attention_cache)[2]
|
259
|
+
|
260
|
+
cache_update_index = (
|
261
|
+
0
|
262
|
+
if self_attention_cache_update_index is None
|
263
|
+
else self_attention_cache_update_index
|
264
|
+
)
|
265
|
+
|
266
|
+
causal_mask = compute_causal_mask(
|
267
|
+
batch_size, input_length, output_length, cache_update_index
|
268
|
+
)
|
269
|
+
|
270
|
+
return (
|
271
|
+
ops.minimum(decoder_mask, causal_mask)
|
272
|
+
if decoder_mask is not None
|
273
|
+
else causal_mask
|
274
|
+
)
|
275
|
+
|
276
|
+
def compute_output_shape(self, decoder_sequence_shape):
|
277
|
+
"""Computes the output shape of the layer.
|
278
|
+
|
279
|
+
Args:
|
280
|
+
decoder_sequence_shape: Shape of the decoder sequence input.
|
281
|
+
|
282
|
+
Returns:
|
283
|
+
Output shape, which is the same as the input shape.
|
284
|
+
"""
|
285
|
+
return decoder_sequence_shape
|
286
|
+
|
287
|
+
def get_config(self):
|
288
|
+
"""Returns the config of the layer.
|
289
|
+
|
290
|
+
Returns:
|
291
|
+
Dictionary containing the parameters used to initialize this layer.
|
292
|
+
"""
|
293
|
+
config = super().get_config()
|
294
|
+
config.update(
|
295
|
+
{
|
296
|
+
"intermediate_dim": self.intermediate_dim,
|
297
|
+
"num_query_heads": self.num_query_heads,
|
298
|
+
"rope_max_wavelength": self.rope_max_wavelength,
|
299
|
+
"rope_scaling_factor": self.rope_scaling_factor,
|
300
|
+
"num_key_value_heads": self.num_key_value_heads,
|
301
|
+
"activation": keras.activations.serialize(self.activation),
|
302
|
+
"layer_norm_epsilon": self.layer_norm_epsilon,
|
303
|
+
"kernel_initializer": keras.initializers.serialize(
|
304
|
+
self.kernel_initializer
|
305
|
+
),
|
306
|
+
"dropout": self.dropout,
|
307
|
+
}
|
308
|
+
)
|
309
|
+
return config
|
@@ -0,0 +1,38 @@
|
|
1
|
+
import keras
|
2
|
+
from keras import ops
|
3
|
+
|
4
|
+
|
5
|
+
class Qwen3LayerNorm(keras.layers.Layer):
|
6
|
+
"""A normalization layer for Qwen that implements RMS normalization."""
|
7
|
+
|
8
|
+
def __init__(self, head_dim=None, epsilon=1e-6, **kwargs):
|
9
|
+
super().__init__(**kwargs)
|
10
|
+
self.head_dim = head_dim
|
11
|
+
self.epsilon = epsilon
|
12
|
+
|
13
|
+
def build(self, input_shape):
|
14
|
+
if self.head_dim:
|
15
|
+
dim = self.head_dim
|
16
|
+
else:
|
17
|
+
dim = input_shape[-1]
|
18
|
+
|
19
|
+
self.scale = self.add_weight(
|
20
|
+
name="scale",
|
21
|
+
trainable=True,
|
22
|
+
shape=(dim,),
|
23
|
+
initializer="ones",
|
24
|
+
dtype=self.variable_dtype,
|
25
|
+
)
|
26
|
+
self.built = True
|
27
|
+
|
28
|
+
def call(self, x):
|
29
|
+
input_dtype = x.dtype
|
30
|
+
x = ops.cast(x, "float32")
|
31
|
+
var = ops.mean(ops.power(x, 2), axis=-1, keepdims=True)
|
32
|
+
x = x * ops.rsqrt(var + self.epsilon)
|
33
|
+
return ops.cast(x * self.scale, input_dtype)
|
34
|
+
|
35
|
+
def get_config(self):
|
36
|
+
config = super().get_config()
|
37
|
+
config.update({"epsilon": self.epsilon})
|
38
|
+
return config
|
@@ -0,0 +1,48 @@
|
|
1
|
+
from keras_hub.src.api_export import keras_hub_export
|
2
|
+
from keras_hub.src.models.qwen3.qwen3_backbone import Qwen3Backbone
|
3
|
+
from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer
|
4
|
+
|
5
|
+
|
6
|
+
@keras_hub_export(
|
7
|
+
"keras_hub.models.Qwen3Tokenizer",
|
8
|
+
)
|
9
|
+
class Qwen3Tokenizer(BytePairTokenizer):
|
10
|
+
"""Tokenizer for Qwen3 models.
|
11
|
+
|
12
|
+
This tokenizer implements byte-pair encoding (BPE) for Qwen3 models,
|
13
|
+
handling special tokens like BOS (beginning of sequence) and EOS (end of
|
14
|
+
sequence).
|
15
|
+
|
16
|
+
Args:
|
17
|
+
vocabulary: Dictionary mapping tokens to token IDs, or path to
|
18
|
+
vocabulary file.
|
19
|
+
merges: List of BPE merges, or path to merges file.
|
20
|
+
bos_token: Beginning of sequence token. Defaults to None.
|
21
|
+
eos_token: End of sequence token. Defaults to "<|endoftext|>".
|
22
|
+
misc_special_tokens: Set of additional special tokens. Defaults to
|
23
|
+
empty set.
|
24
|
+
"""
|
25
|
+
|
26
|
+
backbone_cls = Qwen3Backbone
|
27
|
+
|
28
|
+
def __init__(
|
29
|
+
self,
|
30
|
+
vocabulary=None,
|
31
|
+
merges=None,
|
32
|
+
**kwargs,
|
33
|
+
):
|
34
|
+
# Add EOS token
|
35
|
+
eos_token = "<|im_end|>"
|
36
|
+
self._add_special_token(eos_token, "end_token")
|
37
|
+
|
38
|
+
pad_token = "<|endoftext|>"
|
39
|
+
self._add_special_token(pad_token, "pad_token")
|
40
|
+
|
41
|
+
self.start_token_id = None
|
42
|
+
self.start_token = None
|
43
|
+
|
44
|
+
super().__init__(
|
45
|
+
vocabulary=vocabulary,
|
46
|
+
merges=merges,
|
47
|
+
**kwargs,
|
48
|
+
)
|
@@ -4,8 +4,8 @@ backbone_presets = {
|
|
4
4
|
"qwen1.5_moe_2.7b_en": {
|
5
5
|
"metadata": {
|
6
6
|
"description": (
|
7
|
-
"24-layer Qwen MoE model with 2.7 billion active parameters "
|
8
|
-
"and 8 experts per MoE layer."
|
7
|
+
"24-layer Qwen MoE model with 2.7 billion active parameters "
|
8
|
+
"and 8 experts per MoE layer."
|
9
9
|
),
|
10
10
|
"params": 14315784192,
|
11
11
|
"path": "qwen-1.5-moe",
|
@@ -18,10 +18,10 @@ class ViTBackbone(Backbone):
|
|
18
18
|
|
19
19
|
Args:
|
20
20
|
image_shape: A tuple or list of 3 integers representing the shape of the
|
21
|
-
input image `(height, width, channels)
|
22
|
-
|
23
|
-
|
24
|
-
|
21
|
+
input image `(height, width, channels)`.
|
22
|
+
patch_size: int or (int, int). The size of each image patch, the input
|
23
|
+
image will be divided into patches of shape
|
24
|
+
`(patch_size_h, patch_size_w)`.
|
25
25
|
num_layers: int. The number of transformer encoder layers.
|
26
26
|
num_heads: int. specifying the number of attention heads in each
|
27
27
|
Transformer encoder layer.
|
@@ -37,6 +37,10 @@ class ViTBackbone(Backbone):
|
|
37
37
|
use_mha_bias: bool. Whether to use bias in the multi-head
|
38
38
|
attention layers.
|
39
39
|
use_mlp_bias: bool. Whether to use bias in the MLP layers.
|
40
|
+
use_class_token: bool. Whether to use class token to be part of
|
41
|
+
patch embedding. Defaults to `True`.
|
42
|
+
use_patch_bias: bool. Whether to use bias in Conv2d of patch embedding
|
43
|
+
layer. Defaults to `True`.
|
40
44
|
data_format: str. `"channels_last"` or `"channels_first"`, specifying
|
41
45
|
the data format for the input image. If `None`, defaults to
|
42
46
|
`"channels_last"`.
|
@@ -58,6 +62,8 @@ class ViTBackbone(Backbone):
|
|
58
62
|
layer_norm_epsilon=1e-6,
|
59
63
|
use_mha_bias=True,
|
60
64
|
use_mlp_bias=True,
|
65
|
+
use_class_token=True,
|
66
|
+
use_patch_bias=True,
|
61
67
|
data_format=None,
|
62
68
|
dtype=None,
|
63
69
|
**kwargs,
|
@@ -74,24 +80,34 @@ class ViTBackbone(Backbone):
|
|
74
80
|
f"at index {h_axis} (height) or {w_axis} (width). "
|
75
81
|
f"Image shape: {image_shape}"
|
76
82
|
)
|
77
|
-
|
83
|
+
|
84
|
+
if isinstance(patch_size, int):
|
85
|
+
patch_size = (patch_size, patch_size)
|
86
|
+
|
87
|
+
if image_shape[h_axis] % patch_size[0] != 0:
|
88
|
+
raise ValueError(
|
89
|
+
f"Input height {image_shape[h_axis]} should be divisible by "
|
90
|
+
f"patch size {patch_size[0]}."
|
91
|
+
)
|
92
|
+
|
93
|
+
if image_shape[w_axis] % patch_size[1] != 0:
|
78
94
|
raise ValueError(
|
79
|
-
f"
|
80
|
-
f"
|
81
|
-
f"indices {h_axis} and {w_axis} respectively. Image shape: "
|
82
|
-
f"{image_shape}"
|
95
|
+
f"Input width {image_shape[h_axis]} should be divisible by "
|
96
|
+
f"patch size {patch_size[1]}."
|
83
97
|
)
|
84
98
|
|
85
99
|
num_channels = image_shape[channels_axis]
|
86
100
|
|
87
101
|
# === Functional Model ===
|
88
|
-
inputs = keras.layers.Input(shape=image_shape)
|
102
|
+
inputs = keras.layers.Input(shape=image_shape, name="images")
|
89
103
|
|
90
104
|
x = ViTPatchingAndEmbedding(
|
91
|
-
image_size=image_shape[h_axis],
|
105
|
+
image_size=(image_shape[h_axis], image_shape[w_axis]),
|
92
106
|
patch_size=patch_size,
|
93
107
|
hidden_dim=hidden_dim,
|
94
108
|
num_channels=num_channels,
|
109
|
+
use_class_token=use_class_token,
|
110
|
+
use_patch_bias=use_patch_bias,
|
95
111
|
data_format=data_format,
|
96
112
|
dtype=dtype,
|
97
113
|
name="vit_patching_and_embedding",
|
@@ -130,6 +146,8 @@ class ViTBackbone(Backbone):
|
|
130
146
|
self.layer_norm_epsilon = layer_norm_epsilon
|
131
147
|
self.use_mha_bias = use_mha_bias
|
132
148
|
self.use_mlp_bias = use_mlp_bias
|
149
|
+
self.use_class_token = use_class_token
|
150
|
+
self.use_patch_bias = use_patch_bias
|
133
151
|
self.data_format = data_format
|
134
152
|
|
135
153
|
def get_config(self):
|
@@ -147,6 +165,8 @@ class ViTBackbone(Backbone):
|
|
147
165
|
"layer_norm_epsilon": self.layer_norm_epsilon,
|
148
166
|
"use_mha_bias": self.use_mha_bias,
|
149
167
|
"use_mlp_bias": self.use_mlp_bias,
|
168
|
+
"use_class_token": self.use_class_token,
|
169
|
+
"use_patch_bias": self.use_patch_bias,
|
150
170
|
}
|
151
171
|
)
|
152
172
|
return config
|
@@ -1,78 +1,8 @@
|
|
1
1
|
from keras_hub.src.api_export import keras_hub_export
|
2
2
|
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
|
3
3
|
from keras_hub.src.models.vit.vit_backbone import ViTBackbone
|
4
|
-
from keras_hub.src.utils.tensor_utils import preprocessing_function
|
5
4
|
|
6
5
|
|
7
6
|
@keras_hub_export("keras_hub.layers.ViTImageConverter")
|
8
7
|
class ViTImageConverter(ImageConverter):
|
9
|
-
"""Converts images to the format expected by a ViT model.
|
10
|
-
|
11
|
-
This layer performs image normalization using mean and standard deviation
|
12
|
-
values. By default, it uses the same normalization as the
|
13
|
-
"google/vit-large-patch16-224" model on Hugging Face:
|
14
|
-
`norm_mean=[0.5, 0.5, 0.5]` and `norm_std=[0.5, 0.5, 0.5]`
|
15
|
-
([reference](https://huggingface.co/google/vit-large-patch16-224/blob/main/preprocessor_config.json)).
|
16
|
-
These defaults are suitable for models pretrained using this normalization.
|
17
|
-
|
18
|
-
Args:
|
19
|
-
norm_mean: list or tuple of floats. Mean values for image normalization.
|
20
|
-
Defaults to `[0.5, 0.5, 0.5]`.
|
21
|
-
norm_std: list or tuple of floats. Standard deviation values for
|
22
|
-
image normalization. Defaults to `[0.5, 0.5, 0.5]`.
|
23
|
-
**kwargs: Additional keyword arguments passed to
|
24
|
-
`keras_hub.layers.preprocessing.ImageConverter`.
|
25
|
-
|
26
|
-
Examples:
|
27
|
-
```python
|
28
|
-
import keras
|
29
|
-
import numpy as np
|
30
|
-
from keras_hub.src.layers import ViTImageConverter
|
31
|
-
|
32
|
-
# Example image (replace with your actual image data)
|
33
|
-
image = np.random.rand(1, 224, 224, 3) # Example: (B, H, W, C)
|
34
|
-
|
35
|
-
# Create a ViTImageConverter instance
|
36
|
-
converter = ViTImageConverter(
|
37
|
-
image_size=(28,28),
|
38
|
-
scale=1/255.
|
39
|
-
)
|
40
|
-
# Preprocess the image
|
41
|
-
preprocessed_image = converter(image)
|
42
|
-
```
|
43
|
-
"""
|
44
|
-
|
45
8
|
backbone_cls = ViTBackbone
|
46
|
-
|
47
|
-
def __init__(
|
48
|
-
self, norm_mean=[0.5, 0.5, 0.5], norm_std=[0.5, 0.5, 0.5], **kwargs
|
49
|
-
):
|
50
|
-
super().__init__(**kwargs)
|
51
|
-
self.norm_mean = norm_mean
|
52
|
-
self.norm_std = norm_std
|
53
|
-
|
54
|
-
@preprocessing_function
|
55
|
-
def call(self, inputs):
|
56
|
-
# TODO: Remove this whole function. Why can just use scale and offset
|
57
|
-
# in the base class.
|
58
|
-
x = super().call(inputs)
|
59
|
-
if self.norm_mean:
|
60
|
-
norm_mean = self._expand_non_channel_dims(self.norm_mean, x)
|
61
|
-
x, norm_mean = self._convert_types(x, norm_mean, self.compute_dtype)
|
62
|
-
x = x - norm_mean
|
63
|
-
if self.norm_std:
|
64
|
-
norm_std = self._expand_non_channel_dims(self.norm_std, x)
|
65
|
-
x, norm_std = self._convert_types(x, norm_std, x.dtype)
|
66
|
-
x = x / norm_std
|
67
|
-
|
68
|
-
return x
|
69
|
-
|
70
|
-
def get_config(self):
|
71
|
-
config = super().get_config()
|
72
|
-
config.update(
|
73
|
-
{
|
74
|
-
"norm_mean": self.norm_mean,
|
75
|
-
"norm_std": self.norm_std,
|
76
|
-
}
|
77
|
-
)
|
78
|
-
return config
|
@@ -75,12 +75,13 @@ class ViTPatchingAndEmbedding(keras.layers.Layer):
|
|
75
75
|
"""Patches the image and embeds the patches.
|
76
76
|
|
77
77
|
Args:
|
78
|
-
image_size: int. Size of the input image
|
79
|
-
|
80
|
-
patch_size: int. Size of each image patch.
|
78
|
+
image_size: (int, int). Size of the input image.
|
79
|
+
patch_size: (int, int). Size of each image patch.
|
81
80
|
hidden_dim: int. Dimensionality of the patch embeddings.
|
82
81
|
num_channels: int. Number of channels in the input image. Defaults to
|
83
82
|
`3`.
|
83
|
+
use_class_token: bool. Whether to use class token to be part of
|
84
|
+
patch embedding. Defaults to `True`.
|
84
85
|
data_format: str. `"channels_last"` or `"channels_first"`. Defaults to
|
85
86
|
`None` (which uses `"channels_last"`).
|
86
87
|
**kwargs: Additional keyword arguments passed to `keras.layers.Layer`
|
@@ -92,12 +93,15 @@ class ViTPatchingAndEmbedding(keras.layers.Layer):
|
|
92
93
|
patch_size,
|
93
94
|
hidden_dim,
|
94
95
|
num_channels=3,
|
96
|
+
use_class_token=True,
|
97
|
+
use_patch_bias=True,
|
95
98
|
data_format=None,
|
96
99
|
**kwargs,
|
97
100
|
):
|
98
101
|
super().__init__(**kwargs)
|
99
|
-
|
100
|
-
|
102
|
+
grid_size = tuple([s // p for s, p in zip(image_size, patch_size)])
|
103
|
+
num_patches = grid_size[0] * grid_size[1]
|
104
|
+
num_positions = num_patches + 1 if use_class_token else num_patches
|
101
105
|
|
102
106
|
# === Config ===
|
103
107
|
self.image_size = image_size
|
@@ -106,19 +110,22 @@ class ViTPatchingAndEmbedding(keras.layers.Layer):
|
|
106
110
|
self.num_channels = num_channels
|
107
111
|
self.num_patches = num_patches
|
108
112
|
self.num_positions = num_positions
|
113
|
+
self.use_class_token = use_class_token
|
114
|
+
self.use_patch_bias = use_patch_bias
|
109
115
|
self.data_format = standardize_data_format(data_format)
|
110
116
|
|
111
117
|
def build(self, input_shape):
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
118
|
+
if self.use_class_token:
|
119
|
+
self.class_token = self.add_weight(
|
120
|
+
shape=(
|
121
|
+
1,
|
122
|
+
1,
|
123
|
+
self.hidden_dim,
|
124
|
+
),
|
125
|
+
initializer="random_normal",
|
126
|
+
dtype=self.variable_dtype,
|
127
|
+
name="class_token",
|
128
|
+
)
|
122
129
|
self.patch_embedding = keras.layers.Conv2D(
|
123
130
|
filters=self.hidden_dim,
|
124
131
|
kernel_size=self.patch_size,
|
@@ -127,6 +134,7 @@ class ViTPatchingAndEmbedding(keras.layers.Layer):
|
|
127
134
|
activation=None,
|
128
135
|
dtype=self.dtype_policy,
|
129
136
|
data_format=self.data_format,
|
137
|
+
use_bias=self.use_patch_bias,
|
130
138
|
name="patch_embedding",
|
131
139
|
)
|
132
140
|
self.patch_embedding.build(input_shape)
|
@@ -153,10 +161,16 @@ class ViTPatchingAndEmbedding(keras.layers.Layer):
|
|
153
161
|
patch_embeddings = ops.reshape(
|
154
162
|
patch_embeddings, [embeddings_shape[0], -1, embeddings_shape[-1]]
|
155
163
|
)
|
156
|
-
class_token = ops.tile(self.class_token, (embeddings_shape[0], 1, 1))
|
157
164
|
position_embeddings = self.position_embedding(self.position_ids)
|
158
|
-
|
159
|
-
|
165
|
+
|
166
|
+
if self.use_class_token:
|
167
|
+
class_token = ops.tile(
|
168
|
+
self.class_token, (embeddings_shape[0], 1, 1)
|
169
|
+
)
|
170
|
+
patch_embeddings = ops.concatenate(
|
171
|
+
[class_token, patch_embeddings], axis=1
|
172
|
+
)
|
173
|
+
return ops.add(patch_embeddings, position_embeddings)
|
160
174
|
|
161
175
|
def compute_output_shape(self, input_shape):
|
162
176
|
return (
|
@@ -175,6 +189,7 @@ class ViTPatchingAndEmbedding(keras.layers.Layer):
|
|
175
189
|
"num_channels": self.num_channels,
|
176
190
|
"num_patches": self.num_patches,
|
177
191
|
"num_positions": self.num_positions,
|
192
|
+
"use_class_token": self.use_class_token,
|
178
193
|
}
|
179
194
|
)
|
180
195
|
return config
|