keras-hub 0.21.1.dev0__py3-none-any.whl → 0.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. keras_hub/layers/__init__.py +9 -0
  2. keras_hub/models/__init__.py +47 -0
  3. keras_hub/src/layers/modeling/transformer_encoder.py +6 -3
  4. keras_hub/src/layers/preprocessing/multi_segment_packer.py +17 -3
  5. keras_hub/src/layers/preprocessing/start_end_packer.py +24 -6
  6. keras_hub/src/models/backbone.py +13 -10
  7. keras_hub/src/models/clip/clip_backbone.py +3 -102
  8. keras_hub/src/models/clip/clip_layers.py +295 -0
  9. keras_hub/src/models/clip/clip_preprocessor.py +57 -48
  10. keras_hub/src/models/clip/clip_text_encoder.py +2 -2
  11. keras_hub/src/models/clip/clip_vision_encoder.py +3 -3
  12. keras_hub/src/models/deit/__init__.py +5 -0
  13. keras_hub/src/models/deit/deit_backbone.py +154 -0
  14. keras_hub/src/models/deit/deit_image_classifier.py +171 -0
  15. keras_hub/src/models/deit/deit_image_classifier_preprocessor.py +12 -0
  16. keras_hub/src/models/deit/deit_image_converter.py +8 -0
  17. keras_hub/src/models/deit/deit_layers.py +519 -0
  18. keras_hub/src/models/deit/deit_presets.py +49 -0
  19. keras_hub/src/models/dinov2/__init__.py +5 -0
  20. keras_hub/src/models/dinov2/dinov2_backbone.py +228 -0
  21. keras_hub/src/models/dinov2/dinov2_image_converter.py +8 -0
  22. keras_hub/src/models/dinov2/dinov2_layers.py +886 -0
  23. keras_hub/src/models/dinov2/dinov2_presets.py +89 -0
  24. keras_hub/src/models/esm/__init__.py +5 -0
  25. keras_hub/src/models/esm/esm_attention.py +95 -0
  26. keras_hub/src/models/esm/esm_backbone.py +229 -0
  27. keras_hub/src/models/esm/esm_classifier.py +184 -0
  28. keras_hub/src/models/esm/esm_classifier_preprocessor.py +135 -0
  29. keras_hub/src/models/esm/esm_encoder.py +134 -0
  30. keras_hub/src/models/esm/esm_masked_plm.py +117 -0
  31. keras_hub/src/models/esm/esm_masked_plm_preprocessor.py +143 -0
  32. keras_hub/src/models/esm/esm_presets.py +53 -0
  33. keras_hub/src/models/esm/esm_tokenizer.py +82 -0
  34. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +6 -2
  35. keras_hub/src/models/gemma/gemma_attention.py +1 -1
  36. keras_hub/src/models/gemma3/gemma3_backbone.py +2 -2
  37. keras_hub/src/models/gemma3/gemma3_interleave_embeddings.py +1 -1
  38. keras_hub/src/models/gemma3/gemma3_presets.py +25 -0
  39. keras_hub/src/models/hgnetv2/__init__.py +5 -0
  40. keras_hub/src/models/hgnetv2/hgnetv2_backbone.py +193 -0
  41. keras_hub/src/models/hgnetv2/hgnetv2_encoder.py +148 -0
  42. keras_hub/src/models/hgnetv2/hgnetv2_image_classifier.py +216 -0
  43. keras_hub/src/models/hgnetv2/hgnetv2_image_classifier_preprocessor.py +14 -0
  44. keras_hub/src/models/hgnetv2/hgnetv2_image_converter.py +8 -0
  45. keras_hub/src/models/hgnetv2/hgnetv2_layers.py +918 -0
  46. keras_hub/src/models/hgnetv2/hgnetv2_presets.py +58 -0
  47. keras_hub/src/models/llama3/llama3_presets.py +3 -3
  48. keras_hub/src/models/mistral/mistral_presets.py +17 -1
  49. keras_hub/src/models/mixtral/mixtral_presets.py +2 -2
  50. keras_hub/src/models/mobilenet/mobilenet_presets.py +4 -4
  51. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +2 -2
  52. keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +2 -2
  53. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +17 -17
  54. keras_hub/src/models/qwen3/__init__.py +5 -0
  55. keras_hub/src/models/qwen3/qwen3_attention.py +369 -0
  56. keras_hub/src/models/qwen3/qwen3_backbone.py +191 -0
  57. keras_hub/src/models/qwen3/qwen3_causal_lm.py +390 -0
  58. keras_hub/src/models/qwen3/qwen3_causal_lm_preprocessor.py +10 -0
  59. keras_hub/src/models/qwen3/qwen3_decoder.py +309 -0
  60. keras_hub/src/models/qwen3/qwen3_layernorm.py +38 -0
  61. keras_hub/src/models/qwen3/qwen3_presets.py +73 -0
  62. keras_hub/src/models/qwen3/qwen3_tokenizer.py +48 -0
  63. keras_hub/src/models/qwen_moe/qwen_moe_attention.py +1 -0
  64. keras_hub/src/models/qwen_moe/qwen_moe_presets.py +2 -2
  65. keras_hub/src/models/roformer_v2/roformer_v2_attention.py +0 -2
  66. keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +16 -7
  67. keras_hub/src/models/stable_diffusion_3/mmdit.py +61 -4
  68. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +31 -32
  69. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +1 -0
  70. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +1 -0
  71. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +1 -0
  72. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +6 -2
  73. keras_hub/src/models/vit/vit_backbone.py +31 -11
  74. keras_hub/src/models/vit/vit_image_converter.py +0 -70
  75. keras_hub/src/models/vit/vit_layers.py +33 -18
  76. keras_hub/src/models/vit/vit_presets.py +11 -11
  77. keras_hub/src/utils/keras_utils.py +17 -0
  78. keras_hub/src/utils/preset_utils.py +19 -4
  79. keras_hub/src/utils/tensor_utils.py +14 -0
  80. keras_hub/src/utils/transformers/convert_deit.py +155 -0
  81. keras_hub/src/utils/transformers/convert_dinov2.py +180 -0
  82. keras_hub/src/utils/transformers/convert_esm.py +159 -0
  83. keras_hub/src/utils/transformers/convert_llama3.py +6 -0
  84. keras_hub/src/utils/transformers/convert_qwen3.py +145 -0
  85. keras_hub/src/utils/transformers/export/gemma.py +89 -0
  86. keras_hub/src/utils/transformers/export/hf_exporter.py +98 -0
  87. keras_hub/src/utils/transformers/preset_loader.py +14 -2
  88. keras_hub/src/version.py +1 -1
  89. keras_hub/tokenizers/__init__.py +1 -0
  90. {keras_hub-0.21.1.dev0.dist-info → keras_hub-0.22.0.dist-info}/METADATA +4 -4
  91. {keras_hub-0.21.1.dev0.dist-info → keras_hub-0.22.0.dist-info}/RECORD +93 -49
  92. keras_hub/src/models/clip/clip_encoder_block.py +0 -111
  93. keras_hub/src/models/clip/clip_vision_embedding.py +0 -101
  94. {keras_hub-0.21.1.dev0.dist-info → keras_hub-0.22.0.dist-info}/WHEEL +0 -0
  95. {keras_hub-0.21.1.dev0.dist-info → keras_hub-0.22.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,309 @@
1
+ import keras
2
+ from keras import ops
3
+
4
+ from keras_hub.src.layers.modeling.transformer_layer_utils import (
5
+ compute_causal_mask,
6
+ )
7
+ from keras_hub.src.layers.modeling.transformer_layer_utils import (
8
+ merge_padding_and_attention_mask,
9
+ )
10
+ from keras_hub.src.models.qwen3.qwen3_attention import Qwen3Attention
11
+ from keras_hub.src.models.qwen3.qwen3_layernorm import Qwen3LayerNorm
12
+ from keras_hub.src.utils.keras_utils import clone_initializer
13
+
14
+
15
+ class Qwen3TransformerDecoder(keras.layers.Layer):
16
+ """A Transformer decoder layer for the Qwen3 backbone.
17
+
18
+ This layer implements a Transformer decoder block that includes
19
+ self-attention with optional sliding window attention and a feed-forward
20
+ network.
21
+
22
+ Args:
23
+ intermediate_dim: Output dimension of the first dense layer in the
24
+ feed-forward network.
25
+ num_query_heads: Number of query attention heads.
26
+ num_key_value_heads: Number of key/value attention heads (for GQA).
27
+ rope_max_wavelength: Maximum wavelength for RoPE (Rotary Position
28
+ Embedding).
29
+ rope_scaling_factor: Scaling factor for RoPE, used for extending
30
+ context length.
31
+ activation: Activation function to use in the feed-forward network.
32
+ layer_norm_epsilon: Small float added to variance to avoid dividing
33
+ by zero in layer norm.
34
+ kernel_initializer: Initializer for the kernel weights.
35
+ dropout: Dropout rate for attention and hidden layers.
36
+ sliding_window_size: Size of the sliding window for attention when
37
+ enabled.
38
+ **kwargs: Additional keyword arguments to pass to the Layer.
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ intermediate_dim,
44
+ num_query_heads,
45
+ num_key_value_heads,
46
+ head_dim,
47
+ rope_max_wavelength=10000,
48
+ rope_scaling_factor=1.0,
49
+ activation="silu",
50
+ layer_norm_epsilon=1e-5,
51
+ kernel_initializer="glorot_uniform",
52
+ dropout=0.0,
53
+ sliding_window_size=None,
54
+ **kwargs,
55
+ ):
56
+ super().__init__(**kwargs)
57
+ self.intermediate_dim = intermediate_dim
58
+ self.num_query_heads = num_query_heads
59
+ self.num_key_value_heads = num_key_value_heads
60
+ self.head_dim = head_dim
61
+
62
+ self.rope_max_wavelength = rope_max_wavelength
63
+ self.rope_scaling_factor = rope_scaling_factor
64
+
65
+ self.dropout = dropout
66
+
67
+ self.sliding_window_size = sliding_window_size
68
+
69
+ self.activation = keras.activations.get(activation)
70
+ self.layer_norm_epsilon = layer_norm_epsilon
71
+ self.kernel_initializer = keras.initializers.get(kernel_initializer)
72
+
73
+ self.supports_masking = True
74
+
75
+ def build(self, decoder_sequence_shape):
76
+ self._decoder_sequence_shape = decoder_sequence_shape
77
+ self.hidden_dim = decoder_sequence_shape[-1]
78
+
79
+ # Self attention layer.
80
+ self._self_attention_layer = Qwen3Attention(
81
+ num_query_heads=self.num_query_heads,
82
+ num_key_value_heads=self.num_key_value_heads,
83
+ rope_max_wavelength=self.rope_max_wavelength,
84
+ head_dim=self.head_dim,
85
+ rope_scaling_factor=self.rope_scaling_factor,
86
+ kernel_initializer=clone_initializer(self.kernel_initializer),
87
+ dropout=self.dropout,
88
+ sliding_window_size=self.sliding_window_size,
89
+ dtype=self.dtype_policy,
90
+ name="self_attention",
91
+ )
92
+ self._self_attention_layer.build(decoder_sequence_shape)
93
+
94
+ self._self_attention_layernorm = Qwen3LayerNorm(
95
+ epsilon=self.layer_norm_epsilon,
96
+ dtype=self.dtype_policy,
97
+ name="self_attention_layernorm",
98
+ )
99
+
100
+ self._self_attention_layernorm.build(decoder_sequence_shape)
101
+ self._self_attention_dropout = keras.layers.Dropout(
102
+ rate=self.dropout,
103
+ dtype=self.dtype_policy,
104
+ name="self_attention_dropout",
105
+ )
106
+
107
+ # Feedforward layers.
108
+ self._feedforward_intermediate_dense = keras.layers.Dense(
109
+ self.intermediate_dim,
110
+ kernel_initializer=clone_initializer(self.kernel_initializer),
111
+ use_bias=False,
112
+ dtype=self.dtype_policy,
113
+ name="feedforward_intermediate_dense",
114
+ )
115
+ self._feedforward_intermediate_dense.build(decoder_sequence_shape)
116
+
117
+ self._feedforward_gate_dense = keras.layers.Dense(
118
+ self.intermediate_dim,
119
+ kernel_initializer=clone_initializer(self.kernel_initializer),
120
+ use_bias=False,
121
+ dtype=self.dtype_policy,
122
+ name="feedforward_gate_dense",
123
+ )
124
+ self._feedforward_gate_dense.build(decoder_sequence_shape)
125
+
126
+ self._feedforward_output_dense = keras.layers.Dense(
127
+ self.hidden_dim,
128
+ kernel_initializer=clone_initializer(self.kernel_initializer),
129
+ use_bias=False,
130
+ dtype=self.dtype_policy,
131
+ name="feedforward_output_dense",
132
+ )
133
+
134
+ self._feedforward_output_dense.build(
135
+ self._feedforward_gate_dense.compute_output_shape(
136
+ decoder_sequence_shape
137
+ )
138
+ )
139
+
140
+ self._feedforward_layernorm = Qwen3LayerNorm(
141
+ epsilon=self.layer_norm_epsilon,
142
+ dtype=self.dtype_policy,
143
+ name="feedforward_layernorm",
144
+ )
145
+ self._feedforward_layernorm.build(decoder_sequence_shape)
146
+
147
+ self.built = True
148
+
149
+ def call(
150
+ self,
151
+ decoder_sequence,
152
+ decoder_padding_mask=None,
153
+ decoder_attention_mask=None,
154
+ self_attention_cache=None,
155
+ self_attention_cache_update_index=None,
156
+ training=None,
157
+ ):
158
+ """Forward pass for the decoder layer.
159
+
160
+ Args:
161
+ decoder_sequence: Input tensor of shape [batch_size, seq_length,
162
+ hidden_size].
163
+ decoder_padding_mask: Mask tensor for padding tokens.
164
+ decoder_attention_mask: Additional attention mask.
165
+ self_attention_cache: Optional cached key and value tensors for
166
+ self-attention.
167
+ self_attention_cache_update_index: Index at which to update the
168
+ cache.
169
+ training: Boolean indicating whether in training mode.
170
+
171
+ Returns:
172
+ decoder_output: Output tensor after applying transformer decoder
173
+ block.
174
+ self_attention_cache: Updated cache tensors (if cache is provided).
175
+ """
176
+ self_attention_mask = self._compute_self_attention_mask(
177
+ decoder_sequence=decoder_sequence,
178
+ decoder_padding_mask=decoder_padding_mask,
179
+ decoder_attention_mask=decoder_attention_mask,
180
+ self_attention_cache=self_attention_cache,
181
+ self_attention_cache_update_index=self_attention_cache_update_index,
182
+ )
183
+ residual = decoder_sequence
184
+
185
+ x = self._self_attention_layernorm(decoder_sequence)
186
+
187
+ # Self attention block.
188
+ x = self._self_attention_layer(
189
+ hidden_states=x,
190
+ attention_mask=self_attention_mask,
191
+ cache=self_attention_cache,
192
+ cache_update_index=self_attention_cache_update_index,
193
+ )
194
+
195
+ if self_attention_cache is not None:
196
+ x, self_attention_cache = x
197
+
198
+ x = self._self_attention_dropout(x, training=training)
199
+
200
+ x = x + residual
201
+ residual = x
202
+
203
+ x = self._feedforward_layernorm(x)
204
+ gate_output = self._feedforward_gate_dense(x)
205
+
206
+ # Note that we run the activation function in full 32-bit
207
+ # precision since this is what `torch.nn.functional.silu`
208
+ # does. Internally, `torch.nn.functional.silu` converts the
209
+ # inputs to float32, computes SiLU, and converts the outputs
210
+ # back to compute dtype.
211
+ # CPU Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cpu/Activation.cpp#L1221-L1235 # noqa: E501
212
+ # CUDA Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cuda/ActivationSiluKernel.cu # noqa: E501
213
+ gate_output = ops.cast(gate_output, "float32")
214
+ gate_output = self.activation(gate_output)
215
+ gate_output = ops.cast(gate_output, self.compute_dtype)
216
+
217
+ x = self._feedforward_intermediate_dense(x)
218
+
219
+ x = self._feedforward_output_dense(ops.multiply(x, gate_output))
220
+
221
+ decoder_output = x + residual
222
+
223
+ if self_attention_cache is not None:
224
+ return decoder_output, self_attention_cache
225
+ return decoder_output
226
+
227
+ def _compute_self_attention_mask(
228
+ self,
229
+ decoder_sequence,
230
+ decoder_padding_mask,
231
+ decoder_attention_mask,
232
+ self_attention_cache,
233
+ self_attention_cache_update_index,
234
+ ):
235
+ """Computes the self-attention mask combining causal, padding and
236
+ attention masks.
237
+
238
+ Args:
239
+ decoder_sequence: Input tensor.
240
+ decoder_padding_mask: Mask tensor for padding tokens.
241
+ decoder_attention_mask: Additional attention mask.
242
+ self_attention_cache: Optional cached key and value tensors.
243
+ self_attention_cache_update_index: Index at which to update the
244
+ cache.
245
+
246
+ Returns:
247
+ Combined attention mask tensor.
248
+ """
249
+ decoder_mask = merge_padding_and_attention_mask(
250
+ decoder_sequence, decoder_padding_mask, decoder_attention_mask
251
+ )
252
+ batch_size = ops.shape(decoder_sequence)[0]
253
+ input_length = output_length = ops.shape(decoder_sequence)[1]
254
+ # We need to handle a rectangular causal mask when doing cached
255
+ # decoding. For generative inference, `decoder_sequence` will
256
+ # generally be length 1, and `cache` will be the full generation length.
257
+ if self_attention_cache is not None:
258
+ input_length = ops.shape(self_attention_cache)[2]
259
+
260
+ cache_update_index = (
261
+ 0
262
+ if self_attention_cache_update_index is None
263
+ else self_attention_cache_update_index
264
+ )
265
+
266
+ causal_mask = compute_causal_mask(
267
+ batch_size, input_length, output_length, cache_update_index
268
+ )
269
+
270
+ return (
271
+ ops.minimum(decoder_mask, causal_mask)
272
+ if decoder_mask is not None
273
+ else causal_mask
274
+ )
275
+
276
+ def compute_output_shape(self, decoder_sequence_shape):
277
+ """Computes the output shape of the layer.
278
+
279
+ Args:
280
+ decoder_sequence_shape: Shape of the decoder sequence input.
281
+
282
+ Returns:
283
+ Output shape, which is the same as the input shape.
284
+ """
285
+ return decoder_sequence_shape
286
+
287
+ def get_config(self):
288
+ """Returns the config of the layer.
289
+
290
+ Returns:
291
+ Dictionary containing the parameters used to initialize this layer.
292
+ """
293
+ config = super().get_config()
294
+ config.update(
295
+ {
296
+ "intermediate_dim": self.intermediate_dim,
297
+ "num_query_heads": self.num_query_heads,
298
+ "rope_max_wavelength": self.rope_max_wavelength,
299
+ "rope_scaling_factor": self.rope_scaling_factor,
300
+ "num_key_value_heads": self.num_key_value_heads,
301
+ "activation": keras.activations.serialize(self.activation),
302
+ "layer_norm_epsilon": self.layer_norm_epsilon,
303
+ "kernel_initializer": keras.initializers.serialize(
304
+ self.kernel_initializer
305
+ ),
306
+ "dropout": self.dropout,
307
+ }
308
+ )
309
+ return config
@@ -0,0 +1,38 @@
1
+ import keras
2
+ from keras import ops
3
+
4
+
5
+ class Qwen3LayerNorm(keras.layers.Layer):
6
+ """A normalization layer for Qwen that implements RMS normalization."""
7
+
8
+ def __init__(self, head_dim=None, epsilon=1e-6, **kwargs):
9
+ super().__init__(**kwargs)
10
+ self.head_dim = head_dim
11
+ self.epsilon = epsilon
12
+
13
+ def build(self, input_shape):
14
+ if self.head_dim:
15
+ dim = self.head_dim
16
+ else:
17
+ dim = input_shape[-1]
18
+
19
+ self.scale = self.add_weight(
20
+ name="scale",
21
+ trainable=True,
22
+ shape=(dim,),
23
+ initializer="ones",
24
+ dtype=self.variable_dtype,
25
+ )
26
+ self.built = True
27
+
28
+ def call(self, x):
29
+ input_dtype = x.dtype
30
+ x = ops.cast(x, "float32")
31
+ var = ops.mean(ops.power(x, 2), axis=-1, keepdims=True)
32
+ x = x * ops.rsqrt(var + self.epsilon)
33
+ return ops.cast(x * self.scale, input_dtype)
34
+
35
+ def get_config(self):
36
+ config = super().get_config()
37
+ config.update({"epsilon": self.epsilon})
38
+ return config
@@ -0,0 +1,73 @@
1
+ """Qwen3 model preset configurations."""
2
+
3
+ backbone_presets = {
4
+ "qwen3_0.6b_en": {
5
+ "metadata": {
6
+ "description": (
7
+ "28-layer Qwen3 model with 596M parameters, optimized for "
8
+ "efficiency and fast inference on resource-constrained devices."
9
+ ),
10
+ "params": 596049920,
11
+ "path": "qwen3",
12
+ },
13
+ "kaggle_handle": "kaggle://keras/qwen-3/keras/qwen3_0.6b_en/1",
14
+ },
15
+ "qwen3_1.7b_en": {
16
+ "metadata": {
17
+ "description": (
18
+ "28-layer Qwen3 model with 1.72B parameters, offering "
19
+ "a good balance between performance and resource usage."
20
+ ),
21
+ "params": 1720574976,
22
+ "path": "qwen3",
23
+ },
24
+ "kaggle_handle": "kaggle://keras/qwen-3/keras/qwen3_1.7b_en/1",
25
+ },
26
+ "qwen3_4b_en": {
27
+ "metadata": {
28
+ "description": (
29
+ "36-layer Qwen3 model with 4.02B parameters, offering improved "
30
+ "reasoning capabilities and better performance than smaller "
31
+ "variants."
32
+ ),
33
+ "params": 4022468096,
34
+ "path": "qwen3",
35
+ },
36
+ "kaggle_handle": "kaggle://keras/qwen-3/keras/qwen3_4b_en/1",
37
+ },
38
+ "qwen3_8b_en": {
39
+ "metadata": {
40
+ "description": (
41
+ "36-layer Qwen3 model with 8.19B parameters, featuring "
42
+ "enhanced reasoning, coding, and instruction-following "
43
+ "capabilities."
44
+ ),
45
+ "params": 8190735360,
46
+ "path": "qwen3",
47
+ },
48
+ "kaggle_handle": "kaggle://keras/qwen-3/keras/qwen3_8b_en/1",
49
+ },
50
+ "qwen3_14b_en": {
51
+ "metadata": {
52
+ "description": (
53
+ "40-layer Qwen3 model with 14.77B parameters, featuring "
54
+ "advanced reasoning, coding, and multilingual capabilities."
55
+ ),
56
+ "params": 14768307200,
57
+ "path": "qwen3",
58
+ },
59
+ "kaggle_handle": "kaggle://keras/qwen-3/keras/qwen3_14b_en/1",
60
+ },
61
+ "qwen3_32b_en": {
62
+ "metadata": {
63
+ "description": (
64
+ "64-layer Qwen3 model with 32.76B parameters, featuring "
65
+ "state-of-the-art performance across reasoning, coding, and "
66
+ "general language tasks."
67
+ ),
68
+ "params": 32762123264,
69
+ "path": "qwen3",
70
+ },
71
+ "kaggle_handle": "kaggle://keras/qwen-3/keras/qwen3_32b_en/1",
72
+ },
73
+ }
@@ -0,0 +1,48 @@
1
+ from keras_hub.src.api_export import keras_hub_export
2
+ from keras_hub.src.models.qwen3.qwen3_backbone import Qwen3Backbone
3
+ from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer
4
+
5
+
6
+ @keras_hub_export(
7
+ "keras_hub.models.Qwen3Tokenizer",
8
+ )
9
+ class Qwen3Tokenizer(BytePairTokenizer):
10
+ """Tokenizer for Qwen3 models.
11
+
12
+ This tokenizer implements byte-pair encoding (BPE) for Qwen3 models,
13
+ handling special tokens like BOS (beginning of sequence) and EOS (end of
14
+ sequence).
15
+
16
+ Args:
17
+ vocabulary: Dictionary mapping tokens to token IDs, or path to
18
+ vocabulary file.
19
+ merges: List of BPE merges, or path to merges file.
20
+ bos_token: Beginning of sequence token. Defaults to None.
21
+ eos_token: End of sequence token. Defaults to "<|endoftext|>".
22
+ misc_special_tokens: Set of additional special tokens. Defaults to
23
+ empty set.
24
+ """
25
+
26
+ backbone_cls = Qwen3Backbone
27
+
28
+ def __init__(
29
+ self,
30
+ vocabulary=None,
31
+ merges=None,
32
+ **kwargs,
33
+ ):
34
+ # Add EOS token
35
+ eos_token = "<|im_end|>"
36
+ self._add_special_token(eos_token, "end_token")
37
+
38
+ pad_token = "<|endoftext|>"
39
+ self._add_special_token(pad_token, "pad_token")
40
+
41
+ self.start_token_id = None
42
+ self.start_token = None
43
+
44
+ super().__init__(
45
+ vocabulary=vocabulary,
46
+ merges=merges,
47
+ **kwargs,
48
+ )
@@ -67,6 +67,7 @@ class QwenMoeAttention(keras.layers.Layer):
67
67
  self.rope_scaling_factor = rope_scaling_factor
68
68
  self.use_sliding_window_attention = use_sliding_window_attention
69
69
  self.sliding_window_size = sliding_window_size
70
+ self.logit_soft_cap = None
70
71
 
71
72
  def build(self, inputs_shape):
72
73
  # Einsum variables:
@@ -8,8 +8,8 @@ backbone_presets = {
8
8
  "and 8 experts per MoE layer."
9
9
  ),
10
10
  "params": 14315784192,
11
- "path": "qwen-1.5-moe",
11
+ "path": "qwen_moe",
12
12
  },
13
- "kaggle_handle": "kaggle://keras/qwen-1.5-moe/Keras/qwen1.5_moe_2.7b_en/3",
13
+ "kaggle_handle": "kaggle://keras/qwen-1.5-moe/Keras/qwen1.5_moe_2.7b_en/4",
14
14
  },
15
15
  }
@@ -179,8 +179,6 @@ class RoformerAttention(keras.layers.Layer):
179
179
  vw = ops.reshape(vw, (b, s, self.heads, self.head_size))
180
180
 
181
181
  qw, kw = self.rotary_embedding_layer([qw, kw])
182
- if keras.__version__ < "3.6":
183
- raise ("Please make sure your Keras version is >=3.6.")
184
182
  flash_attention = keras.config.is_flash_attention_enabled()
185
183
  attention_mask = ops.reshape(attention_mask, [b, 1, s, 1])
186
184
  if keras.config.backend() == "torch":
@@ -38,7 +38,6 @@ class FlowMatchEulerDiscreteScheduler(layers.Layer):
38
38
  timesteps = ops.flip(timesteps, axis=0)
39
39
  sigmas = self._timestep_to_sigma(timesteps)
40
40
 
41
- self.timesteps = ops.multiply(sigmas, num_train_timesteps)
42
41
  self.sigma_min = sigmas[-1]
43
42
  self.sigma_max = sigmas[0]
44
43
 
@@ -54,14 +53,24 @@ class FlowMatchEulerDiscreteScheduler(layers.Layer):
54
53
  )
55
54
  return sigma
56
55
 
56
+ def set_sigmas(self, num_steps):
57
+ timesteps = ops.linspace(
58
+ self._sigma_to_timestep(self.sigma_max),
59
+ self._sigma_to_timestep(self.sigma_min),
60
+ num_steps,
61
+ )
62
+ sigmas = self._timestep_to_sigma(timesteps)
63
+ sigmas = ops.concatenate([sigmas, ops.zeros((1,), dtype=sigmas.dtype)])
64
+ self.sigmas = sigmas
65
+
57
66
  def call(self, inputs, num_steps):
58
- start = self._sigma_to_timestep(self.sigma_max)
59
- end = self._sigma_to_timestep(self.sigma_min)
60
- step_size = ops.divide(
61
- ops.subtract(end, start), ops.subtract(num_steps, 1)
67
+ if not hasattr(self, "sigmas"):
68
+ self.set_sigmas(num_steps)
69
+
70
+ step = ops.expand_dims(
71
+ ops.convert_to_tensor(inputs, dtype="int32"), axis=0
62
72
  )
63
- timestep = ops.add(start, ops.multiply(inputs, step_size))
64
- sigma = ops.maximum(self._timestep_to_sigma(timestep), 0.0)
73
+ sigma = ops.take(self.sigmas, step)
65
74
  timestep = self._sigma_to_timestep(sigma)
66
75
  return sigma, timestep
67
76
 
@@ -10,6 +10,63 @@ from keras_hub.src.utils.keras_utils import fused_attention_op_available
10
10
  from keras_hub.src.utils.keras_utils import gelu_approximate
11
11
  from keras_hub.src.utils.keras_utils import standardize_data_format
12
12
 
13
+ # TODO: Deprecate this in favor of
14
+ # `keras.layers.RMSNormalization` once we require Keras 3.9 or later.
15
+ if hasattr(layers, "RMSNormalization"):
16
+ RMSNormalization = layers.RMSNormalization
17
+ else:
18
+
19
+ class RMSNormalization(layers.Layer):
20
+ """A normalization layer for MMDiT that implements RMS normalization."""
21
+
22
+ def __init__(self, axis=-1, epsilon=1e-6, **kwargs):
23
+ super().__init__(**kwargs)
24
+ self.axis = axis
25
+ self.epsilon = epsilon
26
+
27
+ def build(self, input_shape):
28
+ if isinstance(self.axis, list):
29
+ shape = tuple([input_shape[dim] for dim in self.axis])
30
+ else:
31
+ shape = (input_shape[self.axis],)
32
+ self.axis = [self.axis]
33
+
34
+ self.scale = self.add_weight(
35
+ name="scale", shape=shape, initializer="ones"
36
+ )
37
+
38
+ self.built = True
39
+
40
+ def call(self, x):
41
+ x = ops.cast(
42
+ x, keras.backend.result_type(self.compute_dtype, "float32")
43
+ )
44
+ rrms = ops.rsqrt(
45
+ ops.mean(ops.square(x), axis=self.axis, keepdims=True)
46
+ + self.epsilon
47
+ )
48
+ return (x * rrms) * ops.cast(self.scale, x.dtype)
49
+
50
+ def compute_output_shape(self, input_shape):
51
+ if isinstance(self.axis, int):
52
+ axes = [self.axis]
53
+ else:
54
+ axes = self.axis
55
+
56
+ for axis in axes:
57
+ if axis >= len(input_shape) or axis < -len(input_shape):
58
+ raise ValueError(
59
+ f"Axis {axis} is out of bounds for "
60
+ f"input shape {input_shape}. "
61
+ f"Received: axis={self.axis}"
62
+ )
63
+ return input_shape
64
+
65
+ def get_config(self):
66
+ config = super().get_config()
67
+ config.update({"axis": self.axis, "epsilon": self.epsilon})
68
+ return config
69
+
13
70
 
14
71
  class AdaptiveLayerNormalization(layers.Layer):
15
72
  """Adaptive layer normalization.
@@ -402,11 +459,11 @@ def get_qk_norm(qk_norm=None, q_norm_name="q_norm", k_norm_name="k_norm"):
402
459
  if qk_norm is None:
403
460
  pass
404
461
  elif qk_norm == "rms_norm":
405
- q_norm = layers.LayerNormalization(
406
- epsilon=1e-6, rms_scaling=True, dtype="float32", name=q_norm_name
462
+ q_norm = RMSNormalization(
463
+ axis=-1, epsilon=1e-6, dtype="float32", name=q_norm_name
407
464
  )
408
- k_norm = layers.LayerNormalization(
409
- epsilon=1e-6, rms_scaling=True, dtype="float32", name=k_norm_name
465
+ k_norm = RMSNormalization(
466
+ axis=-1, epsilon=1e-6, dtype="float32", name=k_norm_name
410
467
  )
411
468
  else:
412
469
  raise NotImplementedError(