keras-hub-nightly 0.20.0.dev202503260356__py3-none-any.whl → 0.20.0.dev202503270400__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,305 @@
1
+ import keras
2
+ from keras import ops
3
+
4
+ from keras_hub.src.layers.modeling.transformer_layer_utils import (
5
+ compute_causal_mask,
6
+ )
7
+ from keras_hub.src.layers.modeling.transformer_layer_utils import (
8
+ merge_padding_and_attention_mask,
9
+ )
10
+ from keras_hub.src.models.gemma3.gemma3_attention import CachedGemma3Attention
11
+ from keras_hub.src.models.gemma3.rms_normalization import RMSNormalization
12
+
13
+
14
+ class Gemma3DecoderBlock(keras.layers.Layer):
15
+ """Transformer decoder layer for Gemma3.
16
+
17
+ This is different from Gemma and Gemma2 in several ways:
18
+
19
+ - `use_query_key_norm`: Applies RMS Norm on query, key.
20
+ - `rope_wavelength`: RoPE wavelength differs from local to global attention
21
+ layers.
22
+ - `rope_scaling_factor`: RoPE scaling factor differs from local to global
23
+ attention layers.
24
+ - `gate_dim_reduction`: In the gating layers, Gemma and Gemma2 reduce
25
+ intermediate dimension by 2. For Gemma3, no such reduction happens.
26
+ - Uses bidirectional attention for images, and causal for everything else.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ hidden_dim,
32
+ intermediate_dim,
33
+ head_dim,
34
+ num_query_heads,
35
+ num_key_value_heads,
36
+ query_head_dim_normalize=True,
37
+ use_query_key_norm=False,
38
+ use_post_ffw_norm=False,
39
+ use_post_attention_norm=False,
40
+ gate_dim_reduction=2,
41
+ logit_soft_cap=None,
42
+ use_sliding_window_attention=False,
43
+ sliding_window_size=4096,
44
+ layer_norm_epsilon=1e-6,
45
+ rope_wavelength=10_000.0,
46
+ rope_scaling_factor=1.0,
47
+ dropout=0,
48
+ **kwargs,
49
+ ):
50
+ super().__init__(**kwargs)
51
+
52
+ self.hidden_dim = hidden_dim
53
+ self.intermediate_dim = intermediate_dim
54
+ self.head_dim = head_dim
55
+ self.num_query_heads = num_query_heads
56
+ self.num_key_value_heads = num_key_value_heads
57
+ self.query_head_dim_normalize = query_head_dim_normalize
58
+ self.use_query_key_norm = use_query_key_norm
59
+ self.use_post_ffw_norm = use_post_ffw_norm
60
+ self.use_post_attention_norm = use_post_attention_norm
61
+ self.gate_dim_reduction = gate_dim_reduction
62
+ self.logit_soft_cap = logit_soft_cap
63
+ self.use_sliding_window_attention = use_sliding_window_attention
64
+ self.sliding_window_size = sliding_window_size
65
+ self.layer_norm_epsilon = layer_norm_epsilon
66
+ self.rope_wavelength = rope_wavelength
67
+ self.rope_scaling_factor = rope_scaling_factor
68
+ self.dropout = dropout
69
+
70
+ self.pre_attention_norm = RMSNormalization(
71
+ epsilon=self.layer_norm_epsilon,
72
+ dtype=self.dtype_policy,
73
+ name="pre_attention_norm",
74
+ )
75
+
76
+ if use_post_attention_norm:
77
+ self.post_attention_norm = RMSNormalization(
78
+ epsilon=self.layer_norm_epsilon,
79
+ dtype=self.dtype_policy,
80
+ name="post_attention_norm",
81
+ )
82
+
83
+ self.attention = CachedGemma3Attention(
84
+ head_dim=head_dim,
85
+ num_query_heads=num_query_heads,
86
+ num_key_value_heads=num_key_value_heads,
87
+ use_query_key_norm=use_query_key_norm,
88
+ logit_soft_cap=logit_soft_cap,
89
+ use_sliding_window_attention=use_sliding_window_attention,
90
+ sliding_window_size=sliding_window_size,
91
+ query_head_dim_normalize=True,
92
+ rope_wavelength=rope_wavelength,
93
+ rope_scaling_factor=rope_scaling_factor,
94
+ dropout=dropout,
95
+ dtype=self.dtype_policy,
96
+ name="attention",
97
+ )
98
+
99
+ if self.dropout > 0:
100
+ self.attention_dropout = keras.layers.Dropout(rate=dropout)
101
+ self.feedforward_dropout = keras.layers.Dropout(rate=dropout)
102
+
103
+ self.pre_ffw_norm = RMSNormalization(
104
+ epsilon=self.layer_norm_epsilon,
105
+ dtype=self.dtype_policy,
106
+ name="pre_ffw_norm",
107
+ )
108
+
109
+ if use_post_ffw_norm:
110
+ self.post_ffw_norm = RMSNormalization(
111
+ epsilon=self.layer_norm_epsilon,
112
+ dtype=self.dtype_policy,
113
+ name="post_ffw_norm",
114
+ )
115
+
116
+ self.gating_ffw = keras.layers.EinsumDense(
117
+ equation="btd,df->btf",
118
+ output_shape=(None, self.intermediate_dim // gate_dim_reduction),
119
+ dtype=self.dtype_policy,
120
+ name="ffw_gating",
121
+ )
122
+
123
+ self.gating_ffw_2 = keras.layers.EinsumDense(
124
+ equation="btd,df->btf",
125
+ output_shape=(None, self.intermediate_dim // gate_dim_reduction),
126
+ dtype=self.dtype_policy,
127
+ name="ffw_gating_2",
128
+ )
129
+
130
+ self.ffw_linear = keras.layers.EinsumDense(
131
+ equation="btf,fd->btd",
132
+ output_shape=(None, self.hidden_dim),
133
+ dtype=self.dtype_policy,
134
+ name="ffw_linear",
135
+ )
136
+
137
+ def build(self, input_shape):
138
+ self.pre_attention_norm.build(input_shape)
139
+ self.attention.build(input_shape)
140
+
141
+ if self.use_post_attention_norm:
142
+ shape = self.attention.compute_output_shape(input_shape)
143
+ self.post_attention_norm.build(shape)
144
+
145
+ shape = input_shape
146
+ self.pre_ffw_norm.build(shape)
147
+ self.gating_ffw.build(shape)
148
+ self.gating_ffw_2.build(shape)
149
+
150
+ shape = self.gating_ffw.compute_output_shape(shape)
151
+ self.ffw_linear.build(shape)
152
+
153
+ if self.use_post_ffw_norm:
154
+ shape = self.ffw_linear.compute_output_shape(shape)
155
+ self.post_ffw_norm.build(shape)
156
+
157
+ self.built = True
158
+
159
+ def compute_output_shape(self, input_shape):
160
+ # Isometric
161
+ return input_shape
162
+
163
+ def _compute_image_bidirectional_attention_mask(self, text_mask):
164
+ # text_mask is True for text, False for images. Shape of (bsz, seq_len).
165
+ bidirectional_mask = ops.logical_not(text_mask)
166
+
167
+ # Left pad with 0.
168
+ padded_mask = ops.cast(
169
+ ops.pad(bidirectional_mask, [(0, 0), (1, 0)], constant_values=0),
170
+ dtype="int32",
171
+ )
172
+
173
+ # Assign unique indices to every contiguous span of True.
174
+ boundary = ops.cast(
175
+ ops.greater(padded_mask[..., 1:], padded_mask[..., :-1]),
176
+ dtype="int32",
177
+ )
178
+ numbered_boundary = ops.cumsum(boundary, -1)
179
+ indices = ops.multiply(bidirectional_mask, numbered_boundary)
180
+
181
+ indices_expanded_1 = ops.expand_dims(indices, 1)
182
+ indices_expanded_2 = ops.expand_dims(indices, -1)
183
+
184
+ mask = ops.logical_and(
185
+ ops.equal(
186
+ indices_expanded_1,
187
+ indices_expanded_2,
188
+ ),
189
+ indices_expanded_2,
190
+ )
191
+ return mask
192
+
193
+ def _compute_attention_mask(
194
+ self,
195
+ x,
196
+ padding_mask,
197
+ text_mask,
198
+ cache,
199
+ cache_update_index,
200
+ ):
201
+ decoder_mask = merge_padding_and_attention_mask(
202
+ inputs=x, padding_mask=padding_mask, attention_mask=None
203
+ )
204
+
205
+ batch_size = ops.shape(x)[0]
206
+ input_length = output_length = ops.shape(x)[1]
207
+ if cache is not None:
208
+ input_length = ops.shape(cache)[2]
209
+
210
+ causal_mask = compute_causal_mask(
211
+ batch_size=batch_size,
212
+ input_length=input_length,
213
+ output_length=output_length,
214
+ cache_index=cache_update_index,
215
+ )
216
+
217
+ # Compute bidirectional mask (image tokens can attend to each other
218
+ # in both directions, within the same image).
219
+ if text_mask is not None:
220
+ bidirectional_image_mask = (
221
+ self._compute_image_bidirectional_attention_mask(text_mask)
222
+ )
223
+ causal_mask = ops.logical_or(causal_mask, bidirectional_image_mask)
224
+
225
+ # Respect the padding mask.
226
+ if decoder_mask is not None:
227
+ causal_mask = ops.minimum(decoder_mask, causal_mask)
228
+
229
+ return causal_mask
230
+
231
+ def call(
232
+ self,
233
+ x,
234
+ padding_mask=None,
235
+ text_mask=None,
236
+ cache=None,
237
+ cache_update_index=0,
238
+ ):
239
+ # Note: `text_mask` is used only for Gemma33.
240
+ normalized_x = self.pre_attention_norm(x)
241
+ attention_mask = self._compute_attention_mask(
242
+ normalized_x, padding_mask, text_mask, cache, cache_update_index
243
+ )
244
+ if cache is not None:
245
+ attention, new_cache = self.attention(
246
+ normalized_x,
247
+ attention_mask=attention_mask,
248
+ cache=cache,
249
+ cache_update_index=cache_update_index,
250
+ )
251
+ else:
252
+ attention = self.attention(
253
+ normalized_x,
254
+ attention_mask=attention_mask,
255
+ )
256
+
257
+ if self.use_post_attention_norm:
258
+ attention = self.post_attention_norm(attention)
259
+
260
+ if self.dropout:
261
+ attention = self.attention_dropout(attention)
262
+
263
+ attention_x = x + attention
264
+ normalized_x = self.pre_ffw_norm(attention_x)
265
+
266
+ x1 = self.gating_ffw(normalized_x)
267
+ x2 = self.gating_ffw_2(normalized_x)
268
+ x = keras.activations.gelu(x1, approximate=True) * x2
269
+ x = self.ffw_linear(x)
270
+
271
+ if self.use_post_ffw_norm:
272
+ x = self.post_ffw_norm(x)
273
+
274
+ x = x + attention_x
275
+
276
+ if cache is not None:
277
+ return x, new_cache
278
+ return x
279
+
280
+ def get_config(self):
281
+ config = super().get_config()
282
+ config.update(
283
+ {
284
+ "hidden_dim": self.hidden_dim,
285
+ "intermediate_dim": self.intermediate_dim,
286
+ "head_dim": self.head_dim,
287
+ "num_query_heads": self.num_query_heads,
288
+ "num_key_value_heads": self.num_key_value_heads,
289
+ "query_head_dim_normalize": self.query_head_dim_normalize,
290
+ "use_query_key_norm": self.use_query_key_norm,
291
+ "use_post_ffw_norm": self.use_post_ffw_norm,
292
+ "use_post_attention_norm": self.use_post_attention_norm,
293
+ "gate_dim_reduction": self.gate_dim_reduction,
294
+ "logit_soft_cap": self.logit_soft_cap,
295
+ "use_sliding_window_attention": (
296
+ self.use_sliding_window_attention
297
+ ),
298
+ "sliding_window_size": self.sliding_window_size,
299
+ "layer_norm_epsilon": self.layer_norm_epsilon,
300
+ "dropout": self.dropout,
301
+ "rope_wavelength": self.rope_wavelength,
302
+ "rope_scaling_factor": self.rope_scaling_factor,
303
+ }
304
+ )
305
+ return config
@@ -0,0 +1,8 @@
1
+ from keras_hub.src.api_export import keras_hub_export
2
+ from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
3
+ from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3Backbone
4
+
5
+
6
+ @keras_hub_export("keras_hub.layers.Gemma3ImageConverter")
7
+ class Gemma3ImageConverter(ImageConverter):
8
+ backbone_cls = Gemma3Backbone
@@ -0,0 +1,79 @@
1
+ import keras
2
+ from keras import ops
3
+
4
+
5
+ class Gemma3InterleaveEmbeddings(keras.layers.Layer):
6
+ """Places image embeddings in the correct position in an embedding sequence.
7
+
8
+ Args:
9
+ num_vision_tokens_per_image: int. Number of soft tokens per image.
10
+ """
11
+
12
+ def __init__(self, num_vision_tokens_per_image, **kwargs):
13
+ super().__init__(**kwargs)
14
+
15
+ self.num_vision_tokens_per_image = num_vision_tokens_per_image
16
+
17
+ def call(self, image_embeddings, text_embeddings, vision_indices):
18
+ """
19
+ Integrates image embeddings into a text embedding sequence.
20
+
21
+ Args:
22
+ image_embeddings: Tensor of shape
23
+ `(batch_size * num_images_per_prompt,
24
+ num_vision_tokens_per_image, embedding_dim)`.
25
+ text_embeddings: Tensor of shape
26
+ `(batch_size, seq_length, embedding_dim)`.
27
+ text_mask: Boolean tensor of shape `(batch_size, seq_length)`.
28
+
29
+ Returns:
30
+ Tensor of shape `(batch_size, seq_length, embedding_dim)`
31
+ representing the reconstructed embeddings.
32
+ """
33
+
34
+ batch_size, seq_length, embedding_dim = ops.shape(text_embeddings)
35
+
36
+ # Flatten text embeddings, text mask and image embeddings.
37
+ flat_text_embeddings = ops.reshape(
38
+ text_embeddings, (batch_size * seq_length, embedding_dim)
39
+ )
40
+
41
+ # The image batch size might be different when we pass only text, i.e,
42
+ # it will be 0 for text-only.
43
+ image_batch_size = ops.shape(image_embeddings)[0]
44
+ flat_image_embeddings = ops.reshape(
45
+ image_embeddings,
46
+ (
47
+ image_batch_size * self.num_vision_tokens_per_image,
48
+ embedding_dim,
49
+ ),
50
+ )
51
+
52
+ # Reconstruct embeddings.
53
+ vision_indices_shape = ops.shape(vision_indices)
54
+ flat_vision_indices = ops.reshape(
55
+ vision_indices,
56
+ (vision_indices_shape[0] * vision_indices_shape[1], 1),
57
+ )
58
+ indices = ops.cast(flat_vision_indices, "int32")
59
+ reconstructed_embedding = ops.scatter_update(
60
+ flat_text_embeddings, indices, flat_image_embeddings
61
+ )
62
+
63
+ # Reshape to original dimensions
64
+ reconstructed_embedding = ops.reshape(
65
+ reconstructed_embedding, (batch_size, seq_length, embedding_dim)
66
+ )
67
+ return reconstructed_embedding
68
+
69
+ def compute_output_shape(self, input_shape):
70
+ return input_shape
71
+
72
+ def get_config(self):
73
+ config = super().get_config()
74
+ config.update(
75
+ {
76
+ "num_vision_tokens_per_image": self.num_vision_tokens_per_image,
77
+ }
78
+ )
79
+ return config
@@ -0,0 +1,93 @@
1
+ """Gemma3 model preset configurations."""
2
+
3
+ # Metadata for loading pretrained model weights.
4
+ backbone_presets = {
5
+ "gemma3_1b": {
6
+ "metadata": {
7
+ "description": (
8
+ "1 billion parameter, 26-layer, text-only pretrained "
9
+ "Gemma3 model."
10
+ ),
11
+ "params": 999885952,
12
+ "path": "gemma3",
13
+ },
14
+ "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_1b/1",
15
+ },
16
+ "gemma3_instruct_1b": {
17
+ "metadata": {
18
+ "description": (
19
+ "1 billion parameter, 26-layer, text-only instruction-tuned "
20
+ "Gemma3 model."
21
+ ),
22
+ "params": 999885952,
23
+ "path": "gemma3",
24
+ },
25
+ "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_1b/1",
26
+ },
27
+ "gemma3_4b_text": {
28
+ "metadata": {
29
+ "description": (
30
+ "4 billion parameter, 34-layer, text-only pretrained "
31
+ "Gemma3 model."
32
+ ),
33
+ "params": 3880099328,
34
+ "path": "gemma3",
35
+ },
36
+ "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_4b_text/1",
37
+ },
38
+ "gemma3_instruct_4b_text": {
39
+ "metadata": {
40
+ "description": (
41
+ "4 billion parameter, 34-layer, text-only instruction-tuned "
42
+ "Gemma3 model."
43
+ ),
44
+ "params": 3880099328,
45
+ "path": "gemma3",
46
+ },
47
+ "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_4b_text/2",
48
+ },
49
+ "gemma3_12b_text": {
50
+ "metadata": {
51
+ "description": (
52
+ "12 billion parameter, 48-layer, text-only pretrained "
53
+ "Gemma3 model."
54
+ ),
55
+ "params": 11765788416,
56
+ "path": "gemma3",
57
+ },
58
+ "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_12b_text/1",
59
+ },
60
+ "gemma3_instruct_12b_text": {
61
+ "metadata": {
62
+ "description": (
63
+ "12 billion parameter, 48-layer, text-only instruction-tuned "
64
+ "Gemma3 model."
65
+ ),
66
+ "params": 11765788416,
67
+ "path": "gemma3",
68
+ },
69
+ "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_12b_text/1",
70
+ },
71
+ "gemma3_27b_text": {
72
+ "metadata": {
73
+ "description": (
74
+ "27 billion parameter, 62-layer, text-only pretrained "
75
+ "Gemma3 model."
76
+ ),
77
+ "params": 27009002240,
78
+ "path": "gemma3",
79
+ },
80
+ "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_27b_text/1",
81
+ },
82
+ "gemma3_instruct_27b_text": {
83
+ "metadata": {
84
+ "description": (
85
+ "27 billion parameter, 62-layer, text-only instruction-tuned "
86
+ "Gemma3 model."
87
+ ),
88
+ "params": 27009002240,
89
+ "path": "gemma3",
90
+ },
91
+ "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_27b_text/1",
92
+ },
93
+ }
@@ -0,0 +1,87 @@
1
+ from keras_hub.src.api_export import keras_hub_export
2
+ from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3Backbone
3
+ from keras_hub.src.tokenizers.sentence_piece_tokenizer import (
4
+ SentencePieceTokenizer,
5
+ )
6
+
7
+
8
+ @keras_hub_export(
9
+ [
10
+ "keras_hub.tokenizers.Gemma3Tokenizer",
11
+ "keras_hub.models.Gemma3Tokenizer",
12
+ ]
13
+ )
14
+ class Gemma3Tokenizer(SentencePieceTokenizer):
15
+ """Gemma tokenizer layer based on SentencePiece.
16
+
17
+ This tokenizer class will tokenize raw strings into integer sequences and
18
+ is based on `keras_hub.tokenizers.SentencePieceTokenizer`. Unlike the
19
+ underlying tokenizer, it will check for all special tokens needed by
20
+ Gemma models and provides a `from_preset()` method to automatically
21
+ download a matching vocabulary for a Gemma preset.
22
+
23
+ If input is a batch of strings (rank > 0), the layer will output a
24
+ `tf.RaggedTensor` where the last dimension of the output is ragged.
25
+
26
+ If input is a scalar string (rank == 0), the layer will output a dense
27
+ `tf.Tensor` with static shape `[None]`.
28
+
29
+ Args:
30
+ proto: Either a `string` path to a SentencePiece proto file, or a
31
+ `bytes` object with a serialized SentencePiece proto. See the
32
+ [SentencePiece repository](https://github.com/google/sentencepiece)
33
+ for more details on the format.
34
+
35
+ Examples:
36
+
37
+ ```python
38
+ # Unbatched input.
39
+ tokenizer = keras_hub.models.Gemma3Tokenizer.from_preset(
40
+ "gemma_instruct_1b"
41
+ )
42
+ tokenizer("The quick brown fox jumped.")
43
+
44
+ # Batched input.
45
+ tokenizer(["The quick brown fox jumped.", "The fox slept."])
46
+
47
+ # Detokenization.
48
+ tokenizer.detokenize(tokenizer("The quick brown fox jumped."))
49
+
50
+ # Custom vocabulary.
51
+ bytes_io = io.BytesIO()
52
+ ds = tf.data.Dataset.from_tensor_slices(["The quick brown fox jumped."])
53
+ sentencepiece.SentencePieceTrainer.train(
54
+ sentence_iterator=ds.as_numpy_iterator(),
55
+ model_writer=bytes_io,
56
+ vocab_size=8,
57
+ model_type="WORD",
58
+ pad_id=0,
59
+ bos_id=1,
60
+ eos_id=2,
61
+ unk_id=3,
62
+ pad_piece="<pad>",
63
+ bos_piece="<bos>",
64
+ eos_piece="<eos>",
65
+ unk_piece="<unk>",
66
+ )
67
+ tokenizer = keras_hub.models.Gemma3Tokenizer(
68
+ proto=bytes_io.getvalue(),
69
+ )
70
+ tokenizer("The quick brown fox jumped.")
71
+ ```
72
+ """
73
+
74
+ backbone_cls = Gemma3Backbone
75
+
76
+ def __init__(self, proto, **kwargs):
77
+ # Add special tokens.
78
+
79
+ # The usual tokens.
80
+ self._add_special_token("<bos>", "start_token")
81
+ self._add_special_token("<eos>", "end_token")
82
+ self._add_special_token("<pad>", "pad_token")
83
+
84
+ # Image placeholder token.
85
+ self._add_special_token("<img>", "image_placeholder")
86
+
87
+ super().__init__(proto=proto, **kwargs)