keras-hub-nightly 0.20.0.dev202504030357__py3-none-any.whl → 0.21.0.dev202504050402__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/models/__init__.py +5 -20
- keras_hub/api/tokenizers/__init__.py +0 -4
- keras_hub/src/layers/preprocessing/image_converter.py +26 -16
- keras_hub/src/models/gemma3/gemma3_attention.py +74 -21
- keras_hub/src/models/gemma3/gemma3_backbone.py +117 -46
- keras_hub/src/models/gemma3/gemma3_causal_lm.py +72 -15
- keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +512 -355
- keras_hub/src/models/gemma3/gemma3_decoder_block.py +23 -19
- keras_hub/src/models/gemma3/gemma3_image_converter.py +6 -0
- keras_hub/src/models/gemma3/gemma3_interleave_embeddings.py +56 -16
- keras_hub/src/models/gemma3/gemma3_presets.py +74 -8
- keras_hub/src/models/gemma3/gemma3_tokenizer.py +9 -0
- keras_hub/src/models/gemma3/{gemma3_vit.py → gemma3_vision_encoder.py} +150 -139
- keras_hub/src/models/qwen/qwen_backbone.py +0 -7
- keras_hub/src/models/qwen/qwen_causal_lm.py +0 -7
- keras_hub/src/models/qwen/qwen_causal_lm_preprocessor.py +0 -7
- keras_hub/src/models/qwen/qwen_tokenizer.py +0 -9
- keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -1
- keras_hub/src/models/roformer_v2/roformer_v2_text_classifier.py +2 -2
- keras_hub/src/models/vit/vit_image_converter.py +8 -3
- keras_hub/src/tests/test_case.py +4 -0
- keras_hub/src/utils/tensor_utils.py +6 -0
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.20.0.dev202504030357.dist-info → keras_hub_nightly-0.21.0.dev202504050402.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.20.0.dev202504030357.dist-info → keras_hub_nightly-0.21.0.dev202504050402.dist-info}/RECORD +27 -27
- {keras_hub_nightly-0.20.0.dev202504030357.dist-info → keras_hub_nightly-0.21.0.dev202504050402.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.20.0.dev202504030357.dist-info → keras_hub_nightly-0.21.0.dev202504050402.dist-info}/top_level.txt +0 -0
| @@ -14,16 +14,17 @@ from keras_hub.src.models.gemma3.rms_normalization import RMSNormalization | |
| 14 14 | 
             
            class Gemma3DecoderBlock(keras.layers.Layer):
         | 
| 15 15 | 
             
                """Transformer decoder layer for Gemma3.
         | 
| 16 16 |  | 
| 17 | 
            -
                This is  | 
| 18 | 
            -
             | 
| 19 | 
            -
                 | 
| 20 | 
            -
             | 
| 21 | 
            -
             | 
| 22 | 
            -
             | 
| 23 | 
            -
             | 
| 24 | 
            -
             | 
| 25 | 
            -
             | 
| 26 | 
            -
             | 
| 17 | 
            +
                This decoder layer is the same as the layer used for Gemma and Gemma2.
         | 
| 18 | 
            +
                However, there are a few key differences. Firstly, image tokens have
         | 
| 19 | 
            +
                bidirectional masking. Additionally, this layer exposes the following args:
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                `use_query_key_norm`: bool. If True, apply RMS normalization on query
         | 
| 22 | 
            +
                    and key. For Gemma3, this is True.
         | 
| 23 | 
            +
                `rope_wavelength`: float. Configurable value for RoPE wavelength. Gemma3
         | 
| 24 | 
            +
                    uses 10K for local attention layers and 1M for global attention layers.
         | 
| 25 | 
            +
                `gate_dim_reduction`: int. In the gating layers, the output dimension is
         | 
| 26 | 
            +
                    `intermediate_dim // gate_dim_reduction`. For Gemma and Gemma2, this
         | 
| 27 | 
            +
                    value is 2. For Gemma3, it is 1.
         | 
| 27 28 | 
             
                """
         | 
| 28 29 |  | 
| 29 30 | 
             
                def __init__(
         | 
| @@ -160,9 +161,10 @@ class Gemma3DecoderBlock(keras.layers.Layer): | |
| 160 161 | 
             
                    # Isometric
         | 
| 161 162 | 
             
                    return input_shape
         | 
| 162 163 |  | 
| 163 | 
            -
                def _compute_image_bidirectional_attention_mask(self,  | 
| 164 | 
            -
                    #  | 
| 165 | 
            -
                     | 
| 164 | 
            +
                def _compute_image_bidirectional_attention_mask(self, vision_mask):
         | 
| 165 | 
            +
                    # vision_mask is False for text, True for images. Shape of
         | 
| 166 | 
            +
                    # (bsz, seq_len).
         | 
| 167 | 
            +
                    bidirectional_mask = vision_mask
         | 
| 166 168 |  | 
| 167 169 | 
             
                    # Left pad with 0.
         | 
| 168 170 | 
             
                    padded_mask = ops.cast(
         | 
| @@ -194,7 +196,7 @@ class Gemma3DecoderBlock(keras.layers.Layer): | |
| 194 196 | 
             
                    self,
         | 
| 195 197 | 
             
                    x,
         | 
| 196 198 | 
             
                    padding_mask,
         | 
| 197 | 
            -
                     | 
| 199 | 
            +
                    vision_mask,
         | 
| 198 200 | 
             
                    cache,
         | 
| 199 201 | 
             
                    cache_update_index,
         | 
| 200 202 | 
             
                ):
         | 
| @@ -216,9 +218,9 @@ class Gemma3DecoderBlock(keras.layers.Layer): | |
| 216 218 |  | 
| 217 219 | 
             
                    # Compute bidirectional mask (image tokens can attend to each other
         | 
| 218 220 | 
             
                    # in both directions, within the same image).
         | 
| 219 | 
            -
                    if  | 
| 221 | 
            +
                    if vision_mask is not None:
         | 
| 220 222 | 
             
                        bidirectional_image_mask = (
         | 
| 221 | 
            -
                            self._compute_image_bidirectional_attention_mask( | 
| 223 | 
            +
                            self._compute_image_bidirectional_attention_mask(vision_mask)
         | 
| 222 224 | 
             
                        )
         | 
| 223 225 | 
             
                        causal_mask = ops.logical_or(causal_mask, bidirectional_image_mask)
         | 
| 224 226 |  | 
| @@ -232,14 +234,15 @@ class Gemma3DecoderBlock(keras.layers.Layer): | |
| 232 234 | 
             
                    self,
         | 
| 233 235 | 
             
                    x,
         | 
| 234 236 | 
             
                    padding_mask=None,
         | 
| 235 | 
            -
                     | 
| 237 | 
            +
                    vision_mask=None,
         | 
| 236 238 | 
             
                    cache=None,
         | 
| 237 239 | 
             
                    cache_update_index=0,
         | 
| 240 | 
            +
                    cache_update_mask=None,
         | 
| 238 241 | 
             
                ):
         | 
| 239 | 
            -
                    # Note: ` | 
| 242 | 
            +
                    # Note: `vision_mask` is used only for Gemma3.
         | 
| 240 243 | 
             
                    normalized_x = self.pre_attention_norm(x)
         | 
| 241 244 | 
             
                    attention_mask = self._compute_attention_mask(
         | 
| 242 | 
            -
                        normalized_x, padding_mask,  | 
| 245 | 
            +
                        normalized_x, padding_mask, vision_mask, cache, cache_update_index
         | 
| 243 246 | 
             
                    )
         | 
| 244 247 | 
             
                    if cache is not None:
         | 
| 245 248 | 
             
                        attention, new_cache = self.attention(
         | 
| @@ -247,6 +250,7 @@ class Gemma3DecoderBlock(keras.layers.Layer): | |
| 247 250 | 
             
                            attention_mask=attention_mask,
         | 
| 248 251 | 
             
                            cache=cache,
         | 
| 249 252 | 
             
                            cache_update_index=cache_update_index,
         | 
| 253 | 
            +
                            cache_update_mask=cache_update_mask,
         | 
| 250 254 | 
             
                        )
         | 
| 251 255 | 
             
                    else:
         | 
| 252 256 | 
             
                        attention = self.attention(
         | 
| @@ -6,3 +6,9 @@ from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3Backbone | |
| 6 6 | 
             
            @keras_hub_export("keras_hub.layers.Gemma3ImageConverter")
         | 
| 7 7 | 
             
            class Gemma3ImageConverter(ImageConverter):
         | 
| 8 8 | 
             
                backbone_cls = Gemma3Backbone
         | 
| 9 | 
            +
             | 
| 10 | 
            +
                def __init__(self, **kwargs):
         | 
| 11 | 
            +
                    # Always do image preprocessing in float32
         | 
| 12 | 
            +
                    kwargs.pop("dtype", None)
         | 
| 13 | 
            +
                    dtype = "float32"
         | 
| 14 | 
            +
                    super().__init__(dtype=dtype, **kwargs)
         | 
| @@ -5,12 +5,17 @@ from keras import ops | |
| 5 5 | 
             
            class Gemma3InterleaveEmbeddings(keras.layers.Layer):
         | 
| 6 6 | 
             
                """Places image embeddings in the correct position in an embedding sequence.
         | 
| 7 7 |  | 
| 8 | 
            +
                For Gemma3, images can be in any position in the input sequence. In order
         | 
| 9 | 
            +
                to do accomplish this, we have image placeholder tokens in the input
         | 
| 10 | 
            +
                sequence. We fill up these positions with the image embeddings as returned
         | 
| 11 | 
            +
                by the vision encoder.
         | 
| 12 | 
            +
             | 
| 8 13 | 
             
                Args:
         | 
| 9 14 | 
             
                    num_vision_tokens_per_image: int. Number of soft tokens per image.
         | 
| 10 15 | 
             
                """
         | 
| 11 16 |  | 
| 12 | 
            -
                def __init__(self, num_vision_tokens_per_image, **kwargs):
         | 
| 13 | 
            -
                    super().__init__(**kwargs)
         | 
| 17 | 
            +
                def __init__(self, num_vision_tokens_per_image, dtype=None, **kwargs):
         | 
| 18 | 
            +
                    super().__init__(dtype=dtype, **kwargs)
         | 
| 14 19 |  | 
| 15 20 | 
             
                    self.num_vision_tokens_per_image = num_vision_tokens_per_image
         | 
| 16 21 |  | 
| @@ -19,12 +24,17 @@ class Gemma3InterleaveEmbeddings(keras.layers.Layer): | |
| 19 24 | 
             
                    Integrates image embeddings into a text embedding sequence.
         | 
| 20 25 |  | 
| 21 26 | 
             
                    Args:
         | 
| 22 | 
            -
                        image_embeddings:  | 
| 23 | 
            -
                            ` | 
| 24 | 
            -
                             | 
| 25 | 
            -
             | 
| 26 | 
            -
                         | 
| 27 | 
            -
             | 
| 27 | 
            +
                        image_embeddings: tensor. Image embeddings as returned by the
         | 
| 28 | 
            +
                            vision encoder (`Gemma3VisionEncoder`, usually). Shape:
         | 
| 29 | 
            +
                            `(batch_size * num_images_per_prompt, `
         | 
| 30 | 
            +
                            `num_vision_tokens_per_image, embedding_dim)`.
         | 
| 31 | 
            +
                        text_embeddings: tensor. Embeddings returned by the text embedding
         | 
| 32 | 
            +
                            layer. Shape: `(batch_size, seq_length, embedding_dim)`.
         | 
| 33 | 
            +
                        vision_indices:  tensor. Indexes into `text_embeddings`, used to
         | 
| 34 | 
            +
                            identify which places are supposed to be replaced by
         | 
| 35 | 
            +
                            `image_embeddings`. Shape:
         | 
| 36 | 
            +
                            `(batch_size,`
         | 
| 37 | 
            +
                            `num_images_per_prompt * num_vision_tokens_per_image)`.
         | 
| 28 38 |  | 
| 29 39 | 
             
                    Returns:
         | 
| 30 40 | 
             
                        Tensor of shape `(batch_size, seq_length, embedding_dim)`
         | 
| @@ -32,32 +42,62 @@ class Gemma3InterleaveEmbeddings(keras.layers.Layer): | |
| 32 42 | 
             
                    """
         | 
| 33 43 |  | 
| 34 44 | 
             
                    batch_size, seq_length, embedding_dim = ops.shape(text_embeddings)
         | 
| 45 | 
            +
                    # `num_images` will be 0 for text only inputs, and
         | 
| 46 | 
            +
                    # `batch_size * max_images_per_prompt` if images are passed.
         | 
| 47 | 
            +
                    num_images = ops.shape(image_embeddings)[0]
         | 
| 35 48 |  | 
| 36 | 
            -
                    # Flatten text embeddings,  | 
| 49 | 
            +
                    # Flatten text embeddings, image embeddings and indices.
         | 
| 37 50 | 
             
                    flat_text_embeddings = ops.reshape(
         | 
| 38 51 | 
             
                        text_embeddings, (batch_size * seq_length, embedding_dim)
         | 
| 39 52 | 
             
                    )
         | 
| 40 | 
            -
             | 
| 41 | 
            -
                    #  | 
| 42 | 
            -
                    # it will be 0 for text-only.
         | 
| 43 | 
            -
                    image_batch_size = ops.shape(image_embeddings)[0]
         | 
| 53 | 
            +
                    # `flat_image_embeddings` is the `updates` tensor and should be of shape
         | 
| 54 | 
            +
                    # `(num_updates, embedding_dim)`.
         | 
| 44 55 | 
             
                    flat_image_embeddings = ops.reshape(
         | 
| 45 56 | 
             
                        image_embeddings,
         | 
| 46 57 | 
             
                        (
         | 
| 47 | 
            -
                             | 
| 58 | 
            +
                            num_images * self.num_vision_tokens_per_image,
         | 
| 48 59 | 
             
                            embedding_dim,
         | 
| 49 60 | 
             
                        ),
         | 
| 50 61 | 
             
                    )
         | 
| 51 62 |  | 
| 52 | 
            -
                    #  | 
| 63 | 
            +
                    # For vision indices, we need to add values such that the indices
         | 
| 64 | 
            +
                    # index into a flattened `text_embeddings`.
         | 
| 65 | 
            +
                    to_add = ops.multiply(
         | 
| 66 | 
            +
                        keras.ops.arange(batch_size, dtype="int32"), seq_length
         | 
| 67 | 
            +
                    )
         | 
| 68 | 
            +
                    to_add = ops.expand_dims(to_add, axis=-1)
         | 
| 69 | 
            +
                    vision_indices = ops.add(vision_indices, to_add)
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    # indices should be of shape `(num_updates, 1)`. `num_updates` is
         | 
| 72 | 
            +
                    # how many vision tokens there are to update.
         | 
| 53 73 | 
             
                    vision_indices_shape = ops.shape(vision_indices)
         | 
| 54 74 | 
             
                    flat_vision_indices = ops.reshape(
         | 
| 55 75 | 
             
                        vision_indices,
         | 
| 56 76 | 
             
                        (vision_indices_shape[0] * vision_indices_shape[1], 1),
         | 
| 57 77 | 
             
                    )
         | 
| 58 78 | 
             
                    indices = ops.cast(flat_vision_indices, "int32")
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    # Before reconstructing, store the 0th index so that we can restore it
         | 
| 81 | 
            +
                    # later.
         | 
| 82 | 
            +
                    zeroth_index_text_embeddings = ops.take(
         | 
| 83 | 
            +
                        flat_text_embeddings,
         | 
| 84 | 
            +
                        indices=ops.squeeze(to_add, axis=-1),
         | 
| 85 | 
            +
                        axis=0,
         | 
| 86 | 
            +
                    )
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    # Reconstruct embeddings
         | 
| 89 | 
            +
                    reconstructed_embedding = ops.scatter_update(
         | 
| 90 | 
            +
                        inputs=flat_text_embeddings,
         | 
| 91 | 
            +
                        indices=indices,
         | 
| 92 | 
            +
                        updates=flat_image_embeddings,
         | 
| 93 | 
            +
                    )
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                    # Remember that we pad `vision_indices` with the 0th index. We need to
         | 
| 96 | 
            +
                    # restore the original value in the reconstructed embedding tensor.
         | 
| 59 97 | 
             
                    reconstructed_embedding = ops.scatter_update(
         | 
| 60 | 
            -
                         | 
| 98 | 
            +
                        inputs=reconstructed_embedding,
         | 
| 99 | 
            +
                        indices=to_add,
         | 
| 100 | 
            +
                        updates=zeroth_index_text_embeddings,
         | 
| 61 101 | 
             
                    )
         | 
| 62 102 |  | 
| 63 103 | 
             
                    # Reshape to original dimensions
         | 
| @@ -11,7 +11,7 @@ backbone_presets = { | |
| 11 11 | 
             
                        "params": 999885952,
         | 
| 12 12 | 
             
                        "path": "gemma3",
         | 
| 13 13 | 
             
                    },
         | 
| 14 | 
            -
                    "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_1b/ | 
| 14 | 
            +
                    "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_1b/3",
         | 
| 15 15 | 
             
                },
         | 
| 16 16 | 
             
                "gemma3_instruct_1b": {
         | 
| 17 17 | 
             
                    "metadata": {
         | 
| @@ -22,7 +22,7 @@ backbone_presets = { | |
| 22 22 | 
             
                        "params": 999885952,
         | 
| 23 23 | 
             
                        "path": "gemma3",
         | 
| 24 24 | 
             
                    },
         | 
| 25 | 
            -
                    "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_1b/ | 
| 25 | 
            +
                    "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_1b/3",
         | 
| 26 26 | 
             
                },
         | 
| 27 27 | 
             
                "gemma3_4b_text": {
         | 
| 28 28 | 
             
                    "metadata": {
         | 
| @@ -33,7 +33,7 @@ backbone_presets = { | |
| 33 33 | 
             
                        "params": 3880099328,
         | 
| 34 34 | 
             
                        "path": "gemma3",
         | 
| 35 35 | 
             
                    },
         | 
| 36 | 
            -
                    "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_4b_text/ | 
| 36 | 
            +
                    "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_4b_text/2",
         | 
| 37 37 | 
             
                },
         | 
| 38 38 | 
             
                "gemma3_instruct_4b_text": {
         | 
| 39 39 | 
             
                    "metadata": {
         | 
| @@ -44,7 +44,7 @@ backbone_presets = { | |
| 44 44 | 
             
                        "params": 3880099328,
         | 
| 45 45 | 
             
                        "path": "gemma3",
         | 
| 46 46 | 
             
                    },
         | 
| 47 | 
            -
                    "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_4b_text/ | 
| 47 | 
            +
                    "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_4b_text/3",
         | 
| 48 48 | 
             
                },
         | 
| 49 49 | 
             
                "gemma3_12b_text": {
         | 
| 50 50 | 
             
                    "metadata": {
         | 
| @@ -55,7 +55,7 @@ backbone_presets = { | |
| 55 55 | 
             
                        "params": 11765788416,
         | 
| 56 56 | 
             
                        "path": "gemma3",
         | 
| 57 57 | 
             
                    },
         | 
| 58 | 
            -
                    "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_12b_text/ | 
| 58 | 
            +
                    "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_12b_text/2",
         | 
| 59 59 | 
             
                },
         | 
| 60 60 | 
             
                "gemma3_instruct_12b_text": {
         | 
| 61 61 | 
             
                    "metadata": {
         | 
| @@ -66,7 +66,7 @@ backbone_presets = { | |
| 66 66 | 
             
                        "params": 11765788416,
         | 
| 67 67 | 
             
                        "path": "gemma3",
         | 
| 68 68 | 
             
                    },
         | 
| 69 | 
            -
                    "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_12b_text/ | 
| 69 | 
            +
                    "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_12b_text/2",
         | 
| 70 70 | 
             
                },
         | 
| 71 71 | 
             
                "gemma3_27b_text": {
         | 
| 72 72 | 
             
                    "metadata": {
         | 
| @@ -77,7 +77,7 @@ backbone_presets = { | |
| 77 77 | 
             
                        "params": 27009002240,
         | 
| 78 78 | 
             
                        "path": "gemma3",
         | 
| 79 79 | 
             
                    },
         | 
| 80 | 
            -
                    "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_27b_text/ | 
| 80 | 
            +
                    "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_27b_text/3",
         | 
| 81 81 | 
             
                },
         | 
| 82 82 | 
             
                "gemma3_instruct_27b_text": {
         | 
| 83 83 | 
             
                    "metadata": {
         | 
| @@ -88,6 +88,72 @@ backbone_presets = { | |
| 88 88 | 
             
                        "params": 27009002240,
         | 
| 89 89 | 
             
                        "path": "gemma3",
         | 
| 90 90 | 
             
                    },
         | 
| 91 | 
            -
                    "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_27b_text/ | 
| 91 | 
            +
                    "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_27b_text/2",
         | 
| 92 | 
            +
                },
         | 
| 93 | 
            +
                "gemma3_4b": {
         | 
| 94 | 
            +
                    "metadata": {
         | 
| 95 | 
            +
                        "description": (
         | 
| 96 | 
            +
                            "4 billion parameter, 34-layer, vision+text pretrained "
         | 
| 97 | 
            +
                            "Gemma3 model."
         | 
| 98 | 
            +
                        ),
         | 
| 99 | 
            +
                        "params": 4299915632,
         | 
| 100 | 
            +
                        "path": "gemma3",
         | 
| 101 | 
            +
                    },
         | 
| 102 | 
            +
                    "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_4b/1",
         | 
| 103 | 
            +
                },
         | 
| 104 | 
            +
                "gemma3_instruct_4b": {
         | 
| 105 | 
            +
                    "metadata": {
         | 
| 106 | 
            +
                        "description": (
         | 
| 107 | 
            +
                            "4 billion parameter, 34-layer, vision+text instruction-tuned "
         | 
| 108 | 
            +
                            "Gemma3 model."
         | 
| 109 | 
            +
                        ),
         | 
| 110 | 
            +
                        "params": 4299915632,
         | 
| 111 | 
            +
                        "path": "gemma3",
         | 
| 112 | 
            +
                    },
         | 
| 113 | 
            +
                    "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_4b/1",
         | 
| 114 | 
            +
                },
         | 
| 115 | 
            +
                "gemma3_12b": {
         | 
| 116 | 
            +
                    "metadata": {
         | 
| 117 | 
            +
                        "description": (
         | 
| 118 | 
            +
                            "12 billion parameter, 48-layer, vision+text pretrained "
         | 
| 119 | 
            +
                            "Gemma3 model."
         | 
| 120 | 
            +
                        ),
         | 
| 121 | 
            +
                        "params": 12187079280,
         | 
| 122 | 
            +
                        "path": "gemma3",
         | 
| 123 | 
            +
                    },
         | 
| 124 | 
            +
                    "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_12b/1",
         | 
| 125 | 
            +
                },
         | 
| 126 | 
            +
                "gemma3_instruct_12b": {
         | 
| 127 | 
            +
                    "metadata": {
         | 
| 128 | 
            +
                        "description": (
         | 
| 129 | 
            +
                            "12 billion parameter, 48-layer, vision+text instruction-tuned "
         | 
| 130 | 
            +
                            "Gemma3 model."
         | 
| 131 | 
            +
                        ),
         | 
| 132 | 
            +
                        "params": 12187079280,
         | 
| 133 | 
            +
                        "path": "gemma3",
         | 
| 134 | 
            +
                    },
         | 
| 135 | 
            +
                    "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_12b/1",
         | 
| 136 | 
            +
                },
         | 
| 137 | 
            +
                "gemma3_27b": {
         | 
| 138 | 
            +
                    "metadata": {
         | 
| 139 | 
            +
                        "description": (
         | 
| 140 | 
            +
                            "27 billion parameter, 62-layer, vision+text pretrained "
         | 
| 141 | 
            +
                            "Gemma3 model."
         | 
| 142 | 
            +
                        ),
         | 
| 143 | 
            +
                        "params": 27432062576,
         | 
| 144 | 
            +
                        "path": "gemma3",
         | 
| 145 | 
            +
                    },
         | 
| 146 | 
            +
                    "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_27b/1",
         | 
| 147 | 
            +
                },
         | 
| 148 | 
            +
                "gemma3_instruct_27b": {
         | 
| 149 | 
            +
                    "metadata": {
         | 
| 150 | 
            +
                        "description": (
         | 
| 151 | 
            +
                            "27 billion parameter, 62-layer, vision+text instruction-tuned "
         | 
| 152 | 
            +
                            "Gemma3 model."
         | 
| 153 | 
            +
                        ),
         | 
| 154 | 
            +
                        "params": 27432062576,
         | 
| 155 | 
            +
                        "path": "gemma3",
         | 
| 156 | 
            +
                    },
         | 
| 157 | 
            +
                    "kaggle_handle": "kaggle://keras/gemma3/keras/gemma3_instruct_27b/1",
         | 
| 92 158 | 
             
                },
         | 
| 93 159 | 
             
            }
         | 
| @@ -4,6 +4,10 @@ from keras_hub.src.tokenizers.sentence_piece_tokenizer import ( | |
| 4 4 | 
             
                SentencePieceTokenizer,
         | 
| 5 5 | 
             
            )
         | 
| 6 6 |  | 
| 7 | 
            +
            START_OF_IMAGE_TOKEN = "<start_of_image>"
         | 
| 8 | 
            +
            IMAGE_PLACEHOLDER_TOKEN = "<img>"
         | 
| 9 | 
            +
            END_OF_IMAGE_TOKEN = "<end_of_image>"
         | 
| 10 | 
            +
             | 
| 7 11 |  | 
| 8 12 | 
             
            @keras_hub_export(
         | 
| 9 13 | 
             
                [
         | 
| @@ -84,4 +88,9 @@ class Gemma3Tokenizer(SentencePieceTokenizer): | |
| 84 88 | 
             
                    # Image placeholder token.
         | 
| 85 89 | 
             
                    self._add_special_token("<img>", "image_placeholder")
         | 
| 86 90 |  | 
| 91 | 
            +
                    #  Some tokens which are used in the preprocessor. We need to keep them
         | 
| 92 | 
            +
                    # here so that the preprocessor works with `tf.data`.
         | 
| 93 | 
            +
                    self._add_special_token("<start_of_image>", "start_of_image_token")
         | 
| 94 | 
            +
                    self._add_special_token("<end_of_image>", "end_of_image_token")
         | 
| 95 | 
            +
             | 
| 87 96 | 
             
                    super().__init__(proto=proto, **kwargs)
         |