keras-hub-nightly 0.23.0.dev202509190415__py3-none-any.whl → 0.23.0.dev202509280419__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of keras-hub-nightly might be problematic. Click here for more details.

Files changed (32) hide show
  1. keras_hub/layers/__init__.py +3 -0
  2. keras_hub/models/__init__.py +24 -0
  3. keras_hub/src/models/depth_anything/__init__.py +9 -0
  4. keras_hub/src/models/depth_anything/depth_anything_backbone.py +232 -0
  5. keras_hub/src/models/depth_anything/depth_anything_depth_estimator.py +70 -0
  6. keras_hub/src/models/depth_anything/depth_anything_depth_estimator_preprocessor.py +16 -0
  7. keras_hub/src/models/depth_anything/depth_anything_image_converter.py +10 -0
  8. keras_hub/src/models/depth_anything/depth_anything_layers.py +725 -0
  9. keras_hub/src/models/depth_anything/depth_anything_loss.py +89 -0
  10. keras_hub/src/models/depth_anything/depth_anything_presets.py +4 -0
  11. keras_hub/src/models/depth_anything/interpolate.py +62 -0
  12. keras_hub/src/models/depth_estimator.py +239 -0
  13. keras_hub/src/models/depth_estimator_preprocessor.py +78 -0
  14. keras_hub/src/models/dinov2/dinov2_backbone.py +29 -3
  15. keras_hub/src/models/dinov2/dinov2_layers.py +13 -3
  16. keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py +371 -0
  17. keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +365 -0
  18. keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm.py +357 -0
  19. keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm_preprocessor.py +12 -0
  20. keras_hub/src/models/qwen3_moe/qwen3_moe_decoder.py +672 -0
  21. keras_hub/src/models/qwen3_moe/qwen3_moe_layernorm.py +45 -0
  22. keras_hub/src/models/qwen3_moe/qwen3_moe_tokenizer.py +48 -0
  23. keras_hub/src/tests/test_case.py +3 -2
  24. keras_hub/src/utils/transformers/convert_dinov2.py +1 -0
  25. keras_hub/src/utils/transformers/convert_qwen3_moe.py +216 -0
  26. keras_hub/src/utils/transformers/preset_loader.py +3 -0
  27. keras_hub/src/version.py +1 -1
  28. keras_hub/tokenizers/__init__.py +3 -0
  29. {keras_hub_nightly-0.23.0.dev202509190415.dist-info → keras_hub_nightly-0.23.0.dev202509280419.dist-info}/METADATA +1 -1
  30. {keras_hub_nightly-0.23.0.dev202509190415.dist-info → keras_hub_nightly-0.23.0.dev202509280419.dist-info}/RECORD +32 -13
  31. {keras_hub_nightly-0.23.0.dev202509190415.dist-info → keras_hub_nightly-0.23.0.dev202509280419.dist-info}/WHEEL +0 -0
  32. {keras_hub_nightly-0.23.0.dev202509190415.dist-info → keras_hub_nightly-0.23.0.dev202509280419.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,672 @@
1
+ import keras
2
+ from keras import ops
3
+
4
+ from keras_hub.src.layers.modeling.transformer_layer_utils import (
5
+ compute_causal_mask,
6
+ )
7
+ from keras_hub.src.layers.modeling.transformer_layer_utils import (
8
+ merge_padding_and_attention_mask,
9
+ )
10
+ from keras_hub.src.models.qwen3_moe.qwen3_moe_attention import Qwen3MoeAttention
11
+ from keras_hub.src.models.qwen3_moe.qwen3_moe_layernorm import Qwen3MoeLayerNorm
12
+ from keras_hub.src.utils.keras_utils import clone_initializer
13
+
14
+
15
+ def compute_load_balancing_loss(
16
+ router_logits, num_experts, top_k, attention_mask=None
17
+ ):
18
+ """
19
+ Compute the load balancing auxiliary loss for a single MoE layer.
20
+
21
+ Args:
22
+ router_logits: Tensor of shape (batch_size * seq_len, num_experts).
23
+ num_experts: Integer, total number of experts.
24
+ top_k: Integer, number of experts to select per token.
25
+ attention_mask: Tensor of shape (batch_size, seq_len, seq_len),
26
+ optional mask for padding.
27
+
28
+ Returns:
29
+ Scalar tensor representing the auxiliary loss.
30
+ """
31
+ # Compute routing probabilities
32
+ routing_weights = ops.softmax(
33
+ router_logits, axis=-1
34
+ ) # Shape: (batch_size * seq_len, num_experts)
35
+
36
+ # Get top-k experts
37
+ _, selected_experts = ops.top_k(
38
+ routing_weights, k=top_k
39
+ ) # Shape: (batch_size * seq_len, top_k)
40
+
41
+ # Create one-hot encoding for selected experts
42
+ expert_mask = ops.one_hot(
43
+ selected_experts, num_experts
44
+ ) # Shape: (batch_size * seq_len, top_k, num_experts)
45
+
46
+ if attention_mask is not None:
47
+ # Convert attention mask to (batch_size, seq_len)
48
+ batch_size, seq_len, _ = ops.shape(attention_mask)
49
+ flat_mask = ops.any(attention_mask, axis=-1)
50
+ flat_mask = ops.reshape(
51
+ flat_mask, (-1,)
52
+ ) # Shape: (batch_size * seq_len,)
53
+ # Expand mask for broadcasting
54
+ expert_attention_mask = ops.expand_dims(
55
+ flat_mask, axis=-1
56
+ ) # Shape: (batch_size * seq_len, 1)
57
+ expert_attention_mask = ops.cast(expert_attention_mask, dtype="float32")
58
+
59
+ # Compute masked means
60
+ tokens_per_expert = ops.sum(
61
+ expert_mask * expert_attention_mask[:, None, :], axis=0
62
+ ) / ops.maximum(
63
+ ops.sum(expert_attention_mask[:, None, :], axis=0), 1e-9
64
+ ) # Shape: (top_k, num_experts)
65
+ router_prob_per_expert = ops.sum(
66
+ routing_weights * expert_attention_mask, axis=0
67
+ ) / ops.maximum(
68
+ ops.sum(expert_attention_mask, axis=0), 1e-9
69
+ ) # Shape: (num_experts,)
70
+ else:
71
+ # Unmasked means
72
+ tokens_per_expert = ops.mean(
73
+ expert_mask, axis=0
74
+ ) # Shape: (top_k, num_experts)
75
+ router_prob_per_expert = ops.mean(
76
+ routing_weights, axis=0
77
+ ) # Shape: (num_experts,)
78
+
79
+ # Average over top_k dimension if necessary
80
+ tokens_per_expert = ops.mean(
81
+ tokens_per_expert, axis=0
82
+ ) # Shape: (num_experts,)
83
+
84
+ # Compute the loss
85
+ overall_loss = ops.sum(tokens_per_expert * router_prob_per_expert)
86
+ return overall_loss * num_experts
87
+
88
+
89
+ class Qwen3MoeMLP(keras.layers.Layer):
90
+ """A feedforward network layer for a Transformer model.
91
+
92
+ This layer implements the gated linear unit (GLU) variant of a
93
+ feedforward network, which is a common setup in modern Transformers.
94
+ It consists of three dense layers: a gate layer, an intermediate layer,
95
+ and an output layer. The output is computed as
96
+ `output_dense(activation(gate_dense(x)) * intermediate_dense(x))`.
97
+
98
+ Args:
99
+ intermediate_dim (int): The size of the intermediate (hidden) layer.
100
+ hidden_dim (int): The size of the input and output layers.
101
+ activation_fn (str, optional): The activation function to use.
102
+ Defaults to "silu".
103
+ layer_norm_epsilon (float, optional): Epsilon for layer normalization.
104
+ Defaults to 1e-6.
105
+ kernel_initializer (str, optional): The initializer for the kernel
106
+ weights. Defaults to "glorot_uniform".
107
+ """
108
+
109
+ def __init__(
110
+ self,
111
+ intermediate_dim,
112
+ hidden_dim,
113
+ activation_fn="silu",
114
+ layer_norm_epsilon=1e-6,
115
+ kernel_initializer="glorot_uniform",
116
+ **kwargs,
117
+ ):
118
+ super().__init__(**kwargs)
119
+ self.intermediate_dim = intermediate_dim
120
+ self.hidden_dim = hidden_dim
121
+ self.activation_fn = activation_fn
122
+ self.kernel_initializer = kernel_initializer
123
+ self.layer_norm_epsilon = layer_norm_epsilon
124
+
125
+ def build(self, decoder_sequence_shape):
126
+ # Feedforward layers.
127
+ self._feedforward_intermediate_dense = keras.layers.Dense(
128
+ self.intermediate_dim,
129
+ kernel_initializer=clone_initializer(self.kernel_initializer),
130
+ use_bias=False,
131
+ dtype=self.dtype_policy,
132
+ name="feedforward_intermediate_dense",
133
+ )
134
+ self._feedforward_intermediate_dense.build(decoder_sequence_shape)
135
+
136
+ self._feedforward_gate_dense = keras.layers.Dense(
137
+ self.intermediate_dim,
138
+ kernel_initializer=clone_initializer(self.kernel_initializer),
139
+ use_bias=False,
140
+ dtype=self.dtype_policy,
141
+ name="feedforward_gate_dense",
142
+ )
143
+ self._feedforward_gate_dense.build(decoder_sequence_shape)
144
+
145
+ self._feedforward_output_dense = keras.layers.Dense(
146
+ self.hidden_dim,
147
+ kernel_initializer=clone_initializer(self.kernel_initializer),
148
+ use_bias=False,
149
+ dtype=self.dtype_policy,
150
+ name="feedforward_output_dense",
151
+ )
152
+
153
+ self._feedforward_output_dense.build(
154
+ self._feedforward_gate_dense.compute_output_shape(
155
+ decoder_sequence_shape
156
+ )
157
+ )
158
+
159
+ self.activation = keras.activations.get(self.activation_fn)
160
+ self.built = True
161
+
162
+ def call(self, x):
163
+ gate_output = self._feedforward_gate_dense(x)
164
+
165
+ # Note that we run the activation function in full 32-bit
166
+ # precision since this is what `torch.nn.functional.silu`
167
+ # does. Internally, `torch.nn.functional.silu` converts the
168
+ # inputs to float32, computes SiLU, and converts the outputs
169
+ # back to compute dtype.
170
+ # CPU Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cpu/Activation.cpp#L1221-L1235 # noqa: E501
171
+ # CUDA Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cuda/ActivationSiluKernel.cu # noqa: E501
172
+ gate_output = ops.cast(gate_output, "float32")
173
+ gate_output = self.activation(gate_output)
174
+ gate_output = ops.cast(gate_output, self.compute_dtype)
175
+
176
+ x = self._feedforward_intermediate_dense(x)
177
+
178
+ return self._feedforward_output_dense(ops.multiply(x, gate_output))
179
+
180
+
181
+ class Qwen3MoeExperts(keras.layers.Layer):
182
+ """A layer that contains a bank of feedforward experts for MoE.
183
+
184
+ This layer implements the expert part of a Mixture-of-Experts (MoE) model.
185
+ It creates a set of 'expert' feedforward networks that are computed in a
186
+ batched manner for efficiency. The weights for all experts are stored in
187
+ a single tensor, and computations are performed using `einsum` to process
188
+ all experts simultaneously.
189
+
190
+ Args:
191
+ num_experts (int): The total number of experts in the layer.
192
+ hidden_dim (int): The dimension of the input and output of each expert.
193
+ intermediate_dim (int): The intermediate dimension of each expert's
194
+ feedforward network.
195
+ activation_fn (str, optional): The activation function to use within
196
+ each expert. Defaults to "silu".
197
+ kernel_initializer (str, optional): The initializer for the kernel
198
+ weights. Defaults to "glorot_uniform".
199
+ """
200
+
201
+ def __init__(
202
+ self,
203
+ num_experts,
204
+ hidden_dim,
205
+ intermediate_dim,
206
+ activation_fn="silu",
207
+ kernel_initializer="glorot_uniform",
208
+ **kwargs,
209
+ ):
210
+ super().__init__(**kwargs)
211
+ self.num_experts = num_experts
212
+ self.hidden_dim = hidden_dim
213
+ self.intermediate_dim = intermediate_dim
214
+ self.activation = keras.activations.get(activation_fn)
215
+ self.kernel_initializer = kernel_initializer
216
+
217
+ def build(self, _):
218
+ self._expert_feedforward_gate_dense = self.add_weight(
219
+ shape=(
220
+ self.num_experts,
221
+ self.hidden_dim,
222
+ 2 * self.intermediate_dim,
223
+ ),
224
+ initializer=self.kernel_initializer,
225
+ trainable=True,
226
+ dtype=self.variable_dtype,
227
+ name="expert_feedforward_gate_dense",
228
+ )
229
+
230
+ self._expert_feedforward_output_dense = self.add_weight(
231
+ shape=(self.num_experts, self.intermediate_dim, self.hidden_dim),
232
+ initializer=self.kernel_initializer,
233
+ trainable=True,
234
+ dtype=self.variable_dtype,
235
+ name="expert_feedforward_output_dense",
236
+ )
237
+
238
+ self.built = True
239
+
240
+ def call(self, hidden_states):
241
+ gate_up = ops.einsum(
242
+ "th,ehm->etm", hidden_states, self._expert_feedforward_gate_dense
243
+ )
244
+ gate, up = ops.split(gate_up, 2, axis=-1)
245
+ hidden = up * self.activation(gate)
246
+ out = ops.einsum(
247
+ "eti,eih->eth", hidden, self._expert_feedforward_output_dense
248
+ )
249
+ return out
250
+
251
+
252
+ class Qwen3SparseMoeBlock(keras.layers.Layer):
253
+ """A sparse Mixture-of-Experts (MoE) block.
254
+
255
+ This block implements the full MoE logic. It contains a 'router' that
256
+ learns to send each input token to a subset of 'experts'. The final output
257
+ is a weighted combination of the outputs from the selected experts.
258
+ It also computes a load-balancing auxiliary loss during training to
259
+ encourage the router to distribute tokens evenly across all experts.
260
+
261
+ Args:
262
+ hidden_dim (int): The dimension of the input and output tensors.
263
+ moe_intermediate_dim (int): The intermediate dimension of each expert.
264
+ num_experts (int): The total number of experts available.
265
+ top_k (int): The number of experts to route each token to.
266
+ norm_top_k_prob (bool): If True, normalize the probabilities of the
267
+ top-k experts.
268
+ kernel_initializer (str, optional): The initializer for kernel weights.
269
+ Defaults to "glorot_uniform".
270
+ layer_norm_epsilon (float, optional): Epsilon for layer normalization.
271
+ Defaults to 1e-6.
272
+ router_aux_loss_coefficient (float, optional): The coefficient for the
273
+ load-balancing auxiliary loss. Defaults to 0.01.
274
+ """
275
+
276
+ def __init__(
277
+ self,
278
+ hidden_dim,
279
+ moe_intermediate_dim,
280
+ num_experts,
281
+ top_k,
282
+ norm_top_k_prob,
283
+ kernel_initializer="glorot_uniform",
284
+ layer_norm_epsilon=1e-6,
285
+ router_aux_loss_coefficient=0.01,
286
+ **kwargs,
287
+ ):
288
+ super().__init__(**kwargs)
289
+ self.hidden_dim = hidden_dim
290
+ self.intermediate_dim = moe_intermediate_dim
291
+ self.num_experts = num_experts
292
+ self.top_k = top_k
293
+ self.norm_top_k_prob = norm_top_k_prob
294
+ self.kernel_initializer = kernel_initializer
295
+ self.layer_norm_epsilon = layer_norm_epsilon
296
+ self.router_aux_loss_coefficient = router_aux_loss_coefficient
297
+
298
+ def build(self, decoder_sequence_shape):
299
+ self._sparse_feedforward_gate_dense = keras.layers.Dense(
300
+ self.num_experts,
301
+ use_bias=False,
302
+ kernel_initializer=self.kernel_initializer,
303
+ name="sparse_feedforward_gate_dense",
304
+ dtype=self.dtype_policy,
305
+ )
306
+ self._sparse_feedforward_gate_dense.build(decoder_sequence_shape)
307
+
308
+ # NOTE: Experts are implemented as a single layer to enable efficient
309
+ # batched computation. Implementing each expert individually is
310
+ # currently avoided due to the lack of `ragged_dot` support in the
311
+ # Keras ops API, which would make individual implementations unstable
312
+ # and prone to bugs.
313
+ self.expert_bank = Qwen3MoeExperts(
314
+ num_experts=self.num_experts,
315
+ hidden_dim=self.hidden_dim,
316
+ intermediate_dim=self.intermediate_dim,
317
+ kernel_initializer=self.kernel_initializer,
318
+ name="experts",
319
+ dtype=self.dtype_policy,
320
+ )
321
+ self.expert_bank.build(decoder_sequence_shape)
322
+
323
+ self.built = True
324
+
325
+ def call(self, hidden_states, attention_mask=None, training=None):
326
+ batch_size, seq_len, _ = ops.shape(hidden_states)
327
+ hidden_states_flattened = ops.reshape(
328
+ hidden_states, (-1, self.hidden_dim)
329
+ )
330
+
331
+ router_logits = self._sparse_feedforward_gate_dense(
332
+ hidden_states_flattened
333
+ )
334
+ router_probs = ops.softmax(router_logits, axis=-1)
335
+
336
+ top_p, top_i = ops.top_k(router_probs, k=self.top_k)
337
+ if self.norm_top_k_prob:
338
+ top_p = top_p / ops.sum(top_p, axis=-1, keepdims=True)
339
+
340
+ one_hot = ops.one_hot(top_i, self.num_experts)
341
+ one_hot = ops.cast(one_hot, top_p.dtype)
342
+ routing_full = ops.sum(one_hot * top_p[..., None], axis=1)
343
+ routing_full = ops.transpose(routing_full, (1, 0))
344
+ routing_full = ops.cast(routing_full, hidden_states_flattened.dtype)
345
+
346
+ expert_out = self.expert_bank(hidden_states_flattened)
347
+
348
+ weighted_out = expert_out * routing_full[:, :, None]
349
+ expert_contribution = ops.sum(weighted_out, axis=0)
350
+
351
+ out = ops.reshape(
352
+ expert_contribution, (batch_size, seq_len, self.hidden_dim)
353
+ )
354
+
355
+ # Compute and add auxiliary loss during training
356
+ if training:
357
+ aux_loss = compute_load_balancing_loss(
358
+ router_logits=router_logits,
359
+ num_experts=self.num_experts,
360
+ top_k=self.top_k,
361
+ attention_mask=attention_mask,
362
+ )
363
+ self.add_loss(self.router_aux_loss_coefficient * aux_loss)
364
+
365
+ return out, router_logits
366
+
367
+
368
+ class Qwen3MoeTransformerDecoder(keras.layers.Layer):
369
+ """A Transformer decoder layer for the Qwen3 Moe backbone.
370
+
371
+ This layer implements a Transformer decoder block that includes
372
+ self-attention with optional sliding window attention and a
373
+ Mixture-of-Experts (MoE) feed-forward network.
374
+
375
+ Args:
376
+ intermediate_dim: Output dimension of the first dense layer in the
377
+ feed-forward network (for non-MoE layers).
378
+ num_query_heads: Number of query attention heads.
379
+ num_key_value_heads: Number of key/value attention heads (for GQA).
380
+ moe_intermediate_dim: The intermediate dimension for each expert in the
381
+ MoE layer.
382
+ num_experts: The total number of experts in the MoE layer.
383
+ top_k: The number of experts to which each token is routed.
384
+ norm_top_k_prob: If True, normalize the top-k probabilities.
385
+ head_dim: The dimension of each attention head. If None, it is
386
+ inferred from other dimensions.
387
+ is_sparse_mlp: If True, uses a sparse MLP.
388
+ rope_max_wavelength: Maximum wavelength for RoPE (Rotary Position
389
+ Embedding).
390
+ rope_scaling_factor: Scaling factor for RoPE, used for extending
391
+ context length.
392
+ activation: Activation function to use in the feed-forward network.
393
+ layer_norm_epsilon: Small float added to variance to avoid dividing
394
+ by zero in layer norm.
395
+ kernel_initializer: Initializer for the kernel weights.
396
+ dropout: Dropout rate for attention and hidden layers.
397
+ sliding_window_size: Size of the sliding window for attention when
398
+ enabled.
399
+ router_aux_loss_coefficient: The coefficient for the router's auxiliary
400
+ loss, used for load balancing.
401
+ **kwargs: Additional keyword arguments to pass to the Layer.
402
+ """
403
+
404
+ def __init__(
405
+ self,
406
+ intermediate_dim,
407
+ num_query_heads,
408
+ num_key_value_heads,
409
+ moe_intermediate_dim,
410
+ num_experts,
411
+ top_k,
412
+ norm_top_k_prob,
413
+ head_dim=None,
414
+ is_sparse_mlp=False,
415
+ rope_max_wavelength=10000,
416
+ rope_scaling_factor=1.0,
417
+ activation="silu",
418
+ layer_norm_epsilon=1e-6,
419
+ kernel_initializer="glorot_uniform",
420
+ dropout=0,
421
+ sliding_window_size=4096,
422
+ router_aux_loss_coefficient=0.001,
423
+ **kwargs,
424
+ ):
425
+ super().__init__(**kwargs)
426
+ self.intermediate_dim = intermediate_dim
427
+ self.num_query_heads = num_query_heads
428
+ self.num_key_value_heads = num_key_value_heads
429
+ self.rope_max_wavelength = rope_max_wavelength
430
+ self.rope_scaling_factor = rope_scaling_factor
431
+ self.dropout = dropout
432
+ self.sliding_window_size = sliding_window_size
433
+ self.activation = keras.activations.get(activation)
434
+ self.layer_norm_epsilon = layer_norm_epsilon
435
+ self.kernel_initializer = keras.initializers.get(kernel_initializer)
436
+ self.moe_intermediate_dim = moe_intermediate_dim
437
+ self.head_dim = head_dim
438
+ self.num_experts = num_experts
439
+ self.top_k = top_k
440
+ self.norm_top_k_prob = norm_top_k_prob
441
+ self.is_sparse_mlp = is_sparse_mlp
442
+ self.router_aux_loss_coefficient = router_aux_loss_coefficient
443
+ self.supports_masking = True
444
+
445
+ def build(self, decoder_sequence_shape):
446
+ self._decoder_sequence_shape = decoder_sequence_shape
447
+ self.hidden_dim = decoder_sequence_shape[-1]
448
+
449
+ # Self attention layer.
450
+ self._self_attention_layer = Qwen3MoeAttention(
451
+ num_query_heads=self.num_query_heads,
452
+ num_key_value_heads=self.num_key_value_heads,
453
+ rope_max_wavelength=self.rope_max_wavelength,
454
+ head_dim=self.head_dim,
455
+ rope_scaling_factor=self.rope_scaling_factor,
456
+ kernel_initializer=clone_initializer(self.kernel_initializer),
457
+ dropout=self.dropout,
458
+ sliding_window_size=self.sliding_window_size,
459
+ dtype=self.dtype_policy,
460
+ name="self_attention",
461
+ )
462
+ self._self_attention_layer.build(decoder_sequence_shape)
463
+
464
+ self._self_attention_layernorm = Qwen3MoeLayerNorm(
465
+ epsilon=self.layer_norm_epsilon,
466
+ dtype=self.dtype_policy,
467
+ name="self_attention_layernorm",
468
+ )
469
+
470
+ self._self_attention_layernorm.build(decoder_sequence_shape)
471
+ self._self_attention_dropout = keras.layers.Dropout(
472
+ rate=self.dropout,
473
+ dtype=self.dtype_policy,
474
+ name="self_attention_dropout",
475
+ )
476
+
477
+ # Feedforward layers.
478
+ if self.is_sparse_mlp:
479
+ self.mlp = Qwen3SparseMoeBlock(
480
+ hidden_dim=self.hidden_dim,
481
+ moe_intermediate_dim=self.moe_intermediate_dim,
482
+ num_experts=self.num_experts,
483
+ top_k=self.top_k,
484
+ norm_top_k_prob=self.norm_top_k_prob,
485
+ router_aux_loss_coefficient=self.router_aux_loss_coefficient,
486
+ kernel_initializer=self.kernel_initializer,
487
+ dtype=self.dtype_policy,
488
+ )
489
+ self.mlp.build(decoder_sequence_shape)
490
+ else:
491
+ self.mlp = Qwen3MoeMLP(
492
+ intermediate_dim=self.intermediate_dim,
493
+ hidden_dim=self.hidden_dim,
494
+ dtype=self.dtype_policy,
495
+ )
496
+ self.mlp.build(decoder_sequence_shape)
497
+
498
+ self._feedforward_layernorm = Qwen3MoeLayerNorm(
499
+ epsilon=self.layer_norm_epsilon,
500
+ dtype=self.dtype_policy,
501
+ name="feedforward_layernorm",
502
+ )
503
+ self._feedforward_layernorm.build(decoder_sequence_shape)
504
+
505
+ self.built = True
506
+
507
+ def call(
508
+ self,
509
+ decoder_sequence,
510
+ decoder_padding_mask=None,
511
+ decoder_attention_mask=None,
512
+ self_attention_cache=None,
513
+ self_attention_cache_update_index=None,
514
+ training=None,
515
+ ):
516
+ """Forward pass for the decoder layer.
517
+
518
+ Args:
519
+ decoder_sequence: Input tensor of shape [batch_size, seq_length,
520
+ hidden_size].
521
+ decoder_padding_mask: Mask tensor for padding tokens.
522
+ decoder_attention_mask: Additional attention mask.
523
+ self_attention_cache: Optional cached key and value tensors for
524
+ self-attention.
525
+ self_attention_cache_update_index: Index at which to update the
526
+ cache.
527
+ training: Boolean indicating whether in training mode.
528
+
529
+ Returns:
530
+ decoder_output: Output tensor after applying transformer decoder
531
+ block.
532
+ self_attention_cache: Updated cache tensors (if cache is provided).
533
+ """
534
+ self_attention_mask = self._compute_self_attention_mask(
535
+ decoder_sequence=decoder_sequence,
536
+ decoder_padding_mask=decoder_padding_mask,
537
+ decoder_attention_mask=decoder_attention_mask,
538
+ self_attention_cache=self_attention_cache,
539
+ self_attention_cache_update_index=self_attention_cache_update_index,
540
+ )
541
+ residual = decoder_sequence
542
+
543
+ x = self._self_attention_layernorm(decoder_sequence)
544
+
545
+ # Self attention block.
546
+ x = self._self_attention_layer(
547
+ hidden_states=x,
548
+ attention_mask=self_attention_mask,
549
+ cache=self_attention_cache,
550
+ cache_update_index=self_attention_cache_update_index,
551
+ )
552
+
553
+ if self_attention_cache is not None:
554
+ x, self_attention_cache = x
555
+
556
+ x = self._self_attention_dropout(x, training=training)
557
+
558
+ x = x + residual
559
+ residual = x
560
+
561
+ x = self._feedforward_layernorm(x)
562
+ if isinstance(self.mlp, Qwen3SparseMoeBlock):
563
+ x = self.mlp(
564
+ x, training=training, attention_mask=self_attention_mask
565
+ )
566
+ else:
567
+ x = self.mlp(x)
568
+
569
+ if isinstance(x, tuple):
570
+ x, _ = x
571
+
572
+ x = ops.cast(x, ops.dtype(residual))
573
+ decoder_output = x + residual
574
+
575
+ output = (decoder_output,)
576
+
577
+ if self_attention_cache is not None:
578
+ output += (self_attention_cache,)
579
+
580
+ return output[0] if len(output) == 1 else output
581
+
582
+ def _compute_self_attention_mask(
583
+ self,
584
+ decoder_sequence,
585
+ decoder_padding_mask,
586
+ decoder_attention_mask,
587
+ self_attention_cache,
588
+ self_attention_cache_update_index,
589
+ ):
590
+ """Computes the self-attention mask combining causal, padding and
591
+ attention masks.
592
+
593
+ Args:
594
+ decoder_sequence: Input tensor.
595
+ decoder_padding_mask: Mask tensor for padding tokens.
596
+ decoder_attention_mask: Additional attention mask.
597
+ self_attention_cache: Optional cached key and value tensors.
598
+ self_attention_cache_update_index: Index at which to update the
599
+ cache.
600
+
601
+ Returns:
602
+ Combined attention mask tensor.
603
+ """
604
+ decoder_mask = merge_padding_and_attention_mask(
605
+ decoder_sequence, decoder_padding_mask, decoder_attention_mask
606
+ )
607
+ batch_size = ops.shape(decoder_sequence)[0]
608
+ input_length = output_length = ops.shape(decoder_sequence)[1]
609
+ # We need to handle a rectangular causal mask when doing cached
610
+ # decoding. For generative inference, `decoder_sequence` will
611
+ # generally be length 1, and `cache` will be the full generation length.
612
+ if self_attention_cache is not None:
613
+ input_length = ops.shape(self_attention_cache)[2]
614
+
615
+ cache_update_index = (
616
+ 0
617
+ if self_attention_cache_update_index is None
618
+ else self_attention_cache_update_index
619
+ )
620
+
621
+ causal_mask = compute_causal_mask(
622
+ batch_size, input_length, output_length, cache_update_index
623
+ )
624
+
625
+ return (
626
+ ops.minimum(decoder_mask, causal_mask)
627
+ if decoder_mask is not None
628
+ else causal_mask
629
+ )
630
+
631
+ def compute_output_shape(self, decoder_sequence_shape):
632
+ """Computes the output shape of the layer.
633
+
634
+ Args:
635
+ decoder_sequence_shape: Shape of the decoder sequence input.
636
+
637
+ Returns:
638
+ Output shape, which is the same as the input shape.
639
+ """
640
+ return decoder_sequence_shape
641
+
642
+ def get_config(self):
643
+ """Returns the config of the layer.
644
+
645
+ Returns:
646
+ Dictionary containing the parameters used to initialize this layer.
647
+ """
648
+ config = super().get_config()
649
+ config.update(
650
+ {
651
+ "num_query_heads": self.num_query_heads,
652
+ "intermediate_dim": self.intermediate_dim,
653
+ "moe_intermediate_dim": self.moe_intermediate_dim,
654
+ "rope_max_wavelength": self.rope_max_wavelength,
655
+ "num_key_value_heads": self.num_key_value_heads,
656
+ "rope_scaling_factor": self.rope_scaling_factor,
657
+ "layer_norm_epsilon": self.layer_norm_epsilon,
658
+ "dropout": self.dropout,
659
+ "sliding_window_size": self.sliding_window_size,
660
+ "num_experts": self.num_experts,
661
+ "top_k": self.top_k,
662
+ "norm_top_k_prob": self.norm_top_k_prob,
663
+ "router_aux_loss_coefficient": self.router_aux_loss_coefficient,
664
+ "head_dim": self.head_dim,
665
+ "is_sparse_mlp": self.is_sparse_mlp,
666
+ "activation": keras.activations.serialize(self.activation),
667
+ "kernel_initializer": keras.initializers.serialize(
668
+ self.kernel_initializer
669
+ ),
670
+ }
671
+ )
672
+ return config