keras-hub-nightly 0.24.0.dev202511220420__py3-none-any.whl → 0.26.0.dev202601010440__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of keras-hub-nightly might be problematic. Click here for more details.
- keras_hub/models/__init__.py +12 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +2 -275
- keras_hub/src/layers/modeling/rotary_embedding.py +188 -14
- keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -3
- keras_hub/src/models/albert/albert_backbone.py +1 -3
- keras_hub/src/models/bart/bart_backbone.py +1 -3
- keras_hub/src/models/bert/bert_backbone.py +1 -3
- keras_hub/src/models/bloom/bloom_backbone.py +1 -3
- keras_hub/src/models/causal_lm.py +23 -1
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -3
- keras_hub/src/models/dinov3/dinov3_presets.py +90 -1
- keras_hub/src/models/electra/electra_backbone.py +1 -3
- keras_hub/src/models/esm/esm_attention.py +11 -4
- keras_hub/src/models/f_net/f_net_backbone.py +1 -3
- keras_hub/src/models/falcon/falcon_backbone.py +1 -3
- keras_hub/src/models/gemma/gemma_backbone.py +1 -3
- keras_hub/src/models/gemma/gemma_causal_lm.py +16 -0
- keras_hub/src/models/gemma3/gemma3_backbone.py +1 -3
- keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +8 -3
- keras_hub/src/models/gemma3/gemma3_presets.py +12 -0
- keras_hub/src/models/gemma3/gemma3_tokenizer.py +20 -8
- keras_hub/src/models/gpt2/gpt2_backbone.py +1 -3
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +17 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +1 -3
- keras_hub/src/models/gpt_oss/__init__.py +5 -0
- keras_hub/src/models/gpt_oss/gpt_oss_attention.py +330 -0
- keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +219 -0
- keras_hub/src/models/gpt_oss/gpt_oss_causal_lm.py +284 -0
- keras_hub/src/models/gpt_oss/gpt_oss_causal_lm_preprocessor.py +79 -0
- keras_hub/src/models/gpt_oss/gpt_oss_decoder.py +444 -0
- keras_hub/src/models/gpt_oss/gpt_oss_layer_norm.py +34 -0
- keras_hub/src/models/gpt_oss/gpt_oss_presets.py +51 -0
- keras_hub/src/models/gpt_oss/gpt_oss_tokenizer.py +39 -0
- keras_hub/src/models/llama/llama_backbone.py +1 -3
- keras_hub/src/models/llama3/llama3_presets.py +1 -1
- keras_hub/src/models/masked_lm.py +22 -0
- keras_hub/src/models/mistral/mistral_backbone.py +1 -3
- keras_hub/src/models/mixtral/mixtral_backbone.py +1 -3
- keras_hub/src/models/moonshine/moonshine_backbone.py +1 -3
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +1 -3
- keras_hub/src/models/parseq/parseq_decoder.py +21 -9
- keras_hub/src/models/phi3/phi3_backbone.py +1 -3
- keras_hub/src/models/qwen/qwen_backbone.py +1 -3
- keras_hub/src/models/qwen3/qwen3_backbone.py +1 -3
- keras_hub/src/models/qwen3/qwen3_presets.py +36 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +1 -3
- keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +1 -3
- keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -3
- keras_hub/src/models/siglip/siglip_layers.py +1 -3
- keras_hub/src/models/smollm3/__init__.py +5 -0
- keras_hub/src/models/smollm3/smollm3_backbone.py +1 -3
- keras_hub/src/models/smollm3/smollm3_presets.py +16 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +1 -1
- keras_hub/src/models/stable_diffusion_3/t5_encoder.py +1 -3
- keras_hub/src/models/t5/t5_backbone.py +1 -3
- keras_hub/src/models/t5gemma/t5gemma_backbone.py +1 -3
- keras_hub/src/tests/test_case.py +1 -3
- keras_hub/src/utils/transformers/convert_gemma3.py +353 -0
- keras_hub/src/utils/transformers/convert_gpt_oss.py +302 -0
- keras_hub/src/utils/transformers/preset_loader.py +12 -0
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +3 -0
- {keras_hub_nightly-0.24.0.dev202511220420.dist-info → keras_hub_nightly-0.26.0.dev202601010440.dist-info}/METADATA +4 -5
- {keras_hub_nightly-0.24.0.dev202511220420.dist-info → keras_hub_nightly-0.26.0.dev202601010440.dist-info}/RECORD +66 -53
- {keras_hub_nightly-0.24.0.dev202511220420.dist-info → keras_hub_nightly-0.26.0.dev202601010440.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.24.0.dev202511220420.dist-info → keras_hub_nightly-0.26.0.dev202601010440.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,93 @@
|
|
|
1
1
|
"""DINOV3 model preset configurations."""
|
|
2
2
|
|
|
3
3
|
# Metadata for loading pretrained model weights.
|
|
4
|
-
backbone_presets = {
|
|
4
|
+
backbone_presets = {
|
|
5
|
+
"dinov3_vit_small_lvd1689m": {
|
|
6
|
+
"metadata": {
|
|
7
|
+
"description": (
|
|
8
|
+
"Vision Transformer (small-sized model) trained on LVD-1689M "
|
|
9
|
+
"using DINOv3."
|
|
10
|
+
),
|
|
11
|
+
"params": 21_600_000,
|
|
12
|
+
"path": "dinov3",
|
|
13
|
+
},
|
|
14
|
+
"kaggle_handle": "kaggle://keras/dinov3/keras/dinov3_vit_small_lvd1689m/1",
|
|
15
|
+
},
|
|
16
|
+
"dinov3_vit_small_plus_lvd1689m": {
|
|
17
|
+
"metadata": {
|
|
18
|
+
"description": (
|
|
19
|
+
"Vision Transformer (small-plus-sized model) trained on "
|
|
20
|
+
"LVD-1689M using DINOv3."
|
|
21
|
+
),
|
|
22
|
+
"params": 29_000_000,
|
|
23
|
+
"path": "dinov3",
|
|
24
|
+
},
|
|
25
|
+
"kaggle_handle": "kaggle://keras/dinov3/keras/dinov3_vit_small_plus_lvd1689m/1",
|
|
26
|
+
},
|
|
27
|
+
"dinov3_vit_base_lvd1689m": {
|
|
28
|
+
"metadata": {
|
|
29
|
+
"description": (
|
|
30
|
+
"Vision Transformer (base-sized model) trained on LVD-1689M "
|
|
31
|
+
"using DINOv3."
|
|
32
|
+
),
|
|
33
|
+
"params": 86_000_000,
|
|
34
|
+
"path": "dinov3",
|
|
35
|
+
},
|
|
36
|
+
"kaggle_handle": "kaggle://keras/dinov3/keras/dinov3_vit_base_lvd1689m/1",
|
|
37
|
+
},
|
|
38
|
+
"dinov3_vit_large_lvd1689m": {
|
|
39
|
+
"metadata": {
|
|
40
|
+
"description": (
|
|
41
|
+
"Vision Transformer (large-sized model) trained on LVD-1689M "
|
|
42
|
+
"using DINOv3."
|
|
43
|
+
),
|
|
44
|
+
"params": 300_000_000,
|
|
45
|
+
"path": "dinov3",
|
|
46
|
+
},
|
|
47
|
+
"kaggle_handle": "kaggle://keras/dinov3/keras/dinov3_vit_large_lvd1689m/1",
|
|
48
|
+
},
|
|
49
|
+
"dinov3_vit_huge_plus_lvd1689m": {
|
|
50
|
+
"metadata": {
|
|
51
|
+
"description": (
|
|
52
|
+
"Vision Transformer (huge-plus-sized model) trained on "
|
|
53
|
+
"LVD-1689M using DINOv3."
|
|
54
|
+
),
|
|
55
|
+
"params": 840_000_000,
|
|
56
|
+
"path": "dinov3",
|
|
57
|
+
},
|
|
58
|
+
"kaggle_handle": "kaggle://keras/dinov3/keras/dinov3_vit_huge_plus_lvd1689m/1",
|
|
59
|
+
},
|
|
60
|
+
"dinov3_vit_7b_lvd1689m": {
|
|
61
|
+
"metadata": {
|
|
62
|
+
"description": (
|
|
63
|
+
"Vision Transformer (7B-sized model) trained on LVD-1689M "
|
|
64
|
+
"using DINOv3."
|
|
65
|
+
),
|
|
66
|
+
"params": 6_700_000_000,
|
|
67
|
+
"path": "dinov3",
|
|
68
|
+
},
|
|
69
|
+
"kaggle_handle": "kaggle://keras/dinov3/keras/dinov3_vit_7b_lvd1689m/1",
|
|
70
|
+
},
|
|
71
|
+
"dinov3_vit_large_sat493m": {
|
|
72
|
+
"metadata": {
|
|
73
|
+
"description": (
|
|
74
|
+
"Vision Transformer (large-sized model) trained on SAT-493M "
|
|
75
|
+
"using DINOv3."
|
|
76
|
+
),
|
|
77
|
+
"params": 300_000_000,
|
|
78
|
+
"path": "dinov3",
|
|
79
|
+
},
|
|
80
|
+
"kaggle_handle": "kaggle://keras/dinov3/keras/dinov3_vit_large_sat493m/1",
|
|
81
|
+
},
|
|
82
|
+
"dinov3_vit_7b_sat493m": {
|
|
83
|
+
"metadata": {
|
|
84
|
+
"description": (
|
|
85
|
+
"Vision Transformer (7B-sized model) trained on SAT-493M "
|
|
86
|
+
"using DINOv3."
|
|
87
|
+
),
|
|
88
|
+
"params": 6_700_000_000,
|
|
89
|
+
"path": "dinov3",
|
|
90
|
+
},
|
|
91
|
+
"kaggle_handle": "kaggle://keras/dinov3/keras/dinov3_vit_7b_sat493m/1",
|
|
92
|
+
},
|
|
93
|
+
}
|
|
@@ -1,10 +1,8 @@
|
|
|
1
1
|
import keras
|
|
2
|
+
from keras.layers import ReversibleEmbedding
|
|
2
3
|
|
|
3
4
|
from keras_hub.src.api_export import keras_hub_export
|
|
4
5
|
from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
|
|
5
|
-
from keras_hub.src.layers.modeling.reversible_embedding import (
|
|
6
|
-
ReversibleEmbedding,
|
|
7
|
-
)
|
|
8
6
|
from keras_hub.src.layers.modeling.transformer_encoder import TransformerEncoder
|
|
9
7
|
from keras_hub.src.models.backbone import Backbone
|
|
10
8
|
from keras_hub.src.utils.keras_utils import gelu_approximate
|
|
@@ -14,7 +14,8 @@ class ESMRotaryEmbedding(RotaryEmbedding):
|
|
|
14
14
|
inv_freq = self.scaling_factor / (
|
|
15
15
|
self.max_wavelength ** (ops.arange(0, dim, 2, dtype=x.dtype) / dim)
|
|
16
16
|
)
|
|
17
|
-
|
|
17
|
+
# Use ops.shape for dynamic shape compatibility with TFLite
|
|
18
|
+
t = ops.arange(ops.shape(x)[position], dtype=x.dtype)
|
|
18
19
|
freqs = ops.outer(t, inv_freq)
|
|
19
20
|
emb = ops.concatenate((freqs, freqs), axis=-1)
|
|
20
21
|
|
|
@@ -32,11 +33,17 @@ class ESMRotaryEmbedding(RotaryEmbedding):
|
|
|
32
33
|
|
|
33
34
|
def rotate_half(self, x):
|
|
34
35
|
x1, x2 = ops.split(x, 2, -1)
|
|
35
|
-
|
|
36
|
+
# Avoid `ops.concatenate` to prevent XLA compilation issues on JAX
|
|
37
|
+
# backend. Use stack + reshape approach from base RotaryEmbedding.
|
|
38
|
+
half_rot_x = ops.stack((-x2, x1), axis=-2)
|
|
39
|
+
half_rot_x = ops.reshape(half_rot_x, ops.shape(x))
|
|
40
|
+
return half_rot_x
|
|
36
41
|
|
|
37
42
|
def apply_rotary_pos_emb(self, x, cos, sin):
|
|
38
|
-
|
|
39
|
-
|
|
43
|
+
# Use ops.shape for dynamic shape compatibility with TFLite
|
|
44
|
+
seq_len = ops.shape(x)[1]
|
|
45
|
+
cos = cos[:, :seq_len, :, :]
|
|
46
|
+
sin = sin[:, :seq_len, :, :]
|
|
40
47
|
|
|
41
48
|
return (x * cos) + (self.rotate_half(x) * sin)
|
|
42
49
|
|
|
@@ -1,11 +1,9 @@
|
|
|
1
1
|
import keras
|
|
2
|
+
from keras.layers import ReversibleEmbedding
|
|
2
3
|
|
|
3
4
|
from keras_hub.src.api_export import keras_hub_export
|
|
4
5
|
from keras_hub.src.layers.modeling.f_net_encoder import FNetEncoder
|
|
5
6
|
from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
|
|
6
|
-
from keras_hub.src.layers.modeling.reversible_embedding import (
|
|
7
|
-
ReversibleEmbedding,
|
|
8
|
-
)
|
|
9
7
|
from keras_hub.src.models.backbone import Backbone
|
|
10
8
|
from keras_hub.src.utils.keras_utils import gelu_approximate
|
|
11
9
|
|
|
@@ -1,9 +1,7 @@
|
|
|
1
1
|
import keras
|
|
2
|
+
from keras.layers import ReversibleEmbedding
|
|
2
3
|
|
|
3
4
|
from keras_hub.src.api_export import keras_hub_export
|
|
4
|
-
from keras_hub.src.layers.modeling.reversible_embedding import (
|
|
5
|
-
ReversibleEmbedding,
|
|
6
|
-
)
|
|
7
5
|
from keras_hub.src.models.backbone import Backbone
|
|
8
6
|
from keras_hub.src.models.falcon.falcon_transformer_decoder import (
|
|
9
7
|
FalconTransformerDecoder,
|
|
@@ -1,10 +1,8 @@
|
|
|
1
1
|
import keras
|
|
2
2
|
from keras import ops
|
|
3
|
+
from keras.layers import ReversibleEmbedding
|
|
3
4
|
|
|
4
5
|
from keras_hub.src.api_export import keras_hub_export
|
|
5
|
-
from keras_hub.src.layers.modeling.reversible_embedding import (
|
|
6
|
-
ReversibleEmbedding,
|
|
7
|
-
)
|
|
8
6
|
from keras_hub.src.models.backbone import Backbone
|
|
9
7
|
from keras_hub.src.models.gemma.gemma_decoder_block import GemmaDecoderBlock
|
|
10
8
|
from keras_hub.src.models.gemma.rms_normalization import RMSNormalization
|
|
@@ -431,3 +431,19 @@ class GemmaCausalLM(CausalLM):
|
|
|
431
431
|
)
|
|
432
432
|
per_token_loss = per_token_loss_fn(target_ids, logits)
|
|
433
433
|
return per_token_loss
|
|
434
|
+
|
|
435
|
+
def get_quantization_layer_structure(self, mode):
|
|
436
|
+
if mode != "gptq":
|
|
437
|
+
return None
|
|
438
|
+
|
|
439
|
+
# Wrap embedding + scaling
|
|
440
|
+
backbone = self.backbone
|
|
441
|
+
inputs = keras.Input(shape=(None,), dtype="int32")
|
|
442
|
+
x = backbone.token_embedding(inputs)
|
|
443
|
+
x = x * ops.cast(ops.sqrt(backbone.hidden_dim), x.dtype)
|
|
444
|
+
pre_processor = keras.Model(inputs=inputs, outputs=x)
|
|
445
|
+
|
|
446
|
+
return {
|
|
447
|
+
"pre_block_layers": [pre_processor],
|
|
448
|
+
"sequential_blocks": backbone.transformer_layers,
|
|
449
|
+
}
|
|
@@ -1,10 +1,8 @@
|
|
|
1
1
|
import keras
|
|
2
2
|
from keras import ops
|
|
3
|
+
from keras.layers import ReversibleEmbedding
|
|
3
4
|
|
|
4
5
|
from keras_hub.src.api_export import keras_hub_export
|
|
5
|
-
from keras_hub.src.layers.modeling.reversible_embedding import (
|
|
6
|
-
ReversibleEmbedding,
|
|
7
|
-
)
|
|
8
6
|
from keras_hub.src.models.backbone import Backbone
|
|
9
7
|
from keras_hub.src.models.gemma.rms_normalization import RMSNormalization
|
|
10
8
|
from keras_hub.src.models.gemma3.gemma3_decoder_block import Gemma3DecoderBlock
|
|
@@ -283,9 +283,14 @@ class Gemma3CausalLMPreprocessor(CausalLMPreprocessor):
|
|
|
283
283
|
# is `None`.
|
|
284
284
|
self.text_only_model = self.image_converter is None
|
|
285
285
|
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
286
|
+
if self.text_only_model:
|
|
287
|
+
self.image_placeholder = None
|
|
288
|
+
self.start_of_image_token = None
|
|
289
|
+
self.end_of_image_token = None
|
|
290
|
+
else:
|
|
291
|
+
self.image_placeholder = self.tokenizer.image_placeholder
|
|
292
|
+
self.start_of_image_token = self.tokenizer.start_of_image_token
|
|
293
|
+
self.end_of_image_token = self.tokenizer.end_of_image_token
|
|
289
294
|
|
|
290
295
|
def build(self, input_shape):
|
|
291
296
|
# Defer packer creation to `build()` so that we can be sure tokenizer
|
|
@@ -220,4 +220,16 @@ backbone_presets = {
|
|
|
220
220
|
},
|
|
221
221
|
"kaggle_handle": "kaggle://keras/medgemma/keras/medgemma_instruct_27b_text/1",
|
|
222
222
|
},
|
|
223
|
+
"function_gemma_instruct_270m": {
|
|
224
|
+
"metadata": {
|
|
225
|
+
"description": (
|
|
226
|
+
"A 270M Million parameter text-only model based on Gemma 3. "
|
|
227
|
+
"This model is trained specifically for function calling "
|
|
228
|
+
"improvements."
|
|
229
|
+
),
|
|
230
|
+
"params": 268098176,
|
|
231
|
+
"path": "gemma3",
|
|
232
|
+
},
|
|
233
|
+
"kaggle_handle": "kaggle://keras/function-gemma/keras/function_gemma_instruct_270m/1",
|
|
234
|
+
},
|
|
223
235
|
}
|
|
@@ -77,20 +77,32 @@ class Gemma3Tokenizer(SentencePieceTokenizer):
|
|
|
77
77
|
|
|
78
78
|
backbone_cls = Gemma3Backbone
|
|
79
79
|
|
|
80
|
-
def __init__(self, proto, **kwargs):
|
|
80
|
+
def __init__(self, proto, has_vision_tokens=True, **kwargs):
|
|
81
81
|
# Add special tokens.
|
|
82
82
|
|
|
83
|
+
self.has_vision_tokens = has_vision_tokens
|
|
83
84
|
# The usual tokens.
|
|
84
85
|
self._add_special_token("<bos>", "start_token")
|
|
85
86
|
self._add_special_token("<eos>", "end_token")
|
|
86
87
|
self._add_special_token("<pad>", "pad_token")
|
|
87
88
|
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
89
|
+
if has_vision_tokens:
|
|
90
|
+
# Image placeholder token.
|
|
91
|
+
self._add_special_token("<img>", "image_placeholder")
|
|
92
|
+
# Some tokens which are used in the preprocessor.
|
|
93
|
+
# We need to keep them
|
|
94
|
+
# here so that the preprocessor works with tf.data.
|
|
95
|
+
self._add_special_token("<start_of_image>", "start_of_image_token")
|
|
96
|
+
self._add_special_token("<end_of_image>", "end_of_image_token")
|
|
97
|
+
else:
|
|
98
|
+
# For text-only, skip assigning token IDs or set to -1
|
|
99
|
+
self.start_of_image_token_id = -1
|
|
100
|
+
self.image_placeholder_token_id = -1
|
|
101
|
+
self.end_of_image_token_id = -1
|
|
95
102
|
|
|
96
103
|
super().__init__(proto=proto, **kwargs)
|
|
104
|
+
|
|
105
|
+
def get_config(self):
|
|
106
|
+
config = super().get_config()
|
|
107
|
+
config.update({"has_vision_tokens": self.has_vision_tokens})
|
|
108
|
+
return config
|
|
@@ -1,10 +1,8 @@
|
|
|
1
1
|
import keras
|
|
2
|
+
from keras.layers import ReversibleEmbedding
|
|
2
3
|
|
|
3
4
|
from keras_hub.src.api_export import keras_hub_export
|
|
4
5
|
from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
|
|
5
|
-
from keras_hub.src.layers.modeling.reversible_embedding import (
|
|
6
|
-
ReversibleEmbedding,
|
|
7
|
-
)
|
|
8
6
|
from keras_hub.src.layers.modeling.transformer_decoder import TransformerDecoder
|
|
9
7
|
from keras_hub.src.models.backbone import Backbone
|
|
10
8
|
from keras_hub.src.utils.keras_utils import gelu_approximate
|
|
@@ -420,3 +420,20 @@ class GPT2CausalLM(CausalLM):
|
|
|
420
420
|
)
|
|
421
421
|
per_token_loss = per_token_loss_fn(target_ids, logits)
|
|
422
422
|
return per_token_loss
|
|
423
|
+
|
|
424
|
+
def get_quantization_layer_structure(self, mode):
|
|
425
|
+
if mode != "gptq":
|
|
426
|
+
return None
|
|
427
|
+
|
|
428
|
+
backbone = self.backbone
|
|
429
|
+
token_ids = keras.Input(shape=(None,), dtype="int32")
|
|
430
|
+
tokens = backbone.token_embedding(token_ids)
|
|
431
|
+
positions = backbone.position_embedding(tokens)
|
|
432
|
+
x = backbone.embeddings_add((tokens, positions))
|
|
433
|
+
x = backbone.embeddings_dropout(x)
|
|
434
|
+
pre_processor = keras.Model(inputs=token_ids, outputs=x)
|
|
435
|
+
|
|
436
|
+
return {
|
|
437
|
+
"pre_block_layers": [pre_processor],
|
|
438
|
+
"sequential_blocks": backbone.transformer_layers,
|
|
439
|
+
}
|
|
@@ -1,9 +1,7 @@
|
|
|
1
1
|
import keras
|
|
2
|
+
from keras.layers import ReversibleEmbedding
|
|
2
3
|
|
|
3
4
|
from keras_hub.src.api_export import keras_hub_export
|
|
4
|
-
from keras_hub.src.layers.modeling.reversible_embedding import (
|
|
5
|
-
ReversibleEmbedding,
|
|
6
|
-
)
|
|
7
5
|
from keras_hub.src.models.backbone import Backbone
|
|
8
6
|
from keras_hub.src.models.gpt_neo_x.gpt_neo_x_decoder import GPTNeoXDecoder
|
|
9
7
|
from keras_hub.src.utils.keras_utils import gelu_approximate
|
|
@@ -0,0 +1,330 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
import keras
|
|
4
|
+
from keras import ops
|
|
5
|
+
|
|
6
|
+
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
|
|
7
|
+
from keras_hub.src.utils.keras_utils import clone_initializer
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class GptOssAttention(keras.layers.Layer):
|
|
11
|
+
"""A cached attention layer with sliding window and sink tokens.
|
|
12
|
+
|
|
13
|
+
This layer implements the attention mechanism described in the GPT-OSS
|
|
14
|
+
paper. It includes grouped-query attention, rotary position embeddings,
|
|
15
|
+
sliding window attention, and sink tokens for improved performance on
|
|
16
|
+
long sequences.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
num_query_heads: int. The number of query attention heads.
|
|
20
|
+
num_key_value_heads: int. The number of key and value attention
|
|
21
|
+
heads.
|
|
22
|
+
rope_max_wavelength: int. The maximum wavelength for the
|
|
23
|
+
rotary position embedding. Defaults to 10000.
|
|
24
|
+
rope_scaling_factor: float. The scaling factor for the
|
|
25
|
+
rotary position embedding. Defaults to 1.0.
|
|
26
|
+
kernel_initializer: str. The initializer for the kernel
|
|
27
|
+
weights. Defaults to "glorot_uniform".
|
|
28
|
+
sliding_window: int. The size of the sliding window.
|
|
29
|
+
Defaults to 4096.
|
|
30
|
+
dropout: float. The dropout rate. Defaults to 0.
|
|
31
|
+
head_dim: int. Head dimension for attention. If None,
|
|
32
|
+
calculated as hidden_dim // num_query_heads. Defaults to None.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
num_query_heads,
|
|
38
|
+
num_key_value_heads,
|
|
39
|
+
rope_max_wavelength=10000,
|
|
40
|
+
rope_scaling_factor=1.0,
|
|
41
|
+
kernel_initializer="glorot_uniform",
|
|
42
|
+
sliding_window=4096,
|
|
43
|
+
dropout=0,
|
|
44
|
+
head_dim=None,
|
|
45
|
+
**kwargs,
|
|
46
|
+
):
|
|
47
|
+
super().__init__(**kwargs)
|
|
48
|
+
self.num_query_heads = num_query_heads
|
|
49
|
+
self.num_key_value_heads = num_key_value_heads
|
|
50
|
+
self.sliding_window = sliding_window
|
|
51
|
+
self.dropout = dropout
|
|
52
|
+
self.head_dim = head_dim
|
|
53
|
+
self.rope_max_wavelength = rope_max_wavelength
|
|
54
|
+
self.rope_scaling_factor = rope_scaling_factor
|
|
55
|
+
self.num_key_value_groups = num_query_heads // num_key_value_heads
|
|
56
|
+
self._kernel_initializer = keras.initializers.get(
|
|
57
|
+
clone_initializer(kernel_initializer)
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
def build(self, inputs_shape):
|
|
61
|
+
# Einsum variables:
|
|
62
|
+
# b = batch size
|
|
63
|
+
# q = query length
|
|
64
|
+
# k = key/value length
|
|
65
|
+
# m = the model's hidden_dim
|
|
66
|
+
# u = num query heads
|
|
67
|
+
# v = num key/value heads
|
|
68
|
+
# h = head dim
|
|
69
|
+
self._hidden_dim = inputs_shape[-1]
|
|
70
|
+
|
|
71
|
+
if self.head_dim is not None:
|
|
72
|
+
self._head_dim = self.head_dim
|
|
73
|
+
else:
|
|
74
|
+
self._head_dim = self._hidden_dim // self.num_query_heads
|
|
75
|
+
self._inv_norm_factor = 1.0 / math.sqrt(self._head_dim)
|
|
76
|
+
|
|
77
|
+
self._rotary_dim = (self._head_dim // 2) * 2
|
|
78
|
+
|
|
79
|
+
self.query_dense = keras.layers.EinsumDense(
|
|
80
|
+
equation="bqm,muh->bquh",
|
|
81
|
+
output_shape=(None, self.num_query_heads, self._head_dim),
|
|
82
|
+
bias_axes="uh",
|
|
83
|
+
kernel_initializer=self._kernel_initializer,
|
|
84
|
+
bias_initializer="zeros",
|
|
85
|
+
dtype=self.dtype_policy,
|
|
86
|
+
name="query",
|
|
87
|
+
)
|
|
88
|
+
self.query_dense.build(inputs_shape)
|
|
89
|
+
|
|
90
|
+
self.key_dense = keras.layers.EinsumDense(
|
|
91
|
+
equation="bkm,mvh->bkvh",
|
|
92
|
+
output_shape=(
|
|
93
|
+
None,
|
|
94
|
+
self.num_key_value_heads,
|
|
95
|
+
self._head_dim,
|
|
96
|
+
),
|
|
97
|
+
bias_axes="vh",
|
|
98
|
+
kernel_initializer=self._kernel_initializer,
|
|
99
|
+
bias_initializer="zeros",
|
|
100
|
+
dtype=self.dtype_policy,
|
|
101
|
+
name="key",
|
|
102
|
+
)
|
|
103
|
+
self.key_dense.build(inputs_shape)
|
|
104
|
+
|
|
105
|
+
self.value_dense = keras.layers.EinsumDense(
|
|
106
|
+
equation="bkm,mvh->bkvh",
|
|
107
|
+
output_shape=(
|
|
108
|
+
None,
|
|
109
|
+
self.num_key_value_heads,
|
|
110
|
+
self._head_dim,
|
|
111
|
+
),
|
|
112
|
+
bias_axes="vh",
|
|
113
|
+
kernel_initializer=self._kernel_initializer,
|
|
114
|
+
bias_initializer="zeros",
|
|
115
|
+
dtype=self.dtype_policy,
|
|
116
|
+
name="value",
|
|
117
|
+
)
|
|
118
|
+
self.value_dense.build(inputs_shape)
|
|
119
|
+
|
|
120
|
+
self.dropout_layer = keras.layers.Dropout(
|
|
121
|
+
rate=self.dropout,
|
|
122
|
+
dtype=self.dtype_policy,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
self.output_dense = keras.layers.EinsumDense(
|
|
126
|
+
equation="bquh,uhm->bqm",
|
|
127
|
+
output_shape=(None, self._hidden_dim),
|
|
128
|
+
bias_axes="m",
|
|
129
|
+
kernel_initializer=self._kernel_initializer,
|
|
130
|
+
bias_initializer="zeros",
|
|
131
|
+
dtype=self.dtype_policy,
|
|
132
|
+
name="attention_output",
|
|
133
|
+
)
|
|
134
|
+
self.output_dense.build(
|
|
135
|
+
(None, None, self.num_query_heads, self._head_dim)
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
self.rotary_embedding_layer = RotaryEmbedding(
|
|
139
|
+
max_wavelength=self.rope_max_wavelength,
|
|
140
|
+
scaling_factor=self.rope_scaling_factor, # YaRN scaling factor
|
|
141
|
+
rope_type="yarn",
|
|
142
|
+
beta_fast=32.0,
|
|
143
|
+
beta_slow=1.0,
|
|
144
|
+
original_max_position_embeddings=4096,
|
|
145
|
+
dtype=self.dtype_policy,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
self.sinks = self.add_weight(
|
|
149
|
+
shape=(self.num_query_heads,),
|
|
150
|
+
initializer="random_normal",
|
|
151
|
+
dtype=self.dtype,
|
|
152
|
+
name="sinks",
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
self._dot_product_equation = "bquh,bkuh->buqk"
|
|
156
|
+
self._combine_equation = "buqk,bkuh->bquh"
|
|
157
|
+
|
|
158
|
+
self.built = True
|
|
159
|
+
|
|
160
|
+
def call(
|
|
161
|
+
self,
|
|
162
|
+
hidden_states,
|
|
163
|
+
attention_mask=None,
|
|
164
|
+
cache=None,
|
|
165
|
+
cache_update_index=None,
|
|
166
|
+
training=None,
|
|
167
|
+
):
|
|
168
|
+
start_index = (
|
|
169
|
+
cache_update_index if cache_update_index is not None else 0
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
query = self.query_dense(hidden_states)
|
|
173
|
+
|
|
174
|
+
# Compute RoPE for queries (only
|
|
175
|
+
# to first _rotary_dim dimensions)
|
|
176
|
+
if self._rotary_dim < self._head_dim:
|
|
177
|
+
query_rot = query[..., : self._rotary_dim]
|
|
178
|
+
query_rot = self.rotary_embedding_layer(
|
|
179
|
+
query_rot, start_index=start_index
|
|
180
|
+
)
|
|
181
|
+
query = ops.concatenate(
|
|
182
|
+
[query_rot, query[..., self._rotary_dim :]], axis=-1
|
|
183
|
+
)
|
|
184
|
+
else:
|
|
185
|
+
query = self.rotary_embedding_layer(query, start_index=start_index)
|
|
186
|
+
|
|
187
|
+
def _compute_key_value(x):
|
|
188
|
+
key, value = self.key_dense(x), self.value_dense(x)
|
|
189
|
+
# Compute RoPE for keys (only apply to first _rotary_dim dimensions)
|
|
190
|
+
if self._rotary_dim < self._head_dim:
|
|
191
|
+
key_rot = key[..., : self._rotary_dim]
|
|
192
|
+
key_rot = self.rotary_embedding_layer(
|
|
193
|
+
key_rot, start_index=start_index
|
|
194
|
+
)
|
|
195
|
+
key = ops.concatenate(
|
|
196
|
+
[key_rot, key[..., self._rotary_dim :]], axis=-1
|
|
197
|
+
)
|
|
198
|
+
else:
|
|
199
|
+
key = self.rotary_embedding_layer(key, start_index=start_index)
|
|
200
|
+
return key, value
|
|
201
|
+
|
|
202
|
+
if cache is not None:
|
|
203
|
+
key_cache = cache[:, 0, ...]
|
|
204
|
+
value_cache = cache[:, 1, ...]
|
|
205
|
+
if cache_update_index is None:
|
|
206
|
+
key = key_cache
|
|
207
|
+
value = value_cache
|
|
208
|
+
else:
|
|
209
|
+
key_update, value_update = _compute_key_value(hidden_states)
|
|
210
|
+
start = [0, cache_update_index, 0, 0]
|
|
211
|
+
key = ops.slice_update(key_cache, start, key_update)
|
|
212
|
+
value = ops.slice_update(value_cache, start, value_update)
|
|
213
|
+
cache = ops.stack((key, value), axis=1)
|
|
214
|
+
else:
|
|
215
|
+
if cache_update_index is not None:
|
|
216
|
+
raise ValueError(
|
|
217
|
+
"`cache_update_index` should not be set if `cache` is "
|
|
218
|
+
f"`None`. Received: cache={cache}, "
|
|
219
|
+
f"cache_update_index={cache_update_index}"
|
|
220
|
+
)
|
|
221
|
+
key, value = _compute_key_value(hidden_states)
|
|
222
|
+
|
|
223
|
+
# [batch_shape, seq_len, num_key_value_heads, head_dim]
|
|
224
|
+
# -> [batch_shape, seq_len, num_heads, head_dim]
|
|
225
|
+
key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2)
|
|
226
|
+
value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2)
|
|
227
|
+
|
|
228
|
+
attention_output = self._compute_attention(
|
|
229
|
+
query, key, value, attention_mask, start_index
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
attention_output = self.dropout_layer(
|
|
233
|
+
attention_output, training=training
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
attention_output = self.output_dense(attention_output)
|
|
237
|
+
|
|
238
|
+
if cache is not None:
|
|
239
|
+
return attention_output, cache
|
|
240
|
+
return attention_output
|
|
241
|
+
|
|
242
|
+
def _compute_attention(
|
|
243
|
+
self, query, key, value, attention_mask=None, start_index=0
|
|
244
|
+
):
|
|
245
|
+
attention_scores = ops.einsum(self._dot_product_equation, query, key)
|
|
246
|
+
attention_scores = ops.multiply(
|
|
247
|
+
attention_scores,
|
|
248
|
+
ops.cast(self._inv_norm_factor, self.compute_dtype),
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
# Apply sliding window mask if specified
|
|
252
|
+
if self.sliding_window is not None and self.sliding_window > 0:
|
|
253
|
+
q_len = ops.shape(attention_scores)[-2]
|
|
254
|
+
kv_len = ops.shape(attention_scores)[-1]
|
|
255
|
+
|
|
256
|
+
# Query positions are offset by start_index during generation
|
|
257
|
+
q_positions = ops.arange(q_len) + start_index
|
|
258
|
+
kv_positions = ops.arange(kv_len)
|
|
259
|
+
|
|
260
|
+
# Mask true for positions outside sliding window
|
|
261
|
+
# For causal attention: mask if kv_pos < q_pos - sliding_window
|
|
262
|
+
mask = (
|
|
263
|
+
kv_positions[None, :]
|
|
264
|
+
>= q_positions[:, None] - self.sliding_window
|
|
265
|
+
)
|
|
266
|
+
if self.compute_dtype == "float32":
|
|
267
|
+
sliding_adder = ops.cast(-1e9, self.compute_dtype)
|
|
268
|
+
else:
|
|
269
|
+
sliding_adder = ops.cast(-1e4, self.compute_dtype)
|
|
270
|
+
attention_scores = ops.where(
|
|
271
|
+
mask[None, None, :, :], attention_scores, sliding_adder
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
if attention_mask is not None:
|
|
275
|
+
# The mask is a boolean tensor, True for positions to be masked.
|
|
276
|
+
# We add a large negative number to the masked positions.
|
|
277
|
+
# Use a large negative value for masking
|
|
278
|
+
if self.compute_dtype == "float32":
|
|
279
|
+
adder = ops.cast(-1e9, self.compute_dtype)
|
|
280
|
+
else:
|
|
281
|
+
adder = ops.cast(-1e4, self.compute_dtype)
|
|
282
|
+
attention_scores = ops.where(
|
|
283
|
+
attention_mask[:, None, :, :], attention_scores, adder
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
# Handle sink tokens by concatenating them to the logits.
|
|
287
|
+
b = ops.shape(attention_scores)[0]
|
|
288
|
+
q = ops.shape(attention_scores)[2]
|
|
289
|
+
|
|
290
|
+
sinks = ops.reshape(self.sinks, (1, self.num_query_heads, 1, 1))
|
|
291
|
+
sinks = ops.broadcast_to(sinks, (b, self.num_query_heads, q, 1))
|
|
292
|
+
# attention_scores shape: [b, num_heads, q, k]
|
|
293
|
+
# sinks shape: [b, num_heads, q, 1]
|
|
294
|
+
# We need to concatenate along the last dimension
|
|
295
|
+
combined_logits = ops.concatenate([attention_scores, sinks], axis=-1)
|
|
296
|
+
|
|
297
|
+
# Stabilize logits before softmax for numerical stability.
|
|
298
|
+
max_logits = ops.max(combined_logits, axis=-1, keepdims=True)
|
|
299
|
+
max_logits = ops.stop_gradient(max_logits)
|
|
300
|
+
combined_logits = combined_logits - max_logits
|
|
301
|
+
|
|
302
|
+
probs = ops.softmax(combined_logits, axis=-1)
|
|
303
|
+
|
|
304
|
+
# Remove the sink probabilities before computing the output.
|
|
305
|
+
attention_scores = probs[..., :-1]
|
|
306
|
+
attention_scores = ops.cast(attention_scores, self.compute_dtype)
|
|
307
|
+
|
|
308
|
+
attention_output = ops.einsum(
|
|
309
|
+
self._combine_equation, attention_scores, value
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
return attention_output
|
|
313
|
+
|
|
314
|
+
def get_config(self):
|
|
315
|
+
config = super().get_config()
|
|
316
|
+
config.update(
|
|
317
|
+
{
|
|
318
|
+
"num_query_heads": self.num_query_heads,
|
|
319
|
+
"num_key_value_heads": self.num_key_value_heads,
|
|
320
|
+
"rope_max_wavelength": self.rope_max_wavelength,
|
|
321
|
+
"rope_scaling_factor": self.rope_scaling_factor,
|
|
322
|
+
"kernel_initializer": keras.initializers.serialize(
|
|
323
|
+
self._kernel_initializer
|
|
324
|
+
),
|
|
325
|
+
"sliding_window": self.sliding_window,
|
|
326
|
+
"dropout": self.dropout,
|
|
327
|
+
"head_dim": self.head_dim,
|
|
328
|
+
}
|
|
329
|
+
)
|
|
330
|
+
return config
|