keras-hub-nightly 0.23.0.dev202509180413__py3-none-any.whl → 0.23.0.dev202509280419__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of keras-hub-nightly might be problematic. Click here for more details.

Files changed (32) hide show
  1. keras_hub/layers/__init__.py +3 -0
  2. keras_hub/models/__init__.py +24 -0
  3. keras_hub/src/models/depth_anything/__init__.py +9 -0
  4. keras_hub/src/models/depth_anything/depth_anything_backbone.py +232 -0
  5. keras_hub/src/models/depth_anything/depth_anything_depth_estimator.py +70 -0
  6. keras_hub/src/models/depth_anything/depth_anything_depth_estimator_preprocessor.py +16 -0
  7. keras_hub/src/models/depth_anything/depth_anything_image_converter.py +10 -0
  8. keras_hub/src/models/depth_anything/depth_anything_layers.py +725 -0
  9. keras_hub/src/models/depth_anything/depth_anything_loss.py +89 -0
  10. keras_hub/src/models/depth_anything/depth_anything_presets.py +4 -0
  11. keras_hub/src/models/depth_anything/interpolate.py +62 -0
  12. keras_hub/src/models/depth_estimator.py +239 -0
  13. keras_hub/src/models/depth_estimator_preprocessor.py +78 -0
  14. keras_hub/src/models/dinov2/dinov2_backbone.py +29 -3
  15. keras_hub/src/models/dinov2/dinov2_layers.py +13 -3
  16. keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py +371 -0
  17. keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +365 -0
  18. keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm.py +357 -0
  19. keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm_preprocessor.py +12 -0
  20. keras_hub/src/models/qwen3_moe/qwen3_moe_decoder.py +672 -0
  21. keras_hub/src/models/qwen3_moe/qwen3_moe_layernorm.py +45 -0
  22. keras_hub/src/models/qwen3_moe/qwen3_moe_tokenizer.py +48 -0
  23. keras_hub/src/tests/test_case.py +3 -2
  24. keras_hub/src/utils/transformers/convert_dinov2.py +1 -0
  25. keras_hub/src/utils/transformers/convert_qwen3_moe.py +216 -0
  26. keras_hub/src/utils/transformers/preset_loader.py +3 -0
  27. keras_hub/src/version.py +1 -1
  28. keras_hub/tokenizers/__init__.py +3 -0
  29. {keras_hub_nightly-0.23.0.dev202509180413.dist-info → keras_hub_nightly-0.23.0.dev202509280419.dist-info}/METADATA +1 -1
  30. {keras_hub_nightly-0.23.0.dev202509180413.dist-info → keras_hub_nightly-0.23.0.dev202509280419.dist-info}/RECORD +32 -13
  31. {keras_hub_nightly-0.23.0.dev202509180413.dist-info → keras_hub_nightly-0.23.0.dev202509280419.dist-info}/WHEEL +0 -0
  32. {keras_hub_nightly-0.23.0.dev202509180413.dist-info → keras_hub_nightly-0.23.0.dev202509280419.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,371 @@
1
+ import math
2
+
3
+ import keras
4
+ from keras import ops
5
+
6
+ from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
7
+ from keras_hub.src.models.qwen3_moe.qwen3_moe_layernorm import Qwen3MoeLayerNorm
8
+ from keras_hub.src.utils.keras_utils import clone_initializer
9
+ from keras_hub.src.utils.keras_utils import fused_attention_op_available
10
+
11
+
12
+ class Qwen3MoeAttention(keras.layers.Layer):
13
+ """A multi-head attention layer for Qwen3Moe models
14
+ This attention implementation supports grouped-query attention (GQA) where
15
+ the number of key-value heads can be less than the number of query heads.
16
+
17
+ Args:
18
+ num_query_heads: int. Number of query heads.
19
+ num_key_value_heads: int. Number of key/value heads (for GQA).
20
+ head_dim: int. The dimension of each attention head.
21
+ rope_max_wavelength: int. Maximum wavelength for RoPE (Rotary Position
22
+ Embedding).
23
+ rope_scaling_factor: float. Scaling factor for RoPE, used for extending
24
+ context length.
25
+ kernel_initializer: Initializer for the kernel weights.
26
+ dropout: float. Dropout rate for attention weights.
27
+ layer_norm_epsilon: float. The epsilon value for layer normalization.
28
+ sliding_window_size: int. Size of the sliding window for attention.
29
+ **kwargs: Additional keyword arguments to pass to the Layer.
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ num_query_heads,
35
+ num_key_value_heads,
36
+ head_dim=None,
37
+ rope_max_wavelength=10000,
38
+ rope_scaling_factor=1,
39
+ kernel_initializer="glorot_uniform",
40
+ dropout=0.0,
41
+ layer_norm_epsilon=1e-6,
42
+ sliding_window_size=None,
43
+ **kwargs,
44
+ ):
45
+ super().__init__(
46
+ **kwargs,
47
+ )
48
+ self.num_query_heads = num_query_heads
49
+ self.num_key_value_heads = num_key_value_heads
50
+ self.head_dim = head_dim
51
+ self.dropout = dropout
52
+
53
+ self.layer_norm_epsilon = layer_norm_epsilon
54
+
55
+ self.num_key_value_groups = num_query_heads // num_key_value_heads
56
+ self.rope_max_wavelength = rope_max_wavelength
57
+
58
+ self.kernel_initializer = keras.initializers.get(
59
+ clone_initializer(kernel_initializer)
60
+ )
61
+
62
+ self.rope_scaling_factor = rope_scaling_factor
63
+ self.sliding_window_size = sliding_window_size
64
+
65
+ def build(self, inputs_shape):
66
+ # Einsum variables:
67
+ # b = batch size
68
+ # q = query length
69
+ # k = key/value length
70
+ # m = model dim
71
+ # u = num query heads
72
+ # v = num key/value heads
73
+ # h = head dim
74
+ hidden_dim = inputs_shape[-1]
75
+ if not self.head_dim:
76
+ self.head_dim = hidden_dim // self.num_query_heads
77
+
78
+ self._inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
79
+ self._query_dense = keras.layers.EinsumDense(
80
+ equation="bqm,muh->bquh",
81
+ output_shape=(None, self.num_query_heads, self.head_dim),
82
+ kernel_initializer=self.kernel_initializer,
83
+ dtype=self.dtype_policy,
84
+ name="query",
85
+ )
86
+ self._query_dense.build(inputs_shape)
87
+
88
+ self._query_dense_layer_norm = Qwen3MoeLayerNorm(
89
+ epsilon=self.layer_norm_epsilon,
90
+ dtype=self.dtype_policy,
91
+ head_dim=self.head_dim,
92
+ name="query_dense_layernorm",
93
+ )
94
+ self._query_dense_layer_norm.build(inputs_shape)
95
+
96
+ self._key_dense = keras.layers.EinsumDense(
97
+ equation="bkm,mvh->bkvh",
98
+ output_shape=(
99
+ None,
100
+ self.num_key_value_heads,
101
+ self.head_dim,
102
+ ),
103
+ kernel_initializer=self.kernel_initializer,
104
+ dtype=self.dtype_policy,
105
+ name="key",
106
+ )
107
+ self._key_dense.build(inputs_shape)
108
+
109
+ self._key_dense_layer_norm = Qwen3MoeLayerNorm(
110
+ epsilon=self.layer_norm_epsilon,
111
+ dtype=self.dtype_policy,
112
+ head_dim=self.head_dim,
113
+ name="key_dense_layernorm",
114
+ )
115
+ self._key_dense_layer_norm.build(inputs_shape)
116
+
117
+ self._value_dense = keras.layers.EinsumDense(
118
+ equation="bkm,mvh->bkvh",
119
+ output_shape=(
120
+ None,
121
+ self.num_key_value_heads,
122
+ self.head_dim,
123
+ ),
124
+ kernel_initializer=self.kernel_initializer,
125
+ dtype=self.dtype_policy,
126
+ name="value",
127
+ )
128
+ self._value_dense.build(inputs_shape)
129
+
130
+ self._softmax = keras.layers.Softmax(
131
+ axis=-1,
132
+ dtype="float32",
133
+ name="attention_softmax",
134
+ )
135
+
136
+ self._dropout_layer = keras.layers.Dropout(
137
+ rate=self.dropout,
138
+ dtype=self.dtype_policy,
139
+ )
140
+
141
+ self._output_dense = keras.layers.EinsumDense(
142
+ equation="bquh,uhm->bqm",
143
+ output_shape=(None, hidden_dim),
144
+ kernel_initializer=self.kernel_initializer,
145
+ dtype=self.dtype_policy,
146
+ name="attention_output",
147
+ )
148
+ self._output_dense.build(
149
+ (None, None, self.num_query_heads, self.head_dim)
150
+ )
151
+
152
+ self.rotary_embedding_layer = RotaryEmbedding(
153
+ max_wavelength=self.rope_max_wavelength,
154
+ scaling_factor=self.rope_scaling_factor,
155
+ dtype=self.dtype_policy,
156
+ )
157
+
158
+ self._dot_product_equation = "bquh,bkuh->buqk"
159
+ self._combine_equation = "buqk,bkuh->bquh"
160
+
161
+ self.built = True
162
+
163
+ def call(
164
+ self,
165
+ hidden_states,
166
+ attention_mask=None,
167
+ cache=None,
168
+ cache_update_index=None,
169
+ training=None,
170
+ ):
171
+ """Applies attention mechanism to the input hidden states.
172
+
173
+ Args:
174
+ hidden_states: Input tensor of shape [batch_size, seq_length,
175
+ hidden_size].
176
+ attention_mask: Mask tensor of shape [batch_size, seq_length,
177
+ seq_length].
178
+ cache: Optional cached key and value tensors.
179
+ cache_update_index: Index at which to update the cache.
180
+ training: Boolean indicating whether in training mode.
181
+
182
+ Returns:
183
+ attention_output: Output tensor after applying attention.
184
+ cache: Updated cache tensors (if cache is provided).
185
+ """
186
+ start_index = (
187
+ cache_update_index if cache_update_index is not None else 0
188
+ )
189
+
190
+ query = self._query_dense(hidden_states)
191
+ query = self._query_dense_layer_norm(query)
192
+
193
+ # Compute RoPE for queries
194
+ query = self.rotary_embedding_layer(query, start_index=start_index)
195
+
196
+ def _compute_key_value(x):
197
+ key = self._key_dense(x)
198
+ key = self._key_dense_layer_norm(key)
199
+ key = self.rotary_embedding_layer(key, start_index=start_index)
200
+
201
+ value = self._value_dense(x)
202
+
203
+ return key, value
204
+
205
+ if cache is not None:
206
+ key_cache = cache[:, 0, ...]
207
+ value_cache = cache[:, 1, ...]
208
+ if cache_update_index is None:
209
+ key = key_cache
210
+ value = value_cache
211
+ else:
212
+ key_update, value_update = _compute_key_value(hidden_states)
213
+ start = [0, cache_update_index, 0, 0]
214
+ key = ops.slice_update(key_cache, start, key_update)
215
+ value = ops.slice_update(value_cache, start, value_update)
216
+ cache = ops.stack((key, value), axis=1)
217
+ else:
218
+ if cache_update_index is not None:
219
+ raise ValueError(
220
+ "`cache_update_index` should not be set if `cache` is "
221
+ f"`None`. Received: cache={cache}, "
222
+ f"cache_update_index={cache_update_index}"
223
+ )
224
+ key, value = _compute_key_value(hidden_states)
225
+
226
+ # [batch_shape, seq_len, num_key_value_heads, head_dim]
227
+ # -> [batch_shape, seq_len, num_heads, head_dim]
228
+ key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2)
229
+ value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2)
230
+
231
+ attention_output = self._compute_attention(
232
+ query,
233
+ key,
234
+ value,
235
+ attention_mask,
236
+ cache_update_index=cache_update_index,
237
+ )
238
+
239
+ attention_output = self._dropout_layer(
240
+ attention_output, training=training
241
+ )
242
+
243
+ attention_output = self._output_dense(attention_output)
244
+
245
+ if cache is not None:
246
+ return attention_output, cache
247
+ return attention_output
248
+
249
+ def _masked_softmax(self, attention_scores, attention_mask=None):
250
+ """Applies softmax with optional masking.
251
+
252
+ Args:
253
+ attention_scores: Attention score tensor.
254
+ attention_mask: Optional mask tensor.
255
+
256
+ Returns:
257
+ Masked softmax attention weights.
258
+ """
259
+ if attention_mask is not None:
260
+ return self._softmax(
261
+ attention_scores, attention_mask[:, None, :, :]
262
+ )
263
+ return self._softmax(attention_scores)
264
+
265
+ def _compute_attention(
266
+ self, query, key, value, attention_mask=None, cache_update_index=None
267
+ ):
268
+ """Computes attention using query, key, and value tensors.
269
+ Uses Flash Attention when available for better performance.
270
+
271
+ Args:
272
+ query: Query tensor.
273
+ key: Key tensor.
274
+ value: Value tensor.
275
+ attention_mask: Optional mask tensor.
276
+ cache_update_index: Index for sliding window computation.
277
+
278
+ Returns:
279
+ attention_output: Output tensor after applying attention.
280
+ """
281
+ if fused_attention_op_available():
282
+ # Use `dot_product_attention` with Flash Attention support if
283
+ # available.
284
+ if attention_mask is not None:
285
+ attention_mask = ops.expand_dims(attention_mask, axis=1)
286
+ attention_mask = ops.cast(attention_mask, dtype="bool")
287
+ attention_output = ops.dot_product_attention(
288
+ query,
289
+ key,
290
+ value,
291
+ mask=attention_mask,
292
+ scale=self._inv_norm_factor,
293
+ )
294
+ return attention_output
295
+
296
+ attention_scores = ops.einsum(self._dot_product_equation, query, key)
297
+
298
+ attention_scores = ops.multiply(
299
+ attention_scores,
300
+ ops.cast(self._inv_norm_factor, self.compute_dtype),
301
+ )
302
+ if self.sliding_window_size:
303
+ attention_mask = self._mask_sliding_window(
304
+ attention_mask,
305
+ cache_update_index=cache_update_index
306
+ if cache_update_index is not None
307
+ else 0,
308
+ )
309
+ attention_scores = self._masked_softmax(
310
+ attention_scores, attention_mask
311
+ )
312
+ attention_scores = ops.cast(attention_scores, self.compute_dtype)
313
+ attention_output = ops.einsum(
314
+ self._combine_equation, attention_scores, value
315
+ )
316
+
317
+ return attention_output
318
+
319
+ def _mask_sliding_window(
320
+ self,
321
+ attention_mask,
322
+ cache_update_index=0,
323
+ ):
324
+ """Creates and combines a sliding window mask with the attention mask.
325
+
326
+ Args:
327
+ attention_mask: Original attention mask.
328
+ cache_update_index: Starting index for the sliding window.
329
+
330
+ Returns:
331
+ Combined attention mask with sliding window constraints.
332
+ """
333
+ _, query_len, key_len = ops.shape(attention_mask)
334
+ # Compute the sliding window for square attention.
335
+ all_ones = ops.ones((key_len, key_len), "bool")
336
+ if keras.config.backend() == "tensorflow":
337
+ # TODO: trui/tril has issues with dynamic shape on the tensorflow
338
+ # backend. We should fix, but use `band_part` for now.
339
+ import tensorflow as tf
340
+
341
+ band_size = ops.minimum(key_len, self.sliding_window_size - 1)
342
+ band_size = ops.cast(band_size, "int32")
343
+ sliding_mask = tf.linalg.band_part(all_ones, band_size, band_size)
344
+ else:
345
+ sliding_mask = ops.triu(
346
+ all_ones, -1 * self.sliding_window_size + 1
347
+ ) * ops.tril(all_ones, self.sliding_window_size - 1)
348
+ # Slice the window for short queries during generation.
349
+ start = (cache_update_index, 0)
350
+ sliding_mask = ops.slice(sliding_mask, start, (query_len, key_len))
351
+ sliding_mask = ops.expand_dims(sliding_mask, 0)
352
+ return ops.logical_and(attention_mask, ops.cast(sliding_mask, "bool"))
353
+
354
+ def get_config(self):
355
+ config = super().get_config()
356
+ config.update(
357
+ {
358
+ "num_query_heads": self.num_query_heads,
359
+ "num_key_value_heads": self.num_key_value_heads,
360
+ "rope_max_wavelength": self.rope_max_wavelength,
361
+ "rope_scaling_factor": self.rope_scaling_factor,
362
+ "kernel_initializer": keras.initializers.serialize(
363
+ self.kernel_initializer
364
+ ),
365
+ "dropout": self.dropout,
366
+ "sliding_window_size": self.sliding_window_size,
367
+ "head_dim": self.head_dim,
368
+ "layer_norm_epsilon": self.layer_norm_epsilon,
369
+ }
370
+ )
371
+ return config