keras-hub-nightly 0.20.0.dev202504020401__py3-none-any.whl → 0.21.0.dev202504040358__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/models/__init__.py +5 -20
- keras_hub/api/tokenizers/__init__.py +0 -4
- keras_hub/src/layers/preprocessing/image_converter.py +26 -16
- keras_hub/src/models/gemma/gemma_attention.py +17 -10
- keras_hub/src/models/gemma3/gemma3_attention.py +76 -23
- keras_hub/src/models/gemma3/gemma3_backbone.py +117 -46
- keras_hub/src/models/gemma3/gemma3_causal_lm.py +72 -15
- keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +512 -355
- keras_hub/src/models/gemma3/gemma3_decoder_block.py +23 -19
- keras_hub/src/models/gemma3/gemma3_image_converter.py +6 -0
- keras_hub/src/models/gemma3/gemma3_interleave_embeddings.py +56 -16
- keras_hub/src/models/gemma3/gemma3_presets.py +74 -8
- keras_hub/src/models/gemma3/gemma3_tokenizer.py +9 -0
- keras_hub/src/models/gemma3/{gemma3_vit.py → gemma3_vision_encoder.py} +150 -139
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -2
- keras_hub/src/models/llama/llama_attention.py +2 -2
- keras_hub/src/models/mistral/mistral_attention.py +2 -2
- keras_hub/src/models/phi3/phi3_attention.py +2 -2
- keras_hub/src/models/qwen/qwen_attention.py +2 -2
- keras_hub/src/models/qwen/qwen_backbone.py +0 -7
- keras_hub/src/models/qwen/qwen_causal_lm.py +0 -7
- keras_hub/src/models/qwen/qwen_causal_lm_preprocessor.py +0 -7
- keras_hub/src/models/qwen/qwen_tokenizer.py +0 -9
- keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -1
- keras_hub/src/models/roformer_v2/roformer_v2_text_classifier.py +2 -2
- keras_hub/src/models/stable_diffusion_3/mmdit.py +2 -2
- keras_hub/src/models/vit/vit_image_converter.py +8 -3
- keras_hub/src/tests/test_case.py +4 -0
- keras_hub/src/utils/keras_utils.py +44 -1
- keras_hub/src/utils/tensor_utils.py +6 -0
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.20.0.dev202504020401.dist-info → keras_hub_nightly-0.21.0.dev202504040358.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.20.0.dev202504020401.dist-info → keras_hub_nightly-0.21.0.dev202504040358.dist-info}/RECORD +35 -35
- {keras_hub_nightly-0.20.0.dev202504020401.dist-info → keras_hub_nightly-0.21.0.dev202504040358.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.20.0.dev202504020401.dist-info → keras_hub_nightly-0.21.0.dev202504040358.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,150 @@
|
|
1
1
|
import keras
|
2
2
|
from keras import ops
|
3
3
|
|
4
|
+
from keras_hub.src.api_export import keras_hub_export
|
4
5
|
from keras_hub.src.models.gemma.rms_normalization import RMSNormalization
|
5
6
|
from keras_hub.src.utils.keras_utils import clone_initializer
|
6
7
|
|
7
8
|
|
8
|
-
|
9
|
+
@keras_hub_export("keras_hub.models.Gemma3VisionEncoder")
|
10
|
+
class Gemma3VisionEncoder(keras.Model):
|
11
|
+
"""Vision Transformer (ViT) model for Gemma3.
|
12
|
+
|
13
|
+
Args:
|
14
|
+
image_size: int. The height/width of the image. Both height and width is
|
15
|
+
expected to be the same.
|
16
|
+
patch_size: int. The size of each square patch in the input image.
|
17
|
+
num_heads: int. The number of attention heads for the vision(image)
|
18
|
+
transformer encoder.
|
19
|
+
hidden_dim: int. The size of the transformer hidden state at the end
|
20
|
+
of each vision transformer layer.
|
21
|
+
num_layers: int. The number of transformer layers.
|
22
|
+
intermediate_dim: int. The output dimension of the first Dense layer in
|
23
|
+
a two-layer feedforward network for transformer.
|
24
|
+
output_dim: int. The odimension of the output returned by the model.
|
25
|
+
pool_size: int. Factors by which to downscale `(dim1, dim2)` in the
|
26
|
+
average pooling layer. The same value is used for `"strides"`.
|
27
|
+
Defaults to 14.
|
28
|
+
layer_norm_epsilon: float. The epsilon value user for every layer norm
|
29
|
+
in all transformer blocks. Defaults to `1e-6`.
|
30
|
+
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
|
31
|
+
for the models computations and weights. Note that some
|
32
|
+
computations, such as softmax and layer normalization will always
|
33
|
+
be done a float32 precision regardless of dtype.
|
34
|
+
|
35
|
+
Example:
|
36
|
+
```python
|
37
|
+
image = np.random.rand(224, 224, 3)
|
38
|
+
vit_model = Gemma3VisionEncoder(image_size=224)
|
39
|
+
# The output will be of shape:
|
40
|
+
# [batch_size, num_vision_tokens_per_image, hidden_dim]
|
41
|
+
output = vit_model([image])
|
42
|
+
```
|
43
|
+
"""
|
44
|
+
|
45
|
+
def __init__(
|
46
|
+
self,
|
47
|
+
image_size,
|
48
|
+
patch_size,
|
49
|
+
num_heads,
|
50
|
+
hidden_dim,
|
51
|
+
num_layers,
|
52
|
+
intermediate_dim,
|
53
|
+
output_dim,
|
54
|
+
pool_size=14,
|
55
|
+
layer_norm_epsilon=1e-6,
|
56
|
+
dtype=None,
|
57
|
+
**kwargs,
|
58
|
+
):
|
59
|
+
# If the passed dtype is `bfloat16`, use `float32` to maintain parity
|
60
|
+
# with other framework implementations.
|
61
|
+
if dtype == "bfloat16":
|
62
|
+
dtype = "float32"
|
63
|
+
|
64
|
+
# === Functional Model ===
|
65
|
+
image_input = keras.Input(
|
66
|
+
shape=(None, image_size, image_size, 3),
|
67
|
+
name="images",
|
68
|
+
)
|
69
|
+
x = image_input # Intermediate result.
|
70
|
+
x = Gemma3VisionEncoderBlock(
|
71
|
+
hidden_dim=hidden_dim,
|
72
|
+
num_layers=num_layers,
|
73
|
+
num_heads=num_heads,
|
74
|
+
intermediate_dim=intermediate_dim,
|
75
|
+
patch_size=patch_size,
|
76
|
+
image_size=image_size,
|
77
|
+
dtype=dtype,
|
78
|
+
name="image_encoder",
|
79
|
+
)(x)
|
80
|
+
|
81
|
+
x = Gemma3VisionAveragePooling(
|
82
|
+
image_size=image_size,
|
83
|
+
patch_size=patch_size,
|
84
|
+
pool_size=pool_size,
|
85
|
+
dtype=dtype,
|
86
|
+
name="pooling",
|
87
|
+
)(x)
|
88
|
+
|
89
|
+
x = Gemma3VisionOutput(
|
90
|
+
output_dim=output_dim,
|
91
|
+
layer_norm_epsilon=layer_norm_epsilon,
|
92
|
+
kernel_initializer=keras.initializers.RandomNormal(
|
93
|
+
mean=0.0, stddev=0.01
|
94
|
+
),
|
95
|
+
dtype=dtype,
|
96
|
+
name="vision_output_encoder",
|
97
|
+
)(x)
|
98
|
+
|
99
|
+
outputs = x
|
100
|
+
super().__init__(
|
101
|
+
inputs=image_input,
|
102
|
+
outputs=outputs,
|
103
|
+
**kwargs,
|
104
|
+
)
|
105
|
+
|
106
|
+
# === Config ===
|
107
|
+
self.image_size = image_size
|
108
|
+
self.patch_size = patch_size
|
109
|
+
self.num_heads = num_heads
|
110
|
+
self.hidden_dim = hidden_dim
|
111
|
+
self.num_layers = num_layers
|
112
|
+
self.intermediate_dim = intermediate_dim
|
113
|
+
self.output_dim = output_dim
|
114
|
+
self.pool_size = pool_size
|
115
|
+
self.layer_norm_epsilon = layer_norm_epsilon
|
116
|
+
self.num_vision_tokens_per_image = (
|
117
|
+
(image_size // patch_size) ** 2
|
118
|
+
) // (pool_size**2)
|
119
|
+
|
120
|
+
# Before Keras 3.2, there is no `keras.dtype_policies.get`.
|
121
|
+
if hasattr(keras.dtype_policies, "get"):
|
122
|
+
self.dtype_policy = keras.dtype_policies.get(dtype)
|
123
|
+
else:
|
124
|
+
if isinstance(dtype, keras.dtype_policies.DTypePolicy):
|
125
|
+
dtype = dtype.name
|
126
|
+
dtype = dtype or keras.config.dtype_policy().name
|
127
|
+
self.dtype_policy = keras.dtype_policies.DTypePolicy(dtype)
|
128
|
+
|
129
|
+
def get_config(self):
|
130
|
+
config = super().get_config()
|
131
|
+
config.update(
|
132
|
+
{
|
133
|
+
"num_heads": self.num_heads,
|
134
|
+
"hidden_dim": self.hidden_dim,
|
135
|
+
"num_layers": self.num_layers,
|
136
|
+
"intermediate_dim": self.intermediate_dim,
|
137
|
+
"output_dim": self.output_dim,
|
138
|
+
"pool_size": self.pool_size,
|
139
|
+
"image_size": self.image_size,
|
140
|
+
"patch_size": self.patch_size,
|
141
|
+
"layer_norm_epsilon": self.layer_norm_epsilon,
|
142
|
+
}
|
143
|
+
)
|
144
|
+
return config
|
145
|
+
|
146
|
+
|
147
|
+
class Gemma3VisionEmbedding(keras.layers.Layer):
|
9
148
|
def __init__(
|
10
149
|
self,
|
11
150
|
image_size,
|
@@ -62,7 +201,7 @@ class Gemma3VitEmbeddings(keras.layers.Layer):
|
|
62
201
|
)
|
63
202
|
|
64
203
|
|
65
|
-
class
|
204
|
+
class Gemma3VisionAttention(keras.layers.Layer):
|
66
205
|
"""
|
67
206
|
Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py
|
68
207
|
"""
|
@@ -197,7 +336,7 @@ class Gemma3VitAttention(keras.layers.Layer):
|
|
197
336
|
return config
|
198
337
|
|
199
338
|
|
200
|
-
class
|
339
|
+
class Gemma3VisionEncoderLayer(keras.layers.Layer):
|
201
340
|
def __init__(
|
202
341
|
self,
|
203
342
|
num_heads,
|
@@ -217,7 +356,7 @@ class Gemma3VitEncoderBlock(keras.layers.Layer):
|
|
217
356
|
|
218
357
|
def build(self, input_shape):
|
219
358
|
hidden_dim = input_shape[-1]
|
220
|
-
self.attn =
|
359
|
+
self.attn = Gemma3VisionAttention(
|
221
360
|
hidden_dim,
|
222
361
|
self.num_heads,
|
223
362
|
dtype=self.dtype_policy,
|
@@ -277,7 +416,7 @@ class Gemma3VitEncoderBlock(keras.layers.Layer):
|
|
277
416
|
return config
|
278
417
|
|
279
418
|
|
280
|
-
class
|
419
|
+
class Gemma3VisionEncoderBlock(keras.layers.Layer):
|
281
420
|
def __init__(
|
282
421
|
self,
|
283
422
|
patch_size,
|
@@ -303,7 +442,7 @@ class Gemma3VitEncoder(keras.layers.Layer):
|
|
303
442
|
dtype=dtype,
|
304
443
|
name="encoder_layer_norm",
|
305
444
|
)
|
306
|
-
self.vision_embeddings =
|
445
|
+
self.vision_embeddings = Gemma3VisionEmbedding(
|
307
446
|
hidden_dim=hidden_dim,
|
308
447
|
patch_size=patch_size,
|
309
448
|
image_size=image_size,
|
@@ -311,7 +450,7 @@ class Gemma3VitEncoder(keras.layers.Layer):
|
|
311
450
|
name="encoder_embeddings",
|
312
451
|
)
|
313
452
|
self.resblocks = [
|
314
|
-
|
453
|
+
Gemma3VisionEncoderLayer(
|
315
454
|
self.num_heads,
|
316
455
|
self.intermediate_dim,
|
317
456
|
dtype=dtype,
|
@@ -321,7 +460,7 @@ class Gemma3VitEncoder(keras.layers.Layer):
|
|
321
460
|
]
|
322
461
|
|
323
462
|
def build(self, inputs_shape):
|
324
|
-
# Collapse `batch_size`, dummy axis, `
|
463
|
+
# Collapse `batch_size`, dummy axis, `max_images_per_prompt` into one.
|
325
464
|
inputs_shape = [None] + list(inputs_shape[2:])
|
326
465
|
self.vision_embeddings.build(inputs_shape)
|
327
466
|
for block in self.resblocks:
|
@@ -332,7 +471,7 @@ class Gemma3VitEncoder(keras.layers.Layer):
|
|
332
471
|
def call(self, inputs, mask=None):
|
333
472
|
inputs_shape = ops.shape(inputs)
|
334
473
|
|
335
|
-
# Collapse `batch_size`, dummy axis, `
|
474
|
+
# Collapse `batch_size`, dummy axis, `max_images_per_prompt` into one.
|
336
475
|
inputs = ops.reshape(
|
337
476
|
inputs,
|
338
477
|
[inputs_shape[0] * inputs_shape[1]] + list(inputs_shape[2:]),
|
@@ -372,7 +511,7 @@ class Gemma3VitEncoder(keras.layers.Layer):
|
|
372
511
|
return config
|
373
512
|
|
374
513
|
|
375
|
-
class
|
514
|
+
class Gemma3VisionAveragePooling(keras.layers.Layer):
|
376
515
|
def __init__(self, image_size, patch_size, pool_size, **kwargs):
|
377
516
|
super().__init__(**kwargs)
|
378
517
|
|
@@ -425,7 +564,7 @@ class AveragePooling(keras.layers.Layer):
|
|
425
564
|
return config
|
426
565
|
|
427
566
|
|
428
|
-
class
|
567
|
+
class Gemma3VisionOutput(keras.layers.Layer):
|
429
568
|
def __init__(
|
430
569
|
self,
|
431
570
|
output_dim,
|
@@ -478,131 +617,3 @@ class Gemma3VisionOutputEncoder(keras.layers.Layer):
|
|
478
617
|
|
479
618
|
def compute_output_shape(self, input_shape):
|
480
619
|
return input_shape[:-1] + (self.output_dim,)
|
481
|
-
|
482
|
-
|
483
|
-
class Gemma3Vit(keras.Model):
|
484
|
-
"""Vision Transformer (ViT) model for Gemma3.
|
485
|
-
|
486
|
-
Args:
|
487
|
-
image_size: int. The height/width of the image. Both height and width is
|
488
|
-
expected to be the same.
|
489
|
-
patch_size: int. The size of each square patch in the input image.
|
490
|
-
num_heads: int. The number of attention heads for the vision(image)
|
491
|
-
transformer encoder.
|
492
|
-
hidden_dim: int. The size of the transformer hidden state at the end
|
493
|
-
of each vision transformer layer.
|
494
|
-
num_layers: int. The number of transformer layers.
|
495
|
-
intermediate_dim: int. The output dimension of the first Dense layer in
|
496
|
-
a two-layer feedforward network for transformer.
|
497
|
-
pool_size: int. Factors by which to downscale `(dim1, dim2)` in the
|
498
|
-
average pooling layer. The same value is used for `"strides"`.
|
499
|
-
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
|
500
|
-
for the models computations and weights. Note that some
|
501
|
-
computations, such as softmax and layer normalization will always
|
502
|
-
be done a float32 precision regardless of dtype.
|
503
|
-
|
504
|
-
Example:
|
505
|
-
```python
|
506
|
-
image = np.random.rand(224, 224, 3)
|
507
|
-
vit_model = Gemma3Vit(image_size=224)
|
508
|
-
# The output will be of shape:
|
509
|
-
# [batch_size, num_vision_tokens_per_image, hidden_dim]
|
510
|
-
output = vit_model([image])
|
511
|
-
```
|
512
|
-
"""
|
513
|
-
|
514
|
-
def __init__(
|
515
|
-
self,
|
516
|
-
image_size,
|
517
|
-
patch_size,
|
518
|
-
num_heads,
|
519
|
-
hidden_dim,
|
520
|
-
num_layers,
|
521
|
-
intermediate_dim,
|
522
|
-
output_dim,
|
523
|
-
pool_size=14,
|
524
|
-
layer_norm_epsilon=1e-6,
|
525
|
-
dtype=None,
|
526
|
-
**kwargs,
|
527
|
-
):
|
528
|
-
# === Functional Model ===
|
529
|
-
image_input = keras.Input(
|
530
|
-
shape=(None, image_size, image_size, 3),
|
531
|
-
name="images",
|
532
|
-
)
|
533
|
-
x = image_input # Intermediate result.
|
534
|
-
x = Gemma3VitEncoder(
|
535
|
-
hidden_dim=hidden_dim,
|
536
|
-
num_layers=num_layers,
|
537
|
-
num_heads=num_heads,
|
538
|
-
intermediate_dim=intermediate_dim,
|
539
|
-
patch_size=patch_size,
|
540
|
-
image_size=image_size,
|
541
|
-
dtype=dtype,
|
542
|
-
name="image_encoder",
|
543
|
-
)(x)
|
544
|
-
|
545
|
-
x = AveragePooling(
|
546
|
-
image_size=image_size,
|
547
|
-
patch_size=patch_size,
|
548
|
-
pool_size=pool_size,
|
549
|
-
dtype=dtype,
|
550
|
-
name="pooling",
|
551
|
-
)(x)
|
552
|
-
|
553
|
-
x = Gemma3VisionOutputEncoder(
|
554
|
-
output_dim=output_dim,
|
555
|
-
layer_norm_epsilon=layer_norm_epsilon,
|
556
|
-
kernel_initializer=keras.initializers.RandomNormal(
|
557
|
-
mean=0.0, stddev=0.01
|
558
|
-
),
|
559
|
-
dtype=dtype,
|
560
|
-
name="vision_output_encoder",
|
561
|
-
)(x)
|
562
|
-
|
563
|
-
outputs = x
|
564
|
-
super().__init__(
|
565
|
-
inputs=image_input,
|
566
|
-
outputs=outputs,
|
567
|
-
**kwargs,
|
568
|
-
)
|
569
|
-
|
570
|
-
# === Config ===
|
571
|
-
self.image_size = image_size
|
572
|
-
self.patch_size = patch_size
|
573
|
-
self.num_heads = num_heads
|
574
|
-
self.hidden_dim = hidden_dim
|
575
|
-
self.num_layers = num_layers
|
576
|
-
self.intermediate_dim = intermediate_dim
|
577
|
-
self.output_dim = output_dim
|
578
|
-
self.pool_size = pool_size
|
579
|
-
self.layer_norm_epsilon = layer_norm_epsilon
|
580
|
-
self.num_vision_tokens_per_image = (
|
581
|
-
(image_size // patch_size) ** 2
|
582
|
-
) // (pool_size**2)
|
583
|
-
|
584
|
-
# Before Keras 3.2, there is no `keras.dtype_policies.get`.
|
585
|
-
if hasattr(keras.dtype_policies, "get"):
|
586
|
-
self.dtype_policy = keras.dtype_policies.get(dtype)
|
587
|
-
else:
|
588
|
-
if isinstance(dtype, keras.dtype_policies.DTypePolicy):
|
589
|
-
dtype = dtype.name
|
590
|
-
dtype = dtype or keras.config.dtype_policy().name
|
591
|
-
self.dtype_policy = keras.dtype_policies.DTypePolicy(dtype)
|
592
|
-
|
593
|
-
def get_config(self):
|
594
|
-
config = super().get_config()
|
595
|
-
config.update(
|
596
|
-
{
|
597
|
-
"num_heads": self.num_heads,
|
598
|
-
"hidden_dim": self.hidden_dim,
|
599
|
-
"num_layers": self.num_layers,
|
600
|
-
"intermediate_dim": self.intermediate_dim,
|
601
|
-
"output_dim": self.output_dim,
|
602
|
-
"pool_size": self.pool_size,
|
603
|
-
"image_size": self.image_size,
|
604
|
-
"patch_size": self.patch_size,
|
605
|
-
"layer_norm_epsilon": self.layer_norm_epsilon,
|
606
|
-
}
|
607
|
-
)
|
608
|
-
return config
|
@@ -5,7 +5,7 @@ from keras import ops
|
|
5
5
|
|
6
6
|
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
|
7
7
|
from keras_hub.src.utils.keras_utils import clone_initializer
|
8
|
-
from keras_hub.src.utils.keras_utils import
|
8
|
+
from keras_hub.src.utils.keras_utils import fused_attention_op_available
|
9
9
|
|
10
10
|
|
11
11
|
class GPTNeoXAttention(keras.layers.Layer):
|
@@ -125,7 +125,7 @@ class GPTNeoXAttention(keras.layers.Layer):
|
|
125
125
|
def _compute_attention(
|
126
126
|
self, query, key, value, attention_mask=None, training=None
|
127
127
|
):
|
128
|
-
if
|
128
|
+
if fused_attention_op_available() and self.dropout == 0:
|
129
129
|
# Use `dot_product_attention` with Flash Attention support if
|
130
130
|
# available.
|
131
131
|
if attention_mask is not None:
|
@@ -5,7 +5,7 @@ from keras import ops
|
|
5
5
|
|
6
6
|
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
|
7
7
|
from keras_hub.src.utils.keras_utils import clone_initializer
|
8
|
-
from keras_hub.src.utils.keras_utils import
|
8
|
+
from keras_hub.src.utils.keras_utils import fused_attention_op_available
|
9
9
|
|
10
10
|
|
11
11
|
class LlamaAttention(keras.layers.Layer):
|
@@ -185,7 +185,7 @@ class LlamaAttention(keras.layers.Layer):
|
|
185
185
|
return self._softmax(attention_scores)
|
186
186
|
|
187
187
|
def _compute_attention(self, query, key, value, attention_mask=None):
|
188
|
-
if
|
188
|
+
if fused_attention_op_available():
|
189
189
|
# Use `dot_product_attention` with Flash Attention support if
|
190
190
|
# available.
|
191
191
|
if attention_mask is not None:
|
@@ -5,7 +5,7 @@ from keras import ops
|
|
5
5
|
|
6
6
|
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
|
7
7
|
from keras_hub.src.utils.keras_utils import clone_initializer
|
8
|
-
from keras_hub.src.utils.keras_utils import
|
8
|
+
from keras_hub.src.utils.keras_utils import fused_attention_op_available
|
9
9
|
|
10
10
|
|
11
11
|
# This is just a self-attention layer in Mistral. But it can be generalized
|
@@ -196,7 +196,7 @@ class CachedMistralAttention(keras.layers.Layer):
|
|
196
196
|
return self._softmax(attention_scores)
|
197
197
|
|
198
198
|
def _compute_attention(self, query, key, value, attention_mask=None):
|
199
|
-
if
|
199
|
+
if fused_attention_op_available():
|
200
200
|
# Use `dot_product_attention` with Flash Attention support if
|
201
201
|
# available.
|
202
202
|
if attention_mask is not None:
|
@@ -8,7 +8,7 @@ from keras_hub.src.models.phi3.phi3_rotary_embedding import (
|
|
8
8
|
Phi3SuScaledRotaryEmbedding,
|
9
9
|
)
|
10
10
|
from keras_hub.src.utils.keras_utils import clone_initializer
|
11
|
-
from keras_hub.src.utils.keras_utils import
|
11
|
+
from keras_hub.src.utils.keras_utils import fused_attention_op_available
|
12
12
|
|
13
13
|
|
14
14
|
class Phi3Attention(keras.layers.Layer):
|
@@ -217,7 +217,7 @@ class Phi3Attention(keras.layers.Layer):
|
|
217
217
|
return self.softmax(attention_scores)
|
218
218
|
|
219
219
|
def _compute_attention(self, query, key, value, attention_mask=None):
|
220
|
-
if
|
220
|
+
if fused_attention_op_available():
|
221
221
|
# Use `dot_product_attention` with Flash Attention support if
|
222
222
|
# available.
|
223
223
|
if attention_mask is not None:
|
@@ -5,7 +5,7 @@ from keras import ops
|
|
5
5
|
|
6
6
|
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
|
7
7
|
from keras_hub.src.utils.keras_utils import clone_initializer
|
8
|
-
from keras_hub.src.utils.keras_utils import
|
8
|
+
from keras_hub.src.utils.keras_utils import fused_attention_op_available
|
9
9
|
|
10
10
|
|
11
11
|
class QwenAttention(keras.layers.Layer):
|
@@ -263,7 +263,7 @@ class QwenAttention(keras.layers.Layer):
|
|
263
263
|
Returns:
|
264
264
|
attention_output: Output tensor after applying attention.
|
265
265
|
"""
|
266
|
-
if
|
266
|
+
if fused_attention_op_available():
|
267
267
|
# Use `dot_product_attention` with Flash Attention support if
|
268
268
|
# available.
|
269
269
|
if attention_mask is not None:
|
@@ -1,7 +1,6 @@
|
|
1
1
|
import keras
|
2
2
|
from keras import ops
|
3
3
|
|
4
|
-
from keras_hub.src.api_export import keras_hub_export
|
5
4
|
from keras_hub.src.layers.modeling.reversible_embedding import (
|
6
5
|
ReversibleEmbedding,
|
7
6
|
)
|
@@ -14,12 +13,6 @@ def _qwen_kernel_initializer(stddev=0.02):
|
|
14
13
|
return keras.initializers.RandomNormal(stddev=stddev)
|
15
14
|
|
16
15
|
|
17
|
-
@keras_hub_export(
|
18
|
-
[
|
19
|
-
"keras_hub.models.QwenBackbone",
|
20
|
-
"keras_hub.models.Qwen2Backbone",
|
21
|
-
]
|
22
|
-
)
|
23
16
|
class QwenBackbone(Backbone):
|
24
17
|
"""
|
25
18
|
The Qwen Transformer core architecture with hyperparameters.
|
@@ -1,7 +1,6 @@
|
|
1
1
|
import keras
|
2
2
|
from keras import ops
|
3
3
|
|
4
|
-
from keras_hub.src.api_export import keras_hub_export
|
5
4
|
from keras_hub.src.models.causal_lm import CausalLM
|
6
5
|
from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone
|
7
6
|
from keras_hub.src.models.qwen.qwen_causal_lm_preprocessor import (
|
@@ -10,12 +9,6 @@ from keras_hub.src.models.qwen.qwen_causal_lm_preprocessor import (
|
|
10
9
|
from keras_hub.src.utils.tensor_utils import any_equal
|
11
10
|
|
12
11
|
|
13
|
-
@keras_hub_export(
|
14
|
-
[
|
15
|
-
"keras_hub.models.QwenCausalLM",
|
16
|
-
"keras_hub.models.Qwen2CausalLM",
|
17
|
-
]
|
18
|
-
)
|
19
12
|
class QwenCausalLM(CausalLM):
|
20
13
|
backbone_cls = QwenBackbone
|
21
14
|
preprocessor_cls = QwenCausalLMPreprocessor
|
@@ -1,15 +1,8 @@
|
|
1
|
-
from keras_hub.src.api_export import keras_hub_export
|
2
1
|
from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor
|
3
2
|
from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone
|
4
3
|
from keras_hub.src.models.qwen.qwen_tokenizer import QwenTokenizer
|
5
4
|
|
6
5
|
|
7
|
-
@keras_hub_export(
|
8
|
-
[
|
9
|
-
"keras_hub.models.QwenCausalLMPreprocessor",
|
10
|
-
"keras_hub.models.Qwen2CausalLMPreprocessor",
|
11
|
-
]
|
12
|
-
)
|
13
6
|
class QwenCausalLMPreprocessor(CausalLMPreprocessor):
|
14
7
|
backbone_cls = QwenBackbone
|
15
8
|
tokenizer_cls = QwenTokenizer
|
@@ -1,16 +1,7 @@
|
|
1
|
-
from keras_hub.src.api_export import keras_hub_export
|
2
1
|
from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone
|
3
2
|
from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer
|
4
3
|
|
5
4
|
|
6
|
-
@keras_hub_export(
|
7
|
-
[
|
8
|
-
"keras_hub.tokenizers.QwenTokenizer",
|
9
|
-
"keras_hub.tokenizers.Qwen2Tokenizer",
|
10
|
-
"keras_hub.models.QwenTokenizer",
|
11
|
-
"keras_hub.models.Qwen2Tokenizer",
|
12
|
-
]
|
13
|
-
)
|
14
5
|
class QwenTokenizer(BytePairTokenizer):
|
15
6
|
"""Tokenizer for Qwen models.
|
16
7
|
|
@@ -16,7 +16,7 @@ def roformer_kernel_initializer(stddev=0.02):
|
|
16
16
|
return keras.initializers.TruncatedNormal(stddev=stddev)
|
17
17
|
|
18
18
|
|
19
|
-
@keras_hub_export("keras_hub.models.
|
19
|
+
@keras_hub_export("keras_hub.models.RoformerV2Backbone")
|
20
20
|
class RoformerV2Backbone(Backbone):
|
21
21
|
"""A RoformerV2 encoder network.
|
22
22
|
|
@@ -10,8 +10,8 @@ from keras_hub.src.models.roformer_v2.roformer_v2_text_classifier_preprocessor i
|
|
10
10
|
)
|
11
11
|
|
12
12
|
|
13
|
-
@keras_hub_export("keras_hub.models.
|
14
|
-
class
|
13
|
+
@keras_hub_export("keras_hub.models.RoformerV2TextClassifier")
|
14
|
+
class RoformerV2TextClassifier(RobertaTextClassifier):
|
15
15
|
"""An end-to-end RoformerV2 model for classification tasks.
|
16
16
|
|
17
17
|
This model attaches a classification head to
|
@@ -6,8 +6,8 @@ from keras import ops
|
|
6
6
|
|
7
7
|
from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
|
8
8
|
from keras_hub.src.models.backbone import Backbone
|
9
|
+
from keras_hub.src.utils.keras_utils import fused_attention_op_available
|
9
10
|
from keras_hub.src.utils.keras_utils import gelu_approximate
|
10
|
-
from keras_hub.src.utils.keras_utils import has_flash_attention_support
|
11
11
|
from keras_hub.src.utils.keras_utils import standardize_data_format
|
12
12
|
|
13
13
|
|
@@ -771,7 +771,7 @@ class MMDiTBlock(layers.Layer):
|
|
771
771
|
def _compute_attention(self, query, key, value):
|
772
772
|
batch_size = ops.shape(query)[0]
|
773
773
|
|
774
|
-
if
|
774
|
+
if fused_attention_op_available():
|
775
775
|
# Use `dot_product_attention` with Flash Attention support if
|
776
776
|
# available.
|
777
777
|
encoded = ops.dot_product_attention(
|
@@ -53,12 +53,17 @@ class ViTImageConverter(ImageConverter):
|
|
53
53
|
|
54
54
|
@preprocessing_function
|
55
55
|
def call(self, inputs):
|
56
|
+
# TODO: Remove this whole function. Why can just use scale and offset
|
57
|
+
# in the base class.
|
56
58
|
x = super().call(inputs)
|
57
|
-
# By default normalize using imagenet mean and std
|
58
59
|
if self.norm_mean:
|
59
|
-
|
60
|
+
norm_mean = self._expand_non_channel_dims(self.norm_mean, x)
|
61
|
+
x, norm_mean = self._convert_types(x, norm_mean, self.compute_dtype)
|
62
|
+
x = x - norm_mean
|
60
63
|
if self.norm_std:
|
61
|
-
|
64
|
+
norm_std = self._expand_non_channel_dims(self.norm_std, x)
|
65
|
+
x, norm_std = self._convert_types(x, norm_std, x.dtype)
|
66
|
+
x = x / norm_std
|
62
67
|
|
63
68
|
return x
|
64
69
|
|
keras_hub/src/tests/test_case.py
CHANGED
@@ -197,6 +197,7 @@ class TestCase(tf.test.TestCase, parameterized.TestCase):
|
|
197
197
|
input_data,
|
198
198
|
expected_output=None,
|
199
199
|
expected_detokenize_output=None,
|
200
|
+
return_output=False,
|
200
201
|
):
|
201
202
|
"""Run basic tests for a preprocessing layer."""
|
202
203
|
layer = cls(**init_kwargs)
|
@@ -230,6 +231,9 @@ class TestCase(tf.test.TestCase, parameterized.TestCase):
|
|
230
231
|
if expected_output:
|
231
232
|
self.assertAllClose(output, expected_output)
|
232
233
|
|
234
|
+
if return_output:
|
235
|
+
return output
|
236
|
+
|
233
237
|
def run_preprocessor_test(
|
234
238
|
self,
|
235
239
|
cls,
|
@@ -55,7 +55,7 @@ def standardize_data_format(data_format):
|
|
55
55
|
return data_format
|
56
56
|
|
57
57
|
|
58
|
-
def
|
58
|
+
def fused_attention_op_available():
|
59
59
|
if (
|
60
60
|
hasattr(keras.config, "is_flash_attention_enabled")
|
61
61
|
and keras.config.backend() == "jax"
|
@@ -104,3 +104,46 @@ def running_on_gpu():
|
|
104
104
|
import torch
|
105
105
|
|
106
106
|
return torch.cuda.is_available()
|
107
|
+
|
108
|
+
|
109
|
+
def gpu_supports_fused_attention_op():
|
110
|
+
deny_list = ["T4"]
|
111
|
+
for denied_gpu in deny_list:
|
112
|
+
if any(denied_gpu in gpu.upper() for gpu in get_gpu_names()):
|
113
|
+
return False
|
114
|
+
return True
|
115
|
+
|
116
|
+
|
117
|
+
def get_gpu_names():
|
118
|
+
"""Detects and returns the names of available GPUs based on the backend.
|
119
|
+
|
120
|
+
Note:
|
121
|
+
The format and content of the returned GPU names are **not normalized**
|
122
|
+
and vary significantly depending on the active backend. This function
|
123
|
+
provides the names as reported by the respective backend's API."
|
124
|
+
"""
|
125
|
+
backend = keras.config.backend()
|
126
|
+
if backend == "jax":
|
127
|
+
import jax
|
128
|
+
|
129
|
+
devices = jax.devices()
|
130
|
+
|
131
|
+
return [getattr(d, "device_kind", "") for d in devices]
|
132
|
+
|
133
|
+
elif backend == "tensorflow":
|
134
|
+
import tensorflow as tf
|
135
|
+
|
136
|
+
gpus = tf.config.list_physical_devices("GPU")
|
137
|
+
return [
|
138
|
+
tf.config.experimental.get_device_details(gpu)["device_name"]
|
139
|
+
for gpu in gpus
|
140
|
+
]
|
141
|
+
elif backend == "torch":
|
142
|
+
import torch
|
143
|
+
|
144
|
+
return [
|
145
|
+
torch.cuda.get_device_name(i)
|
146
|
+
for i in range(torch.cuda.device_count())
|
147
|
+
]
|
148
|
+
else:
|
149
|
+
return [""]
|