keras-hub-nightly 0.24.0.dev202511260427__py3-none-any.whl → 0.26.0.dev202601010440__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of keras-hub-nightly might be problematic. Click here for more details.

Files changed (65) hide show
  1. keras_hub/models/__init__.py +12 -0
  2. keras_hub/src/layers/modeling/reversible_embedding.py +2 -275
  3. keras_hub/src/layers/modeling/rotary_embedding.py +188 -14
  4. keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -3
  5. keras_hub/src/models/albert/albert_backbone.py +1 -3
  6. keras_hub/src/models/bart/bart_backbone.py +1 -3
  7. keras_hub/src/models/bert/bert_backbone.py +1 -3
  8. keras_hub/src/models/bloom/bloom_backbone.py +1 -3
  9. keras_hub/src/models/causal_lm.py +23 -1
  10. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -3
  11. keras_hub/src/models/dinov3/dinov3_presets.py +90 -1
  12. keras_hub/src/models/electra/electra_backbone.py +1 -3
  13. keras_hub/src/models/esm/esm_attention.py +11 -4
  14. keras_hub/src/models/f_net/f_net_backbone.py +1 -3
  15. keras_hub/src/models/falcon/falcon_backbone.py +1 -3
  16. keras_hub/src/models/gemma/gemma_backbone.py +1 -3
  17. keras_hub/src/models/gemma/gemma_causal_lm.py +16 -0
  18. keras_hub/src/models/gemma3/gemma3_backbone.py +1 -3
  19. keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +8 -3
  20. keras_hub/src/models/gemma3/gemma3_presets.py +12 -0
  21. keras_hub/src/models/gemma3/gemma3_tokenizer.py +20 -8
  22. keras_hub/src/models/gpt2/gpt2_backbone.py +1 -3
  23. keras_hub/src/models/gpt2/gpt2_causal_lm.py +17 -0
  24. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +1 -3
  25. keras_hub/src/models/gpt_oss/__init__.py +5 -0
  26. keras_hub/src/models/gpt_oss/gpt_oss_attention.py +330 -0
  27. keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +219 -0
  28. keras_hub/src/models/gpt_oss/gpt_oss_causal_lm.py +284 -0
  29. keras_hub/src/models/gpt_oss/gpt_oss_causal_lm_preprocessor.py +79 -0
  30. keras_hub/src/models/gpt_oss/gpt_oss_decoder.py +444 -0
  31. keras_hub/src/models/gpt_oss/gpt_oss_layer_norm.py +34 -0
  32. keras_hub/src/models/gpt_oss/gpt_oss_presets.py +51 -0
  33. keras_hub/src/models/gpt_oss/gpt_oss_tokenizer.py +39 -0
  34. keras_hub/src/models/llama/llama_backbone.py +1 -3
  35. keras_hub/src/models/llama3/llama3_presets.py +1 -1
  36. keras_hub/src/models/masked_lm.py +22 -0
  37. keras_hub/src/models/mistral/mistral_backbone.py +1 -3
  38. keras_hub/src/models/mixtral/mixtral_backbone.py +1 -3
  39. keras_hub/src/models/moonshine/moonshine_backbone.py +1 -3
  40. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +1 -3
  41. keras_hub/src/models/parseq/parseq_decoder.py +21 -9
  42. keras_hub/src/models/phi3/phi3_backbone.py +1 -3
  43. keras_hub/src/models/qwen/qwen_backbone.py +1 -3
  44. keras_hub/src/models/qwen3/qwen3_backbone.py +1 -3
  45. keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +1 -3
  46. keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +1 -3
  47. keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -3
  48. keras_hub/src/models/siglip/siglip_layers.py +1 -3
  49. keras_hub/src/models/smollm3/__init__.py +5 -0
  50. keras_hub/src/models/smollm3/smollm3_backbone.py +1 -3
  51. keras_hub/src/models/smollm3/smollm3_presets.py +16 -0
  52. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +1 -1
  53. keras_hub/src/models/stable_diffusion_3/t5_encoder.py +1 -3
  54. keras_hub/src/models/t5/t5_backbone.py +1 -3
  55. keras_hub/src/models/t5gemma/t5gemma_backbone.py +1 -3
  56. keras_hub/src/tests/test_case.py +1 -3
  57. keras_hub/src/utils/transformers/convert_gemma3.py +353 -0
  58. keras_hub/src/utils/transformers/convert_gpt_oss.py +302 -0
  59. keras_hub/src/utils/transformers/preset_loader.py +12 -0
  60. keras_hub/src/version.py +1 -1
  61. keras_hub/tokenizers/__init__.py +3 -0
  62. {keras_hub_nightly-0.24.0.dev202511260427.dist-info → keras_hub_nightly-0.26.0.dev202601010440.dist-info}/METADATA +4 -5
  63. {keras_hub_nightly-0.24.0.dev202511260427.dist-info → keras_hub_nightly-0.26.0.dev202601010440.dist-info}/RECORD +65 -52
  64. {keras_hub_nightly-0.24.0.dev202511260427.dist-info → keras_hub_nightly-0.26.0.dev202601010440.dist-info}/WHEEL +0 -0
  65. {keras_hub_nightly-0.24.0.dev202511260427.dist-info → keras_hub_nightly-0.26.0.dev202601010440.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,93 @@
1
1
  """DINOV3 model preset configurations."""
2
2
 
3
3
  # Metadata for loading pretrained model weights.
4
- backbone_presets = {}
4
+ backbone_presets = {
5
+ "dinov3_vit_small_lvd1689m": {
6
+ "metadata": {
7
+ "description": (
8
+ "Vision Transformer (small-sized model) trained on LVD-1689M "
9
+ "using DINOv3."
10
+ ),
11
+ "params": 21_600_000,
12
+ "path": "dinov3",
13
+ },
14
+ "kaggle_handle": "kaggle://keras/dinov3/keras/dinov3_vit_small_lvd1689m/1",
15
+ },
16
+ "dinov3_vit_small_plus_lvd1689m": {
17
+ "metadata": {
18
+ "description": (
19
+ "Vision Transformer (small-plus-sized model) trained on "
20
+ "LVD-1689M using DINOv3."
21
+ ),
22
+ "params": 29_000_000,
23
+ "path": "dinov3",
24
+ },
25
+ "kaggle_handle": "kaggle://keras/dinov3/keras/dinov3_vit_small_plus_lvd1689m/1",
26
+ },
27
+ "dinov3_vit_base_lvd1689m": {
28
+ "metadata": {
29
+ "description": (
30
+ "Vision Transformer (base-sized model) trained on LVD-1689M "
31
+ "using DINOv3."
32
+ ),
33
+ "params": 86_000_000,
34
+ "path": "dinov3",
35
+ },
36
+ "kaggle_handle": "kaggle://keras/dinov3/keras/dinov3_vit_base_lvd1689m/1",
37
+ },
38
+ "dinov3_vit_large_lvd1689m": {
39
+ "metadata": {
40
+ "description": (
41
+ "Vision Transformer (large-sized model) trained on LVD-1689M "
42
+ "using DINOv3."
43
+ ),
44
+ "params": 300_000_000,
45
+ "path": "dinov3",
46
+ },
47
+ "kaggle_handle": "kaggle://keras/dinov3/keras/dinov3_vit_large_lvd1689m/1",
48
+ },
49
+ "dinov3_vit_huge_plus_lvd1689m": {
50
+ "metadata": {
51
+ "description": (
52
+ "Vision Transformer (huge-plus-sized model) trained on "
53
+ "LVD-1689M using DINOv3."
54
+ ),
55
+ "params": 840_000_000,
56
+ "path": "dinov3",
57
+ },
58
+ "kaggle_handle": "kaggle://keras/dinov3/keras/dinov3_vit_huge_plus_lvd1689m/1",
59
+ },
60
+ "dinov3_vit_7b_lvd1689m": {
61
+ "metadata": {
62
+ "description": (
63
+ "Vision Transformer (7B-sized model) trained on LVD-1689M "
64
+ "using DINOv3."
65
+ ),
66
+ "params": 6_700_000_000,
67
+ "path": "dinov3",
68
+ },
69
+ "kaggle_handle": "kaggle://keras/dinov3/keras/dinov3_vit_7b_lvd1689m/1",
70
+ },
71
+ "dinov3_vit_large_sat493m": {
72
+ "metadata": {
73
+ "description": (
74
+ "Vision Transformer (large-sized model) trained on SAT-493M "
75
+ "using DINOv3."
76
+ ),
77
+ "params": 300_000_000,
78
+ "path": "dinov3",
79
+ },
80
+ "kaggle_handle": "kaggle://keras/dinov3/keras/dinov3_vit_large_sat493m/1",
81
+ },
82
+ "dinov3_vit_7b_sat493m": {
83
+ "metadata": {
84
+ "description": (
85
+ "Vision Transformer (7B-sized model) trained on SAT-493M "
86
+ "using DINOv3."
87
+ ),
88
+ "params": 6_700_000_000,
89
+ "path": "dinov3",
90
+ },
91
+ "kaggle_handle": "kaggle://keras/dinov3/keras/dinov3_vit_7b_sat493m/1",
92
+ },
93
+ }
@@ -1,10 +1,8 @@
1
1
  import keras
2
+ from keras.layers import ReversibleEmbedding
2
3
 
3
4
  from keras_hub.src.api_export import keras_hub_export
4
5
  from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
5
- from keras_hub.src.layers.modeling.reversible_embedding import (
6
- ReversibleEmbedding,
7
- )
8
6
  from keras_hub.src.layers.modeling.transformer_encoder import TransformerEncoder
9
7
  from keras_hub.src.models.backbone import Backbone
10
8
  from keras_hub.src.utils.keras_utils import gelu_approximate
@@ -14,7 +14,8 @@ class ESMRotaryEmbedding(RotaryEmbedding):
14
14
  inv_freq = self.scaling_factor / (
15
15
  self.max_wavelength ** (ops.arange(0, dim, 2, dtype=x.dtype) / dim)
16
16
  )
17
- t = ops.arange(x.shape[position], dtype=x.dtype)
17
+ # Use ops.shape for dynamic shape compatibility with TFLite
18
+ t = ops.arange(ops.shape(x)[position], dtype=x.dtype)
18
19
  freqs = ops.outer(t, inv_freq)
19
20
  emb = ops.concatenate((freqs, freqs), axis=-1)
20
21
 
@@ -32,11 +33,17 @@ class ESMRotaryEmbedding(RotaryEmbedding):
32
33
 
33
34
  def rotate_half(self, x):
34
35
  x1, x2 = ops.split(x, 2, -1)
35
- return ops.concatenate((-x2, x1), axis=-1)
36
+ # Avoid `ops.concatenate` to prevent XLA compilation issues on JAX
37
+ # backend. Use stack + reshape approach from base RotaryEmbedding.
38
+ half_rot_x = ops.stack((-x2, x1), axis=-2)
39
+ half_rot_x = ops.reshape(half_rot_x, ops.shape(x))
40
+ return half_rot_x
36
41
 
37
42
  def apply_rotary_pos_emb(self, x, cos, sin):
38
- cos = cos[:, : x.shape[1], :, :]
39
- sin = sin[:, : x.shape[1], :, :]
43
+ # Use ops.shape for dynamic shape compatibility with TFLite
44
+ seq_len = ops.shape(x)[1]
45
+ cos = cos[:, :seq_len, :, :]
46
+ sin = sin[:, :seq_len, :, :]
40
47
 
41
48
  return (x * cos) + (self.rotate_half(x) * sin)
42
49
 
@@ -1,11 +1,9 @@
1
1
  import keras
2
+ from keras.layers import ReversibleEmbedding
2
3
 
3
4
  from keras_hub.src.api_export import keras_hub_export
4
5
  from keras_hub.src.layers.modeling.f_net_encoder import FNetEncoder
5
6
  from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
6
- from keras_hub.src.layers.modeling.reversible_embedding import (
7
- ReversibleEmbedding,
8
- )
9
7
  from keras_hub.src.models.backbone import Backbone
10
8
  from keras_hub.src.utils.keras_utils import gelu_approximate
11
9
 
@@ -1,9 +1,7 @@
1
1
  import keras
2
+ from keras.layers import ReversibleEmbedding
2
3
 
3
4
  from keras_hub.src.api_export import keras_hub_export
4
- from keras_hub.src.layers.modeling.reversible_embedding import (
5
- ReversibleEmbedding,
6
- )
7
5
  from keras_hub.src.models.backbone import Backbone
8
6
  from keras_hub.src.models.falcon.falcon_transformer_decoder import (
9
7
  FalconTransformerDecoder,
@@ -1,10 +1,8 @@
1
1
  import keras
2
2
  from keras import ops
3
+ from keras.layers import ReversibleEmbedding
3
4
 
4
5
  from keras_hub.src.api_export import keras_hub_export
5
- from keras_hub.src.layers.modeling.reversible_embedding import (
6
- ReversibleEmbedding,
7
- )
8
6
  from keras_hub.src.models.backbone import Backbone
9
7
  from keras_hub.src.models.gemma.gemma_decoder_block import GemmaDecoderBlock
10
8
  from keras_hub.src.models.gemma.rms_normalization import RMSNormalization
@@ -431,3 +431,19 @@ class GemmaCausalLM(CausalLM):
431
431
  )
432
432
  per_token_loss = per_token_loss_fn(target_ids, logits)
433
433
  return per_token_loss
434
+
435
+ def get_quantization_layer_structure(self, mode):
436
+ if mode != "gptq":
437
+ return None
438
+
439
+ # Wrap embedding + scaling
440
+ backbone = self.backbone
441
+ inputs = keras.Input(shape=(None,), dtype="int32")
442
+ x = backbone.token_embedding(inputs)
443
+ x = x * ops.cast(ops.sqrt(backbone.hidden_dim), x.dtype)
444
+ pre_processor = keras.Model(inputs=inputs, outputs=x)
445
+
446
+ return {
447
+ "pre_block_layers": [pre_processor],
448
+ "sequential_blocks": backbone.transformer_layers,
449
+ }
@@ -1,10 +1,8 @@
1
1
  import keras
2
2
  from keras import ops
3
+ from keras.layers import ReversibleEmbedding
3
4
 
4
5
  from keras_hub.src.api_export import keras_hub_export
5
- from keras_hub.src.layers.modeling.reversible_embedding import (
6
- ReversibleEmbedding,
7
- )
8
6
  from keras_hub.src.models.backbone import Backbone
9
7
  from keras_hub.src.models.gemma.rms_normalization import RMSNormalization
10
8
  from keras_hub.src.models.gemma3.gemma3_decoder_block import Gemma3DecoderBlock
@@ -283,9 +283,14 @@ class Gemma3CausalLMPreprocessor(CausalLMPreprocessor):
283
283
  # is `None`.
284
284
  self.text_only_model = self.image_converter is None
285
285
 
286
- self.image_placeholder = self.tokenizer.image_placeholder
287
- self.start_of_image_token = self.tokenizer.start_of_image_token
288
- self.end_of_image_token = self.tokenizer.end_of_image_token
286
+ if self.text_only_model:
287
+ self.image_placeholder = None
288
+ self.start_of_image_token = None
289
+ self.end_of_image_token = None
290
+ else:
291
+ self.image_placeholder = self.tokenizer.image_placeholder
292
+ self.start_of_image_token = self.tokenizer.start_of_image_token
293
+ self.end_of_image_token = self.tokenizer.end_of_image_token
289
294
 
290
295
  def build(self, input_shape):
291
296
  # Defer packer creation to `build()` so that we can be sure tokenizer
@@ -220,4 +220,16 @@ backbone_presets = {
220
220
  },
221
221
  "kaggle_handle": "kaggle://keras/medgemma/keras/medgemma_instruct_27b_text/1",
222
222
  },
223
+ "function_gemma_instruct_270m": {
224
+ "metadata": {
225
+ "description": (
226
+ "A 270M Million parameter text-only model based on Gemma 3. "
227
+ "This model is trained specifically for function calling "
228
+ "improvements."
229
+ ),
230
+ "params": 268098176,
231
+ "path": "gemma3",
232
+ },
233
+ "kaggle_handle": "kaggle://keras/function-gemma/keras/function_gemma_instruct_270m/1",
234
+ },
223
235
  }
@@ -77,20 +77,32 @@ class Gemma3Tokenizer(SentencePieceTokenizer):
77
77
 
78
78
  backbone_cls = Gemma3Backbone
79
79
 
80
- def __init__(self, proto, **kwargs):
80
+ def __init__(self, proto, has_vision_tokens=True, **kwargs):
81
81
  # Add special tokens.
82
82
 
83
+ self.has_vision_tokens = has_vision_tokens
83
84
  # The usual tokens.
84
85
  self._add_special_token("<bos>", "start_token")
85
86
  self._add_special_token("<eos>", "end_token")
86
87
  self._add_special_token("<pad>", "pad_token")
87
88
 
88
- # Image placeholder token.
89
- self._add_special_token("<img>", "image_placeholder")
90
-
91
- # Some tokens which are used in the preprocessor. We need to keep them
92
- # here so that the preprocessor works with `tf.data`.
93
- self._add_special_token("<start_of_image>", "start_of_image_token")
94
- self._add_special_token("<end_of_image>", "end_of_image_token")
89
+ if has_vision_tokens:
90
+ # Image placeholder token.
91
+ self._add_special_token("<img>", "image_placeholder")
92
+ # Some tokens which are used in the preprocessor.
93
+ # We need to keep them
94
+ # here so that the preprocessor works with tf.data.
95
+ self._add_special_token("<start_of_image>", "start_of_image_token")
96
+ self._add_special_token("<end_of_image>", "end_of_image_token")
97
+ else:
98
+ # For text-only, skip assigning token IDs or set to -1
99
+ self.start_of_image_token_id = -1
100
+ self.image_placeholder_token_id = -1
101
+ self.end_of_image_token_id = -1
95
102
 
96
103
  super().__init__(proto=proto, **kwargs)
104
+
105
+ def get_config(self):
106
+ config = super().get_config()
107
+ config.update({"has_vision_tokens": self.has_vision_tokens})
108
+ return config
@@ -1,10 +1,8 @@
1
1
  import keras
2
+ from keras.layers import ReversibleEmbedding
2
3
 
3
4
  from keras_hub.src.api_export import keras_hub_export
4
5
  from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
5
- from keras_hub.src.layers.modeling.reversible_embedding import (
6
- ReversibleEmbedding,
7
- )
8
6
  from keras_hub.src.layers.modeling.transformer_decoder import TransformerDecoder
9
7
  from keras_hub.src.models.backbone import Backbone
10
8
  from keras_hub.src.utils.keras_utils import gelu_approximate
@@ -420,3 +420,20 @@ class GPT2CausalLM(CausalLM):
420
420
  )
421
421
  per_token_loss = per_token_loss_fn(target_ids, logits)
422
422
  return per_token_loss
423
+
424
+ def get_quantization_layer_structure(self, mode):
425
+ if mode != "gptq":
426
+ return None
427
+
428
+ backbone = self.backbone
429
+ token_ids = keras.Input(shape=(None,), dtype="int32")
430
+ tokens = backbone.token_embedding(token_ids)
431
+ positions = backbone.position_embedding(tokens)
432
+ x = backbone.embeddings_add((tokens, positions))
433
+ x = backbone.embeddings_dropout(x)
434
+ pre_processor = keras.Model(inputs=token_ids, outputs=x)
435
+
436
+ return {
437
+ "pre_block_layers": [pre_processor],
438
+ "sequential_blocks": backbone.transformer_layers,
439
+ }
@@ -1,9 +1,7 @@
1
1
  import keras
2
+ from keras.layers import ReversibleEmbedding
2
3
 
3
4
  from keras_hub.src.api_export import keras_hub_export
4
- from keras_hub.src.layers.modeling.reversible_embedding import (
5
- ReversibleEmbedding,
6
- )
7
5
  from keras_hub.src.models.backbone import Backbone
8
6
  from keras_hub.src.models.gpt_neo_x.gpt_neo_x_decoder import GPTNeoXDecoder
9
7
  from keras_hub.src.utils.keras_utils import gelu_approximate
@@ -0,0 +1,5 @@
1
+ from keras_hub.src.models.gpt_oss.gpt_oss_backbone import GptOssBackbone
2
+ from keras_hub.src.models.gpt_oss.gpt_oss_presets import backbone_presets
3
+ from keras_hub.src.utils.preset_utils import register_presets
4
+
5
+ register_presets(backbone_presets, GptOssBackbone)
@@ -0,0 +1,330 @@
1
+ import math
2
+
3
+ import keras
4
+ from keras import ops
5
+
6
+ from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
7
+ from keras_hub.src.utils.keras_utils import clone_initializer
8
+
9
+
10
+ class GptOssAttention(keras.layers.Layer):
11
+ """A cached attention layer with sliding window and sink tokens.
12
+
13
+ This layer implements the attention mechanism described in the GPT-OSS
14
+ paper. It includes grouped-query attention, rotary position embeddings,
15
+ sliding window attention, and sink tokens for improved performance on
16
+ long sequences.
17
+
18
+ Args:
19
+ num_query_heads: int. The number of query attention heads.
20
+ num_key_value_heads: int. The number of key and value attention
21
+ heads.
22
+ rope_max_wavelength: int. The maximum wavelength for the
23
+ rotary position embedding. Defaults to 10000.
24
+ rope_scaling_factor: float. The scaling factor for the
25
+ rotary position embedding. Defaults to 1.0.
26
+ kernel_initializer: str. The initializer for the kernel
27
+ weights. Defaults to "glorot_uniform".
28
+ sliding_window: int. The size of the sliding window.
29
+ Defaults to 4096.
30
+ dropout: float. The dropout rate. Defaults to 0.
31
+ head_dim: int. Head dimension for attention. If None,
32
+ calculated as hidden_dim // num_query_heads. Defaults to None.
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ num_query_heads,
38
+ num_key_value_heads,
39
+ rope_max_wavelength=10000,
40
+ rope_scaling_factor=1.0,
41
+ kernel_initializer="glorot_uniform",
42
+ sliding_window=4096,
43
+ dropout=0,
44
+ head_dim=None,
45
+ **kwargs,
46
+ ):
47
+ super().__init__(**kwargs)
48
+ self.num_query_heads = num_query_heads
49
+ self.num_key_value_heads = num_key_value_heads
50
+ self.sliding_window = sliding_window
51
+ self.dropout = dropout
52
+ self.head_dim = head_dim
53
+ self.rope_max_wavelength = rope_max_wavelength
54
+ self.rope_scaling_factor = rope_scaling_factor
55
+ self.num_key_value_groups = num_query_heads // num_key_value_heads
56
+ self._kernel_initializer = keras.initializers.get(
57
+ clone_initializer(kernel_initializer)
58
+ )
59
+
60
+ def build(self, inputs_shape):
61
+ # Einsum variables:
62
+ # b = batch size
63
+ # q = query length
64
+ # k = key/value length
65
+ # m = the model's hidden_dim
66
+ # u = num query heads
67
+ # v = num key/value heads
68
+ # h = head dim
69
+ self._hidden_dim = inputs_shape[-1]
70
+
71
+ if self.head_dim is not None:
72
+ self._head_dim = self.head_dim
73
+ else:
74
+ self._head_dim = self._hidden_dim // self.num_query_heads
75
+ self._inv_norm_factor = 1.0 / math.sqrt(self._head_dim)
76
+
77
+ self._rotary_dim = (self._head_dim // 2) * 2
78
+
79
+ self.query_dense = keras.layers.EinsumDense(
80
+ equation="bqm,muh->bquh",
81
+ output_shape=(None, self.num_query_heads, self._head_dim),
82
+ bias_axes="uh",
83
+ kernel_initializer=self._kernel_initializer,
84
+ bias_initializer="zeros",
85
+ dtype=self.dtype_policy,
86
+ name="query",
87
+ )
88
+ self.query_dense.build(inputs_shape)
89
+
90
+ self.key_dense = keras.layers.EinsumDense(
91
+ equation="bkm,mvh->bkvh",
92
+ output_shape=(
93
+ None,
94
+ self.num_key_value_heads,
95
+ self._head_dim,
96
+ ),
97
+ bias_axes="vh",
98
+ kernel_initializer=self._kernel_initializer,
99
+ bias_initializer="zeros",
100
+ dtype=self.dtype_policy,
101
+ name="key",
102
+ )
103
+ self.key_dense.build(inputs_shape)
104
+
105
+ self.value_dense = keras.layers.EinsumDense(
106
+ equation="bkm,mvh->bkvh",
107
+ output_shape=(
108
+ None,
109
+ self.num_key_value_heads,
110
+ self._head_dim,
111
+ ),
112
+ bias_axes="vh",
113
+ kernel_initializer=self._kernel_initializer,
114
+ bias_initializer="zeros",
115
+ dtype=self.dtype_policy,
116
+ name="value",
117
+ )
118
+ self.value_dense.build(inputs_shape)
119
+
120
+ self.dropout_layer = keras.layers.Dropout(
121
+ rate=self.dropout,
122
+ dtype=self.dtype_policy,
123
+ )
124
+
125
+ self.output_dense = keras.layers.EinsumDense(
126
+ equation="bquh,uhm->bqm",
127
+ output_shape=(None, self._hidden_dim),
128
+ bias_axes="m",
129
+ kernel_initializer=self._kernel_initializer,
130
+ bias_initializer="zeros",
131
+ dtype=self.dtype_policy,
132
+ name="attention_output",
133
+ )
134
+ self.output_dense.build(
135
+ (None, None, self.num_query_heads, self._head_dim)
136
+ )
137
+
138
+ self.rotary_embedding_layer = RotaryEmbedding(
139
+ max_wavelength=self.rope_max_wavelength,
140
+ scaling_factor=self.rope_scaling_factor, # YaRN scaling factor
141
+ rope_type="yarn",
142
+ beta_fast=32.0,
143
+ beta_slow=1.0,
144
+ original_max_position_embeddings=4096,
145
+ dtype=self.dtype_policy,
146
+ )
147
+
148
+ self.sinks = self.add_weight(
149
+ shape=(self.num_query_heads,),
150
+ initializer="random_normal",
151
+ dtype=self.dtype,
152
+ name="sinks",
153
+ )
154
+
155
+ self._dot_product_equation = "bquh,bkuh->buqk"
156
+ self._combine_equation = "buqk,bkuh->bquh"
157
+
158
+ self.built = True
159
+
160
+ def call(
161
+ self,
162
+ hidden_states,
163
+ attention_mask=None,
164
+ cache=None,
165
+ cache_update_index=None,
166
+ training=None,
167
+ ):
168
+ start_index = (
169
+ cache_update_index if cache_update_index is not None else 0
170
+ )
171
+
172
+ query = self.query_dense(hidden_states)
173
+
174
+ # Compute RoPE for queries (only
175
+ # to first _rotary_dim dimensions)
176
+ if self._rotary_dim < self._head_dim:
177
+ query_rot = query[..., : self._rotary_dim]
178
+ query_rot = self.rotary_embedding_layer(
179
+ query_rot, start_index=start_index
180
+ )
181
+ query = ops.concatenate(
182
+ [query_rot, query[..., self._rotary_dim :]], axis=-1
183
+ )
184
+ else:
185
+ query = self.rotary_embedding_layer(query, start_index=start_index)
186
+
187
+ def _compute_key_value(x):
188
+ key, value = self.key_dense(x), self.value_dense(x)
189
+ # Compute RoPE for keys (only apply to first _rotary_dim dimensions)
190
+ if self._rotary_dim < self._head_dim:
191
+ key_rot = key[..., : self._rotary_dim]
192
+ key_rot = self.rotary_embedding_layer(
193
+ key_rot, start_index=start_index
194
+ )
195
+ key = ops.concatenate(
196
+ [key_rot, key[..., self._rotary_dim :]], axis=-1
197
+ )
198
+ else:
199
+ key = self.rotary_embedding_layer(key, start_index=start_index)
200
+ return key, value
201
+
202
+ if cache is not None:
203
+ key_cache = cache[:, 0, ...]
204
+ value_cache = cache[:, 1, ...]
205
+ if cache_update_index is None:
206
+ key = key_cache
207
+ value = value_cache
208
+ else:
209
+ key_update, value_update = _compute_key_value(hidden_states)
210
+ start = [0, cache_update_index, 0, 0]
211
+ key = ops.slice_update(key_cache, start, key_update)
212
+ value = ops.slice_update(value_cache, start, value_update)
213
+ cache = ops.stack((key, value), axis=1)
214
+ else:
215
+ if cache_update_index is not None:
216
+ raise ValueError(
217
+ "`cache_update_index` should not be set if `cache` is "
218
+ f"`None`. Received: cache={cache}, "
219
+ f"cache_update_index={cache_update_index}"
220
+ )
221
+ key, value = _compute_key_value(hidden_states)
222
+
223
+ # [batch_shape, seq_len, num_key_value_heads, head_dim]
224
+ # -> [batch_shape, seq_len, num_heads, head_dim]
225
+ key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2)
226
+ value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2)
227
+
228
+ attention_output = self._compute_attention(
229
+ query, key, value, attention_mask, start_index
230
+ )
231
+
232
+ attention_output = self.dropout_layer(
233
+ attention_output, training=training
234
+ )
235
+
236
+ attention_output = self.output_dense(attention_output)
237
+
238
+ if cache is not None:
239
+ return attention_output, cache
240
+ return attention_output
241
+
242
+ def _compute_attention(
243
+ self, query, key, value, attention_mask=None, start_index=0
244
+ ):
245
+ attention_scores = ops.einsum(self._dot_product_equation, query, key)
246
+ attention_scores = ops.multiply(
247
+ attention_scores,
248
+ ops.cast(self._inv_norm_factor, self.compute_dtype),
249
+ )
250
+
251
+ # Apply sliding window mask if specified
252
+ if self.sliding_window is not None and self.sliding_window > 0:
253
+ q_len = ops.shape(attention_scores)[-2]
254
+ kv_len = ops.shape(attention_scores)[-1]
255
+
256
+ # Query positions are offset by start_index during generation
257
+ q_positions = ops.arange(q_len) + start_index
258
+ kv_positions = ops.arange(kv_len)
259
+
260
+ # Mask true for positions outside sliding window
261
+ # For causal attention: mask if kv_pos < q_pos - sliding_window
262
+ mask = (
263
+ kv_positions[None, :]
264
+ >= q_positions[:, None] - self.sliding_window
265
+ )
266
+ if self.compute_dtype == "float32":
267
+ sliding_adder = ops.cast(-1e9, self.compute_dtype)
268
+ else:
269
+ sliding_adder = ops.cast(-1e4, self.compute_dtype)
270
+ attention_scores = ops.where(
271
+ mask[None, None, :, :], attention_scores, sliding_adder
272
+ )
273
+
274
+ if attention_mask is not None:
275
+ # The mask is a boolean tensor, True for positions to be masked.
276
+ # We add a large negative number to the masked positions.
277
+ # Use a large negative value for masking
278
+ if self.compute_dtype == "float32":
279
+ adder = ops.cast(-1e9, self.compute_dtype)
280
+ else:
281
+ adder = ops.cast(-1e4, self.compute_dtype)
282
+ attention_scores = ops.where(
283
+ attention_mask[:, None, :, :], attention_scores, adder
284
+ )
285
+
286
+ # Handle sink tokens by concatenating them to the logits.
287
+ b = ops.shape(attention_scores)[0]
288
+ q = ops.shape(attention_scores)[2]
289
+
290
+ sinks = ops.reshape(self.sinks, (1, self.num_query_heads, 1, 1))
291
+ sinks = ops.broadcast_to(sinks, (b, self.num_query_heads, q, 1))
292
+ # attention_scores shape: [b, num_heads, q, k]
293
+ # sinks shape: [b, num_heads, q, 1]
294
+ # We need to concatenate along the last dimension
295
+ combined_logits = ops.concatenate([attention_scores, sinks], axis=-1)
296
+
297
+ # Stabilize logits before softmax for numerical stability.
298
+ max_logits = ops.max(combined_logits, axis=-1, keepdims=True)
299
+ max_logits = ops.stop_gradient(max_logits)
300
+ combined_logits = combined_logits - max_logits
301
+
302
+ probs = ops.softmax(combined_logits, axis=-1)
303
+
304
+ # Remove the sink probabilities before computing the output.
305
+ attention_scores = probs[..., :-1]
306
+ attention_scores = ops.cast(attention_scores, self.compute_dtype)
307
+
308
+ attention_output = ops.einsum(
309
+ self._combine_equation, attention_scores, value
310
+ )
311
+
312
+ return attention_output
313
+
314
+ def get_config(self):
315
+ config = super().get_config()
316
+ config.update(
317
+ {
318
+ "num_query_heads": self.num_query_heads,
319
+ "num_key_value_heads": self.num_key_value_heads,
320
+ "rope_max_wavelength": self.rope_max_wavelength,
321
+ "rope_scaling_factor": self.rope_scaling_factor,
322
+ "kernel_initializer": keras.initializers.serialize(
323
+ self._kernel_initializer
324
+ ),
325
+ "sliding_window": self.sliding_window,
326
+ "dropout": self.dropout,
327
+ "head_dim": self.head_dim,
328
+ }
329
+ )
330
+ return config