keras-hub-nightly 0.21.0.dev202505200408__py3-none-any.whl → 0.21.0.dev202505220409__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/src/models/mixtral/mixtral_attention.py +31 -42
- keras_hub/src/models/qwen_moe/qwen_moe_attention.py +7 -9
- keras_hub/src/utils/transformers/convert_mixtral.py +1 -1
- keras_hub/src/version.py +1 -1
- {keras_hub_nightly-0.21.0.dev202505200408.dist-info → keras_hub_nightly-0.21.0.dev202505220409.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.21.0.dev202505200408.dist-info → keras_hub_nightly-0.21.0.dev202505220409.dist-info}/RECORD +8 -8
- {keras_hub_nightly-0.21.0.dev202505200408.dist-info → keras_hub_nightly-0.21.0.dev202505220409.dist-info}/WHEEL +1 -1
- {keras_hub_nightly-0.21.0.dev202505200408.dist-info → keras_hub_nightly-0.21.0.dev202505220409.dist-info}/top_level.txt +0 -0
@@ -27,19 +27,19 @@ class CachedMixtralAttention(keras.layers.Layer):
|
|
27
27
|
**kwargs,
|
28
28
|
):
|
29
29
|
super().__init__(**kwargs)
|
30
|
-
self.
|
31
|
-
self.
|
32
|
-
self.
|
33
|
-
self.
|
30
|
+
self.num_query_heads = num_query_heads
|
31
|
+
self.num_key_value_heads = num_key_value_heads
|
32
|
+
self.sliding_window = sliding_window
|
33
|
+
self.dropout = dropout
|
34
34
|
|
35
|
-
self.
|
36
|
-
self.
|
35
|
+
self.num_key_value_groups = num_query_heads // num_key_value_heads
|
36
|
+
self.rope_max_wavelength = rope_max_wavelength
|
37
37
|
|
38
38
|
self._kernel_initializer = keras.initializers.get(
|
39
39
|
clone_initializer(kernel_initializer)
|
40
40
|
)
|
41
41
|
|
42
|
-
self.
|
42
|
+
self.rope_scaling_factor = rope_scaling_factor
|
43
43
|
|
44
44
|
def build(self, inputs_shape):
|
45
45
|
# Einsum variables:
|
@@ -51,12 +51,12 @@ class CachedMixtralAttention(keras.layers.Layer):
|
|
51
51
|
# v = num key/value heads
|
52
52
|
# h = head dim
|
53
53
|
self._hidden_dim = inputs_shape[-1]
|
54
|
-
self._head_dim = self._hidden_dim // self.
|
54
|
+
self._head_dim = self._hidden_dim // self.num_query_heads
|
55
55
|
self._inv_norm_factor = 1.0 / math.sqrt(self._head_dim)
|
56
56
|
|
57
57
|
self.query_dense = keras.layers.EinsumDense(
|
58
58
|
equation="bqm,muh->bquh",
|
59
|
-
output_shape=(None, self.
|
59
|
+
output_shape=(None, self.num_query_heads, self._head_dim),
|
60
60
|
kernel_initializer=self._kernel_initializer,
|
61
61
|
dtype=self.dtype_policy,
|
62
62
|
name="query",
|
@@ -67,7 +67,7 @@ class CachedMixtralAttention(keras.layers.Layer):
|
|
67
67
|
equation="bkm,mvh->bkvh",
|
68
68
|
output_shape=(
|
69
69
|
None,
|
70
|
-
self.
|
70
|
+
self.num_key_value_heads,
|
71
71
|
self._head_dim,
|
72
72
|
),
|
73
73
|
kernel_initializer=self._kernel_initializer,
|
@@ -80,7 +80,7 @@ class CachedMixtralAttention(keras.layers.Layer):
|
|
80
80
|
equation="bkm,mvh->bkvh",
|
81
81
|
output_shape=(
|
82
82
|
None,
|
83
|
-
self.
|
83
|
+
self.num_key_value_heads,
|
84
84
|
self._head_dim,
|
85
85
|
),
|
86
86
|
kernel_initializer=self._kernel_initializer,
|
@@ -89,31 +89,31 @@ class CachedMixtralAttention(keras.layers.Layer):
|
|
89
89
|
)
|
90
90
|
self.value_dense.build(inputs_shape)
|
91
91
|
|
92
|
-
self.
|
92
|
+
self.softmax = keras.layers.Softmax(
|
93
93
|
axis=-1,
|
94
94
|
dtype="float32",
|
95
95
|
name="attention_softmax",
|
96
96
|
)
|
97
97
|
|
98
|
-
self.
|
99
|
-
rate=self.
|
98
|
+
self.dropout_layer = keras.layers.Dropout(
|
99
|
+
rate=self.dropout,
|
100
100
|
dtype=self.dtype_policy,
|
101
101
|
)
|
102
102
|
|
103
|
-
self.
|
103
|
+
self.output_dense = keras.layers.EinsumDense(
|
104
104
|
equation="bquh,uhm->bqm",
|
105
105
|
output_shape=(None, self._hidden_dim),
|
106
106
|
kernel_initializer=self._kernel_initializer,
|
107
107
|
dtype=self.dtype_policy,
|
108
108
|
name="attention_output",
|
109
109
|
)
|
110
|
-
self.
|
111
|
-
(None, None, self.
|
110
|
+
self.output_dense.build(
|
111
|
+
(None, None, self.num_query_heads, self._head_dim)
|
112
112
|
)
|
113
113
|
|
114
114
|
self.rotary_embedding_layer = RotaryEmbedding(
|
115
|
-
max_wavelength=self.
|
116
|
-
scaling_factor=self.
|
115
|
+
max_wavelength=self.rope_max_wavelength,
|
116
|
+
scaling_factor=self.rope_scaling_factor,
|
117
117
|
dtype=self.dtype_policy,
|
118
118
|
)
|
119
119
|
|
@@ -168,18 +168,18 @@ class CachedMixtralAttention(keras.layers.Layer):
|
|
168
168
|
|
169
169
|
# [batch_shape, seq_len, num_key_value_heads, head_dim]
|
170
170
|
# -> [batch_shape, seq_len, num_heads, head_dim]
|
171
|
-
key = ops.repeat(key, repeats=self.
|
172
|
-
value = ops.repeat(value, repeats=self.
|
171
|
+
key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2)
|
172
|
+
value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2)
|
173
173
|
|
174
174
|
attention_output = self._compute_attention(
|
175
175
|
query, key, value, attention_mask
|
176
176
|
)
|
177
177
|
|
178
|
-
attention_output = self.
|
178
|
+
attention_output = self.dropout_layer(
|
179
179
|
attention_output, training=training
|
180
180
|
)
|
181
181
|
|
182
|
-
attention_output = self.
|
182
|
+
attention_output = self.output_dense(attention_output)
|
183
183
|
|
184
184
|
if cache is not None:
|
185
185
|
return attention_output, cache
|
@@ -187,10 +187,8 @@ class CachedMixtralAttention(keras.layers.Layer):
|
|
187
187
|
|
188
188
|
def _masked_softmax(self, attention_scores, attention_mask=None):
|
189
189
|
if attention_mask is not None:
|
190
|
-
return self.
|
191
|
-
|
192
|
-
)
|
193
|
-
return self._softmax(attention_scores)
|
190
|
+
return self.softmax(attention_scores, attention_mask[:, None, :, :])
|
191
|
+
return self.softmax(attention_scores)
|
194
192
|
|
195
193
|
def _use_fused_attention_op(self):
|
196
194
|
if not fused_attention_op_available():
|
@@ -198,9 +196,6 @@ class CachedMixtralAttention(keras.layers.Layer):
|
|
198
196
|
if self.dropout > 0.0:
|
199
197
|
return False
|
200
198
|
if running_on_gpu():
|
201
|
-
# GPU never supports softcap in the fused op.
|
202
|
-
if self.logit_soft_cap is not None:
|
203
|
-
return False
|
204
199
|
return gpu_supports_fused_attention_op()
|
205
200
|
elif running_on_tpu():
|
206
201
|
# TPU supports softcap with on keras >= 3.10.
|
@@ -215,18 +210,12 @@ class CachedMixtralAttention(keras.layers.Layer):
|
|
215
210
|
attention_mask = ops.expand_dims(attention_mask, axis=1)
|
216
211
|
attention_mask = ops.cast(attention_mask, dtype="bool")
|
217
212
|
|
218
|
-
if self.logit_soft_cap:
|
219
|
-
kwargs = {"attn_logits_soft_cap": self.logit_soft_cap}
|
220
|
-
else:
|
221
|
-
kwargs = {}
|
222
|
-
|
223
213
|
attention_output = ops.dot_product_attention(
|
224
214
|
query,
|
225
215
|
key,
|
226
216
|
value,
|
227
217
|
mask=attention_mask,
|
228
218
|
scale=self._inv_norm_factor,
|
229
|
-
**kwargs,
|
230
219
|
)
|
231
220
|
return attention_output
|
232
221
|
|
@@ -249,15 +238,15 @@ class CachedMixtralAttention(keras.layers.Layer):
|
|
249
238
|
config = super().get_config()
|
250
239
|
config.update(
|
251
240
|
{
|
252
|
-
"num_query_heads": self.
|
253
|
-
"num_key_value_heads": self.
|
254
|
-
"rope_max_wavelength": self.
|
255
|
-
"rope_scaling_factor": self.
|
241
|
+
"num_query_heads": self.num_query_heads,
|
242
|
+
"num_key_value_heads": self.num_key_value_heads,
|
243
|
+
"rope_max_wavelength": self.rope_max_wavelength,
|
244
|
+
"rope_scaling_factor": self.rope_scaling_factor,
|
256
245
|
"kernel_initializer": keras.initializers.serialize(
|
257
246
|
self._kernel_initializer
|
258
247
|
),
|
259
|
-
"sliding_window": self.
|
260
|
-
"dropout": self.
|
248
|
+
"sliding_window": self.sliding_window,
|
249
|
+
"dropout": self.dropout,
|
261
250
|
}
|
262
251
|
)
|
263
252
|
return config
|
@@ -256,9 +256,6 @@ class QwenMoeAttention(keras.layers.Layer):
|
|
256
256
|
if self.dropout > 0.0:
|
257
257
|
return False
|
258
258
|
if running_on_gpu():
|
259
|
-
# GPU never supports softcap in the fused op.
|
260
|
-
if self.logit_soft_cap is not None:
|
261
|
-
return False
|
262
259
|
return gpu_supports_fused_attention_op()
|
263
260
|
elif running_on_tpu():
|
264
261
|
# TPU supports softcap with on keras >= 3.10.
|
@@ -268,7 +265,13 @@ class QwenMoeAttention(keras.layers.Layer):
|
|
268
265
|
return False
|
269
266
|
|
270
267
|
def _compute_attention(
|
271
|
-
self,
|
268
|
+
self,
|
269
|
+
query,
|
270
|
+
key,
|
271
|
+
value,
|
272
|
+
attention_mask=None,
|
273
|
+
cache_update_index=None,
|
274
|
+
**kwargs,
|
272
275
|
):
|
273
276
|
"""Computes attention using query, key, and value tensors.
|
274
277
|
|
@@ -289,11 +292,6 @@ class QwenMoeAttention(keras.layers.Layer):
|
|
289
292
|
attention_mask = ops.expand_dims(attention_mask, axis=1)
|
290
293
|
attention_mask = ops.cast(attention_mask, dtype="bool")
|
291
294
|
|
292
|
-
if self.logit_soft_cap:
|
293
|
-
kwargs = {"attn_logits_soft_cap": self.logit_soft_cap}
|
294
|
-
else:
|
295
|
-
kwargs = {}
|
296
|
-
|
297
295
|
attention_output = ops.dot_product_attention(
|
298
296
|
query,
|
299
297
|
key,
|
@@ -68,7 +68,7 @@ def convert_weights(backbone, loader, transformers_config):
|
|
68
68
|
)
|
69
69
|
## Output
|
70
70
|
loader.port_weight(
|
71
|
-
keras_variable=decoder_layer._self_attention_layer.
|
71
|
+
keras_variable=decoder_layer._self_attention_layer.output_dense.kernel,
|
72
72
|
hf_weight_key=f"model.layers.{i}.self_attn.o_proj.weight",
|
73
73
|
hook_fn=transpose_and_reshape,
|
74
74
|
)
|
keras_hub/src/version.py
CHANGED
@@ -5,7 +5,7 @@ keras_hub/models/__init__.py,sha256=itSzodVUeuX6HQnmsSXY0Wv-5Htbu397410R-SFW_4I,
|
|
5
5
|
keras_hub/samplers/__init__.py,sha256=aFQIkiqbZpi8vjrPp2MVII4QUfE-eQjra5fMeHsoy7k,886
|
6
6
|
keras_hub/src/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
7
7
|
keras_hub/src/api_export.py,sha256=9pQZK27JObxWZ96QPLBp1OBsjWigh1iuV6RglPGMRk0,1499
|
8
|
-
keras_hub/src/version.py,sha256=
|
8
|
+
keras_hub/src/version.py,sha256=ZWHai9U-yJxL-dj1yBgjl16y6XtOeP2SreCCjSf9xgA,222
|
9
9
|
keras_hub/src/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
10
10
|
keras_hub/src/layers/modeling/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
11
|
keras_hub/src/layers/modeling/alibi_bias.py,sha256=1XBTHI52L_iJDhN_w5ydu_iMhCuTgQAxEPwcLA6BPuk,4411
|
@@ -250,7 +250,7 @@ keras_hub/src/models/mit/mit_image_classifier_preprocessor.py,sha256=oNYs-pUK8Vn
|
|
250
250
|
keras_hub/src/models/mit/mit_image_converter.py,sha256=Mw7nV-OzyBveGuZUNFsPPKyq9jXJVW2_cVH024CNkXM,311
|
251
251
|
keras_hub/src/models/mit/mit_layers.py,sha256=HUJO5uhJ6jgwANpwbQdPlEVwLRVb3BZQ-Ftjg3B9XvY,9734
|
252
252
|
keras_hub/src/models/mit/mit_presets.py,sha256=ooLrh2OoGZKxnCGnhB6BynYJtVCXH7nDDFhgQRWt36U,4528
|
253
|
-
keras_hub/src/models/mixtral/mixtral_attention.py,sha256=
|
253
|
+
keras_hub/src/models/mixtral/mixtral_attention.py,sha256=f5aiTtstWeKG_ZwumAlYIzjIN08CpnxNdenxWNJSwZw,8713
|
254
254
|
keras_hub/src/models/mixtral/mixtral_backbone.py,sha256=vUAFXvqwVBgKxYbOsqIHzPN59bhaDrGWwOnBCzeUtt0,8034
|
255
255
|
keras_hub/src/models/mixtral/mixtral_causal_lm.py,sha256=JA1t6xTeaYX_fNo9ftRyvzdRDG3vndC-Rlwn5fnsbQo,12001
|
256
256
|
keras_hub/src/models/mixtral/mixtral_causal_lm_preprocessor.py,sha256=q2qXa9QAUWBvOWv9DeNvwsBNXSORJAbQFoQsWQ7e8V8,3079
|
@@ -311,7 +311,7 @@ keras_hub/src/models/qwen/qwen_layernorm.py,sha256=DS35r3qd6g5ocL7Nhf_vNzLLMo1aI
|
|
311
311
|
keras_hub/src/models/qwen/qwen_presets.py,sha256=_jRG7bB4yBGWteBLbK2elc1e9doRl8zdzQRZgxFvnfc,1988
|
312
312
|
keras_hub/src/models/qwen/qwen_tokenizer.py,sha256=LCv3IyiDDHqVnM9N3lf5-BE3iwicIh0nKS1hjoPw9lE,1532
|
313
313
|
keras_hub/src/models/qwen_moe/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
314
|
-
keras_hub/src/models/qwen_moe/qwen_moe_attention.py,sha256=
|
314
|
+
keras_hub/src/models/qwen_moe/qwen_moe_attention.py,sha256=pE79_iHUm2LGkoWL6zMJw_pNfzIvmyq3yJaiq47W2TY,13242
|
315
315
|
keras_hub/src/models/qwen_moe/qwen_moe_backbone.py,sha256=nrfELvIvRLmrgKrUNXci2CrecmeI6bWzJj7HH-RcWJA,15341
|
316
316
|
keras_hub/src/models/qwen_moe/qwen_moe_causal_lm.py,sha256=MeP60v7GcN_SmH5_ULRpqgmFVgaYAosSecZiSQVlJvU,13256
|
317
317
|
keras_hub/src/models/qwen_moe/qwen_moe_causal_lm_preprocessor.py,sha256=uKaXRrJs02vkVudjdehzJPp0B84tPMkxNHlp166kceE,589
|
@@ -490,7 +490,7 @@ keras_hub/src/utils/transformers/convert_gemma.py,sha256=ElCgwBpSN5Q7rV5PJawTsoy
|
|
490
490
|
keras_hub/src/utils/transformers/convert_gpt2.py,sha256=HCeHN_-GiQJRxLCM9OCJJ1watPVpIBF8ujS8pGbBOWc,5703
|
491
491
|
keras_hub/src/utils/transformers/convert_llama3.py,sha256=c5phNl-QayQ_BS0s-lenbu6oHxqfwDShKJoh9DluxUU,6146
|
492
492
|
keras_hub/src/utils/transformers/convert_mistral.py,sha256=kVhN9h1ZFVhwkNW8p3wnS7eANJUXIsNy1RxWXy20Gqw,4760
|
493
|
-
keras_hub/src/utils/transformers/convert_mixtral.py,sha256=
|
493
|
+
keras_hub/src/utils/transformers/convert_mixtral.py,sha256=PxeCY8Xe7U_caICugwOCEjuSZ51ZUtmef6rUxh-Wt54,5508
|
494
494
|
keras_hub/src/utils/transformers/convert_pali_gemma.py,sha256=B1leeDw96Yvu81hYumf66hIid07k5NLqoeWAJgPnaLs,10649
|
495
495
|
keras_hub/src/utils/transformers/convert_qwen.py,sha256=WUxMAEFVqRs7TRw7QU5TH3_ev4yf02R1xFVliMvTQqg,5886
|
496
496
|
keras_hub/src/utils/transformers/convert_qwen_moe.py,sha256=a7R28aln-PdAcNuKAXdrtzvslho2Co6GypChxLMKPpc,10618
|
@@ -499,7 +499,7 @@ keras_hub/src/utils/transformers/preset_loader.py,sha256=1nfS5xVsl-JROGXJXltTqV1
|
|
499
499
|
keras_hub/src/utils/transformers/safetensor_utils.py,sha256=CYUHyA4y-B61r7NDnCsFb4t_UmSwZ1k9L-8gzEd6KRg,3339
|
500
500
|
keras_hub/tokenizers/__init__.py,sha256=uMjjm0mzUkRb0e4Ac_JK8aJ9cKGUi5UqmzWoWAFJprE,4164
|
501
501
|
keras_hub/utils/__init__.py,sha256=jXPqVGBpJr_PpYmqD8aDG-fRMlxH-ulqCR2SZMn288Y,646
|
502
|
-
keras_hub_nightly-0.21.0.
|
503
|
-
keras_hub_nightly-0.21.0.
|
504
|
-
keras_hub_nightly-0.21.0.
|
505
|
-
keras_hub_nightly-0.21.0.
|
502
|
+
keras_hub_nightly-0.21.0.dev202505220409.dist-info/METADATA,sha256=EqRkCDIuHYBX4sLxSObub9YnmlNwhf_d2-IKG1tm4Xw,7393
|
503
|
+
keras_hub_nightly-0.21.0.dev202505220409.dist-info/WHEEL,sha256=zaaOINJESkSfm_4HQVc5ssNzHCPXhJm0kEUakpsEHaU,91
|
504
|
+
keras_hub_nightly-0.21.0.dev202505220409.dist-info/top_level.txt,sha256=N4J6piIWBKa38A4uV-CnIopnOEf8mHAbkNXafXm_CuA,10
|
505
|
+
keras_hub_nightly-0.21.0.dev202505220409.dist-info/RECORD,,
|