keras-hub-nightly 0.21.0.dev202505050407__py3-none-any.whl → 0.21.0.dev202505060405__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. keras_hub/models/__init__.py +21 -0
  2. keras_hub/src/models/backbone.py +5 -2
  3. keras_hub/src/models/mixtral/mixtral_attention.py +263 -0
  4. keras_hub/src/models/mixtral/mixtral_backbone.py +207 -0
  5. keras_hub/src/models/mixtral/mixtral_causal_lm.py +281 -0
  6. keras_hub/src/models/mixtral/mixtral_causal_lm_preprocessor.py +76 -0
  7. keras_hub/src/models/mixtral/mixtral_decoder.py +494 -0
  8. keras_hub/src/models/mixtral/mixtral_layer_norm.py +34 -0
  9. keras_hub/src/models/mixtral/mixtral_tokenizer.py +21 -0
  10. keras_hub/src/models/qwen/qwen_attention.py +3 -1
  11. keras_hub/src/models/qwen/qwen_presets.py +61 -0
  12. keras_hub/src/models/qwen_moe/__init__.py +0 -0
  13. keras_hub/src/models/qwen_moe/qwen_moe_attention.py +377 -0
  14. keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +373 -0
  15. keras_hub/src/models/qwen_moe/qwen_moe_causal_lm.py +350 -0
  16. keras_hub/src/models/qwen_moe/qwen_moe_causal_lm_preprocessor.py +17 -0
  17. keras_hub/src/models/qwen_moe/qwen_moe_decoder.py +625 -0
  18. keras_hub/src/models/qwen_moe/qwen_moe_layernorm.py +32 -0
  19. keras_hub/src/models/qwen_moe/qwen_moe_tokenizer.py +46 -0
  20. keras_hub/src/models/retinanet/retinanet_image_converter.py +0 -13
  21. keras_hub/src/models/retinanet/retinanet_presets.py +2 -2
  22. keras_hub/src/models/task.py +5 -2
  23. keras_hub/src/utils/keras_utils.py +11 -0
  24. keras_hub/src/utils/preset_utils.py +69 -9
  25. keras_hub/src/utils/tensor_utils.py +27 -1
  26. keras_hub/src/utils/transformers/convert_mixtral.py +139 -0
  27. keras_hub/src/utils/transformers/convert_qwen_moe.py +253 -0
  28. keras_hub/src/utils/transformers/preset_loader.py +6 -0
  29. keras_hub/src/version.py +1 -1
  30. keras_hub/tokenizers/__init__.py +6 -0
  31. {keras_hub_nightly-0.21.0.dev202505050407.dist-info → keras_hub_nightly-0.21.0.dev202505060405.dist-info}/METADATA +1 -1
  32. {keras_hub_nightly-0.21.0.dev202505050407.dist-info → keras_hub_nightly-0.21.0.dev202505060405.dist-info}/RECORD +34 -16
  33. {keras_hub_nightly-0.21.0.dev202505050407.dist-info → keras_hub_nightly-0.21.0.dev202505060405.dist-info}/WHEEL +0 -0
  34. {keras_hub_nightly-0.21.0.dev202505050407.dist-info → keras_hub_nightly-0.21.0.dev202505060405.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,625 @@
1
+ import keras
2
+ from keras import ops
3
+
4
+ from keras_hub.src.layers.modeling.transformer_layer_utils import (
5
+ compute_causal_mask,
6
+ )
7
+ from keras_hub.src.layers.modeling.transformer_layer_utils import (
8
+ merge_padding_and_attention_mask,
9
+ )
10
+ from keras_hub.src.models.qwen_moe.qwen_moe_attention import QwenMoeAttention
11
+ from keras_hub.src.models.qwen_moe.qwen_moe_layernorm import QwenMoeLayerNorm
12
+ from keras_hub.src.utils.keras_utils import clone_initializer
13
+
14
+
15
+ def compute_load_balancing_loss(
16
+ router_logits, num_experts, top_k, attention_mask=None
17
+ ):
18
+ """
19
+ Compute the load balancing auxiliary loss for a single MoE layer.
20
+
21
+ Args:
22
+ router_logits: Tensor of shape (batch_size * seq_len, num_experts).
23
+ num_experts: Integer, total number of experts.
24
+ top_k: Integer, number of experts to select per token.
25
+ attention_mask: Tensor of shape (batch_size, seq_len, seq_len),
26
+ optional mask for padding.
27
+
28
+ Returns:
29
+ Scalar tensor representing the auxiliary loss.
30
+ """
31
+ # Compute routing probabilities
32
+ routing_weights = ops.softmax(
33
+ router_logits, axis=-1
34
+ ) # Shape: (batch_size * seq_len, num_experts)
35
+
36
+ # Get top-k experts
37
+ _, selected_experts = ops.top_k(
38
+ routing_weights, k=top_k
39
+ ) # Shape: (batch_size * seq_len, top_k)
40
+
41
+ # Create one-hot encoding for selected experts
42
+ expert_mask = ops.one_hot(
43
+ selected_experts, num_experts
44
+ ) # Shape: (batch_size * seq_len, top_k, num_experts)
45
+
46
+ if attention_mask is not None:
47
+ # Convert attention mask to (batch_size, seq_len)
48
+ batch_size, seq_len, _ = ops.shape(attention_mask)
49
+ flat_mask = ops.any(attention_mask, axis=-1)
50
+ flat_mask = ops.reshape(
51
+ flat_mask, (-1,)
52
+ ) # Shape: (batch_size * seq_len,)
53
+ # Expand mask for broadcasting
54
+ expert_attention_mask = ops.expand_dims(
55
+ flat_mask, axis=-1
56
+ ) # Shape: (batch_size * seq_len, 1)
57
+ expert_attention_mask = ops.cast(expert_attention_mask, dtype="float32")
58
+
59
+ # Compute masked means
60
+ tokens_per_expert = ops.sum(
61
+ expert_mask * expert_attention_mask[:, None, :], axis=0
62
+ ) / ops.maximum(
63
+ ops.sum(expert_attention_mask[:, None, :], axis=0), 1e-9
64
+ ) # Shape: (top_k, num_experts)
65
+ router_prob_per_expert = ops.sum(
66
+ routing_weights * expert_attention_mask, axis=0
67
+ ) / ops.maximum(
68
+ ops.sum(expert_attention_mask, axis=0), 1e-9
69
+ ) # Shape: (num_experts,)
70
+ else:
71
+ # Unmasked means
72
+ tokens_per_expert = ops.mean(
73
+ expert_mask, axis=0
74
+ ) # Shape: (top_k, num_experts)
75
+ router_prob_per_expert = ops.mean(
76
+ routing_weights, axis=0
77
+ ) # Shape: (num_experts,)
78
+
79
+ # Average over top_k dimension if necessary
80
+ tokens_per_expert = ops.mean(
81
+ tokens_per_expert, axis=0
82
+ ) # Shape: (num_experts,)
83
+
84
+ # Compute the loss
85
+ overall_loss = ops.sum(tokens_per_expert * router_prob_per_expert)
86
+ return overall_loss * num_experts
87
+
88
+
89
+ class QwenMoeMLP(keras.layers.Layer):
90
+ def __init__(
91
+ self,
92
+ intermediate_dim,
93
+ hidden_dim,
94
+ activation_fn="silu",
95
+ layer_norm_epsilon=1e-5,
96
+ kernel_initializer="glorot_uniform",
97
+ **kwargs,
98
+ ):
99
+ super().__init__(**kwargs)
100
+ self.intermediate_dim = intermediate_dim
101
+ self.hidden_dim = hidden_dim
102
+ self.activation_fn = activation_fn
103
+ self.kernel_initializer = kernel_initializer
104
+ self.layer_norm_epsilon = layer_norm_epsilon
105
+
106
+ def build(self, decoder_sequence_shape):
107
+ # Feedforward layers.
108
+ self._feedforward_intermediate_dense = keras.layers.Dense(
109
+ self.intermediate_dim,
110
+ kernel_initializer=clone_initializer(self.kernel_initializer),
111
+ use_bias=False,
112
+ dtype=self.dtype_policy,
113
+ name="feedforward_intermediate_dense",
114
+ )
115
+ self._feedforward_intermediate_dense.build(decoder_sequence_shape)
116
+
117
+ self._feedforward_gate_dense = keras.layers.Dense(
118
+ self.intermediate_dim,
119
+ kernel_initializer=clone_initializer(self.kernel_initializer),
120
+ use_bias=False,
121
+ dtype=self.dtype_policy,
122
+ name="feedforward_gate_dense",
123
+ )
124
+ self._feedforward_gate_dense.build(decoder_sequence_shape)
125
+
126
+ self._feedforward_output_dense = keras.layers.Dense(
127
+ self.hidden_dim,
128
+ kernel_initializer=clone_initializer(self.kernel_initializer),
129
+ use_bias=False,
130
+ dtype=self.dtype_policy,
131
+ name="feedforward_output_dense",
132
+ )
133
+
134
+ self._feedforward_output_dense.build(
135
+ self._feedforward_gate_dense.compute_output_shape(
136
+ decoder_sequence_shape
137
+ )
138
+ )
139
+
140
+ self.activation = keras.activations.get(self.activation_fn)
141
+ self.built = True
142
+
143
+ def call(self, x):
144
+ gate_output = self._feedforward_gate_dense(x)
145
+
146
+ # Note that we run the activation function in full 32-bit
147
+ # precision since this is what `torch.nn.functional.silu`
148
+ # does. Internally, `torch.nn.functional.silu` converts the
149
+ # inputs to float32, computes SiLU, and converts the outputs
150
+ # back to compute dtype.
151
+ # CPU Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cpu/Activation.cpp#L1221-L1235 # noqa: E501
152
+ # CUDA Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cuda/ActivationSiluKernel.cu # noqa: E501
153
+ gate_output = ops.cast(gate_output, "float32")
154
+ gate_output = self.activation(gate_output)
155
+ gate_output = ops.cast(gate_output, self.compute_dtype)
156
+
157
+ x = self._feedforward_intermediate_dense(x)
158
+
159
+ x = self._feedforward_output_dense(ops.multiply(x, gate_output))
160
+
161
+ return x
162
+
163
+
164
+ class QwenMoeExperts(keras.layers.Layer):
165
+ """Batched Experts Layer"""
166
+
167
+ def __init__(
168
+ self,
169
+ num_experts,
170
+ hidden_dim,
171
+ intermediate_dim,
172
+ activation_fn="silu",
173
+ kernel_initializer="glorot_uniform",
174
+ **kwargs,
175
+ ):
176
+ super().__init__(**kwargs)
177
+ self.num_experts = num_experts
178
+ self.hidden_dim = hidden_dim
179
+ self.intermediate_dim = intermediate_dim
180
+ self.activation = keras.activations.get(activation_fn)
181
+ self.kernel_initializer = kernel_initializer
182
+
183
+ def build(self, _):
184
+ self._expert_feedforward_gate_dense = self.add_weight(
185
+ shape=(
186
+ self.num_experts,
187
+ self.hidden_dim,
188
+ 2 * self.intermediate_dim,
189
+ ),
190
+ initializer=self.kernel_initializer,
191
+ trainable=True,
192
+ dtype=self.variable_dtype,
193
+ name="expert_feedforward_gate_dense",
194
+ )
195
+
196
+ self._expert_feedforward_output_dense = self.add_weight(
197
+ shape=(self.num_experts, self.intermediate_dim, self.hidden_dim),
198
+ initializer=self.kernel_initializer,
199
+ trainable=True,
200
+ dtype=self.variable_dtype,
201
+ name="expert_feedforward_output_dense",
202
+ )
203
+
204
+ self.built = True
205
+
206
+ def call(self, hidden_states):
207
+ gate_up = ops.einsum(
208
+ "th,ehm->etm", hidden_states, self._expert_feedforward_gate_dense
209
+ )
210
+ gate, up = ops.split(gate_up, 2, axis=-1)
211
+ hidden = up * self.activation(gate)
212
+ out = ops.einsum(
213
+ "eti,eih->eth", hidden, self._expert_feedforward_output_dense
214
+ )
215
+ return out
216
+
217
+
218
+ class QwenSparseMoeBlock(keras.layers.Layer):
219
+ """Qwen-2 Sparse Moe Block"""
220
+
221
+ def __init__(
222
+ self,
223
+ hidden_dim,
224
+ moe_intermediate_dim,
225
+ shared_expert_intermediate_dim,
226
+ num_experts,
227
+ top_k,
228
+ norm_top_k_prob,
229
+ kernel_initializer="glorot_uniform",
230
+ layer_norm_epsilon=1e-5,
231
+ router_aux_loss_coefficient=0.01,
232
+ **kwargs,
233
+ ):
234
+ super().__init__(**kwargs)
235
+ self.hidden_dim = hidden_dim
236
+ self.intermediate_dim = moe_intermediate_dim
237
+ self.intermediate_dim_shared = shared_expert_intermediate_dim
238
+ self.num_experts = num_experts
239
+ self.top_k = top_k
240
+ self.norm_top_k_prob = norm_top_k_prob
241
+ self.kernel_initializer = kernel_initializer
242
+ self.layer_norm_epsilon = layer_norm_epsilon
243
+ self.router_aux_loss_coefficient = router_aux_loss_coefficient
244
+
245
+ def build(self, decoder_sequence_shape):
246
+ self._sparse_feedforward_gate_dense = keras.layers.Dense(
247
+ self.num_experts,
248
+ use_bias=False,
249
+ kernel_initializer=self.kernel_initializer,
250
+ name="sparse_feedforward_gate_dense",
251
+ dtype=self.dtype_policy,
252
+ )
253
+ self._sparse_feedforward_gate_dense.build(decoder_sequence_shape)
254
+
255
+ # NOTE: Experts are implemented as a single layer to enable efficient
256
+ # batched computation. Implementing each expert individually is
257
+ # currently avoided due to the lack of `ragged_dot` support in the
258
+ # Keras ops API, which would make individual implementations unstable
259
+ # and prone to bugs.
260
+ self.expert_bank = QwenMoeExperts(
261
+ num_experts=self.num_experts,
262
+ hidden_dim=self.hidden_dim,
263
+ intermediate_dim=self.intermediate_dim,
264
+ kernel_initializer=self.kernel_initializer,
265
+ name="experts",
266
+ dtype=self.dtype_policy,
267
+ )
268
+ self.expert_bank.build(decoder_sequence_shape)
269
+
270
+ self.shared_expert_dense = QwenMoeMLP(
271
+ intermediate_dim=self.intermediate_dim_shared,
272
+ hidden_dim=self.hidden_dim,
273
+ kernel_initializer=self.kernel_initializer,
274
+ layer_norm_epsilon=self.layer_norm_epsilon,
275
+ name="shared_expert_dense",
276
+ dtype=self.dtype_policy,
277
+ )
278
+ self.shared_expert_dense.build(decoder_sequence_shape)
279
+
280
+ self.shared_expert_gate_dense = keras.layers.Dense(
281
+ 1,
282
+ use_bias=False,
283
+ name="shared_expert_gate_dense",
284
+ dtype=self.dtype_policy,
285
+ )
286
+ self.shared_expert_gate_dense.build(decoder_sequence_shape)
287
+
288
+ self.built = True
289
+
290
+ def call(self, hidden_states, attention_mask=None, training=None):
291
+ batch_size, seq_len, _ = ops.shape(hidden_states)
292
+ hidden_states_flattened = ops.reshape(
293
+ hidden_states, (-1, self.hidden_dim)
294
+ )
295
+
296
+ router_logits = self._sparse_feedforward_gate_dense(
297
+ hidden_states_flattened
298
+ )
299
+ router_probs = ops.softmax(router_logits, axis=-1)
300
+
301
+ top_p, top_i = ops.top_k(router_probs, k=self.top_k)
302
+ if self.norm_top_k_prob:
303
+ top_p = top_p / ops.sum(top_p, axis=-1, keepdims=True)
304
+
305
+ one_hot = ops.one_hot(top_i, self.num_experts)
306
+ one_hot = ops.cast(one_hot, top_p.dtype)
307
+ routing_full = ops.sum(one_hot * top_p[..., None], axis=1)
308
+ routing_full = ops.transpose(routing_full, (1, 0))
309
+ routing_full = ops.cast(routing_full, hidden_states_flattened.dtype)
310
+
311
+ expert_out = self.expert_bank(hidden_states_flattened)
312
+
313
+ weighted_out = expert_out * routing_full[:, :, None]
314
+ expert_contribution = ops.sum(weighted_out, axis=0)
315
+
316
+ shared_expert_output = self.shared_expert_dense(hidden_states_flattened)
317
+ shared_expert_output *= ops.sigmoid(
318
+ self.shared_expert_gate_dense(hidden_states_flattened)
319
+ )
320
+
321
+ out_flat = expert_contribution + shared_expert_output
322
+ out = ops.reshape(out_flat, (batch_size, seq_len, self.hidden_dim))
323
+
324
+ # Compute and add auxiliary loss during training
325
+ if training:
326
+ aux_loss = compute_load_balancing_loss(
327
+ router_logits=router_logits,
328
+ num_experts=self.num_experts,
329
+ top_k=self.top_k,
330
+ attention_mask=attention_mask,
331
+ )
332
+ self.add_loss(self.router_aux_loss_coefficient * aux_loss)
333
+
334
+ return out, router_logits
335
+
336
+
337
+ class QwenMoeTransformerDecoder(keras.layers.Layer):
338
+ def __init__(
339
+ self,
340
+ intermediate_dim,
341
+ num_query_heads,
342
+ num_key_value_heads,
343
+ moe_intermediate_dim,
344
+ shared_expert_intermediate_dim,
345
+ num_experts,
346
+ top_k,
347
+ norm_top_k_prob,
348
+ decoder_sparse_step,
349
+ rope_max_wavelength=10000,
350
+ rope_scaling_factor=1.0,
351
+ activation="silu",
352
+ layer_norm_epsilon=1e-5,
353
+ kernel_initializer="glorot_uniform",
354
+ dropout=0,
355
+ use_sliding_window_attention=False,
356
+ sliding_window_size=4096,
357
+ layer_index=0,
358
+ mlp_only_layers=[],
359
+ output_router_logits=False,
360
+ router_aux_loss_coefficient=0.001,
361
+ **kwargs,
362
+ ):
363
+ super().__init__(**kwargs)
364
+ self.intermediate_dim = intermediate_dim
365
+ self.num_query_heads = num_query_heads
366
+ self.num_key_value_heads = num_key_value_heads
367
+ self.rope_max_wavelength = rope_max_wavelength
368
+ self.rope_scaling_factor = rope_scaling_factor
369
+ self.dropout = dropout
370
+ self.use_sliding_window_attention = use_sliding_window_attention
371
+ self.sliding_window_size = sliding_window_size
372
+ self.activation = keras.activations.get(activation)
373
+ self.layer_norm_epsilon = layer_norm_epsilon
374
+ self.kernel_initializer = keras.initializers.get(kernel_initializer)
375
+ self.layer_index = layer_index
376
+ self.mlp_only_layers = mlp_only_layers
377
+ self.moe_intermediate_dim = moe_intermediate_dim
378
+ self.shared_expert_intermediate_dim = shared_expert_intermediate_dim
379
+ self.num_experts = num_experts
380
+ self.top_k = top_k
381
+ self.norm_top_k_prob = norm_top_k_prob
382
+ self.decoder_sparse_step = decoder_sparse_step
383
+ self.output_router_logits = output_router_logits
384
+ self.router_aux_loss_coefficient = router_aux_loss_coefficient
385
+ self.supports_masking = True
386
+
387
+ def build(self, decoder_sequence_shape):
388
+ self._decoder_sequence_shape = decoder_sequence_shape
389
+ self.hidden_dim = decoder_sequence_shape[-1]
390
+
391
+ # Self attention layer.
392
+ self._self_attention_layer = QwenMoeAttention(
393
+ num_query_heads=self.num_query_heads,
394
+ num_key_value_heads=self.num_key_value_heads,
395
+ rope_max_wavelength=self.rope_max_wavelength,
396
+ rope_scaling_factor=self.rope_scaling_factor,
397
+ kernel_initializer=clone_initializer(self.kernel_initializer),
398
+ dropout=self.dropout,
399
+ use_sliding_window_attention=self.use_sliding_window_attention,
400
+ sliding_window_size=self.sliding_window_size,
401
+ name="self_attention",
402
+ dtype=self.dtype_policy,
403
+ )
404
+ self._self_attention_layer.build(decoder_sequence_shape)
405
+
406
+ self._self_attention_layernorm = QwenMoeLayerNorm(
407
+ epsilon=self.layer_norm_epsilon,
408
+ dtype=self.dtype_policy,
409
+ name="self_attention_layernorm",
410
+ )
411
+
412
+ self._self_attention_layernorm.build(decoder_sequence_shape)
413
+ self._self_attention_dropout = keras.layers.Dropout(
414
+ rate=self.dropout,
415
+ dtype=self.dtype_policy,
416
+ name="self_attention_dropout",
417
+ )
418
+
419
+ # Feedforward layers.
420
+ if (self.layer_index not in self.mlp_only_layers) and (
421
+ self.num_experts > 0
422
+ and (self.layer_index + 1) % self.decoder_sparse_step == 0
423
+ ):
424
+ self.mlp = QwenSparseMoeBlock(
425
+ hidden_dim=self.hidden_dim,
426
+ moe_intermediate_dim=self.moe_intermediate_dim,
427
+ shared_expert_intermediate_dim=self.shared_expert_intermediate_dim,
428
+ num_experts=self.num_experts,
429
+ top_k=self.top_k,
430
+ norm_top_k_prob=self.norm_top_k_prob,
431
+ router_aux_loss_coefficient=self.router_aux_loss_coefficient,
432
+ kernel_initializer=self.kernel_initializer,
433
+ dtype=self.dtype_policy,
434
+ )
435
+ self.mlp.build(decoder_sequence_shape)
436
+ else:
437
+ self.mlp = QwenMoeMLP(
438
+ intermediate_dim=self.intermediate_dim,
439
+ hidden_dim=self.hidden_dim,
440
+ dtype=self.dtype_policy,
441
+ )
442
+ self.mlp.build(decoder_sequence_shape)
443
+
444
+ self._feedforward_layernorm = QwenMoeLayerNorm(
445
+ epsilon=self.layer_norm_epsilon,
446
+ dtype=self.dtype_policy,
447
+ name="feedforward_layernorm",
448
+ )
449
+ self._feedforward_layernorm.build(decoder_sequence_shape)
450
+
451
+ self.built = True
452
+
453
+ def call(
454
+ self,
455
+ decoder_sequence,
456
+ decoder_padding_mask=None,
457
+ decoder_attention_mask=None,
458
+ self_attention_cache=None,
459
+ self_attention_cache_update_index=None,
460
+ training=None,
461
+ ):
462
+ """Forward pass for the decoder layer.
463
+
464
+ Args:
465
+ decoder_sequence: Input tensor of shape [batch_size, seq_length,
466
+ hidden_size].
467
+ decoder_padding_mask: Mask tensor for padding tokens.
468
+ decoder_attention_mask: Additional attention mask.
469
+ self_attention_cache: Optional cached key and value tensors for
470
+ self-attention.
471
+ self_attention_cache_update_index: Index at which to update the
472
+ cache.
473
+ training: Boolean indicating whether in training mode.
474
+
475
+ Returns:
476
+ decoder_output: Output tensor after applying transformer decoder
477
+ block.
478
+ self_attention_cache: Updated cache tensors (if cache is provided).
479
+ """
480
+ self_attention_mask = self._compute_self_attention_mask(
481
+ decoder_sequence=decoder_sequence,
482
+ decoder_padding_mask=decoder_padding_mask,
483
+ decoder_attention_mask=decoder_attention_mask,
484
+ self_attention_cache=self_attention_cache,
485
+ self_attention_cache_update_index=self_attention_cache_update_index,
486
+ )
487
+ residual = decoder_sequence
488
+
489
+ x = self._self_attention_layernorm(decoder_sequence)
490
+
491
+ # Self attention block.
492
+ x = self._self_attention_layer(
493
+ hidden_states=x,
494
+ attention_mask=self_attention_mask,
495
+ cache=self_attention_cache,
496
+ cache_update_index=self_attention_cache_update_index,
497
+ )
498
+
499
+ if self_attention_cache is not None:
500
+ x, self_attention_cache = x
501
+
502
+ x = self._self_attention_dropout(x, training=training)
503
+
504
+ x = x + residual
505
+ residual = x
506
+
507
+ x = self._feedforward_layernorm(x)
508
+ if isinstance(self.mlp, QwenSparseMoeBlock):
509
+ x = self.mlp(
510
+ x, training=training, attention_mask=self_attention_mask
511
+ )
512
+ else:
513
+ x = self.mlp(x)
514
+ if isinstance(x, tuple):
515
+ x, router_logits = x
516
+ else:
517
+ router_logits = None
518
+
519
+ x = ops.cast(x, ops.dtype(residual))
520
+ decoder_output = x + residual
521
+
522
+ output = (decoder_output,)
523
+
524
+ if self_attention_cache is not None:
525
+ output += (self_attention_cache,)
526
+
527
+ if self.output_router_logits:
528
+ output += (router_logits,)
529
+
530
+ return output[0] if len(output) == 1 else output
531
+
532
+ def _compute_self_attention_mask(
533
+ self,
534
+ decoder_sequence,
535
+ decoder_padding_mask,
536
+ decoder_attention_mask,
537
+ self_attention_cache,
538
+ self_attention_cache_update_index,
539
+ ):
540
+ """Computes the self-attention mask combining causal, padding and
541
+ attention masks.
542
+
543
+ Args:
544
+ decoder_sequence: Input tensor.
545
+ decoder_padding_mask: Mask tensor for padding tokens.
546
+ decoder_attention_mask: Additional attention mask.
547
+ self_attention_cache: Optional cached key and value tensors.
548
+ self_attention_cache_update_index: Index at which to update the
549
+ cache.
550
+
551
+ Returns:
552
+ Combined attention mask tensor.
553
+ """
554
+ decoder_mask = merge_padding_and_attention_mask(
555
+ decoder_sequence, decoder_padding_mask, decoder_attention_mask
556
+ )
557
+ batch_size = ops.shape(decoder_sequence)[0]
558
+ input_length = output_length = ops.shape(decoder_sequence)[1]
559
+ # We need to handle a rectangular causal mask when doing cached
560
+ # decoding. For generative inference, `decoder_sequence` will
561
+ # generally be length 1, and `cache` will be the full generation length.
562
+ if self_attention_cache is not None:
563
+ input_length = ops.shape(self_attention_cache)[2]
564
+
565
+ cache_update_index = (
566
+ 0
567
+ if self_attention_cache_update_index is None
568
+ else self_attention_cache_update_index
569
+ )
570
+
571
+ causal_mask = compute_causal_mask(
572
+ batch_size, input_length, output_length, cache_update_index
573
+ )
574
+
575
+ return (
576
+ ops.minimum(decoder_mask, causal_mask)
577
+ if decoder_mask is not None
578
+ else causal_mask
579
+ )
580
+
581
+ def compute_output_shape(self, decoder_sequence_shape):
582
+ """Computes the output shape of the layer.
583
+
584
+ Args:
585
+ decoder_sequence_shape: Shape of the decoder sequence input.
586
+
587
+ Returns:
588
+ Output shape, which is the same as the input shape.
589
+ """
590
+ return decoder_sequence_shape
591
+
592
+ def get_config(self):
593
+ """Returns the config of the layer.
594
+
595
+ Returns:
596
+ Dictionary containing the parameters used to initialize this layer.
597
+ """
598
+ config = super().get_config()
599
+ config.update(
600
+ {
601
+ "num_query_heads": self.num_query_heads,
602
+ "intermediate_dim": self.intermediate_dim,
603
+ "moe_intermediate_dim": self.moe_intermediate_dim,
604
+ "shared_expert_intermediate_dim": (
605
+ self.shared_expert_intermediate_dim
606
+ ),
607
+ "rope_max_wavelength": self.rope_max_wavelength,
608
+ "num_key_value_heads": self.num_key_value_heads,
609
+ "rope_scaling_factor": self.rope_scaling_factor,
610
+ "layer_norm_epsilon": self.layer_norm_epsilon,
611
+ "dropout": self.dropout,
612
+ "use_sliding_window_attention": (
613
+ self.use_sliding_window_attention
614
+ ),
615
+ "sliding_window_size": self.sliding_window_size,
616
+ "num_experts": self.num_experts,
617
+ "top_k": self.top_k,
618
+ "norm_top_k_prob": self.norm_top_k_prob,
619
+ "decoder_sparse_step": self.decoder_sparse_step,
620
+ "mlp_only_layers": self.mlp_only_layers,
621
+ "output_router_logits": self.output_router_logits,
622
+ "router_aux_loss_coefficient": self.router_aux_loss_coefficient,
623
+ }
624
+ )
625
+ return config
@@ -0,0 +1,32 @@
1
+ import keras
2
+ from keras import ops
3
+
4
+
5
+ class QwenMoeLayerNorm(keras.layers.Layer):
6
+ """A normalization layer for Qwen that implements RMS normalization."""
7
+
8
+ def __init__(self, epsilon=1e-6, **kwargs):
9
+ super().__init__(**kwargs)
10
+ self.epsilon = epsilon
11
+
12
+ def build(self, input_shape):
13
+ dim = input_shape[-1]
14
+ self.scale = self.add_weight(
15
+ name="scale",
16
+ trainable=True,
17
+ shape=(dim,),
18
+ initializer="ones",
19
+ dtype=self.variable_dtype,
20
+ )
21
+ self.built = True
22
+
23
+ def call(self, x):
24
+ x = ops.cast(x, "float32")
25
+ var = ops.mean(ops.power(x, 2), axis=-1, keepdims=True)
26
+ x = x * ops.rsqrt(var + self.epsilon)
27
+ return ops.cast(x * self.scale, self.compute_dtype)
28
+
29
+ def get_config(self):
30
+ config = super().get_config()
31
+ config.update({"epsilon": self.epsilon})
32
+ return config
@@ -0,0 +1,46 @@
1
+ from keras_hub.src.api_export import keras_hub_export
2
+ from keras_hub.src.models.qwen_moe.qwen_moe_backbone import QwenMoeBackbone
3
+ from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer
4
+
5
+
6
+ @keras_hub_export(
7
+ "keras_hub.tokenizers.QwenMoeTokenizer",
8
+ )
9
+ class QwenMoeTokenizer(BytePairTokenizer):
10
+ """Tokenizer for Qwen Moe model.
11
+
12
+ This tokenizer implements byte-pair encoding (BPE) for Qwen models,
13
+ handling special tokens like BOS (beginning of sequence) and EOS (end of
14
+ sequence).
15
+
16
+ Args:
17
+ vocabulary: Dictionary mapping tokens to token IDs, or path to
18
+ vocabulary file.
19
+ merges: List of BPE merges, or path to merges file.
20
+ bos_token: Beginning of sequence token. Defaults to None.
21
+ eos_token: End of sequence token. Defaults to "<|endoftext|>".
22
+ misc_special_tokens: Set of additional special tokens. Defaults to
23
+ empty set.
24
+ """
25
+
26
+ backbone_cls = QwenMoeBackbone
27
+
28
+ def __init__(
29
+ self,
30
+ vocabulary=None,
31
+ merges=None,
32
+ **kwargs,
33
+ ):
34
+ # Add EOS token
35
+ eos_token = "<|endoftext|>"
36
+ self._add_special_token(eos_token, "end_token")
37
+
38
+ self.start_token_id = None
39
+ self.start_token = None
40
+ self.pad_token_id = 0
41
+
42
+ super().__init__(
43
+ vocabulary=vocabulary,
44
+ merges=merges,
45
+ **kwargs,
46
+ )