keras-hub-nightly 0.20.0.dev202504020401__py3-none-any.whl → 0.21.0.dev202504040358__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. keras_hub/api/models/__init__.py +5 -20
  2. keras_hub/api/tokenizers/__init__.py +0 -4
  3. keras_hub/src/layers/preprocessing/image_converter.py +26 -16
  4. keras_hub/src/models/gemma/gemma_attention.py +17 -10
  5. keras_hub/src/models/gemma3/gemma3_attention.py +76 -23
  6. keras_hub/src/models/gemma3/gemma3_backbone.py +117 -46
  7. keras_hub/src/models/gemma3/gemma3_causal_lm.py +72 -15
  8. keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +512 -355
  9. keras_hub/src/models/gemma3/gemma3_decoder_block.py +23 -19
  10. keras_hub/src/models/gemma3/gemma3_image_converter.py +6 -0
  11. keras_hub/src/models/gemma3/gemma3_interleave_embeddings.py +56 -16
  12. keras_hub/src/models/gemma3/gemma3_presets.py +74 -8
  13. keras_hub/src/models/gemma3/gemma3_tokenizer.py +9 -0
  14. keras_hub/src/models/gemma3/{gemma3_vit.py → gemma3_vision_encoder.py} +150 -139
  15. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -2
  16. keras_hub/src/models/llama/llama_attention.py +2 -2
  17. keras_hub/src/models/mistral/mistral_attention.py +2 -2
  18. keras_hub/src/models/phi3/phi3_attention.py +2 -2
  19. keras_hub/src/models/qwen/qwen_attention.py +2 -2
  20. keras_hub/src/models/qwen/qwen_backbone.py +0 -7
  21. keras_hub/src/models/qwen/qwen_causal_lm.py +0 -7
  22. keras_hub/src/models/qwen/qwen_causal_lm_preprocessor.py +0 -7
  23. keras_hub/src/models/qwen/qwen_tokenizer.py +0 -9
  24. keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -1
  25. keras_hub/src/models/roformer_v2/roformer_v2_text_classifier.py +2 -2
  26. keras_hub/src/models/stable_diffusion_3/mmdit.py +2 -2
  27. keras_hub/src/models/vit/vit_image_converter.py +8 -3
  28. keras_hub/src/tests/test_case.py +4 -0
  29. keras_hub/src/utils/keras_utils.py +44 -1
  30. keras_hub/src/utils/tensor_utils.py +6 -0
  31. keras_hub/src/version_utils.py +1 -1
  32. {keras_hub_nightly-0.20.0.dev202504020401.dist-info → keras_hub_nightly-0.21.0.dev202504040358.dist-info}/METADATA +1 -1
  33. {keras_hub_nightly-0.20.0.dev202504020401.dist-info → keras_hub_nightly-0.21.0.dev202504040358.dist-info}/RECORD +35 -35
  34. {keras_hub_nightly-0.20.0.dev202504020401.dist-info → keras_hub_nightly-0.21.0.dev202504040358.dist-info}/WHEEL +0 -0
  35. {keras_hub_nightly-0.20.0.dev202504020401.dist-info → keras_hub_nightly-0.21.0.dev202504040358.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,150 @@
1
1
  import keras
2
2
  from keras import ops
3
3
 
4
+ from keras_hub.src.api_export import keras_hub_export
4
5
  from keras_hub.src.models.gemma.rms_normalization import RMSNormalization
5
6
  from keras_hub.src.utils.keras_utils import clone_initializer
6
7
 
7
8
 
8
- class Gemma3VitEmbeddings(keras.layers.Layer):
9
+ @keras_hub_export("keras_hub.models.Gemma3VisionEncoder")
10
+ class Gemma3VisionEncoder(keras.Model):
11
+ """Vision Transformer (ViT) model for Gemma3.
12
+
13
+ Args:
14
+ image_size: int. The height/width of the image. Both height and width is
15
+ expected to be the same.
16
+ patch_size: int. The size of each square patch in the input image.
17
+ num_heads: int. The number of attention heads for the vision(image)
18
+ transformer encoder.
19
+ hidden_dim: int. The size of the transformer hidden state at the end
20
+ of each vision transformer layer.
21
+ num_layers: int. The number of transformer layers.
22
+ intermediate_dim: int. The output dimension of the first Dense layer in
23
+ a two-layer feedforward network for transformer.
24
+ output_dim: int. The odimension of the output returned by the model.
25
+ pool_size: int. Factors by which to downscale `(dim1, dim2)` in the
26
+ average pooling layer. The same value is used for `"strides"`.
27
+ Defaults to 14.
28
+ layer_norm_epsilon: float. The epsilon value user for every layer norm
29
+ in all transformer blocks. Defaults to `1e-6`.
30
+ dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
31
+ for the models computations and weights. Note that some
32
+ computations, such as softmax and layer normalization will always
33
+ be done a float32 precision regardless of dtype.
34
+
35
+ Example:
36
+ ```python
37
+ image = np.random.rand(224, 224, 3)
38
+ vit_model = Gemma3VisionEncoder(image_size=224)
39
+ # The output will be of shape:
40
+ # [batch_size, num_vision_tokens_per_image, hidden_dim]
41
+ output = vit_model([image])
42
+ ```
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ image_size,
48
+ patch_size,
49
+ num_heads,
50
+ hidden_dim,
51
+ num_layers,
52
+ intermediate_dim,
53
+ output_dim,
54
+ pool_size=14,
55
+ layer_norm_epsilon=1e-6,
56
+ dtype=None,
57
+ **kwargs,
58
+ ):
59
+ # If the passed dtype is `bfloat16`, use `float32` to maintain parity
60
+ # with other framework implementations.
61
+ if dtype == "bfloat16":
62
+ dtype = "float32"
63
+
64
+ # === Functional Model ===
65
+ image_input = keras.Input(
66
+ shape=(None, image_size, image_size, 3),
67
+ name="images",
68
+ )
69
+ x = image_input # Intermediate result.
70
+ x = Gemma3VisionEncoderBlock(
71
+ hidden_dim=hidden_dim,
72
+ num_layers=num_layers,
73
+ num_heads=num_heads,
74
+ intermediate_dim=intermediate_dim,
75
+ patch_size=patch_size,
76
+ image_size=image_size,
77
+ dtype=dtype,
78
+ name="image_encoder",
79
+ )(x)
80
+
81
+ x = Gemma3VisionAveragePooling(
82
+ image_size=image_size,
83
+ patch_size=patch_size,
84
+ pool_size=pool_size,
85
+ dtype=dtype,
86
+ name="pooling",
87
+ )(x)
88
+
89
+ x = Gemma3VisionOutput(
90
+ output_dim=output_dim,
91
+ layer_norm_epsilon=layer_norm_epsilon,
92
+ kernel_initializer=keras.initializers.RandomNormal(
93
+ mean=0.0, stddev=0.01
94
+ ),
95
+ dtype=dtype,
96
+ name="vision_output_encoder",
97
+ )(x)
98
+
99
+ outputs = x
100
+ super().__init__(
101
+ inputs=image_input,
102
+ outputs=outputs,
103
+ **kwargs,
104
+ )
105
+
106
+ # === Config ===
107
+ self.image_size = image_size
108
+ self.patch_size = patch_size
109
+ self.num_heads = num_heads
110
+ self.hidden_dim = hidden_dim
111
+ self.num_layers = num_layers
112
+ self.intermediate_dim = intermediate_dim
113
+ self.output_dim = output_dim
114
+ self.pool_size = pool_size
115
+ self.layer_norm_epsilon = layer_norm_epsilon
116
+ self.num_vision_tokens_per_image = (
117
+ (image_size // patch_size) ** 2
118
+ ) // (pool_size**2)
119
+
120
+ # Before Keras 3.2, there is no `keras.dtype_policies.get`.
121
+ if hasattr(keras.dtype_policies, "get"):
122
+ self.dtype_policy = keras.dtype_policies.get(dtype)
123
+ else:
124
+ if isinstance(dtype, keras.dtype_policies.DTypePolicy):
125
+ dtype = dtype.name
126
+ dtype = dtype or keras.config.dtype_policy().name
127
+ self.dtype_policy = keras.dtype_policies.DTypePolicy(dtype)
128
+
129
+ def get_config(self):
130
+ config = super().get_config()
131
+ config.update(
132
+ {
133
+ "num_heads": self.num_heads,
134
+ "hidden_dim": self.hidden_dim,
135
+ "num_layers": self.num_layers,
136
+ "intermediate_dim": self.intermediate_dim,
137
+ "output_dim": self.output_dim,
138
+ "pool_size": self.pool_size,
139
+ "image_size": self.image_size,
140
+ "patch_size": self.patch_size,
141
+ "layer_norm_epsilon": self.layer_norm_epsilon,
142
+ }
143
+ )
144
+ return config
145
+
146
+
147
+ class Gemma3VisionEmbedding(keras.layers.Layer):
9
148
  def __init__(
10
149
  self,
11
150
  image_size,
@@ -62,7 +201,7 @@ class Gemma3VitEmbeddings(keras.layers.Layer):
62
201
  )
63
202
 
64
203
 
65
- class Gemma3VitAttention(keras.layers.Layer):
204
+ class Gemma3VisionAttention(keras.layers.Layer):
66
205
  """
67
206
  Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py
68
207
  """
@@ -197,7 +336,7 @@ class Gemma3VitAttention(keras.layers.Layer):
197
336
  return config
198
337
 
199
338
 
200
- class Gemma3VitEncoderBlock(keras.layers.Layer):
339
+ class Gemma3VisionEncoderLayer(keras.layers.Layer):
201
340
  def __init__(
202
341
  self,
203
342
  num_heads,
@@ -217,7 +356,7 @@ class Gemma3VitEncoderBlock(keras.layers.Layer):
217
356
 
218
357
  def build(self, input_shape):
219
358
  hidden_dim = input_shape[-1]
220
- self.attn = Gemma3VitAttention(
359
+ self.attn = Gemma3VisionAttention(
221
360
  hidden_dim,
222
361
  self.num_heads,
223
362
  dtype=self.dtype_policy,
@@ -277,7 +416,7 @@ class Gemma3VitEncoderBlock(keras.layers.Layer):
277
416
  return config
278
417
 
279
418
 
280
- class Gemma3VitEncoder(keras.layers.Layer):
419
+ class Gemma3VisionEncoderBlock(keras.layers.Layer):
281
420
  def __init__(
282
421
  self,
283
422
  patch_size,
@@ -303,7 +442,7 @@ class Gemma3VitEncoder(keras.layers.Layer):
303
442
  dtype=dtype,
304
443
  name="encoder_layer_norm",
305
444
  )
306
- self.vision_embeddings = Gemma3VitEmbeddings(
445
+ self.vision_embeddings = Gemma3VisionEmbedding(
307
446
  hidden_dim=hidden_dim,
308
447
  patch_size=patch_size,
309
448
  image_size=image_size,
@@ -311,7 +450,7 @@ class Gemma3VitEncoder(keras.layers.Layer):
311
450
  name="encoder_embeddings",
312
451
  )
313
452
  self.resblocks = [
314
- Gemma3VitEncoderBlock(
453
+ Gemma3VisionEncoderLayer(
315
454
  self.num_heads,
316
455
  self.intermediate_dim,
317
456
  dtype=dtype,
@@ -321,7 +460,7 @@ class Gemma3VitEncoder(keras.layers.Layer):
321
460
  ]
322
461
 
323
462
  def build(self, inputs_shape):
324
- # Collapse `batch_size`, dummy axis, `image_max_length` into one.
463
+ # Collapse `batch_size`, dummy axis, `max_images_per_prompt` into one.
325
464
  inputs_shape = [None] + list(inputs_shape[2:])
326
465
  self.vision_embeddings.build(inputs_shape)
327
466
  for block in self.resblocks:
@@ -332,7 +471,7 @@ class Gemma3VitEncoder(keras.layers.Layer):
332
471
  def call(self, inputs, mask=None):
333
472
  inputs_shape = ops.shape(inputs)
334
473
 
335
- # Collapse `batch_size`, dummy axis, `image_max_length` into one.
474
+ # Collapse `batch_size`, dummy axis, `max_images_per_prompt` into one.
336
475
  inputs = ops.reshape(
337
476
  inputs,
338
477
  [inputs_shape[0] * inputs_shape[1]] + list(inputs_shape[2:]),
@@ -372,7 +511,7 @@ class Gemma3VitEncoder(keras.layers.Layer):
372
511
  return config
373
512
 
374
513
 
375
- class AveragePooling(keras.layers.Layer):
514
+ class Gemma3VisionAveragePooling(keras.layers.Layer):
376
515
  def __init__(self, image_size, patch_size, pool_size, **kwargs):
377
516
  super().__init__(**kwargs)
378
517
 
@@ -425,7 +564,7 @@ class AveragePooling(keras.layers.Layer):
425
564
  return config
426
565
 
427
566
 
428
- class Gemma3VisionOutputEncoder(keras.layers.Layer):
567
+ class Gemma3VisionOutput(keras.layers.Layer):
429
568
  def __init__(
430
569
  self,
431
570
  output_dim,
@@ -478,131 +617,3 @@ class Gemma3VisionOutputEncoder(keras.layers.Layer):
478
617
 
479
618
  def compute_output_shape(self, input_shape):
480
619
  return input_shape[:-1] + (self.output_dim,)
481
-
482
-
483
- class Gemma3Vit(keras.Model):
484
- """Vision Transformer (ViT) model for Gemma3.
485
-
486
- Args:
487
- image_size: int. The height/width of the image. Both height and width is
488
- expected to be the same.
489
- patch_size: int. The size of each square patch in the input image.
490
- num_heads: int. The number of attention heads for the vision(image)
491
- transformer encoder.
492
- hidden_dim: int. The size of the transformer hidden state at the end
493
- of each vision transformer layer.
494
- num_layers: int. The number of transformer layers.
495
- intermediate_dim: int. The output dimension of the first Dense layer in
496
- a two-layer feedforward network for transformer.
497
- pool_size: int. Factors by which to downscale `(dim1, dim2)` in the
498
- average pooling layer. The same value is used for `"strides"`.
499
- dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
500
- for the models computations and weights. Note that some
501
- computations, such as softmax and layer normalization will always
502
- be done a float32 precision regardless of dtype.
503
-
504
- Example:
505
- ```python
506
- image = np.random.rand(224, 224, 3)
507
- vit_model = Gemma3Vit(image_size=224)
508
- # The output will be of shape:
509
- # [batch_size, num_vision_tokens_per_image, hidden_dim]
510
- output = vit_model([image])
511
- ```
512
- """
513
-
514
- def __init__(
515
- self,
516
- image_size,
517
- patch_size,
518
- num_heads,
519
- hidden_dim,
520
- num_layers,
521
- intermediate_dim,
522
- output_dim,
523
- pool_size=14,
524
- layer_norm_epsilon=1e-6,
525
- dtype=None,
526
- **kwargs,
527
- ):
528
- # === Functional Model ===
529
- image_input = keras.Input(
530
- shape=(None, image_size, image_size, 3),
531
- name="images",
532
- )
533
- x = image_input # Intermediate result.
534
- x = Gemma3VitEncoder(
535
- hidden_dim=hidden_dim,
536
- num_layers=num_layers,
537
- num_heads=num_heads,
538
- intermediate_dim=intermediate_dim,
539
- patch_size=patch_size,
540
- image_size=image_size,
541
- dtype=dtype,
542
- name="image_encoder",
543
- )(x)
544
-
545
- x = AveragePooling(
546
- image_size=image_size,
547
- patch_size=patch_size,
548
- pool_size=pool_size,
549
- dtype=dtype,
550
- name="pooling",
551
- )(x)
552
-
553
- x = Gemma3VisionOutputEncoder(
554
- output_dim=output_dim,
555
- layer_norm_epsilon=layer_norm_epsilon,
556
- kernel_initializer=keras.initializers.RandomNormal(
557
- mean=0.0, stddev=0.01
558
- ),
559
- dtype=dtype,
560
- name="vision_output_encoder",
561
- )(x)
562
-
563
- outputs = x
564
- super().__init__(
565
- inputs=image_input,
566
- outputs=outputs,
567
- **kwargs,
568
- )
569
-
570
- # === Config ===
571
- self.image_size = image_size
572
- self.patch_size = patch_size
573
- self.num_heads = num_heads
574
- self.hidden_dim = hidden_dim
575
- self.num_layers = num_layers
576
- self.intermediate_dim = intermediate_dim
577
- self.output_dim = output_dim
578
- self.pool_size = pool_size
579
- self.layer_norm_epsilon = layer_norm_epsilon
580
- self.num_vision_tokens_per_image = (
581
- (image_size // patch_size) ** 2
582
- ) // (pool_size**2)
583
-
584
- # Before Keras 3.2, there is no `keras.dtype_policies.get`.
585
- if hasattr(keras.dtype_policies, "get"):
586
- self.dtype_policy = keras.dtype_policies.get(dtype)
587
- else:
588
- if isinstance(dtype, keras.dtype_policies.DTypePolicy):
589
- dtype = dtype.name
590
- dtype = dtype or keras.config.dtype_policy().name
591
- self.dtype_policy = keras.dtype_policies.DTypePolicy(dtype)
592
-
593
- def get_config(self):
594
- config = super().get_config()
595
- config.update(
596
- {
597
- "num_heads": self.num_heads,
598
- "hidden_dim": self.hidden_dim,
599
- "num_layers": self.num_layers,
600
- "intermediate_dim": self.intermediate_dim,
601
- "output_dim": self.output_dim,
602
- "pool_size": self.pool_size,
603
- "image_size": self.image_size,
604
- "patch_size": self.patch_size,
605
- "layer_norm_epsilon": self.layer_norm_epsilon,
606
- }
607
- )
608
- return config
@@ -5,7 +5,7 @@ from keras import ops
5
5
 
6
6
  from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
7
7
  from keras_hub.src.utils.keras_utils import clone_initializer
8
- from keras_hub.src.utils.keras_utils import has_flash_attention_support
8
+ from keras_hub.src.utils.keras_utils import fused_attention_op_available
9
9
 
10
10
 
11
11
  class GPTNeoXAttention(keras.layers.Layer):
@@ -125,7 +125,7 @@ class GPTNeoXAttention(keras.layers.Layer):
125
125
  def _compute_attention(
126
126
  self, query, key, value, attention_mask=None, training=None
127
127
  ):
128
- if has_flash_attention_support() and self.dropout == 0:
128
+ if fused_attention_op_available() and self.dropout == 0:
129
129
  # Use `dot_product_attention` with Flash Attention support if
130
130
  # available.
131
131
  if attention_mask is not None:
@@ -5,7 +5,7 @@ from keras import ops
5
5
 
6
6
  from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
7
7
  from keras_hub.src.utils.keras_utils import clone_initializer
8
- from keras_hub.src.utils.keras_utils import has_flash_attention_support
8
+ from keras_hub.src.utils.keras_utils import fused_attention_op_available
9
9
 
10
10
 
11
11
  class LlamaAttention(keras.layers.Layer):
@@ -185,7 +185,7 @@ class LlamaAttention(keras.layers.Layer):
185
185
  return self._softmax(attention_scores)
186
186
 
187
187
  def _compute_attention(self, query, key, value, attention_mask=None):
188
- if has_flash_attention_support():
188
+ if fused_attention_op_available():
189
189
  # Use `dot_product_attention` with Flash Attention support if
190
190
  # available.
191
191
  if attention_mask is not None:
@@ -5,7 +5,7 @@ from keras import ops
5
5
 
6
6
  from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
7
7
  from keras_hub.src.utils.keras_utils import clone_initializer
8
- from keras_hub.src.utils.keras_utils import has_flash_attention_support
8
+ from keras_hub.src.utils.keras_utils import fused_attention_op_available
9
9
 
10
10
 
11
11
  # This is just a self-attention layer in Mistral. But it can be generalized
@@ -196,7 +196,7 @@ class CachedMistralAttention(keras.layers.Layer):
196
196
  return self._softmax(attention_scores)
197
197
 
198
198
  def _compute_attention(self, query, key, value, attention_mask=None):
199
- if has_flash_attention_support():
199
+ if fused_attention_op_available():
200
200
  # Use `dot_product_attention` with Flash Attention support if
201
201
  # available.
202
202
  if attention_mask is not None:
@@ -8,7 +8,7 @@ from keras_hub.src.models.phi3.phi3_rotary_embedding import (
8
8
  Phi3SuScaledRotaryEmbedding,
9
9
  )
10
10
  from keras_hub.src.utils.keras_utils import clone_initializer
11
- from keras_hub.src.utils.keras_utils import has_flash_attention_support
11
+ from keras_hub.src.utils.keras_utils import fused_attention_op_available
12
12
 
13
13
 
14
14
  class Phi3Attention(keras.layers.Layer):
@@ -217,7 +217,7 @@ class Phi3Attention(keras.layers.Layer):
217
217
  return self.softmax(attention_scores)
218
218
 
219
219
  def _compute_attention(self, query, key, value, attention_mask=None):
220
- if has_flash_attention_support():
220
+ if fused_attention_op_available():
221
221
  # Use `dot_product_attention` with Flash Attention support if
222
222
  # available.
223
223
  if attention_mask is not None:
@@ -5,7 +5,7 @@ from keras import ops
5
5
 
6
6
  from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
7
7
  from keras_hub.src.utils.keras_utils import clone_initializer
8
- from keras_hub.src.utils.keras_utils import has_flash_attention_support
8
+ from keras_hub.src.utils.keras_utils import fused_attention_op_available
9
9
 
10
10
 
11
11
  class QwenAttention(keras.layers.Layer):
@@ -263,7 +263,7 @@ class QwenAttention(keras.layers.Layer):
263
263
  Returns:
264
264
  attention_output: Output tensor after applying attention.
265
265
  """
266
- if has_flash_attention_support():
266
+ if fused_attention_op_available():
267
267
  # Use `dot_product_attention` with Flash Attention support if
268
268
  # available.
269
269
  if attention_mask is not None:
@@ -1,7 +1,6 @@
1
1
  import keras
2
2
  from keras import ops
3
3
 
4
- from keras_hub.src.api_export import keras_hub_export
5
4
  from keras_hub.src.layers.modeling.reversible_embedding import (
6
5
  ReversibleEmbedding,
7
6
  )
@@ -14,12 +13,6 @@ def _qwen_kernel_initializer(stddev=0.02):
14
13
  return keras.initializers.RandomNormal(stddev=stddev)
15
14
 
16
15
 
17
- @keras_hub_export(
18
- [
19
- "keras_hub.models.QwenBackbone",
20
- "keras_hub.models.Qwen2Backbone",
21
- ]
22
- )
23
16
  class QwenBackbone(Backbone):
24
17
  """
25
18
  The Qwen Transformer core architecture with hyperparameters.
@@ -1,7 +1,6 @@
1
1
  import keras
2
2
  from keras import ops
3
3
 
4
- from keras_hub.src.api_export import keras_hub_export
5
4
  from keras_hub.src.models.causal_lm import CausalLM
6
5
  from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone
7
6
  from keras_hub.src.models.qwen.qwen_causal_lm_preprocessor import (
@@ -10,12 +9,6 @@ from keras_hub.src.models.qwen.qwen_causal_lm_preprocessor import (
10
9
  from keras_hub.src.utils.tensor_utils import any_equal
11
10
 
12
11
 
13
- @keras_hub_export(
14
- [
15
- "keras_hub.models.QwenCausalLM",
16
- "keras_hub.models.Qwen2CausalLM",
17
- ]
18
- )
19
12
  class QwenCausalLM(CausalLM):
20
13
  backbone_cls = QwenBackbone
21
14
  preprocessor_cls = QwenCausalLMPreprocessor
@@ -1,15 +1,8 @@
1
- from keras_hub.src.api_export import keras_hub_export
2
1
  from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor
3
2
  from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone
4
3
  from keras_hub.src.models.qwen.qwen_tokenizer import QwenTokenizer
5
4
 
6
5
 
7
- @keras_hub_export(
8
- [
9
- "keras_hub.models.QwenCausalLMPreprocessor",
10
- "keras_hub.models.Qwen2CausalLMPreprocessor",
11
- ]
12
- )
13
6
  class QwenCausalLMPreprocessor(CausalLMPreprocessor):
14
7
  backbone_cls = QwenBackbone
15
8
  tokenizer_cls = QwenTokenizer
@@ -1,16 +1,7 @@
1
- from keras_hub.src.api_export import keras_hub_export
2
1
  from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone
3
2
  from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer
4
3
 
5
4
 
6
- @keras_hub_export(
7
- [
8
- "keras_hub.tokenizers.QwenTokenizer",
9
- "keras_hub.tokenizers.Qwen2Tokenizer",
10
- "keras_hub.models.QwenTokenizer",
11
- "keras_hub.models.Qwen2Tokenizer",
12
- ]
13
- )
14
5
  class QwenTokenizer(BytePairTokenizer):
15
6
  """Tokenizer for Qwen models.
16
7
 
@@ -16,7 +16,7 @@ def roformer_kernel_initializer(stddev=0.02):
16
16
  return keras.initializers.TruncatedNormal(stddev=stddev)
17
17
 
18
18
 
19
- @keras_hub_export("keras_hub.models.RorformerV2Backbone")
19
+ @keras_hub_export("keras_hub.models.RoformerV2Backbone")
20
20
  class RoformerV2Backbone(Backbone):
21
21
  """A RoformerV2 encoder network.
22
22
 
@@ -10,8 +10,8 @@ from keras_hub.src.models.roformer_v2.roformer_v2_text_classifier_preprocessor i
10
10
  )
11
11
 
12
12
 
13
- @keras_hub_export("keras_hub.models.RorformerV2TextClassifier")
14
- class RorformerV2TextClassifier(RobertaTextClassifier):
13
+ @keras_hub_export("keras_hub.models.RoformerV2TextClassifier")
14
+ class RoformerV2TextClassifier(RobertaTextClassifier):
15
15
  """An end-to-end RoformerV2 model for classification tasks.
16
16
 
17
17
  This model attaches a classification head to
@@ -6,8 +6,8 @@ from keras import ops
6
6
 
7
7
  from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
8
8
  from keras_hub.src.models.backbone import Backbone
9
+ from keras_hub.src.utils.keras_utils import fused_attention_op_available
9
10
  from keras_hub.src.utils.keras_utils import gelu_approximate
10
- from keras_hub.src.utils.keras_utils import has_flash_attention_support
11
11
  from keras_hub.src.utils.keras_utils import standardize_data_format
12
12
 
13
13
 
@@ -771,7 +771,7 @@ class MMDiTBlock(layers.Layer):
771
771
  def _compute_attention(self, query, key, value):
772
772
  batch_size = ops.shape(query)[0]
773
773
 
774
- if has_flash_attention_support():
774
+ if fused_attention_op_available():
775
775
  # Use `dot_product_attention` with Flash Attention support if
776
776
  # available.
777
777
  encoded = ops.dot_product_attention(
@@ -53,12 +53,17 @@ class ViTImageConverter(ImageConverter):
53
53
 
54
54
  @preprocessing_function
55
55
  def call(self, inputs):
56
+ # TODO: Remove this whole function. Why can just use scale and offset
57
+ # in the base class.
56
58
  x = super().call(inputs)
57
- # By default normalize using imagenet mean and std
58
59
  if self.norm_mean:
59
- x = x - self._expand_non_channel_dims(self.norm_mean, x)
60
+ norm_mean = self._expand_non_channel_dims(self.norm_mean, x)
61
+ x, norm_mean = self._convert_types(x, norm_mean, self.compute_dtype)
62
+ x = x - norm_mean
60
63
  if self.norm_std:
61
- x = x / self._expand_non_channel_dims(self.norm_std, x)
64
+ norm_std = self._expand_non_channel_dims(self.norm_std, x)
65
+ x, norm_std = self._convert_types(x, norm_std, x.dtype)
66
+ x = x / norm_std
62
67
 
63
68
  return x
64
69
 
@@ -197,6 +197,7 @@ class TestCase(tf.test.TestCase, parameterized.TestCase):
197
197
  input_data,
198
198
  expected_output=None,
199
199
  expected_detokenize_output=None,
200
+ return_output=False,
200
201
  ):
201
202
  """Run basic tests for a preprocessing layer."""
202
203
  layer = cls(**init_kwargs)
@@ -230,6 +231,9 @@ class TestCase(tf.test.TestCase, parameterized.TestCase):
230
231
  if expected_output:
231
232
  self.assertAllClose(output, expected_output)
232
233
 
234
+ if return_output:
235
+ return output
236
+
233
237
  def run_preprocessor_test(
234
238
  self,
235
239
  cls,
@@ -55,7 +55,7 @@ def standardize_data_format(data_format):
55
55
  return data_format
56
56
 
57
57
 
58
- def has_flash_attention_support():
58
+ def fused_attention_op_available():
59
59
  if (
60
60
  hasattr(keras.config, "is_flash_attention_enabled")
61
61
  and keras.config.backend() == "jax"
@@ -104,3 +104,46 @@ def running_on_gpu():
104
104
  import torch
105
105
 
106
106
  return torch.cuda.is_available()
107
+
108
+
109
+ def gpu_supports_fused_attention_op():
110
+ deny_list = ["T4"]
111
+ for denied_gpu in deny_list:
112
+ if any(denied_gpu in gpu.upper() for gpu in get_gpu_names()):
113
+ return False
114
+ return True
115
+
116
+
117
+ def get_gpu_names():
118
+ """Detects and returns the names of available GPUs based on the backend.
119
+
120
+ Note:
121
+ The format and content of the returned GPU names are **not normalized**
122
+ and vary significantly depending on the active backend. This function
123
+ provides the names as reported by the respective backend's API."
124
+ """
125
+ backend = keras.config.backend()
126
+ if backend == "jax":
127
+ import jax
128
+
129
+ devices = jax.devices()
130
+
131
+ return [getattr(d, "device_kind", "") for d in devices]
132
+
133
+ elif backend == "tensorflow":
134
+ import tensorflow as tf
135
+
136
+ gpus = tf.config.list_physical_devices("GPU")
137
+ return [
138
+ tf.config.experimental.get_device_details(gpu)["device_name"]
139
+ for gpu in gpus
140
+ ]
141
+ elif backend == "torch":
142
+ import torch
143
+
144
+ return [
145
+ torch.cuda.get_device_name(i)
146
+ for i in range(torch.cuda.device_count())
147
+ ]
148
+ else:
149
+ return [""]