keras-hub-nightly 0.20.0.dev202504030357__py3-none-any.whl → 0.21.0.dev202504050402__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/models/__init__.py +5 -20
- keras_hub/api/tokenizers/__init__.py +0 -4
- keras_hub/src/layers/preprocessing/image_converter.py +26 -16
- keras_hub/src/models/gemma3/gemma3_attention.py +74 -21
- keras_hub/src/models/gemma3/gemma3_backbone.py +117 -46
- keras_hub/src/models/gemma3/gemma3_causal_lm.py +72 -15
- keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +512 -355
- keras_hub/src/models/gemma3/gemma3_decoder_block.py +23 -19
- keras_hub/src/models/gemma3/gemma3_image_converter.py +6 -0
- keras_hub/src/models/gemma3/gemma3_interleave_embeddings.py +56 -16
- keras_hub/src/models/gemma3/gemma3_presets.py +74 -8
- keras_hub/src/models/gemma3/gemma3_tokenizer.py +9 -0
- keras_hub/src/models/gemma3/{gemma3_vit.py → gemma3_vision_encoder.py} +150 -139
- keras_hub/src/models/qwen/qwen_backbone.py +0 -7
- keras_hub/src/models/qwen/qwen_causal_lm.py +0 -7
- keras_hub/src/models/qwen/qwen_causal_lm_preprocessor.py +0 -7
- keras_hub/src/models/qwen/qwen_tokenizer.py +0 -9
- keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -1
- keras_hub/src/models/roformer_v2/roformer_v2_text_classifier.py +2 -2
- keras_hub/src/models/vit/vit_image_converter.py +8 -3
- keras_hub/src/tests/test_case.py +4 -0
- keras_hub/src/utils/tensor_utils.py +6 -0
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.20.0.dev202504030357.dist-info → keras_hub_nightly-0.21.0.dev202504050402.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.20.0.dev202504030357.dist-info → keras_hub_nightly-0.21.0.dev202504050402.dist-info}/RECORD +27 -27
- {keras_hub_nightly-0.20.0.dev202504030357.dist-info → keras_hub_nightly-0.21.0.dev202504050402.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.20.0.dev202504030357.dist-info → keras_hub_nightly-0.21.0.dev202504050402.dist-info}/top_level.txt +0 -0
@@ -14,16 +14,17 @@ from keras_hub.src.models.gemma3.rms_normalization import RMSNormalization
|
|
14
14
|
class Gemma3DecoderBlock(keras.layers.Layer):
|
15
15
|
"""Transformer decoder layer for Gemma3.
|
16
16
|
|
17
|
-
This is
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
17
|
+
This decoder layer is the same as the layer used for Gemma and Gemma2.
|
18
|
+
However, there are a few key differences. Firstly, image tokens have
|
19
|
+
bidirectional masking. Additionally, this layer exposes the following args:
|
20
|
+
|
21
|
+
`use_query_key_norm`: bool. If True, apply RMS normalization on query
|
22
|
+
and key. For Gemma3, this is True.
|
23
|
+
`rope_wavelength`: float. Configurable value for RoPE wavelength. Gemma3
|
24
|
+
uses 10K for local attention layers and 1M for global attention layers.
|
25
|
+
`gate_dim_reduction`: int. In the gating layers, the output dimension is
|
26
|
+
`intermediate_dim // gate_dim_reduction`. For Gemma and Gemma2, this
|
27
|
+
value is 2. For Gemma3, it is 1.
|
27
28
|
"""
|
28
29
|
|
29
30
|
def __init__(
|
@@ -160,9 +161,10 @@ class Gemma3DecoderBlock(keras.layers.Layer):
|
|
160
161
|
# Isometric
|
161
162
|
return input_shape
|
162
163
|
|
163
|
-
def _compute_image_bidirectional_attention_mask(self,
|
164
|
-
#
|
165
|
-
|
164
|
+
def _compute_image_bidirectional_attention_mask(self, vision_mask):
|
165
|
+
# vision_mask is False for text, True for images. Shape of
|
166
|
+
# (bsz, seq_len).
|
167
|
+
bidirectional_mask = vision_mask
|
166
168
|
|
167
169
|
# Left pad with 0.
|
168
170
|
padded_mask = ops.cast(
|
@@ -194,7 +196,7 @@ class Gemma3DecoderBlock(keras.layers.Layer):
|
|
194
196
|
self,
|
195
197
|
x,
|
196
198
|
padding_mask,
|
197
|
-
|
199
|
+
vision_mask,
|
198
200
|
cache,
|
199
201
|
cache_update_index,
|
200
202
|
):
|
@@ -216,9 +218,9 @@ class Gemma3DecoderBlock(keras.layers.Layer):
|
|
216
218
|
|
217
219
|
# Compute bidirectional mask (image tokens can attend to each other
|
218
220
|
# in both directions, within the same image).
|
219
|
-
if
|
221
|
+
if vision_mask is not None:
|
220
222
|
bidirectional_image_mask = (
|
221
|
-
self._compute_image_bidirectional_attention_mask(
|
223
|
+
self._compute_image_bidirectional_attention_mask(vision_mask)
|
222
224
|
)
|
223
225
|
causal_mask = ops.logical_or(causal_mask, bidirectional_image_mask)
|
224
226
|
|
@@ -232,14 +234,15 @@ class Gemma3DecoderBlock(keras.layers.Layer):
|
|
232
234
|
self,
|
233
235
|
x,
|
234
236
|
padding_mask=None,
|
235
|
-
|
237
|
+
vision_mask=None,
|
236
238
|
cache=None,
|
237
239
|
cache_update_index=0,
|
240
|
+
cache_update_mask=None,
|
238
241
|
):
|
239
|
-
# Note: `
|
242
|
+
# Note: `vision_mask` is used only for Gemma3.
|
240
243
|
normalized_x = self.pre_attention_norm(x)
|
241
244
|
attention_mask = self._compute_attention_mask(
|
242
|
-
normalized_x, padding_mask,
|
245
|
+
normalized_x, padding_mask, vision_mask, cache, cache_update_index
|
243
246
|
)
|
244
247
|
if cache is not None:
|
245
248
|
attention, new_cache = self.attention(
|
@@ -247,6 +250,7 @@ class Gemma3DecoderBlock(keras.layers.Layer):
|
|
247
250
|
attention_mask=attention_mask,
|
248
251
|
cache=cache,
|
249
252
|
cache_update_index=cache_update_index,
|
253
|
+
cache_update_mask=cache_update_mask,
|
250
254
|
)
|
251
255
|
else:
|
252
256
|
attention = self.attention(
|
@@ -6,3 +6,9 @@ from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3Backbone
|
|
6
6
|
@keras_hub_export("keras_hub.layers.Gemma3ImageConverter")
|
7
7
|
class Gemma3ImageConverter(ImageConverter):
|
8
8
|
backbone_cls = Gemma3Backbone
|
9
|
+
|
10
|
+
def __init__(self, **kwargs):
|
11
|
+
# Always do image preprocessing in float32
|
12
|
+
kwargs.pop("dtype", None)
|
13
|
+
dtype = "float32"
|
14
|
+
super().__init__(dtype=dtype, **kwargs)
|
@@ -5,12 +5,17 @@ from keras import ops
|
|
5
5
|
class Gemma3InterleaveEmbeddings(keras.layers.Layer):
|
6
6
|
"""Places image embeddings in the correct position in an embedding sequence.
|
7
7
|
|
8
|
+
For Gemma3, images can be in any position in the input sequence. In order
|
9
|
+
to do accomplish this, we have image placeholder tokens in the input
|
10
|
+
sequence. We fill up these positions with the image embeddings as returned
|
11
|
+
by the vision encoder.
|
12
|
+
|
8
13
|
Args:
|
9
14
|
num_vision_tokens_per_image: int. Number of soft tokens per image.
|
10
15
|
"""
|
11
16
|
|
12
|
-
def __init__(self, num_vision_tokens_per_image, **kwargs):
|
13
|
-
super().__init__(**kwargs)
|
17
|
+
def __init__(self, num_vision_tokens_per_image, dtype=None, **kwargs):
|
18
|
+
super().__init__(dtype=dtype, **kwargs)
|
14
19
|
|
15
20
|
self.num_vision_tokens_per_image = num_vision_tokens_per_image
|
16
21
|
|
@@ -19,12 +24,17 @@ class Gemma3InterleaveEmbeddings(keras.layers.Layer):
|
|
19
24
|
Integrates image embeddings into a text embedding sequence.
|
20
25
|
|
21
26
|
Args:
|
22
|
-
image_embeddings:
|
23
|
-
`
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
27
|
+
image_embeddings: tensor. Image embeddings as returned by the
|
28
|
+
vision encoder (`Gemma3VisionEncoder`, usually). Shape:
|
29
|
+
`(batch_size * num_images_per_prompt, `
|
30
|
+
`num_vision_tokens_per_image, embedding_dim)`.
|
31
|
+
text_embeddings: tensor. Embeddings returned by the text embedding
|
32
|
+
layer. Shape: `(batch_size, seq_length, embedding_dim)`.
|
33
|
+
vision_indices: tensor. Indexes into `text_embeddings`, used to
|
34
|
+
identify which places are supposed to be replaced by
|
35
|
+
`image_embeddings`. Shape:
|
36
|
+
`(batch_size,`
|
37
|
+
`num_images_per_prompt * num_vision_tokens_per_image)`.
|
28
38
|
|
29
39
|
Returns:
|
30
40
|
Tensor of shape `(batch_size, seq_length, embedding_dim)`
|
@@ -32,32 +42,62 @@ class Gemma3InterleaveEmbeddings(keras.layers.Layer):
|
|
32
42
|
"""
|
33
43
|
|
34
44
|
batch_size, seq_length, embedding_dim = ops.shape(text_embeddings)
|
45
|
+
# `num_images` will be 0 for text only inputs, and
|
46
|
+
# `batch_size * max_images_per_prompt` if images are passed.
|
47
|
+
num_images = ops.shape(image_embeddings)[0]
|
35
48
|
|
36
|
-
# Flatten text embeddings,
|
49
|
+
# Flatten text embeddings, image embeddings and indices.
|
37
50
|
flat_text_embeddings = ops.reshape(
|
38
51
|
text_embeddings, (batch_size * seq_length, embedding_dim)
|
39
52
|
)
|
40
|
-
|
41
|
-
#
|
42
|
-
# it will be 0 for text-only.
|
43
|
-
image_batch_size = ops.shape(image_embeddings)[0]
|
53
|
+
# `flat_image_embeddings` is the `updates` tensor and should be of shape
|
54
|
+
# `(num_updates, embedding_dim)`.
|
44
55
|
flat_image_embeddings = ops.reshape(
|
45
56
|
image_embeddings,
|
46
57
|
(
|
47
|
-
|
58
|
+
num_images * self.num_vision_tokens_per_image,
|
48
59
|
embedding_dim,
|
49
60
|
),
|
50
61
|
)
|
51
62
|
|
52
|
-
#
|
63
|
+
# For vision indices, we need to add values such that the indices
|
64
|
+
# index into a flattened `text_embeddings`.
|
65
|
+
to_add = ops.multiply(
|
66
|
+
keras.ops.arange(batch_size, dtype="int32"), seq_length
|
67
|
+
)
|
68
|
+
to_add = ops.expand_dims(to_add, axis=-1)
|
69
|
+
vision_indices = ops.add(vision_indices, to_add)
|
70
|
+
|
71
|
+
# indices should be of shape `(num_updates, 1)`. `num_updates` is
|
72
|
+
# how many vision tokens there are to update.
|
53
73
|
vision_indices_shape = ops.shape(vision_indices)
|
54
74
|
flat_vision_indices = ops.reshape(
|
55
75
|
vision_indices,
|
56
76
|
(vision_indices_shape[0] * vision_indices_shape[1], 1),
|
57
77
|
)
|
58
78
|
indices = ops.cast(flat_vision_indices, "int32")
|
79
|
+
|
80
|
+
# Before reconstructing, store the 0th index so that we can restore it
|
81
|
+
# later.
|
82
|
+
zeroth_index_text_embeddings = ops.take(
|
83
|
+
flat_text_embeddings,
|
84
|
+
indices=ops.squeeze(to_add, axis=-1),
|
85
|
+
axis=0,
|
86
|
+
)
|
87
|
+
|
88
|
+
# Reconstruct embeddings
|
89
|
+
reconstructed_embedding = ops.scatter_update(
|
90
|
+
inputs=flat_text_embeddings,
|
91
|
+
indices=indices,
|
92
|
+
updates=flat_image_embeddings,
|
93
|
+
)
|
94
|
+
|
95
|
+
# Remember that we pad `vision_indices` with the 0th index. We need to
|
96
|
+
# restore the original value in the reconstructed embedding tensor.
|
59
97
|
reconstructed_embedding = ops.scatter_update(
|
60
|
-
|
98
|
+
inputs=reconstructed_embedding,
|
99
|
+
indices=to_add,
|
100
|
+
updates=zeroth_index_text_embeddings,
|
61
101
|
)
|
62
102
|
|
63
103
|
# Reshape to original dimensions
|
@@ -11,7 +11,7 @@ backbone_presets = {
|
|
11
11
|
"params": 999885952,
|
12
12
|
"path": "gemma3",
|
13
13
|
},
|
14
|
-
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_1b/
|
14
|
+
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_1b/3",
|
15
15
|
},
|
16
16
|
"gemma3_instruct_1b": {
|
17
17
|
"metadata": {
|
@@ -22,7 +22,7 @@ backbone_presets = {
|
|
22
22
|
"params": 999885952,
|
23
23
|
"path": "gemma3",
|
24
24
|
},
|
25
|
-
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_1b/
|
25
|
+
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_1b/3",
|
26
26
|
},
|
27
27
|
"gemma3_4b_text": {
|
28
28
|
"metadata": {
|
@@ -33,7 +33,7 @@ backbone_presets = {
|
|
33
33
|
"params": 3880099328,
|
34
34
|
"path": "gemma3",
|
35
35
|
},
|
36
|
-
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_4b_text/
|
36
|
+
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_4b_text/2",
|
37
37
|
},
|
38
38
|
"gemma3_instruct_4b_text": {
|
39
39
|
"metadata": {
|
@@ -44,7 +44,7 @@ backbone_presets = {
|
|
44
44
|
"params": 3880099328,
|
45
45
|
"path": "gemma3",
|
46
46
|
},
|
47
|
-
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_4b_text/
|
47
|
+
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_4b_text/3",
|
48
48
|
},
|
49
49
|
"gemma3_12b_text": {
|
50
50
|
"metadata": {
|
@@ -55,7 +55,7 @@ backbone_presets = {
|
|
55
55
|
"params": 11765788416,
|
56
56
|
"path": "gemma3",
|
57
57
|
},
|
58
|
-
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_12b_text/
|
58
|
+
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_12b_text/2",
|
59
59
|
},
|
60
60
|
"gemma3_instruct_12b_text": {
|
61
61
|
"metadata": {
|
@@ -66,7 +66,7 @@ backbone_presets = {
|
|
66
66
|
"params": 11765788416,
|
67
67
|
"path": "gemma3",
|
68
68
|
},
|
69
|
-
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_12b_text/
|
69
|
+
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_12b_text/2",
|
70
70
|
},
|
71
71
|
"gemma3_27b_text": {
|
72
72
|
"metadata": {
|
@@ -77,7 +77,7 @@ backbone_presets = {
|
|
77
77
|
"params": 27009002240,
|
78
78
|
"path": "gemma3",
|
79
79
|
},
|
80
|
-
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_27b_text/
|
80
|
+
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_27b_text/3",
|
81
81
|
},
|
82
82
|
"gemma3_instruct_27b_text": {
|
83
83
|
"metadata": {
|
@@ -88,6 +88,72 @@ backbone_presets = {
|
|
88
88
|
"params": 27009002240,
|
89
89
|
"path": "gemma3",
|
90
90
|
},
|
91
|
-
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_27b_text/
|
91
|
+
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_27b_text/2",
|
92
|
+
},
|
93
|
+
"gemma3_4b": {
|
94
|
+
"metadata": {
|
95
|
+
"description": (
|
96
|
+
"4 billion parameter, 34-layer, vision+text pretrained "
|
97
|
+
"Gemma3 model."
|
98
|
+
),
|
99
|
+
"params": 4299915632,
|
100
|
+
"path": "gemma3",
|
101
|
+
},
|
102
|
+
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_4b/1",
|
103
|
+
},
|
104
|
+
"gemma3_instruct_4b": {
|
105
|
+
"metadata": {
|
106
|
+
"description": (
|
107
|
+
"4 billion parameter, 34-layer, vision+text instruction-tuned "
|
108
|
+
"Gemma3 model."
|
109
|
+
),
|
110
|
+
"params": 4299915632,
|
111
|
+
"path": "gemma3",
|
112
|
+
},
|
113
|
+
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_4b/1",
|
114
|
+
},
|
115
|
+
"gemma3_12b": {
|
116
|
+
"metadata": {
|
117
|
+
"description": (
|
118
|
+
"12 billion parameter, 48-layer, vision+text pretrained "
|
119
|
+
"Gemma3 model."
|
120
|
+
),
|
121
|
+
"params": 12187079280,
|
122
|
+
"path": "gemma3",
|
123
|
+
},
|
124
|
+
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_12b/1",
|
125
|
+
},
|
126
|
+
"gemma3_instruct_12b": {
|
127
|
+
"metadata": {
|
128
|
+
"description": (
|
129
|
+
"12 billion parameter, 48-layer, vision+text instruction-tuned "
|
130
|
+
"Gemma3 model."
|
131
|
+
),
|
132
|
+
"params": 12187079280,
|
133
|
+
"path": "gemma3",
|
134
|
+
},
|
135
|
+
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_12b/1",
|
136
|
+
},
|
137
|
+
"gemma3_27b": {
|
138
|
+
"metadata": {
|
139
|
+
"description": (
|
140
|
+
"27 billion parameter, 62-layer, vision+text pretrained "
|
141
|
+
"Gemma3 model."
|
142
|
+
),
|
143
|
+
"params": 27432062576,
|
144
|
+
"path": "gemma3",
|
145
|
+
},
|
146
|
+
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_27b/1",
|
147
|
+
},
|
148
|
+
"gemma3_instruct_27b": {
|
149
|
+
"metadata": {
|
150
|
+
"description": (
|
151
|
+
"27 billion parameter, 62-layer, vision+text instruction-tuned "
|
152
|
+
"Gemma3 model."
|
153
|
+
),
|
154
|
+
"params": 27432062576,
|
155
|
+
"path": "gemma3",
|
156
|
+
},
|
157
|
+
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_27b/1",
|
92
158
|
},
|
93
159
|
}
|
@@ -4,6 +4,10 @@ from keras_hub.src.tokenizers.sentence_piece_tokenizer import (
|
|
4
4
|
SentencePieceTokenizer,
|
5
5
|
)
|
6
6
|
|
7
|
+
START_OF_IMAGE_TOKEN = "<start_of_image>"
|
8
|
+
IMAGE_PLACEHOLDER_TOKEN = "<img>"
|
9
|
+
END_OF_IMAGE_TOKEN = "<end_of_image>"
|
10
|
+
|
7
11
|
|
8
12
|
@keras_hub_export(
|
9
13
|
[
|
@@ -84,4 +88,9 @@ class Gemma3Tokenizer(SentencePieceTokenizer):
|
|
84
88
|
# Image placeholder token.
|
85
89
|
self._add_special_token("<img>", "image_placeholder")
|
86
90
|
|
91
|
+
# Some tokens which are used in the preprocessor. We need to keep them
|
92
|
+
# here so that the preprocessor works with `tf.data`.
|
93
|
+
self._add_special_token("<start_of_image>", "start_of_image_token")
|
94
|
+
self._add_special_token("<end_of_image>", "end_of_image_token")
|
95
|
+
|
87
96
|
super().__init__(proto=proto, **kwargs)
|