keras-hub 0.20.0.dev1__py3-none-any.whl → 0.21.0.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +15 -33
- keras_hub/layers/__init__.py +134 -0
- keras_hub/metrics/__init__.py +11 -0
- keras_hub/models/__init__.py +642 -0
- keras_hub/samplers/__init__.py +18 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +25 -35
- keras_hub/src/layers/preprocessing/image_converter.py +1 -0
- keras_hub/src/layers/preprocessing/random_deletion.py +1 -1
- keras_hub/src/layers/preprocessing/random_swap.py +1 -1
- keras_hub/src/models/audio_to_text.py +66 -0
- keras_hub/src/models/audio_to_text_preprocessor.py +80 -0
- keras_hub/src/models/backbone.py +5 -2
- keras_hub/src/models/cspnet/cspnet_backbone.py +51 -26
- keras_hub/src/models/cspnet/cspnet_presets.py +38 -3
- keras_hub/src/models/falcon/falcon_backbone.py +1 -1
- keras_hub/src/models/gemma/gemma_presets.py +10 -10
- keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +3 -2
- keras_hub/src/models/gemma3/gemma3_presets.py +8 -8
- keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
- keras_hub/src/models/llama/llama_attention.py +24 -6
- keras_hub/src/models/llama/llama_backbone.py +50 -16
- keras_hub/src/models/llama/llama_decoder.py +20 -3
- keras_hub/src/models/llama/llama_presets.py +3 -3
- keras_hub/src/models/llama/llama_rotary_embedding.py +180 -0
- keras_hub/src/models/llama3/llama3_backbone.py +10 -2
- keras_hub/src/models/llama3/llama3_presets.py +84 -2
- keras_hub/src/models/mistral/mistral_presets.py +3 -3
- keras_hub/src/models/mixtral/__init__.py +5 -0
- keras_hub/src/models/mixtral/mixtral_attention.py +252 -0
- keras_hub/src/models/mixtral/mixtral_backbone.py +207 -0
- keras_hub/src/models/mixtral/mixtral_causal_lm.py +281 -0
- keras_hub/src/models/mixtral/mixtral_causal_lm_preprocessor.py +76 -0
- keras_hub/src/models/mixtral/mixtral_decoder.py +494 -0
- keras_hub/src/models/mixtral/mixtral_layer_norm.py +34 -0
- keras_hub/src/models/mixtral/mixtral_presets.py +26 -0
- keras_hub/src/models/mixtral/mixtral_tokenizer.py +21 -0
- keras_hub/src/models/moonshine/__init__.py +5 -0
- keras_hub/src/models/moonshine/moonshine_audio_converter.py +301 -0
- keras_hub/src/models/moonshine/moonshine_audio_to_text.py +383 -0
- keras_hub/src/models/moonshine/moonshine_audio_to_text_preprocessor.py +272 -0
- keras_hub/src/models/moonshine/moonshine_backbone.py +478 -0
- keras_hub/src/models/moonshine/moonshine_decoder.py +313 -0
- keras_hub/src/models/moonshine/moonshine_encoder.py +212 -0
- keras_hub/src/models/moonshine/moonshine_layers.py +239 -0
- keras_hub/src/models/moonshine/moonshine_multi_head_attention.py +355 -0
- keras_hub/src/models/moonshine/moonshine_presets.py +25 -0
- keras_hub/src/models/moonshine/moonshine_tokenizer.py +62 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +11 -11
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +1 -1
- keras_hub/src/models/qwen/__init__.py +4 -0
- keras_hub/src/models/qwen/qwen_attention.py +3 -1
- keras_hub/src/models/qwen/qwen_backbone.py +8 -1
- keras_hub/src/models/qwen/qwen_causal_lm.py +7 -0
- keras_hub/src/models/qwen/qwen_causal_lm_preprocessor.py +7 -0
- keras_hub/src/models/qwen/qwen_presets.py +61 -0
- keras_hub/src/models/qwen/qwen_tokenizer.py +9 -0
- keras_hub/src/models/qwen_moe/__init__.py +5 -0
- keras_hub/src/models/qwen_moe/qwen_moe_attention.py +375 -0
- keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +373 -0
- keras_hub/src/models/qwen_moe/qwen_moe_causal_lm.py +350 -0
- keras_hub/src/models/qwen_moe/qwen_moe_causal_lm_preprocessor.py +17 -0
- keras_hub/src/models/qwen_moe/qwen_moe_decoder.py +625 -0
- keras_hub/src/models/qwen_moe/qwen_moe_layernorm.py +32 -0
- keras_hub/src/models/qwen_moe/qwen_moe_presets.py +15 -0
- keras_hub/src/models/qwen_moe/qwen_moe_tokenizer.py +46 -0
- keras_hub/src/models/retinanet/retinanet_image_converter.py +0 -13
- keras_hub/src/models/retinanet/retinanet_presets.py +2 -2
- keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +0 -18
- keras_hub/src/models/segformer/segformer_presets.py +12 -12
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +6 -0
- keras_hub/src/models/task.py +5 -2
- keras_hub/src/models/xception/__init__.py +5 -0
- keras_hub/src/models/xception/xception_backbone.py +188 -0
- keras_hub/src/models/xception/xception_image_classifier.py +12 -0
- keras_hub/src/models/xception/xception_image_classifier_preprocessor.py +14 -0
- keras_hub/src/models/xception/xception_image_converter.py +8 -0
- keras_hub/src/models/xception/xception_presets.py +14 -0
- keras_hub/src/tests/mocks/mock_gemma3_tokenizer.py +155 -0
- keras_hub/src/utils/coco/__init__.py +0 -0
- keras_hub/src/utils/coco/coco_utils.py +133 -0
- keras_hub/src/utils/imagenet/imagenet_utils.py +36 -0
- keras_hub/src/utils/keras_utils.py +11 -0
- keras_hub/src/utils/preset_utils.py +70 -10
- keras_hub/src/utils/tensor_utils.py +27 -1
- keras_hub/src/utils/timm/convert_cspnet.py +94 -23
- keras_hub/src/utils/timm/preset_loader.py +6 -6
- keras_hub/src/utils/transformers/convert_llama3.py +21 -1
- keras_hub/src/utils/transformers/convert_mixtral.py +139 -0
- keras_hub/src/utils/transformers/convert_qwen.py +1 -0
- keras_hub/src/utils/transformers/convert_qwen_moe.py +253 -0
- keras_hub/src/utils/transformers/preset_loader.py +6 -0
- keras_hub/src/{version_utils.py → version.py} +1 -1
- keras_hub/tokenizers/__init__.py +117 -0
- keras_hub/utils/__init__.py +21 -0
- {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/METADATA +6 -20
- {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/RECORD +98 -55
- {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/WHEEL +1 -1
- keras_hub/api/__init__.py +0 -15
- keras_hub/api/layers/__init__.py +0 -86
- keras_hub/api/metrics/__init__.py +0 -11
- keras_hub/api/models/__init__.py +0 -416
- keras_hub/api/samplers/__init__.py +0 -16
- keras_hub/api/tokenizers/__init__.py +0 -58
- keras_hub/api/utils/__init__.py +0 -9
- {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/top_level.txt +0 -0
|
@@ -81,7 +81,7 @@ backbone_presets = {
|
|
|
81
81
|
"path": "pali_gemma2",
|
|
82
82
|
"model_card": "https://www.kaggle.com/models/google/paligemma-2",
|
|
83
83
|
},
|
|
84
|
-
"kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_ft_docci_10b_448/
|
|
84
|
+
"kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_ft_docci_10b_448/3",
|
|
85
85
|
},
|
|
86
86
|
"pali_gemma2_mix_3b_224": {
|
|
87
87
|
"metadata": {
|
|
@@ -126,7 +126,7 @@ backbone_presets = {
|
|
|
126
126
|
"path": "pali_gemma2",
|
|
127
127
|
"model_card": "https://www.kaggle.com/models/google/paligemma-2",
|
|
128
128
|
},
|
|
129
|
-
"kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_mix_10b_224/
|
|
129
|
+
"kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_mix_10b_224/3",
|
|
130
130
|
},
|
|
131
131
|
"pali_gemma2_mix_10b_448": {
|
|
132
132
|
"metadata": {
|
|
@@ -141,7 +141,7 @@ backbone_presets = {
|
|
|
141
141
|
"path": "pali_gemma2",
|
|
142
142
|
"model_card": "https://www.kaggle.com/models/google/paligemma-2",
|
|
143
143
|
},
|
|
144
|
-
"kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_mix_10b_448/
|
|
144
|
+
"kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_mix_10b_448/3",
|
|
145
145
|
},
|
|
146
146
|
"pali_gemma2_mix_28b_224": {
|
|
147
147
|
"metadata": {
|
|
@@ -156,7 +156,7 @@ backbone_presets = {
|
|
|
156
156
|
"path": "pali_gemma2",
|
|
157
157
|
"model_card": "https://www.kaggle.com/models/google/paligemma-2",
|
|
158
158
|
},
|
|
159
|
-
"kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_28b_mix_224/
|
|
159
|
+
"kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_28b_mix_224/3",
|
|
160
160
|
},
|
|
161
161
|
"pali_gemma2_mix_28b_448": {
|
|
162
162
|
"metadata": {
|
|
@@ -171,7 +171,7 @@ backbone_presets = {
|
|
|
171
171
|
"path": "pali_gemma2",
|
|
172
172
|
"model_card": "https://www.kaggle.com/models/google/paligemma-2",
|
|
173
173
|
},
|
|
174
|
-
"kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_28b_mix_448/
|
|
174
|
+
"kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_28b_mix_448/3",
|
|
175
175
|
},
|
|
176
176
|
"pali_gemma2_pt_3b_224": {
|
|
177
177
|
"metadata": {
|
|
@@ -231,7 +231,7 @@ backbone_presets = {
|
|
|
231
231
|
"path": "pali_gemma2",
|
|
232
232
|
"model_card": "https://www.kaggle.com/models/google/paligemma-2",
|
|
233
233
|
},
|
|
234
|
-
"kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_10b_224/
|
|
234
|
+
"kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_10b_224/3",
|
|
235
235
|
},
|
|
236
236
|
"pali_gemma2_pt_10b_448": {
|
|
237
237
|
"metadata": {
|
|
@@ -246,7 +246,7 @@ backbone_presets = {
|
|
|
246
246
|
"path": "pali_gemma2",
|
|
247
247
|
"model_card": "https://www.kaggle.com/models/google/paligemma-2",
|
|
248
248
|
},
|
|
249
|
-
"kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_10b_448/
|
|
249
|
+
"kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_10b_448/3",
|
|
250
250
|
},
|
|
251
251
|
"pali_gemma2_pt_10b_896": {
|
|
252
252
|
"metadata": {
|
|
@@ -261,7 +261,7 @@ backbone_presets = {
|
|
|
261
261
|
"path": "pali_gemma2",
|
|
262
262
|
"model_card": "https://www.kaggle.com/models/google/paligemma-2",
|
|
263
263
|
},
|
|
264
|
-
"kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_10b_896/
|
|
264
|
+
"kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_10b_896/3",
|
|
265
265
|
},
|
|
266
266
|
"pali_gemma2_pt_28b_224": {
|
|
267
267
|
"metadata": {
|
|
@@ -276,7 +276,7 @@ backbone_presets = {
|
|
|
276
276
|
"path": "pali_gemma2",
|
|
277
277
|
"model_card": "https://www.kaggle.com/models/google/paligemma-2",
|
|
278
278
|
},
|
|
279
|
-
"kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_28b_224/
|
|
279
|
+
"kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_28b_224/4",
|
|
280
280
|
},
|
|
281
281
|
"pali_gemma2_pt_28b_448": {
|
|
282
282
|
"metadata": {
|
|
@@ -291,7 +291,7 @@ backbone_presets = {
|
|
|
291
291
|
"path": "pali_gemma2",
|
|
292
292
|
"model_card": "https://www.kaggle.com/models/google/paligemma-2",
|
|
293
293
|
},
|
|
294
|
-
"kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_28b_448/
|
|
294
|
+
"kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_28b_448/3",
|
|
295
295
|
},
|
|
296
296
|
"pali_gemma2_pt_28b_896": {
|
|
297
297
|
"metadata": {
|
|
@@ -306,6 +306,6 @@ backbone_presets = {
|
|
|
306
306
|
"path": "pali_gemma2",
|
|
307
307
|
"model_card": "https://www.kaggle.com/models/google/paligemma-2",
|
|
308
308
|
},
|
|
309
|
-
"kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_28b_896/
|
|
309
|
+
"kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_28b_896/3",
|
|
310
310
|
},
|
|
311
311
|
}
|
|
@@ -329,7 +329,7 @@ class PaliGemmaVitEncoder(keras.layers.Layer):
|
|
|
329
329
|
# Fix the compatibility issue with Keras 3.1 where
|
|
330
330
|
# `compute_output_spec` fails to propagate `inputs_shape`
|
|
331
331
|
# correctly, causing it to be `None`.
|
|
332
|
-
|
|
332
|
+
return [None, None, self.hidden_dim]
|
|
333
333
|
return [
|
|
334
334
|
inputs_shape[0],
|
|
335
335
|
(inputs_shape[1] // self.patch_size) ** 2,
|
|
@@ -287,7 +287,9 @@ class QwenAttention(keras.layers.Layer):
|
|
|
287
287
|
if self.use_sliding_window_attention:
|
|
288
288
|
attention_mask = self._mask_sliding_window(
|
|
289
289
|
attention_mask,
|
|
290
|
-
cache_update_index=cache_update_index
|
|
290
|
+
cache_update_index=cache_update_index
|
|
291
|
+
if cache_update_index
|
|
292
|
+
else 0,
|
|
291
293
|
)
|
|
292
294
|
attention_scores = self._masked_softmax(
|
|
293
295
|
attention_scores, attention_mask
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import keras
|
|
2
2
|
from keras import ops
|
|
3
3
|
|
|
4
|
+
from keras_hub.src.api_export import keras_hub_export
|
|
4
5
|
from keras_hub.src.layers.modeling.reversible_embedding import (
|
|
5
6
|
ReversibleEmbedding,
|
|
6
7
|
)
|
|
@@ -13,6 +14,12 @@ def _qwen_kernel_initializer(stddev=0.02):
|
|
|
13
14
|
return keras.initializers.RandomNormal(stddev=stddev)
|
|
14
15
|
|
|
15
16
|
|
|
17
|
+
@keras_hub_export(
|
|
18
|
+
[
|
|
19
|
+
"keras_hub.models.QwenBackbone",
|
|
20
|
+
"keras_hub.models.Qwen2Backbone",
|
|
21
|
+
]
|
|
22
|
+
)
|
|
16
23
|
class QwenBackbone(Backbone):
|
|
17
24
|
"""
|
|
18
25
|
The Qwen Transformer core architecture with hyperparameters.
|
|
@@ -168,7 +175,7 @@ class QwenBackbone(Backbone):
|
|
|
168
175
|
self.layer_norm_epsilon = layer_norm_epsilon
|
|
169
176
|
self.dropout = dropout
|
|
170
177
|
self.tie_word_embeddings = tie_word_embeddings
|
|
171
|
-
self.use_sliding_window_attention =
|
|
178
|
+
self.use_sliding_window_attention = use_sliding_window_attention
|
|
172
179
|
self.sliding_window_size = sliding_window_size
|
|
173
180
|
|
|
174
181
|
def get_config(self):
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import keras
|
|
2
2
|
from keras import ops
|
|
3
3
|
|
|
4
|
+
from keras_hub.src.api_export import keras_hub_export
|
|
4
5
|
from keras_hub.src.models.causal_lm import CausalLM
|
|
5
6
|
from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone
|
|
6
7
|
from keras_hub.src.models.qwen.qwen_causal_lm_preprocessor import (
|
|
@@ -9,6 +10,12 @@ from keras_hub.src.models.qwen.qwen_causal_lm_preprocessor import (
|
|
|
9
10
|
from keras_hub.src.utils.tensor_utils import any_equal
|
|
10
11
|
|
|
11
12
|
|
|
13
|
+
@keras_hub_export(
|
|
14
|
+
[
|
|
15
|
+
"keras_hub.models.QwenCausalLM",
|
|
16
|
+
"keras_hub.models.Qwen2CausalLM",
|
|
17
|
+
]
|
|
18
|
+
)
|
|
12
19
|
class QwenCausalLM(CausalLM):
|
|
13
20
|
backbone_cls = QwenBackbone
|
|
14
21
|
preprocessor_cls = QwenCausalLMPreprocessor
|
|
@@ -1,8 +1,15 @@
|
|
|
1
|
+
from keras_hub.src.api_export import keras_hub_export
|
|
1
2
|
from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor
|
|
2
3
|
from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone
|
|
3
4
|
from keras_hub.src.models.qwen.qwen_tokenizer import QwenTokenizer
|
|
4
5
|
|
|
5
6
|
|
|
7
|
+
@keras_hub_export(
|
|
8
|
+
[
|
|
9
|
+
"keras_hub.models.QwenCausalLMPreprocessor",
|
|
10
|
+
"keras_hub.models.Qwen2CausalLMPreprocessor",
|
|
11
|
+
]
|
|
12
|
+
)
|
|
6
13
|
class QwenCausalLMPreprocessor(CausalLMPreprocessor):
|
|
7
14
|
backbone_cls = QwenBackbone
|
|
8
15
|
tokenizer_cls = QwenTokenizer
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
"""Qwen preset configurations."""
|
|
2
|
+
|
|
3
|
+
backbone_presets = {
|
|
4
|
+
"qwen2.5_0.5b_en": {
|
|
5
|
+
"metadata": {
|
|
6
|
+
"description": ("24-layer Qwen model with 0.5 billion parameters."),
|
|
7
|
+
"params": 494032768,
|
|
8
|
+
"path": "qwen",
|
|
9
|
+
},
|
|
10
|
+
"kaggle_handle": "kaggle://keras/qwen/keras/qwen2.5_0.5b_en/1",
|
|
11
|
+
},
|
|
12
|
+
"qwen2.5_3b_en": {
|
|
13
|
+
"metadata": {
|
|
14
|
+
"description": ("36-layer Qwen model with 3.1 billion parameters."),
|
|
15
|
+
"params": 3085938688,
|
|
16
|
+
"path": "qwen",
|
|
17
|
+
},
|
|
18
|
+
"kaggle_handle": "kaggle://keras/qwen/keras/qwen2.5_3b_en/1",
|
|
19
|
+
},
|
|
20
|
+
"qwen2.5_7b_en": {
|
|
21
|
+
"metadata": {
|
|
22
|
+
"description": ("48-layer Qwen model with 7 billion parameters."),
|
|
23
|
+
"params": 6993420288,
|
|
24
|
+
"path": "qwen",
|
|
25
|
+
},
|
|
26
|
+
"kaggle_handle": "kaggle://keras/qwen/keras/qwen2.5_7b_en/3",
|
|
27
|
+
},
|
|
28
|
+
"qwen2.5_instruct_0.5b_en": {
|
|
29
|
+
"metadata": {
|
|
30
|
+
"description": (
|
|
31
|
+
"Instruction fine-tuned 24-layer Qwen model with 0.5 ",
|
|
32
|
+
"billion parameters.",
|
|
33
|
+
),
|
|
34
|
+
"params": 494032768,
|
|
35
|
+
"path": "qwen",
|
|
36
|
+
},
|
|
37
|
+
"kaggle_handle": "kaggle://keras/qwen/keras/qwen2.5_instruct_0.5b_en/1",
|
|
38
|
+
},
|
|
39
|
+
"qwen2.5_instruct_32b_en": {
|
|
40
|
+
"metadata": {
|
|
41
|
+
"description": (
|
|
42
|
+
"Instruction fine-tuned 64-layer Qwen model with 32 ",
|
|
43
|
+
"billion parameters.",
|
|
44
|
+
),
|
|
45
|
+
"params": 32763876352,
|
|
46
|
+
"path": "qwen",
|
|
47
|
+
},
|
|
48
|
+
"kaggle_handle": "kaggle://keras/qwen/keras/qwen2.5_instruct_32b_en/2",
|
|
49
|
+
},
|
|
50
|
+
"qwen2.5_instruct_72b_en": {
|
|
51
|
+
"metadata": {
|
|
52
|
+
"description": (
|
|
53
|
+
"Instruction fine-tuned 80-layer Qwen model with 72 ",
|
|
54
|
+
"billion parameters.",
|
|
55
|
+
),
|
|
56
|
+
"params": 72706203648,
|
|
57
|
+
"path": "qwen",
|
|
58
|
+
},
|
|
59
|
+
"kaggle_handle": "kaggle://keras/qwen/keras/qwen2.5_instruct_72b_en/2",
|
|
60
|
+
},
|
|
61
|
+
}
|
|
@@ -1,7 +1,16 @@
|
|
|
1
|
+
from keras_hub.src.api_export import keras_hub_export
|
|
1
2
|
from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone
|
|
2
3
|
from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer
|
|
3
4
|
|
|
4
5
|
|
|
6
|
+
@keras_hub_export(
|
|
7
|
+
[
|
|
8
|
+
"keras_hub.tokenizers.QwenTokenizer",
|
|
9
|
+
"keras_hub.tokenizers.Qwen2Tokenizer",
|
|
10
|
+
"keras_hub.models.QwenTokenizer",
|
|
11
|
+
"keras_hub.models.Qwen2Tokenizer",
|
|
12
|
+
]
|
|
13
|
+
)
|
|
5
14
|
class QwenTokenizer(BytePairTokenizer):
|
|
6
15
|
"""Tokenizer for Qwen models.
|
|
7
16
|
|
|
@@ -0,0 +1,5 @@
|
|
|
1
|
+
from keras_hub.src.models.qwen_moe.qwen_moe_backbone import QwenMoeBackbone
|
|
2
|
+
from keras_hub.src.models.qwen_moe.qwen_moe_presets import backbone_presets
|
|
3
|
+
from keras_hub.src.utils.preset_utils import register_presets
|
|
4
|
+
|
|
5
|
+
register_presets(backbone_presets, QwenMoeBackbone)
|
|
@@ -0,0 +1,375 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import math
|
|
3
|
+
|
|
4
|
+
import keras
|
|
5
|
+
from keras import ops
|
|
6
|
+
|
|
7
|
+
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
|
|
8
|
+
from keras_hub.src.utils.keras_utils import clone_initializer
|
|
9
|
+
from keras_hub.src.utils.keras_utils import fused_attention_op_available
|
|
10
|
+
from keras_hub.src.utils.keras_utils import gpu_supports_fused_attention_op
|
|
11
|
+
from keras_hub.src.utils.keras_utils import running_on_gpu
|
|
12
|
+
from keras_hub.src.utils.keras_utils import running_on_tpu
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class QwenMoeAttention(keras.layers.Layer):
|
|
16
|
+
"""A multi-head attention layer for Qwen-Moe model
|
|
17
|
+
|
|
18
|
+
This attention implementation supports grouped-query attention (GQA) where
|
|
19
|
+
the number of key-value heads can be less than the number of query heads.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
num_query_heads: Number of query heads.
|
|
23
|
+
num_key_value_heads: Number of key/value heads (for GQA).
|
|
24
|
+
rope_max_wavelength: Maximum wavelength for RoPE (Rotary Position
|
|
25
|
+
Embedding).
|
|
26
|
+
rope_scaling_factor: Scaling factor for RoPE, used for extending
|
|
27
|
+
context length.
|
|
28
|
+
kernel_initializer: Initializer for the kernel weights.
|
|
29
|
+
bias_initializer: Initializer for the bias weights.
|
|
30
|
+
dropout: Dropout rate for attention weights.
|
|
31
|
+
use_sliding_window_attention: Whether to use sliding window
|
|
32
|
+
attention.
|
|
33
|
+
sliding_window_size: Size of the sliding window for attention.
|
|
34
|
+
**kwargs: Additional keyword arguments to pass to the Layer.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
num_query_heads,
|
|
40
|
+
num_key_value_heads,
|
|
41
|
+
rope_max_wavelength=10000,
|
|
42
|
+
rope_scaling_factor=1,
|
|
43
|
+
kernel_initializer="glorot_uniform",
|
|
44
|
+
bias_initializer="zeros",
|
|
45
|
+
dropout=0,
|
|
46
|
+
use_sliding_window_attention=False,
|
|
47
|
+
sliding_window_size=4096,
|
|
48
|
+
**kwargs,
|
|
49
|
+
):
|
|
50
|
+
super().__init__(
|
|
51
|
+
**kwargs,
|
|
52
|
+
)
|
|
53
|
+
self.num_query_heads = num_query_heads
|
|
54
|
+
self.num_key_value_heads = num_key_value_heads
|
|
55
|
+
self.dropout = dropout
|
|
56
|
+
|
|
57
|
+
self.num_key_value_groups = num_query_heads // num_key_value_heads
|
|
58
|
+
self.rope_max_wavelength = rope_max_wavelength
|
|
59
|
+
|
|
60
|
+
self.kernel_initializer = keras.initializers.get(
|
|
61
|
+
clone_initializer(kernel_initializer)
|
|
62
|
+
)
|
|
63
|
+
self.bias_initializer = keras.initializers.get(
|
|
64
|
+
clone_initializer(bias_initializer)
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
self.rope_scaling_factor = rope_scaling_factor
|
|
68
|
+
self.use_sliding_window_attention = use_sliding_window_attention
|
|
69
|
+
self.sliding_window_size = sliding_window_size
|
|
70
|
+
|
|
71
|
+
def build(self, inputs_shape):
|
|
72
|
+
# Einsum variables:
|
|
73
|
+
# b = batch size
|
|
74
|
+
# q = query length
|
|
75
|
+
# k = key/value length
|
|
76
|
+
# m = model dim
|
|
77
|
+
# u = num query heads
|
|
78
|
+
# v = num key/value heads
|
|
79
|
+
# h = head dim
|
|
80
|
+
hidden_dim = inputs_shape[-1]
|
|
81
|
+
head_dim = hidden_dim // self.num_query_heads
|
|
82
|
+
self._inv_norm_factor = 1.0 / math.sqrt(head_dim)
|
|
83
|
+
self.query_dense = keras.layers.EinsumDense(
|
|
84
|
+
equation="bqm,muh->bquh",
|
|
85
|
+
output_shape=(None, self.num_query_heads, head_dim),
|
|
86
|
+
kernel_initializer=self.kernel_initializer,
|
|
87
|
+
bias_initializer=self.bias_initializer,
|
|
88
|
+
bias_axes="uh",
|
|
89
|
+
dtype=self.dtype_policy,
|
|
90
|
+
name="query",
|
|
91
|
+
)
|
|
92
|
+
self.query_dense.build(inputs_shape)
|
|
93
|
+
|
|
94
|
+
self.key_dense = keras.layers.EinsumDense(
|
|
95
|
+
equation="bkm,mvh->bkvh",
|
|
96
|
+
output_shape=(
|
|
97
|
+
None,
|
|
98
|
+
self.num_key_value_heads,
|
|
99
|
+
head_dim,
|
|
100
|
+
),
|
|
101
|
+
kernel_initializer=self.kernel_initializer,
|
|
102
|
+
bias_initializer=self.bias_initializer,
|
|
103
|
+
bias_axes="vh",
|
|
104
|
+
dtype=self.dtype_policy,
|
|
105
|
+
name="key",
|
|
106
|
+
)
|
|
107
|
+
self.key_dense.build(inputs_shape)
|
|
108
|
+
|
|
109
|
+
self.value_dense = keras.layers.EinsumDense(
|
|
110
|
+
equation="bkm,mvh->bkvh",
|
|
111
|
+
output_shape=(
|
|
112
|
+
None,
|
|
113
|
+
self.num_key_value_heads,
|
|
114
|
+
head_dim,
|
|
115
|
+
),
|
|
116
|
+
kernel_initializer=self.kernel_initializer,
|
|
117
|
+
bias_initializer=self.bias_initializer,
|
|
118
|
+
bias_axes="vh",
|
|
119
|
+
dtype=self.dtype_policy,
|
|
120
|
+
name="value",
|
|
121
|
+
)
|
|
122
|
+
self.value_dense.build(inputs_shape)
|
|
123
|
+
|
|
124
|
+
self._softmax = keras.layers.Softmax(
|
|
125
|
+
axis=-1,
|
|
126
|
+
dtype="float32",
|
|
127
|
+
name="attention_softmax",
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
self._dropout_layer = keras.layers.Dropout(
|
|
131
|
+
rate=self.dropout,
|
|
132
|
+
dtype=self.dtype_policy,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
self._output_dense = keras.layers.EinsumDense(
|
|
136
|
+
equation="bquh,uhm->bqm",
|
|
137
|
+
output_shape=(None, hidden_dim),
|
|
138
|
+
kernel_initializer=self.kernel_initializer,
|
|
139
|
+
dtype=self.dtype_policy,
|
|
140
|
+
name="attention_output",
|
|
141
|
+
)
|
|
142
|
+
self._output_dense.build((None, None, self.num_query_heads, head_dim))
|
|
143
|
+
|
|
144
|
+
self.rotary_embedding_layer = RotaryEmbedding(
|
|
145
|
+
max_wavelength=self.rope_max_wavelength,
|
|
146
|
+
scaling_factor=self.rope_scaling_factor,
|
|
147
|
+
dtype=self.dtype_policy,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
self._dot_product_equation = "bquh,bkuh->buqk"
|
|
151
|
+
self._combine_equation = "buqk,bkuh->bquh"
|
|
152
|
+
|
|
153
|
+
self.built = True
|
|
154
|
+
|
|
155
|
+
def call(
|
|
156
|
+
self,
|
|
157
|
+
hidden_states,
|
|
158
|
+
attention_mask=None,
|
|
159
|
+
cache=None,
|
|
160
|
+
cache_update_index=None,
|
|
161
|
+
training=None,
|
|
162
|
+
):
|
|
163
|
+
"""Applies attention mechanism to the input hidden states.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
hidden_states: Input tensor of shape [batch_size, seq_length,
|
|
167
|
+
hidden_size].
|
|
168
|
+
attention_mask: Mask tensor of shape [batch_size, seq_length,
|
|
169
|
+
seq_length].
|
|
170
|
+
cache: Optional cached key and value tensors.
|
|
171
|
+
cache_update_index: Index at which to update the cache.
|
|
172
|
+
training: Boolean indicating whether in training mode.
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
attention_output: Output tensor after applying attention.
|
|
176
|
+
cache: Updated cache tensors (if cache is provided).
|
|
177
|
+
"""
|
|
178
|
+
start_index = (
|
|
179
|
+
cache_update_index if cache_update_index is not None else 0
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
query = self.query_dense(hidden_states)
|
|
183
|
+
|
|
184
|
+
# Compute RoPE for queries
|
|
185
|
+
query = self.rotary_embedding_layer(query, start_index=start_index)
|
|
186
|
+
|
|
187
|
+
def _compute_key_value(x):
|
|
188
|
+
key, value = self.key_dense(x), self.value_dense(x)
|
|
189
|
+
# Compute RoPE for keys
|
|
190
|
+
key = self.rotary_embedding_layer(key, start_index=start_index)
|
|
191
|
+
return key, value
|
|
192
|
+
|
|
193
|
+
if cache is not None:
|
|
194
|
+
key_cache = cache[:, 0, ...]
|
|
195
|
+
value_cache = cache[:, 1, ...]
|
|
196
|
+
if cache_update_index is None:
|
|
197
|
+
key = key_cache
|
|
198
|
+
value = value_cache
|
|
199
|
+
else:
|
|
200
|
+
key_update, value_update = _compute_key_value(hidden_states)
|
|
201
|
+
start = [0, cache_update_index, 0, 0]
|
|
202
|
+
key = ops.slice_update(key_cache, start, key_update)
|
|
203
|
+
value = ops.slice_update(value_cache, start, value_update)
|
|
204
|
+
cache = ops.stack((key, value), axis=1)
|
|
205
|
+
else:
|
|
206
|
+
if cache_update_index is not None:
|
|
207
|
+
raise ValueError(
|
|
208
|
+
"`cache_update_index` should not be set if `cache` is "
|
|
209
|
+
f"`None`. Received: cache={cache}, "
|
|
210
|
+
f"cache_update_index={cache_update_index}"
|
|
211
|
+
)
|
|
212
|
+
key, value = _compute_key_value(hidden_states)
|
|
213
|
+
|
|
214
|
+
# [batch_shape, seq_len, num_key_value_heads, head_dim]
|
|
215
|
+
# -> [batch_shape, seq_len, num_heads, head_dim]
|
|
216
|
+
key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2)
|
|
217
|
+
value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2)
|
|
218
|
+
|
|
219
|
+
attention_output = self._compute_attention(
|
|
220
|
+
query,
|
|
221
|
+
key,
|
|
222
|
+
value,
|
|
223
|
+
attention_mask,
|
|
224
|
+
cache_update_index=cache_update_index,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
attention_output = self._dropout_layer(
|
|
228
|
+
attention_output, training=training
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
attention_output = self._output_dense(attention_output)
|
|
232
|
+
|
|
233
|
+
if cache is not None:
|
|
234
|
+
return attention_output, cache
|
|
235
|
+
return attention_output
|
|
236
|
+
|
|
237
|
+
def _masked_softmax(self, attention_scores, attention_mask=None):
|
|
238
|
+
"""Applies softmax with optional masking.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
attention_scores: Attention score tensor.
|
|
242
|
+
attention_mask: Optional mask tensor.
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
Masked softmax attention weights.
|
|
246
|
+
"""
|
|
247
|
+
if attention_mask is not None:
|
|
248
|
+
return self._softmax(
|
|
249
|
+
attention_scores, attention_mask[:, None, :, :]
|
|
250
|
+
)
|
|
251
|
+
return self._softmax(attention_scores)
|
|
252
|
+
|
|
253
|
+
def _use_fused_attention_op(self):
|
|
254
|
+
if not fused_attention_op_available():
|
|
255
|
+
return False
|
|
256
|
+
if self.dropout > 0.0:
|
|
257
|
+
return False
|
|
258
|
+
if running_on_gpu():
|
|
259
|
+
return gpu_supports_fused_attention_op()
|
|
260
|
+
elif running_on_tpu():
|
|
261
|
+
# TPU supports softcap with on keras >= 3.10.
|
|
262
|
+
sig = inspect.signature(ops.dot_product_attention)
|
|
263
|
+
return "attn_logits_soft_cap" in sig.parameters
|
|
264
|
+
else:
|
|
265
|
+
return False
|
|
266
|
+
|
|
267
|
+
def _compute_attention(
|
|
268
|
+
self,
|
|
269
|
+
query,
|
|
270
|
+
key,
|
|
271
|
+
value,
|
|
272
|
+
attention_mask=None,
|
|
273
|
+
cache_update_index=None,
|
|
274
|
+
**kwargs,
|
|
275
|
+
):
|
|
276
|
+
"""Computes attention using query, key, and value tensors.
|
|
277
|
+
|
|
278
|
+
Uses Flash Attention when available for better performance.
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
query: Query tensor.
|
|
282
|
+
key: Key tensor.
|
|
283
|
+
value: Value tensor.
|
|
284
|
+
attention_mask: Optional mask tensor.
|
|
285
|
+
cache_update_index: Index for sliding window computation.
|
|
286
|
+
|
|
287
|
+
Returns:
|
|
288
|
+
attention_output: Output tensor after applying attention.
|
|
289
|
+
"""
|
|
290
|
+
if self._use_fused_attention_op():
|
|
291
|
+
if attention_mask is not None:
|
|
292
|
+
attention_mask = ops.expand_dims(attention_mask, axis=1)
|
|
293
|
+
attention_mask = ops.cast(attention_mask, dtype="bool")
|
|
294
|
+
|
|
295
|
+
attention_output = ops.dot_product_attention(
|
|
296
|
+
query,
|
|
297
|
+
key,
|
|
298
|
+
value,
|
|
299
|
+
mask=attention_mask,
|
|
300
|
+
scale=self._inv_norm_factor,
|
|
301
|
+
**kwargs,
|
|
302
|
+
)
|
|
303
|
+
return attention_output
|
|
304
|
+
|
|
305
|
+
attention_scores = ops.einsum(self._dot_product_equation, query, key)
|
|
306
|
+
|
|
307
|
+
attention_scores = ops.multiply(
|
|
308
|
+
attention_scores,
|
|
309
|
+
ops.cast(self._inv_norm_factor, self.compute_dtype),
|
|
310
|
+
)
|
|
311
|
+
if self.use_sliding_window_attention:
|
|
312
|
+
attention_mask = self._mask_sliding_window(
|
|
313
|
+
attention_mask,
|
|
314
|
+
cache_update_index=cache_update_index
|
|
315
|
+
if cache_update_index
|
|
316
|
+
else 0,
|
|
317
|
+
)
|
|
318
|
+
attention_scores = self._masked_softmax(
|
|
319
|
+
attention_scores, attention_mask
|
|
320
|
+
)
|
|
321
|
+
attention_scores = ops.cast(attention_scores, self.compute_dtype)
|
|
322
|
+
attention_output = ops.einsum(
|
|
323
|
+
self._combine_equation, attention_scores, value
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
return attention_output
|
|
327
|
+
|
|
328
|
+
def _mask_sliding_window(
|
|
329
|
+
self,
|
|
330
|
+
attention_mask,
|
|
331
|
+
cache_update_index=0,
|
|
332
|
+
):
|
|
333
|
+
"""Creates and combines a sliding window mask with the attention mask.
|
|
334
|
+
|
|
335
|
+
Args:
|
|
336
|
+
attention_mask: Original attention mask.
|
|
337
|
+
cache_update_index: Starting index for the sliding window.
|
|
338
|
+
|
|
339
|
+
Returns:
|
|
340
|
+
Combined attention mask with sliding window constraints.
|
|
341
|
+
"""
|
|
342
|
+
_, query_len, key_len = ops.shape(attention_mask)
|
|
343
|
+
# Compute the sliding window for square attention.
|
|
344
|
+
all_ones = ops.ones((key_len, key_len), "bool")
|
|
345
|
+
sliding_mask = ops.triu(
|
|
346
|
+
all_ones, -1 * self.sliding_window_size + 1
|
|
347
|
+
) * ops.tril(all_ones, self.sliding_window_size - 1)
|
|
348
|
+
# Slice the window for short queries during generation.
|
|
349
|
+
start = (cache_update_index, 0)
|
|
350
|
+
sliding_mask = ops.slice(sliding_mask, start, (query_len, key_len))
|
|
351
|
+
sliding_mask = ops.expand_dims(sliding_mask, 0)
|
|
352
|
+
return ops.logical_and(attention_mask, ops.cast(sliding_mask, "bool"))
|
|
353
|
+
|
|
354
|
+
def get_config(self):
|
|
355
|
+
config = super().get_config()
|
|
356
|
+
config.update(
|
|
357
|
+
{
|
|
358
|
+
"num_query_heads": self.num_query_heads,
|
|
359
|
+
"num_key_value_heads": self.num_key_value_heads,
|
|
360
|
+
"rope_max_wavelength": self.rope_max_wavelength,
|
|
361
|
+
"rope_scaling_factor": self.rope_scaling_factor,
|
|
362
|
+
"kernel_initializer": keras.initializers.serialize(
|
|
363
|
+
self.kernel_initializer
|
|
364
|
+
),
|
|
365
|
+
"bias_initializer": keras.initializers.serialize(
|
|
366
|
+
self.bias_initializer
|
|
367
|
+
),
|
|
368
|
+
"dropout": self.dropout,
|
|
369
|
+
"use_sliding_window_attention": (
|
|
370
|
+
self.use_sliding_window_attention
|
|
371
|
+
),
|
|
372
|
+
"sliding_window_size": self.sliding_window_size,
|
|
373
|
+
}
|
|
374
|
+
)
|
|
375
|
+
return config
|