keras-hub-nightly 0.21.0.dev202505280410__py3-none-any.whl → 0.22.0.dev202505300409__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,309 @@
1
+ import keras
2
+ from keras import ops
3
+
4
+ from keras_hub.src.layers.modeling.transformer_layer_utils import (
5
+ compute_causal_mask,
6
+ )
7
+ from keras_hub.src.layers.modeling.transformer_layer_utils import (
8
+ merge_padding_and_attention_mask,
9
+ )
10
+ from keras_hub.src.models.qwen3.qwen3_attention import Qwen3Attention
11
+ from keras_hub.src.models.qwen3.qwen3_layernorm import Qwen3LayerNorm
12
+ from keras_hub.src.utils.keras_utils import clone_initializer
13
+
14
+
15
+ class Qwen3TransformerDecoder(keras.layers.Layer):
16
+ """A Transformer decoder layer for the Qwen3 backbone.
17
+
18
+ This layer implements a Transformer decoder block that includes
19
+ self-attention with optional sliding window attention and a feed-forward
20
+ network.
21
+
22
+ Args:
23
+ intermediate_dim: Output dimension of the first dense layer in the
24
+ feed-forward network.
25
+ num_query_heads: Number of query attention heads.
26
+ num_key_value_heads: Number of key/value attention heads (for GQA).
27
+ rope_max_wavelength: Maximum wavelength for RoPE (Rotary Position
28
+ Embedding).
29
+ rope_scaling_factor: Scaling factor for RoPE, used for extending
30
+ context length.
31
+ activation: Activation function to use in the feed-forward network.
32
+ layer_norm_epsilon: Small float added to variance to avoid dividing
33
+ by zero in layer norm.
34
+ kernel_initializer: Initializer for the kernel weights.
35
+ dropout: Dropout rate for attention and hidden layers.
36
+ sliding_window_size: Size of the sliding window for attention when
37
+ enabled.
38
+ **kwargs: Additional keyword arguments to pass to the Layer.
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ intermediate_dim,
44
+ num_query_heads,
45
+ num_key_value_heads,
46
+ head_dim,
47
+ rope_max_wavelength=10000,
48
+ rope_scaling_factor=1.0,
49
+ activation="silu",
50
+ layer_norm_epsilon=1e-5,
51
+ kernel_initializer="glorot_uniform",
52
+ dropout=0.0,
53
+ sliding_window_size=None,
54
+ **kwargs,
55
+ ):
56
+ super().__init__(**kwargs)
57
+ self.intermediate_dim = intermediate_dim
58
+ self.num_query_heads = num_query_heads
59
+ self.num_key_value_heads = num_key_value_heads
60
+ self.head_dim = head_dim
61
+
62
+ self.rope_max_wavelength = rope_max_wavelength
63
+ self.rope_scaling_factor = rope_scaling_factor
64
+
65
+ self.dropout = dropout
66
+
67
+ self.sliding_window_size = sliding_window_size
68
+
69
+ self.activation = keras.activations.get(activation)
70
+ self.layer_norm_epsilon = layer_norm_epsilon
71
+ self.kernel_initializer = keras.initializers.get(kernel_initializer)
72
+
73
+ self.supports_masking = True
74
+
75
+ def build(self, decoder_sequence_shape):
76
+ self._decoder_sequence_shape = decoder_sequence_shape
77
+ self.hidden_dim = decoder_sequence_shape[-1]
78
+
79
+ # Self attention layer.
80
+ self._self_attention_layer = Qwen3Attention(
81
+ num_query_heads=self.num_query_heads,
82
+ num_key_value_heads=self.num_key_value_heads,
83
+ rope_max_wavelength=self.rope_max_wavelength,
84
+ head_dim=self.head_dim,
85
+ rope_scaling_factor=self.rope_scaling_factor,
86
+ kernel_initializer=clone_initializer(self.kernel_initializer),
87
+ dropout=self.dropout,
88
+ sliding_window_size=self.sliding_window_size,
89
+ dtype=self.dtype_policy,
90
+ name="self_attention",
91
+ )
92
+ self._self_attention_layer.build(decoder_sequence_shape)
93
+
94
+ self._self_attention_layernorm = Qwen3LayerNorm(
95
+ epsilon=self.layer_norm_epsilon,
96
+ dtype=self.dtype_policy,
97
+ name="self_attention_layernorm",
98
+ )
99
+
100
+ self._self_attention_layernorm.build(decoder_sequence_shape)
101
+ self._self_attention_dropout = keras.layers.Dropout(
102
+ rate=self.dropout,
103
+ dtype=self.dtype_policy,
104
+ name="self_attention_dropout",
105
+ )
106
+
107
+ # Feedforward layers.
108
+ self._feedforward_intermediate_dense = keras.layers.Dense(
109
+ self.intermediate_dim,
110
+ kernel_initializer=clone_initializer(self.kernel_initializer),
111
+ use_bias=False,
112
+ dtype=self.dtype_policy,
113
+ name="feedforward_intermediate_dense",
114
+ )
115
+ self._feedforward_intermediate_dense.build(decoder_sequence_shape)
116
+
117
+ self._feedforward_gate_dense = keras.layers.Dense(
118
+ self.intermediate_dim,
119
+ kernel_initializer=clone_initializer(self.kernel_initializer),
120
+ use_bias=False,
121
+ dtype=self.dtype_policy,
122
+ name="feedforward_gate_dense",
123
+ )
124
+ self._feedforward_gate_dense.build(decoder_sequence_shape)
125
+
126
+ self._feedforward_output_dense = keras.layers.Dense(
127
+ self.hidden_dim,
128
+ kernel_initializer=clone_initializer(self.kernel_initializer),
129
+ use_bias=False,
130
+ dtype=self.dtype_policy,
131
+ name="feedforward_output_dense",
132
+ )
133
+
134
+ self._feedforward_output_dense.build(
135
+ self._feedforward_gate_dense.compute_output_shape(
136
+ decoder_sequence_shape
137
+ )
138
+ )
139
+
140
+ self._feedforward_layernorm = Qwen3LayerNorm(
141
+ epsilon=self.layer_norm_epsilon,
142
+ dtype=self.dtype_policy,
143
+ name="feedforward_layernorm",
144
+ )
145
+ self._feedforward_layernorm.build(decoder_sequence_shape)
146
+
147
+ self.built = True
148
+
149
+ def call(
150
+ self,
151
+ decoder_sequence,
152
+ decoder_padding_mask=None,
153
+ decoder_attention_mask=None,
154
+ self_attention_cache=None,
155
+ self_attention_cache_update_index=None,
156
+ training=None,
157
+ ):
158
+ """Forward pass for the decoder layer.
159
+
160
+ Args:
161
+ decoder_sequence: Input tensor of shape [batch_size, seq_length,
162
+ hidden_size].
163
+ decoder_padding_mask: Mask tensor for padding tokens.
164
+ decoder_attention_mask: Additional attention mask.
165
+ self_attention_cache: Optional cached key and value tensors for
166
+ self-attention.
167
+ self_attention_cache_update_index: Index at which to update the
168
+ cache.
169
+ training: Boolean indicating whether in training mode.
170
+
171
+ Returns:
172
+ decoder_output: Output tensor after applying transformer decoder
173
+ block.
174
+ self_attention_cache: Updated cache tensors (if cache is provided).
175
+ """
176
+ self_attention_mask = self._compute_self_attention_mask(
177
+ decoder_sequence=decoder_sequence,
178
+ decoder_padding_mask=decoder_padding_mask,
179
+ decoder_attention_mask=decoder_attention_mask,
180
+ self_attention_cache=self_attention_cache,
181
+ self_attention_cache_update_index=self_attention_cache_update_index,
182
+ )
183
+ residual = decoder_sequence
184
+
185
+ x = self._self_attention_layernorm(decoder_sequence)
186
+
187
+ # Self attention block.
188
+ x = self._self_attention_layer(
189
+ hidden_states=x,
190
+ attention_mask=self_attention_mask,
191
+ cache=self_attention_cache,
192
+ cache_update_index=self_attention_cache_update_index,
193
+ )
194
+
195
+ if self_attention_cache is not None:
196
+ x, self_attention_cache = x
197
+
198
+ x = self._self_attention_dropout(x, training=training)
199
+
200
+ x = x + residual
201
+ residual = x
202
+
203
+ x = self._feedforward_layernorm(x)
204
+ gate_output = self._feedforward_gate_dense(x)
205
+
206
+ # Note that we run the activation function in full 32-bit
207
+ # precision since this is what `torch.nn.functional.silu`
208
+ # does. Internally, `torch.nn.functional.silu` converts the
209
+ # inputs to float32, computes SiLU, and converts the outputs
210
+ # back to compute dtype.
211
+ # CPU Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cpu/Activation.cpp#L1221-L1235 # noqa: E501
212
+ # CUDA Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cuda/ActivationSiluKernel.cu # noqa: E501
213
+ gate_output = ops.cast(gate_output, "float32")
214
+ gate_output = self.activation(gate_output)
215
+ gate_output = ops.cast(gate_output, self.compute_dtype)
216
+
217
+ x = self._feedforward_intermediate_dense(x)
218
+
219
+ x = self._feedforward_output_dense(ops.multiply(x, gate_output))
220
+
221
+ decoder_output = x + residual
222
+
223
+ if self_attention_cache is not None:
224
+ return decoder_output, self_attention_cache
225
+ return decoder_output
226
+
227
+ def _compute_self_attention_mask(
228
+ self,
229
+ decoder_sequence,
230
+ decoder_padding_mask,
231
+ decoder_attention_mask,
232
+ self_attention_cache,
233
+ self_attention_cache_update_index,
234
+ ):
235
+ """Computes the self-attention mask combining causal, padding and
236
+ attention masks.
237
+
238
+ Args:
239
+ decoder_sequence: Input tensor.
240
+ decoder_padding_mask: Mask tensor for padding tokens.
241
+ decoder_attention_mask: Additional attention mask.
242
+ self_attention_cache: Optional cached key and value tensors.
243
+ self_attention_cache_update_index: Index at which to update the
244
+ cache.
245
+
246
+ Returns:
247
+ Combined attention mask tensor.
248
+ """
249
+ decoder_mask = merge_padding_and_attention_mask(
250
+ decoder_sequence, decoder_padding_mask, decoder_attention_mask
251
+ )
252
+ batch_size = ops.shape(decoder_sequence)[0]
253
+ input_length = output_length = ops.shape(decoder_sequence)[1]
254
+ # We need to handle a rectangular causal mask when doing cached
255
+ # decoding. For generative inference, `decoder_sequence` will
256
+ # generally be length 1, and `cache` will be the full generation length.
257
+ if self_attention_cache is not None:
258
+ input_length = ops.shape(self_attention_cache)[2]
259
+
260
+ cache_update_index = (
261
+ 0
262
+ if self_attention_cache_update_index is None
263
+ else self_attention_cache_update_index
264
+ )
265
+
266
+ causal_mask = compute_causal_mask(
267
+ batch_size, input_length, output_length, cache_update_index
268
+ )
269
+
270
+ return (
271
+ ops.minimum(decoder_mask, causal_mask)
272
+ if decoder_mask is not None
273
+ else causal_mask
274
+ )
275
+
276
+ def compute_output_shape(self, decoder_sequence_shape):
277
+ """Computes the output shape of the layer.
278
+
279
+ Args:
280
+ decoder_sequence_shape: Shape of the decoder sequence input.
281
+
282
+ Returns:
283
+ Output shape, which is the same as the input shape.
284
+ """
285
+ return decoder_sequence_shape
286
+
287
+ def get_config(self):
288
+ """Returns the config of the layer.
289
+
290
+ Returns:
291
+ Dictionary containing the parameters used to initialize this layer.
292
+ """
293
+ config = super().get_config()
294
+ config.update(
295
+ {
296
+ "intermediate_dim": self.intermediate_dim,
297
+ "num_query_heads": self.num_query_heads,
298
+ "rope_max_wavelength": self.rope_max_wavelength,
299
+ "rope_scaling_factor": self.rope_scaling_factor,
300
+ "num_key_value_heads": self.num_key_value_heads,
301
+ "activation": keras.activations.serialize(self.activation),
302
+ "layer_norm_epsilon": self.layer_norm_epsilon,
303
+ "kernel_initializer": keras.initializers.serialize(
304
+ self.kernel_initializer
305
+ ),
306
+ "dropout": self.dropout,
307
+ }
308
+ )
309
+ return config
@@ -0,0 +1,38 @@
1
+ import keras
2
+ from keras import ops
3
+
4
+
5
+ class Qwen3LayerNorm(keras.layers.Layer):
6
+ """A normalization layer for Qwen that implements RMS normalization."""
7
+
8
+ def __init__(self, head_dim=None, epsilon=1e-6, **kwargs):
9
+ super().__init__(**kwargs)
10
+ self.head_dim = head_dim
11
+ self.epsilon = epsilon
12
+
13
+ def build(self, input_shape):
14
+ if self.head_dim:
15
+ dim = self.head_dim
16
+ else:
17
+ dim = input_shape[-1]
18
+
19
+ self.scale = self.add_weight(
20
+ name="scale",
21
+ trainable=True,
22
+ shape=(dim,),
23
+ initializer="ones",
24
+ dtype=self.variable_dtype,
25
+ )
26
+ self.built = True
27
+
28
+ def call(self, x):
29
+ input_dtype = x.dtype
30
+ x = ops.cast(x, "float32")
31
+ var = ops.mean(ops.power(x, 2), axis=-1, keepdims=True)
32
+ x = x * ops.rsqrt(var + self.epsilon)
33
+ return ops.cast(x * self.scale, input_dtype)
34
+
35
+ def get_config(self):
36
+ config = super().get_config()
37
+ config.update({"epsilon": self.epsilon})
38
+ return config
@@ -0,0 +1,48 @@
1
+ from keras_hub.src.api_export import keras_hub_export
2
+ from keras_hub.src.models.qwen3.qwen3_backbone import Qwen3Backbone
3
+ from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer
4
+
5
+
6
+ @keras_hub_export(
7
+ "keras_hub.models.Qwen3Tokenizer",
8
+ )
9
+ class Qwen3Tokenizer(BytePairTokenizer):
10
+ """Tokenizer for Qwen3 models.
11
+
12
+ This tokenizer implements byte-pair encoding (BPE) for Qwen3 models,
13
+ handling special tokens like BOS (beginning of sequence) and EOS (end of
14
+ sequence).
15
+
16
+ Args:
17
+ vocabulary: Dictionary mapping tokens to token IDs, or path to
18
+ vocabulary file.
19
+ merges: List of BPE merges, or path to merges file.
20
+ bos_token: Beginning of sequence token. Defaults to None.
21
+ eos_token: End of sequence token. Defaults to "<|endoftext|>".
22
+ misc_special_tokens: Set of additional special tokens. Defaults to
23
+ empty set.
24
+ """
25
+
26
+ backbone_cls = Qwen3Backbone
27
+
28
+ def __init__(
29
+ self,
30
+ vocabulary=None,
31
+ merges=None,
32
+ **kwargs,
33
+ ):
34
+ # Add EOS token
35
+ eos_token = "<|im_end|>"
36
+ self._add_special_token(eos_token, "end_token")
37
+
38
+ pad_token = "<|endoftext|>"
39
+ self._add_special_token(pad_token, "pad_token")
40
+
41
+ self.start_token_id = None
42
+ self.start_token = None
43
+
44
+ super().__init__(
45
+ vocabulary=vocabulary,
46
+ merges=merges,
47
+ **kwargs,
48
+ )
@@ -4,8 +4,8 @@ backbone_presets = {
4
4
  "qwen1.5_moe_2.7b_en": {
5
5
  "metadata": {
6
6
  "description": (
7
- "24-layer Qwen MoE model with 2.7 billion active parameters ",
8
- "and 8 experts per MoE layer.",
7
+ "24-layer Qwen MoE model with 2.7 billion active parameters "
8
+ "and 8 experts per MoE layer."
9
9
  ),
10
10
  "params": 14315784192,
11
11
  "path": "qwen-1.5-moe",
@@ -18,10 +18,10 @@ class ViTBackbone(Backbone):
18
18
 
19
19
  Args:
20
20
  image_shape: A tuple or list of 3 integers representing the shape of the
21
- input image `(height, width, channels)`, `height` and `width` must
22
- be equal.
23
- patch_size: int. The size of each image patch, the input image will be
24
- divided into patches of shape `(patch_size, patch_size)`.
21
+ input image `(height, width, channels)`.
22
+ patch_size: int or (int, int). The size of each image patch, the input
23
+ image will be divided into patches of shape
24
+ `(patch_size_h, patch_size_w)`.
25
25
  num_layers: int. The number of transformer encoder layers.
26
26
  num_heads: int. specifying the number of attention heads in each
27
27
  Transformer encoder layer.
@@ -37,6 +37,10 @@ class ViTBackbone(Backbone):
37
37
  use_mha_bias: bool. Whether to use bias in the multi-head
38
38
  attention layers.
39
39
  use_mlp_bias: bool. Whether to use bias in the MLP layers.
40
+ use_class_token: bool. Whether to use class token to be part of
41
+ patch embedding. Defaults to `True`.
42
+ use_patch_bias: bool. Whether to use bias in Conv2d of patch embedding
43
+ layer. Defaults to `True`.
40
44
  data_format: str. `"channels_last"` or `"channels_first"`, specifying
41
45
  the data format for the input image. If `None`, defaults to
42
46
  `"channels_last"`.
@@ -58,6 +62,8 @@ class ViTBackbone(Backbone):
58
62
  layer_norm_epsilon=1e-6,
59
63
  use_mha_bias=True,
60
64
  use_mlp_bias=True,
65
+ use_class_token=True,
66
+ use_patch_bias=True,
61
67
  data_format=None,
62
68
  dtype=None,
63
69
  **kwargs,
@@ -74,24 +80,34 @@ class ViTBackbone(Backbone):
74
80
  f"at index {h_axis} (height) or {w_axis} (width). "
75
81
  f"Image shape: {image_shape}"
76
82
  )
77
- if image_shape[h_axis] != image_shape[w_axis]:
83
+
84
+ if isinstance(patch_size, int):
85
+ patch_size = (patch_size, patch_size)
86
+
87
+ if image_shape[h_axis] % patch_size[0] != 0:
88
+ raise ValueError(
89
+ f"Input height {image_shape[h_axis]} should be divisible by "
90
+ f"patch size {patch_size[0]}."
91
+ )
92
+
93
+ if image_shape[w_axis] % patch_size[1] != 0:
78
94
  raise ValueError(
79
- f"Image height and width must be equal. Found height: "
80
- f"{image_shape[h_axis]}, width: {image_shape[w_axis]} at "
81
- f"indices {h_axis} and {w_axis} respectively. Image shape: "
82
- f"{image_shape}"
95
+ f"Input width {image_shape[h_axis]} should be divisible by "
96
+ f"patch size {patch_size[1]}."
83
97
  )
84
98
 
85
99
  num_channels = image_shape[channels_axis]
86
100
 
87
101
  # === Functional Model ===
88
- inputs = keras.layers.Input(shape=image_shape)
102
+ inputs = keras.layers.Input(shape=image_shape, name="images")
89
103
 
90
104
  x = ViTPatchingAndEmbedding(
91
- image_size=image_shape[h_axis],
105
+ image_size=(image_shape[h_axis], image_shape[w_axis]),
92
106
  patch_size=patch_size,
93
107
  hidden_dim=hidden_dim,
94
108
  num_channels=num_channels,
109
+ use_class_token=use_class_token,
110
+ use_patch_bias=use_patch_bias,
95
111
  data_format=data_format,
96
112
  dtype=dtype,
97
113
  name="vit_patching_and_embedding",
@@ -130,6 +146,8 @@ class ViTBackbone(Backbone):
130
146
  self.layer_norm_epsilon = layer_norm_epsilon
131
147
  self.use_mha_bias = use_mha_bias
132
148
  self.use_mlp_bias = use_mlp_bias
149
+ self.use_class_token = use_class_token
150
+ self.use_patch_bias = use_patch_bias
133
151
  self.data_format = data_format
134
152
 
135
153
  def get_config(self):
@@ -147,6 +165,8 @@ class ViTBackbone(Backbone):
147
165
  "layer_norm_epsilon": self.layer_norm_epsilon,
148
166
  "use_mha_bias": self.use_mha_bias,
149
167
  "use_mlp_bias": self.use_mlp_bias,
168
+ "use_class_token": self.use_class_token,
169
+ "use_patch_bias": self.use_patch_bias,
150
170
  }
151
171
  )
152
172
  return config
@@ -1,78 +1,8 @@
1
1
  from keras_hub.src.api_export import keras_hub_export
2
2
  from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
3
3
  from keras_hub.src.models.vit.vit_backbone import ViTBackbone
4
- from keras_hub.src.utils.tensor_utils import preprocessing_function
5
4
 
6
5
 
7
6
  @keras_hub_export("keras_hub.layers.ViTImageConverter")
8
7
  class ViTImageConverter(ImageConverter):
9
- """Converts images to the format expected by a ViT model.
10
-
11
- This layer performs image normalization using mean and standard deviation
12
- values. By default, it uses the same normalization as the
13
- "google/vit-large-patch16-224" model on Hugging Face:
14
- `norm_mean=[0.5, 0.5, 0.5]` and `norm_std=[0.5, 0.5, 0.5]`
15
- ([reference](https://huggingface.co/google/vit-large-patch16-224/blob/main/preprocessor_config.json)).
16
- These defaults are suitable for models pretrained using this normalization.
17
-
18
- Args:
19
- norm_mean: list or tuple of floats. Mean values for image normalization.
20
- Defaults to `[0.5, 0.5, 0.5]`.
21
- norm_std: list or tuple of floats. Standard deviation values for
22
- image normalization. Defaults to `[0.5, 0.5, 0.5]`.
23
- **kwargs: Additional keyword arguments passed to
24
- `keras_hub.layers.preprocessing.ImageConverter`.
25
-
26
- Examples:
27
- ```python
28
- import keras
29
- import numpy as np
30
- from keras_hub.src.layers import ViTImageConverter
31
-
32
- # Example image (replace with your actual image data)
33
- image = np.random.rand(1, 224, 224, 3) # Example: (B, H, W, C)
34
-
35
- # Create a ViTImageConverter instance
36
- converter = ViTImageConverter(
37
- image_size=(28,28),
38
- scale=1/255.
39
- )
40
- # Preprocess the image
41
- preprocessed_image = converter(image)
42
- ```
43
- """
44
-
45
8
  backbone_cls = ViTBackbone
46
-
47
- def __init__(
48
- self, norm_mean=[0.5, 0.5, 0.5], norm_std=[0.5, 0.5, 0.5], **kwargs
49
- ):
50
- super().__init__(**kwargs)
51
- self.norm_mean = norm_mean
52
- self.norm_std = norm_std
53
-
54
- @preprocessing_function
55
- def call(self, inputs):
56
- # TODO: Remove this whole function. Why can just use scale and offset
57
- # in the base class.
58
- x = super().call(inputs)
59
- if self.norm_mean:
60
- norm_mean = self._expand_non_channel_dims(self.norm_mean, x)
61
- x, norm_mean = self._convert_types(x, norm_mean, self.compute_dtype)
62
- x = x - norm_mean
63
- if self.norm_std:
64
- norm_std = self._expand_non_channel_dims(self.norm_std, x)
65
- x, norm_std = self._convert_types(x, norm_std, x.dtype)
66
- x = x / norm_std
67
-
68
- return x
69
-
70
- def get_config(self):
71
- config = super().get_config()
72
- config.update(
73
- {
74
- "norm_mean": self.norm_mean,
75
- "norm_std": self.norm_std,
76
- }
77
- )
78
- return config
@@ -75,12 +75,13 @@ class ViTPatchingAndEmbedding(keras.layers.Layer):
75
75
  """Patches the image and embeds the patches.
76
76
 
77
77
  Args:
78
- image_size: int. Size of the input image (height or width).
79
- Assumed to be square.
80
- patch_size: int. Size of each image patch.
78
+ image_size: (int, int). Size of the input image.
79
+ patch_size: (int, int). Size of each image patch.
81
80
  hidden_dim: int. Dimensionality of the patch embeddings.
82
81
  num_channels: int. Number of channels in the input image. Defaults to
83
82
  `3`.
83
+ use_class_token: bool. Whether to use class token to be part of
84
+ patch embedding. Defaults to `True`.
84
85
  data_format: str. `"channels_last"` or `"channels_first"`. Defaults to
85
86
  `None` (which uses `"channels_last"`).
86
87
  **kwargs: Additional keyword arguments passed to `keras.layers.Layer`
@@ -92,12 +93,15 @@ class ViTPatchingAndEmbedding(keras.layers.Layer):
92
93
  patch_size,
93
94
  hidden_dim,
94
95
  num_channels=3,
96
+ use_class_token=True,
97
+ use_patch_bias=True,
95
98
  data_format=None,
96
99
  **kwargs,
97
100
  ):
98
101
  super().__init__(**kwargs)
99
- num_patches = (image_size // patch_size) ** 2
100
- num_positions = num_patches + 1
102
+ grid_size = tuple([s // p for s, p in zip(image_size, patch_size)])
103
+ num_patches = grid_size[0] * grid_size[1]
104
+ num_positions = num_patches + 1 if use_class_token else num_patches
101
105
 
102
106
  # === Config ===
103
107
  self.image_size = image_size
@@ -106,19 +110,22 @@ class ViTPatchingAndEmbedding(keras.layers.Layer):
106
110
  self.num_channels = num_channels
107
111
  self.num_patches = num_patches
108
112
  self.num_positions = num_positions
113
+ self.use_class_token = use_class_token
114
+ self.use_patch_bias = use_patch_bias
109
115
  self.data_format = standardize_data_format(data_format)
110
116
 
111
117
  def build(self, input_shape):
112
- self.class_token = self.add_weight(
113
- shape=(
114
- 1,
115
- 1,
116
- self.hidden_dim,
117
- ),
118
- initializer="random_normal",
119
- dtype=self.variable_dtype,
120
- name="class_token",
121
- )
118
+ if self.use_class_token:
119
+ self.class_token = self.add_weight(
120
+ shape=(
121
+ 1,
122
+ 1,
123
+ self.hidden_dim,
124
+ ),
125
+ initializer="random_normal",
126
+ dtype=self.variable_dtype,
127
+ name="class_token",
128
+ )
122
129
  self.patch_embedding = keras.layers.Conv2D(
123
130
  filters=self.hidden_dim,
124
131
  kernel_size=self.patch_size,
@@ -127,6 +134,7 @@ class ViTPatchingAndEmbedding(keras.layers.Layer):
127
134
  activation=None,
128
135
  dtype=self.dtype_policy,
129
136
  data_format=self.data_format,
137
+ use_bias=self.use_patch_bias,
130
138
  name="patch_embedding",
131
139
  )
132
140
  self.patch_embedding.build(input_shape)
@@ -153,10 +161,16 @@ class ViTPatchingAndEmbedding(keras.layers.Layer):
153
161
  patch_embeddings = ops.reshape(
154
162
  patch_embeddings, [embeddings_shape[0], -1, embeddings_shape[-1]]
155
163
  )
156
- class_token = ops.tile(self.class_token, (embeddings_shape[0], 1, 1))
157
164
  position_embeddings = self.position_embedding(self.position_ids)
158
- embeddings = ops.concatenate([class_token, patch_embeddings], axis=1)
159
- return ops.add(embeddings, position_embeddings)
165
+
166
+ if self.use_class_token:
167
+ class_token = ops.tile(
168
+ self.class_token, (embeddings_shape[0], 1, 1)
169
+ )
170
+ patch_embeddings = ops.concatenate(
171
+ [class_token, patch_embeddings], axis=1
172
+ )
173
+ return ops.add(patch_embeddings, position_embeddings)
160
174
 
161
175
  def compute_output_shape(self, input_shape):
162
176
  return (
@@ -175,6 +189,7 @@ class ViTPatchingAndEmbedding(keras.layers.Layer):
175
189
  "num_channels": self.num_channels,
176
190
  "num_patches": self.num_patches,
177
191
  "num_positions": self.num_positions,
192
+ "use_class_token": self.use_class_token,
178
193
  }
179
194
  )
180
195
  return config