keras-hub-nightly 0.21.0.dev202505050407__py3-none-any.whl → 0.21.0.dev202505060405__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. keras_hub/models/__init__.py +21 -0
  2. keras_hub/src/models/backbone.py +5 -2
  3. keras_hub/src/models/mixtral/mixtral_attention.py +263 -0
  4. keras_hub/src/models/mixtral/mixtral_backbone.py +207 -0
  5. keras_hub/src/models/mixtral/mixtral_causal_lm.py +281 -0
  6. keras_hub/src/models/mixtral/mixtral_causal_lm_preprocessor.py +76 -0
  7. keras_hub/src/models/mixtral/mixtral_decoder.py +494 -0
  8. keras_hub/src/models/mixtral/mixtral_layer_norm.py +34 -0
  9. keras_hub/src/models/mixtral/mixtral_tokenizer.py +21 -0
  10. keras_hub/src/models/qwen/qwen_attention.py +3 -1
  11. keras_hub/src/models/qwen/qwen_presets.py +61 -0
  12. keras_hub/src/models/qwen_moe/__init__.py +0 -0
  13. keras_hub/src/models/qwen_moe/qwen_moe_attention.py +377 -0
  14. keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +373 -0
  15. keras_hub/src/models/qwen_moe/qwen_moe_causal_lm.py +350 -0
  16. keras_hub/src/models/qwen_moe/qwen_moe_causal_lm_preprocessor.py +17 -0
  17. keras_hub/src/models/qwen_moe/qwen_moe_decoder.py +625 -0
  18. keras_hub/src/models/qwen_moe/qwen_moe_layernorm.py +32 -0
  19. keras_hub/src/models/qwen_moe/qwen_moe_tokenizer.py +46 -0
  20. keras_hub/src/models/retinanet/retinanet_image_converter.py +0 -13
  21. keras_hub/src/models/retinanet/retinanet_presets.py +2 -2
  22. keras_hub/src/models/task.py +5 -2
  23. keras_hub/src/utils/keras_utils.py +11 -0
  24. keras_hub/src/utils/preset_utils.py +69 -9
  25. keras_hub/src/utils/tensor_utils.py +27 -1
  26. keras_hub/src/utils/transformers/convert_mixtral.py +139 -0
  27. keras_hub/src/utils/transformers/convert_qwen_moe.py +253 -0
  28. keras_hub/src/utils/transformers/preset_loader.py +6 -0
  29. keras_hub/src/version.py +1 -1
  30. keras_hub/tokenizers/__init__.py +6 -0
  31. {keras_hub_nightly-0.21.0.dev202505050407.dist-info → keras_hub_nightly-0.21.0.dev202505060405.dist-info}/METADATA +1 -1
  32. {keras_hub_nightly-0.21.0.dev202505050407.dist-info → keras_hub_nightly-0.21.0.dev202505060405.dist-info}/RECORD +34 -16
  33. {keras_hub_nightly-0.21.0.dev202505050407.dist-info → keras_hub_nightly-0.21.0.dev202505060405.dist-info}/WHEEL +0 -0
  34. {keras_hub_nightly-0.21.0.dev202505050407.dist-info → keras_hub_nightly-0.21.0.dev202505060405.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,494 @@
1
+ import keras
2
+ from keras import ops
3
+
4
+ from keras_hub.src.layers.modeling.transformer_layer_utils import (
5
+ compute_causal_mask,
6
+ )
7
+ from keras_hub.src.layers.modeling.transformer_layer_utils import (
8
+ merge_padding_and_attention_mask,
9
+ )
10
+ from keras_hub.src.models.mixtral.mixtral_attention import (
11
+ CachedMixtralAttention,
12
+ )
13
+ from keras_hub.src.models.mixtral.mixtral_layer_norm import (
14
+ MixtralLayerNormalization,
15
+ )
16
+ from keras_hub.src.utils.keras_utils import clone_initializer
17
+
18
+
19
+ def compute_load_balancing_loss(
20
+ router_logits, num_experts, top_k, attention_mask=None
21
+ ):
22
+ """Compute the load balancing auxiliary loss for a single MoE layer.
23
+
24
+ Args:
25
+ router_logits: Tensor of shape (batch_size * seq_len, num_experts).
26
+ num_experts: Integer, total number of experts.
27
+ top_k: Integer, number of experts to select per token.
28
+ attention_mask: Tensor of shape (batch_size, seq_len), optional mask
29
+ for padding.
30
+ Returns:
31
+ Scalar tensor representing the auxiliary loss.
32
+ """
33
+ # Compute routing probabilities
34
+ routing_weights = ops.softmax(
35
+ router_logits, axis=-1
36
+ ) # Shape: (batch_size * seq_len, num_experts)
37
+
38
+ # Get top-k experts
39
+ top_k_weights, selected_experts = ops.top_k(
40
+ routing_weights, k=top_k
41
+ ) # Shape: (batch_size * seq_len, top_k) for both
42
+
43
+ # Create one-hot encoding for selected experts
44
+ expert_mask = ops.one_hot(
45
+ selected_experts, num_experts
46
+ ) # Shape: (batch_size * seq_len, top_k, num_experts)
47
+
48
+ if attention_mask is not None:
49
+ # Flatten attention_mask to match router_logits
50
+ seq_len = ops.shape(attention_mask)[1]
51
+ batch_seq_len = ops.shape(router_logits)[0]
52
+ # Dynamically compute the batch size to match router_logits
53
+ target_batch_size = batch_seq_len // seq_len
54
+ # Slice attention_mask to match the expected batch size
55
+ attention_mask = ops.slice(
56
+ attention_mask, [0, 0], [target_batch_size, seq_len]
57
+ )
58
+ flat_mask = ops.reshape(
59
+ attention_mask, (-1,)
60
+ ) # Shape: (batch_size * seq_len,)
61
+ flat_mask = ops.cast(flat_mask, dtype="float32")
62
+ # Expand mask for broadcasting
63
+ expert_attention_mask = ops.expand_dims(
64
+ flat_mask, axis=-1
65
+ ) # Shape: (batch_size * seq_len, 1)
66
+ expert_attention_mask = ops.expand_dims(
67
+ expert_attention_mask, axis=1
68
+ ) # Shape: (batch_size * seq_len, 1, 1)
69
+
70
+ # Compute masked token counts
71
+ tokens_per_expert = ops.sum(
72
+ expert_mask * expert_attention_mask, axis=0
73
+ ) # Shape: (top_k, num_experts)
74
+ mask_sum = ops.sum(expert_attention_mask, axis=0) # Shape: (1, 1)
75
+ tokens_per_expert = tokens_per_expert / ops.maximum(mask_sum, 1e-9)
76
+
77
+ # Compute masked router probabilities
78
+ router_prob_per_expert = ops.sum(
79
+ routing_weights * flat_mask[:, None], axis=0
80
+ ) # Shape: (num_experts,)
81
+ router_prob_per_expert = router_prob_per_expert / ops.maximum(
82
+ ops.sum(flat_mask), 1e-9
83
+ )
84
+ else:
85
+ # Unmasked means
86
+ tokens_per_expert = ops.mean(
87
+ expert_mask, axis=0
88
+ ) # Shape: (top_k, num_experts)
89
+ router_prob_per_expert = ops.mean(
90
+ routing_weights, axis=0
91
+ ) # Shape: (num_experts,)
92
+
93
+ # Average over top_k dimension
94
+ tokens_per_expert = ops.mean(
95
+ tokens_per_expert, axis=0
96
+ ) # Shape: (num_experts,)
97
+
98
+ # Compute the loss
99
+ overall_loss = ops.sum(tokens_per_expert * router_prob_per_expert)
100
+ return overall_loss * num_experts
101
+
102
+
103
+ class MixtralMoeExperts(keras.layers.Layer):
104
+ """Batched feed-forward experts for Mixtral (pure keras.ops)."""
105
+
106
+ def __init__(
107
+ self,
108
+ num_experts,
109
+ hidden_dim,
110
+ intermediate_dim,
111
+ activation_fn="silu",
112
+ kernel_initializer="glorot_uniform",
113
+ **kwargs,
114
+ ):
115
+ super().__init__(**kwargs)
116
+ self.num_experts = num_experts
117
+ self.hidden_dim = hidden_dim
118
+ self.intermediate_dim = intermediate_dim
119
+ self.activation = keras.activations.get(activation_fn)
120
+ self.kernel_initializer = keras.initializers.get(kernel_initializer)
121
+
122
+ def build(self, _):
123
+ # Weight for gate dense layer:
124
+ # [num_experts, hidden_dim, intermediate_dim]
125
+ self._expert_feedforward_gate_dense = self.add_weight(
126
+ shape=(self.num_experts, self.hidden_dim, self.intermediate_dim),
127
+ initializer=self.kernel_initializer,
128
+ trainable=True,
129
+ dtype=self.variable_dtype,
130
+ name="expert_feedforward_gate_dense",
131
+ )
132
+ # Weight for intermediate dense layer:
133
+ # [num_experts, hidden_dim, intermediate_dim]
134
+ self._expert_feedforward_intermediate_dense = self.add_weight(
135
+ shape=(self.num_experts, self.hidden_dim, self.intermediate_dim),
136
+ initializer=self.kernel_initializer,
137
+ trainable=True,
138
+ dtype=self.variable_dtype,
139
+ name="expert_feedforward_intermediate_dense",
140
+ )
141
+ # Weight for output dense layer:
142
+ # [num_experts, intermediate_dim, hidden_dim]
143
+ self._expert_feedforward_output_dense = self.add_weight(
144
+ shape=(self.num_experts, self.intermediate_dim, self.hidden_dim),
145
+ initializer=self.kernel_initializer,
146
+ trainable=True,
147
+ name="expert_feedforward_output_dense",
148
+ )
149
+ self.built = True
150
+
151
+ def call(self, hidden_states):
152
+ # Compute gate output for all experts:
153
+ # [num_experts, tokens, intermediate_dim]
154
+ gate = ops.einsum(
155
+ "th,ehm->etm", hidden_states, self._expert_feedforward_gate_dense
156
+ )
157
+ gate = ops.cast(gate, "float32") # Match PyTorch SiLU precision
158
+ gate = self.activation(gate)
159
+ gate = ops.cast(gate, self.compute_dtype)
160
+
161
+ # Compute intermediate output for all experts:
162
+ # [num_experts, tokens, intermediate_dim]
163
+ intermediate = ops.einsum(
164
+ "th,ehm->etm",
165
+ hidden_states,
166
+ self._expert_feedforward_intermediate_dense,
167
+ )
168
+ hidden = intermediate * gate # Element-wise multiplication
169
+
170
+ # Compute final output: [num_experts, tokens, hidden_dim]
171
+ out = ops.einsum(
172
+ "eti,eih->eth", hidden, self._expert_feedforward_output_dense
173
+ )
174
+ return out
175
+
176
+
177
+ class MixtralSparseMoeBlock(keras.layers.Layer):
178
+ """Mixtral sparse MoE block rewritten in batched style."""
179
+
180
+ def __init__(
181
+ self,
182
+ hidden_dim,
183
+ intermediate_dim,
184
+ num_experts,
185
+ top_k=2,
186
+ router_jitter_noise=0.0,
187
+ layer_norm_epsilon=1e-5,
188
+ router_aux_loss_coef=0.02,
189
+ kernel_initializer="glorot_uniform",
190
+ **kwargs,
191
+ ):
192
+ super().__init__(**kwargs)
193
+ self.hidden_dim = hidden_dim
194
+ self.intermediate_dim = intermediate_dim
195
+ self.num_experts = num_experts
196
+ self.top_k = top_k
197
+ self.router_jitter_noise = router_jitter_noise
198
+ self.layer_norm_epsilon = layer_norm_epsilon
199
+ self.router_aux_loss_coef = router_aux_loss_coef
200
+ self.kernel_initializer = keras.initializers.get(kernel_initializer)
201
+
202
+ def build(self, decoder_sequence_shape):
203
+ # Router dense layer to compute logits for expert selection
204
+ self._sparse_feedforward_gate_dense = keras.layers.Dense(
205
+ self.num_experts,
206
+ kernel_initializer=self.kernel_initializer,
207
+ use_bias=False,
208
+ dtype=self.dtype_policy,
209
+ name="sparse_feedforward_gate_dense",
210
+ )
211
+ self._sparse_feedforward_gate_dense.build(decoder_sequence_shape)
212
+
213
+ # Batched expert bank
214
+ self.expert_bank = MixtralMoeExperts(
215
+ num_experts=self.num_experts,
216
+ hidden_dim=self.hidden_dim,
217
+ intermediate_dim=self.intermediate_dim,
218
+ kernel_initializer=self.kernel_initializer,
219
+ name="experts",
220
+ dtype=self.dtype_policy,
221
+ )
222
+ self.expert_bank.build(decoder_sequence_shape)
223
+ self.built = True
224
+
225
+ def call(self, hidden_states, attention_mask=None, training=False):
226
+ batch_size, seq_len, _ = ops.shape(hidden_states)
227
+ hidden_states_flattened = ops.reshape(
228
+ hidden_states, (-1, self.hidden_dim)
229
+ )
230
+
231
+ # Apply jitter noise during training if specified
232
+ if training and self.router_jitter_noise > 0:
233
+ random_factors = ops.random.uniform(
234
+ shape=ops.shape(hidden_states_flattened),
235
+ minval=1.0 - self.router_jitter_noise,
236
+ maxval=1.0 + self.router_jitter_noise,
237
+ dtype=hidden_states_flattened.dtype,
238
+ )
239
+ hidden_states_flattened = hidden_states_flattened * random_factors
240
+
241
+ # Compute router logits and probabilities
242
+ router_logits = self._sparse_feedforward_gate_dense(
243
+ hidden_states_flattened
244
+ )
245
+ router_probs = ops.softmax(router_logits, axis=-1)
246
+
247
+ top_p, top_i = ops.top_k(router_probs, k=self.top_k)
248
+ sum_topk = ops.sum(top_p, axis=-1, keepdims=True)
249
+ top_p = top_p / sum_topk # Normalize top-k probabilities
250
+
251
+ one_hot = ops.one_hot(top_i, self.num_experts)
252
+ one_hot = ops.cast(one_hot, top_p.dtype)
253
+ routing_full = ops.sum(one_hot * top_p[..., None], axis=1)
254
+ routing_full = ops.transpose(routing_full, (1, 0))
255
+ routing_full = ops.cast(routing_full, hidden_states_flattened.dtype)
256
+
257
+ expert_out = self.expert_bank(hidden_states_flattened)
258
+
259
+ weighted_out = expert_out * routing_full[:, :, None]
260
+ expert_contribution = ops.sum(weighted_out, axis=0)
261
+
262
+ out = ops.reshape(
263
+ expert_contribution, (batch_size, seq_len, self.hidden_dim)
264
+ )
265
+
266
+ if training:
267
+ aux_loss = compute_load_balancing_loss(
268
+ router_logits=router_logits,
269
+ num_experts=self.num_experts,
270
+ top_k=self.top_k,
271
+ attention_mask=attention_mask,
272
+ )
273
+ self.add_loss(self.router_aux_loss_coef * aux_loss)
274
+
275
+ return out, router_logits
276
+
277
+
278
+ class MixtralTransformerDecoder(keras.layers.Layer):
279
+ def __init__(
280
+ self,
281
+ intermediate_dim,
282
+ num_query_heads,
283
+ num_key_value_heads,
284
+ num_experts,
285
+ top_k=2,
286
+ router_jitter_noise=0.0,
287
+ output_router_logits=False,
288
+ rope_max_wavelength=10000,
289
+ rope_scaling_factor=1.0,
290
+ activation="silu",
291
+ layer_norm_epsilon=1e-5,
292
+ router_aux_loss_coef=0.02,
293
+ kernel_initializer="glorot_uniform",
294
+ sliding_window=512,
295
+ dropout=0,
296
+ **kwargs,
297
+ ):
298
+ super().__init__(**kwargs)
299
+ self.intermediate_dim = intermediate_dim
300
+ self.num_query_heads = num_query_heads
301
+ self.num_key_value_heads = num_key_value_heads
302
+
303
+ self.num_experts = num_experts
304
+ self.top_k = top_k
305
+ self.router_jitter_noise = router_jitter_noise
306
+
307
+ self.rope_max_wavelength = rope_max_wavelength
308
+ self.rope_scaling_factor = rope_scaling_factor
309
+
310
+ self.dropout = dropout
311
+
312
+ self.sliding_window = sliding_window
313
+ self.activation = keras.activations.get(activation)
314
+ self.layer_norm_epsilon = layer_norm_epsilon
315
+ self.kernel_initializer = keras.initializers.get(kernel_initializer)
316
+
317
+ self.router_aux_loss_coef = router_aux_loss_coef
318
+ self.output_router_logits = output_router_logits
319
+
320
+ self.supports_masking = True
321
+
322
+ def build(self, decoder_sequence_shape):
323
+ self._decoder_sequence_shape = decoder_sequence_shape
324
+ self.hidden_dim = decoder_sequence_shape[-1]
325
+
326
+ # Self attention layer.
327
+ self._self_attention_layer = CachedMixtralAttention(
328
+ num_query_heads=self.num_query_heads,
329
+ num_key_value_heads=self.num_key_value_heads,
330
+ rope_max_wavelength=self.rope_max_wavelength,
331
+ rope_scaling_factor=self.rope_scaling_factor,
332
+ sliding_window=self.sliding_window,
333
+ kernel_initializer=clone_initializer(self.kernel_initializer),
334
+ dropout=self.dropout,
335
+ dtype=self.dtype_policy,
336
+ name="self_attention",
337
+ )
338
+ self._self_attention_layer.build(decoder_sequence_shape)
339
+
340
+ self._self_attention_layernorm = MixtralLayerNormalization(
341
+ epsilon=self.layer_norm_epsilon,
342
+ dtype=self.dtype_policy,
343
+ name="self_attention_layernorm",
344
+ )
345
+ self._self_attention_layernorm.build(decoder_sequence_shape)
346
+ self._self_attention_dropout = keras.layers.Dropout(
347
+ rate=self.dropout,
348
+ dtype=self.dtype_policy,
349
+ name="self_attention_dropout",
350
+ )
351
+
352
+ self._sparse_moe_block = MixtralSparseMoeBlock(
353
+ hidden_dim=self.hidden_dim,
354
+ intermediate_dim=self.intermediate_dim,
355
+ num_experts=self.num_experts,
356
+ top_k=self.top_k,
357
+ router_jitter_noise=self.router_jitter_noise,
358
+ router_aux_loss_coef=self.router_aux_loss_coef,
359
+ dtype=self.dtype_policy,
360
+ )
361
+ self._sparse_moe_block.build(decoder_sequence_shape)
362
+
363
+ self._feedforward_layernorm = MixtralLayerNormalization(
364
+ epsilon=self.layer_norm_epsilon,
365
+ dtype=self.dtype_policy,
366
+ name="feedforward_layernorm",
367
+ )
368
+ self._feedforward_layernorm.build(decoder_sequence_shape)
369
+
370
+ self.built = True
371
+
372
+ def call(
373
+ self,
374
+ decoder_sequence,
375
+ decoder_padding_mask=None,
376
+ decoder_attention_mask=None,
377
+ self_attention_cache=None,
378
+ self_attention_cache_update_index=None,
379
+ training=None,
380
+ ):
381
+ self_attention_mask = self._compute_self_attention_mask(
382
+ decoder_sequence=decoder_sequence,
383
+ decoder_padding_mask=decoder_padding_mask,
384
+ decoder_attention_mask=decoder_attention_mask,
385
+ self_attention_cache=self_attention_cache,
386
+ self_attention_cache_update_index=self_attention_cache_update_index,
387
+ )
388
+ residual = decoder_sequence
389
+
390
+ x = self._self_attention_layernorm(decoder_sequence)
391
+
392
+ # Self attention block.
393
+ x = self._self_attention_layer(
394
+ hidden_states=x,
395
+ attention_mask=self_attention_mask,
396
+ cache=self_attention_cache,
397
+ cache_update_index=self_attention_cache_update_index,
398
+ )
399
+
400
+ if self_attention_cache is not None:
401
+ x, self_attention_cache = x
402
+
403
+ x = self._self_attention_dropout(x, training=training)
404
+
405
+ x = x + residual
406
+ residual = x
407
+
408
+ x = self._feedforward_layernorm(x)
409
+ x, router_logits = self._sparse_moe_block(
410
+ x, attention_mask=decoder_padding_mask
411
+ )
412
+
413
+ decoder_output = x + residual
414
+
415
+ output = (decoder_output,)
416
+
417
+ if self_attention_cache is not None:
418
+ output += (self_attention_cache,)
419
+
420
+ if self.output_router_logits:
421
+ output += (router_logits,)
422
+
423
+ return output[0] if len(output) == 1 else output
424
+
425
+ def _compute_self_attention_mask(
426
+ self,
427
+ decoder_sequence,
428
+ decoder_padding_mask,
429
+ decoder_attention_mask,
430
+ self_attention_cache,
431
+ self_attention_cache_update_index,
432
+ ):
433
+ decoder_mask = merge_padding_and_attention_mask(
434
+ decoder_sequence, decoder_padding_mask, decoder_attention_mask
435
+ )
436
+ batch_size = ops.shape(decoder_sequence)[0]
437
+ input_length = output_length = ops.shape(decoder_sequence)[1]
438
+ # We need to handle a rectangular causal mask when doing cached
439
+ # decoding. For generative inference, `decoder_sequence` will
440
+ # generally be length 1, and `cache` will be the full generation length.
441
+ if self_attention_cache is not None:
442
+ input_length = ops.shape(self_attention_cache)[2]
443
+
444
+ cache_update_index = (
445
+ 0
446
+ if self_attention_cache_update_index is None
447
+ else self_attention_cache_update_index
448
+ )
449
+
450
+ # The lower traingular attention mask
451
+ causal_mask = compute_causal_mask(
452
+ batch_size, input_length, output_length, cache_update_index
453
+ )
454
+
455
+ # Mixtral uses a banded attention mask if sliding window is not None
456
+ if self.sliding_window is not None:
457
+ # ops.trui/tril has issues with dynamic shape on the tensorflow
458
+ # causal_mask = ops.triu(causal_mask, k=-self.sliding_window)
459
+ i = ops.arange(output_length)[:, None] + cache_update_index
460
+ j = ops.arange(input_length)[None, :]
461
+ causal_mask_upper = ops.cast(i < j + self.sliding_window, "int32")
462
+ causal_mask = ops.minimum(causal_mask, causal_mask_upper)
463
+
464
+ return (
465
+ ops.minimum(decoder_mask, causal_mask)
466
+ if decoder_mask is not None
467
+ else causal_mask
468
+ )
469
+
470
+ def compute_output_shape(self, decoder_sequence_shape):
471
+ return decoder_sequence_shape
472
+
473
+ def get_config(self):
474
+ config = super().get_config()
475
+ config.update(
476
+ {
477
+ "intermediate_dim": self.intermediate_dim,
478
+ "num_query_heads": self.num_query_heads,
479
+ "rope_max_wavelength": self.rope_max_wavelength,
480
+ "rope_scaling_factor": self.rope_scaling_factor,
481
+ "num_key_value_heads": self.num_key_value_heads,
482
+ "num_experts": self.num_experts,
483
+ "top_k": self.top_k,
484
+ "router_jitter_noise": self.router_jitter_noise,
485
+ "sliding_window": self.sliding_window,
486
+ "activation": keras.activations.serialize(self.activation),
487
+ "layer_norm_epsilon": self.layer_norm_epsilon,
488
+ "kernel_initializer": keras.initializers.serialize(
489
+ self.kernel_initializer
490
+ ),
491
+ "dropout": self.dropout,
492
+ }
493
+ )
494
+ return config
@@ -0,0 +1,34 @@
1
+ import keras
2
+ from keras import ops
3
+
4
+
5
+ # NOTE: `keras.layers.LayerNormalization(rms_scaling=True)`
6
+ # does not produce same results
7
+ class MixtralLayerNormalization(keras.layers.Layer):
8
+ """A normalization layer for Mixtral that implements RMS normalization."""
9
+
10
+ def __init__(self, epsilon=1e-6, **kwargs):
11
+ super().__init__(**kwargs)
12
+ self.epsilon = epsilon
13
+
14
+ def build(self, input_shape):
15
+ dim = input_shape[-1]
16
+ self.scale = self.add_weight(
17
+ name="scale",
18
+ trainable=True,
19
+ shape=(dim,),
20
+ initializer="ones",
21
+ dtype=self.variable_dtype,
22
+ )
23
+ self.built = True
24
+
25
+ def call(self, x):
26
+ x = ops.cast(x, "float32")
27
+ var = ops.mean(ops.power(x, 2), axis=-1, keepdims=True)
28
+ x = x * ops.rsqrt(var + self.epsilon)
29
+ return ops.cast(x * self.scale, self.compute_dtype)
30
+
31
+ def get_config(self):
32
+ config = super().get_config()
33
+ config.update({"epsilon": self.epsilon})
34
+ return config
@@ -0,0 +1,21 @@
1
+ from keras_hub.src.api_export import keras_hub_export
2
+ from keras_hub.src.models.mixtral.mixtral_backbone import MixtralBackbone
3
+ from keras_hub.src.tokenizers.sentence_piece_tokenizer import (
4
+ SentencePieceTokenizer,
5
+ )
6
+
7
+
8
+ @keras_hub_export(
9
+ [
10
+ "keras_hub.tokenizers.MixtralTokenizer",
11
+ "keras_hub.models.MixtralTokenizer",
12
+ ]
13
+ )
14
+ class MixtralTokenizer(SentencePieceTokenizer):
15
+ backbone_cls = MixtralBackbone
16
+
17
+ def __init__(self, proto, **kwargs):
18
+ self._add_special_token("<s>", "start_token")
19
+ self._add_special_token("</s>", "end_token")
20
+ self.pad_token_id = 0
21
+ super().__init__(proto=proto, **kwargs)
@@ -287,7 +287,9 @@ class QwenAttention(keras.layers.Layer):
287
287
  if self.use_sliding_window_attention:
288
288
  attention_mask = self._mask_sliding_window(
289
289
  attention_mask,
290
- cache_update_index=cache_update_index,
290
+ cache_update_index=cache_update_index
291
+ if cache_update_index
292
+ else 0,
291
293
  )
292
294
  attention_scores = self._masked_softmax(
293
295
  attention_scores, attention_mask
@@ -0,0 +1,61 @@
1
+ """Qwen preset configurations."""
2
+
3
+ backbone_presets = {
4
+ "qwen2.5_0.5b_en": {
5
+ "metadata": {
6
+ "description": ("24-layer Qwen model with 0.5 billion parameters."),
7
+ "params": 494032768,
8
+ "path": "qwen",
9
+ },
10
+ "kaggle_handle": "kaggle://keras/qwen/keras/qwen2.5_0.5b_en",
11
+ },
12
+ "qwen2.5_3b_en": {
13
+ "metadata": {
14
+ "description": ("36-layer Qwen model with 3.1 billion parameters."),
15
+ "params": 3085938688,
16
+ "path": "qwen",
17
+ },
18
+ "kaggle_handle": "kaggle://keras/qwen/keras/qwen2.5_3b_en",
19
+ },
20
+ "qwen2.5_7b_en": {
21
+ "metadata": {
22
+ "description": ("48-layer Qwen model with 7 billion parameters."),
23
+ "params": 6993420288,
24
+ "path": "qwen",
25
+ },
26
+ "kaggle_handle": "kaggle://keras/qwen/keras/qwen2.5_7b_en/2",
27
+ },
28
+ "qwen2.5_instruct_0.5b_en": {
29
+ "metadata": {
30
+ "description": (
31
+ "Instruction fine-tuned 24-layer Qwen model with 0.5 ",
32
+ "billion parameters.",
33
+ ),
34
+ "params": 494032768,
35
+ "path": "qwen",
36
+ },
37
+ "kaggle_handle": "kaggle://keras/qwen/keras/qwen2.5_instruct_0.5b_en",
38
+ },
39
+ "qwen2.5_instruct_32b_en": {
40
+ "metadata": {
41
+ "description": (
42
+ "Instruction fine-tuned 64-layer Qwen model with 32 ",
43
+ "billion parameters.",
44
+ ),
45
+ "params": 32763876352,
46
+ "path": "qwen",
47
+ },
48
+ "kaggle_handle": "kaggle://keras/qwen/keras/qwen2.5_instruct_32b_en",
49
+ },
50
+ "qwen2.5_instruct_72b_en": {
51
+ "metadata": {
52
+ "description": (
53
+ "Instruction fine-tuned 80-layer Qwen model with 72 ",
54
+ "billion parameters.",
55
+ ),
56
+ "params": 72706203648,
57
+ "path": "qwen",
58
+ },
59
+ "kaggle_handle": "kaggle://keras/qwen/keras/qwen2.5_instruct_72b_en",
60
+ },
61
+ }
File without changes