keras-hub-nightly 0.20.0.dev202504030357__py3-none-any.whl → 0.21.0.dev202504050402__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/models/__init__.py +5 -20
- keras_hub/api/tokenizers/__init__.py +0 -4
- keras_hub/src/layers/preprocessing/image_converter.py +26 -16
- keras_hub/src/models/gemma3/gemma3_attention.py +74 -21
- keras_hub/src/models/gemma3/gemma3_backbone.py +117 -46
- keras_hub/src/models/gemma3/gemma3_causal_lm.py +72 -15
- keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +512 -355
- keras_hub/src/models/gemma3/gemma3_decoder_block.py +23 -19
- keras_hub/src/models/gemma3/gemma3_image_converter.py +6 -0
- keras_hub/src/models/gemma3/gemma3_interleave_embeddings.py +56 -16
- keras_hub/src/models/gemma3/gemma3_presets.py +74 -8
- keras_hub/src/models/gemma3/gemma3_tokenizer.py +9 -0
- keras_hub/src/models/gemma3/{gemma3_vit.py → gemma3_vision_encoder.py} +150 -139
- keras_hub/src/models/qwen/qwen_backbone.py +0 -7
- keras_hub/src/models/qwen/qwen_causal_lm.py +0 -7
- keras_hub/src/models/qwen/qwen_causal_lm_preprocessor.py +0 -7
- keras_hub/src/models/qwen/qwen_tokenizer.py +0 -9
- keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -1
- keras_hub/src/models/roformer_v2/roformer_v2_text_classifier.py +2 -2
- keras_hub/src/models/vit/vit_image_converter.py +8 -3
- keras_hub/src/tests/test_case.py +4 -0
- keras_hub/src/utils/tensor_utils.py +6 -0
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.20.0.dev202504030357.dist-info → keras_hub_nightly-0.21.0.dev202504050402.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.20.0.dev202504030357.dist-info → keras_hub_nightly-0.21.0.dev202504050402.dist-info}/RECORD +27 -27
- {keras_hub_nightly-0.20.0.dev202504030357.dist-info → keras_hub_nightly-0.21.0.dev202504050402.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.20.0.dev202504030357.dist-info → keras_hub_nightly-0.21.0.dev202504050402.dist-info}/top_level.txt +0 -0
| @@ -1,4 +1,5 @@ | |
| 1 1 | 
             
            import keras
         | 
| 2 | 
            +
            import numpy as np
         | 
| 2 3 | 
             
            import tensorflow as tf
         | 
| 3 4 |  | 
| 4 5 | 
             
            from keras_hub.src.api_export import keras_hub_export
         | 
| @@ -14,24 +15,28 @@ from keras_hub.src.models.gemma3.gemma3_tokenizer import Gemma3Tokenizer | |
| 14 15 | 
             
            from keras_hub.src.utils.tensor_utils import preprocessing_function
         | 
| 15 16 | 
             
            from keras_hub.src.utils.tensor_utils import strip_to_ragged
         | 
| 16 17 |  | 
| 17 | 
            -
            START_OF_IMAGE_TOKEN = "<start_of_image>"
         | 
| 18 | 
            -
            IMAGE_PLACEHOLDER_TOKEN = "<img>"
         | 
| 19 | 
            -
            END_OF_IMAGE_TOKEN = "<end_of_image>"
         | 
| 20 | 
            -
             | 
| 21 18 |  | 
| 22 19 | 
             
            @keras_hub_export("keras_hub.models.Gemma3CausalLMPreprocessor")
         | 
| 23 20 | 
             
            class Gemma3CausalLMPreprocessor(CausalLMPreprocessor):
         | 
| 24 21 | 
             
                """Gemma3 Causal LM preprocessor.
         | 
| 25 22 |  | 
| 26 23 | 
             
                This preprocessing layer is meant for use with
         | 
| 27 | 
            -
                `keras_hub.models.Gemma3CausalLM`.  | 
| 28 | 
            -
                 | 
| 29 | 
            -
                 | 
| 30 | 
            -
             | 
| 31 | 
            -
                 | 
| 32 | 
            -
                 | 
| 33 | 
            -
                 | 
| 34 | 
            -
                 | 
| 24 | 
            +
                `keras_hub.models.Gemma3CausalLM`. It can be configured in two ways:
         | 
| 25 | 
            +
                text-only and text + vision, based on whether the passed value of
         | 
| 26 | 
            +
                `image_converter` is None. For the former, it takes in batches of strings,
         | 
| 27 | 
            +
                whereas for the latter, it takes in batches of images and strings. It
         | 
| 28 | 
            +
                returns outputs in a `(x, y, sample_weight)` format, where the `y` label is
         | 
| 29 | 
            +
                the next token id in the `x` sequence. `sample_weight` is 0 for "prompt"
         | 
| 30 | 
            +
                tokens, and 1 for "response" tokens, so that the loss is computed only on
         | 
| 31 | 
            +
                the "response" tokens.
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                For the text + vision case, this layer replaces instance of
         | 
| 34 | 
            +
                `<start_of_image>` token in the prompt with `num_vision_tokens_per_image`
         | 
| 35 | 
            +
                placeholder tokens. It also returns indices of where these vision tokens
         | 
| 36 | 
            +
                are present so that in the model, image embeddings can be placed in the
         | 
| 37 | 
            +
                right position in the sequence of text embeddings. Note that if
         | 
| 38 | 
            +
                `max_images_per_prompt` is 2, you can pass either 0, 1, 2 images per sample.
         | 
| 39 | 
            +
                The value 0 corresponds to text-only input.
         | 
| 35 40 |  | 
| 36 41 | 
             
                For use with generation, the layer also exposes two methods
         | 
| 37 42 | 
             
                `generate_preprocess()` and `generate_postprocess()`. When this preprocessor
         | 
| @@ -64,25 +69,170 @@ class Gemma3CausalLMPreprocessor(CausalLMPreprocessor): | |
| 64 69 |  | 
| 65 70 | 
             
                Examples:
         | 
| 66 71 | 
             
                ```python
         | 
| 72 | 
            +
                # === Language Gemma3 model ===
         | 
| 67 73 | 
             
                # Load the preprocessor from a preset.
         | 
| 68 74 | 
             
                preprocessor = keras_hub.models.Gemma3CausalLMPreprocessor.from_preset(
         | 
| 69 | 
            -
                    " | 
| 75 | 
            +
                    "gemma3_instruct_1b"
         | 
| 70 76 | 
             
                )
         | 
| 71 77 |  | 
| 72 | 
            -
                #  | 
| 78 | 
            +
                # Unbatched inputs.
         | 
| 73 79 | 
             
                preprocessor(
         | 
| 74 | 
            -
                     | 
| 75 | 
            -
             | 
| 80 | 
            +
                    {
         | 
| 81 | 
            +
                        "prompts": "What is the capital of India?",
         | 
| 82 | 
            +
                        "responses": "New Delhi",
         | 
| 83 | 
            +
                    }
         | 
| 76 84 | 
             
                )
         | 
| 77 85 |  | 
| 78 | 
            -
                #  | 
| 79 | 
            -
                max_images_per_prompt = 2
         | 
| 86 | 
            +
                # Batched inputs.
         | 
| 80 87 | 
             
                preprocessor(
         | 
| 81 | 
            -
                     | 
| 82 | 
            -
             | 
| 83 | 
            -
             | 
| 84 | 
            -
             | 
| 88 | 
            +
                    {
         | 
| 89 | 
            +
                        "prompts": [
         | 
| 90 | 
            +
                            "What is the capital of India?",
         | 
| 91 | 
            +
                            "What is the capital of Spain?"
         | 
| 92 | 
            +
                        ],
         | 
| 93 | 
            +
                        "responses": ["New Delhi", "Madrid"],
         | 
| 94 | 
            +
                    }
         | 
| 85 95 | 
             
                )
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                # Apply preprocessing to a `tf.data.Dataset`.
         | 
| 98 | 
            +
                features = {
         | 
| 99 | 
            +
                    "prompts": [
         | 
| 100 | 
            +
                        "What is the capital of India?",
         | 
| 101 | 
            +
                        "What is the capital of Spain?"
         | 
| 102 | 
            +
                    ],
         | 
| 103 | 
            +
                    "responses": ["New Delhi", "Madrid"],
         | 
| 104 | 
            +
                }
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                ds = tf.data.Dataset.from_tensor_slices(features)
         | 
| 107 | 
            +
                ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                # Prepare tokens for generation (no end token).
         | 
| 110 | 
            +
                preprocessor.generate_preprocess(["The quick brown fox jumped."])
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                # Map generation outputs back to strings.
         | 
| 113 | 
            +
                preprocessor.generate_postprocess({
         | 
| 114 | 
            +
                    'token_ids': np.array([[2, 818, 3823, 8864, 37423, 32694, 236761, 0]]),
         | 
| 115 | 
            +
                    'padding_mask': np.array([[ 1, 1, 1, 1, 1, 1, 1, 0]]),
         | 
| 116 | 
            +
                })
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                # === Vision and Language Gemma3 model ===
         | 
| 119 | 
            +
                # Load the preprocessor from a preset.
         | 
| 120 | 
            +
                preprocessor = keras_hub.models.Gemma3CausalLMPreprocessor.from_preset(
         | 
| 121 | 
            +
                    "gemma3_instruct_4b"
         | 
| 122 | 
            +
                )
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                # text-only inputs (unbatched)
         | 
| 125 | 
            +
                preprocessor(
         | 
| 126 | 
            +
                    {
         | 
| 127 | 
            +
                        "prompts": "What is the capital of India?",
         | 
| 128 | 
            +
                        "responses": "New Delhi",
         | 
| 129 | 
            +
                    }
         | 
| 130 | 
            +
                )
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                # text-only inputs (batched)
         | 
| 133 | 
            +
                preprocessor(
         | 
| 134 | 
            +
                    {
         | 
| 135 | 
            +
                        "prompts": [
         | 
| 136 | 
            +
                            "What is the capital of India?",
         | 
| 137 | 
            +
                            "What is the capital of Spain?"
         | 
| 138 | 
            +
                        ],
         | 
| 139 | 
            +
                        "responses": ["New Delhi", "Madrid"],
         | 
| 140 | 
            +
                    }
         | 
| 141 | 
            +
                )
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                # Unbatched inputs, with one image.
         | 
| 144 | 
            +
                preprocessor(
         | 
| 145 | 
            +
                    {
         | 
| 146 | 
            +
                        "prompts": "this is a lily <start_of_image>",
         | 
| 147 | 
            +
                        "responses": "pristine!",
         | 
| 148 | 
            +
                        "images": np.ones((896, 896, 3), dtype="float32")
         | 
| 149 | 
            +
                    }
         | 
| 150 | 
            +
                )
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                # Unbatched inputs, with two images.
         | 
| 153 | 
            +
                preprocessor(
         | 
| 154 | 
            +
                    {
         | 
| 155 | 
            +
                        "prompts": "lily: <start_of_image>, sunflower: <start_of_image>",
         | 
| 156 | 
            +
                        "responses": "pristine!",
         | 
| 157 | 
            +
                        "images": [
         | 
| 158 | 
            +
                            np.ones((896, 896, 3), dtype="float32"),
         | 
| 159 | 
            +
                            np.ones((896, 896, 3), dtype="float32")
         | 
| 160 | 
            +
                        ],
         | 
| 161 | 
            +
                    }
         | 
| 162 | 
            +
                )
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                # Batched inputs, one image per prompt.
         | 
| 165 | 
            +
                preprocessor(
         | 
| 166 | 
            +
                    {
         | 
| 167 | 
            +
                        "prompts": [
         | 
| 168 | 
            +
                            "this is a lily: <start_of_image>",
         | 
| 169 | 
            +
                            "this is a sunflower: <start_of_image>"
         | 
| 170 | 
            +
                        ],
         | 
| 171 | 
            +
                        "responses": ["pristine!", "radiant!"],
         | 
| 172 | 
            +
                        "images": [
         | 
| 173 | 
            +
                            np.ones((896, 896, 3), dtype="float32"),
         | 
| 174 | 
            +
                            np.ones((896, 896, 3), dtype="float32")
         | 
| 175 | 
            +
                        ]
         | 
| 176 | 
            +
                    }
         | 
| 177 | 
            +
                )
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                # Can also be written this way.
         | 
| 180 | 
            +
                preprocessor(
         | 
| 181 | 
            +
                    {
         | 
| 182 | 
            +
                        "prompts": [
         | 
| 183 | 
            +
                            "this is a lily: <start_of_image>",
         | 
| 184 | 
            +
                            "this is a sunflower: <start_of_image>"
         | 
| 185 | 
            +
                        ],
         | 
| 186 | 
            +
                        "responses": ["pristine!", "radiant!"],
         | 
| 187 | 
            +
                        "images": [
         | 
| 188 | 
            +
                            [np.ones((896, 896, 3), dtype="float32")],
         | 
| 189 | 
            +
                            [np.ones((896, 896, 3), dtype="float32")]
         | 
| 190 | 
            +
                        ]
         | 
| 191 | 
            +
                    }
         | 
| 192 | 
            +
                )
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                # Different number of images in every sample.
         | 
| 195 | 
            +
                preprocessor(
         | 
| 196 | 
            +
                    {
         | 
| 197 | 
            +
                        "prompts": [
         | 
| 198 | 
            +
                            "Who is this singer: <start_of_image>?",
         | 
| 199 | 
            +
                            "Who are these musicians <start_of_image>, <start_of_image>?"
         | 
| 200 | 
            +
                        ],
         | 
| 201 | 
            +
                        "responses": ["Arijit Singh", "John Lennon, Paul Mccartney"],
         | 
| 202 | 
            +
                        "images": [
         | 
| 203 | 
            +
                            [
         | 
| 204 | 
            +
                                np.ones((896, 896, 3), dtype="float32"),
         | 
| 205 | 
            +
                                np.ones((896, 896, 3), dtype="float32")
         | 
| 206 | 
            +
                            ],
         | 
| 207 | 
            +
                            [np.ones((896, 896, 3), dtype="float32")]
         | 
| 208 | 
            +
                        ]
         | 
| 209 | 
            +
                    }
         | 
| 210 | 
            +
                )
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                # Apply preprocessing to a `tf.data.Dataset`.
         | 
| 213 | 
            +
                inputs = {
         | 
| 214 | 
            +
                    "prompts": [
         | 
| 215 | 
            +
                        "Who are these two: <start_of_image>, <start_of_image>",
         | 
| 216 | 
            +
                        "Who is this: <start_of_image>?",
         | 
| 217 | 
            +
                        "What is the capital of India?"
         | 
| 218 | 
            +
                    ],
         | 
| 219 | 
            +
                    "responses": [
         | 
| 220 | 
            +
                        "John Lennon, Paul Mccartney",
         | 
| 221 | 
            +
                        "Arijit Singh",
         | 
| 222 | 
            +
                        "New Delhi"
         | 
| 223 | 
            +
                    ],
         | 
| 224 | 
            +
                    "images": (
         | 
| 225 | 
            +
                        tf.ragged.constant(
         | 
| 226 | 
            +
                            [
         | 
| 227 | 
            +
                                [np.ones((10, 10, 3)), np.ones((10, 10, 3))],
         | 
| 228 | 
            +
                                [np.ones((10, 10, 3))],
         | 
| 229 | 
            +
                                [],
         | 
| 230 | 
            +
                            ]
         | 
| 231 | 
            +
                        )
         | 
| 232 | 
            +
                    )
         | 
| 233 | 
            +
                }
         | 
| 234 | 
            +
                ds = tf.data.Dataset.from_tensor_slices(inputs)
         | 
| 235 | 
            +
                ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
         | 
| 86 236 | 
             
                ```
         | 
| 87 237 | 
             
                """
         | 
| 88 238 |  | 
| @@ -109,18 +259,34 @@ class Gemma3CausalLMPreprocessor(CausalLMPreprocessor): | |
| 109 259 | 
             
                        **kwargs,
         | 
| 110 260 | 
             
                    )
         | 
| 111 261 |  | 
| 112 | 
            -
                     | 
| 262 | 
            +
                    # Ensure `max_images_per_prompt * num_vision_tokens_per_image` is
         | 
| 263 | 
            +
                    # greater than `sequence_length`.
         | 
| 264 | 
            +
                    if (
         | 
| 265 | 
            +
                        image_converter is not None
         | 
| 266 | 
            +
                        and sequence_length
         | 
| 267 | 
            +
                        <= max_images_per_prompt * num_vision_tokens_per_image
         | 
| 268 | 
            +
                    ):
         | 
| 113 269 | 
             
                        raise ValueError(
         | 
| 114 | 
            -
                            " | 
| 115 | 
            -
                            " | 
| 270 | 
            +
                            "`sequence_length` should be greater than "
         | 
| 271 | 
            +
                            "`max_images_per_prompt * num_vision_tokens_per_image`."
         | 
| 272 | 
            +
                            f"Received: `sequence_length` = {sequence_length}"
         | 
| 273 | 
            +
                            f"`max_images_per_prompt` = {max_images_per_prompt}"
         | 
| 274 | 
            +
                            "`num_vision_tokens_per_image` = "
         | 
| 275 | 
            +
                            f"{num_vision_tokens_per_image}"
         | 
| 116 276 | 
             
                        )
         | 
| 117 277 |  | 
| 118 278 | 
             
                    self.image_converter = image_converter
         | 
| 119 279 | 
             
                    self.max_images_per_prompt = max_images_per_prompt
         | 
| 120 280 | 
             
                    self.num_vision_tokens_per_image = num_vision_tokens_per_image
         | 
| 121 281 |  | 
| 282 | 
            +
                    # The preprocessor and model are "text-only" if `self.image_converter`
         | 
| 283 | 
            +
                    # is `None`.
         | 
| 122 284 | 
             
                    self.text_only_model = self.image_converter is None
         | 
| 123 285 |  | 
| 286 | 
            +
                    self.image_placeholder = self.tokenizer.image_placeholder
         | 
| 287 | 
            +
                    self.start_of_image_token = self.tokenizer.start_of_image_token
         | 
| 288 | 
            +
                    self.end_of_image_token = self.tokenizer.end_of_image_token
         | 
| 289 | 
            +
             | 
| 124 290 | 
             
                def build(self, input_shape):
         | 
| 125 291 | 
             
                    # Defer packer creation to `build()` so that we can be sure tokenizer
         | 
| 126 292 | 
             
                    # assets have loaded when restoring a saved model.
         | 
| @@ -133,15 +299,77 @@ class Gemma3CausalLMPreprocessor(CausalLMPreprocessor): | |
| 133 299 | 
             
                    )
         | 
| 134 300 | 
             
                    self.built = True
         | 
| 135 301 |  | 
| 302 | 
            +
                def _get_vision_indices(self, vision_mask):
         | 
| 303 | 
            +
                    """Computes indices given vision mask, and pads with 0.
         | 
| 304 | 
            +
             | 
| 305 | 
            +
                    If `vision_mask` is
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                    ```
         | 
| 308 | 
            +
                    [
         | 
| 309 | 
            +
                        [False, True, True], [False, True, False], [False, False, False]
         | 
| 310 | 
            +
                    ]
         | 
| 311 | 
            +
                    ```
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                    , then the output will be:
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                    ```
         | 
| 316 | 
            +
                    [
         | 
| 317 | 
            +
                        [1, 2, 0], [1, 0, 0], [0, 0, 0]
         | 
| 318 | 
            +
                    ]
         | 
| 319 | 
            +
                    ```
         | 
| 320 | 
            +
                    """
         | 
| 321 | 
            +
                    batch_size, sequence_length = vision_mask.shape
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                    vision_mask_flattened = tf.reshape(vision_mask, [-1])
         | 
| 324 | 
            +
                    vision_indices = tf.where(vision_mask_flattened)[..., 0]
         | 
| 325 | 
            +
                    vision_indices = tf.cast(vision_indices, dtype=tf.int32)
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                    row_lengths = tf.math.reduce_sum(
         | 
| 328 | 
            +
                        tf.cast(vision_mask, dtype=vision_indices.dtype), axis=1
         | 
| 329 | 
            +
                    )
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                    batched_vision_indices = tf.RaggedTensor.from_row_lengths(
         | 
| 332 | 
            +
                        values=vision_indices,
         | 
| 333 | 
            +
                        row_lengths=row_lengths,
         | 
| 334 | 
            +
                    )
         | 
| 335 | 
            +
             | 
| 336 | 
            +
                    to_subtract = tf.math.scalar_mul(
         | 
| 337 | 
            +
                        scalar=tf.cast(sequence_length, dtype=tf.int32),
         | 
| 338 | 
            +
                        x=tf.range(
         | 
| 339 | 
            +
                            start=0,
         | 
| 340 | 
            +
                            limit=tf.shape(vision_mask)[0],
         | 
| 341 | 
            +
                            dtype=tf.int32,
         | 
| 342 | 
            +
                        ),
         | 
| 343 | 
            +
                    )
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                    # All indices should be independent of other samples in the batch. If
         | 
| 346 | 
            +
                    # not, and if we do sharding along the batch dimension for data
         | 
| 347 | 
            +
                    # parallel, things might get weird.
         | 
| 348 | 
            +
                    batched_vision_indices = tf.math.subtract(
         | 
| 349 | 
            +
                        batched_vision_indices,
         | 
| 350 | 
            +
                        tf.expand_dims(to_subtract, axis=-1),
         | 
| 351 | 
            +
                    )
         | 
| 352 | 
            +
             | 
| 353 | 
            +
                    # Pad the indices.
         | 
| 354 | 
            +
                    batched_vision_indices = batched_vision_indices.to_tensor(
         | 
| 355 | 
            +
                        shape=[
         | 
| 356 | 
            +
                            batch_size,
         | 
| 357 | 
            +
                            self.max_images_per_prompt * self.num_vision_tokens_per_image,
         | 
| 358 | 
            +
                        ],
         | 
| 359 | 
            +
                        default_value=0,
         | 
| 360 | 
            +
                    )
         | 
| 361 | 
            +
                    return batched_vision_indices
         | 
| 362 | 
            +
             | 
| 136 363 | 
             
                def _format_output(
         | 
| 137 364 | 
             
                    self,
         | 
| 138 365 | 
             
                    images,
         | 
| 139 366 | 
             
                    token_ids,
         | 
| 140 | 
            -
                     | 
| 367 | 
            +
                    vision_mask,
         | 
| 141 368 | 
             
                    response_mask,
         | 
| 142 369 | 
             
                    padding_mask,
         | 
| 143 370 | 
             
                    return_labels=False,
         | 
| 144 371 | 
             
                    text_only_input=False,
         | 
| 372 | 
            +
                    batched=False,
         | 
| 145 373 | 
             
                ):
         | 
| 146 374 | 
             
                    if return_labels:
         | 
| 147 375 | 
             
                        # Target `y` will be the next token.
         | 
| @@ -149,12 +377,13 @@ class Gemma3CausalLMPreprocessor(CausalLMPreprocessor): | |
| 149 377 | 
             
                        # Only compute the loss for labels in the response.
         | 
| 150 378 | 
             
                        sample_weight = response_mask[..., 1:]
         | 
| 151 379 |  | 
| 380 | 
            +
                        # The last token does not have a next token. So, remove it.
         | 
| 152 381 | 
             
                        token_ids = token_ids[..., :-1]
         | 
| 153 | 
            -
                         | 
| 382 | 
            +
                        vision_mask = vision_mask[..., :-1]
         | 
| 154 383 | 
             
                        response_mask = response_mask[..., :-1]
         | 
| 155 384 | 
             
                        padding_mask = padding_mask[..., :-1]
         | 
| 156 385 |  | 
| 157 | 
            -
                    batch_size | 
| 386 | 
            +
                    batch_size = tf.shape(vision_mask)[0]
         | 
| 158 387 |  | 
| 159 388 | 
             
                    if text_only_input:
         | 
| 160 389 | 
             
                        vision_indices = tf.ones(
         | 
| @@ -165,48 +394,109 @@ class Gemma3CausalLMPreprocessor(CausalLMPreprocessor): | |
| 165 394 | 
             
                            dtype=tf.int32,
         | 
| 166 395 | 
             
                        )
         | 
| 167 396 | 
             
                    else:
         | 
| 168 | 
            -
                         | 
| 169 | 
            -
                        flat_text_mask = tf.reshape(
         | 
| 170 | 
            -
                            text_mask, (batch_size * sequence_length)
         | 
| 171 | 
            -
                        )
         | 
| 172 | 
            -
                        vision_indices = tf.where(tf.logical_not(flat_text_mask))
         | 
| 173 | 
            -
                        vision_indices = tf.reshape(vision_indices, (batch_size, -1))
         | 
| 397 | 
            +
                        vision_indices = self._get_vision_indices(vision_mask=vision_mask)
         | 
| 174 398 |  | 
| 175 | 
            -
                    # The last token does not have a next token, so we truncate it out.
         | 
| 176 399 | 
             
                    x = {
         | 
| 177 400 | 
             
                        # Image
         | 
| 178 | 
            -
                        "images": images,
         | 
| 401 | 
            +
                        "images": images if batched else tf.squeeze(images, axis=0),
         | 
| 179 402 | 
             
                        # Text
         | 
| 180 | 
            -
                        "token_ids":  | 
| 181 | 
            -
             | 
| 182 | 
            -
                         | 
| 183 | 
            -
                        " | 
| 403 | 
            +
                        "token_ids": (
         | 
| 404 | 
            +
                            token_ids if batched else tf.squeeze(token_ids, axis=0)
         | 
| 405 | 
            +
                        ),
         | 
| 406 | 
            +
                        "vision_indices": (
         | 
| 407 | 
            +
                            vision_indices
         | 
| 408 | 
            +
                            if batched
         | 
| 409 | 
            +
                            else tf.squeeze(vision_indices, axis=0)
         | 
| 410 | 
            +
                        ),
         | 
| 411 | 
            +
                        # This mask is redundant information. But easier to compute it here
         | 
| 412 | 
            +
                        # than the model forward pass.
         | 
| 413 | 
            +
                        "vision_mask": (
         | 
| 414 | 
            +
                            vision_mask if batched else tf.squeeze(vision_mask, axis=0)
         | 
| 415 | 
            +
                        ),
         | 
| 416 | 
            +
                        "padding_mask": (
         | 
| 417 | 
            +
                            padding_mask if batched else tf.squeeze(padding_mask, axis=0)
         | 
| 418 | 
            +
                        ),
         | 
| 184 419 | 
             
                    }
         | 
| 185 420 |  | 
| 186 421 | 
             
                    if return_labels:
         | 
| 422 | 
            +
                        if not batched:
         | 
| 423 | 
            +
                            y = tf.squeeze(y, axis=0)
         | 
| 424 | 
            +
                            sample_weight = tf.squeeze(sample_weight, 0)
         | 
| 425 | 
            +
             | 
| 187 426 | 
             
                        return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
         | 
| 188 427 | 
             
                    else:
         | 
| 189 428 | 
             
                        return x
         | 
| 190 429 |  | 
| 191 | 
            -
                def  | 
| 192 | 
            -
                     | 
| 430 | 
            +
                def _preprocess_images(self, images, batched):
         | 
| 431 | 
            +
                    desired_height = self.image_converter.image_size[0]
         | 
| 432 | 
            +
                    desired_width = self.image_converter.image_size[1]
         | 
| 433 | 
            +
             | 
| 434 | 
            +
                    # Images can be lists/ragged tensors. We need to pad them/truncate them.
         | 
| 435 | 
            +
                    if isinstance(images, (list, np.ndarray)):
         | 
| 436 | 
            +
                        images = tf.ragged.constant(images)
         | 
| 437 | 
            +
                    elif isinstance(images, tf.RaggedTensor):
         | 
| 438 | 
            +
                        pass
         | 
| 439 | 
            +
                    elif isinstance(images, tf.Tensor):
         | 
| 440 | 
            +
                        images = tf.RaggedTensor.from_tensor(images)
         | 
| 441 | 
            +
                    else:
         | 
| 442 | 
            +
                        # Attempt to convert anyway. This handles the case where
         | 
| 443 | 
            +
                        # the inputs might be `jax.Array`, `torch.Tensor`. To check the
         | 
| 444 | 
            +
                        # type, we will have to import all three frameworks, which is
         | 
| 445 | 
            +
                        # undesirable.
         | 
| 446 | 
            +
                        try:
         | 
| 447 | 
            +
                            images = tf.RaggedTensor.from_tensor(images)
         | 
| 448 | 
            +
                        except:  # noqa: E722
         | 
| 449 | 
            +
                            raise ValueError(
         | 
| 450 | 
            +
                                "`images` should be a list, ragged tensor, dense tensor."
         | 
| 451 | 
            +
                                f"Received: `type(images)` = {type(images)}"
         | 
| 452 | 
            +
                            )
         | 
| 193 453 |  | 
| 194 | 
            -
                     | 
| 195 | 
            -
             | 
| 196 | 
            -
             | 
| 197 | 
            -
                     | 
| 198 | 
            -
                     | 
| 199 | 
            -
             | 
| 200 | 
            -
             | 
| 201 | 
            -
                     | 
| 202 | 
            -
                     | 
| 203 | 
            -
             | 
| 204 | 
            -
             | 
| 205 | 
            -
             | 
| 454 | 
            +
                    if not batched:
         | 
| 455 | 
            +
                        images = tf.expand_dims(images, axis=0)
         | 
| 456 | 
            +
             | 
| 457 | 
            +
                    # If the input is a list of images, instead of list of lists of images.
         | 
| 458 | 
            +
                    if len(images.shape) == 4:
         | 
| 459 | 
            +
                        images = tf.expand_dims(images, axis=1)
         | 
| 460 | 
            +
             | 
| 461 | 
            +
                    # Convert to dense tensor.
         | 
| 462 | 
            +
                    images = images.to_tensor(
         | 
| 463 | 
            +
                        shape=[None, self.max_images_per_prompt, None, None, 3],
         | 
| 464 | 
            +
                        default_value=0,
         | 
| 465 | 
            +
                    )
         | 
| 466 | 
            +
             | 
| 467 | 
            +
                    # Resize, rescale, etc. the images.
         | 
| 468 | 
            +
                    original_images_shape = tf.shape(images)
         | 
| 469 | 
            +
             | 
| 470 | 
            +
                    # Before passing through image converter, we need to collapse the
         | 
| 471 | 
            +
                    # first two dimensions (`batch_size`, `max_images_per_prompt`) into one.
         | 
| 472 | 
            +
                    images = tf.reshape(
         | 
| 473 | 
            +
                        images,
         | 
| 474 | 
            +
                        [
         | 
| 475 | 
            +
                            -1,
         | 
| 476 | 
            +
                            original_images_shape[-3],
         | 
| 477 | 
            +
                            original_images_shape[-2],
         | 
| 478 | 
            +
                            original_images_shape[-1],
         | 
| 479 | 
            +
                        ],
         | 
| 480 | 
            +
                    )
         | 
| 481 | 
            +
                    images = self.image_converter(images)
         | 
| 482 | 
            +
             | 
| 483 | 
            +
                    if keras.config.backend() == "torch" and not isinstance(
         | 
| 484 | 
            +
                        images, tf.Tensor
         | 
| 485 | 
            +
                    ):
         | 
| 486 | 
            +
                        images = images.cpu()
         | 
| 487 | 
            +
             | 
| 488 | 
            +
                    # Recover the rank.
         | 
| 489 | 
            +
                    images = tf.reshape(
         | 
| 490 | 
            +
                        images,
         | 
| 491 | 
            +
                        [
         | 
| 492 | 
            +
                            original_images_shape[0],
         | 
| 493 | 
            +
                            self.max_images_per_prompt,
         | 
| 494 | 
            +
                            desired_height,
         | 
| 495 | 
            +
                            desired_width,
         | 
| 496 | 
            +
                            original_images_shape[-1],
         | 
| 497 | 
            +
                        ],
         | 
| 206 498 | 
             
                    )
         | 
| 207 | 
            -
                     | 
| 208 | 
            -
                    ragged_tensor = tf.cast(ragged_tensor, tf.int32)
         | 
| 209 | 
            -
                    return ragged_tensor
         | 
| 499 | 
            +
                    return images
         | 
| 210 500 |  | 
| 211 501 | 
             
                @preprocessing_function
         | 
| 212 502 | 
             
                def call(
         | 
| @@ -218,52 +508,76 @@ class Gemma3CausalLMPreprocessor(CausalLMPreprocessor): | |
| 218 508 | 
             
                ):
         | 
| 219 509 | 
             
                    sequence_length = sequence_length or self.sequence_length
         | 
| 220 510 |  | 
| 511 | 
            +
                    # === Input extraction and validation ===
         | 
| 512 | 
            +
             | 
| 221 513 | 
             
                    # Extract text part of the input.
         | 
| 222 514 | 
             
                    prompts, responses = x["prompts"], x["responses"]
         | 
| 223 515 |  | 
| 516 | 
            +
                    # Find out if the input is batched/not batched. Uprank if not batched.
         | 
| 517 | 
            +
                    # In other preprocessors, we don't have to do this, but here, all
         | 
| 518 | 
            +
                    # the following logic (indices, etc.) uses tensors with a batch dim.
         | 
| 519 | 
            +
                    # We will squeeze these back at the end.
         | 
| 520 | 
            +
                    batched = True
         | 
| 521 | 
            +
                    if isinstance(prompts, str):
         | 
| 522 | 
            +
                        batched = False
         | 
| 523 | 
            +
                        prompts = [prompts]
         | 
| 524 | 
            +
                        responses = [responses]
         | 
| 525 | 
            +
                    if isinstance(prompts, tf.Tensor) and len(prompts.shape) == 0:
         | 
| 526 | 
            +
                        batched = False
         | 
| 527 | 
            +
                        prompts = tf.expand_dims(prompts, axis=0)
         | 
| 528 | 
            +
                        responses = tf.expand_dims(responses, axis=0)
         | 
| 529 | 
            +
             | 
| 224 530 | 
             
                    # Extract images from the input.
         | 
| 225 531 | 
             
                    images = x.get("images", None)
         | 
| 226 | 
            -
                    num_valid_images = x.get("num_valid_images", None)
         | 
| 227 532 |  | 
| 228 | 
            -
                     | 
| 229 | 
            -
             | 
| 230 | 
            -
             | 
| 231 | 
            -
             | 
| 232 | 
            -
             | 
| 233 | 
            -
             | 
| 234 | 
            -
                     | 
| 235 | 
            -
             | 
| 236 | 
            -
             | 
| 533 | 
            +
                    # There are 8 cases, based on values of
         | 
| 534 | 
            +
                    # a = `self.text_only_model`, b = `images` is `None`, and whether
         | 
| 535 | 
            +
                    # c = `<start_of_image>` token is present in `prompts`.
         | 
| 536 | 
            +
                    # F F F, F F T -> Raise error if #`<start_of_image>` <0,  or
         | 
| 537 | 
            +
                    # > `max_images_per_prompt`.
         | 
| 538 | 
            +
                    # F T F -> Return empty images and vision indices
         | 
| 539 | 
            +
                    # F T T -> Return empty images and vision indices to the model.
         | 
| 540 | 
            +
                    # T F F, T F T -> Raise error.
         | 
| 541 | 
            +
                    # T T F -> Only token IDs and padding mask are returned.
         | 
| 542 | 
            +
                    # T T T -> Only token IDs and padding mask are returned.
         | 
| 543 | 
            +
             | 
| 544 | 
            +
                    if self.text_only_model and images is not None:
         | 
| 545 | 
            +
                        raise ValueError(
         | 
| 546 | 
            +
                            "The initialized preprocessor/model is text-only, but "
         | 
| 547 | 
            +
                            " `images` is not `None`."
         | 
| 548 | 
            +
                        )
         | 
| 549 | 
            +
             | 
| 550 | 
            +
                    # Add image placeholder tokens. Replace `"<start_of_image>"` in
         | 
| 551 | 
            +
                    # prompts with
         | 
| 552 | 
            +
                    # `"\n\n<start_of_image> <img> * 256 <end_of_image>\n\n"`.
         | 
| 553 | 
            +
                    if not self.text_only_model:
         | 
| 237 554 | 
             
                        prompts = tf.strings.regex_replace(
         | 
| 238 555 | 
             
                            prompts,
         | 
| 239 | 
            -
                             | 
| 240 | 
            -
                            f"\n\n{ | 
| 241 | 
            -
                            +  | 
| 242 | 
            -
                            + f"{ | 
| 556 | 
            +
                            self.start_of_image_token,
         | 
| 557 | 
            +
                            f"\n\n{self.start_of_image_token}"
         | 
| 558 | 
            +
                            + self.image_placeholder * self.num_vision_tokens_per_image
         | 
| 559 | 
            +
                            + f"{self.end_of_image_token}\n\n",
         | 
| 243 560 | 
             
                        )
         | 
| 244 561 |  | 
| 562 | 
            +
                    # === Tokenization, padding, etc. ===
         | 
| 563 | 
            +
             | 
| 245 564 | 
             
                    # Tokenise the inputs.
         | 
| 246 565 | 
             
                    prompts = self.tokenizer(prompts)
         | 
| 247 566 | 
             
                    responses = self.tokenizer(responses)
         | 
| 248 567 |  | 
| 249 | 
            -
                    #  | 
| 250 | 
            -
                    # the dummy placeholder image tokens which we will add at the end.
         | 
| 251 | 
            -
                    # Hence, we use a packer only on the text part first, and then
         | 
| 252 | 
            -
                    # add the padded dummy placeholder tokens separately.
         | 
| 568 | 
            +
                    # Padding.
         | 
| 253 569 | 
             
                    token_ids, segment_ids = self.packer(
         | 
| 254 570 | 
             
                        (prompts, responses),
         | 
| 255 | 
            -
                        sequence_length=sequence_length
         | 
| 256 | 
            -
                        if images is not None
         | 
| 257 | 
            -
                        else sequence_length + 1,
         | 
| 571 | 
            +
                        sequence_length=sequence_length + 1,
         | 
| 258 572 | 
             
                        add_start_value=self.add_start_token,
         | 
| 259 573 | 
             
                        add_end_value=self.add_end_token,
         | 
| 260 574 | 
             
                    )
         | 
| 575 | 
            +
                    response_mask = segment_ids == 1
         | 
| 576 | 
            +
                    padding_mask = token_ids != self.tokenizer.pad_token_id
         | 
| 261 577 |  | 
| 262 | 
            -
                    #  | 
| 578 | 
            +
                    # === Text Model ===
         | 
| 263 579 | 
             
                    if self.text_only_model:
         | 
| 264 580 | 
             
                        # The last token does not have a next token, so we truncate it out.
         | 
| 265 | 
            -
                        response_mask = segment_ids == 1
         | 
| 266 | 
            -
                        padding_mask = token_ids != self.tokenizer.pad_token_id
         | 
| 267 581 | 
             
                        x = {
         | 
| 268 582 | 
             
                            "token_ids": token_ids[..., :-1],
         | 
| 269 583 | 
             
                            "padding_mask": padding_mask[..., :-1],
         | 
| @@ -273,162 +587,68 @@ class Gemma3CausalLMPreprocessor(CausalLMPreprocessor): | |
| 273 587 | 
             
                        y = token_ids[..., 1:]
         | 
| 274 588 | 
             
                        # Only compute the loss for labels in the response.
         | 
| 275 589 | 
             
                        sample_weight = response_mask[..., 1:]
         | 
| 590 | 
            +
             | 
| 591 | 
            +
                        # Squeeze if not batched.
         | 
| 592 | 
            +
                        if not batched:
         | 
| 593 | 
            +
                            x["token_ids"] = tf.squeeze(x["token_ids"], axis=0)
         | 
| 594 | 
            +
                            x["padding_mask"] = tf.squeeze(x["padding_mask"], axis=0)
         | 
| 595 | 
            +
                            y = tf.squeeze(y, axis=0)
         | 
| 596 | 
            +
                            sample_weight = tf.squeeze(sample_weight, axis=0)
         | 
| 597 | 
            +
             | 
| 276 598 | 
             
                        return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
         | 
| 277 599 |  | 
| 278 | 
            -
                    # Vision  | 
| 600 | 
            +
                    # === Vision processing ===
         | 
| 601 | 
            +
             | 
| 279 602 | 
             
                    batch_size = tf.shape(prompts)[0]
         | 
| 603 | 
            +
                    desired_height = self.image_converter.image_size[0]
         | 
| 604 | 
            +
                    desired_width = self.image_converter.image_size[1]
         | 
| 280 605 | 
             
                    if images is None:
         | 
| 606 | 
            +
                        # == Branch: vision model, with `None` value for `images` ==
         | 
| 607 | 
            +
             | 
| 281 608 | 
             
                        # To handle the text-only input case, we need to pass an empty
         | 
| 282 | 
            -
                        # tensor so as to skip the vision  | 
| 609 | 
            +
                        # tensor so as to skip the vision layers of the model.
         | 
| 610 | 
            +
             | 
| 611 | 
            +
                        # TODO: Once functional models accept `None` inputs, consider
         | 
| 612 | 
            +
                        # passing this as `None` directly.
         | 
| 283 613 | 
             
                        images = tf.ones(
         | 
| 284 614 | 
             
                            shape=[
         | 
| 285 615 | 
             
                                batch_size,
         | 
| 286 616 | 
             
                                0,
         | 
| 287 | 
            -
                                 | 
| 288 | 
            -
                                 | 
| 617 | 
            +
                                desired_height,
         | 
| 618 | 
            +
                                desired_width,
         | 
| 289 619 | 
             
                                3,
         | 
| 290 620 | 
             
                            ],
         | 
| 291 621 | 
             
                            dtype="float32",
         | 
| 292 622 | 
             
                        )
         | 
| 293 623 |  | 
| 294 | 
            -
                         | 
| 295 | 
            -
                        padding_mask = token_ids != self.tokenizer.pad_token_id
         | 
| 296 | 
            -
                        response_mask = segment_ids == 1
         | 
| 624 | 
            +
                        vision_mask = tf.zeros_like(token_ids, dtype=bool)
         | 
| 297 625 |  | 
| 298 626 | 
             
                        return self._format_output(
         | 
| 299 627 | 
             
                            images=images,
         | 
| 300 628 | 
             
                            token_ids=token_ids,
         | 
| 301 | 
            -
                             | 
| 629 | 
            +
                            vision_mask=vision_mask,
         | 
| 302 630 | 
             
                            response_mask=response_mask,
         | 
| 303 631 | 
             
                            padding_mask=padding_mask,
         | 
| 304 632 | 
             
                            return_labels=True,
         | 
| 305 633 | 
             
                            text_only_input=True,
         | 
| 634 | 
            +
                            batched=batched,
         | 
| 306 635 | 
             
                        )
         | 
| 307 636 |  | 
| 308 | 
            -
                     | 
| 309 | 
            -
                    if num_valid_images is None:
         | 
| 310 | 
            -
                        num_valid_images = tf.fill(
         | 
| 311 | 
            -
                            dims=(batch_size,),
         | 
| 312 | 
            -
                            value=self.max_images_per_prompt,
         | 
| 313 | 
            -
                        )
         | 
| 637 | 
            +
                    # == Branch: vision model, with non-`None` value for `images` ==
         | 
| 314 638 |  | 
| 315 | 
            -
                     | 
| 316 | 
            -
                    if original_image_shape[1] != self.max_images_per_prompt:
         | 
| 317 | 
            -
                        raise ValueError(
         | 
| 318 | 
            -
                            "The number of images per sample should be the same as "
         | 
| 319 | 
            -
                            "`max_images_per_prompt`. Received: "
         | 
| 320 | 
            -
                            f"images.shape = {original_image_shape}, "
         | 
| 321 | 
            -
                            f"max_images_per_prompt = {self.max_images_per_prompt}"
         | 
| 322 | 
            -
                        )
         | 
| 323 | 
            -
                    if tf.cast(
         | 
| 324 | 
            -
                        tf.math.reduce_sum(
         | 
| 325 | 
            -
                            tf.cast(
         | 
| 326 | 
            -
                                tf.math.greater(
         | 
| 327 | 
            -
                                    num_valid_images, self.max_images_per_prompt
         | 
| 328 | 
            -
                                ),
         | 
| 329 | 
            -
                                dtype=tf.int32,
         | 
| 330 | 
            -
                            )
         | 
| 331 | 
            -
                        ),
         | 
| 332 | 
            -
                        dtype=bool,
         | 
| 333 | 
            -
                    ):
         | 
| 334 | 
            -
                        raise ValueError(
         | 
| 335 | 
            -
                            "`num_valid_images` should have values <= "
         | 
| 336 | 
            -
                            "self.max_images_per_prompt. Received: "
         | 
| 337 | 
            -
                            f"num_valid_images = {num_valid_images}, ",
         | 
| 338 | 
            -
                            f"max_images_per_prompt = {self.max_images_per_prompt}",
         | 
| 339 | 
            -
                        )
         | 
| 639 | 
            +
                    images = self._preprocess_images(images=images, batched=batched)
         | 
| 340 640 |  | 
| 341 | 
            -
                     | 
| 342 | 
            -
                    padded_images_shape = tf.shape(images)
         | 
| 343 | 
            -
                    images = tf.reshape(
         | 
| 344 | 
            -
                        images,
         | 
| 345 | 
            -
                        [
         | 
| 346 | 
            -
                            -1,
         | 
| 347 | 
            -
                            padded_images_shape[-3],
         | 
| 348 | 
            -
                            padded_images_shape[-2],
         | 
| 349 | 
            -
                            padded_images_shape[-1],
         | 
| 350 | 
            -
                        ],
         | 
| 351 | 
            -
                    )
         | 
| 352 | 
            -
                    images = self.image_converter(images)
         | 
| 353 | 
            -
                    height = (
         | 
| 354 | 
            -
                        self.image_size[0]
         | 
| 355 | 
            -
                        if self.image_converter.image_size
         | 
| 356 | 
            -
                        else original_image_shape[-3]
         | 
| 357 | 
            -
                    )
         | 
| 358 | 
            -
                    width = (
         | 
| 359 | 
            -
                        self.image_size[1]
         | 
| 360 | 
            -
                        if self.image_converter.image_size
         | 
| 361 | 
            -
                        else original_image_shape[-2]
         | 
| 362 | 
            -
                    )
         | 
| 363 | 
            -
                    images = tf.reshape(
         | 
| 364 | 
            -
                        images,
         | 
| 365 | 
            -
                        [
         | 
| 366 | 
            -
                            padded_images_shape[0],
         | 
| 367 | 
            -
                            self.max_images_per_prompt,
         | 
| 368 | 
            -
                            height,
         | 
| 369 | 
            -
                            width,
         | 
| 370 | 
            -
                            3,
         | 
| 371 | 
            -
                        ],
         | 
| 372 | 
            -
                    )
         | 
| 373 | 
            -
             | 
| 374 | 
            -
                    # Format tokens.
         | 
| 375 | 
            -
                    padding_mask = token_ids != self.tokenizer.pad_token_id
         | 
| 376 | 
            -
                    token_ids = tf.ragged.boolean_mask(token_ids, padding_mask)
         | 
| 377 | 
            -
                    segment_ids = tf.ragged.boolean_mask(segment_ids, padding_mask)
         | 
| 378 | 
            -
                    padding_mask = tf.ragged.boolean_mask(padding_mask, padding_mask)
         | 
| 379 | 
            -
                    response_mask = segment_ids == 1
         | 
| 380 | 
            -
             | 
| 381 | 
            -
                    # Using `num_valid_images`, we need to add dummy image tokens at the
         | 
| 382 | 
            -
                    # end of the tokenized text. Ideally, we could have passed an image
         | 
| 383 | 
            -
                    # padding mask to the model, but it won't work with XLA since an
         | 
| 384 | 
            -
                    # `ops.where` on it in the interleaving layer will return different
         | 
| 385 | 
            -
                    # number of images every time. So, we need to fix the number of images.
         | 
| 386 | 
            -
                    vision_placeholder_tensor = self._get_image_placeholder_ragged_tensor(
         | 
| 387 | 
            -
                        (self.max_images_per_prompt - num_valid_images)
         | 
| 388 | 
            -
                        * self.num_vision_tokens_per_image,
         | 
| 389 | 
            -
                        self.tokenizer.token_to_id("<img>"),
         | 
| 390 | 
            -
                    )
         | 
| 391 | 
            -
                    vision_placeholder_tensor = vision_placeholder_tensor.to_tensor(
         | 
| 392 | 
            -
                        shape=[
         | 
| 393 | 
            -
                            batch_size,
         | 
| 394 | 
            -
                            self.max_images_per_prompt * self.num_vision_tokens_per_image,
         | 
| 395 | 
            -
                        ],
         | 
| 396 | 
            -
                        default_value=self.tokenizer.pad_token_id,
         | 
| 397 | 
            -
                    )
         | 
| 398 | 
            -
             | 
| 399 | 
            -
                    token_ids_with_placeholder = tf.concat(
         | 
| 400 | 
            -
                        [token_ids, vision_placeholder_tensor], axis=1
         | 
| 401 | 
            -
                    )
         | 
| 402 | 
            -
             | 
| 403 | 
            -
                    # Now, pad everything to the same length.
         | 
| 404 | 
            -
                    desired_length = (
         | 
| 405 | 
            -
                        sequence_length
         | 
| 406 | 
            -
                        + self.max_images_per_prompt * self.num_vision_tokens_per_image
         | 
| 407 | 
            -
                    )
         | 
| 408 | 
            -
                    token_ids_with_placeholder = token_ids_with_placeholder.to_tensor(
         | 
| 409 | 
            -
                        shape=[batch_size, desired_length + 1],
         | 
| 410 | 
            -
                        default_value=self.tokenizer.pad_token_id,
         | 
| 411 | 
            -
                    )
         | 
| 412 | 
            -
                    padding_mask_with_placeholder = padding_mask.to_tensor(
         | 
| 413 | 
            -
                        shape=[batch_size, desired_length + 1],
         | 
| 414 | 
            -
                        default_value=False,
         | 
| 415 | 
            -
                    )
         | 
| 416 | 
            -
                    response_mask_with_placeholder = response_mask.to_tensor(
         | 
| 417 | 
            -
                        shape=[batch_size, desired_length + 1],
         | 
| 418 | 
            -
                        default_value=False,
         | 
| 419 | 
            -
                    )
         | 
| 420 | 
            -
             | 
| 421 | 
            -
                    text_mask = token_ids_with_placeholder != self.tokenizer.token_to_id(
         | 
| 422 | 
            -
                        "<img>"
         | 
| 423 | 
            -
                    )
         | 
| 641 | 
            +
                    vision_mask = token_ids == self.tokenizer.image_placeholder_id
         | 
| 424 642 |  | 
| 425 643 | 
             
                    return self._format_output(
         | 
| 426 644 | 
             
                        images=images,
         | 
| 427 | 
            -
                        token_ids= | 
| 428 | 
            -
                         | 
| 429 | 
            -
                        response_mask= | 
| 430 | 
            -
                        padding_mask= | 
| 645 | 
            +
                        token_ids=token_ids,
         | 
| 646 | 
            +
                        vision_mask=vision_mask,
         | 
| 647 | 
            +
                        response_mask=response_mask,
         | 
| 648 | 
            +
                        padding_mask=padding_mask,
         | 
| 431 649 | 
             
                        return_labels=True,
         | 
| 650 | 
            +
                        text_only_input=False,
         | 
| 651 | 
            +
                        batched=batched,
         | 
| 432 652 | 
             
                    )
         | 
| 433 653 |  | 
| 434 654 | 
             
                @preprocessing_function
         | 
| @@ -448,39 +668,59 @@ class Gemma3CausalLMPreprocessor(CausalLMPreprocessor): | |
| 448 668 | 
             
                    the sequence (as generation is expected to continue at the end of the
         | 
| 449 669 | 
             
                    inputted prompt).
         | 
| 450 670 | 
             
                    """
         | 
| 671 | 
            +
             | 
| 451 672 | 
             
                    if not self.built:
         | 
| 452 673 | 
             
                        self.build(None)
         | 
| 453 674 |  | 
| 675 | 
            +
                    # Extract inputs.
         | 
| 454 676 | 
             
                    if isinstance(x, dict):
         | 
| 455 677 | 
             
                        images = x.get("images", None)
         | 
| 456 | 
            -
             | 
| 678 | 
            +
             | 
| 457 679 | 
             
                        # TODO: do we even need `responses` for generation? Makes sense for
         | 
| 458 | 
            -
                        # finetuning (i.e., `call()`).
         | 
| 680 | 
            +
                        # finetuning only (i.e., `call()`).
         | 
| 459 681 | 
             
                        responses = x.get("responses", None)
         | 
| 460 682 | 
             
                        prompts = x["prompts"]
         | 
| 461 683 | 
             
                    else:
         | 
| 462 684 | 
             
                        images = None
         | 
| 463 | 
            -
                        num_valid_images = None
         | 
| 464 685 | 
             
                        responses = None
         | 
| 465 686 | 
             
                        prompts = x
         | 
| 466 687 |  | 
| 467 | 
            -
                    if  | 
| 468 | 
            -
             | 
| 469 | 
            -
             | 
| 470 | 
            -
             | 
| 471 | 
            -
             | 
| 472 | 
            -
             | 
| 473 | 
            -
             | 
| 474 | 
            -
                         | 
| 475 | 
            -
                         | 
| 688 | 
            +
                    # Find out if the input is batched/not batched. Uprank if not batched.
         | 
| 689 | 
            +
                    # In other preprocessors, we don't have to do this, but here, all
         | 
| 690 | 
            +
                    # the following logic (indices, etc.) uses tensors with a batch dim.
         | 
| 691 | 
            +
                    # We will squeeze these back at the end.
         | 
| 692 | 
            +
                    batched = True
         | 
| 693 | 
            +
                    if isinstance(prompts, str):
         | 
| 694 | 
            +
                        batched = False
         | 
| 695 | 
            +
                        prompts = [prompts]
         | 
| 696 | 
            +
                        if responses is not None:
         | 
| 697 | 
            +
                            responses = [responses]
         | 
| 698 | 
            +
                    if isinstance(prompts, tf.Tensor) and len(prompts.shape) == 0:
         | 
| 699 | 
            +
                        batched = False
         | 
| 700 | 
            +
                        prompts = tf.expand_dims(prompts, axis=0)
         | 
| 701 | 
            +
                        if responses is not None:
         | 
| 702 | 
            +
                            responses = tf.expand_dims(responses, axis=0)
         | 
| 703 | 
            +
             | 
| 704 | 
            +
                    # We have the same 8 cases here, as in `call()`.
         | 
| 705 | 
            +
                    if self.text_only_model and images is not None:
         | 
| 706 | 
            +
                        raise ValueError(
         | 
| 707 | 
            +
                            "The initialized preprocessor/model is text-only, but "
         | 
| 708 | 
            +
                            " `images` is not `None`."
         | 
| 709 | 
            +
                        )
         | 
| 710 | 
            +
             | 
| 711 | 
            +
                    # Add image placeholder tokens. Replace `"<start_of_image>"` in
         | 
| 712 | 
            +
                    # prompts with
         | 
| 713 | 
            +
                    # `"\n\n<start_of_image> <img> * 256 <end_of_image>\n\n"`.
         | 
| 714 | 
            +
                    if not self.text_only_model:
         | 
| 476 715 | 
             
                        prompts = tf.strings.regex_replace(
         | 
| 477 716 | 
             
                            prompts,
         | 
| 478 | 
            -
                             | 
| 479 | 
            -
                            f"\n\n{ | 
| 480 | 
            -
                            +  | 
| 481 | 
            -
                            + f"{ | 
| 717 | 
            +
                            self.start_of_image_token,
         | 
| 718 | 
            +
                            f"\n\n{self.start_of_image_token}"
         | 
| 719 | 
            +
                            + self.image_placeholder * self.num_vision_tokens_per_image
         | 
| 720 | 
            +
                            + f"{self.end_of_image_token}\n\n",
         | 
| 482 721 | 
             
                        )
         | 
| 483 722 |  | 
| 723 | 
            +
                    # === Tokenization, padding, etc. ===
         | 
| 484 724 | 
             
                    prompts = self.tokenizer(prompts)
         | 
| 485 725 |  | 
| 486 726 | 
             
                    if responses is not None:
         | 
| @@ -489,174 +729,79 @@ class Gemma3CausalLMPreprocessor(CausalLMPreprocessor): | |
| 489 729 | 
             
                    else:
         | 
| 490 730 | 
             
                        segments = (prompts,)
         | 
| 491 731 |  | 
| 732 | 
            +
                    # Padding.
         | 
| 492 733 | 
             
                    token_ids, segment_ids = self.packer(
         | 
| 493 734 | 
             
                        segments,
         | 
| 494 735 | 
             
                        sequence_length=sequence_length,
         | 
| 495 736 | 
             
                        add_end_value=False,
         | 
| 496 737 | 
             
                    )
         | 
| 738 | 
            +
                    response_mask = segment_ids == 1
         | 
| 739 | 
            +
                    padding_mask = token_ids != self.tokenizer.pad_token_id
         | 
| 497 740 |  | 
| 498 | 
            -
                    #  | 
| 741 | 
            +
                    # === Text Model ===
         | 
| 499 742 | 
             
                    if self.text_only_model:
         | 
| 500 | 
            -
                        response_mask = segment_ids == 1
         | 
| 501 | 
            -
                        padding_mask = token_ids != self.tokenizer.pad_token_id
         | 
| 502 743 | 
             
                        return {
         | 
| 503 | 
            -
                            "token_ids":  | 
| 504 | 
            -
             | 
| 744 | 
            +
                            "token_ids": (
         | 
| 745 | 
            +
                                token_ids if batched else tf.squeeze(token_ids, axis=0)
         | 
| 746 | 
            +
                            ),
         | 
| 747 | 
            +
                            "padding_mask": (
         | 
| 748 | 
            +
                                padding_mask
         | 
| 749 | 
            +
                                if batched
         | 
| 750 | 
            +
                                else tf.squeeze(padding_mask, axis=0)
         | 
| 751 | 
            +
                            ),
         | 
| 505 752 | 
             
                        }
         | 
| 506 753 |  | 
| 507 | 
            -
                    # Vision  | 
| 754 | 
            +
                    # === Vision processing ===
         | 
| 755 | 
            +
             | 
| 508 756 | 
             
                    batch_size = tf.shape(prompts)[0]
         | 
| 757 | 
            +
                    desired_height = self.image_converter.image_size[0]
         | 
| 758 | 
            +
                    desired_width = self.image_converter.image_size[1]
         | 
| 509 759 | 
             
                    if images is None:
         | 
| 760 | 
            +
                        # == Branch: vision model, with `None` value for `images` ==
         | 
| 761 | 
            +
             | 
| 510 762 | 
             
                        # To handle the text-only input case, we need to pass an empty
         | 
| 511 | 
            -
                        # tensor so as to skip the vision  | 
| 763 | 
            +
                        # tensor so as to skip the vision layers of the model.
         | 
| 764 | 
            +
             | 
| 765 | 
            +
                        # TODO: Once functional models accept `None` inputs, consider
         | 
| 766 | 
            +
                        # passing this as `None` directly.
         | 
| 512 767 | 
             
                        images = tf.ones(
         | 
| 513 768 | 
             
                            shape=[
         | 
| 514 769 | 
             
                                batch_size,
         | 
| 515 770 | 
             
                                0,
         | 
| 516 | 
            -
                                 | 
| 517 | 
            -
                                 | 
| 771 | 
            +
                                desired_height,
         | 
| 772 | 
            +
                                desired_width,
         | 
| 518 773 | 
             
                                3,
         | 
| 519 774 | 
             
                            ],
         | 
| 520 775 | 
             
                            dtype="float32",
         | 
| 521 776 | 
             
                        )
         | 
| 522 777 |  | 
| 523 | 
            -
                         | 
| 524 | 
            -
                        padding_mask = token_ids != self.tokenizer.pad_token_id
         | 
| 525 | 
            -
                        response_mask = segment_ids == 1
         | 
| 778 | 
            +
                        vision_mask = tf.zeros_like(token_ids, dtype=bool)
         | 
| 526 779 |  | 
| 527 780 | 
             
                        return self._format_output(
         | 
| 528 781 | 
             
                            images=images,
         | 
| 529 782 | 
             
                            token_ids=token_ids,
         | 
| 530 | 
            -
                             | 
| 783 | 
            +
                            vision_mask=vision_mask,
         | 
| 531 784 | 
             
                            response_mask=response_mask,
         | 
| 532 785 | 
             
                            padding_mask=padding_mask,
         | 
| 533 786 | 
             
                            return_labels=False,
         | 
| 534 787 | 
             
                            text_only_input=True,
         | 
| 788 | 
            +
                            batched=batched,
         | 
| 535 789 | 
             
                        )
         | 
| 536 790 |  | 
| 537 | 
            -
                    #  | 
| 538 | 
            -
                     | 
| 539 | 
            -
                    if num_valid_images is None:
         | 
| 540 | 
            -
                        num_valid_images = tf.fill(
         | 
| 541 | 
            -
                            dims=(batch_size,),
         | 
| 542 | 
            -
                            value=self.max_images_per_prompt,
         | 
| 543 | 
            -
                        )
         | 
| 544 | 
            -
             | 
| 545 | 
            -
                    # Image inputs checks.
         | 
| 546 | 
            -
                    if original_image_shape[1] != self.max_images_per_prompt:
         | 
| 547 | 
            -
                        raise ValueError(
         | 
| 548 | 
            -
                            "The number of images per sample should be the same as "
         | 
| 549 | 
            -
                            "`max_images_per_prompt`. Received: "
         | 
| 550 | 
            -
                            f"images.shape = {original_image_shape}, "
         | 
| 551 | 
            -
                            f"max_images_per_prompt = {self.max_images_per_prompt}"
         | 
| 552 | 
            -
                        )
         | 
| 553 | 
            -
                    if tf.cast(
         | 
| 554 | 
            -
                        tf.math.reduce_sum(
         | 
| 555 | 
            -
                            tf.cast(
         | 
| 556 | 
            -
                                tf.math.greater(
         | 
| 557 | 
            -
                                    num_valid_images, self.max_images_per_prompt
         | 
| 558 | 
            -
                                ),
         | 
| 559 | 
            -
                                dtype=tf.int32,
         | 
| 560 | 
            -
                            )
         | 
| 561 | 
            -
                        ),
         | 
| 562 | 
            -
                        dtype=bool,
         | 
| 563 | 
            -
                    ):
         | 
| 564 | 
            -
                        raise ValueError(
         | 
| 565 | 
            -
                            "`num_valid_images` should have values <= "
         | 
| 566 | 
            -
                            "self.max_images_per_prompt. Received: "
         | 
| 567 | 
            -
                            f"num_valid_images = {num_valid_images}, ",
         | 
| 568 | 
            -
                            f"max_images_per_prompt = {self.max_images_per_prompt}",
         | 
| 569 | 
            -
                        )
         | 
| 570 | 
            -
             | 
| 571 | 
            -
                    # Resize, rescale, etc. the images.
         | 
| 572 | 
            -
                    padded_images_shape = tf.shape(images)
         | 
| 573 | 
            -
                    images = tf.reshape(
         | 
| 574 | 
            -
                        images,
         | 
| 575 | 
            -
                        [
         | 
| 576 | 
            -
                            -1,
         | 
| 577 | 
            -
                            padded_images_shape[-3],
         | 
| 578 | 
            -
                            padded_images_shape[-2],
         | 
| 579 | 
            -
                            padded_images_shape[-1],
         | 
| 580 | 
            -
                        ],
         | 
| 581 | 
            -
                    )
         | 
| 582 | 
            -
                    images = self.image_converter(images)
         | 
| 583 | 
            -
                    height = (
         | 
| 584 | 
            -
                        self.image_size[0]
         | 
| 585 | 
            -
                        if self.image_converter.image_size
         | 
| 586 | 
            -
                        else original_image_shape[-3]
         | 
| 587 | 
            -
                    )
         | 
| 588 | 
            -
                    width = (
         | 
| 589 | 
            -
                        self.image_size[1]
         | 
| 590 | 
            -
                        if self.image_converter.image_size
         | 
| 591 | 
            -
                        else original_image_shape[-2]
         | 
| 592 | 
            -
                    )
         | 
| 593 | 
            -
                    images = tf.reshape(
         | 
| 594 | 
            -
                        images,
         | 
| 595 | 
            -
                        [
         | 
| 596 | 
            -
                            padded_images_shape[0],
         | 
| 597 | 
            -
                            self.max_images_per_prompt,
         | 
| 598 | 
            -
                            height,
         | 
| 599 | 
            -
                            width,
         | 
| 600 | 
            -
                            3,
         | 
| 601 | 
            -
                        ],
         | 
| 602 | 
            -
                    )
         | 
| 603 | 
            -
             | 
| 604 | 
            -
                    padding_mask = token_ids != self.tokenizer.pad_token_id
         | 
| 605 | 
            -
                    token_ids = tf.ragged.boolean_mask(token_ids, padding_mask)
         | 
| 606 | 
            -
                    segment_ids = tf.ragged.boolean_mask(segment_ids, padding_mask)
         | 
| 607 | 
            -
                    padding_mask = tf.ragged.boolean_mask(padding_mask, padding_mask)
         | 
| 608 | 
            -
                    response_mask = segment_ids == 1
         | 
| 609 | 
            -
             | 
| 610 | 
            -
                    # Using `num_valid_images`, we need to add dummy image tokens at the
         | 
| 611 | 
            -
                    # end of the tokenized text. Ideally, we could have passed an image
         | 
| 612 | 
            -
                    # padding mask to the model, but it won't work with XLA since an
         | 
| 613 | 
            -
                    # `ops.where` on it in the interleaving layer will return different
         | 
| 614 | 
            -
                    # number of images every time. So, we need to fix the number of images.
         | 
| 615 | 
            -
                    vision_placeholder_tensor = self._get_image_placeholder_ragged_tensor(
         | 
| 616 | 
            -
                        (self.max_images_per_prompt - num_valid_images)
         | 
| 617 | 
            -
                        * self.num_vision_tokens_per_image,
         | 
| 618 | 
            -
                        self.tokenizer.token_to_id("<img>"),
         | 
| 619 | 
            -
                    )
         | 
| 620 | 
            -
                    vision_placeholder_tensor = vision_placeholder_tensor.to_tensor(
         | 
| 621 | 
            -
                        shape=[
         | 
| 622 | 
            -
                            batch_size,
         | 
| 623 | 
            -
                            self.max_images_per_prompt * self.num_vision_tokens_per_image,
         | 
| 624 | 
            -
                        ],
         | 
| 625 | 
            -
                        default_value=self.tokenizer.pad_token_id,
         | 
| 626 | 
            -
                    )
         | 
| 627 | 
            -
                    token_ids_with_placeholder = tf.concat(
         | 
| 628 | 
            -
                        [token_ids, vision_placeholder_tensor], axis=1
         | 
| 629 | 
            -
                    )
         | 
| 630 | 
            -
             | 
| 631 | 
            -
                    # Now, pad everything to the same length.
         | 
| 632 | 
            -
                    desired_length = (
         | 
| 633 | 
            -
                        sequence_length
         | 
| 634 | 
            -
                        + self.max_images_per_prompt * self.num_vision_tokens_per_image
         | 
| 635 | 
            -
                    )
         | 
| 636 | 
            -
                    token_ids_with_placeholder = token_ids_with_placeholder.to_tensor(
         | 
| 637 | 
            -
                        shape=[batch_size, desired_length],
         | 
| 638 | 
            -
                        default_value=self.tokenizer.pad_token_id,
         | 
| 639 | 
            -
                    )
         | 
| 640 | 
            -
                    padding_mask_with_placeholder = padding_mask.to_tensor(
         | 
| 641 | 
            -
                        shape=[batch_size, desired_length],
         | 
| 642 | 
            -
                        default_value=False,
         | 
| 643 | 
            -
                    )
         | 
| 644 | 
            -
                    response_mask_with_placeholder = response_mask.to_tensor(
         | 
| 645 | 
            -
                        shape=[batch_size, desired_length],
         | 
| 646 | 
            -
                        default_value=False,
         | 
| 647 | 
            -
                    )
         | 
| 791 | 
            +
                    # == Branch: vision model, with non-`None` value for `images` ==
         | 
| 792 | 
            +
                    images = self._preprocess_images(images=images, batched=batched)
         | 
| 648 793 |  | 
| 649 | 
            -
                     | 
| 650 | 
            -
                        "<img>"
         | 
| 651 | 
            -
                    )
         | 
| 794 | 
            +
                    vision_mask = token_ids == self.tokenizer.image_placeholder_id
         | 
| 652 795 |  | 
| 653 796 | 
             
                    return self._format_output(
         | 
| 654 797 | 
             
                        images=images,
         | 
| 655 | 
            -
                        token_ids= | 
| 656 | 
            -
                         | 
| 657 | 
            -
                        response_mask= | 
| 658 | 
            -
                        padding_mask= | 
| 798 | 
            +
                        token_ids=token_ids,
         | 
| 799 | 
            +
                        vision_mask=vision_mask,
         | 
| 800 | 
            +
                        response_mask=response_mask,
         | 
| 801 | 
            +
                        padding_mask=padding_mask,
         | 
| 659 802 | 
             
                        return_labels=False,
         | 
| 803 | 
            +
                        text_only_input=False,
         | 
| 804 | 
            +
                        batched=batched,
         | 
| 660 805 | 
             
                    )
         | 
| 661 806 |  | 
| 662 807 | 
             
                def get_config(self):
         | 
| @@ -686,6 +831,18 @@ class Gemma3CausalLMPreprocessor(CausalLMPreprocessor): | |
| 686 831 |  | 
| 687 832 | 
             
                    token_ids, padding_mask = x["token_ids"], x["padding_mask"]
         | 
| 688 833 | 
             
                    ids_to_strip = self.tokenizer.special_token_ids
         | 
| 689 | 
            -
             | 
| 834 | 
            +
             | 
| 835 | 
            +
                    # We do not want to strip SoI token because it is provided by the user.
         | 
| 836 | 
            +
                    if self.tokenizer.start_of_image_token_id in ids_to_strip:
         | 
| 837 | 
            +
                        ids_to_strip.remove(self.tokenizer.start_of_image_token_id)
         | 
| 838 | 
            +
             | 
| 690 839 | 
             
                    token_ids = strip_to_ragged(token_ids, padding_mask, ids_to_strip)
         | 
| 691 840 | 
             
                    return self.tokenizer.detokenize(token_ids)
         | 
| 841 | 
            +
             | 
| 842 | 
            +
                @property
         | 
| 843 | 
            +
                def max_images_per_prompt(self):
         | 
| 844 | 
            +
                    return self._max_images_per_prompt
         | 
| 845 | 
            +
             | 
| 846 | 
            +
                @max_images_per_prompt.setter
         | 
| 847 | 
            +
                def max_images_per_prompt(self, value):
         | 
| 848 | 
            +
                    self._max_images_per_prompt = value
         |