keras-hub-nightly 0.23.0.dev202509180413__py3-none-any.whl → 0.23.0.dev202509280419__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of keras-hub-nightly might be problematic. Click here for more details.
- keras_hub/layers/__init__.py +3 -0
- keras_hub/models/__init__.py +24 -0
- keras_hub/src/models/depth_anything/__init__.py +9 -0
- keras_hub/src/models/depth_anything/depth_anything_backbone.py +232 -0
- keras_hub/src/models/depth_anything/depth_anything_depth_estimator.py +70 -0
- keras_hub/src/models/depth_anything/depth_anything_depth_estimator_preprocessor.py +16 -0
- keras_hub/src/models/depth_anything/depth_anything_image_converter.py +10 -0
- keras_hub/src/models/depth_anything/depth_anything_layers.py +725 -0
- keras_hub/src/models/depth_anything/depth_anything_loss.py +89 -0
- keras_hub/src/models/depth_anything/depth_anything_presets.py +4 -0
- keras_hub/src/models/depth_anything/interpolate.py +62 -0
- keras_hub/src/models/depth_estimator.py +239 -0
- keras_hub/src/models/depth_estimator_preprocessor.py +78 -0
- keras_hub/src/models/dinov2/dinov2_backbone.py +29 -3
- keras_hub/src/models/dinov2/dinov2_layers.py +13 -3
- keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py +371 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +365 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm.py +357 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm_preprocessor.py +12 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_decoder.py +672 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_layernorm.py +45 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_tokenizer.py +48 -0
- keras_hub/src/tests/test_case.py +3 -2
- keras_hub/src/utils/transformers/convert_dinov2.py +1 -0
- keras_hub/src/utils/transformers/convert_qwen3_moe.py +216 -0
- keras_hub/src/utils/transformers/preset_loader.py +3 -0
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +3 -0
- {keras_hub_nightly-0.23.0.dev202509180413.dist-info → keras_hub_nightly-0.23.0.dev202509280419.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.23.0.dev202509180413.dist-info → keras_hub_nightly-0.23.0.dev202509280419.dist-info}/RECORD +32 -13
- {keras_hub_nightly-0.23.0.dev202509180413.dist-info → keras_hub_nightly-0.23.0.dev202509280419.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.23.0.dev202509180413.dist-info → keras_hub_nightly-0.23.0.dev202509280419.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,672 @@
|
|
|
1
|
+
import keras
|
|
2
|
+
from keras import ops
|
|
3
|
+
|
|
4
|
+
from keras_hub.src.layers.modeling.transformer_layer_utils import (
|
|
5
|
+
compute_causal_mask,
|
|
6
|
+
)
|
|
7
|
+
from keras_hub.src.layers.modeling.transformer_layer_utils import (
|
|
8
|
+
merge_padding_and_attention_mask,
|
|
9
|
+
)
|
|
10
|
+
from keras_hub.src.models.qwen3_moe.qwen3_moe_attention import Qwen3MoeAttention
|
|
11
|
+
from keras_hub.src.models.qwen3_moe.qwen3_moe_layernorm import Qwen3MoeLayerNorm
|
|
12
|
+
from keras_hub.src.utils.keras_utils import clone_initializer
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def compute_load_balancing_loss(
|
|
16
|
+
router_logits, num_experts, top_k, attention_mask=None
|
|
17
|
+
):
|
|
18
|
+
"""
|
|
19
|
+
Compute the load balancing auxiliary loss for a single MoE layer.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
router_logits: Tensor of shape (batch_size * seq_len, num_experts).
|
|
23
|
+
num_experts: Integer, total number of experts.
|
|
24
|
+
top_k: Integer, number of experts to select per token.
|
|
25
|
+
attention_mask: Tensor of shape (batch_size, seq_len, seq_len),
|
|
26
|
+
optional mask for padding.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
Scalar tensor representing the auxiliary loss.
|
|
30
|
+
"""
|
|
31
|
+
# Compute routing probabilities
|
|
32
|
+
routing_weights = ops.softmax(
|
|
33
|
+
router_logits, axis=-1
|
|
34
|
+
) # Shape: (batch_size * seq_len, num_experts)
|
|
35
|
+
|
|
36
|
+
# Get top-k experts
|
|
37
|
+
_, selected_experts = ops.top_k(
|
|
38
|
+
routing_weights, k=top_k
|
|
39
|
+
) # Shape: (batch_size * seq_len, top_k)
|
|
40
|
+
|
|
41
|
+
# Create one-hot encoding for selected experts
|
|
42
|
+
expert_mask = ops.one_hot(
|
|
43
|
+
selected_experts, num_experts
|
|
44
|
+
) # Shape: (batch_size * seq_len, top_k, num_experts)
|
|
45
|
+
|
|
46
|
+
if attention_mask is not None:
|
|
47
|
+
# Convert attention mask to (batch_size, seq_len)
|
|
48
|
+
batch_size, seq_len, _ = ops.shape(attention_mask)
|
|
49
|
+
flat_mask = ops.any(attention_mask, axis=-1)
|
|
50
|
+
flat_mask = ops.reshape(
|
|
51
|
+
flat_mask, (-1,)
|
|
52
|
+
) # Shape: (batch_size * seq_len,)
|
|
53
|
+
# Expand mask for broadcasting
|
|
54
|
+
expert_attention_mask = ops.expand_dims(
|
|
55
|
+
flat_mask, axis=-1
|
|
56
|
+
) # Shape: (batch_size * seq_len, 1)
|
|
57
|
+
expert_attention_mask = ops.cast(expert_attention_mask, dtype="float32")
|
|
58
|
+
|
|
59
|
+
# Compute masked means
|
|
60
|
+
tokens_per_expert = ops.sum(
|
|
61
|
+
expert_mask * expert_attention_mask[:, None, :], axis=0
|
|
62
|
+
) / ops.maximum(
|
|
63
|
+
ops.sum(expert_attention_mask[:, None, :], axis=0), 1e-9
|
|
64
|
+
) # Shape: (top_k, num_experts)
|
|
65
|
+
router_prob_per_expert = ops.sum(
|
|
66
|
+
routing_weights * expert_attention_mask, axis=0
|
|
67
|
+
) / ops.maximum(
|
|
68
|
+
ops.sum(expert_attention_mask, axis=0), 1e-9
|
|
69
|
+
) # Shape: (num_experts,)
|
|
70
|
+
else:
|
|
71
|
+
# Unmasked means
|
|
72
|
+
tokens_per_expert = ops.mean(
|
|
73
|
+
expert_mask, axis=0
|
|
74
|
+
) # Shape: (top_k, num_experts)
|
|
75
|
+
router_prob_per_expert = ops.mean(
|
|
76
|
+
routing_weights, axis=0
|
|
77
|
+
) # Shape: (num_experts,)
|
|
78
|
+
|
|
79
|
+
# Average over top_k dimension if necessary
|
|
80
|
+
tokens_per_expert = ops.mean(
|
|
81
|
+
tokens_per_expert, axis=0
|
|
82
|
+
) # Shape: (num_experts,)
|
|
83
|
+
|
|
84
|
+
# Compute the loss
|
|
85
|
+
overall_loss = ops.sum(tokens_per_expert * router_prob_per_expert)
|
|
86
|
+
return overall_loss * num_experts
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class Qwen3MoeMLP(keras.layers.Layer):
|
|
90
|
+
"""A feedforward network layer for a Transformer model.
|
|
91
|
+
|
|
92
|
+
This layer implements the gated linear unit (GLU) variant of a
|
|
93
|
+
feedforward network, which is a common setup in modern Transformers.
|
|
94
|
+
It consists of three dense layers: a gate layer, an intermediate layer,
|
|
95
|
+
and an output layer. The output is computed as
|
|
96
|
+
`output_dense(activation(gate_dense(x)) * intermediate_dense(x))`.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
intermediate_dim (int): The size of the intermediate (hidden) layer.
|
|
100
|
+
hidden_dim (int): The size of the input and output layers.
|
|
101
|
+
activation_fn (str, optional): The activation function to use.
|
|
102
|
+
Defaults to "silu".
|
|
103
|
+
layer_norm_epsilon (float, optional): Epsilon for layer normalization.
|
|
104
|
+
Defaults to 1e-6.
|
|
105
|
+
kernel_initializer (str, optional): The initializer for the kernel
|
|
106
|
+
weights. Defaults to "glorot_uniform".
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
def __init__(
|
|
110
|
+
self,
|
|
111
|
+
intermediate_dim,
|
|
112
|
+
hidden_dim,
|
|
113
|
+
activation_fn="silu",
|
|
114
|
+
layer_norm_epsilon=1e-6,
|
|
115
|
+
kernel_initializer="glorot_uniform",
|
|
116
|
+
**kwargs,
|
|
117
|
+
):
|
|
118
|
+
super().__init__(**kwargs)
|
|
119
|
+
self.intermediate_dim = intermediate_dim
|
|
120
|
+
self.hidden_dim = hidden_dim
|
|
121
|
+
self.activation_fn = activation_fn
|
|
122
|
+
self.kernel_initializer = kernel_initializer
|
|
123
|
+
self.layer_norm_epsilon = layer_norm_epsilon
|
|
124
|
+
|
|
125
|
+
def build(self, decoder_sequence_shape):
|
|
126
|
+
# Feedforward layers.
|
|
127
|
+
self._feedforward_intermediate_dense = keras.layers.Dense(
|
|
128
|
+
self.intermediate_dim,
|
|
129
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
|
130
|
+
use_bias=False,
|
|
131
|
+
dtype=self.dtype_policy,
|
|
132
|
+
name="feedforward_intermediate_dense",
|
|
133
|
+
)
|
|
134
|
+
self._feedforward_intermediate_dense.build(decoder_sequence_shape)
|
|
135
|
+
|
|
136
|
+
self._feedforward_gate_dense = keras.layers.Dense(
|
|
137
|
+
self.intermediate_dim,
|
|
138
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
|
139
|
+
use_bias=False,
|
|
140
|
+
dtype=self.dtype_policy,
|
|
141
|
+
name="feedforward_gate_dense",
|
|
142
|
+
)
|
|
143
|
+
self._feedforward_gate_dense.build(decoder_sequence_shape)
|
|
144
|
+
|
|
145
|
+
self._feedforward_output_dense = keras.layers.Dense(
|
|
146
|
+
self.hidden_dim,
|
|
147
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
|
148
|
+
use_bias=False,
|
|
149
|
+
dtype=self.dtype_policy,
|
|
150
|
+
name="feedforward_output_dense",
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
self._feedforward_output_dense.build(
|
|
154
|
+
self._feedforward_gate_dense.compute_output_shape(
|
|
155
|
+
decoder_sequence_shape
|
|
156
|
+
)
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
self.activation = keras.activations.get(self.activation_fn)
|
|
160
|
+
self.built = True
|
|
161
|
+
|
|
162
|
+
def call(self, x):
|
|
163
|
+
gate_output = self._feedforward_gate_dense(x)
|
|
164
|
+
|
|
165
|
+
# Note that we run the activation function in full 32-bit
|
|
166
|
+
# precision since this is what `torch.nn.functional.silu`
|
|
167
|
+
# does. Internally, `torch.nn.functional.silu` converts the
|
|
168
|
+
# inputs to float32, computes SiLU, and converts the outputs
|
|
169
|
+
# back to compute dtype.
|
|
170
|
+
# CPU Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cpu/Activation.cpp#L1221-L1235 # noqa: E501
|
|
171
|
+
# CUDA Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cuda/ActivationSiluKernel.cu # noqa: E501
|
|
172
|
+
gate_output = ops.cast(gate_output, "float32")
|
|
173
|
+
gate_output = self.activation(gate_output)
|
|
174
|
+
gate_output = ops.cast(gate_output, self.compute_dtype)
|
|
175
|
+
|
|
176
|
+
x = self._feedforward_intermediate_dense(x)
|
|
177
|
+
|
|
178
|
+
return self._feedforward_output_dense(ops.multiply(x, gate_output))
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
class Qwen3MoeExperts(keras.layers.Layer):
|
|
182
|
+
"""A layer that contains a bank of feedforward experts for MoE.
|
|
183
|
+
|
|
184
|
+
This layer implements the expert part of a Mixture-of-Experts (MoE) model.
|
|
185
|
+
It creates a set of 'expert' feedforward networks that are computed in a
|
|
186
|
+
batched manner for efficiency. The weights for all experts are stored in
|
|
187
|
+
a single tensor, and computations are performed using `einsum` to process
|
|
188
|
+
all experts simultaneously.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
num_experts (int): The total number of experts in the layer.
|
|
192
|
+
hidden_dim (int): The dimension of the input and output of each expert.
|
|
193
|
+
intermediate_dim (int): The intermediate dimension of each expert's
|
|
194
|
+
feedforward network.
|
|
195
|
+
activation_fn (str, optional): The activation function to use within
|
|
196
|
+
each expert. Defaults to "silu".
|
|
197
|
+
kernel_initializer (str, optional): The initializer for the kernel
|
|
198
|
+
weights. Defaults to "glorot_uniform".
|
|
199
|
+
"""
|
|
200
|
+
|
|
201
|
+
def __init__(
|
|
202
|
+
self,
|
|
203
|
+
num_experts,
|
|
204
|
+
hidden_dim,
|
|
205
|
+
intermediate_dim,
|
|
206
|
+
activation_fn="silu",
|
|
207
|
+
kernel_initializer="glorot_uniform",
|
|
208
|
+
**kwargs,
|
|
209
|
+
):
|
|
210
|
+
super().__init__(**kwargs)
|
|
211
|
+
self.num_experts = num_experts
|
|
212
|
+
self.hidden_dim = hidden_dim
|
|
213
|
+
self.intermediate_dim = intermediate_dim
|
|
214
|
+
self.activation = keras.activations.get(activation_fn)
|
|
215
|
+
self.kernel_initializer = kernel_initializer
|
|
216
|
+
|
|
217
|
+
def build(self, _):
|
|
218
|
+
self._expert_feedforward_gate_dense = self.add_weight(
|
|
219
|
+
shape=(
|
|
220
|
+
self.num_experts,
|
|
221
|
+
self.hidden_dim,
|
|
222
|
+
2 * self.intermediate_dim,
|
|
223
|
+
),
|
|
224
|
+
initializer=self.kernel_initializer,
|
|
225
|
+
trainable=True,
|
|
226
|
+
dtype=self.variable_dtype,
|
|
227
|
+
name="expert_feedforward_gate_dense",
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
self._expert_feedforward_output_dense = self.add_weight(
|
|
231
|
+
shape=(self.num_experts, self.intermediate_dim, self.hidden_dim),
|
|
232
|
+
initializer=self.kernel_initializer,
|
|
233
|
+
trainable=True,
|
|
234
|
+
dtype=self.variable_dtype,
|
|
235
|
+
name="expert_feedforward_output_dense",
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
self.built = True
|
|
239
|
+
|
|
240
|
+
def call(self, hidden_states):
|
|
241
|
+
gate_up = ops.einsum(
|
|
242
|
+
"th,ehm->etm", hidden_states, self._expert_feedforward_gate_dense
|
|
243
|
+
)
|
|
244
|
+
gate, up = ops.split(gate_up, 2, axis=-1)
|
|
245
|
+
hidden = up * self.activation(gate)
|
|
246
|
+
out = ops.einsum(
|
|
247
|
+
"eti,eih->eth", hidden, self._expert_feedforward_output_dense
|
|
248
|
+
)
|
|
249
|
+
return out
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
class Qwen3SparseMoeBlock(keras.layers.Layer):
|
|
253
|
+
"""A sparse Mixture-of-Experts (MoE) block.
|
|
254
|
+
|
|
255
|
+
This block implements the full MoE logic. It contains a 'router' that
|
|
256
|
+
learns to send each input token to a subset of 'experts'. The final output
|
|
257
|
+
is a weighted combination of the outputs from the selected experts.
|
|
258
|
+
It also computes a load-balancing auxiliary loss during training to
|
|
259
|
+
encourage the router to distribute tokens evenly across all experts.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
hidden_dim (int): The dimension of the input and output tensors.
|
|
263
|
+
moe_intermediate_dim (int): The intermediate dimension of each expert.
|
|
264
|
+
num_experts (int): The total number of experts available.
|
|
265
|
+
top_k (int): The number of experts to route each token to.
|
|
266
|
+
norm_top_k_prob (bool): If True, normalize the probabilities of the
|
|
267
|
+
top-k experts.
|
|
268
|
+
kernel_initializer (str, optional): The initializer for kernel weights.
|
|
269
|
+
Defaults to "glorot_uniform".
|
|
270
|
+
layer_norm_epsilon (float, optional): Epsilon for layer normalization.
|
|
271
|
+
Defaults to 1e-6.
|
|
272
|
+
router_aux_loss_coefficient (float, optional): The coefficient for the
|
|
273
|
+
load-balancing auxiliary loss. Defaults to 0.01.
|
|
274
|
+
"""
|
|
275
|
+
|
|
276
|
+
def __init__(
|
|
277
|
+
self,
|
|
278
|
+
hidden_dim,
|
|
279
|
+
moe_intermediate_dim,
|
|
280
|
+
num_experts,
|
|
281
|
+
top_k,
|
|
282
|
+
norm_top_k_prob,
|
|
283
|
+
kernel_initializer="glorot_uniform",
|
|
284
|
+
layer_norm_epsilon=1e-6,
|
|
285
|
+
router_aux_loss_coefficient=0.01,
|
|
286
|
+
**kwargs,
|
|
287
|
+
):
|
|
288
|
+
super().__init__(**kwargs)
|
|
289
|
+
self.hidden_dim = hidden_dim
|
|
290
|
+
self.intermediate_dim = moe_intermediate_dim
|
|
291
|
+
self.num_experts = num_experts
|
|
292
|
+
self.top_k = top_k
|
|
293
|
+
self.norm_top_k_prob = norm_top_k_prob
|
|
294
|
+
self.kernel_initializer = kernel_initializer
|
|
295
|
+
self.layer_norm_epsilon = layer_norm_epsilon
|
|
296
|
+
self.router_aux_loss_coefficient = router_aux_loss_coefficient
|
|
297
|
+
|
|
298
|
+
def build(self, decoder_sequence_shape):
|
|
299
|
+
self._sparse_feedforward_gate_dense = keras.layers.Dense(
|
|
300
|
+
self.num_experts,
|
|
301
|
+
use_bias=False,
|
|
302
|
+
kernel_initializer=self.kernel_initializer,
|
|
303
|
+
name="sparse_feedforward_gate_dense",
|
|
304
|
+
dtype=self.dtype_policy,
|
|
305
|
+
)
|
|
306
|
+
self._sparse_feedforward_gate_dense.build(decoder_sequence_shape)
|
|
307
|
+
|
|
308
|
+
# NOTE: Experts are implemented as a single layer to enable efficient
|
|
309
|
+
# batched computation. Implementing each expert individually is
|
|
310
|
+
# currently avoided due to the lack of `ragged_dot` support in the
|
|
311
|
+
# Keras ops API, which would make individual implementations unstable
|
|
312
|
+
# and prone to bugs.
|
|
313
|
+
self.expert_bank = Qwen3MoeExperts(
|
|
314
|
+
num_experts=self.num_experts,
|
|
315
|
+
hidden_dim=self.hidden_dim,
|
|
316
|
+
intermediate_dim=self.intermediate_dim,
|
|
317
|
+
kernel_initializer=self.kernel_initializer,
|
|
318
|
+
name="experts",
|
|
319
|
+
dtype=self.dtype_policy,
|
|
320
|
+
)
|
|
321
|
+
self.expert_bank.build(decoder_sequence_shape)
|
|
322
|
+
|
|
323
|
+
self.built = True
|
|
324
|
+
|
|
325
|
+
def call(self, hidden_states, attention_mask=None, training=None):
|
|
326
|
+
batch_size, seq_len, _ = ops.shape(hidden_states)
|
|
327
|
+
hidden_states_flattened = ops.reshape(
|
|
328
|
+
hidden_states, (-1, self.hidden_dim)
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
router_logits = self._sparse_feedforward_gate_dense(
|
|
332
|
+
hidden_states_flattened
|
|
333
|
+
)
|
|
334
|
+
router_probs = ops.softmax(router_logits, axis=-1)
|
|
335
|
+
|
|
336
|
+
top_p, top_i = ops.top_k(router_probs, k=self.top_k)
|
|
337
|
+
if self.norm_top_k_prob:
|
|
338
|
+
top_p = top_p / ops.sum(top_p, axis=-1, keepdims=True)
|
|
339
|
+
|
|
340
|
+
one_hot = ops.one_hot(top_i, self.num_experts)
|
|
341
|
+
one_hot = ops.cast(one_hot, top_p.dtype)
|
|
342
|
+
routing_full = ops.sum(one_hot * top_p[..., None], axis=1)
|
|
343
|
+
routing_full = ops.transpose(routing_full, (1, 0))
|
|
344
|
+
routing_full = ops.cast(routing_full, hidden_states_flattened.dtype)
|
|
345
|
+
|
|
346
|
+
expert_out = self.expert_bank(hidden_states_flattened)
|
|
347
|
+
|
|
348
|
+
weighted_out = expert_out * routing_full[:, :, None]
|
|
349
|
+
expert_contribution = ops.sum(weighted_out, axis=0)
|
|
350
|
+
|
|
351
|
+
out = ops.reshape(
|
|
352
|
+
expert_contribution, (batch_size, seq_len, self.hidden_dim)
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
# Compute and add auxiliary loss during training
|
|
356
|
+
if training:
|
|
357
|
+
aux_loss = compute_load_balancing_loss(
|
|
358
|
+
router_logits=router_logits,
|
|
359
|
+
num_experts=self.num_experts,
|
|
360
|
+
top_k=self.top_k,
|
|
361
|
+
attention_mask=attention_mask,
|
|
362
|
+
)
|
|
363
|
+
self.add_loss(self.router_aux_loss_coefficient * aux_loss)
|
|
364
|
+
|
|
365
|
+
return out, router_logits
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
class Qwen3MoeTransformerDecoder(keras.layers.Layer):
|
|
369
|
+
"""A Transformer decoder layer for the Qwen3 Moe backbone.
|
|
370
|
+
|
|
371
|
+
This layer implements a Transformer decoder block that includes
|
|
372
|
+
self-attention with optional sliding window attention and a
|
|
373
|
+
Mixture-of-Experts (MoE) feed-forward network.
|
|
374
|
+
|
|
375
|
+
Args:
|
|
376
|
+
intermediate_dim: Output dimension of the first dense layer in the
|
|
377
|
+
feed-forward network (for non-MoE layers).
|
|
378
|
+
num_query_heads: Number of query attention heads.
|
|
379
|
+
num_key_value_heads: Number of key/value attention heads (for GQA).
|
|
380
|
+
moe_intermediate_dim: The intermediate dimension for each expert in the
|
|
381
|
+
MoE layer.
|
|
382
|
+
num_experts: The total number of experts in the MoE layer.
|
|
383
|
+
top_k: The number of experts to which each token is routed.
|
|
384
|
+
norm_top_k_prob: If True, normalize the top-k probabilities.
|
|
385
|
+
head_dim: The dimension of each attention head. If None, it is
|
|
386
|
+
inferred from other dimensions.
|
|
387
|
+
is_sparse_mlp: If True, uses a sparse MLP.
|
|
388
|
+
rope_max_wavelength: Maximum wavelength for RoPE (Rotary Position
|
|
389
|
+
Embedding).
|
|
390
|
+
rope_scaling_factor: Scaling factor for RoPE, used for extending
|
|
391
|
+
context length.
|
|
392
|
+
activation: Activation function to use in the feed-forward network.
|
|
393
|
+
layer_norm_epsilon: Small float added to variance to avoid dividing
|
|
394
|
+
by zero in layer norm.
|
|
395
|
+
kernel_initializer: Initializer for the kernel weights.
|
|
396
|
+
dropout: Dropout rate for attention and hidden layers.
|
|
397
|
+
sliding_window_size: Size of the sliding window for attention when
|
|
398
|
+
enabled.
|
|
399
|
+
router_aux_loss_coefficient: The coefficient for the router's auxiliary
|
|
400
|
+
loss, used for load balancing.
|
|
401
|
+
**kwargs: Additional keyword arguments to pass to the Layer.
|
|
402
|
+
"""
|
|
403
|
+
|
|
404
|
+
def __init__(
|
|
405
|
+
self,
|
|
406
|
+
intermediate_dim,
|
|
407
|
+
num_query_heads,
|
|
408
|
+
num_key_value_heads,
|
|
409
|
+
moe_intermediate_dim,
|
|
410
|
+
num_experts,
|
|
411
|
+
top_k,
|
|
412
|
+
norm_top_k_prob,
|
|
413
|
+
head_dim=None,
|
|
414
|
+
is_sparse_mlp=False,
|
|
415
|
+
rope_max_wavelength=10000,
|
|
416
|
+
rope_scaling_factor=1.0,
|
|
417
|
+
activation="silu",
|
|
418
|
+
layer_norm_epsilon=1e-6,
|
|
419
|
+
kernel_initializer="glorot_uniform",
|
|
420
|
+
dropout=0,
|
|
421
|
+
sliding_window_size=4096,
|
|
422
|
+
router_aux_loss_coefficient=0.001,
|
|
423
|
+
**kwargs,
|
|
424
|
+
):
|
|
425
|
+
super().__init__(**kwargs)
|
|
426
|
+
self.intermediate_dim = intermediate_dim
|
|
427
|
+
self.num_query_heads = num_query_heads
|
|
428
|
+
self.num_key_value_heads = num_key_value_heads
|
|
429
|
+
self.rope_max_wavelength = rope_max_wavelength
|
|
430
|
+
self.rope_scaling_factor = rope_scaling_factor
|
|
431
|
+
self.dropout = dropout
|
|
432
|
+
self.sliding_window_size = sliding_window_size
|
|
433
|
+
self.activation = keras.activations.get(activation)
|
|
434
|
+
self.layer_norm_epsilon = layer_norm_epsilon
|
|
435
|
+
self.kernel_initializer = keras.initializers.get(kernel_initializer)
|
|
436
|
+
self.moe_intermediate_dim = moe_intermediate_dim
|
|
437
|
+
self.head_dim = head_dim
|
|
438
|
+
self.num_experts = num_experts
|
|
439
|
+
self.top_k = top_k
|
|
440
|
+
self.norm_top_k_prob = norm_top_k_prob
|
|
441
|
+
self.is_sparse_mlp = is_sparse_mlp
|
|
442
|
+
self.router_aux_loss_coefficient = router_aux_loss_coefficient
|
|
443
|
+
self.supports_masking = True
|
|
444
|
+
|
|
445
|
+
def build(self, decoder_sequence_shape):
|
|
446
|
+
self._decoder_sequence_shape = decoder_sequence_shape
|
|
447
|
+
self.hidden_dim = decoder_sequence_shape[-1]
|
|
448
|
+
|
|
449
|
+
# Self attention layer.
|
|
450
|
+
self._self_attention_layer = Qwen3MoeAttention(
|
|
451
|
+
num_query_heads=self.num_query_heads,
|
|
452
|
+
num_key_value_heads=self.num_key_value_heads,
|
|
453
|
+
rope_max_wavelength=self.rope_max_wavelength,
|
|
454
|
+
head_dim=self.head_dim,
|
|
455
|
+
rope_scaling_factor=self.rope_scaling_factor,
|
|
456
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
|
457
|
+
dropout=self.dropout,
|
|
458
|
+
sliding_window_size=self.sliding_window_size,
|
|
459
|
+
dtype=self.dtype_policy,
|
|
460
|
+
name="self_attention",
|
|
461
|
+
)
|
|
462
|
+
self._self_attention_layer.build(decoder_sequence_shape)
|
|
463
|
+
|
|
464
|
+
self._self_attention_layernorm = Qwen3MoeLayerNorm(
|
|
465
|
+
epsilon=self.layer_norm_epsilon,
|
|
466
|
+
dtype=self.dtype_policy,
|
|
467
|
+
name="self_attention_layernorm",
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
self._self_attention_layernorm.build(decoder_sequence_shape)
|
|
471
|
+
self._self_attention_dropout = keras.layers.Dropout(
|
|
472
|
+
rate=self.dropout,
|
|
473
|
+
dtype=self.dtype_policy,
|
|
474
|
+
name="self_attention_dropout",
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
# Feedforward layers.
|
|
478
|
+
if self.is_sparse_mlp:
|
|
479
|
+
self.mlp = Qwen3SparseMoeBlock(
|
|
480
|
+
hidden_dim=self.hidden_dim,
|
|
481
|
+
moe_intermediate_dim=self.moe_intermediate_dim,
|
|
482
|
+
num_experts=self.num_experts,
|
|
483
|
+
top_k=self.top_k,
|
|
484
|
+
norm_top_k_prob=self.norm_top_k_prob,
|
|
485
|
+
router_aux_loss_coefficient=self.router_aux_loss_coefficient,
|
|
486
|
+
kernel_initializer=self.kernel_initializer,
|
|
487
|
+
dtype=self.dtype_policy,
|
|
488
|
+
)
|
|
489
|
+
self.mlp.build(decoder_sequence_shape)
|
|
490
|
+
else:
|
|
491
|
+
self.mlp = Qwen3MoeMLP(
|
|
492
|
+
intermediate_dim=self.intermediate_dim,
|
|
493
|
+
hidden_dim=self.hidden_dim,
|
|
494
|
+
dtype=self.dtype_policy,
|
|
495
|
+
)
|
|
496
|
+
self.mlp.build(decoder_sequence_shape)
|
|
497
|
+
|
|
498
|
+
self._feedforward_layernorm = Qwen3MoeLayerNorm(
|
|
499
|
+
epsilon=self.layer_norm_epsilon,
|
|
500
|
+
dtype=self.dtype_policy,
|
|
501
|
+
name="feedforward_layernorm",
|
|
502
|
+
)
|
|
503
|
+
self._feedforward_layernorm.build(decoder_sequence_shape)
|
|
504
|
+
|
|
505
|
+
self.built = True
|
|
506
|
+
|
|
507
|
+
def call(
|
|
508
|
+
self,
|
|
509
|
+
decoder_sequence,
|
|
510
|
+
decoder_padding_mask=None,
|
|
511
|
+
decoder_attention_mask=None,
|
|
512
|
+
self_attention_cache=None,
|
|
513
|
+
self_attention_cache_update_index=None,
|
|
514
|
+
training=None,
|
|
515
|
+
):
|
|
516
|
+
"""Forward pass for the decoder layer.
|
|
517
|
+
|
|
518
|
+
Args:
|
|
519
|
+
decoder_sequence: Input tensor of shape [batch_size, seq_length,
|
|
520
|
+
hidden_size].
|
|
521
|
+
decoder_padding_mask: Mask tensor for padding tokens.
|
|
522
|
+
decoder_attention_mask: Additional attention mask.
|
|
523
|
+
self_attention_cache: Optional cached key and value tensors for
|
|
524
|
+
self-attention.
|
|
525
|
+
self_attention_cache_update_index: Index at which to update the
|
|
526
|
+
cache.
|
|
527
|
+
training: Boolean indicating whether in training mode.
|
|
528
|
+
|
|
529
|
+
Returns:
|
|
530
|
+
decoder_output: Output tensor after applying transformer decoder
|
|
531
|
+
block.
|
|
532
|
+
self_attention_cache: Updated cache tensors (if cache is provided).
|
|
533
|
+
"""
|
|
534
|
+
self_attention_mask = self._compute_self_attention_mask(
|
|
535
|
+
decoder_sequence=decoder_sequence,
|
|
536
|
+
decoder_padding_mask=decoder_padding_mask,
|
|
537
|
+
decoder_attention_mask=decoder_attention_mask,
|
|
538
|
+
self_attention_cache=self_attention_cache,
|
|
539
|
+
self_attention_cache_update_index=self_attention_cache_update_index,
|
|
540
|
+
)
|
|
541
|
+
residual = decoder_sequence
|
|
542
|
+
|
|
543
|
+
x = self._self_attention_layernorm(decoder_sequence)
|
|
544
|
+
|
|
545
|
+
# Self attention block.
|
|
546
|
+
x = self._self_attention_layer(
|
|
547
|
+
hidden_states=x,
|
|
548
|
+
attention_mask=self_attention_mask,
|
|
549
|
+
cache=self_attention_cache,
|
|
550
|
+
cache_update_index=self_attention_cache_update_index,
|
|
551
|
+
)
|
|
552
|
+
|
|
553
|
+
if self_attention_cache is not None:
|
|
554
|
+
x, self_attention_cache = x
|
|
555
|
+
|
|
556
|
+
x = self._self_attention_dropout(x, training=training)
|
|
557
|
+
|
|
558
|
+
x = x + residual
|
|
559
|
+
residual = x
|
|
560
|
+
|
|
561
|
+
x = self._feedforward_layernorm(x)
|
|
562
|
+
if isinstance(self.mlp, Qwen3SparseMoeBlock):
|
|
563
|
+
x = self.mlp(
|
|
564
|
+
x, training=training, attention_mask=self_attention_mask
|
|
565
|
+
)
|
|
566
|
+
else:
|
|
567
|
+
x = self.mlp(x)
|
|
568
|
+
|
|
569
|
+
if isinstance(x, tuple):
|
|
570
|
+
x, _ = x
|
|
571
|
+
|
|
572
|
+
x = ops.cast(x, ops.dtype(residual))
|
|
573
|
+
decoder_output = x + residual
|
|
574
|
+
|
|
575
|
+
output = (decoder_output,)
|
|
576
|
+
|
|
577
|
+
if self_attention_cache is not None:
|
|
578
|
+
output += (self_attention_cache,)
|
|
579
|
+
|
|
580
|
+
return output[0] if len(output) == 1 else output
|
|
581
|
+
|
|
582
|
+
def _compute_self_attention_mask(
|
|
583
|
+
self,
|
|
584
|
+
decoder_sequence,
|
|
585
|
+
decoder_padding_mask,
|
|
586
|
+
decoder_attention_mask,
|
|
587
|
+
self_attention_cache,
|
|
588
|
+
self_attention_cache_update_index,
|
|
589
|
+
):
|
|
590
|
+
"""Computes the self-attention mask combining causal, padding and
|
|
591
|
+
attention masks.
|
|
592
|
+
|
|
593
|
+
Args:
|
|
594
|
+
decoder_sequence: Input tensor.
|
|
595
|
+
decoder_padding_mask: Mask tensor for padding tokens.
|
|
596
|
+
decoder_attention_mask: Additional attention mask.
|
|
597
|
+
self_attention_cache: Optional cached key and value tensors.
|
|
598
|
+
self_attention_cache_update_index: Index at which to update the
|
|
599
|
+
cache.
|
|
600
|
+
|
|
601
|
+
Returns:
|
|
602
|
+
Combined attention mask tensor.
|
|
603
|
+
"""
|
|
604
|
+
decoder_mask = merge_padding_and_attention_mask(
|
|
605
|
+
decoder_sequence, decoder_padding_mask, decoder_attention_mask
|
|
606
|
+
)
|
|
607
|
+
batch_size = ops.shape(decoder_sequence)[0]
|
|
608
|
+
input_length = output_length = ops.shape(decoder_sequence)[1]
|
|
609
|
+
# We need to handle a rectangular causal mask when doing cached
|
|
610
|
+
# decoding. For generative inference, `decoder_sequence` will
|
|
611
|
+
# generally be length 1, and `cache` will be the full generation length.
|
|
612
|
+
if self_attention_cache is not None:
|
|
613
|
+
input_length = ops.shape(self_attention_cache)[2]
|
|
614
|
+
|
|
615
|
+
cache_update_index = (
|
|
616
|
+
0
|
|
617
|
+
if self_attention_cache_update_index is None
|
|
618
|
+
else self_attention_cache_update_index
|
|
619
|
+
)
|
|
620
|
+
|
|
621
|
+
causal_mask = compute_causal_mask(
|
|
622
|
+
batch_size, input_length, output_length, cache_update_index
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
return (
|
|
626
|
+
ops.minimum(decoder_mask, causal_mask)
|
|
627
|
+
if decoder_mask is not None
|
|
628
|
+
else causal_mask
|
|
629
|
+
)
|
|
630
|
+
|
|
631
|
+
def compute_output_shape(self, decoder_sequence_shape):
|
|
632
|
+
"""Computes the output shape of the layer.
|
|
633
|
+
|
|
634
|
+
Args:
|
|
635
|
+
decoder_sequence_shape: Shape of the decoder sequence input.
|
|
636
|
+
|
|
637
|
+
Returns:
|
|
638
|
+
Output shape, which is the same as the input shape.
|
|
639
|
+
"""
|
|
640
|
+
return decoder_sequence_shape
|
|
641
|
+
|
|
642
|
+
def get_config(self):
|
|
643
|
+
"""Returns the config of the layer.
|
|
644
|
+
|
|
645
|
+
Returns:
|
|
646
|
+
Dictionary containing the parameters used to initialize this layer.
|
|
647
|
+
"""
|
|
648
|
+
config = super().get_config()
|
|
649
|
+
config.update(
|
|
650
|
+
{
|
|
651
|
+
"num_query_heads": self.num_query_heads,
|
|
652
|
+
"intermediate_dim": self.intermediate_dim,
|
|
653
|
+
"moe_intermediate_dim": self.moe_intermediate_dim,
|
|
654
|
+
"rope_max_wavelength": self.rope_max_wavelength,
|
|
655
|
+
"num_key_value_heads": self.num_key_value_heads,
|
|
656
|
+
"rope_scaling_factor": self.rope_scaling_factor,
|
|
657
|
+
"layer_norm_epsilon": self.layer_norm_epsilon,
|
|
658
|
+
"dropout": self.dropout,
|
|
659
|
+
"sliding_window_size": self.sliding_window_size,
|
|
660
|
+
"num_experts": self.num_experts,
|
|
661
|
+
"top_k": self.top_k,
|
|
662
|
+
"norm_top_k_prob": self.norm_top_k_prob,
|
|
663
|
+
"router_aux_loss_coefficient": self.router_aux_loss_coefficient,
|
|
664
|
+
"head_dim": self.head_dim,
|
|
665
|
+
"is_sparse_mlp": self.is_sparse_mlp,
|
|
666
|
+
"activation": keras.activations.serialize(self.activation),
|
|
667
|
+
"kernel_initializer": keras.initializers.serialize(
|
|
668
|
+
self.kernel_initializer
|
|
669
|
+
),
|
|
670
|
+
}
|
|
671
|
+
)
|
|
672
|
+
return config
|