keras-hub-nightly 0.20.0.dev202503170356__py3-none-any.whl → 0.20.0.dev202503180354__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -267,6 +267,24 @@ from keras_hub.src.models.phi3.phi3_causal_lm_preprocessor import (
267
267
  )
268
268
  from keras_hub.src.models.phi3.phi3_tokenizer import Phi3Tokenizer
269
269
  from keras_hub.src.models.preprocessor import Preprocessor
270
+ from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone
271
+ from keras_hub.src.models.qwen.qwen_backbone import (
272
+ QwenBackbone as Qwen2Backbone,
273
+ )
274
+ from keras_hub.src.models.qwen.qwen_causal_lm import QwenCausalLM
275
+ from keras_hub.src.models.qwen.qwen_causal_lm import (
276
+ QwenCausalLM as Qwen2CausalLM,
277
+ )
278
+ from keras_hub.src.models.qwen.qwen_causal_lm_preprocessor import (
279
+ QwenCausalLMPreprocessor,
280
+ )
281
+ from keras_hub.src.models.qwen.qwen_causal_lm_preprocessor import (
282
+ QwenCausalLMPreprocessor as Qwen2CausalLMPreprocessor,
283
+ )
284
+ from keras_hub.src.models.qwen.qwen_tokenizer import QwenTokenizer
285
+ from keras_hub.src.models.qwen.qwen_tokenizer import (
286
+ QwenTokenizer as Qwen2Tokenizer,
287
+ )
270
288
  from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone
271
289
  from keras_hub.src.models.resnet.resnet_image_classifier import (
272
290
  ResNetImageClassifier,
@@ -29,6 +29,10 @@ from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import (
29
29
  PaliGemmaTokenizer,
30
30
  )
31
31
  from keras_hub.src.models.phi3.phi3_tokenizer import Phi3Tokenizer
32
+ from keras_hub.src.models.qwen.qwen_tokenizer import QwenTokenizer
33
+ from keras_hub.src.models.qwen.qwen_tokenizer import (
34
+ QwenTokenizer as Qwen2Tokenizer,
35
+ )
32
36
  from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer
33
37
  from keras_hub.src.models.siglip.siglip_tokenizer import SigLIPTokenizer
34
38
  from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer
@@ -280,6 +280,8 @@ class ImageConverter(PreprocessingLayer):
280
280
  return inputs
281
281
 
282
282
  def _expand_non_channel_dims(self, value, inputs):
283
+ input_dtype = keras.backend.standardize_dtype(inputs.dtype)
284
+
283
285
  unbatched = len(ops.shape(inputs)) == 3
284
286
  channels_first = self.data_format == "channels_first"
285
287
  if unbatched:
@@ -294,9 +296,10 @@ class ImageConverter(PreprocessingLayer):
294
296
  # device (potentially GPU) after preprocessing.
295
297
  if keras.backend.backend() == "torch" and self.image_size is None:
296
298
  return ops.expand_dims(value, broadcast_dims).cpu()
297
- return ops.expand_dims(value, broadcast_dims)
299
+ expanded = ops.expand_dims(value, broadcast_dims)
300
+ return ops.cast(expanded, input_dtype)
298
301
  else:
299
- return np.expand_dims(value, broadcast_dims)
302
+ return np.expand_dims(value, broadcast_dims).astype(input_dtype)
300
303
 
301
304
  def get_config(self):
302
305
  config = super().get_config()
@@ -0,0 +1 @@
1
+ from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone
@@ -0,0 +1,358 @@
1
+ import math
2
+
3
+ import keras
4
+ from keras import ops
5
+
6
+ from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
7
+ from keras_hub.src.utils.keras_utils import clone_initializer
8
+ from keras_hub.src.utils.keras_utils import has_flash_attention_support
9
+
10
+
11
+ class QwenAttention(keras.layers.Layer):
12
+ """A multi-head attention layer for Qwen models
13
+
14
+ This attention implementation supports grouped-query attention (GQA) where
15
+ the number of key-value heads can be less than the number of query heads.
16
+
17
+ Args:
18
+ num_query_heads: Number of query heads.
19
+ num_key_value_heads: Number of key/value heads (for GQA).
20
+ rope_max_wavelength: Maximum wavelength for RoPE (Rotary Position
21
+ Embedding).
22
+ rope_scaling_factor: Scaling factor for RoPE, used for extending
23
+ context length.
24
+ kernel_initializer: Initializer for the kernel weights.
25
+ bias_initializer: Initializer for the bias weights.
26
+ dropout: Dropout rate for attention weights.
27
+ use_sliding_window_attention: Whether to use sliding window
28
+ attention.
29
+ sliding_window_size: Size of the sliding window for attention.
30
+ **kwargs: Additional keyword arguments to pass to the Layer.
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ num_query_heads,
36
+ num_key_value_heads,
37
+ rope_max_wavelength=10000,
38
+ rope_scaling_factor=1,
39
+ kernel_initializer="glorot_uniform",
40
+ bias_initializer="zeros",
41
+ dropout=0,
42
+ use_sliding_window_attention=False,
43
+ sliding_window_size=4096,
44
+ **kwargs,
45
+ ):
46
+ super().__init__(
47
+ **kwargs,
48
+ )
49
+ self.num_query_heads = num_query_heads
50
+ self.num_key_value_heads = num_key_value_heads
51
+ self.dropout = dropout
52
+
53
+ self.num_key_value_groups = num_query_heads // num_key_value_heads
54
+ self.rope_max_wavelength = rope_max_wavelength
55
+
56
+ self.kernel_initializer = keras.initializers.get(
57
+ clone_initializer(kernel_initializer)
58
+ )
59
+ self.bias_initializer = keras.initializers.get(
60
+ clone_initializer(bias_initializer)
61
+ )
62
+
63
+ self.rope_scaling_factor = rope_scaling_factor
64
+ self.use_sliding_window_attention = use_sliding_window_attention
65
+ self.sliding_window_size = sliding_window_size
66
+
67
+ def build(self, inputs_shape):
68
+ # Einsum variables:
69
+ # b = batch size
70
+ # q = query length
71
+ # k = key/value length
72
+ # m = model dim
73
+ # u = num query heads
74
+ # v = num key/value heads
75
+ # h = head dim
76
+ hidden_dim = inputs_shape[-1]
77
+ head_dim = hidden_dim // self.num_query_heads
78
+ self._inv_norm_factor = 1.0 / math.sqrt(head_dim)
79
+ self._query_dense = keras.layers.EinsumDense(
80
+ equation="bqm,muh->bquh",
81
+ output_shape=(None, self.num_query_heads, head_dim),
82
+ kernel_initializer=self.kernel_initializer,
83
+ bias_initializer=self.bias_initializer,
84
+ bias_axes="uh",
85
+ dtype=self.dtype_policy,
86
+ name="query",
87
+ )
88
+ self._query_dense.build(inputs_shape)
89
+
90
+ self._key_dense = keras.layers.EinsumDense(
91
+ equation="bkm,mvh->bkvh",
92
+ output_shape=(
93
+ None,
94
+ self.num_key_value_heads,
95
+ head_dim,
96
+ ),
97
+ kernel_initializer=self.kernel_initializer,
98
+ bias_initializer=self.bias_initializer,
99
+ bias_axes="vh",
100
+ dtype=self.dtype_policy,
101
+ name="key",
102
+ )
103
+ self._key_dense.build(inputs_shape)
104
+
105
+ self._value_dense = keras.layers.EinsumDense(
106
+ equation="bkm,mvh->bkvh",
107
+ output_shape=(
108
+ None,
109
+ self.num_key_value_heads,
110
+ head_dim,
111
+ ),
112
+ kernel_initializer=self.kernel_initializer,
113
+ bias_initializer=self.bias_initializer,
114
+ bias_axes="vh",
115
+ dtype=self.dtype_policy,
116
+ name="value",
117
+ )
118
+ self._value_dense.build(inputs_shape)
119
+
120
+ self._softmax = keras.layers.Softmax(
121
+ axis=-1,
122
+ dtype="float32",
123
+ name="attention_softmax",
124
+ )
125
+
126
+ self._dropout_layer = keras.layers.Dropout(
127
+ rate=self.dropout,
128
+ dtype=self.dtype_policy,
129
+ )
130
+
131
+ self._output_dense = keras.layers.EinsumDense(
132
+ equation="bquh,uhm->bqm",
133
+ output_shape=(None, hidden_dim),
134
+ kernel_initializer=self.kernel_initializer,
135
+ dtype=self.dtype_policy,
136
+ name="attention_output",
137
+ )
138
+ self._output_dense.build((None, None, self.num_query_heads, head_dim))
139
+
140
+ self.rotary_embedding_layer = RotaryEmbedding(
141
+ max_wavelength=self.rope_max_wavelength,
142
+ scaling_factor=self.rope_scaling_factor,
143
+ dtype=self.dtype_policy,
144
+ )
145
+
146
+ self._dot_product_equation = "bquh,bkuh->buqk"
147
+ self._combine_equation = "buqk,bkuh->bquh"
148
+
149
+ self.built = True
150
+
151
+ def call(
152
+ self,
153
+ hidden_states,
154
+ attention_mask=None,
155
+ cache=None,
156
+ cache_update_index=None,
157
+ training=None,
158
+ ):
159
+ """Applies attention mechanism to the input hidden states.
160
+
161
+ Args:
162
+ hidden_states: Input tensor of shape [batch_size, seq_length,
163
+ hidden_size].
164
+ attention_mask: Mask tensor of shape [batch_size, seq_length,
165
+ seq_length].
166
+ cache: Optional cached key and value tensors.
167
+ cache_update_index: Index at which to update the cache.
168
+ training: Boolean indicating whether in training mode.
169
+
170
+ Returns:
171
+ attention_output: Output tensor after applying attention.
172
+ cache: Updated cache tensors (if cache is provided).
173
+ """
174
+ start_index = (
175
+ cache_update_index if cache_update_index is not None else 0
176
+ )
177
+
178
+ query = self._query_dense(hidden_states)
179
+
180
+ # Compute RoPE for queries
181
+ query = self.rotary_embedding_layer(query, start_index=start_index)
182
+
183
+ def _compute_key_value(x):
184
+ key, value = self._key_dense(x), self._value_dense(x)
185
+ # Compute RoPE for keys
186
+ key = self.rotary_embedding_layer(key, start_index=start_index)
187
+ return key, value
188
+
189
+ if cache is not None:
190
+ key_cache = cache[:, 0, ...]
191
+ value_cache = cache[:, 1, ...]
192
+ if cache_update_index is None:
193
+ key = key_cache
194
+ value = value_cache
195
+ else:
196
+ key_update, value_update = _compute_key_value(hidden_states)
197
+ start = [0, cache_update_index, 0, 0]
198
+ key = ops.slice_update(key_cache, start, key_update)
199
+ value = ops.slice_update(value_cache, start, value_update)
200
+ cache = ops.stack((key, value), axis=1)
201
+ else:
202
+ if cache_update_index is not None:
203
+ raise ValueError(
204
+ "`cache_update_index` should not be set if `cache` is "
205
+ f"`None`. Received: cache={cache}, "
206
+ f"cache_update_index={cache_update_index}"
207
+ )
208
+ key, value = _compute_key_value(hidden_states)
209
+
210
+ # [batch_shape, seq_len, num_key_value_heads, head_dim]
211
+ # -> [batch_shape, seq_len, num_heads, head_dim]
212
+ key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2)
213
+ value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2)
214
+
215
+ attention_output = self._compute_attention(
216
+ query,
217
+ key,
218
+ value,
219
+ attention_mask,
220
+ cache_update_index=cache_update_index,
221
+ )
222
+
223
+ attention_output = self._dropout_layer(
224
+ attention_output, training=training
225
+ )
226
+
227
+ attention_output = self._output_dense(attention_output)
228
+
229
+ if cache is not None:
230
+ return attention_output, cache
231
+ return attention_output
232
+
233
+ def _masked_softmax(self, attention_scores, attention_mask=None):
234
+ """Applies softmax with optional masking.
235
+
236
+ Args:
237
+ attention_scores: Attention score tensor.
238
+ attention_mask: Optional mask tensor.
239
+
240
+ Returns:
241
+ Masked softmax attention weights.
242
+ """
243
+ if attention_mask is not None:
244
+ return self._softmax(
245
+ attention_scores, attention_mask[:, None, :, :]
246
+ )
247
+ return self._softmax(attention_scores)
248
+
249
+ def _compute_attention(
250
+ self, query, key, value, attention_mask=None, cache_update_index=None
251
+ ):
252
+ """Computes attention using query, key, and value tensors.
253
+
254
+ Uses Flash Attention when available for better performance.
255
+
256
+ Args:
257
+ query: Query tensor.
258
+ key: Key tensor.
259
+ value: Value tensor.
260
+ attention_mask: Optional mask tensor.
261
+ cache_update_index: Index for sliding window computation.
262
+
263
+ Returns:
264
+ attention_output: Output tensor after applying attention.
265
+ """
266
+ if has_flash_attention_support():
267
+ # Use `dot_product_attention` with Flash Attention support if
268
+ # available.
269
+ if attention_mask is not None:
270
+ attention_mask = ops.expand_dims(attention_mask, axis=1)
271
+ attention_mask = ops.cast(attention_mask, dtype="bool")
272
+ attention_output = ops.dot_product_attention(
273
+ query,
274
+ key,
275
+ value,
276
+ mask=attention_mask,
277
+ scale=self._inv_norm_factor,
278
+ )
279
+ return attention_output
280
+
281
+ attention_scores = ops.einsum(self._dot_product_equation, query, key)
282
+
283
+ attention_scores = ops.multiply(
284
+ attention_scores,
285
+ ops.cast(self._inv_norm_factor, self.compute_dtype),
286
+ )
287
+ if self.use_sliding_window_attention:
288
+ attention_mask = self._mask_sliding_window(
289
+ attention_mask,
290
+ cache_update_index=cache_update_index,
291
+ )
292
+ attention_scores = self._masked_softmax(
293
+ attention_scores, attention_mask
294
+ )
295
+ attention_scores = ops.cast(attention_scores, self.compute_dtype)
296
+ attention_output = ops.einsum(
297
+ self._combine_equation, attention_scores, value
298
+ )
299
+
300
+ return attention_output
301
+
302
+ def _mask_sliding_window(
303
+ self,
304
+ attention_mask,
305
+ cache_update_index=0,
306
+ ):
307
+ """Creates and combines a sliding window mask with the attention mask.
308
+
309
+ Args:
310
+ attention_mask: Original attention mask.
311
+ cache_update_index: Starting index for the sliding window.
312
+
313
+ Returns:
314
+ Combined attention mask with sliding window constraints.
315
+ """
316
+ _, query_len, key_len = ops.shape(attention_mask)
317
+ # Compute the sliding window for square attention.
318
+ all_ones = ops.ones((key_len, key_len), "bool")
319
+ if keras.config.backend() == "tensorflow":
320
+ # TODO: trui/tril has issues with dynamic shape on the tensorflow
321
+ # backend. We should fix, but use `band_part` for now.
322
+ import tensorflow as tf
323
+
324
+ band_size = ops.minimum(key_len, self.sliding_window_size - 1)
325
+ band_size = ops.cast(band_size, "int32")
326
+ sliding_mask = tf.linalg.band_part(all_ones, band_size, band_size)
327
+ else:
328
+ sliding_mask = ops.triu(
329
+ all_ones, -1 * self.sliding_window_size + 1
330
+ ) * ops.tril(all_ones, self.sliding_window_size - 1)
331
+ # Slice the window for short queries during generation.
332
+ start = (cache_update_index, 0)
333
+ sliding_mask = ops.slice(sliding_mask, start, (query_len, key_len))
334
+ sliding_mask = ops.expand_dims(sliding_mask, 0)
335
+ return ops.logical_and(attention_mask, ops.cast(sliding_mask, "bool"))
336
+
337
+ def get_config(self):
338
+ config = super().get_config()
339
+ config.update(
340
+ {
341
+ "num_query_heads": self.num_query_heads,
342
+ "num_key_value_heads": self.num_key_value_heads,
343
+ "rope_max_wavelength": self.rope_max_wavelength,
344
+ "rope_scaling_factor": self.rope_scaling_factor,
345
+ "kernel_initializer": keras.initializers.serialize(
346
+ self.kernel_initializer
347
+ ),
348
+ "bias_initializer": keras.initializers.serialize(
349
+ self.bias_initializer
350
+ ),
351
+ "dropout": self.dropout,
352
+ "use_sliding_window_attention": (
353
+ self.use_sliding_window_attention
354
+ ),
355
+ "sliding_window_size": self.sliding_window_size,
356
+ }
357
+ )
358
+ return config