keras-hub 0.20.0.dev1__py3-none-any.whl → 0.21.0.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +15 -33
- keras_hub/layers/__init__.py +134 -0
- keras_hub/metrics/__init__.py +11 -0
- keras_hub/models/__init__.py +642 -0
- keras_hub/samplers/__init__.py +18 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +25 -35
- keras_hub/src/layers/preprocessing/image_converter.py +1 -0
- keras_hub/src/layers/preprocessing/random_deletion.py +1 -1
- keras_hub/src/layers/preprocessing/random_swap.py +1 -1
- keras_hub/src/models/audio_to_text.py +66 -0
- keras_hub/src/models/audio_to_text_preprocessor.py +80 -0
- keras_hub/src/models/backbone.py +5 -2
- keras_hub/src/models/cspnet/cspnet_backbone.py +51 -26
- keras_hub/src/models/cspnet/cspnet_presets.py +38 -3
- keras_hub/src/models/falcon/falcon_backbone.py +1 -1
- keras_hub/src/models/gemma/gemma_presets.py +10 -10
- keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +3 -2
- keras_hub/src/models/gemma3/gemma3_presets.py +8 -8
- keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
- keras_hub/src/models/llama/llama_attention.py +24 -6
- keras_hub/src/models/llama/llama_backbone.py +50 -16
- keras_hub/src/models/llama/llama_decoder.py +20 -3
- keras_hub/src/models/llama/llama_presets.py +3 -3
- keras_hub/src/models/llama/llama_rotary_embedding.py +180 -0
- keras_hub/src/models/llama3/llama3_backbone.py +10 -2
- keras_hub/src/models/llama3/llama3_presets.py +84 -2
- keras_hub/src/models/mistral/mistral_presets.py +3 -3
- keras_hub/src/models/mixtral/__init__.py +5 -0
- keras_hub/src/models/mixtral/mixtral_attention.py +252 -0
- keras_hub/src/models/mixtral/mixtral_backbone.py +207 -0
- keras_hub/src/models/mixtral/mixtral_causal_lm.py +281 -0
- keras_hub/src/models/mixtral/mixtral_causal_lm_preprocessor.py +76 -0
- keras_hub/src/models/mixtral/mixtral_decoder.py +494 -0
- keras_hub/src/models/mixtral/mixtral_layer_norm.py +34 -0
- keras_hub/src/models/mixtral/mixtral_presets.py +26 -0
- keras_hub/src/models/mixtral/mixtral_tokenizer.py +21 -0
- keras_hub/src/models/moonshine/__init__.py +5 -0
- keras_hub/src/models/moonshine/moonshine_audio_converter.py +301 -0
- keras_hub/src/models/moonshine/moonshine_audio_to_text.py +383 -0
- keras_hub/src/models/moonshine/moonshine_audio_to_text_preprocessor.py +272 -0
- keras_hub/src/models/moonshine/moonshine_backbone.py +478 -0
- keras_hub/src/models/moonshine/moonshine_decoder.py +313 -0
- keras_hub/src/models/moonshine/moonshine_encoder.py +212 -0
- keras_hub/src/models/moonshine/moonshine_layers.py +239 -0
- keras_hub/src/models/moonshine/moonshine_multi_head_attention.py +355 -0
- keras_hub/src/models/moonshine/moonshine_presets.py +25 -0
- keras_hub/src/models/moonshine/moonshine_tokenizer.py +62 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +11 -11
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +1 -1
- keras_hub/src/models/qwen/__init__.py +4 -0
- keras_hub/src/models/qwen/qwen_attention.py +3 -1
- keras_hub/src/models/qwen/qwen_backbone.py +8 -1
- keras_hub/src/models/qwen/qwen_causal_lm.py +7 -0
- keras_hub/src/models/qwen/qwen_causal_lm_preprocessor.py +7 -0
- keras_hub/src/models/qwen/qwen_presets.py +61 -0
- keras_hub/src/models/qwen/qwen_tokenizer.py +9 -0
- keras_hub/src/models/qwen_moe/__init__.py +5 -0
- keras_hub/src/models/qwen_moe/qwen_moe_attention.py +375 -0
- keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +373 -0
- keras_hub/src/models/qwen_moe/qwen_moe_causal_lm.py +350 -0
- keras_hub/src/models/qwen_moe/qwen_moe_causal_lm_preprocessor.py +17 -0
- keras_hub/src/models/qwen_moe/qwen_moe_decoder.py +625 -0
- keras_hub/src/models/qwen_moe/qwen_moe_layernorm.py +32 -0
- keras_hub/src/models/qwen_moe/qwen_moe_presets.py +15 -0
- keras_hub/src/models/qwen_moe/qwen_moe_tokenizer.py +46 -0
- keras_hub/src/models/retinanet/retinanet_image_converter.py +0 -13
- keras_hub/src/models/retinanet/retinanet_presets.py +2 -2
- keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +0 -18
- keras_hub/src/models/segformer/segformer_presets.py +12 -12
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +6 -0
- keras_hub/src/models/task.py +5 -2
- keras_hub/src/models/xception/__init__.py +5 -0
- keras_hub/src/models/xception/xception_backbone.py +188 -0
- keras_hub/src/models/xception/xception_image_classifier.py +12 -0
- keras_hub/src/models/xception/xception_image_classifier_preprocessor.py +14 -0
- keras_hub/src/models/xception/xception_image_converter.py +8 -0
- keras_hub/src/models/xception/xception_presets.py +14 -0
- keras_hub/src/tests/mocks/mock_gemma3_tokenizer.py +155 -0
- keras_hub/src/utils/coco/__init__.py +0 -0
- keras_hub/src/utils/coco/coco_utils.py +133 -0
- keras_hub/src/utils/imagenet/imagenet_utils.py +36 -0
- keras_hub/src/utils/keras_utils.py +11 -0
- keras_hub/src/utils/preset_utils.py +70 -10
- keras_hub/src/utils/tensor_utils.py +27 -1
- keras_hub/src/utils/timm/convert_cspnet.py +94 -23
- keras_hub/src/utils/timm/preset_loader.py +6 -6
- keras_hub/src/utils/transformers/convert_llama3.py +21 -1
- keras_hub/src/utils/transformers/convert_mixtral.py +139 -0
- keras_hub/src/utils/transformers/convert_qwen.py +1 -0
- keras_hub/src/utils/transformers/convert_qwen_moe.py +253 -0
- keras_hub/src/utils/transformers/preset_loader.py +6 -0
- keras_hub/src/{version_utils.py → version.py} +1 -1
- keras_hub/tokenizers/__init__.py +117 -0
- keras_hub/utils/__init__.py +21 -0
- {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/METADATA +6 -20
- {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/RECORD +98 -55
- {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/WHEEL +1 -1
- keras_hub/api/__init__.py +0 -15
- keras_hub/api/layers/__init__.py +0 -86
- keras_hub/api/metrics/__init__.py +0 -11
- keras_hub/api/models/__init__.py +0 -416
- keras_hub/api/samplers/__init__.py +0 -16
- keras_hub/api/tokenizers/__init__.py +0 -58
- keras_hub/api/utils/__init__.py +0 -9
- {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,625 @@
|
|
|
1
|
+
import keras
|
|
2
|
+
from keras import ops
|
|
3
|
+
|
|
4
|
+
from keras_hub.src.layers.modeling.transformer_layer_utils import (
|
|
5
|
+
compute_causal_mask,
|
|
6
|
+
)
|
|
7
|
+
from keras_hub.src.layers.modeling.transformer_layer_utils import (
|
|
8
|
+
merge_padding_and_attention_mask,
|
|
9
|
+
)
|
|
10
|
+
from keras_hub.src.models.qwen_moe.qwen_moe_attention import QwenMoeAttention
|
|
11
|
+
from keras_hub.src.models.qwen_moe.qwen_moe_layernorm import QwenMoeLayerNorm
|
|
12
|
+
from keras_hub.src.utils.keras_utils import clone_initializer
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def compute_load_balancing_loss(
|
|
16
|
+
router_logits, num_experts, top_k, attention_mask=None
|
|
17
|
+
):
|
|
18
|
+
"""
|
|
19
|
+
Compute the load balancing auxiliary loss for a single MoE layer.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
router_logits: Tensor of shape (batch_size * seq_len, num_experts).
|
|
23
|
+
num_experts: Integer, total number of experts.
|
|
24
|
+
top_k: Integer, number of experts to select per token.
|
|
25
|
+
attention_mask: Tensor of shape (batch_size, seq_len, seq_len),
|
|
26
|
+
optional mask for padding.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
Scalar tensor representing the auxiliary loss.
|
|
30
|
+
"""
|
|
31
|
+
# Compute routing probabilities
|
|
32
|
+
routing_weights = ops.softmax(
|
|
33
|
+
router_logits, axis=-1
|
|
34
|
+
) # Shape: (batch_size * seq_len, num_experts)
|
|
35
|
+
|
|
36
|
+
# Get top-k experts
|
|
37
|
+
_, selected_experts = ops.top_k(
|
|
38
|
+
routing_weights, k=top_k
|
|
39
|
+
) # Shape: (batch_size * seq_len, top_k)
|
|
40
|
+
|
|
41
|
+
# Create one-hot encoding for selected experts
|
|
42
|
+
expert_mask = ops.one_hot(
|
|
43
|
+
selected_experts, num_experts
|
|
44
|
+
) # Shape: (batch_size * seq_len, top_k, num_experts)
|
|
45
|
+
|
|
46
|
+
if attention_mask is not None:
|
|
47
|
+
# Convert attention mask to (batch_size, seq_len)
|
|
48
|
+
batch_size, seq_len, _ = ops.shape(attention_mask)
|
|
49
|
+
flat_mask = ops.any(attention_mask, axis=-1)
|
|
50
|
+
flat_mask = ops.reshape(
|
|
51
|
+
flat_mask, (-1,)
|
|
52
|
+
) # Shape: (batch_size * seq_len,)
|
|
53
|
+
# Expand mask for broadcasting
|
|
54
|
+
expert_attention_mask = ops.expand_dims(
|
|
55
|
+
flat_mask, axis=-1
|
|
56
|
+
) # Shape: (batch_size * seq_len, 1)
|
|
57
|
+
expert_attention_mask = ops.cast(expert_attention_mask, dtype="float32")
|
|
58
|
+
|
|
59
|
+
# Compute masked means
|
|
60
|
+
tokens_per_expert = ops.sum(
|
|
61
|
+
expert_mask * expert_attention_mask[:, None, :], axis=0
|
|
62
|
+
) / ops.maximum(
|
|
63
|
+
ops.sum(expert_attention_mask[:, None, :], axis=0), 1e-9
|
|
64
|
+
) # Shape: (top_k, num_experts)
|
|
65
|
+
router_prob_per_expert = ops.sum(
|
|
66
|
+
routing_weights * expert_attention_mask, axis=0
|
|
67
|
+
) / ops.maximum(
|
|
68
|
+
ops.sum(expert_attention_mask, axis=0), 1e-9
|
|
69
|
+
) # Shape: (num_experts,)
|
|
70
|
+
else:
|
|
71
|
+
# Unmasked means
|
|
72
|
+
tokens_per_expert = ops.mean(
|
|
73
|
+
expert_mask, axis=0
|
|
74
|
+
) # Shape: (top_k, num_experts)
|
|
75
|
+
router_prob_per_expert = ops.mean(
|
|
76
|
+
routing_weights, axis=0
|
|
77
|
+
) # Shape: (num_experts,)
|
|
78
|
+
|
|
79
|
+
# Average over top_k dimension if necessary
|
|
80
|
+
tokens_per_expert = ops.mean(
|
|
81
|
+
tokens_per_expert, axis=0
|
|
82
|
+
) # Shape: (num_experts,)
|
|
83
|
+
|
|
84
|
+
# Compute the loss
|
|
85
|
+
overall_loss = ops.sum(tokens_per_expert * router_prob_per_expert)
|
|
86
|
+
return overall_loss * num_experts
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class QwenMoeMLP(keras.layers.Layer):
|
|
90
|
+
def __init__(
|
|
91
|
+
self,
|
|
92
|
+
intermediate_dim,
|
|
93
|
+
hidden_dim,
|
|
94
|
+
activation_fn="silu",
|
|
95
|
+
layer_norm_epsilon=1e-5,
|
|
96
|
+
kernel_initializer="glorot_uniform",
|
|
97
|
+
**kwargs,
|
|
98
|
+
):
|
|
99
|
+
super().__init__(**kwargs)
|
|
100
|
+
self.intermediate_dim = intermediate_dim
|
|
101
|
+
self.hidden_dim = hidden_dim
|
|
102
|
+
self.activation_fn = activation_fn
|
|
103
|
+
self.kernel_initializer = kernel_initializer
|
|
104
|
+
self.layer_norm_epsilon = layer_norm_epsilon
|
|
105
|
+
|
|
106
|
+
def build(self, decoder_sequence_shape):
|
|
107
|
+
# Feedforward layers.
|
|
108
|
+
self._feedforward_intermediate_dense = keras.layers.Dense(
|
|
109
|
+
self.intermediate_dim,
|
|
110
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
|
111
|
+
use_bias=False,
|
|
112
|
+
dtype=self.dtype_policy,
|
|
113
|
+
name="feedforward_intermediate_dense",
|
|
114
|
+
)
|
|
115
|
+
self._feedforward_intermediate_dense.build(decoder_sequence_shape)
|
|
116
|
+
|
|
117
|
+
self._feedforward_gate_dense = keras.layers.Dense(
|
|
118
|
+
self.intermediate_dim,
|
|
119
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
|
120
|
+
use_bias=False,
|
|
121
|
+
dtype=self.dtype_policy,
|
|
122
|
+
name="feedforward_gate_dense",
|
|
123
|
+
)
|
|
124
|
+
self._feedforward_gate_dense.build(decoder_sequence_shape)
|
|
125
|
+
|
|
126
|
+
self._feedforward_output_dense = keras.layers.Dense(
|
|
127
|
+
self.hidden_dim,
|
|
128
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
|
129
|
+
use_bias=False,
|
|
130
|
+
dtype=self.dtype_policy,
|
|
131
|
+
name="feedforward_output_dense",
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
self._feedforward_output_dense.build(
|
|
135
|
+
self._feedforward_gate_dense.compute_output_shape(
|
|
136
|
+
decoder_sequence_shape
|
|
137
|
+
)
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
self.activation = keras.activations.get(self.activation_fn)
|
|
141
|
+
self.built = True
|
|
142
|
+
|
|
143
|
+
def call(self, x):
|
|
144
|
+
gate_output = self._feedforward_gate_dense(x)
|
|
145
|
+
|
|
146
|
+
# Note that we run the activation function in full 32-bit
|
|
147
|
+
# precision since this is what `torch.nn.functional.silu`
|
|
148
|
+
# does. Internally, `torch.nn.functional.silu` converts the
|
|
149
|
+
# inputs to float32, computes SiLU, and converts the outputs
|
|
150
|
+
# back to compute dtype.
|
|
151
|
+
# CPU Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cpu/Activation.cpp#L1221-L1235 # noqa: E501
|
|
152
|
+
# CUDA Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cuda/ActivationSiluKernel.cu # noqa: E501
|
|
153
|
+
gate_output = ops.cast(gate_output, "float32")
|
|
154
|
+
gate_output = self.activation(gate_output)
|
|
155
|
+
gate_output = ops.cast(gate_output, self.compute_dtype)
|
|
156
|
+
|
|
157
|
+
x = self._feedforward_intermediate_dense(x)
|
|
158
|
+
|
|
159
|
+
x = self._feedforward_output_dense(ops.multiply(x, gate_output))
|
|
160
|
+
|
|
161
|
+
return x
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class QwenMoeExperts(keras.layers.Layer):
|
|
165
|
+
"""Batched Experts Layer"""
|
|
166
|
+
|
|
167
|
+
def __init__(
|
|
168
|
+
self,
|
|
169
|
+
num_experts,
|
|
170
|
+
hidden_dim,
|
|
171
|
+
intermediate_dim,
|
|
172
|
+
activation_fn="silu",
|
|
173
|
+
kernel_initializer="glorot_uniform",
|
|
174
|
+
**kwargs,
|
|
175
|
+
):
|
|
176
|
+
super().__init__(**kwargs)
|
|
177
|
+
self.num_experts = num_experts
|
|
178
|
+
self.hidden_dim = hidden_dim
|
|
179
|
+
self.intermediate_dim = intermediate_dim
|
|
180
|
+
self.activation = keras.activations.get(activation_fn)
|
|
181
|
+
self.kernel_initializer = kernel_initializer
|
|
182
|
+
|
|
183
|
+
def build(self, _):
|
|
184
|
+
self._expert_feedforward_gate_dense = self.add_weight(
|
|
185
|
+
shape=(
|
|
186
|
+
self.num_experts,
|
|
187
|
+
self.hidden_dim,
|
|
188
|
+
2 * self.intermediate_dim,
|
|
189
|
+
),
|
|
190
|
+
initializer=self.kernel_initializer,
|
|
191
|
+
trainable=True,
|
|
192
|
+
dtype=self.variable_dtype,
|
|
193
|
+
name="expert_feedforward_gate_dense",
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
self._expert_feedforward_output_dense = self.add_weight(
|
|
197
|
+
shape=(self.num_experts, self.intermediate_dim, self.hidden_dim),
|
|
198
|
+
initializer=self.kernel_initializer,
|
|
199
|
+
trainable=True,
|
|
200
|
+
dtype=self.variable_dtype,
|
|
201
|
+
name="expert_feedforward_output_dense",
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
self.built = True
|
|
205
|
+
|
|
206
|
+
def call(self, hidden_states):
|
|
207
|
+
gate_up = ops.einsum(
|
|
208
|
+
"th,ehm->etm", hidden_states, self._expert_feedforward_gate_dense
|
|
209
|
+
)
|
|
210
|
+
gate, up = ops.split(gate_up, 2, axis=-1)
|
|
211
|
+
hidden = up * self.activation(gate)
|
|
212
|
+
out = ops.einsum(
|
|
213
|
+
"eti,eih->eth", hidden, self._expert_feedforward_output_dense
|
|
214
|
+
)
|
|
215
|
+
return out
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
class QwenSparseMoeBlock(keras.layers.Layer):
|
|
219
|
+
"""Qwen-2 Sparse Moe Block"""
|
|
220
|
+
|
|
221
|
+
def __init__(
|
|
222
|
+
self,
|
|
223
|
+
hidden_dim,
|
|
224
|
+
moe_intermediate_dim,
|
|
225
|
+
shared_expert_intermediate_dim,
|
|
226
|
+
num_experts,
|
|
227
|
+
top_k,
|
|
228
|
+
norm_top_k_prob,
|
|
229
|
+
kernel_initializer="glorot_uniform",
|
|
230
|
+
layer_norm_epsilon=1e-5,
|
|
231
|
+
router_aux_loss_coefficient=0.01,
|
|
232
|
+
**kwargs,
|
|
233
|
+
):
|
|
234
|
+
super().__init__(**kwargs)
|
|
235
|
+
self.hidden_dim = hidden_dim
|
|
236
|
+
self.intermediate_dim = moe_intermediate_dim
|
|
237
|
+
self.intermediate_dim_shared = shared_expert_intermediate_dim
|
|
238
|
+
self.num_experts = num_experts
|
|
239
|
+
self.top_k = top_k
|
|
240
|
+
self.norm_top_k_prob = norm_top_k_prob
|
|
241
|
+
self.kernel_initializer = kernel_initializer
|
|
242
|
+
self.layer_norm_epsilon = layer_norm_epsilon
|
|
243
|
+
self.router_aux_loss_coefficient = router_aux_loss_coefficient
|
|
244
|
+
|
|
245
|
+
def build(self, decoder_sequence_shape):
|
|
246
|
+
self._sparse_feedforward_gate_dense = keras.layers.Dense(
|
|
247
|
+
self.num_experts,
|
|
248
|
+
use_bias=False,
|
|
249
|
+
kernel_initializer=self.kernel_initializer,
|
|
250
|
+
name="sparse_feedforward_gate_dense",
|
|
251
|
+
dtype=self.dtype_policy,
|
|
252
|
+
)
|
|
253
|
+
self._sparse_feedforward_gate_dense.build(decoder_sequence_shape)
|
|
254
|
+
|
|
255
|
+
# NOTE: Experts are implemented as a single layer to enable efficient
|
|
256
|
+
# batched computation. Implementing each expert individually is
|
|
257
|
+
# currently avoided due to the lack of `ragged_dot` support in the
|
|
258
|
+
# Keras ops API, which would make individual implementations unstable
|
|
259
|
+
# and prone to bugs.
|
|
260
|
+
self.expert_bank = QwenMoeExperts(
|
|
261
|
+
num_experts=self.num_experts,
|
|
262
|
+
hidden_dim=self.hidden_dim,
|
|
263
|
+
intermediate_dim=self.intermediate_dim,
|
|
264
|
+
kernel_initializer=self.kernel_initializer,
|
|
265
|
+
name="experts",
|
|
266
|
+
dtype=self.dtype_policy,
|
|
267
|
+
)
|
|
268
|
+
self.expert_bank.build(decoder_sequence_shape)
|
|
269
|
+
|
|
270
|
+
self.shared_expert_dense = QwenMoeMLP(
|
|
271
|
+
intermediate_dim=self.intermediate_dim_shared,
|
|
272
|
+
hidden_dim=self.hidden_dim,
|
|
273
|
+
kernel_initializer=self.kernel_initializer,
|
|
274
|
+
layer_norm_epsilon=self.layer_norm_epsilon,
|
|
275
|
+
name="shared_expert_dense",
|
|
276
|
+
dtype=self.dtype_policy,
|
|
277
|
+
)
|
|
278
|
+
self.shared_expert_dense.build(decoder_sequence_shape)
|
|
279
|
+
|
|
280
|
+
self.shared_expert_gate_dense = keras.layers.Dense(
|
|
281
|
+
1,
|
|
282
|
+
use_bias=False,
|
|
283
|
+
name="shared_expert_gate_dense",
|
|
284
|
+
dtype=self.dtype_policy,
|
|
285
|
+
)
|
|
286
|
+
self.shared_expert_gate_dense.build(decoder_sequence_shape)
|
|
287
|
+
|
|
288
|
+
self.built = True
|
|
289
|
+
|
|
290
|
+
def call(self, hidden_states, attention_mask=None, training=None):
|
|
291
|
+
batch_size, seq_len, _ = ops.shape(hidden_states)
|
|
292
|
+
hidden_states_flattened = ops.reshape(
|
|
293
|
+
hidden_states, (-1, self.hidden_dim)
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
router_logits = self._sparse_feedforward_gate_dense(
|
|
297
|
+
hidden_states_flattened
|
|
298
|
+
)
|
|
299
|
+
router_probs = ops.softmax(router_logits, axis=-1)
|
|
300
|
+
|
|
301
|
+
top_p, top_i = ops.top_k(router_probs, k=self.top_k)
|
|
302
|
+
if self.norm_top_k_prob:
|
|
303
|
+
top_p = top_p / ops.sum(top_p, axis=-1, keepdims=True)
|
|
304
|
+
|
|
305
|
+
one_hot = ops.one_hot(top_i, self.num_experts)
|
|
306
|
+
one_hot = ops.cast(one_hot, top_p.dtype)
|
|
307
|
+
routing_full = ops.sum(one_hot * top_p[..., None], axis=1)
|
|
308
|
+
routing_full = ops.transpose(routing_full, (1, 0))
|
|
309
|
+
routing_full = ops.cast(routing_full, hidden_states_flattened.dtype)
|
|
310
|
+
|
|
311
|
+
expert_out = self.expert_bank(hidden_states_flattened)
|
|
312
|
+
|
|
313
|
+
weighted_out = expert_out * routing_full[:, :, None]
|
|
314
|
+
expert_contribution = ops.sum(weighted_out, axis=0)
|
|
315
|
+
|
|
316
|
+
shared_expert_output = self.shared_expert_dense(hidden_states_flattened)
|
|
317
|
+
shared_expert_output *= ops.sigmoid(
|
|
318
|
+
self.shared_expert_gate_dense(hidden_states_flattened)
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
out_flat = expert_contribution + shared_expert_output
|
|
322
|
+
out = ops.reshape(out_flat, (batch_size, seq_len, self.hidden_dim))
|
|
323
|
+
|
|
324
|
+
# Compute and add auxiliary loss during training
|
|
325
|
+
if training:
|
|
326
|
+
aux_loss = compute_load_balancing_loss(
|
|
327
|
+
router_logits=router_logits,
|
|
328
|
+
num_experts=self.num_experts,
|
|
329
|
+
top_k=self.top_k,
|
|
330
|
+
attention_mask=attention_mask,
|
|
331
|
+
)
|
|
332
|
+
self.add_loss(self.router_aux_loss_coefficient * aux_loss)
|
|
333
|
+
|
|
334
|
+
return out, router_logits
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
class QwenMoeTransformerDecoder(keras.layers.Layer):
|
|
338
|
+
def __init__(
|
|
339
|
+
self,
|
|
340
|
+
intermediate_dim,
|
|
341
|
+
num_query_heads,
|
|
342
|
+
num_key_value_heads,
|
|
343
|
+
moe_intermediate_dim,
|
|
344
|
+
shared_expert_intermediate_dim,
|
|
345
|
+
num_experts,
|
|
346
|
+
top_k,
|
|
347
|
+
norm_top_k_prob,
|
|
348
|
+
decoder_sparse_step,
|
|
349
|
+
rope_max_wavelength=10000,
|
|
350
|
+
rope_scaling_factor=1.0,
|
|
351
|
+
activation="silu",
|
|
352
|
+
layer_norm_epsilon=1e-5,
|
|
353
|
+
kernel_initializer="glorot_uniform",
|
|
354
|
+
dropout=0,
|
|
355
|
+
use_sliding_window_attention=False,
|
|
356
|
+
sliding_window_size=4096,
|
|
357
|
+
layer_index=0,
|
|
358
|
+
mlp_only_layers=[],
|
|
359
|
+
output_router_logits=False,
|
|
360
|
+
router_aux_loss_coefficient=0.001,
|
|
361
|
+
**kwargs,
|
|
362
|
+
):
|
|
363
|
+
super().__init__(**kwargs)
|
|
364
|
+
self.intermediate_dim = intermediate_dim
|
|
365
|
+
self.num_query_heads = num_query_heads
|
|
366
|
+
self.num_key_value_heads = num_key_value_heads
|
|
367
|
+
self.rope_max_wavelength = rope_max_wavelength
|
|
368
|
+
self.rope_scaling_factor = rope_scaling_factor
|
|
369
|
+
self.dropout = dropout
|
|
370
|
+
self.use_sliding_window_attention = use_sliding_window_attention
|
|
371
|
+
self.sliding_window_size = sliding_window_size
|
|
372
|
+
self.activation = keras.activations.get(activation)
|
|
373
|
+
self.layer_norm_epsilon = layer_norm_epsilon
|
|
374
|
+
self.kernel_initializer = keras.initializers.get(kernel_initializer)
|
|
375
|
+
self.layer_index = layer_index
|
|
376
|
+
self.mlp_only_layers = mlp_only_layers
|
|
377
|
+
self.moe_intermediate_dim = moe_intermediate_dim
|
|
378
|
+
self.shared_expert_intermediate_dim = shared_expert_intermediate_dim
|
|
379
|
+
self.num_experts = num_experts
|
|
380
|
+
self.top_k = top_k
|
|
381
|
+
self.norm_top_k_prob = norm_top_k_prob
|
|
382
|
+
self.decoder_sparse_step = decoder_sparse_step
|
|
383
|
+
self.output_router_logits = output_router_logits
|
|
384
|
+
self.router_aux_loss_coefficient = router_aux_loss_coefficient
|
|
385
|
+
self.supports_masking = True
|
|
386
|
+
|
|
387
|
+
def build(self, decoder_sequence_shape):
|
|
388
|
+
self._decoder_sequence_shape = decoder_sequence_shape
|
|
389
|
+
self.hidden_dim = decoder_sequence_shape[-1]
|
|
390
|
+
|
|
391
|
+
# Self attention layer.
|
|
392
|
+
self._self_attention_layer = QwenMoeAttention(
|
|
393
|
+
num_query_heads=self.num_query_heads,
|
|
394
|
+
num_key_value_heads=self.num_key_value_heads,
|
|
395
|
+
rope_max_wavelength=self.rope_max_wavelength,
|
|
396
|
+
rope_scaling_factor=self.rope_scaling_factor,
|
|
397
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
|
398
|
+
dropout=self.dropout,
|
|
399
|
+
use_sliding_window_attention=self.use_sliding_window_attention,
|
|
400
|
+
sliding_window_size=self.sliding_window_size,
|
|
401
|
+
name="self_attention",
|
|
402
|
+
dtype=self.dtype_policy,
|
|
403
|
+
)
|
|
404
|
+
self._self_attention_layer.build(decoder_sequence_shape)
|
|
405
|
+
|
|
406
|
+
self._self_attention_layernorm = QwenMoeLayerNorm(
|
|
407
|
+
epsilon=self.layer_norm_epsilon,
|
|
408
|
+
dtype=self.dtype_policy,
|
|
409
|
+
name="self_attention_layernorm",
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
self._self_attention_layernorm.build(decoder_sequence_shape)
|
|
413
|
+
self._self_attention_dropout = keras.layers.Dropout(
|
|
414
|
+
rate=self.dropout,
|
|
415
|
+
dtype=self.dtype_policy,
|
|
416
|
+
name="self_attention_dropout",
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
# Feedforward layers.
|
|
420
|
+
if (self.layer_index not in self.mlp_only_layers) and (
|
|
421
|
+
self.num_experts > 0
|
|
422
|
+
and (self.layer_index + 1) % self.decoder_sparse_step == 0
|
|
423
|
+
):
|
|
424
|
+
self.mlp = QwenSparseMoeBlock(
|
|
425
|
+
hidden_dim=self.hidden_dim,
|
|
426
|
+
moe_intermediate_dim=self.moe_intermediate_dim,
|
|
427
|
+
shared_expert_intermediate_dim=self.shared_expert_intermediate_dim,
|
|
428
|
+
num_experts=self.num_experts,
|
|
429
|
+
top_k=self.top_k,
|
|
430
|
+
norm_top_k_prob=self.norm_top_k_prob,
|
|
431
|
+
router_aux_loss_coefficient=self.router_aux_loss_coefficient,
|
|
432
|
+
kernel_initializer=self.kernel_initializer,
|
|
433
|
+
dtype=self.dtype_policy,
|
|
434
|
+
)
|
|
435
|
+
self.mlp.build(decoder_sequence_shape)
|
|
436
|
+
else:
|
|
437
|
+
self.mlp = QwenMoeMLP(
|
|
438
|
+
intermediate_dim=self.intermediate_dim,
|
|
439
|
+
hidden_dim=self.hidden_dim,
|
|
440
|
+
dtype=self.dtype_policy,
|
|
441
|
+
)
|
|
442
|
+
self.mlp.build(decoder_sequence_shape)
|
|
443
|
+
|
|
444
|
+
self._feedforward_layernorm = QwenMoeLayerNorm(
|
|
445
|
+
epsilon=self.layer_norm_epsilon,
|
|
446
|
+
dtype=self.dtype_policy,
|
|
447
|
+
name="feedforward_layernorm",
|
|
448
|
+
)
|
|
449
|
+
self._feedforward_layernorm.build(decoder_sequence_shape)
|
|
450
|
+
|
|
451
|
+
self.built = True
|
|
452
|
+
|
|
453
|
+
def call(
|
|
454
|
+
self,
|
|
455
|
+
decoder_sequence,
|
|
456
|
+
decoder_padding_mask=None,
|
|
457
|
+
decoder_attention_mask=None,
|
|
458
|
+
self_attention_cache=None,
|
|
459
|
+
self_attention_cache_update_index=None,
|
|
460
|
+
training=None,
|
|
461
|
+
):
|
|
462
|
+
"""Forward pass for the decoder layer.
|
|
463
|
+
|
|
464
|
+
Args:
|
|
465
|
+
decoder_sequence: Input tensor of shape [batch_size, seq_length,
|
|
466
|
+
hidden_size].
|
|
467
|
+
decoder_padding_mask: Mask tensor for padding tokens.
|
|
468
|
+
decoder_attention_mask: Additional attention mask.
|
|
469
|
+
self_attention_cache: Optional cached key and value tensors for
|
|
470
|
+
self-attention.
|
|
471
|
+
self_attention_cache_update_index: Index at which to update the
|
|
472
|
+
cache.
|
|
473
|
+
training: Boolean indicating whether in training mode.
|
|
474
|
+
|
|
475
|
+
Returns:
|
|
476
|
+
decoder_output: Output tensor after applying transformer decoder
|
|
477
|
+
block.
|
|
478
|
+
self_attention_cache: Updated cache tensors (if cache is provided).
|
|
479
|
+
"""
|
|
480
|
+
self_attention_mask = self._compute_self_attention_mask(
|
|
481
|
+
decoder_sequence=decoder_sequence,
|
|
482
|
+
decoder_padding_mask=decoder_padding_mask,
|
|
483
|
+
decoder_attention_mask=decoder_attention_mask,
|
|
484
|
+
self_attention_cache=self_attention_cache,
|
|
485
|
+
self_attention_cache_update_index=self_attention_cache_update_index,
|
|
486
|
+
)
|
|
487
|
+
residual = decoder_sequence
|
|
488
|
+
|
|
489
|
+
x = self._self_attention_layernorm(decoder_sequence)
|
|
490
|
+
|
|
491
|
+
# Self attention block.
|
|
492
|
+
x = self._self_attention_layer(
|
|
493
|
+
hidden_states=x,
|
|
494
|
+
attention_mask=self_attention_mask,
|
|
495
|
+
cache=self_attention_cache,
|
|
496
|
+
cache_update_index=self_attention_cache_update_index,
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
if self_attention_cache is not None:
|
|
500
|
+
x, self_attention_cache = x
|
|
501
|
+
|
|
502
|
+
x = self._self_attention_dropout(x, training=training)
|
|
503
|
+
|
|
504
|
+
x = x + residual
|
|
505
|
+
residual = x
|
|
506
|
+
|
|
507
|
+
x = self._feedforward_layernorm(x)
|
|
508
|
+
if isinstance(self.mlp, QwenSparseMoeBlock):
|
|
509
|
+
x = self.mlp(
|
|
510
|
+
x, training=training, attention_mask=self_attention_mask
|
|
511
|
+
)
|
|
512
|
+
else:
|
|
513
|
+
x = self.mlp(x)
|
|
514
|
+
if isinstance(x, tuple):
|
|
515
|
+
x, router_logits = x
|
|
516
|
+
else:
|
|
517
|
+
router_logits = None
|
|
518
|
+
|
|
519
|
+
x = ops.cast(x, ops.dtype(residual))
|
|
520
|
+
decoder_output = x + residual
|
|
521
|
+
|
|
522
|
+
output = (decoder_output,)
|
|
523
|
+
|
|
524
|
+
if self_attention_cache is not None:
|
|
525
|
+
output += (self_attention_cache,)
|
|
526
|
+
|
|
527
|
+
if self.output_router_logits:
|
|
528
|
+
output += (router_logits,)
|
|
529
|
+
|
|
530
|
+
return output[0] if len(output) == 1 else output
|
|
531
|
+
|
|
532
|
+
def _compute_self_attention_mask(
|
|
533
|
+
self,
|
|
534
|
+
decoder_sequence,
|
|
535
|
+
decoder_padding_mask,
|
|
536
|
+
decoder_attention_mask,
|
|
537
|
+
self_attention_cache,
|
|
538
|
+
self_attention_cache_update_index,
|
|
539
|
+
):
|
|
540
|
+
"""Computes the self-attention mask combining causal, padding and
|
|
541
|
+
attention masks.
|
|
542
|
+
|
|
543
|
+
Args:
|
|
544
|
+
decoder_sequence: Input tensor.
|
|
545
|
+
decoder_padding_mask: Mask tensor for padding tokens.
|
|
546
|
+
decoder_attention_mask: Additional attention mask.
|
|
547
|
+
self_attention_cache: Optional cached key and value tensors.
|
|
548
|
+
self_attention_cache_update_index: Index at which to update the
|
|
549
|
+
cache.
|
|
550
|
+
|
|
551
|
+
Returns:
|
|
552
|
+
Combined attention mask tensor.
|
|
553
|
+
"""
|
|
554
|
+
decoder_mask = merge_padding_and_attention_mask(
|
|
555
|
+
decoder_sequence, decoder_padding_mask, decoder_attention_mask
|
|
556
|
+
)
|
|
557
|
+
batch_size = ops.shape(decoder_sequence)[0]
|
|
558
|
+
input_length = output_length = ops.shape(decoder_sequence)[1]
|
|
559
|
+
# We need to handle a rectangular causal mask when doing cached
|
|
560
|
+
# decoding. For generative inference, `decoder_sequence` will
|
|
561
|
+
# generally be length 1, and `cache` will be the full generation length.
|
|
562
|
+
if self_attention_cache is not None:
|
|
563
|
+
input_length = ops.shape(self_attention_cache)[2]
|
|
564
|
+
|
|
565
|
+
cache_update_index = (
|
|
566
|
+
0
|
|
567
|
+
if self_attention_cache_update_index is None
|
|
568
|
+
else self_attention_cache_update_index
|
|
569
|
+
)
|
|
570
|
+
|
|
571
|
+
causal_mask = compute_causal_mask(
|
|
572
|
+
batch_size, input_length, output_length, cache_update_index
|
|
573
|
+
)
|
|
574
|
+
|
|
575
|
+
return (
|
|
576
|
+
ops.minimum(decoder_mask, causal_mask)
|
|
577
|
+
if decoder_mask is not None
|
|
578
|
+
else causal_mask
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
def compute_output_shape(self, decoder_sequence_shape):
|
|
582
|
+
"""Computes the output shape of the layer.
|
|
583
|
+
|
|
584
|
+
Args:
|
|
585
|
+
decoder_sequence_shape: Shape of the decoder sequence input.
|
|
586
|
+
|
|
587
|
+
Returns:
|
|
588
|
+
Output shape, which is the same as the input shape.
|
|
589
|
+
"""
|
|
590
|
+
return decoder_sequence_shape
|
|
591
|
+
|
|
592
|
+
def get_config(self):
|
|
593
|
+
"""Returns the config of the layer.
|
|
594
|
+
|
|
595
|
+
Returns:
|
|
596
|
+
Dictionary containing the parameters used to initialize this layer.
|
|
597
|
+
"""
|
|
598
|
+
config = super().get_config()
|
|
599
|
+
config.update(
|
|
600
|
+
{
|
|
601
|
+
"num_query_heads": self.num_query_heads,
|
|
602
|
+
"intermediate_dim": self.intermediate_dim,
|
|
603
|
+
"moe_intermediate_dim": self.moe_intermediate_dim,
|
|
604
|
+
"shared_expert_intermediate_dim": (
|
|
605
|
+
self.shared_expert_intermediate_dim
|
|
606
|
+
),
|
|
607
|
+
"rope_max_wavelength": self.rope_max_wavelength,
|
|
608
|
+
"num_key_value_heads": self.num_key_value_heads,
|
|
609
|
+
"rope_scaling_factor": self.rope_scaling_factor,
|
|
610
|
+
"layer_norm_epsilon": self.layer_norm_epsilon,
|
|
611
|
+
"dropout": self.dropout,
|
|
612
|
+
"use_sliding_window_attention": (
|
|
613
|
+
self.use_sliding_window_attention
|
|
614
|
+
),
|
|
615
|
+
"sliding_window_size": self.sliding_window_size,
|
|
616
|
+
"num_experts": self.num_experts,
|
|
617
|
+
"top_k": self.top_k,
|
|
618
|
+
"norm_top_k_prob": self.norm_top_k_prob,
|
|
619
|
+
"decoder_sparse_step": self.decoder_sparse_step,
|
|
620
|
+
"mlp_only_layers": self.mlp_only_layers,
|
|
621
|
+
"output_router_logits": self.output_router_logits,
|
|
622
|
+
"router_aux_loss_coefficient": self.router_aux_loss_coefficient,
|
|
623
|
+
}
|
|
624
|
+
)
|
|
625
|
+
return config
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
import keras
|
|
2
|
+
from keras import ops
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class QwenMoeLayerNorm(keras.layers.Layer):
|
|
6
|
+
"""A normalization layer for Qwen that implements RMS normalization."""
|
|
7
|
+
|
|
8
|
+
def __init__(self, epsilon=1e-6, **kwargs):
|
|
9
|
+
super().__init__(**kwargs)
|
|
10
|
+
self.epsilon = epsilon
|
|
11
|
+
|
|
12
|
+
def build(self, input_shape):
|
|
13
|
+
dim = input_shape[-1]
|
|
14
|
+
self.scale = self.add_weight(
|
|
15
|
+
name="scale",
|
|
16
|
+
trainable=True,
|
|
17
|
+
shape=(dim,),
|
|
18
|
+
initializer="ones",
|
|
19
|
+
dtype=self.variable_dtype,
|
|
20
|
+
)
|
|
21
|
+
self.built = True
|
|
22
|
+
|
|
23
|
+
def call(self, x):
|
|
24
|
+
x = ops.cast(x, "float32")
|
|
25
|
+
var = ops.mean(ops.power(x, 2), axis=-1, keepdims=True)
|
|
26
|
+
x = x * ops.rsqrt(var + self.epsilon)
|
|
27
|
+
return ops.cast(x * self.scale, self.compute_dtype)
|
|
28
|
+
|
|
29
|
+
def get_config(self):
|
|
30
|
+
config = super().get_config()
|
|
31
|
+
config.update({"epsilon": self.epsilon})
|
|
32
|
+
return config
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Qwen MoE preset configurations."""
|
|
2
|
+
|
|
3
|
+
backbone_presets = {
|
|
4
|
+
"qwen1.5_moe_2.7b_en": {
|
|
5
|
+
"metadata": {
|
|
6
|
+
"description": (
|
|
7
|
+
"24-layer Qwen MoE model with 2.7 billion active parameters ",
|
|
8
|
+
"and 8 experts per MoE layer.",
|
|
9
|
+
),
|
|
10
|
+
"params": 14315784192,
|
|
11
|
+
"path": "qwen-1.5-moe",
|
|
12
|
+
},
|
|
13
|
+
"kaggle_handle": "kaggle://keras/qwen-1.5-moe/Keras/qwen1.5_moe_2.7b_en/3",
|
|
14
|
+
},
|
|
15
|
+
}
|