keras-hub-nightly 0.22.0.dev202508170419__py3-none-any.whl → 0.24.0.dev202511090424__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of keras-hub-nightly might be problematic. Click here for more details.

Files changed (126) hide show
  1. keras_hub/layers/__init__.py +15 -0
  2. keras_hub/models/__init__.py +93 -0
  3. keras_hub/src/layers/modeling/position_embedding.py +21 -6
  4. keras_hub/src/layers/modeling/reversible_embedding.py +8 -1
  5. keras_hub/src/layers/modeling/rotary_embedding.py +16 -6
  6. keras_hub/src/layers/modeling/sine_position_encoding.py +21 -8
  7. keras_hub/src/layers/modeling/token_and_position_embedding.py +2 -1
  8. keras_hub/src/models/backbone.py +28 -16
  9. keras_hub/src/models/causal_lm.py +37 -0
  10. keras_hub/src/models/causal_lm_preprocessor.py +14 -0
  11. keras_hub/src/models/clip/clip_presets.py +8 -8
  12. keras_hub/src/models/d_fine/__init__.py +5 -0
  13. keras_hub/src/models/d_fine/d_fine_attention.py +461 -0
  14. keras_hub/src/models/d_fine/d_fine_backbone.py +891 -0
  15. keras_hub/src/models/d_fine/d_fine_decoder.py +944 -0
  16. keras_hub/src/models/d_fine/d_fine_encoder.py +365 -0
  17. keras_hub/src/models/d_fine/d_fine_hybrid_encoder.py +642 -0
  18. keras_hub/src/models/d_fine/d_fine_image_converter.py +8 -0
  19. keras_hub/src/models/d_fine/d_fine_layers.py +1828 -0
  20. keras_hub/src/models/d_fine/d_fine_loss.py +938 -0
  21. keras_hub/src/models/d_fine/d_fine_object_detector.py +875 -0
  22. keras_hub/src/models/d_fine/d_fine_object_detector_preprocessor.py +14 -0
  23. keras_hub/src/models/d_fine/d_fine_presets.py +155 -0
  24. keras_hub/src/models/d_fine/d_fine_utils.py +827 -0
  25. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +7 -2
  26. keras_hub/src/models/depth_anything/__init__.py +9 -0
  27. keras_hub/src/models/depth_anything/depth_anything_backbone.py +232 -0
  28. keras_hub/src/models/depth_anything/depth_anything_depth_estimator.py +70 -0
  29. keras_hub/src/models/depth_anything/depth_anything_depth_estimator_preprocessor.py +16 -0
  30. keras_hub/src/models/depth_anything/depth_anything_image_converter.py +10 -0
  31. keras_hub/src/models/depth_anything/depth_anything_layers.py +725 -0
  32. keras_hub/src/models/depth_anything/depth_anything_loss.py +89 -0
  33. keras_hub/src/models/depth_anything/depth_anything_presets.py +41 -0
  34. keras_hub/src/models/depth_anything/interpolate.py +62 -0
  35. keras_hub/src/models/depth_estimator.py +239 -0
  36. keras_hub/src/models/depth_estimator_preprocessor.py +78 -0
  37. keras_hub/src/models/dinov2/dinov2_backbone.py +29 -3
  38. keras_hub/src/models/dinov2/dinov2_layers.py +16 -4
  39. keras_hub/src/models/dinov3/__init__.py +5 -0
  40. keras_hub/src/models/dinov3/dinov3_backbone.py +263 -0
  41. keras_hub/src/models/dinov3/dinov3_image_converter.py +8 -0
  42. keras_hub/src/models/dinov3/dinov3_layers.py +1013 -0
  43. keras_hub/src/models/dinov3/dinov3_presets.py +4 -0
  44. keras_hub/src/models/gemma/gemma_backbone.py +0 -1
  45. keras_hub/src/models/gemma/gemma_presets.py +30 -0
  46. keras_hub/src/models/gemma3/gemma3_attention.py +48 -0
  47. keras_hub/src/models/gemma3/gemma3_backbone.py +4 -1
  48. keras_hub/src/models/gemma3/gemma3_decoder_block.py +12 -0
  49. keras_hub/src/models/gemma3/gemma3_presets.py +39 -0
  50. keras_hub/src/models/hgnetv2/hgnetv2_backbone.py +4 -1
  51. keras_hub/src/models/hgnetv2/hgnetv2_encoder.py +3 -2
  52. keras_hub/src/models/hgnetv2/hgnetv2_layers.py +27 -11
  53. keras_hub/src/models/image_to_image.py +5 -0
  54. keras_hub/src/models/inpaint.py +5 -0
  55. keras_hub/src/models/mobilenetv5/__init__.py +9 -0
  56. keras_hub/src/models/mobilenetv5/mobilenetv5_attention.py +699 -0
  57. keras_hub/src/models/mobilenetv5/mobilenetv5_backbone.py +396 -0
  58. keras_hub/src/models/mobilenetv5/mobilenetv5_blocks.py +890 -0
  59. keras_hub/src/models/mobilenetv5/mobilenetv5_builder.py +436 -0
  60. keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier.py +157 -0
  61. keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier_preprocessor.py +16 -0
  62. keras_hub/src/models/mobilenetv5/mobilenetv5_image_converter.py +10 -0
  63. keras_hub/src/models/mobilenetv5/mobilenetv5_layers.py +462 -0
  64. keras_hub/src/models/mobilenetv5/mobilenetv5_presets.py +15 -0
  65. keras_hub/src/models/mobilenetv5/mobilenetv5_utils.py +146 -0
  66. keras_hub/src/models/parseq/__init__.py +5 -0
  67. keras_hub/src/models/parseq/parseq_backbone.py +134 -0
  68. keras_hub/src/models/parseq/parseq_causal_lm.py +466 -0
  69. keras_hub/src/models/parseq/parseq_causal_lm_preprocessor.py +168 -0
  70. keras_hub/src/models/parseq/parseq_decoder.py +418 -0
  71. keras_hub/src/models/parseq/parseq_image_converter.py +8 -0
  72. keras_hub/src/models/parseq/parseq_presets.py +15 -0
  73. keras_hub/src/models/parseq/parseq_tokenizer.py +221 -0
  74. keras_hub/src/models/qwen3_moe/__init__.py +5 -0
  75. keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py +371 -0
  76. keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +365 -0
  77. keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm.py +357 -0
  78. keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm_preprocessor.py +12 -0
  79. keras_hub/src/models/qwen3_moe/qwen3_moe_decoder.py +672 -0
  80. keras_hub/src/models/qwen3_moe/qwen3_moe_layernorm.py +45 -0
  81. keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +30 -0
  82. keras_hub/src/models/qwen3_moe/qwen3_moe_tokenizer.py +48 -0
  83. keras_hub/src/models/sam/sam_prompt_encoder.py +3 -1
  84. keras_hub/src/models/siglip/siglip_presets.py +15 -0
  85. keras_hub/src/models/smollm3/smollm3_backbone.py +211 -0
  86. keras_hub/src/models/smollm3/smollm3_causal_lm.py +310 -0
  87. keras_hub/src/models/smollm3/smollm3_causal_lm_preprocessor.py +84 -0
  88. keras_hub/src/models/smollm3/smollm3_layers.py +757 -0
  89. keras_hub/src/models/smollm3/smollm3_tokenizer.py +60 -0
  90. keras_hub/src/models/smollm3/smollm3_utils.py +56 -0
  91. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +3 -3
  92. keras_hub/src/models/t5gemma/__init__.py +5 -0
  93. keras_hub/src/models/t5gemma/t5gemma_attention.py +370 -0
  94. keras_hub/src/models/t5gemma/t5gemma_backbone.py +366 -0
  95. keras_hub/src/models/t5gemma/t5gemma_decoder.py +355 -0
  96. keras_hub/src/models/t5gemma/t5gemma_encoder.py +214 -0
  97. keras_hub/src/models/t5gemma/t5gemma_layers.py +118 -0
  98. keras_hub/src/models/t5gemma/t5gemma_presets.py +374 -0
  99. keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm.py +442 -0
  100. keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm_preprocessor.py +216 -0
  101. keras_hub/src/models/t5gemma/t5gemma_tokenizer.py +84 -0
  102. keras_hub/src/models/text_to_image.py +5 -0
  103. keras_hub/src/samplers/beam_sampler.py +6 -6
  104. keras_hub/src/samplers/sampler.py +8 -6
  105. keras_hub/src/tests/test_case.py +40 -3
  106. keras_hub/src/tokenizers/tokenizer.py +15 -0
  107. keras_hub/src/utils/openvino_utils.py +141 -0
  108. keras_hub/src/utils/preset_utils.py +58 -2
  109. keras_hub/src/utils/tensor_utils.py +26 -2
  110. keras_hub/src/utils/timm/convert_mobilenetv5.py +321 -0
  111. keras_hub/src/utils/timm/preset_loader.py +8 -4
  112. keras_hub/src/utils/transformers/convert_dinov2.py +1 -0
  113. keras_hub/src/utils/transformers/convert_dinov3.py +106 -0
  114. keras_hub/src/utils/transformers/convert_qwen3_moe.py +216 -0
  115. keras_hub/src/utils/transformers/convert_smollm3.py +139 -0
  116. keras_hub/src/utils/transformers/convert_t5gemma.py +229 -0
  117. keras_hub/src/utils/transformers/convert_vit.py +4 -1
  118. keras_hub/src/utils/transformers/export/gemma.py +49 -4
  119. keras_hub/src/utils/transformers/export/hf_exporter.py +71 -25
  120. keras_hub/src/utils/transformers/preset_loader.py +12 -0
  121. keras_hub/src/version.py +1 -1
  122. keras_hub/tokenizers/__init__.py +15 -0
  123. {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/METADATA +1 -1
  124. {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/RECORD +126 -47
  125. {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/WHEEL +0 -0
  126. {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,60 @@
1
+ from keras_hub.src.api_export import keras_hub_export
2
+ from keras_hub.src.models.smollm3.smollm3_backbone import SmolLM3Backbone
3
+ from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer
4
+
5
+
6
+ @keras_hub_export(
7
+ [
8
+ "keras_hub.tokenizers.SmolLM3Tokenizer",
9
+ "keras_hub.tokenizers.SmolLMTokenizer",
10
+ "keras_hub.models.SmolLM3Tokenizer",
11
+ "keras_hub.models.SmolLMTokenizer",
12
+ ]
13
+ )
14
+ class SmolLM3Tokenizer(BytePairTokenizer):
15
+ """Tokenizer for SmolLM3 models.
16
+
17
+ This tokenizer implements byte-pair encoding (BPE) for SmolLM3 models,
18
+ handling special tokens like BOS (beginning of sequence) and EOS (end of
19
+ sequence).
20
+
21
+ Args:
22
+ vocabulary: Dictionary mapping tokens to token IDs, or path to
23
+ vocabulary file.
24
+ merges: List of BPE merges, or path to merges file.
25
+ bos_token: Beginning of sequence token. Defaults to None.
26
+ eos_token: End of sequence token. Defaults to "<|endoftext|>".
27
+ misc_special_tokens: Set of additional special tokens. Defaults to
28
+ empty set.
29
+ """
30
+
31
+ backbone_cls = SmolLM3Backbone
32
+
33
+ def __init__(
34
+ self,
35
+ vocabulary=None,
36
+ merges=None,
37
+ **kwargs,
38
+ ):
39
+ # Add EOS token
40
+ eos_token = "<|end_of_text|>"
41
+ self._add_special_token(eos_token, "end_token")
42
+
43
+ bos_token = "<|begin_of_text|>"
44
+ self._add_special_token(bos_token, "bos_token")
45
+
46
+ start_think_token = "<think>"
47
+ self._add_special_token(start_think_token, "start_think_token")
48
+
49
+ end_think_token = "</think>"
50
+ self._add_special_token(end_think_token, "end_think_token")
51
+
52
+ self.start_token_id = None
53
+ self.start_token = None
54
+ self.pad_token_id = 0
55
+
56
+ super().__init__(
57
+ vocabulary=vocabulary,
58
+ merges=merges,
59
+ **kwargs,
60
+ )
@@ -0,0 +1,56 @@
1
+ from keras import ops
2
+
3
+
4
+ def rotate_half(x):
5
+ x1 = x[..., : ops.shape(x)[-1] // 2]
6
+ x2 = x[..., ops.shape(x)[-1] // 2 :]
7
+ return ops.concatenate((-x2, x1), axis=-1)
8
+
9
+
10
+ def apply_rotary_pos_emb(q, k, cos, sin, expansion_axis=1):
11
+ cos = ops.expand_dims(cos, expansion_axis)
12
+ sin = ops.expand_dims(sin, expansion_axis)
13
+ q_embed = (q * cos) + (rotate_half(q) * sin)
14
+ k_embed = (k * cos) + (rotate_half(k) * sin)
15
+ return q_embed, k_embed
16
+
17
+
18
+ def apply_rotary_pos_single(tensor, cos, sin, expansion_axis=1):
19
+ cos = ops.expand_dims(cos, expansion_axis)
20
+ sin = ops.expand_dims(sin, expansion_axis)
21
+ tensor_embed = (tensor * cos) + (rotate_half(tensor) * sin)
22
+ return tensor_embed
23
+
24
+
25
+ def repeat_kv(hidden_states, n_rep):
26
+ batch, num_key_value_heads, slen, head_dim = ops.shape(hidden_states)
27
+ if n_rep == 1:
28
+ return hidden_states
29
+ hidden_states = ops.expand_dims(hidden_states, axis=2)
30
+ target_shape = (batch, num_key_value_heads, n_rep, slen, head_dim)
31
+ hidden_states = ops.broadcast_to(hidden_states, target_shape)
32
+ return ops.reshape(
33
+ hidden_states, [batch, num_key_value_heads * n_rep, slen, head_dim]
34
+ )
35
+
36
+
37
+ def rope_init(rope_theta, partial_rotary_factor, head_dim):
38
+ """Initialize RoPE (Rotary Position Embedding) parameters.
39
+
40
+ Args:
41
+ rope_theta: float. The theta value for RoPE.
42
+ partial_rotary_factor: float. The factor for partial rotary embedding.
43
+ head_dim: int. The dimension of each attention head.
44
+
45
+ Returns:
46
+ A tuple of (inv_freq, attention_scaling) where inv_freq is the inverse
47
+ frequency tensor and attention_scaling is the scaling factor.
48
+ """
49
+ base = rope_theta
50
+ dim = int(head_dim * partial_rotary_factor)
51
+
52
+ inv_freq = 1.0 / (
53
+ ops.power(base, ops.arange(0, dim, 2, dtype="float32") / dim)
54
+ )
55
+ attention_scaling = 1.0
56
+ return inv_freq, attention_scaling
@@ -11,7 +11,7 @@ backbone_presets = {
11
11
  "params": 2987080931,
12
12
  "path": "stable_diffusion_3",
13
13
  },
14
- "kaggle_handle": "kaggle://keras/stablediffusion3/keras/stable_diffusion_3_medium/4",
14
+ "kaggle_handle": "kaggle://keras/stablediffusion3/keras/stable_diffusion_3_medium/5",
15
15
  },
16
16
  "stable_diffusion_3.5_medium": {
17
17
  "metadata": {
@@ -35,7 +35,7 @@ backbone_presets = {
35
35
  "params": 9048410595,
36
36
  "path": "stable_diffusion_3",
37
37
  },
38
- "kaggle_handle": "kaggle://keras/stablediffusion-3.5/keras/stable_diffusion_3.5_large/2",
38
+ "kaggle_handle": "kaggle://keras/stablediffusion-3.5/keras/stable_diffusion_3.5_large/3",
39
39
  },
40
40
  "stable_diffusion_3.5_large_turbo": {
41
41
  "metadata": {
@@ -49,6 +49,6 @@ backbone_presets = {
49
49
  "params": 9048410595,
50
50
  "path": "stable_diffusion_3",
51
51
  },
52
- "kaggle_handle": "kaggle://keras/stablediffusion-3.5/keras/stable_diffusion_3.5_large_turbo/2",
52
+ "kaggle_handle": "kaggle://keras/stablediffusion-3.5/keras/stable_diffusion_3.5_large_turbo/3",
53
53
  },
54
54
  }
@@ -0,0 +1,5 @@
1
+ from keras_hub.src.models.t5gemma.t5gemma_backbone import T5GemmaBackbone
2
+ from keras_hub.src.models.t5gemma.t5gemma_presets import backbone_presets
3
+ from keras_hub.src.utils.preset_utils import register_presets
4
+
5
+ register_presets(backbone_presets, T5GemmaBackbone)
@@ -0,0 +1,370 @@
1
+ import inspect
2
+
3
+ import keras
4
+
5
+ from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
6
+ from keras_hub.src.models.gemma.gemma_attention import CachedGemmaAttention
7
+ from keras_hub.src.models.t5gemma.t5gemma_layers import (
8
+ t5gemma_kernel_initializer,
9
+ )
10
+ from keras_hub.src.utils.keras_utils import clone_initializer
11
+
12
+
13
+ def repeat_kv(hidden_states, n_rep):
14
+ """Repeats the key/value hidden states to match the number of query heads
15
+ for Grouped Query Attention (GQA).
16
+
17
+ This function is used in `T5GemmaAttention` to broadcast key and value
18
+ states across multiple query heads when Grouped Query Attention (GQA) is
19
+ used (i.e., when `num_query_heads` > `num_key_value_heads`).
20
+
21
+ Args:
22
+ hidden_states: Tensor, The key or value hidden states with shape
23
+ `(batch, sequence_length, num_key_value_heads, head_dim)`.
24
+ n_rep: int, The number of times to repeat the key/value heads. This is
25
+ typically `num_query_heads // num_key_value_heads`.
26
+
27
+ Returns:
28
+ Tensor: The expanded key/value hidden states with shape
29
+ `(batch, sequence_length, num_query_heads, head_dim)`.
30
+ """
31
+ if n_rep == 1:
32
+ return hidden_states
33
+ batch, slen, num_key_value_heads, head_dim = keras.ops.shape(hidden_states)
34
+ hidden_states = keras.ops.expand_dims(hidden_states, 3)
35
+ hidden_states = keras.ops.tile(hidden_states, (1, 1, 1, n_rep, 1))
36
+ return keras.ops.reshape(
37
+ hidden_states, (batch, slen, num_key_value_heads * n_rep, head_dim)
38
+ )
39
+
40
+
41
+ class T5GemmaAttention(CachedGemmaAttention):
42
+ """A unified attention layer for T5Gemma that handles both self-attention
43
+ and cross-attention.
44
+
45
+ This layer performs attention with optional Rotary Positional Embeddings
46
+ (RoPE) and supports Grouped Query Attention (GQA). It is used in
47
+ `T5GemmaEncoderLayer` and `T5GemmaDecoderLayer`.
48
+
49
+ Args:
50
+ hidden_size: int, The dimensionality of the hidden states.
51
+ num_attention_heads: int, The number of attention heads.
52
+ num_key_value_heads: int, The number of key-value heads. For GQA, this
53
+ can be less than `num_attention_heads`.
54
+ query_pre_attn_scalar: float, Scalar to multiply queries by before
55
+ attention.
56
+ attention_bias: bool, Whether to include bias in the dense layers.
57
+ head_dim: int, The dimensionality of each attention head.
58
+ attention_type: str, The type of attention, either 'self' or 'cross'.
59
+ Defaults to 'self'.
60
+ cross_attention_hidden_size: int, optional, The dimensionality of
61
+ encoder hidden states for cross-attention. Defaults to `None`.
62
+ initializer_range: float, The range for the random normal initializer
63
+ for kernel weights. Defaults to `0.02`.
64
+ attention_dropout: float, The dropout rate applied to attention weights.
65
+ Defaults to `0.0`.
66
+ attn_logit_softcapping: float, optional, The softcapping value for
67
+ attention logits. Defaults to `None`.
68
+ rope_max_wavelength: float, The maximum wavelength for Rotary Positional
69
+ Embeddings. Defaults to `10000.0`. Only used for self-attention.
70
+ dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
71
+ for model computations and weights. Defaults to `None`.
72
+ **kwargs: Additional keyword arguments passed to the parent class.
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ hidden_size,
78
+ num_attention_heads,
79
+ num_key_value_heads,
80
+ query_pre_attn_scalar,
81
+ attention_bias,
82
+ head_dim,
83
+ attention_type="self",
84
+ cross_attention_hidden_size=None,
85
+ initializer_range=0.02,
86
+ attention_dropout=0.0,
87
+ attn_logit_softcapping=None,
88
+ rope_max_wavelength=10000.0,
89
+ dtype=None,
90
+ **kwargs,
91
+ ):
92
+ super().__init__(
93
+ head_dim=head_dim,
94
+ num_query_heads=num_attention_heads,
95
+ num_key_value_heads=num_key_value_heads,
96
+ kernel_initializer=t5gemma_kernel_initializer(initializer_range),
97
+ logit_soft_cap=attn_logit_softcapping,
98
+ dropout=attention_dropout,
99
+ query_head_dim_normalize=False,
100
+ use_sliding_window_attention=False,
101
+ dtype=dtype,
102
+ **kwargs,
103
+ )
104
+ if attention_type not in ["self", "cross"]:
105
+ raise ValueError(
106
+ f"attention_type must be 'self' or 'cross', but got "
107
+ f"{attention_type}"
108
+ )
109
+ self.attention_type = attention_type
110
+ self.hidden_size = hidden_size
111
+ self.cross_attention_hidden_size = (
112
+ cross_attention_hidden_size or hidden_size
113
+ )
114
+ self.query_pre_attn_scalar = query_pre_attn_scalar
115
+ self.attention_bias = attention_bias
116
+ self.initializer_range = initializer_range
117
+ self.attention_dropout = attention_dropout
118
+ self.rope_max_wavelength = rope_max_wavelength
119
+ self.num_key_value_groups = (
120
+ self.num_query_heads // self.num_key_value_heads
121
+ )
122
+ self.scaling = self.query_pre_attn_scalar**-0.5
123
+ if self.attention_type == "self":
124
+ self.rotary_embedding = RotaryEmbedding(
125
+ max_wavelength=self.rope_max_wavelength,
126
+ sequence_axis=1,
127
+ feature_axis=3,
128
+ name="rotary_embedding",
129
+ dtype=self.dtype_policy,
130
+ )
131
+
132
+ def build(self, input_shape):
133
+ self._kernel_initializer = t5gemma_kernel_initializer(
134
+ self.initializer_range
135
+ )
136
+
137
+ if self.attention_type == "cross":
138
+ hidden_states_shape, kv_states_shape = input_shape
139
+ else:
140
+ hidden_states_shape = input_shape
141
+ kv_states_shape = input_shape
142
+ # Query projection layer.
143
+ self.hidden_dim = hidden_states_shape[-1]
144
+ self.query_dense = keras.layers.EinsumDense(
145
+ equation="btd,dnh->btnh",
146
+ output_shape=(None, self.num_query_heads, self.head_dim),
147
+ kernel_initializer=clone_initializer(self._kernel_initializer),
148
+ bias_axes="nh" if self.attention_bias else None,
149
+ dtype=self.dtype_policy,
150
+ name="query",
151
+ )
152
+ self.query_dense.build(hidden_states_shape)
153
+
154
+ # Key projection layer.
155
+ self.key_dense = keras.layers.EinsumDense(
156
+ equation="bsd,dkh->bskh",
157
+ output_shape=(None, self.num_key_value_heads, self.head_dim),
158
+ kernel_initializer=clone_initializer(self._kernel_initializer),
159
+ bias_axes="kh" if self.attention_bias else None,
160
+ dtype=self.dtype_policy,
161
+ name="key",
162
+ )
163
+ self.key_dense.build(kv_states_shape)
164
+
165
+ # Value projection layer.
166
+ self.value_dense = keras.layers.EinsumDense(
167
+ equation="bsd,dkh->bskh",
168
+ output_shape=(None, self.num_key_value_heads, self.head_dim),
169
+ kernel_initializer=clone_initializer(self._kernel_initializer),
170
+ bias_axes="kh" if self.attention_bias else None,
171
+ dtype=self.dtype_policy,
172
+ name="value",
173
+ )
174
+ self.value_dense.build(kv_states_shape)
175
+
176
+ # Output projection layer.
177
+ self.output_dense = keras.layers.EinsumDense(
178
+ equation="btnh,nhd->btd",
179
+ output_shape=(None, self.hidden_dim),
180
+ kernel_initializer=clone_initializer(self._kernel_initializer),
181
+ bias_axes="d" if self.attention_bias else None,
182
+ dtype=self.dtype_policy,
183
+ name="attention_output",
184
+ )
185
+ self.output_dense.build(
186
+ (
187
+ hidden_states_shape[0],
188
+ hidden_states_shape[1],
189
+ self.num_query_heads,
190
+ self.head_dim,
191
+ )
192
+ )
193
+ self.dropout_layer = keras.layers.Dropout(
194
+ rate=self.attention_dropout,
195
+ dtype=self.dtype_policy,
196
+ )
197
+ self.softmax = keras.layers.Softmax(axis=-1, dtype="float32")
198
+ self.built = True
199
+
200
+ def _compute_attention_without_fused_op(
201
+ self, query_states, key_states, value_states, attention_mask, training
202
+ ):
203
+ attn_weights = keras.ops.einsum(
204
+ "btnh,bsnh->bnts", query_states, key_states
205
+ )
206
+ attn_weights *= self.scaling
207
+ if self.logit_soft_cap is not None:
208
+ attn_weights = attn_weights / self.logit_soft_cap
209
+ attn_weights = keras.ops.tanh(attn_weights)
210
+ attn_weights = attn_weights * self.logit_soft_cap
211
+ if attention_mask is not None:
212
+ attn_weights += attention_mask
213
+ attn_weights = keras.ops.cast(
214
+ self.softmax(attn_weights),
215
+ query_states.dtype,
216
+ )
217
+ attn_weights = self.dropout_layer(attn_weights, training=training)
218
+ attn_output = keras.ops.einsum(
219
+ "bnts,bsnh->btnh", attn_weights, value_states
220
+ )
221
+ return attn_output
222
+
223
+ def _compute_attention(
224
+ self, query_states, key_states, value_states, attention_mask, training
225
+ ):
226
+ if self._use_fused_attention_op():
227
+ kwargs = {"bias": attention_mask}
228
+ if self.logit_soft_cap is not None:
229
+ sig = inspect.signature(keras.ops.dot_product_attention)
230
+ # This is only supported in JAX TPU backend.
231
+ # https://keras.io/api/ops/nn/#dot_product_attention-function
232
+ if "attn_logits_soft_cap" in sig.parameters:
233
+ kwargs["attn_logits_soft_cap"] = self.logit_soft_cap
234
+ return keras.ops.dot_product_attention(
235
+ query=query_states,
236
+ key=key_states,
237
+ value=value_states,
238
+ scale=self.scaling,
239
+ **kwargs,
240
+ )
241
+ return self._compute_attention_without_fused_op(
242
+ query_states,
243
+ key_states,
244
+ value_states,
245
+ attention_mask,
246
+ training,
247
+ )
248
+
249
+ def call(
250
+ self,
251
+ inputs,
252
+ attention_mask=None,
253
+ cache=None,
254
+ cache_update_index=None,
255
+ training=None,
256
+ ):
257
+ if self.attention_type == "cross":
258
+ if not isinstance(inputs, (list, tuple)) or len(inputs) != 2:
259
+ raise ValueError(
260
+ "For cross-attention, `inputs` must be a list or tuple of "
261
+ "two tensors: `[hidden_states, encoder_hidden_states]`."
262
+ )
263
+ hidden_states, kv_states = inputs
264
+ query_states = self.query_dense(hidden_states)
265
+ if cache is not None:
266
+ if cache_update_index is not None:
267
+ raise ValueError(
268
+ "`cache_update_index` should not be set for "
269
+ "cross-attention caching."
270
+ )
271
+ key_states, value_states = cache[:, 0, ...], cache[:, 1, ...]
272
+ updated_cache = cache
273
+ else:
274
+ key_states = self.key_dense(kv_states)
275
+ value_states = self.value_dense(kv_states)
276
+ updated_cache = keras.ops.stack(
277
+ (key_states, value_states), axis=1
278
+ )
279
+ # Repeat key-value heads for GQA.
280
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
281
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
282
+ attn_output = self._compute_attention(
283
+ query_states, key_states, value_states, attention_mask, training
284
+ )
285
+ attn_output = self.output_dense(attn_output)
286
+ return attn_output, updated_cache
287
+ else: # Self-attention
288
+ hidden_states = inputs
289
+ kv_states = hidden_states
290
+ query_states = self.query_dense(hidden_states)
291
+ key_states = self.key_dense(kv_states)
292
+ value_states = self.value_dense(kv_states)
293
+ start_index = (
294
+ 0 if cache_update_index is None else cache_update_index
295
+ )
296
+ query_states = self.rotary_embedding(
297
+ query_states, start_index=start_index
298
+ )
299
+ key_states = self.rotary_embedding(
300
+ key_states, start_index=start_index
301
+ )
302
+ if cache is not None:
303
+ if cache_update_index is None:
304
+ raise ValueError(
305
+ "Both `cache` and `cache_update_index` must be passed "
306
+ "for self-attention caching."
307
+ )
308
+ key_cache, value_cache = cache[:, 0, ...], cache[:, 1, ...]
309
+ start = [0, cache_update_index, 0, 0]
310
+ key_states = keras.ops.slice_update(
311
+ key_cache, start, key_states
312
+ )
313
+ value_states = keras.ops.slice_update(
314
+ value_cache, start, value_states
315
+ )
316
+ cache = keras.ops.stack((key_states, value_states), axis=1)
317
+ elif cache_update_index is not None:
318
+ raise ValueError(
319
+ "`cache_update_index` should not be set if `cache` is "
320
+ "`None`."
321
+ )
322
+ else:
323
+ cache = keras.ops.stack((key_states, value_states), axis=1)
324
+
325
+ # Repeat key-value heads for GQA.
326
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
327
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
328
+
329
+ attn_output = self._compute_attention(
330
+ query_states, key_states, value_states, attention_mask, training
331
+ )
332
+ attn_output = self.output_dense(attn_output)
333
+ return attn_output, cache
334
+
335
+ def compute_output_shape(self, input_shape):
336
+ if self.attention_type == "cross":
337
+ hidden_states_shape, kv_states_shape = input_shape
338
+ else:
339
+ hidden_states_shape = input_shape
340
+ kv_states_shape = input_shape
341
+ attn_output_shape = hidden_states_shape
342
+ kv_len = kv_states_shape[1]
343
+ cache_shape = (
344
+ hidden_states_shape[0], # batch
345
+ 2, # key and value
346
+ kv_len,
347
+ self.num_key_value_heads,
348
+ self.head_dim,
349
+ )
350
+ return attn_output_shape, cache_shape
351
+
352
+ def get_config(self):
353
+ config = super().get_config()
354
+ config.update(
355
+ {
356
+ "hidden_size": self.hidden_size,
357
+ "head_dim": self.head_dim,
358
+ "num_attention_heads": self.num_query_heads,
359
+ "num_key_value_heads": self.num_key_value_heads,
360
+ "query_pre_attn_scalar": self.query_pre_attn_scalar,
361
+ "attention_bias": self.attention_bias,
362
+ "attention_type": self.attention_type,
363
+ "cross_attention_hidden_size": self.cross_attention_hidden_size,
364
+ "initializer_range": self.initializer_range,
365
+ "attention_dropout": self.attention_dropout,
366
+ "attn_logit_softcapping": self.logit_soft_cap,
367
+ "rope_max_wavelength": self.rope_max_wavelength,
368
+ }
369
+ )
370
+ return config