keras-hub-nightly 0.20.0.dev202503170356__py3-none-any.whl → 0.20.0.dev202503190355__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/models/__init__.py +18 -0
- keras_hub/api/tokenizers/__init__.py +4 -0
- keras_hub/src/layers/preprocessing/image_converter.py +5 -2
- keras_hub/src/models/qwen/__init__.py +1 -0
- keras_hub/src/models/qwen/qwen_attention.py +358 -0
- keras_hub/src/models/qwen/qwen_backbone.py +327 -0
- keras_hub/src/models/qwen/qwen_causal_lm.py +300 -0
- keras_hub/src/models/qwen/qwen_causal_lm_preprocessor.py +18 -0
- keras_hub/src/models/qwen/qwen_decoder.py +311 -0
- keras_hub/src/models/qwen/qwen_layernorm.py +32 -0
- keras_hub/src/models/qwen/qwen_tokenizer.py +51 -0
- keras_hub/src/utils/transformers/convert_qwen.py +148 -0
- keras_hub/src/utils/transformers/preset_loader.py +3 -0
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.20.0.dev202503170356.dist-info → keras_hub_nightly-0.20.0.dev202503190355.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.20.0.dev202503170356.dist-info → keras_hub_nightly-0.20.0.dev202503190355.dist-info}/RECORD +18 -9
- {keras_hub_nightly-0.20.0.dev202503170356.dist-info → keras_hub_nightly-0.20.0.dev202503190355.dist-info}/WHEEL +1 -1
- {keras_hub_nightly-0.20.0.dev202503170356.dist-info → keras_hub_nightly-0.20.0.dev202503190355.dist-info}/top_level.txt +0 -0
keras_hub/api/models/__init__.py
CHANGED
@@ -267,6 +267,24 @@ from keras_hub.src.models.phi3.phi3_causal_lm_preprocessor import (
|
|
267
267
|
)
|
268
268
|
from keras_hub.src.models.phi3.phi3_tokenizer import Phi3Tokenizer
|
269
269
|
from keras_hub.src.models.preprocessor import Preprocessor
|
270
|
+
from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone
|
271
|
+
from keras_hub.src.models.qwen.qwen_backbone import (
|
272
|
+
QwenBackbone as Qwen2Backbone,
|
273
|
+
)
|
274
|
+
from keras_hub.src.models.qwen.qwen_causal_lm import QwenCausalLM
|
275
|
+
from keras_hub.src.models.qwen.qwen_causal_lm import (
|
276
|
+
QwenCausalLM as Qwen2CausalLM,
|
277
|
+
)
|
278
|
+
from keras_hub.src.models.qwen.qwen_causal_lm_preprocessor import (
|
279
|
+
QwenCausalLMPreprocessor,
|
280
|
+
)
|
281
|
+
from keras_hub.src.models.qwen.qwen_causal_lm_preprocessor import (
|
282
|
+
QwenCausalLMPreprocessor as Qwen2CausalLMPreprocessor,
|
283
|
+
)
|
284
|
+
from keras_hub.src.models.qwen.qwen_tokenizer import QwenTokenizer
|
285
|
+
from keras_hub.src.models.qwen.qwen_tokenizer import (
|
286
|
+
QwenTokenizer as Qwen2Tokenizer,
|
287
|
+
)
|
270
288
|
from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone
|
271
289
|
from keras_hub.src.models.resnet.resnet_image_classifier import (
|
272
290
|
ResNetImageClassifier,
|
@@ -29,6 +29,10 @@ from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import (
|
|
29
29
|
PaliGemmaTokenizer,
|
30
30
|
)
|
31
31
|
from keras_hub.src.models.phi3.phi3_tokenizer import Phi3Tokenizer
|
32
|
+
from keras_hub.src.models.qwen.qwen_tokenizer import QwenTokenizer
|
33
|
+
from keras_hub.src.models.qwen.qwen_tokenizer import (
|
34
|
+
QwenTokenizer as Qwen2Tokenizer,
|
35
|
+
)
|
32
36
|
from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer
|
33
37
|
from keras_hub.src.models.siglip.siglip_tokenizer import SigLIPTokenizer
|
34
38
|
from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer
|
@@ -280,6 +280,8 @@ class ImageConverter(PreprocessingLayer):
|
|
280
280
|
return inputs
|
281
281
|
|
282
282
|
def _expand_non_channel_dims(self, value, inputs):
|
283
|
+
input_dtype = keras.backend.standardize_dtype(inputs.dtype)
|
284
|
+
|
283
285
|
unbatched = len(ops.shape(inputs)) == 3
|
284
286
|
channels_first = self.data_format == "channels_first"
|
285
287
|
if unbatched:
|
@@ -294,9 +296,10 @@ class ImageConverter(PreprocessingLayer):
|
|
294
296
|
# device (potentially GPU) after preprocessing.
|
295
297
|
if keras.backend.backend() == "torch" and self.image_size is None:
|
296
298
|
return ops.expand_dims(value, broadcast_dims).cpu()
|
297
|
-
|
299
|
+
expanded = ops.expand_dims(value, broadcast_dims)
|
300
|
+
return ops.cast(expanded, input_dtype)
|
298
301
|
else:
|
299
|
-
return np.expand_dims(value, broadcast_dims)
|
302
|
+
return np.expand_dims(value, broadcast_dims).astype(input_dtype)
|
300
303
|
|
301
304
|
def get_config(self):
|
302
305
|
config = super().get_config()
|
@@ -0,0 +1 @@
|
|
1
|
+
from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone
|
@@ -0,0 +1,358 @@
|
|
1
|
+
import math
|
2
|
+
|
3
|
+
import keras
|
4
|
+
from keras import ops
|
5
|
+
|
6
|
+
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
|
7
|
+
from keras_hub.src.utils.keras_utils import clone_initializer
|
8
|
+
from keras_hub.src.utils.keras_utils import has_flash_attention_support
|
9
|
+
|
10
|
+
|
11
|
+
class QwenAttention(keras.layers.Layer):
|
12
|
+
"""A multi-head attention layer for Qwen models
|
13
|
+
|
14
|
+
This attention implementation supports grouped-query attention (GQA) where
|
15
|
+
the number of key-value heads can be less than the number of query heads.
|
16
|
+
|
17
|
+
Args:
|
18
|
+
num_query_heads: Number of query heads.
|
19
|
+
num_key_value_heads: Number of key/value heads (for GQA).
|
20
|
+
rope_max_wavelength: Maximum wavelength for RoPE (Rotary Position
|
21
|
+
Embedding).
|
22
|
+
rope_scaling_factor: Scaling factor for RoPE, used for extending
|
23
|
+
context length.
|
24
|
+
kernel_initializer: Initializer for the kernel weights.
|
25
|
+
bias_initializer: Initializer for the bias weights.
|
26
|
+
dropout: Dropout rate for attention weights.
|
27
|
+
use_sliding_window_attention: Whether to use sliding window
|
28
|
+
attention.
|
29
|
+
sliding_window_size: Size of the sliding window for attention.
|
30
|
+
**kwargs: Additional keyword arguments to pass to the Layer.
|
31
|
+
"""
|
32
|
+
|
33
|
+
def __init__(
|
34
|
+
self,
|
35
|
+
num_query_heads,
|
36
|
+
num_key_value_heads,
|
37
|
+
rope_max_wavelength=10000,
|
38
|
+
rope_scaling_factor=1,
|
39
|
+
kernel_initializer="glorot_uniform",
|
40
|
+
bias_initializer="zeros",
|
41
|
+
dropout=0,
|
42
|
+
use_sliding_window_attention=False,
|
43
|
+
sliding_window_size=4096,
|
44
|
+
**kwargs,
|
45
|
+
):
|
46
|
+
super().__init__(
|
47
|
+
**kwargs,
|
48
|
+
)
|
49
|
+
self.num_query_heads = num_query_heads
|
50
|
+
self.num_key_value_heads = num_key_value_heads
|
51
|
+
self.dropout = dropout
|
52
|
+
|
53
|
+
self.num_key_value_groups = num_query_heads // num_key_value_heads
|
54
|
+
self.rope_max_wavelength = rope_max_wavelength
|
55
|
+
|
56
|
+
self.kernel_initializer = keras.initializers.get(
|
57
|
+
clone_initializer(kernel_initializer)
|
58
|
+
)
|
59
|
+
self.bias_initializer = keras.initializers.get(
|
60
|
+
clone_initializer(bias_initializer)
|
61
|
+
)
|
62
|
+
|
63
|
+
self.rope_scaling_factor = rope_scaling_factor
|
64
|
+
self.use_sliding_window_attention = use_sliding_window_attention
|
65
|
+
self.sliding_window_size = sliding_window_size
|
66
|
+
|
67
|
+
def build(self, inputs_shape):
|
68
|
+
# Einsum variables:
|
69
|
+
# b = batch size
|
70
|
+
# q = query length
|
71
|
+
# k = key/value length
|
72
|
+
# m = model dim
|
73
|
+
# u = num query heads
|
74
|
+
# v = num key/value heads
|
75
|
+
# h = head dim
|
76
|
+
hidden_dim = inputs_shape[-1]
|
77
|
+
head_dim = hidden_dim // self.num_query_heads
|
78
|
+
self._inv_norm_factor = 1.0 / math.sqrt(head_dim)
|
79
|
+
self._query_dense = keras.layers.EinsumDense(
|
80
|
+
equation="bqm,muh->bquh",
|
81
|
+
output_shape=(None, self.num_query_heads, head_dim),
|
82
|
+
kernel_initializer=self.kernel_initializer,
|
83
|
+
bias_initializer=self.bias_initializer,
|
84
|
+
bias_axes="uh",
|
85
|
+
dtype=self.dtype_policy,
|
86
|
+
name="query",
|
87
|
+
)
|
88
|
+
self._query_dense.build(inputs_shape)
|
89
|
+
|
90
|
+
self._key_dense = keras.layers.EinsumDense(
|
91
|
+
equation="bkm,mvh->bkvh",
|
92
|
+
output_shape=(
|
93
|
+
None,
|
94
|
+
self.num_key_value_heads,
|
95
|
+
head_dim,
|
96
|
+
),
|
97
|
+
kernel_initializer=self.kernel_initializer,
|
98
|
+
bias_initializer=self.bias_initializer,
|
99
|
+
bias_axes="vh",
|
100
|
+
dtype=self.dtype_policy,
|
101
|
+
name="key",
|
102
|
+
)
|
103
|
+
self._key_dense.build(inputs_shape)
|
104
|
+
|
105
|
+
self._value_dense = keras.layers.EinsumDense(
|
106
|
+
equation="bkm,mvh->bkvh",
|
107
|
+
output_shape=(
|
108
|
+
None,
|
109
|
+
self.num_key_value_heads,
|
110
|
+
head_dim,
|
111
|
+
),
|
112
|
+
kernel_initializer=self.kernel_initializer,
|
113
|
+
bias_initializer=self.bias_initializer,
|
114
|
+
bias_axes="vh",
|
115
|
+
dtype=self.dtype_policy,
|
116
|
+
name="value",
|
117
|
+
)
|
118
|
+
self._value_dense.build(inputs_shape)
|
119
|
+
|
120
|
+
self._softmax = keras.layers.Softmax(
|
121
|
+
axis=-1,
|
122
|
+
dtype="float32",
|
123
|
+
name="attention_softmax",
|
124
|
+
)
|
125
|
+
|
126
|
+
self._dropout_layer = keras.layers.Dropout(
|
127
|
+
rate=self.dropout,
|
128
|
+
dtype=self.dtype_policy,
|
129
|
+
)
|
130
|
+
|
131
|
+
self._output_dense = keras.layers.EinsumDense(
|
132
|
+
equation="bquh,uhm->bqm",
|
133
|
+
output_shape=(None, hidden_dim),
|
134
|
+
kernel_initializer=self.kernel_initializer,
|
135
|
+
dtype=self.dtype_policy,
|
136
|
+
name="attention_output",
|
137
|
+
)
|
138
|
+
self._output_dense.build((None, None, self.num_query_heads, head_dim))
|
139
|
+
|
140
|
+
self.rotary_embedding_layer = RotaryEmbedding(
|
141
|
+
max_wavelength=self.rope_max_wavelength,
|
142
|
+
scaling_factor=self.rope_scaling_factor,
|
143
|
+
dtype=self.dtype_policy,
|
144
|
+
)
|
145
|
+
|
146
|
+
self._dot_product_equation = "bquh,bkuh->buqk"
|
147
|
+
self._combine_equation = "buqk,bkuh->bquh"
|
148
|
+
|
149
|
+
self.built = True
|
150
|
+
|
151
|
+
def call(
|
152
|
+
self,
|
153
|
+
hidden_states,
|
154
|
+
attention_mask=None,
|
155
|
+
cache=None,
|
156
|
+
cache_update_index=None,
|
157
|
+
training=None,
|
158
|
+
):
|
159
|
+
"""Applies attention mechanism to the input hidden states.
|
160
|
+
|
161
|
+
Args:
|
162
|
+
hidden_states: Input tensor of shape [batch_size, seq_length,
|
163
|
+
hidden_size].
|
164
|
+
attention_mask: Mask tensor of shape [batch_size, seq_length,
|
165
|
+
seq_length].
|
166
|
+
cache: Optional cached key and value tensors.
|
167
|
+
cache_update_index: Index at which to update the cache.
|
168
|
+
training: Boolean indicating whether in training mode.
|
169
|
+
|
170
|
+
Returns:
|
171
|
+
attention_output: Output tensor after applying attention.
|
172
|
+
cache: Updated cache tensors (if cache is provided).
|
173
|
+
"""
|
174
|
+
start_index = (
|
175
|
+
cache_update_index if cache_update_index is not None else 0
|
176
|
+
)
|
177
|
+
|
178
|
+
query = self._query_dense(hidden_states)
|
179
|
+
|
180
|
+
# Compute RoPE for queries
|
181
|
+
query = self.rotary_embedding_layer(query, start_index=start_index)
|
182
|
+
|
183
|
+
def _compute_key_value(x):
|
184
|
+
key, value = self._key_dense(x), self._value_dense(x)
|
185
|
+
# Compute RoPE for keys
|
186
|
+
key = self.rotary_embedding_layer(key, start_index=start_index)
|
187
|
+
return key, value
|
188
|
+
|
189
|
+
if cache is not None:
|
190
|
+
key_cache = cache[:, 0, ...]
|
191
|
+
value_cache = cache[:, 1, ...]
|
192
|
+
if cache_update_index is None:
|
193
|
+
key = key_cache
|
194
|
+
value = value_cache
|
195
|
+
else:
|
196
|
+
key_update, value_update = _compute_key_value(hidden_states)
|
197
|
+
start = [0, cache_update_index, 0, 0]
|
198
|
+
key = ops.slice_update(key_cache, start, key_update)
|
199
|
+
value = ops.slice_update(value_cache, start, value_update)
|
200
|
+
cache = ops.stack((key, value), axis=1)
|
201
|
+
else:
|
202
|
+
if cache_update_index is not None:
|
203
|
+
raise ValueError(
|
204
|
+
"`cache_update_index` should not be set if `cache` is "
|
205
|
+
f"`None`. Received: cache={cache}, "
|
206
|
+
f"cache_update_index={cache_update_index}"
|
207
|
+
)
|
208
|
+
key, value = _compute_key_value(hidden_states)
|
209
|
+
|
210
|
+
# [batch_shape, seq_len, num_key_value_heads, head_dim]
|
211
|
+
# -> [batch_shape, seq_len, num_heads, head_dim]
|
212
|
+
key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2)
|
213
|
+
value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2)
|
214
|
+
|
215
|
+
attention_output = self._compute_attention(
|
216
|
+
query,
|
217
|
+
key,
|
218
|
+
value,
|
219
|
+
attention_mask,
|
220
|
+
cache_update_index=cache_update_index,
|
221
|
+
)
|
222
|
+
|
223
|
+
attention_output = self._dropout_layer(
|
224
|
+
attention_output, training=training
|
225
|
+
)
|
226
|
+
|
227
|
+
attention_output = self._output_dense(attention_output)
|
228
|
+
|
229
|
+
if cache is not None:
|
230
|
+
return attention_output, cache
|
231
|
+
return attention_output
|
232
|
+
|
233
|
+
def _masked_softmax(self, attention_scores, attention_mask=None):
|
234
|
+
"""Applies softmax with optional masking.
|
235
|
+
|
236
|
+
Args:
|
237
|
+
attention_scores: Attention score tensor.
|
238
|
+
attention_mask: Optional mask tensor.
|
239
|
+
|
240
|
+
Returns:
|
241
|
+
Masked softmax attention weights.
|
242
|
+
"""
|
243
|
+
if attention_mask is not None:
|
244
|
+
return self._softmax(
|
245
|
+
attention_scores, attention_mask[:, None, :, :]
|
246
|
+
)
|
247
|
+
return self._softmax(attention_scores)
|
248
|
+
|
249
|
+
def _compute_attention(
|
250
|
+
self, query, key, value, attention_mask=None, cache_update_index=None
|
251
|
+
):
|
252
|
+
"""Computes attention using query, key, and value tensors.
|
253
|
+
|
254
|
+
Uses Flash Attention when available for better performance.
|
255
|
+
|
256
|
+
Args:
|
257
|
+
query: Query tensor.
|
258
|
+
key: Key tensor.
|
259
|
+
value: Value tensor.
|
260
|
+
attention_mask: Optional mask tensor.
|
261
|
+
cache_update_index: Index for sliding window computation.
|
262
|
+
|
263
|
+
Returns:
|
264
|
+
attention_output: Output tensor after applying attention.
|
265
|
+
"""
|
266
|
+
if has_flash_attention_support():
|
267
|
+
# Use `dot_product_attention` with Flash Attention support if
|
268
|
+
# available.
|
269
|
+
if attention_mask is not None:
|
270
|
+
attention_mask = ops.expand_dims(attention_mask, axis=1)
|
271
|
+
attention_mask = ops.cast(attention_mask, dtype="bool")
|
272
|
+
attention_output = ops.dot_product_attention(
|
273
|
+
query,
|
274
|
+
key,
|
275
|
+
value,
|
276
|
+
mask=attention_mask,
|
277
|
+
scale=self._inv_norm_factor,
|
278
|
+
)
|
279
|
+
return attention_output
|
280
|
+
|
281
|
+
attention_scores = ops.einsum(self._dot_product_equation, query, key)
|
282
|
+
|
283
|
+
attention_scores = ops.multiply(
|
284
|
+
attention_scores,
|
285
|
+
ops.cast(self._inv_norm_factor, self.compute_dtype),
|
286
|
+
)
|
287
|
+
if self.use_sliding_window_attention:
|
288
|
+
attention_mask = self._mask_sliding_window(
|
289
|
+
attention_mask,
|
290
|
+
cache_update_index=cache_update_index,
|
291
|
+
)
|
292
|
+
attention_scores = self._masked_softmax(
|
293
|
+
attention_scores, attention_mask
|
294
|
+
)
|
295
|
+
attention_scores = ops.cast(attention_scores, self.compute_dtype)
|
296
|
+
attention_output = ops.einsum(
|
297
|
+
self._combine_equation, attention_scores, value
|
298
|
+
)
|
299
|
+
|
300
|
+
return attention_output
|
301
|
+
|
302
|
+
def _mask_sliding_window(
|
303
|
+
self,
|
304
|
+
attention_mask,
|
305
|
+
cache_update_index=0,
|
306
|
+
):
|
307
|
+
"""Creates and combines a sliding window mask with the attention mask.
|
308
|
+
|
309
|
+
Args:
|
310
|
+
attention_mask: Original attention mask.
|
311
|
+
cache_update_index: Starting index for the sliding window.
|
312
|
+
|
313
|
+
Returns:
|
314
|
+
Combined attention mask with sliding window constraints.
|
315
|
+
"""
|
316
|
+
_, query_len, key_len = ops.shape(attention_mask)
|
317
|
+
# Compute the sliding window for square attention.
|
318
|
+
all_ones = ops.ones((key_len, key_len), "bool")
|
319
|
+
if keras.config.backend() == "tensorflow":
|
320
|
+
# TODO: trui/tril has issues with dynamic shape on the tensorflow
|
321
|
+
# backend. We should fix, but use `band_part` for now.
|
322
|
+
import tensorflow as tf
|
323
|
+
|
324
|
+
band_size = ops.minimum(key_len, self.sliding_window_size - 1)
|
325
|
+
band_size = ops.cast(band_size, "int32")
|
326
|
+
sliding_mask = tf.linalg.band_part(all_ones, band_size, band_size)
|
327
|
+
else:
|
328
|
+
sliding_mask = ops.triu(
|
329
|
+
all_ones, -1 * self.sliding_window_size + 1
|
330
|
+
) * ops.tril(all_ones, self.sliding_window_size - 1)
|
331
|
+
# Slice the window for short queries during generation.
|
332
|
+
start = (cache_update_index, 0)
|
333
|
+
sliding_mask = ops.slice(sliding_mask, start, (query_len, key_len))
|
334
|
+
sliding_mask = ops.expand_dims(sliding_mask, 0)
|
335
|
+
return ops.logical_and(attention_mask, ops.cast(sliding_mask, "bool"))
|
336
|
+
|
337
|
+
def get_config(self):
|
338
|
+
config = super().get_config()
|
339
|
+
config.update(
|
340
|
+
{
|
341
|
+
"num_query_heads": self.num_query_heads,
|
342
|
+
"num_key_value_heads": self.num_key_value_heads,
|
343
|
+
"rope_max_wavelength": self.rope_max_wavelength,
|
344
|
+
"rope_scaling_factor": self.rope_scaling_factor,
|
345
|
+
"kernel_initializer": keras.initializers.serialize(
|
346
|
+
self.kernel_initializer
|
347
|
+
),
|
348
|
+
"bias_initializer": keras.initializers.serialize(
|
349
|
+
self.bias_initializer
|
350
|
+
),
|
351
|
+
"dropout": self.dropout,
|
352
|
+
"use_sliding_window_attention": (
|
353
|
+
self.use_sliding_window_attention
|
354
|
+
),
|
355
|
+
"sliding_window_size": self.sliding_window_size,
|
356
|
+
}
|
357
|
+
)
|
358
|
+
return config
|