keras-hub-nightly 0.20.0.dev202503260356__py3-none-any.whl → 0.20.0.dev202503270400__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +3 -0
- keras_hub/api/models/__init__.py +6 -0
- keras_hub/api/tokenizers/__init__.py +1 -0
- keras_hub/src/models/gemma3/__init__.py +5 -0
- keras_hub/src/models/gemma3/gemma3_attention.py +315 -0
- keras_hub/src/models/gemma3/gemma3_backbone.py +352 -0
- keras_hub/src/models/gemma3/gemma3_causal_lm.py +306 -0
- keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +691 -0
- keras_hub/src/models/gemma3/gemma3_decoder_block.py +305 -0
- keras_hub/src/models/gemma3/gemma3_image_converter.py +8 -0
- keras_hub/src/models/gemma3/gemma3_interleave_embeddings.py +79 -0
- keras_hub/src/models/gemma3/gemma3_presets.py +93 -0
- keras_hub/src/models/gemma3/gemma3_tokenizer.py +87 -0
- keras_hub/src/models/gemma3/gemma3_vit.py +608 -0
- keras_hub/src/models/gemma3/rms_normalization.py +26 -0
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.20.0.dev202503260356.dist-info → keras_hub_nightly-0.20.0.dev202503270400.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.20.0.dev202503260356.dist-info → keras_hub_nightly-0.20.0.dev202503270400.dist-info}/RECORD +20 -8
- {keras_hub_nightly-0.20.0.dev202503260356.dist-info → keras_hub_nightly-0.20.0.dev202503270400.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.20.0.dev202503260356.dist-info → keras_hub_nightly-0.20.0.dev202503270400.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,305 @@
|
|
1
|
+
import keras
|
2
|
+
from keras import ops
|
3
|
+
|
4
|
+
from keras_hub.src.layers.modeling.transformer_layer_utils import (
|
5
|
+
compute_causal_mask,
|
6
|
+
)
|
7
|
+
from keras_hub.src.layers.modeling.transformer_layer_utils import (
|
8
|
+
merge_padding_and_attention_mask,
|
9
|
+
)
|
10
|
+
from keras_hub.src.models.gemma3.gemma3_attention import CachedGemma3Attention
|
11
|
+
from keras_hub.src.models.gemma3.rms_normalization import RMSNormalization
|
12
|
+
|
13
|
+
|
14
|
+
class Gemma3DecoderBlock(keras.layers.Layer):
|
15
|
+
"""Transformer decoder layer for Gemma3.
|
16
|
+
|
17
|
+
This is different from Gemma and Gemma2 in several ways:
|
18
|
+
|
19
|
+
- `use_query_key_norm`: Applies RMS Norm on query, key.
|
20
|
+
- `rope_wavelength`: RoPE wavelength differs from local to global attention
|
21
|
+
layers.
|
22
|
+
- `rope_scaling_factor`: RoPE scaling factor differs from local to global
|
23
|
+
attention layers.
|
24
|
+
- `gate_dim_reduction`: In the gating layers, Gemma and Gemma2 reduce
|
25
|
+
intermediate dimension by 2. For Gemma3, no such reduction happens.
|
26
|
+
- Uses bidirectional attention for images, and causal for everything else.
|
27
|
+
"""
|
28
|
+
|
29
|
+
def __init__(
|
30
|
+
self,
|
31
|
+
hidden_dim,
|
32
|
+
intermediate_dim,
|
33
|
+
head_dim,
|
34
|
+
num_query_heads,
|
35
|
+
num_key_value_heads,
|
36
|
+
query_head_dim_normalize=True,
|
37
|
+
use_query_key_norm=False,
|
38
|
+
use_post_ffw_norm=False,
|
39
|
+
use_post_attention_norm=False,
|
40
|
+
gate_dim_reduction=2,
|
41
|
+
logit_soft_cap=None,
|
42
|
+
use_sliding_window_attention=False,
|
43
|
+
sliding_window_size=4096,
|
44
|
+
layer_norm_epsilon=1e-6,
|
45
|
+
rope_wavelength=10_000.0,
|
46
|
+
rope_scaling_factor=1.0,
|
47
|
+
dropout=0,
|
48
|
+
**kwargs,
|
49
|
+
):
|
50
|
+
super().__init__(**kwargs)
|
51
|
+
|
52
|
+
self.hidden_dim = hidden_dim
|
53
|
+
self.intermediate_dim = intermediate_dim
|
54
|
+
self.head_dim = head_dim
|
55
|
+
self.num_query_heads = num_query_heads
|
56
|
+
self.num_key_value_heads = num_key_value_heads
|
57
|
+
self.query_head_dim_normalize = query_head_dim_normalize
|
58
|
+
self.use_query_key_norm = use_query_key_norm
|
59
|
+
self.use_post_ffw_norm = use_post_ffw_norm
|
60
|
+
self.use_post_attention_norm = use_post_attention_norm
|
61
|
+
self.gate_dim_reduction = gate_dim_reduction
|
62
|
+
self.logit_soft_cap = logit_soft_cap
|
63
|
+
self.use_sliding_window_attention = use_sliding_window_attention
|
64
|
+
self.sliding_window_size = sliding_window_size
|
65
|
+
self.layer_norm_epsilon = layer_norm_epsilon
|
66
|
+
self.rope_wavelength = rope_wavelength
|
67
|
+
self.rope_scaling_factor = rope_scaling_factor
|
68
|
+
self.dropout = dropout
|
69
|
+
|
70
|
+
self.pre_attention_norm = RMSNormalization(
|
71
|
+
epsilon=self.layer_norm_epsilon,
|
72
|
+
dtype=self.dtype_policy,
|
73
|
+
name="pre_attention_norm",
|
74
|
+
)
|
75
|
+
|
76
|
+
if use_post_attention_norm:
|
77
|
+
self.post_attention_norm = RMSNormalization(
|
78
|
+
epsilon=self.layer_norm_epsilon,
|
79
|
+
dtype=self.dtype_policy,
|
80
|
+
name="post_attention_norm",
|
81
|
+
)
|
82
|
+
|
83
|
+
self.attention = CachedGemma3Attention(
|
84
|
+
head_dim=head_dim,
|
85
|
+
num_query_heads=num_query_heads,
|
86
|
+
num_key_value_heads=num_key_value_heads,
|
87
|
+
use_query_key_norm=use_query_key_norm,
|
88
|
+
logit_soft_cap=logit_soft_cap,
|
89
|
+
use_sliding_window_attention=use_sliding_window_attention,
|
90
|
+
sliding_window_size=sliding_window_size,
|
91
|
+
query_head_dim_normalize=True,
|
92
|
+
rope_wavelength=rope_wavelength,
|
93
|
+
rope_scaling_factor=rope_scaling_factor,
|
94
|
+
dropout=dropout,
|
95
|
+
dtype=self.dtype_policy,
|
96
|
+
name="attention",
|
97
|
+
)
|
98
|
+
|
99
|
+
if self.dropout > 0:
|
100
|
+
self.attention_dropout = keras.layers.Dropout(rate=dropout)
|
101
|
+
self.feedforward_dropout = keras.layers.Dropout(rate=dropout)
|
102
|
+
|
103
|
+
self.pre_ffw_norm = RMSNormalization(
|
104
|
+
epsilon=self.layer_norm_epsilon,
|
105
|
+
dtype=self.dtype_policy,
|
106
|
+
name="pre_ffw_norm",
|
107
|
+
)
|
108
|
+
|
109
|
+
if use_post_ffw_norm:
|
110
|
+
self.post_ffw_norm = RMSNormalization(
|
111
|
+
epsilon=self.layer_norm_epsilon,
|
112
|
+
dtype=self.dtype_policy,
|
113
|
+
name="post_ffw_norm",
|
114
|
+
)
|
115
|
+
|
116
|
+
self.gating_ffw = keras.layers.EinsumDense(
|
117
|
+
equation="btd,df->btf",
|
118
|
+
output_shape=(None, self.intermediate_dim // gate_dim_reduction),
|
119
|
+
dtype=self.dtype_policy,
|
120
|
+
name="ffw_gating",
|
121
|
+
)
|
122
|
+
|
123
|
+
self.gating_ffw_2 = keras.layers.EinsumDense(
|
124
|
+
equation="btd,df->btf",
|
125
|
+
output_shape=(None, self.intermediate_dim // gate_dim_reduction),
|
126
|
+
dtype=self.dtype_policy,
|
127
|
+
name="ffw_gating_2",
|
128
|
+
)
|
129
|
+
|
130
|
+
self.ffw_linear = keras.layers.EinsumDense(
|
131
|
+
equation="btf,fd->btd",
|
132
|
+
output_shape=(None, self.hidden_dim),
|
133
|
+
dtype=self.dtype_policy,
|
134
|
+
name="ffw_linear",
|
135
|
+
)
|
136
|
+
|
137
|
+
def build(self, input_shape):
|
138
|
+
self.pre_attention_norm.build(input_shape)
|
139
|
+
self.attention.build(input_shape)
|
140
|
+
|
141
|
+
if self.use_post_attention_norm:
|
142
|
+
shape = self.attention.compute_output_shape(input_shape)
|
143
|
+
self.post_attention_norm.build(shape)
|
144
|
+
|
145
|
+
shape = input_shape
|
146
|
+
self.pre_ffw_norm.build(shape)
|
147
|
+
self.gating_ffw.build(shape)
|
148
|
+
self.gating_ffw_2.build(shape)
|
149
|
+
|
150
|
+
shape = self.gating_ffw.compute_output_shape(shape)
|
151
|
+
self.ffw_linear.build(shape)
|
152
|
+
|
153
|
+
if self.use_post_ffw_norm:
|
154
|
+
shape = self.ffw_linear.compute_output_shape(shape)
|
155
|
+
self.post_ffw_norm.build(shape)
|
156
|
+
|
157
|
+
self.built = True
|
158
|
+
|
159
|
+
def compute_output_shape(self, input_shape):
|
160
|
+
# Isometric
|
161
|
+
return input_shape
|
162
|
+
|
163
|
+
def _compute_image_bidirectional_attention_mask(self, text_mask):
|
164
|
+
# text_mask is True for text, False for images. Shape of (bsz, seq_len).
|
165
|
+
bidirectional_mask = ops.logical_not(text_mask)
|
166
|
+
|
167
|
+
# Left pad with 0.
|
168
|
+
padded_mask = ops.cast(
|
169
|
+
ops.pad(bidirectional_mask, [(0, 0), (1, 0)], constant_values=0),
|
170
|
+
dtype="int32",
|
171
|
+
)
|
172
|
+
|
173
|
+
# Assign unique indices to every contiguous span of True.
|
174
|
+
boundary = ops.cast(
|
175
|
+
ops.greater(padded_mask[..., 1:], padded_mask[..., :-1]),
|
176
|
+
dtype="int32",
|
177
|
+
)
|
178
|
+
numbered_boundary = ops.cumsum(boundary, -1)
|
179
|
+
indices = ops.multiply(bidirectional_mask, numbered_boundary)
|
180
|
+
|
181
|
+
indices_expanded_1 = ops.expand_dims(indices, 1)
|
182
|
+
indices_expanded_2 = ops.expand_dims(indices, -1)
|
183
|
+
|
184
|
+
mask = ops.logical_and(
|
185
|
+
ops.equal(
|
186
|
+
indices_expanded_1,
|
187
|
+
indices_expanded_2,
|
188
|
+
),
|
189
|
+
indices_expanded_2,
|
190
|
+
)
|
191
|
+
return mask
|
192
|
+
|
193
|
+
def _compute_attention_mask(
|
194
|
+
self,
|
195
|
+
x,
|
196
|
+
padding_mask,
|
197
|
+
text_mask,
|
198
|
+
cache,
|
199
|
+
cache_update_index,
|
200
|
+
):
|
201
|
+
decoder_mask = merge_padding_and_attention_mask(
|
202
|
+
inputs=x, padding_mask=padding_mask, attention_mask=None
|
203
|
+
)
|
204
|
+
|
205
|
+
batch_size = ops.shape(x)[0]
|
206
|
+
input_length = output_length = ops.shape(x)[1]
|
207
|
+
if cache is not None:
|
208
|
+
input_length = ops.shape(cache)[2]
|
209
|
+
|
210
|
+
causal_mask = compute_causal_mask(
|
211
|
+
batch_size=batch_size,
|
212
|
+
input_length=input_length,
|
213
|
+
output_length=output_length,
|
214
|
+
cache_index=cache_update_index,
|
215
|
+
)
|
216
|
+
|
217
|
+
# Compute bidirectional mask (image tokens can attend to each other
|
218
|
+
# in both directions, within the same image).
|
219
|
+
if text_mask is not None:
|
220
|
+
bidirectional_image_mask = (
|
221
|
+
self._compute_image_bidirectional_attention_mask(text_mask)
|
222
|
+
)
|
223
|
+
causal_mask = ops.logical_or(causal_mask, bidirectional_image_mask)
|
224
|
+
|
225
|
+
# Respect the padding mask.
|
226
|
+
if decoder_mask is not None:
|
227
|
+
causal_mask = ops.minimum(decoder_mask, causal_mask)
|
228
|
+
|
229
|
+
return causal_mask
|
230
|
+
|
231
|
+
def call(
|
232
|
+
self,
|
233
|
+
x,
|
234
|
+
padding_mask=None,
|
235
|
+
text_mask=None,
|
236
|
+
cache=None,
|
237
|
+
cache_update_index=0,
|
238
|
+
):
|
239
|
+
# Note: `text_mask` is used only for Gemma33.
|
240
|
+
normalized_x = self.pre_attention_norm(x)
|
241
|
+
attention_mask = self._compute_attention_mask(
|
242
|
+
normalized_x, padding_mask, text_mask, cache, cache_update_index
|
243
|
+
)
|
244
|
+
if cache is not None:
|
245
|
+
attention, new_cache = self.attention(
|
246
|
+
normalized_x,
|
247
|
+
attention_mask=attention_mask,
|
248
|
+
cache=cache,
|
249
|
+
cache_update_index=cache_update_index,
|
250
|
+
)
|
251
|
+
else:
|
252
|
+
attention = self.attention(
|
253
|
+
normalized_x,
|
254
|
+
attention_mask=attention_mask,
|
255
|
+
)
|
256
|
+
|
257
|
+
if self.use_post_attention_norm:
|
258
|
+
attention = self.post_attention_norm(attention)
|
259
|
+
|
260
|
+
if self.dropout:
|
261
|
+
attention = self.attention_dropout(attention)
|
262
|
+
|
263
|
+
attention_x = x + attention
|
264
|
+
normalized_x = self.pre_ffw_norm(attention_x)
|
265
|
+
|
266
|
+
x1 = self.gating_ffw(normalized_x)
|
267
|
+
x2 = self.gating_ffw_2(normalized_x)
|
268
|
+
x = keras.activations.gelu(x1, approximate=True) * x2
|
269
|
+
x = self.ffw_linear(x)
|
270
|
+
|
271
|
+
if self.use_post_ffw_norm:
|
272
|
+
x = self.post_ffw_norm(x)
|
273
|
+
|
274
|
+
x = x + attention_x
|
275
|
+
|
276
|
+
if cache is not None:
|
277
|
+
return x, new_cache
|
278
|
+
return x
|
279
|
+
|
280
|
+
def get_config(self):
|
281
|
+
config = super().get_config()
|
282
|
+
config.update(
|
283
|
+
{
|
284
|
+
"hidden_dim": self.hidden_dim,
|
285
|
+
"intermediate_dim": self.intermediate_dim,
|
286
|
+
"head_dim": self.head_dim,
|
287
|
+
"num_query_heads": self.num_query_heads,
|
288
|
+
"num_key_value_heads": self.num_key_value_heads,
|
289
|
+
"query_head_dim_normalize": self.query_head_dim_normalize,
|
290
|
+
"use_query_key_norm": self.use_query_key_norm,
|
291
|
+
"use_post_ffw_norm": self.use_post_ffw_norm,
|
292
|
+
"use_post_attention_norm": self.use_post_attention_norm,
|
293
|
+
"gate_dim_reduction": self.gate_dim_reduction,
|
294
|
+
"logit_soft_cap": self.logit_soft_cap,
|
295
|
+
"use_sliding_window_attention": (
|
296
|
+
self.use_sliding_window_attention
|
297
|
+
),
|
298
|
+
"sliding_window_size": self.sliding_window_size,
|
299
|
+
"layer_norm_epsilon": self.layer_norm_epsilon,
|
300
|
+
"dropout": self.dropout,
|
301
|
+
"rope_wavelength": self.rope_wavelength,
|
302
|
+
"rope_scaling_factor": self.rope_scaling_factor,
|
303
|
+
}
|
304
|
+
)
|
305
|
+
return config
|
@@ -0,0 +1,8 @@
|
|
1
|
+
from keras_hub.src.api_export import keras_hub_export
|
2
|
+
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
|
3
|
+
from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3Backbone
|
4
|
+
|
5
|
+
|
6
|
+
@keras_hub_export("keras_hub.layers.Gemma3ImageConverter")
|
7
|
+
class Gemma3ImageConverter(ImageConverter):
|
8
|
+
backbone_cls = Gemma3Backbone
|
@@ -0,0 +1,79 @@
|
|
1
|
+
import keras
|
2
|
+
from keras import ops
|
3
|
+
|
4
|
+
|
5
|
+
class Gemma3InterleaveEmbeddings(keras.layers.Layer):
|
6
|
+
"""Places image embeddings in the correct position in an embedding sequence.
|
7
|
+
|
8
|
+
Args:
|
9
|
+
num_vision_tokens_per_image: int. Number of soft tokens per image.
|
10
|
+
"""
|
11
|
+
|
12
|
+
def __init__(self, num_vision_tokens_per_image, **kwargs):
|
13
|
+
super().__init__(**kwargs)
|
14
|
+
|
15
|
+
self.num_vision_tokens_per_image = num_vision_tokens_per_image
|
16
|
+
|
17
|
+
def call(self, image_embeddings, text_embeddings, vision_indices):
|
18
|
+
"""
|
19
|
+
Integrates image embeddings into a text embedding sequence.
|
20
|
+
|
21
|
+
Args:
|
22
|
+
image_embeddings: Tensor of shape
|
23
|
+
`(batch_size * num_images_per_prompt,
|
24
|
+
num_vision_tokens_per_image, embedding_dim)`.
|
25
|
+
text_embeddings: Tensor of shape
|
26
|
+
`(batch_size, seq_length, embedding_dim)`.
|
27
|
+
text_mask: Boolean tensor of shape `(batch_size, seq_length)`.
|
28
|
+
|
29
|
+
Returns:
|
30
|
+
Tensor of shape `(batch_size, seq_length, embedding_dim)`
|
31
|
+
representing the reconstructed embeddings.
|
32
|
+
"""
|
33
|
+
|
34
|
+
batch_size, seq_length, embedding_dim = ops.shape(text_embeddings)
|
35
|
+
|
36
|
+
# Flatten text embeddings, text mask and image embeddings.
|
37
|
+
flat_text_embeddings = ops.reshape(
|
38
|
+
text_embeddings, (batch_size * seq_length, embedding_dim)
|
39
|
+
)
|
40
|
+
|
41
|
+
# The image batch size might be different when we pass only text, i.e,
|
42
|
+
# it will be 0 for text-only.
|
43
|
+
image_batch_size = ops.shape(image_embeddings)[0]
|
44
|
+
flat_image_embeddings = ops.reshape(
|
45
|
+
image_embeddings,
|
46
|
+
(
|
47
|
+
image_batch_size * self.num_vision_tokens_per_image,
|
48
|
+
embedding_dim,
|
49
|
+
),
|
50
|
+
)
|
51
|
+
|
52
|
+
# Reconstruct embeddings.
|
53
|
+
vision_indices_shape = ops.shape(vision_indices)
|
54
|
+
flat_vision_indices = ops.reshape(
|
55
|
+
vision_indices,
|
56
|
+
(vision_indices_shape[0] * vision_indices_shape[1], 1),
|
57
|
+
)
|
58
|
+
indices = ops.cast(flat_vision_indices, "int32")
|
59
|
+
reconstructed_embedding = ops.scatter_update(
|
60
|
+
flat_text_embeddings, indices, flat_image_embeddings
|
61
|
+
)
|
62
|
+
|
63
|
+
# Reshape to original dimensions
|
64
|
+
reconstructed_embedding = ops.reshape(
|
65
|
+
reconstructed_embedding, (batch_size, seq_length, embedding_dim)
|
66
|
+
)
|
67
|
+
return reconstructed_embedding
|
68
|
+
|
69
|
+
def compute_output_shape(self, input_shape):
|
70
|
+
return input_shape
|
71
|
+
|
72
|
+
def get_config(self):
|
73
|
+
config = super().get_config()
|
74
|
+
config.update(
|
75
|
+
{
|
76
|
+
"num_vision_tokens_per_image": self.num_vision_tokens_per_image,
|
77
|
+
}
|
78
|
+
)
|
79
|
+
return config
|
@@ -0,0 +1,93 @@
|
|
1
|
+
"""Gemma3 model preset configurations."""
|
2
|
+
|
3
|
+
# Metadata for loading pretrained model weights.
|
4
|
+
backbone_presets = {
|
5
|
+
"gemma3_1b": {
|
6
|
+
"metadata": {
|
7
|
+
"description": (
|
8
|
+
"1 billion parameter, 26-layer, text-only pretrained "
|
9
|
+
"Gemma3 model."
|
10
|
+
),
|
11
|
+
"params": 999885952,
|
12
|
+
"path": "gemma3",
|
13
|
+
},
|
14
|
+
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_1b/1",
|
15
|
+
},
|
16
|
+
"gemma3_instruct_1b": {
|
17
|
+
"metadata": {
|
18
|
+
"description": (
|
19
|
+
"1 billion parameter, 26-layer, text-only instruction-tuned "
|
20
|
+
"Gemma3 model."
|
21
|
+
),
|
22
|
+
"params": 999885952,
|
23
|
+
"path": "gemma3",
|
24
|
+
},
|
25
|
+
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_1b/1",
|
26
|
+
},
|
27
|
+
"gemma3_4b_text": {
|
28
|
+
"metadata": {
|
29
|
+
"description": (
|
30
|
+
"4 billion parameter, 34-layer, text-only pretrained "
|
31
|
+
"Gemma3 model."
|
32
|
+
),
|
33
|
+
"params": 3880099328,
|
34
|
+
"path": "gemma3",
|
35
|
+
},
|
36
|
+
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_4b_text/1",
|
37
|
+
},
|
38
|
+
"gemma3_instruct_4b_text": {
|
39
|
+
"metadata": {
|
40
|
+
"description": (
|
41
|
+
"4 billion parameter, 34-layer, text-only instruction-tuned "
|
42
|
+
"Gemma3 model."
|
43
|
+
),
|
44
|
+
"params": 3880099328,
|
45
|
+
"path": "gemma3",
|
46
|
+
},
|
47
|
+
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_4b_text/2",
|
48
|
+
},
|
49
|
+
"gemma3_12b_text": {
|
50
|
+
"metadata": {
|
51
|
+
"description": (
|
52
|
+
"12 billion parameter, 48-layer, text-only pretrained "
|
53
|
+
"Gemma3 model."
|
54
|
+
),
|
55
|
+
"params": 11765788416,
|
56
|
+
"path": "gemma3",
|
57
|
+
},
|
58
|
+
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_12b_text/1",
|
59
|
+
},
|
60
|
+
"gemma3_instruct_12b_text": {
|
61
|
+
"metadata": {
|
62
|
+
"description": (
|
63
|
+
"12 billion parameter, 48-layer, text-only instruction-tuned "
|
64
|
+
"Gemma3 model."
|
65
|
+
),
|
66
|
+
"params": 11765788416,
|
67
|
+
"path": "gemma3",
|
68
|
+
},
|
69
|
+
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_12b_text/1",
|
70
|
+
},
|
71
|
+
"gemma3_27b_text": {
|
72
|
+
"metadata": {
|
73
|
+
"description": (
|
74
|
+
"27 billion parameter, 62-layer, text-only pretrained "
|
75
|
+
"Gemma3 model."
|
76
|
+
),
|
77
|
+
"params": 27009002240,
|
78
|
+
"path": "gemma3",
|
79
|
+
},
|
80
|
+
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_27b_text/1",
|
81
|
+
},
|
82
|
+
"gemma3_instruct_27b_text": {
|
83
|
+
"metadata": {
|
84
|
+
"description": (
|
85
|
+
"27 billion parameter, 62-layer, text-only instruction-tuned "
|
86
|
+
"Gemma3 model."
|
87
|
+
),
|
88
|
+
"params": 27009002240,
|
89
|
+
"path": "gemma3",
|
90
|
+
},
|
91
|
+
"kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_27b_text/1",
|
92
|
+
},
|
93
|
+
}
|
@@ -0,0 +1,87 @@
|
|
1
|
+
from keras_hub.src.api_export import keras_hub_export
|
2
|
+
from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3Backbone
|
3
|
+
from keras_hub.src.tokenizers.sentence_piece_tokenizer import (
|
4
|
+
SentencePieceTokenizer,
|
5
|
+
)
|
6
|
+
|
7
|
+
|
8
|
+
@keras_hub_export(
|
9
|
+
[
|
10
|
+
"keras_hub.tokenizers.Gemma3Tokenizer",
|
11
|
+
"keras_hub.models.Gemma3Tokenizer",
|
12
|
+
]
|
13
|
+
)
|
14
|
+
class Gemma3Tokenizer(SentencePieceTokenizer):
|
15
|
+
"""Gemma tokenizer layer based on SentencePiece.
|
16
|
+
|
17
|
+
This tokenizer class will tokenize raw strings into integer sequences and
|
18
|
+
is based on `keras_hub.tokenizers.SentencePieceTokenizer`. Unlike the
|
19
|
+
underlying tokenizer, it will check for all special tokens needed by
|
20
|
+
Gemma models and provides a `from_preset()` method to automatically
|
21
|
+
download a matching vocabulary for a Gemma preset.
|
22
|
+
|
23
|
+
If input is a batch of strings (rank > 0), the layer will output a
|
24
|
+
`tf.RaggedTensor` where the last dimension of the output is ragged.
|
25
|
+
|
26
|
+
If input is a scalar string (rank == 0), the layer will output a dense
|
27
|
+
`tf.Tensor` with static shape `[None]`.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
proto: Either a `string` path to a SentencePiece proto file, or a
|
31
|
+
`bytes` object with a serialized SentencePiece proto. See the
|
32
|
+
[SentencePiece repository](https://github.com/google/sentencepiece)
|
33
|
+
for more details on the format.
|
34
|
+
|
35
|
+
Examples:
|
36
|
+
|
37
|
+
```python
|
38
|
+
# Unbatched input.
|
39
|
+
tokenizer = keras_hub.models.Gemma3Tokenizer.from_preset(
|
40
|
+
"gemma_instruct_1b"
|
41
|
+
)
|
42
|
+
tokenizer("The quick brown fox jumped.")
|
43
|
+
|
44
|
+
# Batched input.
|
45
|
+
tokenizer(["The quick brown fox jumped.", "The fox slept."])
|
46
|
+
|
47
|
+
# Detokenization.
|
48
|
+
tokenizer.detokenize(tokenizer("The quick brown fox jumped."))
|
49
|
+
|
50
|
+
# Custom vocabulary.
|
51
|
+
bytes_io = io.BytesIO()
|
52
|
+
ds = tf.data.Dataset.from_tensor_slices(["The quick brown fox jumped."])
|
53
|
+
sentencepiece.SentencePieceTrainer.train(
|
54
|
+
sentence_iterator=ds.as_numpy_iterator(),
|
55
|
+
model_writer=bytes_io,
|
56
|
+
vocab_size=8,
|
57
|
+
model_type="WORD",
|
58
|
+
pad_id=0,
|
59
|
+
bos_id=1,
|
60
|
+
eos_id=2,
|
61
|
+
unk_id=3,
|
62
|
+
pad_piece="<pad>",
|
63
|
+
bos_piece="<bos>",
|
64
|
+
eos_piece="<eos>",
|
65
|
+
unk_piece="<unk>",
|
66
|
+
)
|
67
|
+
tokenizer = keras_hub.models.Gemma3Tokenizer(
|
68
|
+
proto=bytes_io.getvalue(),
|
69
|
+
)
|
70
|
+
tokenizer("The quick brown fox jumped.")
|
71
|
+
```
|
72
|
+
"""
|
73
|
+
|
74
|
+
backbone_cls = Gemma3Backbone
|
75
|
+
|
76
|
+
def __init__(self, proto, **kwargs):
|
77
|
+
# Add special tokens.
|
78
|
+
|
79
|
+
# The usual tokens.
|
80
|
+
self._add_special_token("<bos>", "start_token")
|
81
|
+
self._add_special_token("<eos>", "end_token")
|
82
|
+
self._add_special_token("<pad>", "pad_token")
|
83
|
+
|
84
|
+
# Image placeholder token.
|
85
|
+
self._add_special_token("<img>", "image_placeholder")
|
86
|
+
|
87
|
+
super().__init__(proto=proto, **kwargs)
|