keras-hub 0.20.0.dev1__py3-none-any.whl → 0.21.0.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. keras_hub/__init__.py +15 -33
  2. keras_hub/layers/__init__.py +134 -0
  3. keras_hub/metrics/__init__.py +11 -0
  4. keras_hub/models/__init__.py +642 -0
  5. keras_hub/samplers/__init__.py +18 -0
  6. keras_hub/src/layers/modeling/reversible_embedding.py +25 -35
  7. keras_hub/src/layers/preprocessing/image_converter.py +1 -0
  8. keras_hub/src/layers/preprocessing/random_deletion.py +1 -1
  9. keras_hub/src/layers/preprocessing/random_swap.py +1 -1
  10. keras_hub/src/models/audio_to_text.py +66 -0
  11. keras_hub/src/models/audio_to_text_preprocessor.py +80 -0
  12. keras_hub/src/models/backbone.py +5 -2
  13. keras_hub/src/models/cspnet/cspnet_backbone.py +51 -26
  14. keras_hub/src/models/cspnet/cspnet_presets.py +38 -3
  15. keras_hub/src/models/falcon/falcon_backbone.py +1 -1
  16. keras_hub/src/models/gemma/gemma_presets.py +10 -10
  17. keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +3 -2
  18. keras_hub/src/models/gemma3/gemma3_presets.py +8 -8
  19. keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
  20. keras_hub/src/models/llama/llama_attention.py +24 -6
  21. keras_hub/src/models/llama/llama_backbone.py +50 -16
  22. keras_hub/src/models/llama/llama_decoder.py +20 -3
  23. keras_hub/src/models/llama/llama_presets.py +3 -3
  24. keras_hub/src/models/llama/llama_rotary_embedding.py +180 -0
  25. keras_hub/src/models/llama3/llama3_backbone.py +10 -2
  26. keras_hub/src/models/llama3/llama3_presets.py +84 -2
  27. keras_hub/src/models/mistral/mistral_presets.py +3 -3
  28. keras_hub/src/models/mixtral/__init__.py +5 -0
  29. keras_hub/src/models/mixtral/mixtral_attention.py +252 -0
  30. keras_hub/src/models/mixtral/mixtral_backbone.py +207 -0
  31. keras_hub/src/models/mixtral/mixtral_causal_lm.py +281 -0
  32. keras_hub/src/models/mixtral/mixtral_causal_lm_preprocessor.py +76 -0
  33. keras_hub/src/models/mixtral/mixtral_decoder.py +494 -0
  34. keras_hub/src/models/mixtral/mixtral_layer_norm.py +34 -0
  35. keras_hub/src/models/mixtral/mixtral_presets.py +26 -0
  36. keras_hub/src/models/mixtral/mixtral_tokenizer.py +21 -0
  37. keras_hub/src/models/moonshine/__init__.py +5 -0
  38. keras_hub/src/models/moonshine/moonshine_audio_converter.py +301 -0
  39. keras_hub/src/models/moonshine/moonshine_audio_to_text.py +383 -0
  40. keras_hub/src/models/moonshine/moonshine_audio_to_text_preprocessor.py +272 -0
  41. keras_hub/src/models/moonshine/moonshine_backbone.py +478 -0
  42. keras_hub/src/models/moonshine/moonshine_decoder.py +313 -0
  43. keras_hub/src/models/moonshine/moonshine_encoder.py +212 -0
  44. keras_hub/src/models/moonshine/moonshine_layers.py +239 -0
  45. keras_hub/src/models/moonshine/moonshine_multi_head_attention.py +355 -0
  46. keras_hub/src/models/moonshine/moonshine_presets.py +25 -0
  47. keras_hub/src/models/moonshine/moonshine_tokenizer.py +62 -0
  48. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +11 -11
  49. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +1 -1
  50. keras_hub/src/models/qwen/__init__.py +4 -0
  51. keras_hub/src/models/qwen/qwen_attention.py +3 -1
  52. keras_hub/src/models/qwen/qwen_backbone.py +8 -1
  53. keras_hub/src/models/qwen/qwen_causal_lm.py +7 -0
  54. keras_hub/src/models/qwen/qwen_causal_lm_preprocessor.py +7 -0
  55. keras_hub/src/models/qwen/qwen_presets.py +61 -0
  56. keras_hub/src/models/qwen/qwen_tokenizer.py +9 -0
  57. keras_hub/src/models/qwen_moe/__init__.py +5 -0
  58. keras_hub/src/models/qwen_moe/qwen_moe_attention.py +375 -0
  59. keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +373 -0
  60. keras_hub/src/models/qwen_moe/qwen_moe_causal_lm.py +350 -0
  61. keras_hub/src/models/qwen_moe/qwen_moe_causal_lm_preprocessor.py +17 -0
  62. keras_hub/src/models/qwen_moe/qwen_moe_decoder.py +625 -0
  63. keras_hub/src/models/qwen_moe/qwen_moe_layernorm.py +32 -0
  64. keras_hub/src/models/qwen_moe/qwen_moe_presets.py +15 -0
  65. keras_hub/src/models/qwen_moe/qwen_moe_tokenizer.py +46 -0
  66. keras_hub/src/models/retinanet/retinanet_image_converter.py +0 -13
  67. keras_hub/src/models/retinanet/retinanet_presets.py +2 -2
  68. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +0 -18
  69. keras_hub/src/models/segformer/segformer_presets.py +12 -12
  70. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +6 -0
  71. keras_hub/src/models/task.py +5 -2
  72. keras_hub/src/models/xception/__init__.py +5 -0
  73. keras_hub/src/models/xception/xception_backbone.py +188 -0
  74. keras_hub/src/models/xception/xception_image_classifier.py +12 -0
  75. keras_hub/src/models/xception/xception_image_classifier_preprocessor.py +14 -0
  76. keras_hub/src/models/xception/xception_image_converter.py +8 -0
  77. keras_hub/src/models/xception/xception_presets.py +14 -0
  78. keras_hub/src/tests/mocks/mock_gemma3_tokenizer.py +155 -0
  79. keras_hub/src/utils/coco/__init__.py +0 -0
  80. keras_hub/src/utils/coco/coco_utils.py +133 -0
  81. keras_hub/src/utils/imagenet/imagenet_utils.py +36 -0
  82. keras_hub/src/utils/keras_utils.py +11 -0
  83. keras_hub/src/utils/preset_utils.py +70 -10
  84. keras_hub/src/utils/tensor_utils.py +27 -1
  85. keras_hub/src/utils/timm/convert_cspnet.py +94 -23
  86. keras_hub/src/utils/timm/preset_loader.py +6 -6
  87. keras_hub/src/utils/transformers/convert_llama3.py +21 -1
  88. keras_hub/src/utils/transformers/convert_mixtral.py +139 -0
  89. keras_hub/src/utils/transformers/convert_qwen.py +1 -0
  90. keras_hub/src/utils/transformers/convert_qwen_moe.py +253 -0
  91. keras_hub/src/utils/transformers/preset_loader.py +6 -0
  92. keras_hub/src/{version_utils.py → version.py} +1 -1
  93. keras_hub/tokenizers/__init__.py +117 -0
  94. keras_hub/utils/__init__.py +21 -0
  95. {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/METADATA +6 -20
  96. {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/RECORD +98 -55
  97. {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/WHEEL +1 -1
  98. keras_hub/api/__init__.py +0 -15
  99. keras_hub/api/layers/__init__.py +0 -86
  100. keras_hub/api/metrics/__init__.py +0 -11
  101. keras_hub/api/models/__init__.py +0 -416
  102. keras_hub/api/samplers/__init__.py +0 -16
  103. keras_hub/api/tokenizers/__init__.py +0 -58
  104. keras_hub/api/utils/__init__.py +0 -9
  105. {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/top_level.txt +0 -0
@@ -81,7 +81,7 @@ backbone_presets = {
81
81
  "path": "pali_gemma2",
82
82
  "model_card": "https://www.kaggle.com/models/google/paligemma-2",
83
83
  },
84
- "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_ft_docci_10b_448/2",
84
+ "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_ft_docci_10b_448/3",
85
85
  },
86
86
  "pali_gemma2_mix_3b_224": {
87
87
  "metadata": {
@@ -126,7 +126,7 @@ backbone_presets = {
126
126
  "path": "pali_gemma2",
127
127
  "model_card": "https://www.kaggle.com/models/google/paligemma-2",
128
128
  },
129
- "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_mix_10b_224/2",
129
+ "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_mix_10b_224/3",
130
130
  },
131
131
  "pali_gemma2_mix_10b_448": {
132
132
  "metadata": {
@@ -141,7 +141,7 @@ backbone_presets = {
141
141
  "path": "pali_gemma2",
142
142
  "model_card": "https://www.kaggle.com/models/google/paligemma-2",
143
143
  },
144
- "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_mix_10b_448/2",
144
+ "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_mix_10b_448/3",
145
145
  },
146
146
  "pali_gemma2_mix_28b_224": {
147
147
  "metadata": {
@@ -156,7 +156,7 @@ backbone_presets = {
156
156
  "path": "pali_gemma2",
157
157
  "model_card": "https://www.kaggle.com/models/google/paligemma-2",
158
158
  },
159
- "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_28b_mix_224/2",
159
+ "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_28b_mix_224/3",
160
160
  },
161
161
  "pali_gemma2_mix_28b_448": {
162
162
  "metadata": {
@@ -171,7 +171,7 @@ backbone_presets = {
171
171
  "path": "pali_gemma2",
172
172
  "model_card": "https://www.kaggle.com/models/google/paligemma-2",
173
173
  },
174
- "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_28b_mix_448/2",
174
+ "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_28b_mix_448/3",
175
175
  },
176
176
  "pali_gemma2_pt_3b_224": {
177
177
  "metadata": {
@@ -231,7 +231,7 @@ backbone_presets = {
231
231
  "path": "pali_gemma2",
232
232
  "model_card": "https://www.kaggle.com/models/google/paligemma-2",
233
233
  },
234
- "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_10b_224/2",
234
+ "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_10b_224/3",
235
235
  },
236
236
  "pali_gemma2_pt_10b_448": {
237
237
  "metadata": {
@@ -246,7 +246,7 @@ backbone_presets = {
246
246
  "path": "pali_gemma2",
247
247
  "model_card": "https://www.kaggle.com/models/google/paligemma-2",
248
248
  },
249
- "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_10b_448/2",
249
+ "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_10b_448/3",
250
250
  },
251
251
  "pali_gemma2_pt_10b_896": {
252
252
  "metadata": {
@@ -261,7 +261,7 @@ backbone_presets = {
261
261
  "path": "pali_gemma2",
262
262
  "model_card": "https://www.kaggle.com/models/google/paligemma-2",
263
263
  },
264
- "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_10b_896/2",
264
+ "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_10b_896/3",
265
265
  },
266
266
  "pali_gemma2_pt_28b_224": {
267
267
  "metadata": {
@@ -276,7 +276,7 @@ backbone_presets = {
276
276
  "path": "pali_gemma2",
277
277
  "model_card": "https://www.kaggle.com/models/google/paligemma-2",
278
278
  },
279
- "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_28b_224/3",
279
+ "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_28b_224/4",
280
280
  },
281
281
  "pali_gemma2_pt_28b_448": {
282
282
  "metadata": {
@@ -291,7 +291,7 @@ backbone_presets = {
291
291
  "path": "pali_gemma2",
292
292
  "model_card": "https://www.kaggle.com/models/google/paligemma-2",
293
293
  },
294
- "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_28b_448/2",
294
+ "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_28b_448/3",
295
295
  },
296
296
  "pali_gemma2_pt_28b_896": {
297
297
  "metadata": {
@@ -306,6 +306,6 @@ backbone_presets = {
306
306
  "path": "pali_gemma2",
307
307
  "model_card": "https://www.kaggle.com/models/google/paligemma-2",
308
308
  },
309
- "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_28b_896/2",
309
+ "kaggle_handle": "kaggle://keras/paligemma2/keras/pali_gemma2_pt_28b_896/3",
310
310
  },
311
311
  }
@@ -329,7 +329,7 @@ class PaliGemmaVitEncoder(keras.layers.Layer):
329
329
  # Fix the compatibility issue with Keras 3.1 where
330
330
  # `compute_output_spec` fails to propagate `inputs_shape`
331
331
  # correctly, causing it to be `None`.
332
- inputs_shape = [None, None, None]
332
+ return [None, None, self.hidden_dim]
333
333
  return [
334
334
  inputs_shape[0],
335
335
  (inputs_shape[1] // self.patch_size) ** 2,
@@ -1 +1,5 @@
1
1
  from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone
2
+ from keras_hub.src.models.qwen.qwen_presets import backbone_presets
3
+ from keras_hub.src.utils.preset_utils import register_presets
4
+
5
+ register_presets(backbone_presets, QwenBackbone)
@@ -287,7 +287,9 @@ class QwenAttention(keras.layers.Layer):
287
287
  if self.use_sliding_window_attention:
288
288
  attention_mask = self._mask_sliding_window(
289
289
  attention_mask,
290
- cache_update_index=cache_update_index,
290
+ cache_update_index=cache_update_index
291
+ if cache_update_index
292
+ else 0,
291
293
  )
292
294
  attention_scores = self._masked_softmax(
293
295
  attention_scores, attention_mask
@@ -1,6 +1,7 @@
1
1
  import keras
2
2
  from keras import ops
3
3
 
4
+ from keras_hub.src.api_export import keras_hub_export
4
5
  from keras_hub.src.layers.modeling.reversible_embedding import (
5
6
  ReversibleEmbedding,
6
7
  )
@@ -13,6 +14,12 @@ def _qwen_kernel_initializer(stddev=0.02):
13
14
  return keras.initializers.RandomNormal(stddev=stddev)
14
15
 
15
16
 
17
+ @keras_hub_export(
18
+ [
19
+ "keras_hub.models.QwenBackbone",
20
+ "keras_hub.models.Qwen2Backbone",
21
+ ]
22
+ )
16
23
  class QwenBackbone(Backbone):
17
24
  """
18
25
  The Qwen Transformer core architecture with hyperparameters.
@@ -168,7 +175,7 @@ class QwenBackbone(Backbone):
168
175
  self.layer_norm_epsilon = layer_norm_epsilon
169
176
  self.dropout = dropout
170
177
  self.tie_word_embeddings = tie_word_embeddings
171
- self.use_sliding_window_attention = (use_sliding_window_attention,)
178
+ self.use_sliding_window_attention = use_sliding_window_attention
172
179
  self.sliding_window_size = sliding_window_size
173
180
 
174
181
  def get_config(self):
@@ -1,6 +1,7 @@
1
1
  import keras
2
2
  from keras import ops
3
3
 
4
+ from keras_hub.src.api_export import keras_hub_export
4
5
  from keras_hub.src.models.causal_lm import CausalLM
5
6
  from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone
6
7
  from keras_hub.src.models.qwen.qwen_causal_lm_preprocessor import (
@@ -9,6 +10,12 @@ from keras_hub.src.models.qwen.qwen_causal_lm_preprocessor import (
9
10
  from keras_hub.src.utils.tensor_utils import any_equal
10
11
 
11
12
 
13
+ @keras_hub_export(
14
+ [
15
+ "keras_hub.models.QwenCausalLM",
16
+ "keras_hub.models.Qwen2CausalLM",
17
+ ]
18
+ )
12
19
  class QwenCausalLM(CausalLM):
13
20
  backbone_cls = QwenBackbone
14
21
  preprocessor_cls = QwenCausalLMPreprocessor
@@ -1,8 +1,15 @@
1
+ from keras_hub.src.api_export import keras_hub_export
1
2
  from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor
2
3
  from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone
3
4
  from keras_hub.src.models.qwen.qwen_tokenizer import QwenTokenizer
4
5
 
5
6
 
7
+ @keras_hub_export(
8
+ [
9
+ "keras_hub.models.QwenCausalLMPreprocessor",
10
+ "keras_hub.models.Qwen2CausalLMPreprocessor",
11
+ ]
12
+ )
6
13
  class QwenCausalLMPreprocessor(CausalLMPreprocessor):
7
14
  backbone_cls = QwenBackbone
8
15
  tokenizer_cls = QwenTokenizer
@@ -0,0 +1,61 @@
1
+ """Qwen preset configurations."""
2
+
3
+ backbone_presets = {
4
+ "qwen2.5_0.5b_en": {
5
+ "metadata": {
6
+ "description": ("24-layer Qwen model with 0.5 billion parameters."),
7
+ "params": 494032768,
8
+ "path": "qwen",
9
+ },
10
+ "kaggle_handle": "kaggle://keras/qwen/keras/qwen2.5_0.5b_en/1",
11
+ },
12
+ "qwen2.5_3b_en": {
13
+ "metadata": {
14
+ "description": ("36-layer Qwen model with 3.1 billion parameters."),
15
+ "params": 3085938688,
16
+ "path": "qwen",
17
+ },
18
+ "kaggle_handle": "kaggle://keras/qwen/keras/qwen2.5_3b_en/1",
19
+ },
20
+ "qwen2.5_7b_en": {
21
+ "metadata": {
22
+ "description": ("48-layer Qwen model with 7 billion parameters."),
23
+ "params": 6993420288,
24
+ "path": "qwen",
25
+ },
26
+ "kaggle_handle": "kaggle://keras/qwen/keras/qwen2.5_7b_en/3",
27
+ },
28
+ "qwen2.5_instruct_0.5b_en": {
29
+ "metadata": {
30
+ "description": (
31
+ "Instruction fine-tuned 24-layer Qwen model with 0.5 ",
32
+ "billion parameters.",
33
+ ),
34
+ "params": 494032768,
35
+ "path": "qwen",
36
+ },
37
+ "kaggle_handle": "kaggle://keras/qwen/keras/qwen2.5_instruct_0.5b_en/1",
38
+ },
39
+ "qwen2.5_instruct_32b_en": {
40
+ "metadata": {
41
+ "description": (
42
+ "Instruction fine-tuned 64-layer Qwen model with 32 ",
43
+ "billion parameters.",
44
+ ),
45
+ "params": 32763876352,
46
+ "path": "qwen",
47
+ },
48
+ "kaggle_handle": "kaggle://keras/qwen/keras/qwen2.5_instruct_32b_en/2",
49
+ },
50
+ "qwen2.5_instruct_72b_en": {
51
+ "metadata": {
52
+ "description": (
53
+ "Instruction fine-tuned 80-layer Qwen model with 72 ",
54
+ "billion parameters.",
55
+ ),
56
+ "params": 72706203648,
57
+ "path": "qwen",
58
+ },
59
+ "kaggle_handle": "kaggle://keras/qwen/keras/qwen2.5_instruct_72b_en/2",
60
+ },
61
+ }
@@ -1,7 +1,16 @@
1
+ from keras_hub.src.api_export import keras_hub_export
1
2
  from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone
2
3
  from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer
3
4
 
4
5
 
6
+ @keras_hub_export(
7
+ [
8
+ "keras_hub.tokenizers.QwenTokenizer",
9
+ "keras_hub.tokenizers.Qwen2Tokenizer",
10
+ "keras_hub.models.QwenTokenizer",
11
+ "keras_hub.models.Qwen2Tokenizer",
12
+ ]
13
+ )
5
14
  class QwenTokenizer(BytePairTokenizer):
6
15
  """Tokenizer for Qwen models.
7
16
 
@@ -0,0 +1,5 @@
1
+ from keras_hub.src.models.qwen_moe.qwen_moe_backbone import QwenMoeBackbone
2
+ from keras_hub.src.models.qwen_moe.qwen_moe_presets import backbone_presets
3
+ from keras_hub.src.utils.preset_utils import register_presets
4
+
5
+ register_presets(backbone_presets, QwenMoeBackbone)
@@ -0,0 +1,375 @@
1
+ import inspect
2
+ import math
3
+
4
+ import keras
5
+ from keras import ops
6
+
7
+ from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
8
+ from keras_hub.src.utils.keras_utils import clone_initializer
9
+ from keras_hub.src.utils.keras_utils import fused_attention_op_available
10
+ from keras_hub.src.utils.keras_utils import gpu_supports_fused_attention_op
11
+ from keras_hub.src.utils.keras_utils import running_on_gpu
12
+ from keras_hub.src.utils.keras_utils import running_on_tpu
13
+
14
+
15
+ class QwenMoeAttention(keras.layers.Layer):
16
+ """A multi-head attention layer for Qwen-Moe model
17
+
18
+ This attention implementation supports grouped-query attention (GQA) where
19
+ the number of key-value heads can be less than the number of query heads.
20
+
21
+ Args:
22
+ num_query_heads: Number of query heads.
23
+ num_key_value_heads: Number of key/value heads (for GQA).
24
+ rope_max_wavelength: Maximum wavelength for RoPE (Rotary Position
25
+ Embedding).
26
+ rope_scaling_factor: Scaling factor for RoPE, used for extending
27
+ context length.
28
+ kernel_initializer: Initializer for the kernel weights.
29
+ bias_initializer: Initializer for the bias weights.
30
+ dropout: Dropout rate for attention weights.
31
+ use_sliding_window_attention: Whether to use sliding window
32
+ attention.
33
+ sliding_window_size: Size of the sliding window for attention.
34
+ **kwargs: Additional keyword arguments to pass to the Layer.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ num_query_heads,
40
+ num_key_value_heads,
41
+ rope_max_wavelength=10000,
42
+ rope_scaling_factor=1,
43
+ kernel_initializer="glorot_uniform",
44
+ bias_initializer="zeros",
45
+ dropout=0,
46
+ use_sliding_window_attention=False,
47
+ sliding_window_size=4096,
48
+ **kwargs,
49
+ ):
50
+ super().__init__(
51
+ **kwargs,
52
+ )
53
+ self.num_query_heads = num_query_heads
54
+ self.num_key_value_heads = num_key_value_heads
55
+ self.dropout = dropout
56
+
57
+ self.num_key_value_groups = num_query_heads // num_key_value_heads
58
+ self.rope_max_wavelength = rope_max_wavelength
59
+
60
+ self.kernel_initializer = keras.initializers.get(
61
+ clone_initializer(kernel_initializer)
62
+ )
63
+ self.bias_initializer = keras.initializers.get(
64
+ clone_initializer(bias_initializer)
65
+ )
66
+
67
+ self.rope_scaling_factor = rope_scaling_factor
68
+ self.use_sliding_window_attention = use_sliding_window_attention
69
+ self.sliding_window_size = sliding_window_size
70
+
71
+ def build(self, inputs_shape):
72
+ # Einsum variables:
73
+ # b = batch size
74
+ # q = query length
75
+ # k = key/value length
76
+ # m = model dim
77
+ # u = num query heads
78
+ # v = num key/value heads
79
+ # h = head dim
80
+ hidden_dim = inputs_shape[-1]
81
+ head_dim = hidden_dim // self.num_query_heads
82
+ self._inv_norm_factor = 1.0 / math.sqrt(head_dim)
83
+ self.query_dense = keras.layers.EinsumDense(
84
+ equation="bqm,muh->bquh",
85
+ output_shape=(None, self.num_query_heads, head_dim),
86
+ kernel_initializer=self.kernel_initializer,
87
+ bias_initializer=self.bias_initializer,
88
+ bias_axes="uh",
89
+ dtype=self.dtype_policy,
90
+ name="query",
91
+ )
92
+ self.query_dense.build(inputs_shape)
93
+
94
+ self.key_dense = keras.layers.EinsumDense(
95
+ equation="bkm,mvh->bkvh",
96
+ output_shape=(
97
+ None,
98
+ self.num_key_value_heads,
99
+ head_dim,
100
+ ),
101
+ kernel_initializer=self.kernel_initializer,
102
+ bias_initializer=self.bias_initializer,
103
+ bias_axes="vh",
104
+ dtype=self.dtype_policy,
105
+ name="key",
106
+ )
107
+ self.key_dense.build(inputs_shape)
108
+
109
+ self.value_dense = keras.layers.EinsumDense(
110
+ equation="bkm,mvh->bkvh",
111
+ output_shape=(
112
+ None,
113
+ self.num_key_value_heads,
114
+ head_dim,
115
+ ),
116
+ kernel_initializer=self.kernel_initializer,
117
+ bias_initializer=self.bias_initializer,
118
+ bias_axes="vh",
119
+ dtype=self.dtype_policy,
120
+ name="value",
121
+ )
122
+ self.value_dense.build(inputs_shape)
123
+
124
+ self._softmax = keras.layers.Softmax(
125
+ axis=-1,
126
+ dtype="float32",
127
+ name="attention_softmax",
128
+ )
129
+
130
+ self._dropout_layer = keras.layers.Dropout(
131
+ rate=self.dropout,
132
+ dtype=self.dtype_policy,
133
+ )
134
+
135
+ self._output_dense = keras.layers.EinsumDense(
136
+ equation="bquh,uhm->bqm",
137
+ output_shape=(None, hidden_dim),
138
+ kernel_initializer=self.kernel_initializer,
139
+ dtype=self.dtype_policy,
140
+ name="attention_output",
141
+ )
142
+ self._output_dense.build((None, None, self.num_query_heads, head_dim))
143
+
144
+ self.rotary_embedding_layer = RotaryEmbedding(
145
+ max_wavelength=self.rope_max_wavelength,
146
+ scaling_factor=self.rope_scaling_factor,
147
+ dtype=self.dtype_policy,
148
+ )
149
+
150
+ self._dot_product_equation = "bquh,bkuh->buqk"
151
+ self._combine_equation = "buqk,bkuh->bquh"
152
+
153
+ self.built = True
154
+
155
+ def call(
156
+ self,
157
+ hidden_states,
158
+ attention_mask=None,
159
+ cache=None,
160
+ cache_update_index=None,
161
+ training=None,
162
+ ):
163
+ """Applies attention mechanism to the input hidden states.
164
+
165
+ Args:
166
+ hidden_states: Input tensor of shape [batch_size, seq_length,
167
+ hidden_size].
168
+ attention_mask: Mask tensor of shape [batch_size, seq_length,
169
+ seq_length].
170
+ cache: Optional cached key and value tensors.
171
+ cache_update_index: Index at which to update the cache.
172
+ training: Boolean indicating whether in training mode.
173
+
174
+ Returns:
175
+ attention_output: Output tensor after applying attention.
176
+ cache: Updated cache tensors (if cache is provided).
177
+ """
178
+ start_index = (
179
+ cache_update_index if cache_update_index is not None else 0
180
+ )
181
+
182
+ query = self.query_dense(hidden_states)
183
+
184
+ # Compute RoPE for queries
185
+ query = self.rotary_embedding_layer(query, start_index=start_index)
186
+
187
+ def _compute_key_value(x):
188
+ key, value = self.key_dense(x), self.value_dense(x)
189
+ # Compute RoPE for keys
190
+ key = self.rotary_embedding_layer(key, start_index=start_index)
191
+ return key, value
192
+
193
+ if cache is not None:
194
+ key_cache = cache[:, 0, ...]
195
+ value_cache = cache[:, 1, ...]
196
+ if cache_update_index is None:
197
+ key = key_cache
198
+ value = value_cache
199
+ else:
200
+ key_update, value_update = _compute_key_value(hidden_states)
201
+ start = [0, cache_update_index, 0, 0]
202
+ key = ops.slice_update(key_cache, start, key_update)
203
+ value = ops.slice_update(value_cache, start, value_update)
204
+ cache = ops.stack((key, value), axis=1)
205
+ else:
206
+ if cache_update_index is not None:
207
+ raise ValueError(
208
+ "`cache_update_index` should not be set if `cache` is "
209
+ f"`None`. Received: cache={cache}, "
210
+ f"cache_update_index={cache_update_index}"
211
+ )
212
+ key, value = _compute_key_value(hidden_states)
213
+
214
+ # [batch_shape, seq_len, num_key_value_heads, head_dim]
215
+ # -> [batch_shape, seq_len, num_heads, head_dim]
216
+ key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2)
217
+ value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2)
218
+
219
+ attention_output = self._compute_attention(
220
+ query,
221
+ key,
222
+ value,
223
+ attention_mask,
224
+ cache_update_index=cache_update_index,
225
+ )
226
+
227
+ attention_output = self._dropout_layer(
228
+ attention_output, training=training
229
+ )
230
+
231
+ attention_output = self._output_dense(attention_output)
232
+
233
+ if cache is not None:
234
+ return attention_output, cache
235
+ return attention_output
236
+
237
+ def _masked_softmax(self, attention_scores, attention_mask=None):
238
+ """Applies softmax with optional masking.
239
+
240
+ Args:
241
+ attention_scores: Attention score tensor.
242
+ attention_mask: Optional mask tensor.
243
+
244
+ Returns:
245
+ Masked softmax attention weights.
246
+ """
247
+ if attention_mask is not None:
248
+ return self._softmax(
249
+ attention_scores, attention_mask[:, None, :, :]
250
+ )
251
+ return self._softmax(attention_scores)
252
+
253
+ def _use_fused_attention_op(self):
254
+ if not fused_attention_op_available():
255
+ return False
256
+ if self.dropout > 0.0:
257
+ return False
258
+ if running_on_gpu():
259
+ return gpu_supports_fused_attention_op()
260
+ elif running_on_tpu():
261
+ # TPU supports softcap with on keras >= 3.10.
262
+ sig = inspect.signature(ops.dot_product_attention)
263
+ return "attn_logits_soft_cap" in sig.parameters
264
+ else:
265
+ return False
266
+
267
+ def _compute_attention(
268
+ self,
269
+ query,
270
+ key,
271
+ value,
272
+ attention_mask=None,
273
+ cache_update_index=None,
274
+ **kwargs,
275
+ ):
276
+ """Computes attention using query, key, and value tensors.
277
+
278
+ Uses Flash Attention when available for better performance.
279
+
280
+ Args:
281
+ query: Query tensor.
282
+ key: Key tensor.
283
+ value: Value tensor.
284
+ attention_mask: Optional mask tensor.
285
+ cache_update_index: Index for sliding window computation.
286
+
287
+ Returns:
288
+ attention_output: Output tensor after applying attention.
289
+ """
290
+ if self._use_fused_attention_op():
291
+ if attention_mask is not None:
292
+ attention_mask = ops.expand_dims(attention_mask, axis=1)
293
+ attention_mask = ops.cast(attention_mask, dtype="bool")
294
+
295
+ attention_output = ops.dot_product_attention(
296
+ query,
297
+ key,
298
+ value,
299
+ mask=attention_mask,
300
+ scale=self._inv_norm_factor,
301
+ **kwargs,
302
+ )
303
+ return attention_output
304
+
305
+ attention_scores = ops.einsum(self._dot_product_equation, query, key)
306
+
307
+ attention_scores = ops.multiply(
308
+ attention_scores,
309
+ ops.cast(self._inv_norm_factor, self.compute_dtype),
310
+ )
311
+ if self.use_sliding_window_attention:
312
+ attention_mask = self._mask_sliding_window(
313
+ attention_mask,
314
+ cache_update_index=cache_update_index
315
+ if cache_update_index
316
+ else 0,
317
+ )
318
+ attention_scores = self._masked_softmax(
319
+ attention_scores, attention_mask
320
+ )
321
+ attention_scores = ops.cast(attention_scores, self.compute_dtype)
322
+ attention_output = ops.einsum(
323
+ self._combine_equation, attention_scores, value
324
+ )
325
+
326
+ return attention_output
327
+
328
+ def _mask_sliding_window(
329
+ self,
330
+ attention_mask,
331
+ cache_update_index=0,
332
+ ):
333
+ """Creates and combines a sliding window mask with the attention mask.
334
+
335
+ Args:
336
+ attention_mask: Original attention mask.
337
+ cache_update_index: Starting index for the sliding window.
338
+
339
+ Returns:
340
+ Combined attention mask with sliding window constraints.
341
+ """
342
+ _, query_len, key_len = ops.shape(attention_mask)
343
+ # Compute the sliding window for square attention.
344
+ all_ones = ops.ones((key_len, key_len), "bool")
345
+ sliding_mask = ops.triu(
346
+ all_ones, -1 * self.sliding_window_size + 1
347
+ ) * ops.tril(all_ones, self.sliding_window_size - 1)
348
+ # Slice the window for short queries during generation.
349
+ start = (cache_update_index, 0)
350
+ sliding_mask = ops.slice(sliding_mask, start, (query_len, key_len))
351
+ sliding_mask = ops.expand_dims(sliding_mask, 0)
352
+ return ops.logical_and(attention_mask, ops.cast(sliding_mask, "bool"))
353
+
354
+ def get_config(self):
355
+ config = super().get_config()
356
+ config.update(
357
+ {
358
+ "num_query_heads": self.num_query_heads,
359
+ "num_key_value_heads": self.num_key_value_heads,
360
+ "rope_max_wavelength": self.rope_max_wavelength,
361
+ "rope_scaling_factor": self.rope_scaling_factor,
362
+ "kernel_initializer": keras.initializers.serialize(
363
+ self.kernel_initializer
364
+ ),
365
+ "bias_initializer": keras.initializers.serialize(
366
+ self.bias_initializer
367
+ ),
368
+ "dropout": self.dropout,
369
+ "use_sliding_window_attention": (
370
+ self.use_sliding_window_attention
371
+ ),
372
+ "sliding_window_size": self.sliding_window_size,
373
+ }
374
+ )
375
+ return config