keras-hub-nightly 0.20.0.dev202504030357__py3-none-any.whl → 0.21.0.dev202504050402__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/models/__init__.py +5 -20
- keras_hub/api/tokenizers/__init__.py +0 -4
- keras_hub/src/layers/preprocessing/image_converter.py +26 -16
- keras_hub/src/models/gemma3/gemma3_attention.py +74 -21
- keras_hub/src/models/gemma3/gemma3_backbone.py +117 -46
- keras_hub/src/models/gemma3/gemma3_causal_lm.py +72 -15
- keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +512 -355
- keras_hub/src/models/gemma3/gemma3_decoder_block.py +23 -19
- keras_hub/src/models/gemma3/gemma3_image_converter.py +6 -0
- keras_hub/src/models/gemma3/gemma3_interleave_embeddings.py +56 -16
- keras_hub/src/models/gemma3/gemma3_presets.py +74 -8
- keras_hub/src/models/gemma3/gemma3_tokenizer.py +9 -0
- keras_hub/src/models/gemma3/{gemma3_vit.py → gemma3_vision_encoder.py} +150 -139
- keras_hub/src/models/qwen/qwen_backbone.py +0 -7
- keras_hub/src/models/qwen/qwen_causal_lm.py +0 -7
- keras_hub/src/models/qwen/qwen_causal_lm_preprocessor.py +0 -7
- keras_hub/src/models/qwen/qwen_tokenizer.py +0 -9
- keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -1
- keras_hub/src/models/roformer_v2/roformer_v2_text_classifier.py +2 -2
- keras_hub/src/models/vit/vit_image_converter.py +8 -3
- keras_hub/src/tests/test_case.py +4 -0
- keras_hub/src/utils/tensor_utils.py +6 -0
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.20.0.dev202504030357.dist-info → keras_hub_nightly-0.21.0.dev202504050402.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.20.0.dev202504030357.dist-info → keras_hub_nightly-0.21.0.dev202504050402.dist-info}/RECORD +27 -27
- {keras_hub_nightly-0.20.0.dev202504030357.dist-info → keras_hub_nightly-0.21.0.dev202504050402.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.20.0.dev202504030357.dist-info → keras_hub_nightly-0.21.0.dev202504050402.dist-info}/top_level.txt +0 -0
| @@ -1,3 +1,4 @@ | |
| 1 | 
            +
            import numpy as np
         | 
| 1 2 | 
             
            from keras import ops
         | 
| 2 3 |  | 
| 3 4 | 
             
            from keras_hub.src.api_export import keras_hub_export
         | 
| @@ -8,15 +9,21 @@ from keras_hub.src.models.gemma3.gemma3_causal_lm_preprocessor import ( | |
| 8 9 | 
             
            )
         | 
| 9 10 | 
             
            from keras_hub.src.utils.tensor_utils import any_equal
         | 
| 10 11 |  | 
| 12 | 
            +
            try:
         | 
| 13 | 
            +
                import tensorflow as tf
         | 
| 14 | 
            +
            except ImportError:
         | 
| 15 | 
            +
                tf = None
         | 
| 16 | 
            +
             | 
| 11 17 |  | 
| 12 18 | 
             
            @keras_hub_export("keras_hub.models.Gemma3CausalLM")
         | 
| 13 19 | 
             
            class Gemma3CausalLM(CausalLM):
         | 
| 14 | 
            -
                """An end-to-end  | 
| 20 | 
            +
                """An end-to-end multimodal Gemma3 model for causal language modeling.
         | 
| 15 21 |  | 
| 16 22 | 
             
                A causal language model (LM) predicts the next token based on previous
         | 
| 17 23 | 
             
                tokens. This task setup can be used to train the model unsupervised on
         | 
| 18 | 
            -
                 | 
| 19 | 
            -
                similar to the data used for training.
         | 
| 24 | 
            +
                images and plain text inputs, or to autoregressively generate plain text
         | 
| 25 | 
            +
                similar to the data used for training. Note that the model is
         | 
| 26 | 
            +
                image-text in, text out.
         | 
| 20 27 |  | 
| 21 28 | 
             
                This model has a `generate()` method, which generates text based on a
         | 
| 22 29 | 
             
                prompt. The generation strategy used is controlled by an additional
         | 
| @@ -30,10 +37,10 @@ class Gemma3CausalLM(CausalLM): | |
| 30 37 | 
             
                when creating the model with `from_preset()`.
         | 
| 31 38 |  | 
| 32 39 | 
             
                Args:
         | 
| 33 | 
            -
                    backbone: A `keras_hub.models.Gemma3Backbone` instance.
         | 
| 34 40 | 
             
                    preprocessor: A `keras_hub.models.Gemma3CausalLMPreprocessor` or
         | 
| 35 41 | 
             
                        `None`. If `None`, this model will not apply preprocessing, and
         | 
| 36 42 | 
             
                        inputs should be preprocessed before calling the model.
         | 
| 43 | 
            +
                    backbone: A `keras_hub.models.Gemma3Backbone` instance.
         | 
| 37 44 | 
             
                """
         | 
| 38 45 |  | 
| 39 46 | 
             
                backbone_cls = Gemma3Backbone
         | 
| @@ -79,15 +86,55 @@ class Gemma3CausalLM(CausalLM): | |
| 79 86 | 
             
                        **kwargs,
         | 
| 80 87 | 
             
                    )
         | 
| 81 88 |  | 
| 89 | 
            +
                def _normalize_generate_inputs(
         | 
| 90 | 
            +
                    self,
         | 
| 91 | 
            +
                    inputs,
         | 
| 92 | 
            +
                ):
         | 
| 93 | 
            +
                    """Overrides the superclass' method to handle unbatched image inputs."""
         | 
| 94 | 
            +
                    if tf and isinstance(inputs, tf.data.Dataset):
         | 
| 95 | 
            +
                        return inputs.as_numpy_iterator(), False
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                    if self.preprocessor is None:
         | 
| 98 | 
            +
                        return [inputs], False
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                    def normalize(x):
         | 
| 101 | 
            +
                        if isinstance(x, str):
         | 
| 102 | 
            +
                            return [x], True
         | 
| 103 | 
            +
                        if tf and isinstance(x, tf.Tensor) and x.shape.rank == 0:
         | 
| 104 | 
            +
                            return x[tf.newaxis], True
         | 
| 105 | 
            +
                        return x, False
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    if isinstance(inputs, dict):
         | 
| 108 | 
            +
                        inputs["prompts"], input_is_scalar = normalize(inputs["prompts"])
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                        # If prompt is scalar, images can be either a 3D NumPy array/Tensor,
         | 
| 111 | 
            +
                        # or, list of 3D NumPy arrays. Let's uprank images.
         | 
| 112 | 
            +
                        if input_is_scalar and "images" in inputs:
         | 
| 113 | 
            +
                            x = inputs["images"]
         | 
| 114 | 
            +
                            if isinstance(x, np.ndarray) and len(x.shape) == 3:
         | 
| 115 | 
            +
                                inputs["images"] = [x]
         | 
| 116 | 
            +
                            elif tf and isinstance(x, tf.Tensor) and x.shape.rank == 3:
         | 
| 117 | 
            +
                                inputs["images"] = x[tf.newaxis]
         | 
| 118 | 
            +
                            elif isinstance(x, list):
         | 
| 119 | 
            +
                                inputs["images"] = [x]
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                        if "responses" in inputs:
         | 
| 122 | 
            +
                            inputs["responses"], _ = normalize(inputs["responses"])
         | 
| 123 | 
            +
                    else:
         | 
| 124 | 
            +
                        inputs, input_is_scalar = normalize(inputs)
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                    return [inputs], input_is_scalar
         | 
| 127 | 
            +
             | 
| 82 128 | 
             
                def call_with_cache(
         | 
| 83 129 | 
             
                    self,
         | 
| 84 130 | 
             
                    token_ids,
         | 
| 85 131 | 
             
                    cache,
         | 
| 86 132 | 
             
                    cache_update_index,
         | 
| 87 133 | 
             
                    img_embeddings=None,
         | 
| 88 | 
            -
                     | 
| 134 | 
            +
                    vision_mask=None,
         | 
| 89 135 | 
             
                    padding_mask=None,
         | 
| 90 136 | 
             
                    vision_indices=None,
         | 
| 137 | 
            +
                    cache_update_mask=None,
         | 
| 91 138 | 
             
                ):
         | 
| 92 139 | 
             
                    """Forward pass of `Gemma3CausalLM` with cache.
         | 
| 93 140 |  | 
| @@ -139,7 +186,8 @@ class Gemma3CausalLM(CausalLM): | |
| 139 186 | 
             
                            cache=current_cache,
         | 
| 140 187 | 
             
                            cache_update_index=cache_update_index,
         | 
| 141 188 | 
             
                            padding_mask=padding_mask,
         | 
| 142 | 
            -
                             | 
| 189 | 
            +
                            vision_mask=vision_mask,
         | 
| 190 | 
            +
                            cache_update_mask=cache_update_mask,
         | 
| 143 191 | 
             
                        )
         | 
| 144 192 | 
             
                        caches.append(next_cache)
         | 
| 145 193 | 
             
                    cache = ops.stack(caches, axis=1)
         | 
| @@ -151,7 +199,7 @@ class Gemma3CausalLM(CausalLM): | |
| 151 199 | 
             
                    self,
         | 
| 152 200 | 
             
                    token_ids,
         | 
| 153 201 | 
             
                    img_embeddings,
         | 
| 154 | 
            -
                     | 
| 202 | 
            +
                    vision_mask,
         | 
| 155 203 | 
             
                    padding_mask,
         | 
| 156 204 | 
             
                    vision_indices,
         | 
| 157 205 | 
             
                ):
         | 
| @@ -170,11 +218,12 @@ class Gemma3CausalLM(CausalLM): | |
| 170 218 | 
             
                    logits, hidden_states, cache = self.call_with_cache(
         | 
| 171 219 | 
             
                        token_ids=token_ids,
         | 
| 172 220 | 
             
                        img_embeddings=img_embeddings,
         | 
| 173 | 
            -
                         | 
| 221 | 
            +
                        vision_mask=vision_mask,
         | 
| 174 222 | 
             
                        cache=cache,
         | 
| 175 223 | 
             
                        cache_update_index=0,
         | 
| 176 224 | 
             
                        padding_mask=padding_mask,
         | 
| 177 225 | 
             
                        vision_indices=vision_indices,
         | 
| 226 | 
            +
                        cache_update_mask=None,
         | 
| 178 227 | 
             
                    )
         | 
| 179 228 | 
             
                    return hidden_states, cache
         | 
| 180 229 |  | 
| @@ -193,29 +242,33 @@ class Gemma3CausalLM(CausalLM): | |
| 193 242 | 
             
                            will stop.
         | 
| 194 243 | 
             
                    """
         | 
| 195 244 |  | 
| 196 | 
            -
                    token_ids, padding_mask, images,  | 
| 245 | 
            +
                    token_ids, padding_mask, images, vision_mask, vision_indices = (
         | 
| 197 246 | 
             
                        inputs["token_ids"],
         | 
| 198 247 | 
             
                        inputs["padding_mask"],
         | 
| 199 248 | 
             
                        inputs.get("images", None),
         | 
| 200 | 
            -
                        inputs.get(" | 
| 249 | 
            +
                        inputs.get("vision_mask", None),
         | 
| 201 250 | 
             
                        inputs.get("vision_indices", None),
         | 
| 202 251 | 
             
                    )
         | 
| 203 252 | 
             
                    if not self.backbone.text_only_model:
         | 
| 204 | 
            -
                         | 
| 205 | 
            -
             | 
| 206 | 
            -
             | 
| 253 | 
            +
                        # Handle an unbatched image. Unlike `token_ids` and
         | 
| 254 | 
            +
                        # `padding_mask`, this will not automatically be upranked.
         | 
| 255 | 
            +
                        if len(ops.shape(images)) == 4:
         | 
| 207 256 | 
             
                            images = ops.expand_dims(images, axis=0)
         | 
| 257 | 
            +
                        if len(ops.shape(vision_mask)) == 1:
         | 
| 258 | 
            +
                            vision_mask = ops.expand_dims(vision_mask, axis=0)
         | 
| 259 | 
            +
                        if len(ops.shape(vision_indices)) == 1:
         | 
| 260 | 
            +
                            vision_indices = ops.expand_dims(vision_indices, axis=0)
         | 
| 208 261 | 
             
                        img_embeddings = self.backbone.vision_encoder(images)
         | 
| 209 262 | 
             
                    else:
         | 
| 210 263 | 
             
                        img_embeddings = None
         | 
| 211 | 
            -
                         | 
| 264 | 
            +
                        vision_mask = None
         | 
| 212 265 | 
             
                        vision_indices = None
         | 
| 213 266 |  | 
| 214 267 | 
             
                    # Create and seed cache with a single forward pass.
         | 
| 215 268 | 
             
                    hidden_states, cache = self._build_cache(
         | 
| 216 269 | 
             
                        token_ids,
         | 
| 217 270 | 
             
                        img_embeddings,
         | 
| 218 | 
            -
                         | 
| 271 | 
            +
                        vision_mask,
         | 
| 219 272 | 
             
                        padding_mask,
         | 
| 220 273 | 
             
                        vision_indices,
         | 
| 221 274 | 
             
                    )
         | 
| @@ -230,10 +283,14 @@ class Gemma3CausalLM(CausalLM): | |
| 230 283 | 
             
                        cache_update_index = index - 1
         | 
| 231 284 | 
             
                        batch_size = ops.shape(prompt)[0]
         | 
| 232 285 | 
             
                        prompt = ops.slice(prompt, [0, index - 1], [batch_size, 1])
         | 
| 286 | 
            +
                        sliced_cache_update_mask = ops.slice(
         | 
| 287 | 
            +
                            ~padding_mask, [0, index - 1], [batch_size, 1]
         | 
| 288 | 
            +
                        )
         | 
| 233 289 | 
             
                        logits, hidden_states, cache = self.call_with_cache(
         | 
| 234 290 | 
             
                            token_ids=prompt,
         | 
| 235 291 | 
             
                            cache=cache,
         | 
| 236 292 | 
             
                            cache_update_index=cache_update_index,
         | 
| 293 | 
            +
                            cache_update_mask=sliced_cache_update_mask,
         | 
| 237 294 | 
             
                        )
         | 
| 238 295 | 
             
                        return (
         | 
| 239 296 | 
             
                            ops.squeeze(logits, axis=1),
         |