keras-hub-nightly 0.21.0.dev202505050407__py3-none-any.whl → 0.21.0.dev202505060405__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. keras_hub/models/__init__.py +21 -0
  2. keras_hub/src/models/backbone.py +5 -2
  3. keras_hub/src/models/mixtral/mixtral_attention.py +263 -0
  4. keras_hub/src/models/mixtral/mixtral_backbone.py +207 -0
  5. keras_hub/src/models/mixtral/mixtral_causal_lm.py +281 -0
  6. keras_hub/src/models/mixtral/mixtral_causal_lm_preprocessor.py +76 -0
  7. keras_hub/src/models/mixtral/mixtral_decoder.py +494 -0
  8. keras_hub/src/models/mixtral/mixtral_layer_norm.py +34 -0
  9. keras_hub/src/models/mixtral/mixtral_tokenizer.py +21 -0
  10. keras_hub/src/models/qwen/qwen_attention.py +3 -1
  11. keras_hub/src/models/qwen/qwen_presets.py +61 -0
  12. keras_hub/src/models/qwen_moe/__init__.py +0 -0
  13. keras_hub/src/models/qwen_moe/qwen_moe_attention.py +377 -0
  14. keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +373 -0
  15. keras_hub/src/models/qwen_moe/qwen_moe_causal_lm.py +350 -0
  16. keras_hub/src/models/qwen_moe/qwen_moe_causal_lm_preprocessor.py +17 -0
  17. keras_hub/src/models/qwen_moe/qwen_moe_decoder.py +625 -0
  18. keras_hub/src/models/qwen_moe/qwen_moe_layernorm.py +32 -0
  19. keras_hub/src/models/qwen_moe/qwen_moe_tokenizer.py +46 -0
  20. keras_hub/src/models/retinanet/retinanet_image_converter.py +0 -13
  21. keras_hub/src/models/retinanet/retinanet_presets.py +2 -2
  22. keras_hub/src/models/task.py +5 -2
  23. keras_hub/src/utils/keras_utils.py +11 -0
  24. keras_hub/src/utils/preset_utils.py +69 -9
  25. keras_hub/src/utils/tensor_utils.py +27 -1
  26. keras_hub/src/utils/transformers/convert_mixtral.py +139 -0
  27. keras_hub/src/utils/transformers/convert_qwen_moe.py +253 -0
  28. keras_hub/src/utils/transformers/preset_loader.py +6 -0
  29. keras_hub/src/version.py +1 -1
  30. keras_hub/tokenizers/__init__.py +6 -0
  31. {keras_hub_nightly-0.21.0.dev202505050407.dist-info → keras_hub_nightly-0.21.0.dev202505060405.dist-info}/METADATA +1 -1
  32. {keras_hub_nightly-0.21.0.dev202505050407.dist-info → keras_hub_nightly-0.21.0.dev202505060405.dist-info}/RECORD +34 -16
  33. {keras_hub_nightly-0.21.0.dev202505050407.dist-info → keras_hub_nightly-0.21.0.dev202505060405.dist-info}/WHEEL +0 -0
  34. {keras_hub_nightly-0.21.0.dev202505050407.dist-info → keras_hub_nightly-0.21.0.dev202505060405.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,377 @@
1
+ import inspect
2
+ import math
3
+
4
+ import keras
5
+ from keras import ops
6
+
7
+ from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
8
+ from keras_hub.src.utils.keras_utils import clone_initializer
9
+ from keras_hub.src.utils.keras_utils import fused_attention_op_available
10
+ from keras_hub.src.utils.keras_utils import gpu_supports_fused_attention_op
11
+ from keras_hub.src.utils.keras_utils import running_on_gpu
12
+ from keras_hub.src.utils.keras_utils import running_on_tpu
13
+
14
+
15
+ class QwenMoeAttention(keras.layers.Layer):
16
+ """A multi-head attention layer for Qwen-Moe model
17
+
18
+ This attention implementation supports grouped-query attention (GQA) where
19
+ the number of key-value heads can be less than the number of query heads.
20
+
21
+ Args:
22
+ num_query_heads: Number of query heads.
23
+ num_key_value_heads: Number of key/value heads (for GQA).
24
+ rope_max_wavelength: Maximum wavelength for RoPE (Rotary Position
25
+ Embedding).
26
+ rope_scaling_factor: Scaling factor for RoPE, used for extending
27
+ context length.
28
+ kernel_initializer: Initializer for the kernel weights.
29
+ bias_initializer: Initializer for the bias weights.
30
+ dropout: Dropout rate for attention weights.
31
+ use_sliding_window_attention: Whether to use sliding window
32
+ attention.
33
+ sliding_window_size: Size of the sliding window for attention.
34
+ **kwargs: Additional keyword arguments to pass to the Layer.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ num_query_heads,
40
+ num_key_value_heads,
41
+ rope_max_wavelength=10000,
42
+ rope_scaling_factor=1,
43
+ kernel_initializer="glorot_uniform",
44
+ bias_initializer="zeros",
45
+ dropout=0,
46
+ use_sliding_window_attention=False,
47
+ sliding_window_size=4096,
48
+ **kwargs,
49
+ ):
50
+ super().__init__(
51
+ **kwargs,
52
+ )
53
+ self.num_query_heads = num_query_heads
54
+ self.num_key_value_heads = num_key_value_heads
55
+ self.dropout = dropout
56
+
57
+ self.num_key_value_groups = num_query_heads // num_key_value_heads
58
+ self.rope_max_wavelength = rope_max_wavelength
59
+
60
+ self.kernel_initializer = keras.initializers.get(
61
+ clone_initializer(kernel_initializer)
62
+ )
63
+ self.bias_initializer = keras.initializers.get(
64
+ clone_initializer(bias_initializer)
65
+ )
66
+
67
+ self.rope_scaling_factor = rope_scaling_factor
68
+ self.use_sliding_window_attention = use_sliding_window_attention
69
+ self.sliding_window_size = sliding_window_size
70
+
71
+ def build(self, inputs_shape):
72
+ # Einsum variables:
73
+ # b = batch size
74
+ # q = query length
75
+ # k = key/value length
76
+ # m = model dim
77
+ # u = num query heads
78
+ # v = num key/value heads
79
+ # h = head dim
80
+ hidden_dim = inputs_shape[-1]
81
+ head_dim = hidden_dim // self.num_query_heads
82
+ self._inv_norm_factor = 1.0 / math.sqrt(head_dim)
83
+ self.query_dense = keras.layers.EinsumDense(
84
+ equation="bqm,muh->bquh",
85
+ output_shape=(None, self.num_query_heads, head_dim),
86
+ kernel_initializer=self.kernel_initializer,
87
+ bias_initializer=self.bias_initializer,
88
+ bias_axes="uh",
89
+ dtype=self.dtype_policy,
90
+ name="query",
91
+ )
92
+ self.query_dense.build(inputs_shape)
93
+
94
+ self.key_dense = keras.layers.EinsumDense(
95
+ equation="bkm,mvh->bkvh",
96
+ output_shape=(
97
+ None,
98
+ self.num_key_value_heads,
99
+ head_dim,
100
+ ),
101
+ kernel_initializer=self.kernel_initializer,
102
+ bias_initializer=self.bias_initializer,
103
+ bias_axes="vh",
104
+ dtype=self.dtype_policy,
105
+ name="key",
106
+ )
107
+ self.key_dense.build(inputs_shape)
108
+
109
+ self.value_dense = keras.layers.EinsumDense(
110
+ equation="bkm,mvh->bkvh",
111
+ output_shape=(
112
+ None,
113
+ self.num_key_value_heads,
114
+ head_dim,
115
+ ),
116
+ kernel_initializer=self.kernel_initializer,
117
+ bias_initializer=self.bias_initializer,
118
+ bias_axes="vh",
119
+ dtype=self.dtype_policy,
120
+ name="value",
121
+ )
122
+ self.value_dense.build(inputs_shape)
123
+
124
+ self._softmax = keras.layers.Softmax(
125
+ axis=-1,
126
+ dtype="float32",
127
+ name="attention_softmax",
128
+ )
129
+
130
+ self._dropout_layer = keras.layers.Dropout(
131
+ rate=self.dropout,
132
+ dtype=self.dtype_policy,
133
+ )
134
+
135
+ self._output_dense = keras.layers.EinsumDense(
136
+ equation="bquh,uhm->bqm",
137
+ output_shape=(None, hidden_dim),
138
+ kernel_initializer=self.kernel_initializer,
139
+ dtype=self.dtype_policy,
140
+ name="attention_output",
141
+ )
142
+ self._output_dense.build((None, None, self.num_query_heads, head_dim))
143
+
144
+ self.rotary_embedding_layer = RotaryEmbedding(
145
+ max_wavelength=self.rope_max_wavelength,
146
+ scaling_factor=self.rope_scaling_factor,
147
+ dtype=self.dtype_policy,
148
+ )
149
+
150
+ self._dot_product_equation = "bquh,bkuh->buqk"
151
+ self._combine_equation = "buqk,bkuh->bquh"
152
+
153
+ self.built = True
154
+
155
+ def call(
156
+ self,
157
+ hidden_states,
158
+ attention_mask=None,
159
+ cache=None,
160
+ cache_update_index=None,
161
+ training=None,
162
+ ):
163
+ """Applies attention mechanism to the input hidden states.
164
+
165
+ Args:
166
+ hidden_states: Input tensor of shape [batch_size, seq_length,
167
+ hidden_size].
168
+ attention_mask: Mask tensor of shape [batch_size, seq_length,
169
+ seq_length].
170
+ cache: Optional cached key and value tensors.
171
+ cache_update_index: Index at which to update the cache.
172
+ training: Boolean indicating whether in training mode.
173
+
174
+ Returns:
175
+ attention_output: Output tensor after applying attention.
176
+ cache: Updated cache tensors (if cache is provided).
177
+ """
178
+ start_index = (
179
+ cache_update_index if cache_update_index is not None else 0
180
+ )
181
+
182
+ query = self.query_dense(hidden_states)
183
+
184
+ # Compute RoPE for queries
185
+ query = self.rotary_embedding_layer(query, start_index=start_index)
186
+
187
+ def _compute_key_value(x):
188
+ key, value = self.key_dense(x), self.value_dense(x)
189
+ # Compute RoPE for keys
190
+ key = self.rotary_embedding_layer(key, start_index=start_index)
191
+ return key, value
192
+
193
+ if cache is not None:
194
+ key_cache = cache[:, 0, ...]
195
+ value_cache = cache[:, 1, ...]
196
+ if cache_update_index is None:
197
+ key = key_cache
198
+ value = value_cache
199
+ else:
200
+ key_update, value_update = _compute_key_value(hidden_states)
201
+ start = [0, cache_update_index, 0, 0]
202
+ key = ops.slice_update(key_cache, start, key_update)
203
+ value = ops.slice_update(value_cache, start, value_update)
204
+ cache = ops.stack((key, value), axis=1)
205
+ else:
206
+ if cache_update_index is not None:
207
+ raise ValueError(
208
+ "`cache_update_index` should not be set if `cache` is "
209
+ f"`None`. Received: cache={cache}, "
210
+ f"cache_update_index={cache_update_index}"
211
+ )
212
+ key, value = _compute_key_value(hidden_states)
213
+
214
+ # [batch_shape, seq_len, num_key_value_heads, head_dim]
215
+ # -> [batch_shape, seq_len, num_heads, head_dim]
216
+ key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2)
217
+ value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2)
218
+
219
+ attention_output = self._compute_attention(
220
+ query,
221
+ key,
222
+ value,
223
+ attention_mask,
224
+ cache_update_index=cache_update_index,
225
+ )
226
+
227
+ attention_output = self._dropout_layer(
228
+ attention_output, training=training
229
+ )
230
+
231
+ attention_output = self._output_dense(attention_output)
232
+
233
+ if cache is not None:
234
+ return attention_output, cache
235
+ return attention_output
236
+
237
+ def _masked_softmax(self, attention_scores, attention_mask=None):
238
+ """Applies softmax with optional masking.
239
+
240
+ Args:
241
+ attention_scores: Attention score tensor.
242
+ attention_mask: Optional mask tensor.
243
+
244
+ Returns:
245
+ Masked softmax attention weights.
246
+ """
247
+ if attention_mask is not None:
248
+ return self._softmax(
249
+ attention_scores, attention_mask[:, None, :, :]
250
+ )
251
+ return self._softmax(attention_scores)
252
+
253
+ def _use_fused_attention_op(self):
254
+ if not fused_attention_op_available():
255
+ return False
256
+ if self.dropout > 0.0:
257
+ return False
258
+ if running_on_gpu():
259
+ # GPU never supports softcap in the fused op.
260
+ if self.logit_soft_cap is not None:
261
+ return False
262
+ return gpu_supports_fused_attention_op()
263
+ elif running_on_tpu():
264
+ # TPU supports softcap with on keras >= 3.10.
265
+ sig = inspect.signature(ops.dot_product_attention)
266
+ return "attn_logits_soft_cap" in sig.parameters
267
+ else:
268
+ return False
269
+
270
+ def _compute_attention(
271
+ self, query, key, value, attention_mask=None, cache_update_index=None
272
+ ):
273
+ """Computes attention using query, key, and value tensors.
274
+
275
+ Uses Flash Attention when available for better performance.
276
+
277
+ Args:
278
+ query: Query tensor.
279
+ key: Key tensor.
280
+ value: Value tensor.
281
+ attention_mask: Optional mask tensor.
282
+ cache_update_index: Index for sliding window computation.
283
+
284
+ Returns:
285
+ attention_output: Output tensor after applying attention.
286
+ """
287
+ if self._use_fused_attention_op():
288
+ if attention_mask is not None:
289
+ attention_mask = ops.expand_dims(attention_mask, axis=1)
290
+ attention_mask = ops.cast(attention_mask, dtype="bool")
291
+
292
+ if self.logit_soft_cap:
293
+ kwargs = {"attn_logits_soft_cap": self.logit_soft_cap}
294
+ else:
295
+ kwargs = {}
296
+
297
+ attention_output = ops.dot_product_attention(
298
+ query,
299
+ key,
300
+ value,
301
+ mask=attention_mask,
302
+ scale=self._inv_norm_factor,
303
+ **kwargs,
304
+ )
305
+ return attention_output
306
+
307
+ attention_scores = ops.einsum(self._dot_product_equation, query, key)
308
+
309
+ attention_scores = ops.multiply(
310
+ attention_scores,
311
+ ops.cast(self._inv_norm_factor, self.compute_dtype),
312
+ )
313
+ if self.use_sliding_window_attention:
314
+ attention_mask = self._mask_sliding_window(
315
+ attention_mask,
316
+ cache_update_index=cache_update_index
317
+ if cache_update_index
318
+ else 0,
319
+ )
320
+ attention_scores = self._masked_softmax(
321
+ attention_scores, attention_mask
322
+ )
323
+ attention_scores = ops.cast(attention_scores, self.compute_dtype)
324
+ attention_output = ops.einsum(
325
+ self._combine_equation, attention_scores, value
326
+ )
327
+
328
+ return attention_output
329
+
330
+ def _mask_sliding_window(
331
+ self,
332
+ attention_mask,
333
+ cache_update_index=0,
334
+ ):
335
+ """Creates and combines a sliding window mask with the attention mask.
336
+
337
+ Args:
338
+ attention_mask: Original attention mask.
339
+ cache_update_index: Starting index for the sliding window.
340
+
341
+ Returns:
342
+ Combined attention mask with sliding window constraints.
343
+ """
344
+ _, query_len, key_len = ops.shape(attention_mask)
345
+ # Compute the sliding window for square attention.
346
+ all_ones = ops.ones((key_len, key_len), "bool")
347
+ sliding_mask = ops.triu(
348
+ all_ones, -1 * self.sliding_window_size + 1
349
+ ) * ops.tril(all_ones, self.sliding_window_size - 1)
350
+ # Slice the window for short queries during generation.
351
+ start = (cache_update_index, 0)
352
+ sliding_mask = ops.slice(sliding_mask, start, (query_len, key_len))
353
+ sliding_mask = ops.expand_dims(sliding_mask, 0)
354
+ return ops.logical_and(attention_mask, ops.cast(sliding_mask, "bool"))
355
+
356
+ def get_config(self):
357
+ config = super().get_config()
358
+ config.update(
359
+ {
360
+ "num_query_heads": self.num_query_heads,
361
+ "num_key_value_heads": self.num_key_value_heads,
362
+ "rope_max_wavelength": self.rope_max_wavelength,
363
+ "rope_scaling_factor": self.rope_scaling_factor,
364
+ "kernel_initializer": keras.initializers.serialize(
365
+ self.kernel_initializer
366
+ ),
367
+ "bias_initializer": keras.initializers.serialize(
368
+ self.bias_initializer
369
+ ),
370
+ "dropout": self.dropout,
371
+ "use_sliding_window_attention": (
372
+ self.use_sliding_window_attention
373
+ ),
374
+ "sliding_window_size": self.sliding_window_size,
375
+ }
376
+ )
377
+ return config