keras-hub-nightly 0.21.0.dev202505050407__py3-none-any.whl → 0.21.0.dev202505060405__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. keras_hub/models/__init__.py +21 -0
  2. keras_hub/src/models/backbone.py +5 -2
  3. keras_hub/src/models/mixtral/mixtral_attention.py +263 -0
  4. keras_hub/src/models/mixtral/mixtral_backbone.py +207 -0
  5. keras_hub/src/models/mixtral/mixtral_causal_lm.py +281 -0
  6. keras_hub/src/models/mixtral/mixtral_causal_lm_preprocessor.py +76 -0
  7. keras_hub/src/models/mixtral/mixtral_decoder.py +494 -0
  8. keras_hub/src/models/mixtral/mixtral_layer_norm.py +34 -0
  9. keras_hub/src/models/mixtral/mixtral_tokenizer.py +21 -0
  10. keras_hub/src/models/qwen/qwen_attention.py +3 -1
  11. keras_hub/src/models/qwen/qwen_presets.py +61 -0
  12. keras_hub/src/models/qwen_moe/__init__.py +0 -0
  13. keras_hub/src/models/qwen_moe/qwen_moe_attention.py +377 -0
  14. keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +373 -0
  15. keras_hub/src/models/qwen_moe/qwen_moe_causal_lm.py +350 -0
  16. keras_hub/src/models/qwen_moe/qwen_moe_causal_lm_preprocessor.py +17 -0
  17. keras_hub/src/models/qwen_moe/qwen_moe_decoder.py +625 -0
  18. keras_hub/src/models/qwen_moe/qwen_moe_layernorm.py +32 -0
  19. keras_hub/src/models/qwen_moe/qwen_moe_tokenizer.py +46 -0
  20. keras_hub/src/models/retinanet/retinanet_image_converter.py +0 -13
  21. keras_hub/src/models/retinanet/retinanet_presets.py +2 -2
  22. keras_hub/src/models/task.py +5 -2
  23. keras_hub/src/utils/keras_utils.py +11 -0
  24. keras_hub/src/utils/preset_utils.py +69 -9
  25. keras_hub/src/utils/tensor_utils.py +27 -1
  26. keras_hub/src/utils/transformers/convert_mixtral.py +139 -0
  27. keras_hub/src/utils/transformers/convert_qwen_moe.py +253 -0
  28. keras_hub/src/utils/transformers/preset_loader.py +6 -0
  29. keras_hub/src/version.py +1 -1
  30. keras_hub/tokenizers/__init__.py +6 -0
  31. {keras_hub_nightly-0.21.0.dev202505050407.dist-info → keras_hub_nightly-0.21.0.dev202505060405.dist-info}/METADATA +1 -1
  32. {keras_hub_nightly-0.21.0.dev202505050407.dist-info → keras_hub_nightly-0.21.0.dev202505060405.dist-info}/RECORD +34 -16
  33. {keras_hub_nightly-0.21.0.dev202505050407.dist-info → keras_hub_nightly-0.21.0.dev202505060405.dist-info}/WHEEL +0 -0
  34. {keras_hub_nightly-0.21.0.dev202505050407.dist-info → keras_hub_nightly-0.21.0.dev202505060405.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,373 @@
1
+ import keras
2
+ from keras import ops
3
+
4
+ from keras_hub.src.api_export import keras_hub_export
5
+ from keras_hub.src.layers.modeling.reversible_embedding import (
6
+ ReversibleEmbedding,
7
+ )
8
+ from keras_hub.src.models.backbone import Backbone
9
+ from keras_hub.src.models.qwen.qwen_layernorm import QwenLayerNorm
10
+ from keras_hub.src.models.qwen_moe.qwen_moe_decoder import (
11
+ QwenMoeTransformerDecoder,
12
+ )
13
+
14
+
15
+ def _qwen_moe_kernel_initializer(stddev=0.02):
16
+ return keras.initializers.RandomNormal(stddev=stddev)
17
+
18
+
19
+ @keras_hub_export(
20
+ "keras_hub.models.QwenMoeBackbone",
21
+ )
22
+ class QwenMoeBackbone(Backbone):
23
+ """Qwen MoE core network with hyperparameters.
24
+
25
+ This backbone implements the base Transformer network for the Qwen MoE
26
+ model. It includes embedding lookups and transformer layers with a Mixture
27
+ of Experts (MoE) architecture, where each layer uses a sparse set of experts
28
+ for efficient computation. This backbone outputs the final hidden states for
29
+ each token, not generative predictions over the vocabulary space. For higher
30
+ -level object for text generation, see `keras_hub.models.QwenMoeCausalLM`.
31
+
32
+ The default constructor gives a fully customizable, randomly initialized
33
+ Qwen MoE model with any number of layers, heads, and embedding dimensions.
34
+ To load preset architectures and weights, use the `from_preset` constructor.
35
+
36
+ Args:
37
+ vocabulary_size: int. The size of the token vocabulary.
38
+ num_layers: int. The number of transformer layers.
39
+ num_query_heads: int. The number of heads for the query projections in
40
+ the attention layer.
41
+ num_key_value_heads: int. The number of heads for the key and value
42
+ projections in the attention layer.
43
+ hidden_dim: int. The size of the transformer hidden state at the end of
44
+ each transformer layer.
45
+ intermediate_dim: int. The output dimension of the first Dense layer in
46
+ the feedforward network for each transformer.
47
+ moe_intermediate_dim: int. The intermediate dimension for each expert
48
+ in the MoE feedforward network.
49
+ shared_expert_intermediate_dim: int. The intermediate dimension for the
50
+ shared expert in the MoE feedforward network.
51
+ num_experts: int. The number of experts in each MoE layer.
52
+ top_k: int. The number of top experts to select for each token in the
53
+ MoE layer.
54
+ head_dim: int. The size of each attention head.
55
+ layer_norm_epsilon: float. The epsilon value used for every layer norm
56
+ in the transformer model.
57
+ dropout: float. Dropout probability for the transformer encoder.
58
+ use_sliding_window_attention: bool. Whether to use sliding local window
59
+ attention. Defaults to False.
60
+ sliding_window_size: int. Size of the sliding local window. Defaults to
61
+ 4096.
62
+ max_sequence_length: int. The maximum sequence length supported by the
63
+ model. Defaults to 4096.
64
+ dtype: str or `keras.mixed_precision.DTypePolicy`. The dtype to use for
65
+ the model's computations and weights. Note that some computations,
66
+ such as softmax and layer normalization, will always be done at
67
+ float32 precision regardless of dtype.
68
+
69
+ Example:
70
+ ```python
71
+ input_data = {
72
+ "token_ids": np.ones(shape=(1, 12), dtype="int32"),
73
+ "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]),
74
+ }
75
+
76
+ # Pretrained Qwen MoE decoder.
77
+ model = keras_hub.models.QwenMoeBackbone.from_preset("qwen_moe_a2_7b")
78
+ model(input_data)
79
+
80
+ # Randomly initialized Qwen MoE decoder with custom config.
81
+ model = keras_hub.models.QwenMoeBackbone(
82
+ vocabulary_size=151936,
83
+ num_layers=28,
84
+ num_query_heads=16,
85
+ num_key_value_heads=8,
86
+ hidden_dim=2048,
87
+ intermediate_dim=4096,
88
+ moe_intermediate_dim=128,
89
+ shared_expert_intermediate_dim=4096,
90
+ num_experts=60,
91
+ top_k=4,
92
+ head_dim=128,
93
+ max_sequence_length=4096,
94
+ )
95
+ model(input_data)
96
+ """
97
+
98
+ def __init__(
99
+ self,
100
+ vocabulary_size,
101
+ num_layers,
102
+ num_query_heads,
103
+ num_key_value_heads,
104
+ hidden_dim,
105
+ intermediate_dim,
106
+ moe_intermediate_dim,
107
+ shared_expert_intermediate_dim,
108
+ num_experts,
109
+ top_k=4,
110
+ norm_top_k_prob=False,
111
+ decoder_sparse_step=1,
112
+ rope_max_wavelength=10000,
113
+ rope_scaling_factor=1.0,
114
+ layer_norm_epsilon=1e-6,
115
+ dropout=0,
116
+ dtype=None,
117
+ tie_word_embeddings=False,
118
+ use_sliding_window_attention=False,
119
+ sliding_window_size=32768,
120
+ output_router_logits=False,
121
+ router_aux_loss_coefficient=0.001,
122
+ mlp_only_layers=[],
123
+ training=None,
124
+ **kwargs,
125
+ ):
126
+ # === Layers ===
127
+ self.token_embedding = ReversibleEmbedding(
128
+ input_dim=vocabulary_size,
129
+ output_dim=hidden_dim,
130
+ tie_weights=tie_word_embeddings,
131
+ embeddings_initializer=_qwen_moe_kernel_initializer(stddev=0.01),
132
+ dtype=dtype,
133
+ name="token_embedding",
134
+ )
135
+ self.transformer_layers = []
136
+ for i in range(num_layers):
137
+ layer = QwenMoeTransformerDecoder(
138
+ intermediate_dim=intermediate_dim,
139
+ num_query_heads=num_query_heads,
140
+ num_key_value_heads=num_key_value_heads,
141
+ moe_intermediate_dim=moe_intermediate_dim,
142
+ shared_expert_intermediate_dim=shared_expert_intermediate_dim,
143
+ num_experts=num_experts,
144
+ top_k=top_k,
145
+ norm_top_k_prob=norm_top_k_prob,
146
+ decoder_sparse_step=decoder_sparse_step,
147
+ rope_max_wavelength=rope_max_wavelength,
148
+ rope_scaling_factor=rope_scaling_factor,
149
+ layer_norm_epsilon=layer_norm_epsilon,
150
+ activation=ops.silu,
151
+ kernel_initializer=_qwen_moe_kernel_initializer(stddev=0.02),
152
+ dropout=dropout,
153
+ dtype=dtype,
154
+ use_sliding_window_attention=use_sliding_window_attention,
155
+ sliding_window_size=sliding_window_size,
156
+ output_router_logits=output_router_logits,
157
+ router_aux_loss_coefficient=router_aux_loss_coefficient,
158
+ mlp_only_layers=mlp_only_layers,
159
+ name=f"transformer_layer_{i}",
160
+ )
161
+ self.transformer_layers.append(layer)
162
+ self.layer_norm = QwenLayerNorm(
163
+ epsilon=layer_norm_epsilon,
164
+ dtype=dtype,
165
+ name="sequence_output_layernorm",
166
+ )
167
+
168
+ # === Functional Model ===
169
+ token_id_input = keras.Input(
170
+ shape=(None,), dtype="int32", name="token_ids"
171
+ )
172
+ padding_mask_input = keras.Input(
173
+ shape=(None,), dtype="int32", name="padding_mask"
174
+ )
175
+ x = self.token_embedding(token_id_input)
176
+ for transformer_layer in self.transformer_layers:
177
+ x = transformer_layer(
178
+ x, decoder_padding_mask=padding_mask_input, training=training
179
+ )
180
+ sequence_output = self.layer_norm(x)
181
+ super().__init__(
182
+ inputs={
183
+ "token_ids": token_id_input,
184
+ "padding_mask": padding_mask_input,
185
+ },
186
+ outputs=sequence_output,
187
+ dtype=dtype,
188
+ **kwargs,
189
+ )
190
+
191
+ # === Config ===
192
+ self.vocabulary_size = vocabulary_size
193
+ self.num_layers = num_layers
194
+ self.num_query_heads = num_query_heads
195
+ self.hidden_dim = hidden_dim
196
+ self.intermediate_dim = intermediate_dim
197
+ self.moe_intermediate_dim = moe_intermediate_dim
198
+ self.shared_expert_intermediate_dim = shared_expert_intermediate_dim
199
+ self.rope_max_wavelength = rope_max_wavelength
200
+ self.num_key_value_heads = num_key_value_heads
201
+ self.rope_scaling_factor = rope_scaling_factor
202
+ self.layer_norm_epsilon = layer_norm_epsilon
203
+ self.dropout = dropout
204
+ self.tie_word_embeddings = tie_word_embeddings
205
+ self.use_sliding_window_attention = use_sliding_window_attention
206
+ self.sliding_window_size = sliding_window_size
207
+ self.num_experts = num_experts
208
+ self.top_k = top_k
209
+ self.norm_top_k_prob = norm_top_k_prob
210
+ self.decoder_sparse_step = decoder_sparse_step
211
+ self.mlp_only_layers = mlp_only_layers
212
+ self.router_aux_loss_coefficient = router_aux_loss_coefficient
213
+ self.output_router_logits = output_router_logits
214
+
215
+ def get_config(self):
216
+ config = super().get_config()
217
+ config.update(
218
+ {
219
+ "vocabulary_size": self.vocabulary_size,
220
+ "num_layers": self.num_layers,
221
+ "num_query_heads": self.num_query_heads,
222
+ "hidden_dim": self.hidden_dim,
223
+ "intermediate_dim": self.intermediate_dim,
224
+ "moe_intermediate_dim": self.moe_intermediate_dim,
225
+ "shared_expert_intermediate_dim": (
226
+ self.shared_expert_intermediate_dim
227
+ ),
228
+ "rope_max_wavelength": self.rope_max_wavelength,
229
+ "num_key_value_heads": self.num_key_value_heads,
230
+ "rope_scaling_factor": self.rope_scaling_factor,
231
+ "layer_norm_epsilon": self.layer_norm_epsilon,
232
+ "dropout": self.dropout,
233
+ "tie_word_embeddings": self.tie_word_embeddings,
234
+ "use_sliding_window_attention": (
235
+ self.use_sliding_window_attention
236
+ ),
237
+ "sliding_window_size": self.sliding_window_size,
238
+ "num_experts": self.num_experts,
239
+ "top_k": self.top_k,
240
+ "norm_top_k_prob": self.norm_top_k_prob,
241
+ "decoder_sparse_step": self.decoder_sparse_step,
242
+ "mlp_only_layers": self.mlp_only_layers,
243
+ "output_router_logits": self.output_router_logits,
244
+ }
245
+ )
246
+ return config
247
+
248
+ @staticmethod
249
+ def get_layout_map(
250
+ device_mesh,
251
+ model_parallel_dim_name="model",
252
+ data_parallel_dim_name="batch",
253
+ ):
254
+ """Get a `keras.distribution.LayoutMap` for model parallel distribution.
255
+
256
+ The returned `LayoutMap` contains the sharding spec for the QwenMoe
257
+ backbone weights, so that you can use it to distribute weights across
258
+ the accelerators.
259
+
260
+ Example:
261
+ ```
262
+ # Feel free to change the mesh shape to balance data and model
263
+ # parallelism
264
+ mesh = keras.distribution.DeviceMesh(
265
+ shape=(1, 8),
266
+ axis_names=('batch', 'model'),
267
+ devices=keras.distribution.list_devices(),
268
+ )
269
+ layout_map = QwenMoeBackbone.get_layout_map(
270
+ mesh,
271
+ model_parallel_dim_name="model",
272
+ )
273
+
274
+ distribution = keras.distribution.ModelParallel(
275
+ layout_map=layout_map,
276
+ batch_dim_name='batch',
277
+ )
278
+
279
+ with distribution.scope():
280
+ qwen_moe_model = keras_hub.models.QwenMoeBackbone.from_preset()
281
+ ```
282
+
283
+ To see how the layout map was applied, load the model then run
284
+ (for one decoder block):
285
+ ```
286
+ embedding_layer = qwen_moe_model.backbone.get_layer("token_embedding")
287
+ decoder_block_1 = qwen_moe_model.backbone.get_layer(
288
+ 'transformer_layer_0'
289
+ )
290
+ for variable in embedding_layer.weights + decoder_block_1.weights:
291
+ print(
292
+ f'{variable.path:<58} {str(variable.shape):<16} '
293
+ f'{str(variable.value.sharding.spec)}'
294
+ )
295
+ ```
296
+
297
+ Args:
298
+ device_mesh: The `keras.distribution.DeviceMesh` instance for
299
+ distribution.
300
+ model_parallel_dim_name: The axis name of the device mesh, where
301
+ the weights should be partition on.
302
+ data_parallel_dim_name: The axis name of the device mesh, where
303
+ the data should be partition on.
304
+ Return:
305
+ `keras.distribution.LayoutMap` that contains the sharding spec
306
+ for all the model weights.
307
+ """
308
+ # The weight path and shape of the Llama backbone is like below
309
+ # token_embedding/embeddings (128256, 2048)
310
+ # repeat block for decoder
311
+ # transformer_layer_0/self_attention/query/kernel (2048, 32, 64)
312
+ # transformer_layer_0/self_attention/key/kernel (2048, 8, 64)
313
+ # transformer_layer_0/self_attention/value/kernel (2048, 8, 64)
314
+ # transformer_layer_0/self_attention/attention_output/kernel
315
+ # (32, 64, 2048)
316
+ # transformer_layer_0/self_attention_layernorm/scale (2048,)
317
+ # transformer_layer_0/feedforward_intermediate_dense/kernel
318
+ # (2048, 8192)
319
+ # transformer_layer_0/feedforward_gate_dense/kernel (2048, 8192)
320
+ # transformer_layer_0/feedforward_output_dense/kerne (8192, 2048)
321
+ # transformer_layer_0/feedforward_layernorm/scale (2048,)
322
+
323
+ if not isinstance(device_mesh, keras.distribution.DeviceMesh):
324
+ raise ValueError(
325
+ "Invalid device_mesh type. Expected "
326
+ f"`keras.distribution.Device`, got {type(device_mesh)}"
327
+ )
328
+ if model_parallel_dim_name not in device_mesh.axis_names:
329
+ raise ValueError(
330
+ f"{model_parallel_dim_name} is not found in the "
331
+ f"device_mesh.axis_names. {device_mesh.axis_name=}"
332
+ )
333
+ if data_parallel_dim_name not in device_mesh.axis_names:
334
+ raise ValueError(
335
+ f"{data_parallel_dim_name} is not found in the "
336
+ f"device_mesh.axis_names. {device_mesh.axis_name=}"
337
+ )
338
+ # Note that it is possible to further config the mesh to be 3D, eg
339
+ # (data, seq, model). We leave it as 2D for now for simplicity.
340
+ data_dim = data_parallel_dim_name
341
+ model_dim = model_parallel_dim_name
342
+ # The sharding config is based on the Gemma team training config.
343
+ # See https://arxiv.org/abs/2403.08295
344
+ layout_map = keras.distribution.LayoutMap(device_mesh)
345
+ layout_map["token_embedding/embeddings"] = (model_dim, data_dim)
346
+ layout_map[
347
+ "transformer_layer.*self_attention.*(query|key|value).kernel"
348
+ ] = (
349
+ model_dim,
350
+ data_dim,
351
+ None,
352
+ )
353
+ layout_map["transformer_layer.*attention_output.kernel"] = (
354
+ model_dim,
355
+ None,
356
+ data_dim,
357
+ )
358
+ layout_map[
359
+ "transformer_layer.*feedforward_intermediate_dense.kernel"
360
+ ] = (
361
+ data_dim,
362
+ model_dim,
363
+ )
364
+ layout_map["transformer_layer.*feedforward_gate_dense.kernel"] = (
365
+ data_dim,
366
+ model_dim,
367
+ )
368
+ layout_map["transformer_layer.*feedforward_output_dense.kernel"] = (
369
+ model_dim,
370
+ data_dim,
371
+ )
372
+
373
+ return layout_map