keras-hub-nightly 0.22.0.dev202508170419__py3-none-any.whl → 0.24.0.dev202511090424__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of keras-hub-nightly might be problematic. Click here for more details.
- keras_hub/layers/__init__.py +15 -0
- keras_hub/models/__init__.py +93 -0
- keras_hub/src/layers/modeling/position_embedding.py +21 -6
- keras_hub/src/layers/modeling/reversible_embedding.py +8 -1
- keras_hub/src/layers/modeling/rotary_embedding.py +16 -6
- keras_hub/src/layers/modeling/sine_position_encoding.py +21 -8
- keras_hub/src/layers/modeling/token_and_position_embedding.py +2 -1
- keras_hub/src/models/backbone.py +28 -16
- keras_hub/src/models/causal_lm.py +37 -0
- keras_hub/src/models/causal_lm_preprocessor.py +14 -0
- keras_hub/src/models/clip/clip_presets.py +8 -8
- keras_hub/src/models/d_fine/__init__.py +5 -0
- keras_hub/src/models/d_fine/d_fine_attention.py +461 -0
- keras_hub/src/models/d_fine/d_fine_backbone.py +891 -0
- keras_hub/src/models/d_fine/d_fine_decoder.py +944 -0
- keras_hub/src/models/d_fine/d_fine_encoder.py +365 -0
- keras_hub/src/models/d_fine/d_fine_hybrid_encoder.py +642 -0
- keras_hub/src/models/d_fine/d_fine_image_converter.py +8 -0
- keras_hub/src/models/d_fine/d_fine_layers.py +1828 -0
- keras_hub/src/models/d_fine/d_fine_loss.py +938 -0
- keras_hub/src/models/d_fine/d_fine_object_detector.py +875 -0
- keras_hub/src/models/d_fine/d_fine_object_detector_preprocessor.py +14 -0
- keras_hub/src/models/d_fine/d_fine_presets.py +155 -0
- keras_hub/src/models/d_fine/d_fine_utils.py +827 -0
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +7 -2
- keras_hub/src/models/depth_anything/__init__.py +9 -0
- keras_hub/src/models/depth_anything/depth_anything_backbone.py +232 -0
- keras_hub/src/models/depth_anything/depth_anything_depth_estimator.py +70 -0
- keras_hub/src/models/depth_anything/depth_anything_depth_estimator_preprocessor.py +16 -0
- keras_hub/src/models/depth_anything/depth_anything_image_converter.py +10 -0
- keras_hub/src/models/depth_anything/depth_anything_layers.py +725 -0
- keras_hub/src/models/depth_anything/depth_anything_loss.py +89 -0
- keras_hub/src/models/depth_anything/depth_anything_presets.py +41 -0
- keras_hub/src/models/depth_anything/interpolate.py +62 -0
- keras_hub/src/models/depth_estimator.py +239 -0
- keras_hub/src/models/depth_estimator_preprocessor.py +78 -0
- keras_hub/src/models/dinov2/dinov2_backbone.py +29 -3
- keras_hub/src/models/dinov2/dinov2_layers.py +16 -4
- keras_hub/src/models/dinov3/__init__.py +5 -0
- keras_hub/src/models/dinov3/dinov3_backbone.py +263 -0
- keras_hub/src/models/dinov3/dinov3_image_converter.py +8 -0
- keras_hub/src/models/dinov3/dinov3_layers.py +1013 -0
- keras_hub/src/models/dinov3/dinov3_presets.py +4 -0
- keras_hub/src/models/gemma/gemma_backbone.py +0 -1
- keras_hub/src/models/gemma/gemma_presets.py +30 -0
- keras_hub/src/models/gemma3/gemma3_attention.py +48 -0
- keras_hub/src/models/gemma3/gemma3_backbone.py +4 -1
- keras_hub/src/models/gemma3/gemma3_decoder_block.py +12 -0
- keras_hub/src/models/gemma3/gemma3_presets.py +39 -0
- keras_hub/src/models/hgnetv2/hgnetv2_backbone.py +4 -1
- keras_hub/src/models/hgnetv2/hgnetv2_encoder.py +3 -2
- keras_hub/src/models/hgnetv2/hgnetv2_layers.py +27 -11
- keras_hub/src/models/image_to_image.py +5 -0
- keras_hub/src/models/inpaint.py +5 -0
- keras_hub/src/models/mobilenetv5/__init__.py +9 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_attention.py +699 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_backbone.py +396 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_blocks.py +890 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_builder.py +436 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier.py +157 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier_preprocessor.py +16 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_image_converter.py +10 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_layers.py +462 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_presets.py +15 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_utils.py +146 -0
- keras_hub/src/models/parseq/__init__.py +5 -0
- keras_hub/src/models/parseq/parseq_backbone.py +134 -0
- keras_hub/src/models/parseq/parseq_causal_lm.py +466 -0
- keras_hub/src/models/parseq/parseq_causal_lm_preprocessor.py +168 -0
- keras_hub/src/models/parseq/parseq_decoder.py +418 -0
- keras_hub/src/models/parseq/parseq_image_converter.py +8 -0
- keras_hub/src/models/parseq/parseq_presets.py +15 -0
- keras_hub/src/models/parseq/parseq_tokenizer.py +221 -0
- keras_hub/src/models/qwen3_moe/__init__.py +5 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py +371 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +365 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm.py +357 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm_preprocessor.py +12 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_decoder.py +672 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_layernorm.py +45 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +30 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_tokenizer.py +48 -0
- keras_hub/src/models/sam/sam_prompt_encoder.py +3 -1
- keras_hub/src/models/siglip/siglip_presets.py +15 -0
- keras_hub/src/models/smollm3/smollm3_backbone.py +211 -0
- keras_hub/src/models/smollm3/smollm3_causal_lm.py +310 -0
- keras_hub/src/models/smollm3/smollm3_causal_lm_preprocessor.py +84 -0
- keras_hub/src/models/smollm3/smollm3_layers.py +757 -0
- keras_hub/src/models/smollm3/smollm3_tokenizer.py +60 -0
- keras_hub/src/models/smollm3/smollm3_utils.py +56 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +3 -3
- keras_hub/src/models/t5gemma/__init__.py +5 -0
- keras_hub/src/models/t5gemma/t5gemma_attention.py +370 -0
- keras_hub/src/models/t5gemma/t5gemma_backbone.py +366 -0
- keras_hub/src/models/t5gemma/t5gemma_decoder.py +355 -0
- keras_hub/src/models/t5gemma/t5gemma_encoder.py +214 -0
- keras_hub/src/models/t5gemma/t5gemma_layers.py +118 -0
- keras_hub/src/models/t5gemma/t5gemma_presets.py +374 -0
- keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm.py +442 -0
- keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm_preprocessor.py +216 -0
- keras_hub/src/models/t5gemma/t5gemma_tokenizer.py +84 -0
- keras_hub/src/models/text_to_image.py +5 -0
- keras_hub/src/samplers/beam_sampler.py +6 -6
- keras_hub/src/samplers/sampler.py +8 -6
- keras_hub/src/tests/test_case.py +40 -3
- keras_hub/src/tokenizers/tokenizer.py +15 -0
- keras_hub/src/utils/openvino_utils.py +141 -0
- keras_hub/src/utils/preset_utils.py +58 -2
- keras_hub/src/utils/tensor_utils.py +26 -2
- keras_hub/src/utils/timm/convert_mobilenetv5.py +321 -0
- keras_hub/src/utils/timm/preset_loader.py +8 -4
- keras_hub/src/utils/transformers/convert_dinov2.py +1 -0
- keras_hub/src/utils/transformers/convert_dinov3.py +106 -0
- keras_hub/src/utils/transformers/convert_qwen3_moe.py +216 -0
- keras_hub/src/utils/transformers/convert_smollm3.py +139 -0
- keras_hub/src/utils/transformers/convert_t5gemma.py +229 -0
- keras_hub/src/utils/transformers/convert_vit.py +4 -1
- keras_hub/src/utils/transformers/export/gemma.py +49 -4
- keras_hub/src/utils/transformers/export/hf_exporter.py +71 -25
- keras_hub/src/utils/transformers/preset_loader.py +12 -0
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +15 -0
- {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/RECORD +126 -47
- {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
from keras_hub.src.api_export import keras_hub_export
|
|
2
|
+
from keras_hub.src.models.smollm3.smollm3_backbone import SmolLM3Backbone
|
|
3
|
+
from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@keras_hub_export(
|
|
7
|
+
[
|
|
8
|
+
"keras_hub.tokenizers.SmolLM3Tokenizer",
|
|
9
|
+
"keras_hub.tokenizers.SmolLMTokenizer",
|
|
10
|
+
"keras_hub.models.SmolLM3Tokenizer",
|
|
11
|
+
"keras_hub.models.SmolLMTokenizer",
|
|
12
|
+
]
|
|
13
|
+
)
|
|
14
|
+
class SmolLM3Tokenizer(BytePairTokenizer):
|
|
15
|
+
"""Tokenizer for SmolLM3 models.
|
|
16
|
+
|
|
17
|
+
This tokenizer implements byte-pair encoding (BPE) for SmolLM3 models,
|
|
18
|
+
handling special tokens like BOS (beginning of sequence) and EOS (end of
|
|
19
|
+
sequence).
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
vocabulary: Dictionary mapping tokens to token IDs, or path to
|
|
23
|
+
vocabulary file.
|
|
24
|
+
merges: List of BPE merges, or path to merges file.
|
|
25
|
+
bos_token: Beginning of sequence token. Defaults to None.
|
|
26
|
+
eos_token: End of sequence token. Defaults to "<|endoftext|>".
|
|
27
|
+
misc_special_tokens: Set of additional special tokens. Defaults to
|
|
28
|
+
empty set.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
backbone_cls = SmolLM3Backbone
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
vocabulary=None,
|
|
36
|
+
merges=None,
|
|
37
|
+
**kwargs,
|
|
38
|
+
):
|
|
39
|
+
# Add EOS token
|
|
40
|
+
eos_token = "<|end_of_text|>"
|
|
41
|
+
self._add_special_token(eos_token, "end_token")
|
|
42
|
+
|
|
43
|
+
bos_token = "<|begin_of_text|>"
|
|
44
|
+
self._add_special_token(bos_token, "bos_token")
|
|
45
|
+
|
|
46
|
+
start_think_token = "<think>"
|
|
47
|
+
self._add_special_token(start_think_token, "start_think_token")
|
|
48
|
+
|
|
49
|
+
end_think_token = "</think>"
|
|
50
|
+
self._add_special_token(end_think_token, "end_think_token")
|
|
51
|
+
|
|
52
|
+
self.start_token_id = None
|
|
53
|
+
self.start_token = None
|
|
54
|
+
self.pad_token_id = 0
|
|
55
|
+
|
|
56
|
+
super().__init__(
|
|
57
|
+
vocabulary=vocabulary,
|
|
58
|
+
merges=merges,
|
|
59
|
+
**kwargs,
|
|
60
|
+
)
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from keras import ops
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def rotate_half(x):
|
|
5
|
+
x1 = x[..., : ops.shape(x)[-1] // 2]
|
|
6
|
+
x2 = x[..., ops.shape(x)[-1] // 2 :]
|
|
7
|
+
return ops.concatenate((-x2, x1), axis=-1)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def apply_rotary_pos_emb(q, k, cos, sin, expansion_axis=1):
|
|
11
|
+
cos = ops.expand_dims(cos, expansion_axis)
|
|
12
|
+
sin = ops.expand_dims(sin, expansion_axis)
|
|
13
|
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
14
|
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
15
|
+
return q_embed, k_embed
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def apply_rotary_pos_single(tensor, cos, sin, expansion_axis=1):
|
|
19
|
+
cos = ops.expand_dims(cos, expansion_axis)
|
|
20
|
+
sin = ops.expand_dims(sin, expansion_axis)
|
|
21
|
+
tensor_embed = (tensor * cos) + (rotate_half(tensor) * sin)
|
|
22
|
+
return tensor_embed
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def repeat_kv(hidden_states, n_rep):
|
|
26
|
+
batch, num_key_value_heads, slen, head_dim = ops.shape(hidden_states)
|
|
27
|
+
if n_rep == 1:
|
|
28
|
+
return hidden_states
|
|
29
|
+
hidden_states = ops.expand_dims(hidden_states, axis=2)
|
|
30
|
+
target_shape = (batch, num_key_value_heads, n_rep, slen, head_dim)
|
|
31
|
+
hidden_states = ops.broadcast_to(hidden_states, target_shape)
|
|
32
|
+
return ops.reshape(
|
|
33
|
+
hidden_states, [batch, num_key_value_heads * n_rep, slen, head_dim]
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def rope_init(rope_theta, partial_rotary_factor, head_dim):
|
|
38
|
+
"""Initialize RoPE (Rotary Position Embedding) parameters.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
rope_theta: float. The theta value for RoPE.
|
|
42
|
+
partial_rotary_factor: float. The factor for partial rotary embedding.
|
|
43
|
+
head_dim: int. The dimension of each attention head.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
A tuple of (inv_freq, attention_scaling) where inv_freq is the inverse
|
|
47
|
+
frequency tensor and attention_scaling is the scaling factor.
|
|
48
|
+
"""
|
|
49
|
+
base = rope_theta
|
|
50
|
+
dim = int(head_dim * partial_rotary_factor)
|
|
51
|
+
|
|
52
|
+
inv_freq = 1.0 / (
|
|
53
|
+
ops.power(base, ops.arange(0, dim, 2, dtype="float32") / dim)
|
|
54
|
+
)
|
|
55
|
+
attention_scaling = 1.0
|
|
56
|
+
return inv_freq, attention_scaling
|
|
@@ -11,7 +11,7 @@ backbone_presets = {
|
|
|
11
11
|
"params": 2987080931,
|
|
12
12
|
"path": "stable_diffusion_3",
|
|
13
13
|
},
|
|
14
|
-
"kaggle_handle": "kaggle://keras/stablediffusion3/keras/stable_diffusion_3_medium/
|
|
14
|
+
"kaggle_handle": "kaggle://keras/stablediffusion3/keras/stable_diffusion_3_medium/5",
|
|
15
15
|
},
|
|
16
16
|
"stable_diffusion_3.5_medium": {
|
|
17
17
|
"metadata": {
|
|
@@ -35,7 +35,7 @@ backbone_presets = {
|
|
|
35
35
|
"params": 9048410595,
|
|
36
36
|
"path": "stable_diffusion_3",
|
|
37
37
|
},
|
|
38
|
-
"kaggle_handle": "kaggle://keras/stablediffusion-3.5/keras/stable_diffusion_3.5_large/
|
|
38
|
+
"kaggle_handle": "kaggle://keras/stablediffusion-3.5/keras/stable_diffusion_3.5_large/3",
|
|
39
39
|
},
|
|
40
40
|
"stable_diffusion_3.5_large_turbo": {
|
|
41
41
|
"metadata": {
|
|
@@ -49,6 +49,6 @@ backbone_presets = {
|
|
|
49
49
|
"params": 9048410595,
|
|
50
50
|
"path": "stable_diffusion_3",
|
|
51
51
|
},
|
|
52
|
-
"kaggle_handle": "kaggle://keras/stablediffusion-3.5/keras/stable_diffusion_3.5_large_turbo/
|
|
52
|
+
"kaggle_handle": "kaggle://keras/stablediffusion-3.5/keras/stable_diffusion_3.5_large_turbo/3",
|
|
53
53
|
},
|
|
54
54
|
}
|
|
@@ -0,0 +1,370 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
|
|
3
|
+
import keras
|
|
4
|
+
|
|
5
|
+
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
|
|
6
|
+
from keras_hub.src.models.gemma.gemma_attention import CachedGemmaAttention
|
|
7
|
+
from keras_hub.src.models.t5gemma.t5gemma_layers import (
|
|
8
|
+
t5gemma_kernel_initializer,
|
|
9
|
+
)
|
|
10
|
+
from keras_hub.src.utils.keras_utils import clone_initializer
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def repeat_kv(hidden_states, n_rep):
|
|
14
|
+
"""Repeats the key/value hidden states to match the number of query heads
|
|
15
|
+
for Grouped Query Attention (GQA).
|
|
16
|
+
|
|
17
|
+
This function is used in `T5GemmaAttention` to broadcast key and value
|
|
18
|
+
states across multiple query heads when Grouped Query Attention (GQA) is
|
|
19
|
+
used (i.e., when `num_query_heads` > `num_key_value_heads`).
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
hidden_states: Tensor, The key or value hidden states with shape
|
|
23
|
+
`(batch, sequence_length, num_key_value_heads, head_dim)`.
|
|
24
|
+
n_rep: int, The number of times to repeat the key/value heads. This is
|
|
25
|
+
typically `num_query_heads // num_key_value_heads`.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
Tensor: The expanded key/value hidden states with shape
|
|
29
|
+
`(batch, sequence_length, num_query_heads, head_dim)`.
|
|
30
|
+
"""
|
|
31
|
+
if n_rep == 1:
|
|
32
|
+
return hidden_states
|
|
33
|
+
batch, slen, num_key_value_heads, head_dim = keras.ops.shape(hidden_states)
|
|
34
|
+
hidden_states = keras.ops.expand_dims(hidden_states, 3)
|
|
35
|
+
hidden_states = keras.ops.tile(hidden_states, (1, 1, 1, n_rep, 1))
|
|
36
|
+
return keras.ops.reshape(
|
|
37
|
+
hidden_states, (batch, slen, num_key_value_heads * n_rep, head_dim)
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class T5GemmaAttention(CachedGemmaAttention):
|
|
42
|
+
"""A unified attention layer for T5Gemma that handles both self-attention
|
|
43
|
+
and cross-attention.
|
|
44
|
+
|
|
45
|
+
This layer performs attention with optional Rotary Positional Embeddings
|
|
46
|
+
(RoPE) and supports Grouped Query Attention (GQA). It is used in
|
|
47
|
+
`T5GemmaEncoderLayer` and `T5GemmaDecoderLayer`.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
hidden_size: int, The dimensionality of the hidden states.
|
|
51
|
+
num_attention_heads: int, The number of attention heads.
|
|
52
|
+
num_key_value_heads: int, The number of key-value heads. For GQA, this
|
|
53
|
+
can be less than `num_attention_heads`.
|
|
54
|
+
query_pre_attn_scalar: float, Scalar to multiply queries by before
|
|
55
|
+
attention.
|
|
56
|
+
attention_bias: bool, Whether to include bias in the dense layers.
|
|
57
|
+
head_dim: int, The dimensionality of each attention head.
|
|
58
|
+
attention_type: str, The type of attention, either 'self' or 'cross'.
|
|
59
|
+
Defaults to 'self'.
|
|
60
|
+
cross_attention_hidden_size: int, optional, The dimensionality of
|
|
61
|
+
encoder hidden states for cross-attention. Defaults to `None`.
|
|
62
|
+
initializer_range: float, The range for the random normal initializer
|
|
63
|
+
for kernel weights. Defaults to `0.02`.
|
|
64
|
+
attention_dropout: float, The dropout rate applied to attention weights.
|
|
65
|
+
Defaults to `0.0`.
|
|
66
|
+
attn_logit_softcapping: float, optional, The softcapping value for
|
|
67
|
+
attention logits. Defaults to `None`.
|
|
68
|
+
rope_max_wavelength: float, The maximum wavelength for Rotary Positional
|
|
69
|
+
Embeddings. Defaults to `10000.0`. Only used for self-attention.
|
|
70
|
+
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
|
|
71
|
+
for model computations and weights. Defaults to `None`.
|
|
72
|
+
**kwargs: Additional keyword arguments passed to the parent class.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
def __init__(
|
|
76
|
+
self,
|
|
77
|
+
hidden_size,
|
|
78
|
+
num_attention_heads,
|
|
79
|
+
num_key_value_heads,
|
|
80
|
+
query_pre_attn_scalar,
|
|
81
|
+
attention_bias,
|
|
82
|
+
head_dim,
|
|
83
|
+
attention_type="self",
|
|
84
|
+
cross_attention_hidden_size=None,
|
|
85
|
+
initializer_range=0.02,
|
|
86
|
+
attention_dropout=0.0,
|
|
87
|
+
attn_logit_softcapping=None,
|
|
88
|
+
rope_max_wavelength=10000.0,
|
|
89
|
+
dtype=None,
|
|
90
|
+
**kwargs,
|
|
91
|
+
):
|
|
92
|
+
super().__init__(
|
|
93
|
+
head_dim=head_dim,
|
|
94
|
+
num_query_heads=num_attention_heads,
|
|
95
|
+
num_key_value_heads=num_key_value_heads,
|
|
96
|
+
kernel_initializer=t5gemma_kernel_initializer(initializer_range),
|
|
97
|
+
logit_soft_cap=attn_logit_softcapping,
|
|
98
|
+
dropout=attention_dropout,
|
|
99
|
+
query_head_dim_normalize=False,
|
|
100
|
+
use_sliding_window_attention=False,
|
|
101
|
+
dtype=dtype,
|
|
102
|
+
**kwargs,
|
|
103
|
+
)
|
|
104
|
+
if attention_type not in ["self", "cross"]:
|
|
105
|
+
raise ValueError(
|
|
106
|
+
f"attention_type must be 'self' or 'cross', but got "
|
|
107
|
+
f"{attention_type}"
|
|
108
|
+
)
|
|
109
|
+
self.attention_type = attention_type
|
|
110
|
+
self.hidden_size = hidden_size
|
|
111
|
+
self.cross_attention_hidden_size = (
|
|
112
|
+
cross_attention_hidden_size or hidden_size
|
|
113
|
+
)
|
|
114
|
+
self.query_pre_attn_scalar = query_pre_attn_scalar
|
|
115
|
+
self.attention_bias = attention_bias
|
|
116
|
+
self.initializer_range = initializer_range
|
|
117
|
+
self.attention_dropout = attention_dropout
|
|
118
|
+
self.rope_max_wavelength = rope_max_wavelength
|
|
119
|
+
self.num_key_value_groups = (
|
|
120
|
+
self.num_query_heads // self.num_key_value_heads
|
|
121
|
+
)
|
|
122
|
+
self.scaling = self.query_pre_attn_scalar**-0.5
|
|
123
|
+
if self.attention_type == "self":
|
|
124
|
+
self.rotary_embedding = RotaryEmbedding(
|
|
125
|
+
max_wavelength=self.rope_max_wavelength,
|
|
126
|
+
sequence_axis=1,
|
|
127
|
+
feature_axis=3,
|
|
128
|
+
name="rotary_embedding",
|
|
129
|
+
dtype=self.dtype_policy,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
def build(self, input_shape):
|
|
133
|
+
self._kernel_initializer = t5gemma_kernel_initializer(
|
|
134
|
+
self.initializer_range
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
if self.attention_type == "cross":
|
|
138
|
+
hidden_states_shape, kv_states_shape = input_shape
|
|
139
|
+
else:
|
|
140
|
+
hidden_states_shape = input_shape
|
|
141
|
+
kv_states_shape = input_shape
|
|
142
|
+
# Query projection layer.
|
|
143
|
+
self.hidden_dim = hidden_states_shape[-1]
|
|
144
|
+
self.query_dense = keras.layers.EinsumDense(
|
|
145
|
+
equation="btd,dnh->btnh",
|
|
146
|
+
output_shape=(None, self.num_query_heads, self.head_dim),
|
|
147
|
+
kernel_initializer=clone_initializer(self._kernel_initializer),
|
|
148
|
+
bias_axes="nh" if self.attention_bias else None,
|
|
149
|
+
dtype=self.dtype_policy,
|
|
150
|
+
name="query",
|
|
151
|
+
)
|
|
152
|
+
self.query_dense.build(hidden_states_shape)
|
|
153
|
+
|
|
154
|
+
# Key projection layer.
|
|
155
|
+
self.key_dense = keras.layers.EinsumDense(
|
|
156
|
+
equation="bsd,dkh->bskh",
|
|
157
|
+
output_shape=(None, self.num_key_value_heads, self.head_dim),
|
|
158
|
+
kernel_initializer=clone_initializer(self._kernel_initializer),
|
|
159
|
+
bias_axes="kh" if self.attention_bias else None,
|
|
160
|
+
dtype=self.dtype_policy,
|
|
161
|
+
name="key",
|
|
162
|
+
)
|
|
163
|
+
self.key_dense.build(kv_states_shape)
|
|
164
|
+
|
|
165
|
+
# Value projection layer.
|
|
166
|
+
self.value_dense = keras.layers.EinsumDense(
|
|
167
|
+
equation="bsd,dkh->bskh",
|
|
168
|
+
output_shape=(None, self.num_key_value_heads, self.head_dim),
|
|
169
|
+
kernel_initializer=clone_initializer(self._kernel_initializer),
|
|
170
|
+
bias_axes="kh" if self.attention_bias else None,
|
|
171
|
+
dtype=self.dtype_policy,
|
|
172
|
+
name="value",
|
|
173
|
+
)
|
|
174
|
+
self.value_dense.build(kv_states_shape)
|
|
175
|
+
|
|
176
|
+
# Output projection layer.
|
|
177
|
+
self.output_dense = keras.layers.EinsumDense(
|
|
178
|
+
equation="btnh,nhd->btd",
|
|
179
|
+
output_shape=(None, self.hidden_dim),
|
|
180
|
+
kernel_initializer=clone_initializer(self._kernel_initializer),
|
|
181
|
+
bias_axes="d" if self.attention_bias else None,
|
|
182
|
+
dtype=self.dtype_policy,
|
|
183
|
+
name="attention_output",
|
|
184
|
+
)
|
|
185
|
+
self.output_dense.build(
|
|
186
|
+
(
|
|
187
|
+
hidden_states_shape[0],
|
|
188
|
+
hidden_states_shape[1],
|
|
189
|
+
self.num_query_heads,
|
|
190
|
+
self.head_dim,
|
|
191
|
+
)
|
|
192
|
+
)
|
|
193
|
+
self.dropout_layer = keras.layers.Dropout(
|
|
194
|
+
rate=self.attention_dropout,
|
|
195
|
+
dtype=self.dtype_policy,
|
|
196
|
+
)
|
|
197
|
+
self.softmax = keras.layers.Softmax(axis=-1, dtype="float32")
|
|
198
|
+
self.built = True
|
|
199
|
+
|
|
200
|
+
def _compute_attention_without_fused_op(
|
|
201
|
+
self, query_states, key_states, value_states, attention_mask, training
|
|
202
|
+
):
|
|
203
|
+
attn_weights = keras.ops.einsum(
|
|
204
|
+
"btnh,bsnh->bnts", query_states, key_states
|
|
205
|
+
)
|
|
206
|
+
attn_weights *= self.scaling
|
|
207
|
+
if self.logit_soft_cap is not None:
|
|
208
|
+
attn_weights = attn_weights / self.logit_soft_cap
|
|
209
|
+
attn_weights = keras.ops.tanh(attn_weights)
|
|
210
|
+
attn_weights = attn_weights * self.logit_soft_cap
|
|
211
|
+
if attention_mask is not None:
|
|
212
|
+
attn_weights += attention_mask
|
|
213
|
+
attn_weights = keras.ops.cast(
|
|
214
|
+
self.softmax(attn_weights),
|
|
215
|
+
query_states.dtype,
|
|
216
|
+
)
|
|
217
|
+
attn_weights = self.dropout_layer(attn_weights, training=training)
|
|
218
|
+
attn_output = keras.ops.einsum(
|
|
219
|
+
"bnts,bsnh->btnh", attn_weights, value_states
|
|
220
|
+
)
|
|
221
|
+
return attn_output
|
|
222
|
+
|
|
223
|
+
def _compute_attention(
|
|
224
|
+
self, query_states, key_states, value_states, attention_mask, training
|
|
225
|
+
):
|
|
226
|
+
if self._use_fused_attention_op():
|
|
227
|
+
kwargs = {"bias": attention_mask}
|
|
228
|
+
if self.logit_soft_cap is not None:
|
|
229
|
+
sig = inspect.signature(keras.ops.dot_product_attention)
|
|
230
|
+
# This is only supported in JAX TPU backend.
|
|
231
|
+
# https://keras.io/api/ops/nn/#dot_product_attention-function
|
|
232
|
+
if "attn_logits_soft_cap" in sig.parameters:
|
|
233
|
+
kwargs["attn_logits_soft_cap"] = self.logit_soft_cap
|
|
234
|
+
return keras.ops.dot_product_attention(
|
|
235
|
+
query=query_states,
|
|
236
|
+
key=key_states,
|
|
237
|
+
value=value_states,
|
|
238
|
+
scale=self.scaling,
|
|
239
|
+
**kwargs,
|
|
240
|
+
)
|
|
241
|
+
return self._compute_attention_without_fused_op(
|
|
242
|
+
query_states,
|
|
243
|
+
key_states,
|
|
244
|
+
value_states,
|
|
245
|
+
attention_mask,
|
|
246
|
+
training,
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
def call(
|
|
250
|
+
self,
|
|
251
|
+
inputs,
|
|
252
|
+
attention_mask=None,
|
|
253
|
+
cache=None,
|
|
254
|
+
cache_update_index=None,
|
|
255
|
+
training=None,
|
|
256
|
+
):
|
|
257
|
+
if self.attention_type == "cross":
|
|
258
|
+
if not isinstance(inputs, (list, tuple)) or len(inputs) != 2:
|
|
259
|
+
raise ValueError(
|
|
260
|
+
"For cross-attention, `inputs` must be a list or tuple of "
|
|
261
|
+
"two tensors: `[hidden_states, encoder_hidden_states]`."
|
|
262
|
+
)
|
|
263
|
+
hidden_states, kv_states = inputs
|
|
264
|
+
query_states = self.query_dense(hidden_states)
|
|
265
|
+
if cache is not None:
|
|
266
|
+
if cache_update_index is not None:
|
|
267
|
+
raise ValueError(
|
|
268
|
+
"`cache_update_index` should not be set for "
|
|
269
|
+
"cross-attention caching."
|
|
270
|
+
)
|
|
271
|
+
key_states, value_states = cache[:, 0, ...], cache[:, 1, ...]
|
|
272
|
+
updated_cache = cache
|
|
273
|
+
else:
|
|
274
|
+
key_states = self.key_dense(kv_states)
|
|
275
|
+
value_states = self.value_dense(kv_states)
|
|
276
|
+
updated_cache = keras.ops.stack(
|
|
277
|
+
(key_states, value_states), axis=1
|
|
278
|
+
)
|
|
279
|
+
# Repeat key-value heads for GQA.
|
|
280
|
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
281
|
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
282
|
+
attn_output = self._compute_attention(
|
|
283
|
+
query_states, key_states, value_states, attention_mask, training
|
|
284
|
+
)
|
|
285
|
+
attn_output = self.output_dense(attn_output)
|
|
286
|
+
return attn_output, updated_cache
|
|
287
|
+
else: # Self-attention
|
|
288
|
+
hidden_states = inputs
|
|
289
|
+
kv_states = hidden_states
|
|
290
|
+
query_states = self.query_dense(hidden_states)
|
|
291
|
+
key_states = self.key_dense(kv_states)
|
|
292
|
+
value_states = self.value_dense(kv_states)
|
|
293
|
+
start_index = (
|
|
294
|
+
0 if cache_update_index is None else cache_update_index
|
|
295
|
+
)
|
|
296
|
+
query_states = self.rotary_embedding(
|
|
297
|
+
query_states, start_index=start_index
|
|
298
|
+
)
|
|
299
|
+
key_states = self.rotary_embedding(
|
|
300
|
+
key_states, start_index=start_index
|
|
301
|
+
)
|
|
302
|
+
if cache is not None:
|
|
303
|
+
if cache_update_index is None:
|
|
304
|
+
raise ValueError(
|
|
305
|
+
"Both `cache` and `cache_update_index` must be passed "
|
|
306
|
+
"for self-attention caching."
|
|
307
|
+
)
|
|
308
|
+
key_cache, value_cache = cache[:, 0, ...], cache[:, 1, ...]
|
|
309
|
+
start = [0, cache_update_index, 0, 0]
|
|
310
|
+
key_states = keras.ops.slice_update(
|
|
311
|
+
key_cache, start, key_states
|
|
312
|
+
)
|
|
313
|
+
value_states = keras.ops.slice_update(
|
|
314
|
+
value_cache, start, value_states
|
|
315
|
+
)
|
|
316
|
+
cache = keras.ops.stack((key_states, value_states), axis=1)
|
|
317
|
+
elif cache_update_index is not None:
|
|
318
|
+
raise ValueError(
|
|
319
|
+
"`cache_update_index` should not be set if `cache` is "
|
|
320
|
+
"`None`."
|
|
321
|
+
)
|
|
322
|
+
else:
|
|
323
|
+
cache = keras.ops.stack((key_states, value_states), axis=1)
|
|
324
|
+
|
|
325
|
+
# Repeat key-value heads for GQA.
|
|
326
|
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
327
|
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
328
|
+
|
|
329
|
+
attn_output = self._compute_attention(
|
|
330
|
+
query_states, key_states, value_states, attention_mask, training
|
|
331
|
+
)
|
|
332
|
+
attn_output = self.output_dense(attn_output)
|
|
333
|
+
return attn_output, cache
|
|
334
|
+
|
|
335
|
+
def compute_output_shape(self, input_shape):
|
|
336
|
+
if self.attention_type == "cross":
|
|
337
|
+
hidden_states_shape, kv_states_shape = input_shape
|
|
338
|
+
else:
|
|
339
|
+
hidden_states_shape = input_shape
|
|
340
|
+
kv_states_shape = input_shape
|
|
341
|
+
attn_output_shape = hidden_states_shape
|
|
342
|
+
kv_len = kv_states_shape[1]
|
|
343
|
+
cache_shape = (
|
|
344
|
+
hidden_states_shape[0], # batch
|
|
345
|
+
2, # key and value
|
|
346
|
+
kv_len,
|
|
347
|
+
self.num_key_value_heads,
|
|
348
|
+
self.head_dim,
|
|
349
|
+
)
|
|
350
|
+
return attn_output_shape, cache_shape
|
|
351
|
+
|
|
352
|
+
def get_config(self):
|
|
353
|
+
config = super().get_config()
|
|
354
|
+
config.update(
|
|
355
|
+
{
|
|
356
|
+
"hidden_size": self.hidden_size,
|
|
357
|
+
"head_dim": self.head_dim,
|
|
358
|
+
"num_attention_heads": self.num_query_heads,
|
|
359
|
+
"num_key_value_heads": self.num_key_value_heads,
|
|
360
|
+
"query_pre_attn_scalar": self.query_pre_attn_scalar,
|
|
361
|
+
"attention_bias": self.attention_bias,
|
|
362
|
+
"attention_type": self.attention_type,
|
|
363
|
+
"cross_attention_hidden_size": self.cross_attention_hidden_size,
|
|
364
|
+
"initializer_range": self.initializer_range,
|
|
365
|
+
"attention_dropout": self.attention_dropout,
|
|
366
|
+
"attn_logit_softcapping": self.logit_soft_cap,
|
|
367
|
+
"rope_max_wavelength": self.rope_max_wavelength,
|
|
368
|
+
}
|
|
369
|
+
)
|
|
370
|
+
return config
|