keras-hub-nightly 0.22.0.dev202508170419__py3-none-any.whl → 0.24.0.dev202511090424__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of keras-hub-nightly might be problematic. Click here for more details.
- keras_hub/layers/__init__.py +15 -0
- keras_hub/models/__init__.py +93 -0
- keras_hub/src/layers/modeling/position_embedding.py +21 -6
- keras_hub/src/layers/modeling/reversible_embedding.py +8 -1
- keras_hub/src/layers/modeling/rotary_embedding.py +16 -6
- keras_hub/src/layers/modeling/sine_position_encoding.py +21 -8
- keras_hub/src/layers/modeling/token_and_position_embedding.py +2 -1
- keras_hub/src/models/backbone.py +28 -16
- keras_hub/src/models/causal_lm.py +37 -0
- keras_hub/src/models/causal_lm_preprocessor.py +14 -0
- keras_hub/src/models/clip/clip_presets.py +8 -8
- keras_hub/src/models/d_fine/__init__.py +5 -0
- keras_hub/src/models/d_fine/d_fine_attention.py +461 -0
- keras_hub/src/models/d_fine/d_fine_backbone.py +891 -0
- keras_hub/src/models/d_fine/d_fine_decoder.py +944 -0
- keras_hub/src/models/d_fine/d_fine_encoder.py +365 -0
- keras_hub/src/models/d_fine/d_fine_hybrid_encoder.py +642 -0
- keras_hub/src/models/d_fine/d_fine_image_converter.py +8 -0
- keras_hub/src/models/d_fine/d_fine_layers.py +1828 -0
- keras_hub/src/models/d_fine/d_fine_loss.py +938 -0
- keras_hub/src/models/d_fine/d_fine_object_detector.py +875 -0
- keras_hub/src/models/d_fine/d_fine_object_detector_preprocessor.py +14 -0
- keras_hub/src/models/d_fine/d_fine_presets.py +155 -0
- keras_hub/src/models/d_fine/d_fine_utils.py +827 -0
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +7 -2
- keras_hub/src/models/depth_anything/__init__.py +9 -0
- keras_hub/src/models/depth_anything/depth_anything_backbone.py +232 -0
- keras_hub/src/models/depth_anything/depth_anything_depth_estimator.py +70 -0
- keras_hub/src/models/depth_anything/depth_anything_depth_estimator_preprocessor.py +16 -0
- keras_hub/src/models/depth_anything/depth_anything_image_converter.py +10 -0
- keras_hub/src/models/depth_anything/depth_anything_layers.py +725 -0
- keras_hub/src/models/depth_anything/depth_anything_loss.py +89 -0
- keras_hub/src/models/depth_anything/depth_anything_presets.py +41 -0
- keras_hub/src/models/depth_anything/interpolate.py +62 -0
- keras_hub/src/models/depth_estimator.py +239 -0
- keras_hub/src/models/depth_estimator_preprocessor.py +78 -0
- keras_hub/src/models/dinov2/dinov2_backbone.py +29 -3
- keras_hub/src/models/dinov2/dinov2_layers.py +16 -4
- keras_hub/src/models/dinov3/__init__.py +5 -0
- keras_hub/src/models/dinov3/dinov3_backbone.py +263 -0
- keras_hub/src/models/dinov3/dinov3_image_converter.py +8 -0
- keras_hub/src/models/dinov3/dinov3_layers.py +1013 -0
- keras_hub/src/models/dinov3/dinov3_presets.py +4 -0
- keras_hub/src/models/gemma/gemma_backbone.py +0 -1
- keras_hub/src/models/gemma/gemma_presets.py +30 -0
- keras_hub/src/models/gemma3/gemma3_attention.py +48 -0
- keras_hub/src/models/gemma3/gemma3_backbone.py +4 -1
- keras_hub/src/models/gemma3/gemma3_decoder_block.py +12 -0
- keras_hub/src/models/gemma3/gemma3_presets.py +39 -0
- keras_hub/src/models/hgnetv2/hgnetv2_backbone.py +4 -1
- keras_hub/src/models/hgnetv2/hgnetv2_encoder.py +3 -2
- keras_hub/src/models/hgnetv2/hgnetv2_layers.py +27 -11
- keras_hub/src/models/image_to_image.py +5 -0
- keras_hub/src/models/inpaint.py +5 -0
- keras_hub/src/models/mobilenetv5/__init__.py +9 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_attention.py +699 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_backbone.py +396 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_blocks.py +890 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_builder.py +436 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier.py +157 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier_preprocessor.py +16 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_image_converter.py +10 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_layers.py +462 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_presets.py +15 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_utils.py +146 -0
- keras_hub/src/models/parseq/__init__.py +5 -0
- keras_hub/src/models/parseq/parseq_backbone.py +134 -0
- keras_hub/src/models/parseq/parseq_causal_lm.py +466 -0
- keras_hub/src/models/parseq/parseq_causal_lm_preprocessor.py +168 -0
- keras_hub/src/models/parseq/parseq_decoder.py +418 -0
- keras_hub/src/models/parseq/parseq_image_converter.py +8 -0
- keras_hub/src/models/parseq/parseq_presets.py +15 -0
- keras_hub/src/models/parseq/parseq_tokenizer.py +221 -0
- keras_hub/src/models/qwen3_moe/__init__.py +5 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py +371 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +365 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm.py +357 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm_preprocessor.py +12 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_decoder.py +672 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_layernorm.py +45 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +30 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_tokenizer.py +48 -0
- keras_hub/src/models/sam/sam_prompt_encoder.py +3 -1
- keras_hub/src/models/siglip/siglip_presets.py +15 -0
- keras_hub/src/models/smollm3/smollm3_backbone.py +211 -0
- keras_hub/src/models/smollm3/smollm3_causal_lm.py +310 -0
- keras_hub/src/models/smollm3/smollm3_causal_lm_preprocessor.py +84 -0
- keras_hub/src/models/smollm3/smollm3_layers.py +757 -0
- keras_hub/src/models/smollm3/smollm3_tokenizer.py +60 -0
- keras_hub/src/models/smollm3/smollm3_utils.py +56 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +3 -3
- keras_hub/src/models/t5gemma/__init__.py +5 -0
- keras_hub/src/models/t5gemma/t5gemma_attention.py +370 -0
- keras_hub/src/models/t5gemma/t5gemma_backbone.py +366 -0
- keras_hub/src/models/t5gemma/t5gemma_decoder.py +355 -0
- keras_hub/src/models/t5gemma/t5gemma_encoder.py +214 -0
- keras_hub/src/models/t5gemma/t5gemma_layers.py +118 -0
- keras_hub/src/models/t5gemma/t5gemma_presets.py +374 -0
- keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm.py +442 -0
- keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm_preprocessor.py +216 -0
- keras_hub/src/models/t5gemma/t5gemma_tokenizer.py +84 -0
- keras_hub/src/models/text_to_image.py +5 -0
- keras_hub/src/samplers/beam_sampler.py +6 -6
- keras_hub/src/samplers/sampler.py +8 -6
- keras_hub/src/tests/test_case.py +40 -3
- keras_hub/src/tokenizers/tokenizer.py +15 -0
- keras_hub/src/utils/openvino_utils.py +141 -0
- keras_hub/src/utils/preset_utils.py +58 -2
- keras_hub/src/utils/tensor_utils.py +26 -2
- keras_hub/src/utils/timm/convert_mobilenetv5.py +321 -0
- keras_hub/src/utils/timm/preset_loader.py +8 -4
- keras_hub/src/utils/transformers/convert_dinov2.py +1 -0
- keras_hub/src/utils/transformers/convert_dinov3.py +106 -0
- keras_hub/src/utils/transformers/convert_qwen3_moe.py +216 -0
- keras_hub/src/utils/transformers/convert_smollm3.py +139 -0
- keras_hub/src/utils/transformers/convert_t5gemma.py +229 -0
- keras_hub/src/utils/transformers/convert_vit.py +4 -1
- keras_hub/src/utils/transformers/export/gemma.py +49 -4
- keras_hub/src/utils/transformers/export/hf_exporter.py +71 -25
- keras_hub/src/utils/transformers/preset_loader.py +12 -0
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +15 -0
- {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/RECORD +126 -47
- {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
3
|
+
from typing import Iterable
|
|
4
|
+
|
|
5
|
+
import keras
|
|
6
|
+
|
|
7
|
+
from keras_hub.src.api_export import keras_hub_export
|
|
8
|
+
from keras_hub.src.tokenizers import tokenizer
|
|
9
|
+
from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch
|
|
10
|
+
from keras_hub.src.utils.tensor_utils import is_int_dtype
|
|
11
|
+
from keras_hub.src.utils.tensor_utils import is_string_dtype
|
|
12
|
+
from keras_hub.src.utils.tensor_utils import preprocessing_function
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
import tensorflow as tf
|
|
16
|
+
import tensorflow_text as tf_text
|
|
17
|
+
except ImportError:
|
|
18
|
+
tf = None
|
|
19
|
+
tf_text = None
|
|
20
|
+
|
|
21
|
+
PARSEQ_VOCAB = list(
|
|
22
|
+
"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"
|
|
23
|
+
"\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
VOCAB_FILENAME = "vocabulary.txt"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@keras_hub_export(
|
|
30
|
+
[
|
|
31
|
+
"keras_hub.tokenizers.PARSeqTokenizer",
|
|
32
|
+
"keras_hub.models.PARSeqTokenizer",
|
|
33
|
+
]
|
|
34
|
+
)
|
|
35
|
+
class PARSeqTokenizer(tokenizer.Tokenizer):
|
|
36
|
+
"""A Tokenizer for PARSeq models, designed for OCR tasks.
|
|
37
|
+
|
|
38
|
+
This tokenizer converts strings into sequences of integer IDs or string
|
|
39
|
+
tokens, and vice-versa. It supports various preprocessing steps such as
|
|
40
|
+
whitespace removal, Unicode normalization, and limiting the maximum label
|
|
41
|
+
length. It also provides functionality to save and load the vocabulary
|
|
42
|
+
from a file.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
vocabulary: str. A string or iterable representing the vocabulary to
|
|
46
|
+
use. If a string, it's treated as the path to a vocabulary file.
|
|
47
|
+
If an iterable, it's treated as a list of characters forming
|
|
48
|
+
the vocabulary. Defaults to `PARSEQ_VOCAB`.
|
|
49
|
+
remove_whitespace: bool. Whether to remove whitespace characters from
|
|
50
|
+
the input. Defaults to `True`.
|
|
51
|
+
normalize_unicode: bool. Whether to normalize Unicode characters in the
|
|
52
|
+
input using NFKD normalization and remove non-ASCII characters.
|
|
53
|
+
Defaults to `True`.
|
|
54
|
+
max_label_length: int. The maximum length of the tokenized output.
|
|
55
|
+
Longer labels will be truncated. Defaults to `25`.
|
|
56
|
+
dtype: str. The data type of the tokenized output. Must be an integer
|
|
57
|
+
type (e.g., "int32") or a string type ("string").
|
|
58
|
+
Defaults to `"int32"`.
|
|
59
|
+
**kwargs: Additional keyword arguments passed to the base
|
|
60
|
+
`keras.layers.Layer` constructor.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def __init__(
|
|
64
|
+
self,
|
|
65
|
+
vocabulary=PARSEQ_VOCAB,
|
|
66
|
+
remove_whitespace=True,
|
|
67
|
+
normalize_unicode=True,
|
|
68
|
+
max_label_length=25,
|
|
69
|
+
dtype="int32",
|
|
70
|
+
**kwargs,
|
|
71
|
+
):
|
|
72
|
+
if not is_int_dtype(dtype) and not is_string_dtype(dtype):
|
|
73
|
+
raise ValueError(
|
|
74
|
+
"Output dtype must be an integer type or a string. "
|
|
75
|
+
f"Received: dtype={dtype}"
|
|
76
|
+
)
|
|
77
|
+
super().__init__(dtype=dtype, **kwargs)
|
|
78
|
+
self.remove_whitespace = remove_whitespace
|
|
79
|
+
self.normalize_unicode = normalize_unicode
|
|
80
|
+
self.max_label_length = max_label_length
|
|
81
|
+
self.file_assets = [VOCAB_FILENAME]
|
|
82
|
+
|
|
83
|
+
self.set_vocabulary(vocabulary)
|
|
84
|
+
|
|
85
|
+
def save_assets(self, dir_path):
|
|
86
|
+
path = os.path.join(dir_path, VOCAB_FILENAME)
|
|
87
|
+
with open(path, "w", encoding="utf-8") as file:
|
|
88
|
+
for token in self.vocabulary:
|
|
89
|
+
file.write(f"{token}\n")
|
|
90
|
+
|
|
91
|
+
def load_assets(self, dir_path):
|
|
92
|
+
path = os.path.join(dir_path, VOCAB_FILENAME)
|
|
93
|
+
self.set_vocabulary(path)
|
|
94
|
+
|
|
95
|
+
def set_vocabulary(self, vocabulary):
|
|
96
|
+
"""Set the tokenizer vocabulary to a file or list of strings."""
|
|
97
|
+
if vocabulary is None:
|
|
98
|
+
self.vocabulary = None
|
|
99
|
+
return
|
|
100
|
+
|
|
101
|
+
if isinstance(vocabulary, str):
|
|
102
|
+
with open(vocabulary, "r", encoding="utf-8") as file:
|
|
103
|
+
self.vocabulary = [line.rstrip() for line in file]
|
|
104
|
+
self.vocabulary = "".join(self.vocabulary)
|
|
105
|
+
elif isinstance(vocabulary, Iterable):
|
|
106
|
+
self.vocabulary = "".join(vocabulary)
|
|
107
|
+
else:
|
|
108
|
+
raise ValueError(
|
|
109
|
+
"Vocabulary must be an file path or list of terms. "
|
|
110
|
+
f"Received: vocabulary={vocabulary}"
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
self.lowercase_only = self.vocabulary == self.vocabulary.lower()
|
|
114
|
+
self.uppercase_only = self.vocabulary == self.vocabulary.upper()
|
|
115
|
+
escaped_charset = re.escape(self.vocabulary) # Escape for safe regex
|
|
116
|
+
self.unsupported_regex = f"[^{escaped_charset}]"
|
|
117
|
+
self._itos = ("[E]",) + tuple(self.vocabulary) + ("[B]", "[P]")
|
|
118
|
+
self._stoi = {s: i for i, s in enumerate(self._itos)}
|
|
119
|
+
|
|
120
|
+
self._add_special_token("[B]", "start_token")
|
|
121
|
+
self._add_special_token("[E]", "end_token")
|
|
122
|
+
self._add_special_token("[P]", "pad_token")
|
|
123
|
+
# Create lookup tables.
|
|
124
|
+
self.char_to_id = tf.lookup.StaticHashTable(
|
|
125
|
+
initializer=tf.lookup.KeyValueTensorInitializer(
|
|
126
|
+
keys=list(self._stoi.keys()),
|
|
127
|
+
values=list(self._stoi.values()),
|
|
128
|
+
key_dtype=tf.string,
|
|
129
|
+
value_dtype=tf.int32,
|
|
130
|
+
),
|
|
131
|
+
default_value=self._stoi["[E]"],
|
|
132
|
+
)
|
|
133
|
+
self.id_to_char = tf.lookup.StaticHashTable(
|
|
134
|
+
initializer=tf.lookup.KeyValueTensorInitializer(
|
|
135
|
+
keys=list(self._stoi.values()),
|
|
136
|
+
values=list(self._stoi.keys()),
|
|
137
|
+
key_dtype=tf.int32,
|
|
138
|
+
value_dtype=tf.string,
|
|
139
|
+
),
|
|
140
|
+
default_value=self.pad_token,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
def get_vocabulary(self):
|
|
144
|
+
"""Get the tokenizer vocabulary as a list of strings tokens."""
|
|
145
|
+
return list(self.vocabulary)
|
|
146
|
+
|
|
147
|
+
def id_to_token(self, id):
|
|
148
|
+
if id >= self.vocabulary_size() or id < 0:
|
|
149
|
+
raise ValueError(
|
|
150
|
+
f"`id` must be in range [0, {self.vocabulary_size() - 1}]. "
|
|
151
|
+
f"Received: {id}"
|
|
152
|
+
)
|
|
153
|
+
return self._itos[id]
|
|
154
|
+
|
|
155
|
+
def token_to_id(self, token):
|
|
156
|
+
return self._stoi[token]
|
|
157
|
+
|
|
158
|
+
def _preprocess(self, inputs):
|
|
159
|
+
"""Performs preprocessing include only characters from ASCII."""
|
|
160
|
+
if self.remove_whitespace:
|
|
161
|
+
inputs = tf.strings.regex_replace(inputs, r"\s+", "")
|
|
162
|
+
|
|
163
|
+
if self.normalize_unicode:
|
|
164
|
+
inputs = tf_text.normalize_utf8(inputs, normalization_form="NFKD")
|
|
165
|
+
inputs = tf.strings.regex_replace(inputs, r"[^!-~]", "")
|
|
166
|
+
|
|
167
|
+
if self.lowercase_only:
|
|
168
|
+
inputs = tf.strings.lower(inputs)
|
|
169
|
+
elif self.uppercase_only:
|
|
170
|
+
inputs = tf.strings.upper(inputs)
|
|
171
|
+
|
|
172
|
+
inputs = tf.strings.regex_replace(inputs, self.unsupported_regex, "")
|
|
173
|
+
inputs = tf.strings.substr(inputs, 0, self.max_label_length)
|
|
174
|
+
|
|
175
|
+
return inputs
|
|
176
|
+
|
|
177
|
+
@preprocessing_function
|
|
178
|
+
def tokenize(self, inputs):
|
|
179
|
+
inputs = tf.convert_to_tensor(inputs)
|
|
180
|
+
unbatched = inputs.shape.rank == 0
|
|
181
|
+
if unbatched:
|
|
182
|
+
inputs = tf.expand_dims(inputs, 0)
|
|
183
|
+
|
|
184
|
+
inputs = tf.map_fn(
|
|
185
|
+
self._preprocess, inputs, fn_output_signature=tf.string
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
token_ids = tf.cond(
|
|
189
|
+
tf.size(inputs) > 0,
|
|
190
|
+
lambda: self.char_to_id.lookup(
|
|
191
|
+
tf.strings.unicode_split(inputs, "UTF-8")
|
|
192
|
+
),
|
|
193
|
+
lambda: tf.RaggedTensor.from_row_splits(
|
|
194
|
+
values=tf.constant([], dtype=tf.int32),
|
|
195
|
+
row_splits=tf.constant([0], dtype=tf.int64),
|
|
196
|
+
),
|
|
197
|
+
)
|
|
198
|
+
if unbatched:
|
|
199
|
+
token_ids = tf.squeeze(token_ids, 0)
|
|
200
|
+
tf.ensure_shape(token_ids, shape=[self.max_label_length])
|
|
201
|
+
return token_ids
|
|
202
|
+
|
|
203
|
+
@preprocessing_function
|
|
204
|
+
def detokenize(self, inputs):
|
|
205
|
+
inputs, unbatched, rectangular = convert_to_ragged_batch(inputs)
|
|
206
|
+
# tf-text sentencepiece does not handle int64.
|
|
207
|
+
inputs = tf.cast(inputs, "int32")
|
|
208
|
+
outputs = self.id_to_char.lookup(inputs)
|
|
209
|
+
if unbatched:
|
|
210
|
+
outputs = tf.squeeze(outputs, 0)
|
|
211
|
+
return outputs
|
|
212
|
+
|
|
213
|
+
def vocabulary_size(self):
|
|
214
|
+
"""Get the integer size of the tokenizer vocabulary."""
|
|
215
|
+
return len(self.vocabulary) + 3
|
|
216
|
+
|
|
217
|
+
def compute_output_spec(self, input_spec):
|
|
218
|
+
return keras.KerasTensor(
|
|
219
|
+
input_spec.shape + (self.max_label_length,),
|
|
220
|
+
dtype=self.compute_dtype,
|
|
221
|
+
)
|
|
@@ -0,0 +1,5 @@
|
|
|
1
|
+
from keras_hub.src.models.qwen3_moe.qwen3_moe_backbone import Qwen3MoeBackbone
|
|
2
|
+
from keras_hub.src.models.qwen3_moe.qwen3_moe_presets import backbone_presets
|
|
3
|
+
from keras_hub.src.utils.preset_utils import register_presets
|
|
4
|
+
|
|
5
|
+
register_presets(backbone_presets, Qwen3MoeBackbone)
|
|
@@ -0,0 +1,371 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
import keras
|
|
4
|
+
from keras import ops
|
|
5
|
+
|
|
6
|
+
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
|
|
7
|
+
from keras_hub.src.models.qwen3_moe.qwen3_moe_layernorm import Qwen3MoeLayerNorm
|
|
8
|
+
from keras_hub.src.utils.keras_utils import clone_initializer
|
|
9
|
+
from keras_hub.src.utils.keras_utils import fused_attention_op_available
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Qwen3MoeAttention(keras.layers.Layer):
|
|
13
|
+
"""A multi-head attention layer for Qwen3Moe models
|
|
14
|
+
This attention implementation supports grouped-query attention (GQA) where
|
|
15
|
+
the number of key-value heads can be less than the number of query heads.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
num_query_heads: int. Number of query heads.
|
|
19
|
+
num_key_value_heads: int. Number of key/value heads (for GQA).
|
|
20
|
+
head_dim: int. The dimension of each attention head.
|
|
21
|
+
rope_max_wavelength: int. Maximum wavelength for RoPE (Rotary Position
|
|
22
|
+
Embedding).
|
|
23
|
+
rope_scaling_factor: float. Scaling factor for RoPE, used for extending
|
|
24
|
+
context length.
|
|
25
|
+
kernel_initializer: Initializer for the kernel weights.
|
|
26
|
+
dropout: float. Dropout rate for attention weights.
|
|
27
|
+
layer_norm_epsilon: float. The epsilon value for layer normalization.
|
|
28
|
+
sliding_window_size: int. Size of the sliding window for attention.
|
|
29
|
+
**kwargs: Additional keyword arguments to pass to the Layer.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
num_query_heads,
|
|
35
|
+
num_key_value_heads,
|
|
36
|
+
head_dim=None,
|
|
37
|
+
rope_max_wavelength=10000,
|
|
38
|
+
rope_scaling_factor=1,
|
|
39
|
+
kernel_initializer="glorot_uniform",
|
|
40
|
+
dropout=0.0,
|
|
41
|
+
layer_norm_epsilon=1e-6,
|
|
42
|
+
sliding_window_size=None,
|
|
43
|
+
**kwargs,
|
|
44
|
+
):
|
|
45
|
+
super().__init__(
|
|
46
|
+
**kwargs,
|
|
47
|
+
)
|
|
48
|
+
self.num_query_heads = num_query_heads
|
|
49
|
+
self.num_key_value_heads = num_key_value_heads
|
|
50
|
+
self.head_dim = head_dim
|
|
51
|
+
self.dropout = dropout
|
|
52
|
+
|
|
53
|
+
self.layer_norm_epsilon = layer_norm_epsilon
|
|
54
|
+
|
|
55
|
+
self.num_key_value_groups = num_query_heads // num_key_value_heads
|
|
56
|
+
self.rope_max_wavelength = rope_max_wavelength
|
|
57
|
+
|
|
58
|
+
self.kernel_initializer = keras.initializers.get(
|
|
59
|
+
clone_initializer(kernel_initializer)
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
self.rope_scaling_factor = rope_scaling_factor
|
|
63
|
+
self.sliding_window_size = sliding_window_size
|
|
64
|
+
|
|
65
|
+
def build(self, inputs_shape):
|
|
66
|
+
# Einsum variables:
|
|
67
|
+
# b = batch size
|
|
68
|
+
# q = query length
|
|
69
|
+
# k = key/value length
|
|
70
|
+
# m = model dim
|
|
71
|
+
# u = num query heads
|
|
72
|
+
# v = num key/value heads
|
|
73
|
+
# h = head dim
|
|
74
|
+
hidden_dim = inputs_shape[-1]
|
|
75
|
+
if not self.head_dim:
|
|
76
|
+
self.head_dim = hidden_dim // self.num_query_heads
|
|
77
|
+
|
|
78
|
+
self._inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
|
|
79
|
+
self._query_dense = keras.layers.EinsumDense(
|
|
80
|
+
equation="bqm,muh->bquh",
|
|
81
|
+
output_shape=(None, self.num_query_heads, self.head_dim),
|
|
82
|
+
kernel_initializer=self.kernel_initializer,
|
|
83
|
+
dtype=self.dtype_policy,
|
|
84
|
+
name="query",
|
|
85
|
+
)
|
|
86
|
+
self._query_dense.build(inputs_shape)
|
|
87
|
+
|
|
88
|
+
self._query_dense_layer_norm = Qwen3MoeLayerNorm(
|
|
89
|
+
epsilon=self.layer_norm_epsilon,
|
|
90
|
+
dtype=self.dtype_policy,
|
|
91
|
+
head_dim=self.head_dim,
|
|
92
|
+
name="query_dense_layernorm",
|
|
93
|
+
)
|
|
94
|
+
self._query_dense_layer_norm.build(inputs_shape)
|
|
95
|
+
|
|
96
|
+
self._key_dense = keras.layers.EinsumDense(
|
|
97
|
+
equation="bkm,mvh->bkvh",
|
|
98
|
+
output_shape=(
|
|
99
|
+
None,
|
|
100
|
+
self.num_key_value_heads,
|
|
101
|
+
self.head_dim,
|
|
102
|
+
),
|
|
103
|
+
kernel_initializer=self.kernel_initializer,
|
|
104
|
+
dtype=self.dtype_policy,
|
|
105
|
+
name="key",
|
|
106
|
+
)
|
|
107
|
+
self._key_dense.build(inputs_shape)
|
|
108
|
+
|
|
109
|
+
self._key_dense_layer_norm = Qwen3MoeLayerNorm(
|
|
110
|
+
epsilon=self.layer_norm_epsilon,
|
|
111
|
+
dtype=self.dtype_policy,
|
|
112
|
+
head_dim=self.head_dim,
|
|
113
|
+
name="key_dense_layernorm",
|
|
114
|
+
)
|
|
115
|
+
self._key_dense_layer_norm.build(inputs_shape)
|
|
116
|
+
|
|
117
|
+
self._value_dense = keras.layers.EinsumDense(
|
|
118
|
+
equation="bkm,mvh->bkvh",
|
|
119
|
+
output_shape=(
|
|
120
|
+
None,
|
|
121
|
+
self.num_key_value_heads,
|
|
122
|
+
self.head_dim,
|
|
123
|
+
),
|
|
124
|
+
kernel_initializer=self.kernel_initializer,
|
|
125
|
+
dtype=self.dtype_policy,
|
|
126
|
+
name="value",
|
|
127
|
+
)
|
|
128
|
+
self._value_dense.build(inputs_shape)
|
|
129
|
+
|
|
130
|
+
self._softmax = keras.layers.Softmax(
|
|
131
|
+
axis=-1,
|
|
132
|
+
dtype="float32",
|
|
133
|
+
name="attention_softmax",
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
self._dropout_layer = keras.layers.Dropout(
|
|
137
|
+
rate=self.dropout,
|
|
138
|
+
dtype=self.dtype_policy,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
self._output_dense = keras.layers.EinsumDense(
|
|
142
|
+
equation="bquh,uhm->bqm",
|
|
143
|
+
output_shape=(None, hidden_dim),
|
|
144
|
+
kernel_initializer=self.kernel_initializer,
|
|
145
|
+
dtype=self.dtype_policy,
|
|
146
|
+
name="attention_output",
|
|
147
|
+
)
|
|
148
|
+
self._output_dense.build(
|
|
149
|
+
(None, None, self.num_query_heads, self.head_dim)
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
self.rotary_embedding_layer = RotaryEmbedding(
|
|
153
|
+
max_wavelength=self.rope_max_wavelength,
|
|
154
|
+
scaling_factor=self.rope_scaling_factor,
|
|
155
|
+
dtype=self.dtype_policy,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
self._dot_product_equation = "bquh,bkuh->buqk"
|
|
159
|
+
self._combine_equation = "buqk,bkuh->bquh"
|
|
160
|
+
|
|
161
|
+
self.built = True
|
|
162
|
+
|
|
163
|
+
def call(
|
|
164
|
+
self,
|
|
165
|
+
hidden_states,
|
|
166
|
+
attention_mask=None,
|
|
167
|
+
cache=None,
|
|
168
|
+
cache_update_index=None,
|
|
169
|
+
training=None,
|
|
170
|
+
):
|
|
171
|
+
"""Applies attention mechanism to the input hidden states.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
hidden_states: Input tensor of shape [batch_size, seq_length,
|
|
175
|
+
hidden_size].
|
|
176
|
+
attention_mask: Mask tensor of shape [batch_size, seq_length,
|
|
177
|
+
seq_length].
|
|
178
|
+
cache: Optional cached key and value tensors.
|
|
179
|
+
cache_update_index: Index at which to update the cache.
|
|
180
|
+
training: Boolean indicating whether in training mode.
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
attention_output: Output tensor after applying attention.
|
|
184
|
+
cache: Updated cache tensors (if cache is provided).
|
|
185
|
+
"""
|
|
186
|
+
start_index = (
|
|
187
|
+
cache_update_index if cache_update_index is not None else 0
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
query = self._query_dense(hidden_states)
|
|
191
|
+
query = self._query_dense_layer_norm(query)
|
|
192
|
+
|
|
193
|
+
# Compute RoPE for queries
|
|
194
|
+
query = self.rotary_embedding_layer(query, start_index=start_index)
|
|
195
|
+
|
|
196
|
+
def _compute_key_value(x):
|
|
197
|
+
key = self._key_dense(x)
|
|
198
|
+
key = self._key_dense_layer_norm(key)
|
|
199
|
+
key = self.rotary_embedding_layer(key, start_index=start_index)
|
|
200
|
+
|
|
201
|
+
value = self._value_dense(x)
|
|
202
|
+
|
|
203
|
+
return key, value
|
|
204
|
+
|
|
205
|
+
if cache is not None:
|
|
206
|
+
key_cache = cache[:, 0, ...]
|
|
207
|
+
value_cache = cache[:, 1, ...]
|
|
208
|
+
if cache_update_index is None:
|
|
209
|
+
key = key_cache
|
|
210
|
+
value = value_cache
|
|
211
|
+
else:
|
|
212
|
+
key_update, value_update = _compute_key_value(hidden_states)
|
|
213
|
+
start = [0, cache_update_index, 0, 0]
|
|
214
|
+
key = ops.slice_update(key_cache, start, key_update)
|
|
215
|
+
value = ops.slice_update(value_cache, start, value_update)
|
|
216
|
+
cache = ops.stack((key, value), axis=1)
|
|
217
|
+
else:
|
|
218
|
+
if cache_update_index is not None:
|
|
219
|
+
raise ValueError(
|
|
220
|
+
"`cache_update_index` should not be set if `cache` is "
|
|
221
|
+
f"`None`. Received: cache={cache}, "
|
|
222
|
+
f"cache_update_index={cache_update_index}"
|
|
223
|
+
)
|
|
224
|
+
key, value = _compute_key_value(hidden_states)
|
|
225
|
+
|
|
226
|
+
# [batch_shape, seq_len, num_key_value_heads, head_dim]
|
|
227
|
+
# -> [batch_shape, seq_len, num_heads, head_dim]
|
|
228
|
+
key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2)
|
|
229
|
+
value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2)
|
|
230
|
+
|
|
231
|
+
attention_output = self._compute_attention(
|
|
232
|
+
query,
|
|
233
|
+
key,
|
|
234
|
+
value,
|
|
235
|
+
attention_mask,
|
|
236
|
+
cache_update_index=cache_update_index,
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
attention_output = self._dropout_layer(
|
|
240
|
+
attention_output, training=training
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
attention_output = self._output_dense(attention_output)
|
|
244
|
+
|
|
245
|
+
if cache is not None:
|
|
246
|
+
return attention_output, cache
|
|
247
|
+
return attention_output
|
|
248
|
+
|
|
249
|
+
def _masked_softmax(self, attention_scores, attention_mask=None):
|
|
250
|
+
"""Applies softmax with optional masking.
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
attention_scores: Attention score tensor.
|
|
254
|
+
attention_mask: Optional mask tensor.
|
|
255
|
+
|
|
256
|
+
Returns:
|
|
257
|
+
Masked softmax attention weights.
|
|
258
|
+
"""
|
|
259
|
+
if attention_mask is not None:
|
|
260
|
+
return self._softmax(
|
|
261
|
+
attention_scores, attention_mask[:, None, :, :]
|
|
262
|
+
)
|
|
263
|
+
return self._softmax(attention_scores)
|
|
264
|
+
|
|
265
|
+
def _compute_attention(
|
|
266
|
+
self, query, key, value, attention_mask=None, cache_update_index=None
|
|
267
|
+
):
|
|
268
|
+
"""Computes attention using query, key, and value tensors.
|
|
269
|
+
Uses Flash Attention when available for better performance.
|
|
270
|
+
|
|
271
|
+
Args:
|
|
272
|
+
query: Query tensor.
|
|
273
|
+
key: Key tensor.
|
|
274
|
+
value: Value tensor.
|
|
275
|
+
attention_mask: Optional mask tensor.
|
|
276
|
+
cache_update_index: Index for sliding window computation.
|
|
277
|
+
|
|
278
|
+
Returns:
|
|
279
|
+
attention_output: Output tensor after applying attention.
|
|
280
|
+
"""
|
|
281
|
+
if fused_attention_op_available():
|
|
282
|
+
# Use `dot_product_attention` with Flash Attention support if
|
|
283
|
+
# available.
|
|
284
|
+
if attention_mask is not None:
|
|
285
|
+
attention_mask = ops.expand_dims(attention_mask, axis=1)
|
|
286
|
+
attention_mask = ops.cast(attention_mask, dtype="bool")
|
|
287
|
+
attention_output = ops.dot_product_attention(
|
|
288
|
+
query,
|
|
289
|
+
key,
|
|
290
|
+
value,
|
|
291
|
+
mask=attention_mask,
|
|
292
|
+
scale=self._inv_norm_factor,
|
|
293
|
+
)
|
|
294
|
+
return attention_output
|
|
295
|
+
|
|
296
|
+
attention_scores = ops.einsum(self._dot_product_equation, query, key)
|
|
297
|
+
|
|
298
|
+
attention_scores = ops.multiply(
|
|
299
|
+
attention_scores,
|
|
300
|
+
ops.cast(self._inv_norm_factor, self.compute_dtype),
|
|
301
|
+
)
|
|
302
|
+
if self.sliding_window_size:
|
|
303
|
+
attention_mask = self._mask_sliding_window(
|
|
304
|
+
attention_mask,
|
|
305
|
+
cache_update_index=cache_update_index
|
|
306
|
+
if cache_update_index is not None
|
|
307
|
+
else 0,
|
|
308
|
+
)
|
|
309
|
+
attention_scores = self._masked_softmax(
|
|
310
|
+
attention_scores, attention_mask
|
|
311
|
+
)
|
|
312
|
+
attention_scores = ops.cast(attention_scores, self.compute_dtype)
|
|
313
|
+
attention_output = ops.einsum(
|
|
314
|
+
self._combine_equation, attention_scores, value
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
return attention_output
|
|
318
|
+
|
|
319
|
+
def _mask_sliding_window(
|
|
320
|
+
self,
|
|
321
|
+
attention_mask,
|
|
322
|
+
cache_update_index=0,
|
|
323
|
+
):
|
|
324
|
+
"""Creates and combines a sliding window mask with the attention mask.
|
|
325
|
+
|
|
326
|
+
Args:
|
|
327
|
+
attention_mask: Original attention mask.
|
|
328
|
+
cache_update_index: Starting index for the sliding window.
|
|
329
|
+
|
|
330
|
+
Returns:
|
|
331
|
+
Combined attention mask with sliding window constraints.
|
|
332
|
+
"""
|
|
333
|
+
_, query_len, key_len = ops.shape(attention_mask)
|
|
334
|
+
# Compute the sliding window for square attention.
|
|
335
|
+
all_ones = ops.ones((key_len, key_len), "bool")
|
|
336
|
+
if keras.config.backend() == "tensorflow":
|
|
337
|
+
# TODO: trui/tril has issues with dynamic shape on the tensorflow
|
|
338
|
+
# backend. We should fix, but use `band_part` for now.
|
|
339
|
+
import tensorflow as tf
|
|
340
|
+
|
|
341
|
+
band_size = ops.minimum(key_len, self.sliding_window_size - 1)
|
|
342
|
+
band_size = ops.cast(band_size, "int32")
|
|
343
|
+
sliding_mask = tf.linalg.band_part(all_ones, band_size, band_size)
|
|
344
|
+
else:
|
|
345
|
+
sliding_mask = ops.triu(
|
|
346
|
+
all_ones, -1 * self.sliding_window_size + 1
|
|
347
|
+
) * ops.tril(all_ones, self.sliding_window_size - 1)
|
|
348
|
+
# Slice the window for short queries during generation.
|
|
349
|
+
start = (cache_update_index, 0)
|
|
350
|
+
sliding_mask = ops.slice(sliding_mask, start, (query_len, key_len))
|
|
351
|
+
sliding_mask = ops.expand_dims(sliding_mask, 0)
|
|
352
|
+
return ops.logical_and(attention_mask, ops.cast(sliding_mask, "bool"))
|
|
353
|
+
|
|
354
|
+
def get_config(self):
|
|
355
|
+
config = super().get_config()
|
|
356
|
+
config.update(
|
|
357
|
+
{
|
|
358
|
+
"num_query_heads": self.num_query_heads,
|
|
359
|
+
"num_key_value_heads": self.num_key_value_heads,
|
|
360
|
+
"rope_max_wavelength": self.rope_max_wavelength,
|
|
361
|
+
"rope_scaling_factor": self.rope_scaling_factor,
|
|
362
|
+
"kernel_initializer": keras.initializers.serialize(
|
|
363
|
+
self.kernel_initializer
|
|
364
|
+
),
|
|
365
|
+
"dropout": self.dropout,
|
|
366
|
+
"sliding_window_size": self.sliding_window_size,
|
|
367
|
+
"head_dim": self.head_dim,
|
|
368
|
+
"layer_norm_epsilon": self.layer_norm_epsilon,
|
|
369
|
+
}
|
|
370
|
+
)
|
|
371
|
+
return config
|