keras-hub-nightly 0.21.0.dev202505040408__py3-none-any.whl → 0.21.0.dev202505060405__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/models/__init__.py +21 -0
- keras_hub/src/models/backbone.py +5 -2
- keras_hub/src/models/mixtral/mixtral_attention.py +263 -0
- keras_hub/src/models/mixtral/mixtral_backbone.py +207 -0
- keras_hub/src/models/mixtral/mixtral_causal_lm.py +281 -0
- keras_hub/src/models/mixtral/mixtral_causal_lm_preprocessor.py +76 -0
- keras_hub/src/models/mixtral/mixtral_decoder.py +494 -0
- keras_hub/src/models/mixtral/mixtral_layer_norm.py +34 -0
- keras_hub/src/models/mixtral/mixtral_tokenizer.py +21 -0
- keras_hub/src/models/qwen/qwen_attention.py +3 -1
- keras_hub/src/models/qwen/qwen_presets.py +61 -0
- keras_hub/src/models/qwen_moe/__init__.py +0 -0
- keras_hub/src/models/qwen_moe/qwen_moe_attention.py +377 -0
- keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +373 -0
- keras_hub/src/models/qwen_moe/qwen_moe_causal_lm.py +350 -0
- keras_hub/src/models/qwen_moe/qwen_moe_causal_lm_preprocessor.py +17 -0
- keras_hub/src/models/qwen_moe/qwen_moe_decoder.py +625 -0
- keras_hub/src/models/qwen_moe/qwen_moe_layernorm.py +32 -0
- keras_hub/src/models/qwen_moe/qwen_moe_tokenizer.py +46 -0
- keras_hub/src/models/retinanet/retinanet_image_converter.py +0 -13
- keras_hub/src/models/retinanet/retinanet_presets.py +2 -2
- keras_hub/src/models/task.py +5 -2
- keras_hub/src/utils/keras_utils.py +11 -0
- keras_hub/src/utils/preset_utils.py +69 -9
- keras_hub/src/utils/tensor_utils.py +27 -1
- keras_hub/src/utils/transformers/convert_mixtral.py +139 -0
- keras_hub/src/utils/transformers/convert_qwen_moe.py +253 -0
- keras_hub/src/utils/transformers/preset_loader.py +6 -0
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +6 -0
- {keras_hub_nightly-0.21.0.dev202505040408.dist-info → keras_hub_nightly-0.21.0.dev202505060405.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.21.0.dev202505040408.dist-info → keras_hub_nightly-0.21.0.dev202505060405.dist-info}/RECORD +34 -16
- {keras_hub_nightly-0.21.0.dev202505040408.dist-info → keras_hub_nightly-0.21.0.dev202505060405.dist-info}/WHEEL +1 -1
- {keras_hub_nightly-0.21.0.dev202505040408.dist-info → keras_hub_nightly-0.21.0.dev202505060405.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,494 @@
|
|
1
|
+
import keras
|
2
|
+
from keras import ops
|
3
|
+
|
4
|
+
from keras_hub.src.layers.modeling.transformer_layer_utils import (
|
5
|
+
compute_causal_mask,
|
6
|
+
)
|
7
|
+
from keras_hub.src.layers.modeling.transformer_layer_utils import (
|
8
|
+
merge_padding_and_attention_mask,
|
9
|
+
)
|
10
|
+
from keras_hub.src.models.mixtral.mixtral_attention import (
|
11
|
+
CachedMixtralAttention,
|
12
|
+
)
|
13
|
+
from keras_hub.src.models.mixtral.mixtral_layer_norm import (
|
14
|
+
MixtralLayerNormalization,
|
15
|
+
)
|
16
|
+
from keras_hub.src.utils.keras_utils import clone_initializer
|
17
|
+
|
18
|
+
|
19
|
+
def compute_load_balancing_loss(
|
20
|
+
router_logits, num_experts, top_k, attention_mask=None
|
21
|
+
):
|
22
|
+
"""Compute the load balancing auxiliary loss for a single MoE layer.
|
23
|
+
|
24
|
+
Args:
|
25
|
+
router_logits: Tensor of shape (batch_size * seq_len, num_experts).
|
26
|
+
num_experts: Integer, total number of experts.
|
27
|
+
top_k: Integer, number of experts to select per token.
|
28
|
+
attention_mask: Tensor of shape (batch_size, seq_len), optional mask
|
29
|
+
for padding.
|
30
|
+
Returns:
|
31
|
+
Scalar tensor representing the auxiliary loss.
|
32
|
+
"""
|
33
|
+
# Compute routing probabilities
|
34
|
+
routing_weights = ops.softmax(
|
35
|
+
router_logits, axis=-1
|
36
|
+
) # Shape: (batch_size * seq_len, num_experts)
|
37
|
+
|
38
|
+
# Get top-k experts
|
39
|
+
top_k_weights, selected_experts = ops.top_k(
|
40
|
+
routing_weights, k=top_k
|
41
|
+
) # Shape: (batch_size * seq_len, top_k) for both
|
42
|
+
|
43
|
+
# Create one-hot encoding for selected experts
|
44
|
+
expert_mask = ops.one_hot(
|
45
|
+
selected_experts, num_experts
|
46
|
+
) # Shape: (batch_size * seq_len, top_k, num_experts)
|
47
|
+
|
48
|
+
if attention_mask is not None:
|
49
|
+
# Flatten attention_mask to match router_logits
|
50
|
+
seq_len = ops.shape(attention_mask)[1]
|
51
|
+
batch_seq_len = ops.shape(router_logits)[0]
|
52
|
+
# Dynamically compute the batch size to match router_logits
|
53
|
+
target_batch_size = batch_seq_len // seq_len
|
54
|
+
# Slice attention_mask to match the expected batch size
|
55
|
+
attention_mask = ops.slice(
|
56
|
+
attention_mask, [0, 0], [target_batch_size, seq_len]
|
57
|
+
)
|
58
|
+
flat_mask = ops.reshape(
|
59
|
+
attention_mask, (-1,)
|
60
|
+
) # Shape: (batch_size * seq_len,)
|
61
|
+
flat_mask = ops.cast(flat_mask, dtype="float32")
|
62
|
+
# Expand mask for broadcasting
|
63
|
+
expert_attention_mask = ops.expand_dims(
|
64
|
+
flat_mask, axis=-1
|
65
|
+
) # Shape: (batch_size * seq_len, 1)
|
66
|
+
expert_attention_mask = ops.expand_dims(
|
67
|
+
expert_attention_mask, axis=1
|
68
|
+
) # Shape: (batch_size * seq_len, 1, 1)
|
69
|
+
|
70
|
+
# Compute masked token counts
|
71
|
+
tokens_per_expert = ops.sum(
|
72
|
+
expert_mask * expert_attention_mask, axis=0
|
73
|
+
) # Shape: (top_k, num_experts)
|
74
|
+
mask_sum = ops.sum(expert_attention_mask, axis=0) # Shape: (1, 1)
|
75
|
+
tokens_per_expert = tokens_per_expert / ops.maximum(mask_sum, 1e-9)
|
76
|
+
|
77
|
+
# Compute masked router probabilities
|
78
|
+
router_prob_per_expert = ops.sum(
|
79
|
+
routing_weights * flat_mask[:, None], axis=0
|
80
|
+
) # Shape: (num_experts,)
|
81
|
+
router_prob_per_expert = router_prob_per_expert / ops.maximum(
|
82
|
+
ops.sum(flat_mask), 1e-9
|
83
|
+
)
|
84
|
+
else:
|
85
|
+
# Unmasked means
|
86
|
+
tokens_per_expert = ops.mean(
|
87
|
+
expert_mask, axis=0
|
88
|
+
) # Shape: (top_k, num_experts)
|
89
|
+
router_prob_per_expert = ops.mean(
|
90
|
+
routing_weights, axis=0
|
91
|
+
) # Shape: (num_experts,)
|
92
|
+
|
93
|
+
# Average over top_k dimension
|
94
|
+
tokens_per_expert = ops.mean(
|
95
|
+
tokens_per_expert, axis=0
|
96
|
+
) # Shape: (num_experts,)
|
97
|
+
|
98
|
+
# Compute the loss
|
99
|
+
overall_loss = ops.sum(tokens_per_expert * router_prob_per_expert)
|
100
|
+
return overall_loss * num_experts
|
101
|
+
|
102
|
+
|
103
|
+
class MixtralMoeExperts(keras.layers.Layer):
|
104
|
+
"""Batched feed-forward experts for Mixtral (pure keras.ops)."""
|
105
|
+
|
106
|
+
def __init__(
|
107
|
+
self,
|
108
|
+
num_experts,
|
109
|
+
hidden_dim,
|
110
|
+
intermediate_dim,
|
111
|
+
activation_fn="silu",
|
112
|
+
kernel_initializer="glorot_uniform",
|
113
|
+
**kwargs,
|
114
|
+
):
|
115
|
+
super().__init__(**kwargs)
|
116
|
+
self.num_experts = num_experts
|
117
|
+
self.hidden_dim = hidden_dim
|
118
|
+
self.intermediate_dim = intermediate_dim
|
119
|
+
self.activation = keras.activations.get(activation_fn)
|
120
|
+
self.kernel_initializer = keras.initializers.get(kernel_initializer)
|
121
|
+
|
122
|
+
def build(self, _):
|
123
|
+
# Weight for gate dense layer:
|
124
|
+
# [num_experts, hidden_dim, intermediate_dim]
|
125
|
+
self._expert_feedforward_gate_dense = self.add_weight(
|
126
|
+
shape=(self.num_experts, self.hidden_dim, self.intermediate_dim),
|
127
|
+
initializer=self.kernel_initializer,
|
128
|
+
trainable=True,
|
129
|
+
dtype=self.variable_dtype,
|
130
|
+
name="expert_feedforward_gate_dense",
|
131
|
+
)
|
132
|
+
# Weight for intermediate dense layer:
|
133
|
+
# [num_experts, hidden_dim, intermediate_dim]
|
134
|
+
self._expert_feedforward_intermediate_dense = self.add_weight(
|
135
|
+
shape=(self.num_experts, self.hidden_dim, self.intermediate_dim),
|
136
|
+
initializer=self.kernel_initializer,
|
137
|
+
trainable=True,
|
138
|
+
dtype=self.variable_dtype,
|
139
|
+
name="expert_feedforward_intermediate_dense",
|
140
|
+
)
|
141
|
+
# Weight for output dense layer:
|
142
|
+
# [num_experts, intermediate_dim, hidden_dim]
|
143
|
+
self._expert_feedforward_output_dense = self.add_weight(
|
144
|
+
shape=(self.num_experts, self.intermediate_dim, self.hidden_dim),
|
145
|
+
initializer=self.kernel_initializer,
|
146
|
+
trainable=True,
|
147
|
+
name="expert_feedforward_output_dense",
|
148
|
+
)
|
149
|
+
self.built = True
|
150
|
+
|
151
|
+
def call(self, hidden_states):
|
152
|
+
# Compute gate output for all experts:
|
153
|
+
# [num_experts, tokens, intermediate_dim]
|
154
|
+
gate = ops.einsum(
|
155
|
+
"th,ehm->etm", hidden_states, self._expert_feedforward_gate_dense
|
156
|
+
)
|
157
|
+
gate = ops.cast(gate, "float32") # Match PyTorch SiLU precision
|
158
|
+
gate = self.activation(gate)
|
159
|
+
gate = ops.cast(gate, self.compute_dtype)
|
160
|
+
|
161
|
+
# Compute intermediate output for all experts:
|
162
|
+
# [num_experts, tokens, intermediate_dim]
|
163
|
+
intermediate = ops.einsum(
|
164
|
+
"th,ehm->etm",
|
165
|
+
hidden_states,
|
166
|
+
self._expert_feedforward_intermediate_dense,
|
167
|
+
)
|
168
|
+
hidden = intermediate * gate # Element-wise multiplication
|
169
|
+
|
170
|
+
# Compute final output: [num_experts, tokens, hidden_dim]
|
171
|
+
out = ops.einsum(
|
172
|
+
"eti,eih->eth", hidden, self._expert_feedforward_output_dense
|
173
|
+
)
|
174
|
+
return out
|
175
|
+
|
176
|
+
|
177
|
+
class MixtralSparseMoeBlock(keras.layers.Layer):
|
178
|
+
"""Mixtral sparse MoE block rewritten in batched style."""
|
179
|
+
|
180
|
+
def __init__(
|
181
|
+
self,
|
182
|
+
hidden_dim,
|
183
|
+
intermediate_dim,
|
184
|
+
num_experts,
|
185
|
+
top_k=2,
|
186
|
+
router_jitter_noise=0.0,
|
187
|
+
layer_norm_epsilon=1e-5,
|
188
|
+
router_aux_loss_coef=0.02,
|
189
|
+
kernel_initializer="glorot_uniform",
|
190
|
+
**kwargs,
|
191
|
+
):
|
192
|
+
super().__init__(**kwargs)
|
193
|
+
self.hidden_dim = hidden_dim
|
194
|
+
self.intermediate_dim = intermediate_dim
|
195
|
+
self.num_experts = num_experts
|
196
|
+
self.top_k = top_k
|
197
|
+
self.router_jitter_noise = router_jitter_noise
|
198
|
+
self.layer_norm_epsilon = layer_norm_epsilon
|
199
|
+
self.router_aux_loss_coef = router_aux_loss_coef
|
200
|
+
self.kernel_initializer = keras.initializers.get(kernel_initializer)
|
201
|
+
|
202
|
+
def build(self, decoder_sequence_shape):
|
203
|
+
# Router dense layer to compute logits for expert selection
|
204
|
+
self._sparse_feedforward_gate_dense = keras.layers.Dense(
|
205
|
+
self.num_experts,
|
206
|
+
kernel_initializer=self.kernel_initializer,
|
207
|
+
use_bias=False,
|
208
|
+
dtype=self.dtype_policy,
|
209
|
+
name="sparse_feedforward_gate_dense",
|
210
|
+
)
|
211
|
+
self._sparse_feedforward_gate_dense.build(decoder_sequence_shape)
|
212
|
+
|
213
|
+
# Batched expert bank
|
214
|
+
self.expert_bank = MixtralMoeExperts(
|
215
|
+
num_experts=self.num_experts,
|
216
|
+
hidden_dim=self.hidden_dim,
|
217
|
+
intermediate_dim=self.intermediate_dim,
|
218
|
+
kernel_initializer=self.kernel_initializer,
|
219
|
+
name="experts",
|
220
|
+
dtype=self.dtype_policy,
|
221
|
+
)
|
222
|
+
self.expert_bank.build(decoder_sequence_shape)
|
223
|
+
self.built = True
|
224
|
+
|
225
|
+
def call(self, hidden_states, attention_mask=None, training=False):
|
226
|
+
batch_size, seq_len, _ = ops.shape(hidden_states)
|
227
|
+
hidden_states_flattened = ops.reshape(
|
228
|
+
hidden_states, (-1, self.hidden_dim)
|
229
|
+
)
|
230
|
+
|
231
|
+
# Apply jitter noise during training if specified
|
232
|
+
if training and self.router_jitter_noise > 0:
|
233
|
+
random_factors = ops.random.uniform(
|
234
|
+
shape=ops.shape(hidden_states_flattened),
|
235
|
+
minval=1.0 - self.router_jitter_noise,
|
236
|
+
maxval=1.0 + self.router_jitter_noise,
|
237
|
+
dtype=hidden_states_flattened.dtype,
|
238
|
+
)
|
239
|
+
hidden_states_flattened = hidden_states_flattened * random_factors
|
240
|
+
|
241
|
+
# Compute router logits and probabilities
|
242
|
+
router_logits = self._sparse_feedforward_gate_dense(
|
243
|
+
hidden_states_flattened
|
244
|
+
)
|
245
|
+
router_probs = ops.softmax(router_logits, axis=-1)
|
246
|
+
|
247
|
+
top_p, top_i = ops.top_k(router_probs, k=self.top_k)
|
248
|
+
sum_topk = ops.sum(top_p, axis=-1, keepdims=True)
|
249
|
+
top_p = top_p / sum_topk # Normalize top-k probabilities
|
250
|
+
|
251
|
+
one_hot = ops.one_hot(top_i, self.num_experts)
|
252
|
+
one_hot = ops.cast(one_hot, top_p.dtype)
|
253
|
+
routing_full = ops.sum(one_hot * top_p[..., None], axis=1)
|
254
|
+
routing_full = ops.transpose(routing_full, (1, 0))
|
255
|
+
routing_full = ops.cast(routing_full, hidden_states_flattened.dtype)
|
256
|
+
|
257
|
+
expert_out = self.expert_bank(hidden_states_flattened)
|
258
|
+
|
259
|
+
weighted_out = expert_out * routing_full[:, :, None]
|
260
|
+
expert_contribution = ops.sum(weighted_out, axis=0)
|
261
|
+
|
262
|
+
out = ops.reshape(
|
263
|
+
expert_contribution, (batch_size, seq_len, self.hidden_dim)
|
264
|
+
)
|
265
|
+
|
266
|
+
if training:
|
267
|
+
aux_loss = compute_load_balancing_loss(
|
268
|
+
router_logits=router_logits,
|
269
|
+
num_experts=self.num_experts,
|
270
|
+
top_k=self.top_k,
|
271
|
+
attention_mask=attention_mask,
|
272
|
+
)
|
273
|
+
self.add_loss(self.router_aux_loss_coef * aux_loss)
|
274
|
+
|
275
|
+
return out, router_logits
|
276
|
+
|
277
|
+
|
278
|
+
class MixtralTransformerDecoder(keras.layers.Layer):
|
279
|
+
def __init__(
|
280
|
+
self,
|
281
|
+
intermediate_dim,
|
282
|
+
num_query_heads,
|
283
|
+
num_key_value_heads,
|
284
|
+
num_experts,
|
285
|
+
top_k=2,
|
286
|
+
router_jitter_noise=0.0,
|
287
|
+
output_router_logits=False,
|
288
|
+
rope_max_wavelength=10000,
|
289
|
+
rope_scaling_factor=1.0,
|
290
|
+
activation="silu",
|
291
|
+
layer_norm_epsilon=1e-5,
|
292
|
+
router_aux_loss_coef=0.02,
|
293
|
+
kernel_initializer="glorot_uniform",
|
294
|
+
sliding_window=512,
|
295
|
+
dropout=0,
|
296
|
+
**kwargs,
|
297
|
+
):
|
298
|
+
super().__init__(**kwargs)
|
299
|
+
self.intermediate_dim = intermediate_dim
|
300
|
+
self.num_query_heads = num_query_heads
|
301
|
+
self.num_key_value_heads = num_key_value_heads
|
302
|
+
|
303
|
+
self.num_experts = num_experts
|
304
|
+
self.top_k = top_k
|
305
|
+
self.router_jitter_noise = router_jitter_noise
|
306
|
+
|
307
|
+
self.rope_max_wavelength = rope_max_wavelength
|
308
|
+
self.rope_scaling_factor = rope_scaling_factor
|
309
|
+
|
310
|
+
self.dropout = dropout
|
311
|
+
|
312
|
+
self.sliding_window = sliding_window
|
313
|
+
self.activation = keras.activations.get(activation)
|
314
|
+
self.layer_norm_epsilon = layer_norm_epsilon
|
315
|
+
self.kernel_initializer = keras.initializers.get(kernel_initializer)
|
316
|
+
|
317
|
+
self.router_aux_loss_coef = router_aux_loss_coef
|
318
|
+
self.output_router_logits = output_router_logits
|
319
|
+
|
320
|
+
self.supports_masking = True
|
321
|
+
|
322
|
+
def build(self, decoder_sequence_shape):
|
323
|
+
self._decoder_sequence_shape = decoder_sequence_shape
|
324
|
+
self.hidden_dim = decoder_sequence_shape[-1]
|
325
|
+
|
326
|
+
# Self attention layer.
|
327
|
+
self._self_attention_layer = CachedMixtralAttention(
|
328
|
+
num_query_heads=self.num_query_heads,
|
329
|
+
num_key_value_heads=self.num_key_value_heads,
|
330
|
+
rope_max_wavelength=self.rope_max_wavelength,
|
331
|
+
rope_scaling_factor=self.rope_scaling_factor,
|
332
|
+
sliding_window=self.sliding_window,
|
333
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
334
|
+
dropout=self.dropout,
|
335
|
+
dtype=self.dtype_policy,
|
336
|
+
name="self_attention",
|
337
|
+
)
|
338
|
+
self._self_attention_layer.build(decoder_sequence_shape)
|
339
|
+
|
340
|
+
self._self_attention_layernorm = MixtralLayerNormalization(
|
341
|
+
epsilon=self.layer_norm_epsilon,
|
342
|
+
dtype=self.dtype_policy,
|
343
|
+
name="self_attention_layernorm",
|
344
|
+
)
|
345
|
+
self._self_attention_layernorm.build(decoder_sequence_shape)
|
346
|
+
self._self_attention_dropout = keras.layers.Dropout(
|
347
|
+
rate=self.dropout,
|
348
|
+
dtype=self.dtype_policy,
|
349
|
+
name="self_attention_dropout",
|
350
|
+
)
|
351
|
+
|
352
|
+
self._sparse_moe_block = MixtralSparseMoeBlock(
|
353
|
+
hidden_dim=self.hidden_dim,
|
354
|
+
intermediate_dim=self.intermediate_dim,
|
355
|
+
num_experts=self.num_experts,
|
356
|
+
top_k=self.top_k,
|
357
|
+
router_jitter_noise=self.router_jitter_noise,
|
358
|
+
router_aux_loss_coef=self.router_aux_loss_coef,
|
359
|
+
dtype=self.dtype_policy,
|
360
|
+
)
|
361
|
+
self._sparse_moe_block.build(decoder_sequence_shape)
|
362
|
+
|
363
|
+
self._feedforward_layernorm = MixtralLayerNormalization(
|
364
|
+
epsilon=self.layer_norm_epsilon,
|
365
|
+
dtype=self.dtype_policy,
|
366
|
+
name="feedforward_layernorm",
|
367
|
+
)
|
368
|
+
self._feedforward_layernorm.build(decoder_sequence_shape)
|
369
|
+
|
370
|
+
self.built = True
|
371
|
+
|
372
|
+
def call(
|
373
|
+
self,
|
374
|
+
decoder_sequence,
|
375
|
+
decoder_padding_mask=None,
|
376
|
+
decoder_attention_mask=None,
|
377
|
+
self_attention_cache=None,
|
378
|
+
self_attention_cache_update_index=None,
|
379
|
+
training=None,
|
380
|
+
):
|
381
|
+
self_attention_mask = self._compute_self_attention_mask(
|
382
|
+
decoder_sequence=decoder_sequence,
|
383
|
+
decoder_padding_mask=decoder_padding_mask,
|
384
|
+
decoder_attention_mask=decoder_attention_mask,
|
385
|
+
self_attention_cache=self_attention_cache,
|
386
|
+
self_attention_cache_update_index=self_attention_cache_update_index,
|
387
|
+
)
|
388
|
+
residual = decoder_sequence
|
389
|
+
|
390
|
+
x = self._self_attention_layernorm(decoder_sequence)
|
391
|
+
|
392
|
+
# Self attention block.
|
393
|
+
x = self._self_attention_layer(
|
394
|
+
hidden_states=x,
|
395
|
+
attention_mask=self_attention_mask,
|
396
|
+
cache=self_attention_cache,
|
397
|
+
cache_update_index=self_attention_cache_update_index,
|
398
|
+
)
|
399
|
+
|
400
|
+
if self_attention_cache is not None:
|
401
|
+
x, self_attention_cache = x
|
402
|
+
|
403
|
+
x = self._self_attention_dropout(x, training=training)
|
404
|
+
|
405
|
+
x = x + residual
|
406
|
+
residual = x
|
407
|
+
|
408
|
+
x = self._feedforward_layernorm(x)
|
409
|
+
x, router_logits = self._sparse_moe_block(
|
410
|
+
x, attention_mask=decoder_padding_mask
|
411
|
+
)
|
412
|
+
|
413
|
+
decoder_output = x + residual
|
414
|
+
|
415
|
+
output = (decoder_output,)
|
416
|
+
|
417
|
+
if self_attention_cache is not None:
|
418
|
+
output += (self_attention_cache,)
|
419
|
+
|
420
|
+
if self.output_router_logits:
|
421
|
+
output += (router_logits,)
|
422
|
+
|
423
|
+
return output[0] if len(output) == 1 else output
|
424
|
+
|
425
|
+
def _compute_self_attention_mask(
|
426
|
+
self,
|
427
|
+
decoder_sequence,
|
428
|
+
decoder_padding_mask,
|
429
|
+
decoder_attention_mask,
|
430
|
+
self_attention_cache,
|
431
|
+
self_attention_cache_update_index,
|
432
|
+
):
|
433
|
+
decoder_mask = merge_padding_and_attention_mask(
|
434
|
+
decoder_sequence, decoder_padding_mask, decoder_attention_mask
|
435
|
+
)
|
436
|
+
batch_size = ops.shape(decoder_sequence)[0]
|
437
|
+
input_length = output_length = ops.shape(decoder_sequence)[1]
|
438
|
+
# We need to handle a rectangular causal mask when doing cached
|
439
|
+
# decoding. For generative inference, `decoder_sequence` will
|
440
|
+
# generally be length 1, and `cache` will be the full generation length.
|
441
|
+
if self_attention_cache is not None:
|
442
|
+
input_length = ops.shape(self_attention_cache)[2]
|
443
|
+
|
444
|
+
cache_update_index = (
|
445
|
+
0
|
446
|
+
if self_attention_cache_update_index is None
|
447
|
+
else self_attention_cache_update_index
|
448
|
+
)
|
449
|
+
|
450
|
+
# The lower traingular attention mask
|
451
|
+
causal_mask = compute_causal_mask(
|
452
|
+
batch_size, input_length, output_length, cache_update_index
|
453
|
+
)
|
454
|
+
|
455
|
+
# Mixtral uses a banded attention mask if sliding window is not None
|
456
|
+
if self.sliding_window is not None:
|
457
|
+
# ops.trui/tril has issues with dynamic shape on the tensorflow
|
458
|
+
# causal_mask = ops.triu(causal_mask, k=-self.sliding_window)
|
459
|
+
i = ops.arange(output_length)[:, None] + cache_update_index
|
460
|
+
j = ops.arange(input_length)[None, :]
|
461
|
+
causal_mask_upper = ops.cast(i < j + self.sliding_window, "int32")
|
462
|
+
causal_mask = ops.minimum(causal_mask, causal_mask_upper)
|
463
|
+
|
464
|
+
return (
|
465
|
+
ops.minimum(decoder_mask, causal_mask)
|
466
|
+
if decoder_mask is not None
|
467
|
+
else causal_mask
|
468
|
+
)
|
469
|
+
|
470
|
+
def compute_output_shape(self, decoder_sequence_shape):
|
471
|
+
return decoder_sequence_shape
|
472
|
+
|
473
|
+
def get_config(self):
|
474
|
+
config = super().get_config()
|
475
|
+
config.update(
|
476
|
+
{
|
477
|
+
"intermediate_dim": self.intermediate_dim,
|
478
|
+
"num_query_heads": self.num_query_heads,
|
479
|
+
"rope_max_wavelength": self.rope_max_wavelength,
|
480
|
+
"rope_scaling_factor": self.rope_scaling_factor,
|
481
|
+
"num_key_value_heads": self.num_key_value_heads,
|
482
|
+
"num_experts": self.num_experts,
|
483
|
+
"top_k": self.top_k,
|
484
|
+
"router_jitter_noise": self.router_jitter_noise,
|
485
|
+
"sliding_window": self.sliding_window,
|
486
|
+
"activation": keras.activations.serialize(self.activation),
|
487
|
+
"layer_norm_epsilon": self.layer_norm_epsilon,
|
488
|
+
"kernel_initializer": keras.initializers.serialize(
|
489
|
+
self.kernel_initializer
|
490
|
+
),
|
491
|
+
"dropout": self.dropout,
|
492
|
+
}
|
493
|
+
)
|
494
|
+
return config
|
@@ -0,0 +1,34 @@
|
|
1
|
+
import keras
|
2
|
+
from keras import ops
|
3
|
+
|
4
|
+
|
5
|
+
# NOTE: `keras.layers.LayerNormalization(rms_scaling=True)`
|
6
|
+
# does not produce same results
|
7
|
+
class MixtralLayerNormalization(keras.layers.Layer):
|
8
|
+
"""A normalization layer for Mixtral that implements RMS normalization."""
|
9
|
+
|
10
|
+
def __init__(self, epsilon=1e-6, **kwargs):
|
11
|
+
super().__init__(**kwargs)
|
12
|
+
self.epsilon = epsilon
|
13
|
+
|
14
|
+
def build(self, input_shape):
|
15
|
+
dim = input_shape[-1]
|
16
|
+
self.scale = self.add_weight(
|
17
|
+
name="scale",
|
18
|
+
trainable=True,
|
19
|
+
shape=(dim,),
|
20
|
+
initializer="ones",
|
21
|
+
dtype=self.variable_dtype,
|
22
|
+
)
|
23
|
+
self.built = True
|
24
|
+
|
25
|
+
def call(self, x):
|
26
|
+
x = ops.cast(x, "float32")
|
27
|
+
var = ops.mean(ops.power(x, 2), axis=-1, keepdims=True)
|
28
|
+
x = x * ops.rsqrt(var + self.epsilon)
|
29
|
+
return ops.cast(x * self.scale, self.compute_dtype)
|
30
|
+
|
31
|
+
def get_config(self):
|
32
|
+
config = super().get_config()
|
33
|
+
config.update({"epsilon": self.epsilon})
|
34
|
+
return config
|
@@ -0,0 +1,21 @@
|
|
1
|
+
from keras_hub.src.api_export import keras_hub_export
|
2
|
+
from keras_hub.src.models.mixtral.mixtral_backbone import MixtralBackbone
|
3
|
+
from keras_hub.src.tokenizers.sentence_piece_tokenizer import (
|
4
|
+
SentencePieceTokenizer,
|
5
|
+
)
|
6
|
+
|
7
|
+
|
8
|
+
@keras_hub_export(
|
9
|
+
[
|
10
|
+
"keras_hub.tokenizers.MixtralTokenizer",
|
11
|
+
"keras_hub.models.MixtralTokenizer",
|
12
|
+
]
|
13
|
+
)
|
14
|
+
class MixtralTokenizer(SentencePieceTokenizer):
|
15
|
+
backbone_cls = MixtralBackbone
|
16
|
+
|
17
|
+
def __init__(self, proto, **kwargs):
|
18
|
+
self._add_special_token("<s>", "start_token")
|
19
|
+
self._add_special_token("</s>", "end_token")
|
20
|
+
self.pad_token_id = 0
|
21
|
+
super().__init__(proto=proto, **kwargs)
|
@@ -287,7 +287,9 @@ class QwenAttention(keras.layers.Layer):
|
|
287
287
|
if self.use_sliding_window_attention:
|
288
288
|
attention_mask = self._mask_sliding_window(
|
289
289
|
attention_mask,
|
290
|
-
cache_update_index=cache_update_index
|
290
|
+
cache_update_index=cache_update_index
|
291
|
+
if cache_update_index
|
292
|
+
else 0,
|
291
293
|
)
|
292
294
|
attention_scores = self._masked_softmax(
|
293
295
|
attention_scores, attention_mask
|
@@ -0,0 +1,61 @@
|
|
1
|
+
"""Qwen preset configurations."""
|
2
|
+
|
3
|
+
backbone_presets = {
|
4
|
+
"qwen2.5_0.5b_en": {
|
5
|
+
"metadata": {
|
6
|
+
"description": ("24-layer Qwen model with 0.5 billion parameters."),
|
7
|
+
"params": 494032768,
|
8
|
+
"path": "qwen",
|
9
|
+
},
|
10
|
+
"kaggle_handle": "kaggle://keras/qwen/keras/qwen2.5_0.5b_en",
|
11
|
+
},
|
12
|
+
"qwen2.5_3b_en": {
|
13
|
+
"metadata": {
|
14
|
+
"description": ("36-layer Qwen model with 3.1 billion parameters."),
|
15
|
+
"params": 3085938688,
|
16
|
+
"path": "qwen",
|
17
|
+
},
|
18
|
+
"kaggle_handle": "kaggle://keras/qwen/keras/qwen2.5_3b_en",
|
19
|
+
},
|
20
|
+
"qwen2.5_7b_en": {
|
21
|
+
"metadata": {
|
22
|
+
"description": ("48-layer Qwen model with 7 billion parameters."),
|
23
|
+
"params": 6993420288,
|
24
|
+
"path": "qwen",
|
25
|
+
},
|
26
|
+
"kaggle_handle": "kaggle://keras/qwen/keras/qwen2.5_7b_en/2",
|
27
|
+
},
|
28
|
+
"qwen2.5_instruct_0.5b_en": {
|
29
|
+
"metadata": {
|
30
|
+
"description": (
|
31
|
+
"Instruction fine-tuned 24-layer Qwen model with 0.5 ",
|
32
|
+
"billion parameters.",
|
33
|
+
),
|
34
|
+
"params": 494032768,
|
35
|
+
"path": "qwen",
|
36
|
+
},
|
37
|
+
"kaggle_handle": "kaggle://keras/qwen/keras/qwen2.5_instruct_0.5b_en",
|
38
|
+
},
|
39
|
+
"qwen2.5_instruct_32b_en": {
|
40
|
+
"metadata": {
|
41
|
+
"description": (
|
42
|
+
"Instruction fine-tuned 64-layer Qwen model with 32 ",
|
43
|
+
"billion parameters.",
|
44
|
+
),
|
45
|
+
"params": 32763876352,
|
46
|
+
"path": "qwen",
|
47
|
+
},
|
48
|
+
"kaggle_handle": "kaggle://keras/qwen/keras/qwen2.5_instruct_32b_en",
|
49
|
+
},
|
50
|
+
"qwen2.5_instruct_72b_en": {
|
51
|
+
"metadata": {
|
52
|
+
"description": (
|
53
|
+
"Instruction fine-tuned 80-layer Qwen model with 72 ",
|
54
|
+
"billion parameters.",
|
55
|
+
),
|
56
|
+
"params": 72706203648,
|
57
|
+
"path": "qwen",
|
58
|
+
},
|
59
|
+
"kaggle_handle": "kaggle://keras/qwen/keras/qwen2.5_instruct_72b_en",
|
60
|
+
},
|
61
|
+
}
|
File without changes
|