keras-hub-nightly 0.20.0.dev202504030357__py3-none-any.whl → 0.21.0.dev202504050402__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/models/__init__.py +5 -20
- keras_hub/api/tokenizers/__init__.py +0 -4
- keras_hub/src/layers/preprocessing/image_converter.py +26 -16
- keras_hub/src/models/gemma3/gemma3_attention.py +74 -21
- keras_hub/src/models/gemma3/gemma3_backbone.py +117 -46
- keras_hub/src/models/gemma3/gemma3_causal_lm.py +72 -15
- keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +512 -355
- keras_hub/src/models/gemma3/gemma3_decoder_block.py +23 -19
- keras_hub/src/models/gemma3/gemma3_image_converter.py +6 -0
- keras_hub/src/models/gemma3/gemma3_interleave_embeddings.py +56 -16
- keras_hub/src/models/gemma3/gemma3_presets.py +74 -8
- keras_hub/src/models/gemma3/gemma3_tokenizer.py +9 -0
- keras_hub/src/models/gemma3/{gemma3_vit.py → gemma3_vision_encoder.py} +150 -139
- keras_hub/src/models/qwen/qwen_backbone.py +0 -7
- keras_hub/src/models/qwen/qwen_causal_lm.py +0 -7
- keras_hub/src/models/qwen/qwen_causal_lm_preprocessor.py +0 -7
- keras_hub/src/models/qwen/qwen_tokenizer.py +0 -9
- keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -1
- keras_hub/src/models/roformer_v2/roformer_v2_text_classifier.py +2 -2
- keras_hub/src/models/vit/vit_image_converter.py +8 -3
- keras_hub/src/tests/test_case.py +4 -0
- keras_hub/src/utils/tensor_utils.py +6 -0
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.20.0.dev202504030357.dist-info → keras_hub_nightly-0.21.0.dev202504050402.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.20.0.dev202504030357.dist-info → keras_hub_nightly-0.21.0.dev202504050402.dist-info}/RECORD +27 -27
- {keras_hub_nightly-0.20.0.dev202504030357.dist-info → keras_hub_nightly-0.21.0.dev202504050402.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.20.0.dev202504030357.dist-info → keras_hub_nightly-0.21.0.dev202504050402.dist-info}/top_level.txt +0 -0
    
        keras_hub/api/models/__init__.py
    CHANGED
    
    | @@ -183,6 +183,9 @@ from keras_hub.src.models.gemma3.gemma3_causal_lm_preprocessor import ( | |
| 183 183 | 
             
                Gemma3CausalLMPreprocessor,
         | 
| 184 184 | 
             
            )
         | 
| 185 185 | 
             
            from keras_hub.src.models.gemma3.gemma3_tokenizer import Gemma3Tokenizer
         | 
| 186 | 
            +
            from keras_hub.src.models.gemma3.gemma3_vision_encoder import (
         | 
| 187 | 
            +
                Gemma3VisionEncoder,
         | 
| 188 | 
            +
            )
         | 
| 186 189 | 
             
            from keras_hub.src.models.gpt2.gpt2_backbone import GPT2Backbone
         | 
| 187 190 | 
             
            from keras_hub.src.models.gpt2.gpt2_causal_lm import GPT2CausalLM
         | 
| 188 191 | 
             
            from keras_hub.src.models.gpt2.gpt2_causal_lm_preprocessor import (
         | 
| @@ -273,24 +276,6 @@ from keras_hub.src.models.phi3.phi3_causal_lm_preprocessor import ( | |
| 273 276 | 
             
            )
         | 
| 274 277 | 
             
            from keras_hub.src.models.phi3.phi3_tokenizer import Phi3Tokenizer
         | 
| 275 278 | 
             
            from keras_hub.src.models.preprocessor import Preprocessor
         | 
| 276 | 
            -
            from keras_hub.src.models.qwen.qwen_backbone import QwenBackbone
         | 
| 277 | 
            -
            from keras_hub.src.models.qwen.qwen_backbone import (
         | 
| 278 | 
            -
                QwenBackbone as Qwen2Backbone,
         | 
| 279 | 
            -
            )
         | 
| 280 | 
            -
            from keras_hub.src.models.qwen.qwen_causal_lm import QwenCausalLM
         | 
| 281 | 
            -
            from keras_hub.src.models.qwen.qwen_causal_lm import (
         | 
| 282 | 
            -
                QwenCausalLM as Qwen2CausalLM,
         | 
| 283 | 
            -
            )
         | 
| 284 | 
            -
            from keras_hub.src.models.qwen.qwen_causal_lm_preprocessor import (
         | 
| 285 | 
            -
                QwenCausalLMPreprocessor,
         | 
| 286 | 
            -
            )
         | 
| 287 | 
            -
            from keras_hub.src.models.qwen.qwen_causal_lm_preprocessor import (
         | 
| 288 | 
            -
                QwenCausalLMPreprocessor as Qwen2CausalLMPreprocessor,
         | 
| 289 | 
            -
            )
         | 
| 290 | 
            -
            from keras_hub.src.models.qwen.qwen_tokenizer import QwenTokenizer
         | 
| 291 | 
            -
            from keras_hub.src.models.qwen.qwen_tokenizer import (
         | 
| 292 | 
            -
                QwenTokenizer as Qwen2Tokenizer,
         | 
| 293 | 
            -
            )
         | 
| 294 279 | 
             
            from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone
         | 
| 295 280 | 
             
            from keras_hub.src.models.resnet.resnet_image_classifier import (
         | 
| 296 281 | 
             
                ResNetImageClassifier,
         | 
| @@ -324,7 +309,7 @@ from keras_hub.src.models.roberta.roberta_text_classifier_preprocessor import ( | |
| 324 309 | 
             
            )
         | 
| 325 310 | 
             
            from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer
         | 
| 326 311 | 
             
            from keras_hub.src.models.roformer_v2.roformer_v2_backbone import (
         | 
| 327 | 
            -
                RoformerV2Backbone | 
| 312 | 
            +
                RoformerV2Backbone,
         | 
| 328 313 | 
             
            )
         | 
| 329 314 | 
             
            from keras_hub.src.models.roformer_v2.roformer_v2_masked_lm import (
         | 
| 330 315 | 
             
                RoformerV2MaskedLM,
         | 
| @@ -333,7 +318,7 @@ from keras_hub.src.models.roformer_v2.roformer_v2_masked_lm_preprocessor import | |
| 333 318 | 
             
                RoformerV2MaskedLMPreprocessor,
         | 
| 334 319 | 
             
            )
         | 
| 335 320 | 
             
            from keras_hub.src.models.roformer_v2.roformer_v2_text_classifier import (
         | 
| 336 | 
            -
                 | 
| 321 | 
            +
                RoformerV2TextClassifier,
         | 
| 337 322 | 
             
            )
         | 
| 338 323 | 
             
            from keras_hub.src.models.roformer_v2.roformer_v2_text_classifier_preprocessor import (
         | 
| 339 324 | 
             
                RoformerV2TextClassifierPreprocessor,
         | 
| @@ -30,10 +30,6 @@ from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import ( | |
| 30 30 | 
             
                PaliGemmaTokenizer,
         | 
| 31 31 | 
             
            )
         | 
| 32 32 | 
             
            from keras_hub.src.models.phi3.phi3_tokenizer import Phi3Tokenizer
         | 
| 33 | 
            -
            from keras_hub.src.models.qwen.qwen_tokenizer import QwenTokenizer
         | 
| 34 | 
            -
            from keras_hub.src.models.qwen.qwen_tokenizer import (
         | 
| 35 | 
            -
                QwenTokenizer as Qwen2Tokenizer,
         | 
| 36 | 
            -
            )
         | 
| 37 33 | 
             
            from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer
         | 
| 38 34 | 
             
            from keras_hub.src.models.roformer_v2.roformer_v2_tokenizer import (
         | 
| 39 35 | 
             
                RoformerV2Tokenizer,
         | 
| @@ -16,6 +16,7 @@ from keras_hub.src.utils.preset_utils import get_preset_loader | |
| 16 16 | 
             
            from keras_hub.src.utils.preset_utils import get_preset_saver
         | 
| 17 17 | 
             
            from keras_hub.src.utils.python_utils import classproperty
         | 
| 18 18 | 
             
            from keras_hub.src.utils.tensor_utils import check_bounding_box_support
         | 
| 19 | 
            +
            from keras_hub.src.utils.tensor_utils import in_tf_function
         | 
| 19 20 | 
             
            from keras_hub.src.utils.tensor_utils import preprocessing_function
         | 
| 20 21 |  | 
| 21 22 |  | 
| @@ -270,9 +271,15 @@ class ImageConverter(PreprocessingLayer): | |
| 270 271 | 
             
                    else:
         | 
| 271 272 | 
             
                        x = inputs
         | 
| 272 273 | 
             
                    if self.scale is not None:
         | 
| 273 | 
            -
                         | 
| 274 | 
            +
                        # If we are scaling always cast to the compute dtype. We can't
         | 
| 275 | 
            +
                        # leave things as an int type if we are scaling to [0, 1].
         | 
| 276 | 
            +
                        scale = self._expand_non_channel_dims(self.scale, x)
         | 
| 277 | 
            +
                        x, scale = self._convert_types(x, scale, self.compute_dtype)
         | 
| 278 | 
            +
                        x = x * scale
         | 
| 274 279 | 
             
                    if self.offset is not None:
         | 
| 275 | 
            -
                         | 
| 280 | 
            +
                        offset = self._expand_non_channel_dims(self.offset, x)
         | 
| 281 | 
            +
                        x, offset = self._convert_types(x, offset, x.dtype)
         | 
| 282 | 
            +
                        x = x + offset
         | 
| 276 283 | 
             
                    if isinstance(inputs, dict):
         | 
| 277 284 | 
             
                        inputs["images"] = x
         | 
| 278 285 | 
             
                    else:
         | 
| @@ -280,26 +287,29 @@ class ImageConverter(PreprocessingLayer): | |
| 280 287 | 
             
                    return inputs
         | 
| 281 288 |  | 
| 282 289 | 
             
                def _expand_non_channel_dims(self, value, inputs):
         | 
| 283 | 
            -
                     | 
| 284 | 
            -
             | 
| 290 | 
            +
                    """Expand non channel dims so value is broadcastable with inputs."""
         | 
| 285 291 | 
             
                    unbatched = len(ops.shape(inputs)) == 3
         | 
| 286 292 | 
             
                    channels_first = self.data_format == "channels_first"
         | 
| 287 293 | 
             
                    if unbatched:
         | 
| 288 294 | 
             
                        broadcast_dims = (1, 2) if channels_first else (0, 1)
         | 
| 289 295 | 
             
                    else:
         | 
| 290 296 | 
             
                        broadcast_dims = (0, 2, 3) if channels_first else (0, 1, 2)
         | 
| 291 | 
            -
                    #  | 
| 292 | 
            -
                     | 
| 293 | 
            -
             | 
| 294 | 
            -
             | 
| 295 | 
            -
             | 
| 296 | 
            -
             | 
| 297 | 
            -
                         | 
| 298 | 
            -
             | 
| 299 | 
            -
             | 
| 300 | 
            -
                        return  | 
| 301 | 
            -
                     | 
| 302 | 
            -
             | 
| 297 | 
            +
                    # An numpy value will work backend native ops or with tf.data.
         | 
| 298 | 
            +
                    return np.expand_dims(value, broadcast_dims)
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                def _convert_types(self, x, y, dtype):
         | 
| 301 | 
            +
                    """Make sure x and y have the same dtype and are on ths same device."""
         | 
| 302 | 
            +
                    if in_tf_function():
         | 
| 303 | 
            +
                        # This could happen on any backend if we are running in tf.data.
         | 
| 304 | 
            +
                        import tensorflow as tf
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                        return tf.cast(x, dtype), tf.cast(y, dtype)
         | 
| 307 | 
            +
                    x = ops.cast(x, dtype)
         | 
| 308 | 
            +
                    y = ops.cast(y, dtype)
         | 
| 309 | 
            +
                    if keras.backend.backend() == "torch":
         | 
| 310 | 
            +
                        # Place on the same device as x (the image).
         | 
| 311 | 
            +
                        y = y.to(x.device)
         | 
| 312 | 
            +
                    return x, y
         | 
| 303 313 |  | 
| 304 314 | 
             
                def get_config(self):
         | 
| 305 315 | 
             
                    config = super().get_config()
         | 
| @@ -8,19 +8,28 @@ from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding | |
| 8 8 | 
             
            from keras_hub.src.models.gemma.rms_normalization import RMSNormalization
         | 
| 9 9 | 
             
            from keras_hub.src.utils.keras_utils import clone_initializer
         | 
| 10 10 | 
             
            from keras_hub.src.utils.keras_utils import fused_attention_op_available
         | 
| 11 | 
            +
            from keras_hub.src.utils.keras_utils import gpu_supports_fused_attention_op
         | 
| 12 | 
            +
            from keras_hub.src.utils.keras_utils import running_on_gpu
         | 
| 11 13 | 
             
            from keras_hub.src.utils.keras_utils import running_on_tpu
         | 
| 12 14 |  | 
| 13 15 |  | 
| 14 16 | 
             
            class CachedGemma3Attention(keras.layers.Layer):
         | 
| 15 17 | 
             
                """A cached grouped query attention layer for Gemma3.
         | 
| 16 18 |  | 
| 17 | 
            -
                This is  | 
| 19 | 
            +
                This is the same as the attention layer used for Gemma and Gemma2. It
         | 
| 20 | 
            +
                exposes a few additional args:
         | 
| 18 21 |  | 
| 19 | 
            -
                 | 
| 20 | 
            -
             | 
| 21 | 
            -
             | 
| 22 | 
            -
             | 
| 23 | 
            -
             | 
| 22 | 
            +
                `use_query_key_norm`: bool. If True, apply RMS normalization on query
         | 
| 23 | 
            +
                    and key. For Gemma3, this is True.
         | 
| 24 | 
            +
                `rope_wavelength`: float. Configurable value for RoPE wavelength. Gemma3
         | 
| 25 | 
            +
                    uses 10K for local attention layers and 1M for global attention layers.
         | 
| 26 | 
            +
                `gate_dim_reduction`: int. In the gating layers, the output dimension is
         | 
| 27 | 
            +
                    `intermediate_dim // gate_dim_reduction`. For Gemma and Gemma2, this
         | 
| 28 | 
            +
                    value is 2. For Gemma3, it is 1.
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                Moreover, the call() method takes in a `cache_update_mask` so as to make
         | 
| 31 | 
            +
                sure that the key-value cache is updated only for the non-prompt tokens
         | 
| 32 | 
            +
                during generation.
         | 
| 24 33 | 
             
                """
         | 
| 25 34 |  | 
| 26 35 | 
             
                def __init__(
         | 
| @@ -139,17 +148,22 @@ class CachedGemma3Attention(keras.layers.Layer): | |
| 139 148 | 
             
                    x = self.rope_layer(x, start_index=start_index)
         | 
| 140 149 | 
             
                    return x
         | 
| 141 150 |  | 
| 142 | 
            -
                def  | 
| 151 | 
            +
                def _use_fused_attention_op(self):
         | 
| 143 152 | 
             
                    if not fused_attention_op_available():
         | 
| 144 153 | 
             
                        return False
         | 
| 145 154 | 
             
                    if self.dropout > 0.0:
         | 
| 146 155 | 
             
                        return False
         | 
| 147 | 
            -
                    if  | 
| 148 | 
            -
                         | 
| 149 | 
            -
             | 
| 150 | 
            -
             | 
| 151 | 
            -
             | 
| 152 | 
            -
                     | 
| 156 | 
            +
                    if running_on_gpu():
         | 
| 157 | 
            +
                        # GPU never supports softcap in the fused op.
         | 
| 158 | 
            +
                        if self.logit_soft_cap is not None:
         | 
| 159 | 
            +
                            return False
         | 
| 160 | 
            +
                        return gpu_supports_fused_attention_op()
         | 
| 161 | 
            +
                    elif running_on_tpu():
         | 
| 162 | 
            +
                        # TPU supports softcap with on keras >= 3.10.
         | 
| 163 | 
            +
                        sig = inspect.signature(ops.dot_product_attention)
         | 
| 164 | 
            +
                        return "attn_logits_soft_cap" in sig.parameters
         | 
| 165 | 
            +
                    else:
         | 
| 166 | 
            +
                        return False
         | 
| 153 167 |  | 
| 154 168 | 
             
                def _compute_attention(
         | 
| 155 169 | 
             
                    self,
         | 
| @@ -166,7 +180,14 @@ class CachedGemma3Attention(keras.layers.Layer): | |
| 166 180 | 
             
                        query_normalization = 1 / np.sqrt(
         | 
| 167 181 | 
             
                            self.hidden_dim // self.num_query_heads
         | 
| 168 182 | 
             
                        )
         | 
| 169 | 
            -
             | 
| 183 | 
            +
             | 
| 184 | 
            +
                    if self.use_sliding_window_attention and attention_mask is not None:
         | 
| 185 | 
            +
                        attention_mask = self._mask_sliding_window(
         | 
| 186 | 
            +
                            attention_mask,
         | 
| 187 | 
            +
                            cache_update_index=cache_update_index,
         | 
| 188 | 
            +
                        )
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                    if self._use_fused_attention_op():
         | 
| 170 191 | 
             
                        if attention_mask is not None:
         | 
| 171 192 | 
             
                            attention_mask = ops.expand_dims(attention_mask, axis=1)
         | 
| 172 193 | 
             
                            attention_mask = ops.cast(attention_mask, dtype="bool")
         | 
| @@ -205,13 +226,8 @@ class CachedGemma3Attention(keras.layers.Layer): | |
| 205 226 | 
             
                            ops.tanh(attention_logits), self.logit_soft_cap
         | 
| 206 227 | 
             
                        )
         | 
| 207 228 |  | 
| 208 | 
            -
                    if  | 
| 209 | 
            -
                        attention_mask =  | 
| 210 | 
            -
                            attention_mask,
         | 
| 211 | 
            -
                            cache_update_index=cache_update_index,
         | 
| 212 | 
            -
                        )
         | 
| 213 | 
            -
             | 
| 214 | 
            -
                    attention_mask = attention_mask[:, None, None, :, :]
         | 
| 229 | 
            +
                    if attention_mask is not None:
         | 
| 230 | 
            +
                        attention_mask = attention_mask[:, None, None, :, :]
         | 
| 215 231 | 
             
                    orig_dtype = attention_logits.dtype
         | 
| 216 232 | 
             
                    attention_softmax = self.softmax(attention_logits, mask=attention_mask)
         | 
| 217 233 | 
             
                    attention_softmax = ops.cast(attention_softmax, orig_dtype)
         | 
| @@ -256,6 +272,7 @@ class CachedGemma3Attention(keras.layers.Layer): | |
| 256 272 | 
             
                    attention_mask=None,
         | 
| 257 273 | 
             
                    cache=None,
         | 
| 258 274 | 
             
                    cache_update_index=0,
         | 
| 275 | 
            +
                    cache_update_mask=None,
         | 
| 259 276 | 
             
                    training=False,
         | 
| 260 277 | 
             
                ):
         | 
| 261 278 | 
             
                    query = self.query_dense(x)
         | 
| @@ -275,7 +292,43 @@ class CachedGemma3Attention(keras.layers.Layer): | |
| 275 292 |  | 
| 276 293 | 
             
                        key_update = self._apply_rope(key_update, cache_update_index)
         | 
| 277 294 | 
             
                        value_update = self.value_dense(x)
         | 
| 295 | 
            +
             | 
| 296 | 
            +
                        # Update cache. Note that the cache is updated only if the
         | 
| 297 | 
            +
                        # corresponding `cache_update_mask` value is True. This is to
         | 
| 298 | 
            +
                        # ensure that we don't update the cache at indices corresponding to
         | 
| 299 | 
            +
                        # the prompt. For Gemma3, in particular, this is useful because
         | 
| 300 | 
            +
                        # image tokens have bidirectional attention. During generation,
         | 
| 301 | 
            +
                        # if we have uneven inputs during generation, we might end up having
         | 
| 302 | 
            +
                        # causal attention between image tokens, which is incorrect. To
         | 
| 303 | 
            +
                        # avoid this, bidirectional attention is taken care of during
         | 
| 304 | 
            +
                        # the prefill step, and during generation, the cache is not updated
         | 
| 305 | 
            +
                        # for the prompt. The shape of `cache_update_mask` is
         | 
| 306 | 
            +
                        # `(bsz, seq_len)`, where `seq_len` is 1 when we are generating
         | 
| 307 | 
            +
                        # token-by-token.
         | 
| 278 308 | 
             
                        start = [0, cache_update_index, 0, 0]
         | 
| 309 | 
            +
                        if cache_update_mask is not None:
         | 
| 310 | 
            +
                            cache_update_mask = ops.expand_dims(
         | 
| 311 | 
            +
                                ops.expand_dims(cache_update_mask, axis=-1),
         | 
| 312 | 
            +
                                axis=-1,
         | 
| 313 | 
            +
                            )
         | 
| 314 | 
            +
                            key_original = ops.slice(
         | 
| 315 | 
            +
                                key_cache, start, ops.shape(key_update)
         | 
| 316 | 
            +
                            )
         | 
| 317 | 
            +
                            value_original = ops.slice(
         | 
| 318 | 
            +
                                value_cache, start, ops.shape(value_update)
         | 
| 319 | 
            +
                            )
         | 
| 320 | 
            +
             | 
| 321 | 
            +
                            key_update = ops.where(
         | 
| 322 | 
            +
                                cache_update_mask,
         | 
| 323 | 
            +
                                key_update,
         | 
| 324 | 
            +
                                key_original,
         | 
| 325 | 
            +
                            )
         | 
| 326 | 
            +
                            value_update = ops.where(
         | 
| 327 | 
            +
                                cache_update_mask,
         | 
| 328 | 
            +
                                value_update,
         | 
| 329 | 
            +
                                value_original,
         | 
| 330 | 
            +
                            )
         | 
| 331 | 
            +
             | 
| 279 332 | 
             
                        key = ops.slice_update(key_cache, start, key_update)
         | 
| 280 333 | 
             
                        value = ops.slice_update(value_cache, start, value_update)
         | 
| 281 334 | 
             
                        cache = ops.stack((key, value), axis=1)
         | 
| @@ -19,13 +19,10 @@ class Gemma3Backbone(Backbone): | |
| 19 19 |  | 
| 20 20 | 
             
                This backbone implements the Gemma3 model architecture. Gemma3 is a
         | 
| 21 21 | 
             
                vision-language model (image-text in, text out). The text input is encoded
         | 
| 22 | 
            -
                using an embedding layer; images are encoded using a vision transformer | 
| 23 | 
            -
                After encoding these two modalities, the image embeddings are placed | 
| 24 | 
            -
                correct position in the text embedding sequence. The mixed sequence | 
| 25 | 
            -
                embeddings is then passed through transformer decoder layers.
         | 
| 26 | 
            -
             | 
| 27 | 
            -
                Currently, this model supports only the `vision_encoder = None` case, i.e.,
         | 
| 28 | 
            -
                working only with text.
         | 
| 22 | 
            +
                using an embedding layer; images are encoded using a vision transformer
         | 
| 23 | 
            +
                (ViT). After encoding these two modalities, the image embeddings are placed
         | 
| 24 | 
            +
                in the correct position in the text embedding sequence. The mixed sequence
         | 
| 25 | 
            +
                of embeddings is then passed through transformer decoder layers.
         | 
| 29 26 |  | 
| 30 27 | 
             
                For a higher-level object for text-generation, see
         | 
| 31 28 | 
             
                `keras_hub.models.Gemma3CausalLM`.
         | 
| @@ -66,8 +63,9 @@ class Gemma3Backbone(Backbone): | |
| 66 63 | 
             
                      window attention. Defaults to `False`.
         | 
| 67 64 | 
             
                    sliding_window_size: int. Size of the sliding local window. Defaults to
         | 
| 68 65 | 
             
                        `4096`.
         | 
| 69 | 
            -
                    vision_encoder:  | 
| 70 | 
            -
                        takes in images and returns corresponding sequence of embeddings.
         | 
| 66 | 
            +
                    vision_encoder: A `Gemma3VisionEncoder` instance. `call()`
         | 
| 67 | 
            +
                        takes in images and returns corresponding sequence of embeddings. If
         | 
| 68 | 
            +
                        `None`, the model is a text-only model.
         | 
| 71 69 | 
             
                    layer_norm_epsilon: float. The epsilon value user for every layer norm
         | 
| 72 70 | 
             
                        in all transformer blocks. Defaults to `1e-6`.
         | 
| 73 71 | 
             
                    dropout: float. Dropout probability for the Transformer decoder blocks.
         | 
| @@ -75,10 +73,12 @@ class Gemma3Backbone(Backbone): | |
| 75 73 | 
             
                    dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
         | 
| 76 74 | 
             
                        for the models computations and weights. Note that some
         | 
| 77 75 | 
             
                        computations, such as softmax and layer normalization will always
         | 
| 78 | 
            -
                        be done  | 
| 76 | 
            +
                        be done in float32 precision regardless of dtype. Defaults to
         | 
| 77 | 
            +
                        `bfloat16`.
         | 
| 79 78 |  | 
| 80 79 | 
             
                Example:
         | 
| 81 80 | 
             
                ```python
         | 
| 81 | 
            +
                # === Language Gemma3 model ===
         | 
| 82 82 | 
             
                input_data = {}
         | 
| 83 83 | 
             
                input_data["token_ids"] = np.ones(shape=(1, 300), dtype="int32")
         | 
| 84 84 | 
             
                input_data["padding_mask"] = (
         | 
| @@ -86,32 +86,90 @@ class Gemma3Backbone(Backbone): | |
| 86 86 | 
             
                    .astype(bool)
         | 
| 87 87 | 
             
                )
         | 
| 88 88 |  | 
| 89 | 
            +
                # Pretrained Gemma3 decoder.
         | 
| 90 | 
            +
                model = keras_hub.models.Gemma3Backbone.from_preset(
         | 
| 91 | 
            +
                    "gemma3_instruct_4b_text"
         | 
| 92 | 
            +
                )
         | 
| 93 | 
            +
                model(input_data)
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                # Randomly initialized Gemma3 decoder with a custom config.
         | 
| 96 | 
            +
                model = keras_hub.models.Gemma3Backbone(
         | 
| 97 | 
            +
                    vocabulary_size=262144,
         | 
| 98 | 
            +
                    image_size=896,
         | 
| 99 | 
            +
                    num_layers=34,
         | 
| 100 | 
            +
                    num_query_heads=8,
         | 
| 101 | 
            +
                    num_key_value_heads=4,
         | 
| 102 | 
            +
                    hidden_dim=2560,
         | 
| 103 | 
            +
                    intermediate_dim=10240,
         | 
| 104 | 
            +
                    head_dim=256,
         | 
| 105 | 
            +
                    query_head_dim_normalize=True,
         | 
| 106 | 
            +
                    use_post_ffw_norm=True,
         | 
| 107 | 
            +
                    use_post_attention_norm=True,
         | 
| 108 | 
            +
                    final_logit_soft_cap=None,
         | 
| 109 | 
            +
                    attention_logit_soft_cap=None,
         | 
| 110 | 
            +
                    sliding_window_size=1024,
         | 
| 111 | 
            +
                    use_sliding_window_attention=True,
         | 
| 112 | 
            +
                    vision_encoder=None,
         | 
| 113 | 
            +
                    layer_norm_epsilon=1e-06,
         | 
| 114 | 
            +
                    dtype="bfloat16",
         | 
| 115 | 
            +
                )
         | 
| 116 | 
            +
                model(input_data)
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                # === Vision + Language Gemma3 model ===
         | 
| 119 | 
            +
                input_data = {}
         | 
| 120 | 
            +
                input_data["images"] = np.ones(shape=(1, 1, 896, 896, 3))
         | 
| 121 | 
            +
                input_data["token_ids"] = np.ones(shape=(1, 300), dtype="int32")
         | 
| 122 | 
            +
                # images after the text part of the sequence.
         | 
| 123 | 
            +
                input_data["vision_mask"] = np.expand_dims(
         | 
| 124 | 
            +
                    np.array([0] * 30 + [1] * 256 + [0] * 14),
         | 
| 125 | 
            +
                    axis=0,
         | 
| 126 | 
            +
                ).astype(bool)
         | 
| 127 | 
            +
                input_data["vision_indices"] = (
         | 
| 128 | 
            +
                    np.expand_dims(np.arange(30, 286), axis=0)
         | 
| 129 | 
            +
                )
         | 
| 130 | 
            +
                input_data["padding_mask"] = (
         | 
| 131 | 
            +
                    np.expand_dims(np.array([1] * 286 + [0] * (300 - 286)), axis=0)
         | 
| 132 | 
            +
                    .astype(bool)
         | 
| 133 | 
            +
                )
         | 
| 134 | 
            +
             | 
| 89 135 | 
             
                # Pretrained Gemma3 decoder.
         | 
| 90 136 | 
             
                model = keras_hub.models.Gemma3Backbone.from_preset("gemma3_instruct_4b")
         | 
| 91 137 | 
             
                model(input_data)
         | 
| 92 138 |  | 
| 93 | 
            -
                 | 
| 94 | 
            -
             | 
| 95 | 
            -
                     | 
| 96 | 
            -
                     | 
| 97 | 
            -
                     | 
| 98 | 
            -
                     | 
| 99 | 
            -
                     | 
| 100 | 
            -
                     | 
| 101 | 
            -
                     | 
| 102 | 
            -
                     | 
| 103 | 
            -
                     | 
| 104 | 
            -
                     | 
| 105 | 
            -
             | 
| 106 | 
            -
             | 
| 107 | 
            -
             | 
| 108 | 
            -
                     | 
| 109 | 
            -
                     | 
| 110 | 
            -
                     | 
| 111 | 
            -
                     | 
| 112 | 
            -
             | 
| 113 | 
            -
             | 
| 114 | 
            -
             | 
| 139 | 
            +
                # Randomly initialized Gemma3 decoder with a custom config.
         | 
| 140 | 
            +
                vision_encoder = Gemma3VisionEncoder(
         | 
| 141 | 
            +
                    image_size=896,
         | 
| 142 | 
            +
                    patch_size=14,
         | 
| 143 | 
            +
                    num_heads=16,
         | 
| 144 | 
            +
                    hidden_dim=1152,
         | 
| 145 | 
            +
                    num_layers=27,
         | 
| 146 | 
            +
                    intermediate_dim=4304,
         | 
| 147 | 
            +
                    output_dim=2560,
         | 
| 148 | 
            +
                    pool_size=4,
         | 
| 149 | 
            +
                    layer_norm_epsilon=1e-6,
         | 
| 150 | 
            +
                    dtype="float32",
         | 
| 151 | 
            +
                )
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                model = keras_hub.models.Gemma3Backbone(
         | 
| 154 | 
            +
                    vocabulary_size=262144,
         | 
| 155 | 
            +
                    image_size=896,
         | 
| 156 | 
            +
                    num_layers=34,
         | 
| 157 | 
            +
                    num_query_heads=8,
         | 
| 158 | 
            +
                    num_key_value_heads=4,
         | 
| 159 | 
            +
                    hidden_dim=2560,
         | 
| 160 | 
            +
                    intermediate_dim=10240,
         | 
| 161 | 
            +
                    head_dim=256,
         | 
| 162 | 
            +
                    query_head_dim_normalize=True,
         | 
| 163 | 
            +
                    use_post_ffw_norm=True,
         | 
| 164 | 
            +
                    use_post_attention_norm=True,
         | 
| 165 | 
            +
                    final_logit_soft_cap=None,
         | 
| 166 | 
            +
                    attention_logit_soft_cap=None,
         | 
| 167 | 
            +
                    sliding_window_size=1024,
         | 
| 168 | 
            +
                    use_sliding_window_attention=True,
         | 
| 169 | 
            +
                    vision_encoder=vision_encoder,
         | 
| 170 | 
            +
                    layer_norm_epsilon=1e-06,
         | 
| 171 | 
            +
                    dtype="bfloat16"
         | 
| 172 | 
            +
                )
         | 
| 115 173 | 
             
                model(input_data)
         | 
| 116 174 | 
             
                ```
         | 
| 117 175 | 
             
                """
         | 
| @@ -134,18 +192,14 @@ class Gemma3Backbone(Backbone): | |
| 134 192 | 
             
                    final_logit_soft_cap=None,
         | 
| 135 193 | 
             
                    use_sliding_window_attention=False,
         | 
| 136 194 | 
             
                    sliding_window_size=1024,
         | 
| 195 | 
            +
                    local_rope_scaling_factor=1.0,
         | 
| 196 | 
            +
                    global_rope_scaling_factor=1.0,
         | 
| 137 197 | 
             
                    vision_encoder=None,
         | 
| 138 198 | 
             
                    layer_norm_epsilon=1e-6,
         | 
| 139 199 | 
             
                    dropout=0,
         | 
| 140 200 | 
             
                    dtype=None,
         | 
| 141 201 | 
             
                    **kwargs,
         | 
| 142 202 | 
             
                ):
         | 
| 143 | 
            -
                    if vision_encoder is not None:
         | 
| 144 | 
            -
                        raise ValueError(
         | 
| 145 | 
            -
                            "Currently, only the text version of the Gemma3 model is "
         | 
| 146 | 
            -
                            "supported."
         | 
| 147 | 
            -
                        )
         | 
| 148 | 
            -
             | 
| 149 203 | 
             
                    # === Layers ===
         | 
| 150 204 | 
             
                    self.token_embedding = ReversibleEmbedding(
         | 
| 151 205 | 
             
                        input_dim=vocabulary_size,
         | 
| @@ -176,7 +230,11 @@ class Gemma3Backbone(Backbone): | |
| 176 230 | 
             
                        # 5 local, 1 global
         | 
| 177 231 | 
             
                        sliding_window = use_sliding_window_attention and (i % 6 < 5)
         | 
| 178 232 | 
             
                        rope_wavelength = 10_000.0 if sliding_window else 1_000_000.0
         | 
| 179 | 
            -
                        rope_scaling_factor =  | 
| 233 | 
            +
                        rope_scaling_factor = (
         | 
| 234 | 
            +
                            local_rope_scaling_factor
         | 
| 235 | 
            +
                            if sliding_window
         | 
| 236 | 
            +
                            else global_rope_scaling_factor
         | 
| 237 | 
            +
                        )
         | 
| 180 238 | 
             
                        layer = Gemma3DecoderBlock(
         | 
| 181 239 | 
             
                            hidden_dim=hidden_dim,
         | 
| 182 240 | 
             
                            intermediate_dim=intermediate_dim,
         | 
| @@ -215,10 +273,11 @@ class Gemma3Backbone(Backbone): | |
| 215 273 | 
             
                        vision_indices_input = keras.Input(
         | 
| 216 274 | 
             
                            shape=(None,), dtype="int32", name="vision_indices"
         | 
| 217 275 | 
             
                        )
         | 
| 218 | 
            -
                        #  | 
| 219 | 
            -
                        # `vision_indices_input | 
| 220 | 
            -
                         | 
| 221 | 
            -
             | 
| 276 | 
            +
                        # Truth be told, this is redundant, and we can infer this from
         | 
| 277 | 
            +
                        # `vision_indices_input`. But it is easier to return this from
         | 
| 278 | 
            +
                        # the preprocessor than to compute it here.
         | 
| 279 | 
            +
                        vision_mask_input = keras.Input(
         | 
| 280 | 
            +
                            shape=(None,), dtype="int32", name="vision_mask"
         | 
| 222 281 | 
             
                        )
         | 
| 223 282 |  | 
| 224 283 | 
             
                    token_id_input = keras.Input(
         | 
| @@ -239,7 +298,7 @@ class Gemma3Backbone(Backbone): | |
| 239 298 | 
             
                    if not text_only_model:
         | 
| 240 299 | 
             
                        img_embeddings = self.vision_encoder(image_input)
         | 
| 241 300 |  | 
| 242 | 
            -
                         | 
| 301 | 
            +
                        # == Interleaving text and images ==
         | 
| 243 302 | 
             
                        # Place image embeddings in the right position in
         | 
| 244 303 | 
             
                        # `text_embeddings`.
         | 
| 245 304 | 
             
                        x = self.interleave_embeddings(
         | 
| @@ -255,7 +314,7 @@ class Gemma3Backbone(Backbone): | |
| 255 314 | 
             
                        x = transformer_layer(
         | 
| 256 315 | 
             
                            x,
         | 
| 257 316 | 
             
                            padding_mask=padding_mask_input,
         | 
| 258 | 
            -
                             | 
| 317 | 
            +
                            vision_mask=None if text_only_model else vision_mask_input,
         | 
| 259 318 | 
             
                        )
         | 
| 260 319 | 
             
                    sequence_output = self.layer_norm(x)
         | 
| 261 320 |  | 
| @@ -268,7 +327,7 @@ class Gemma3Backbone(Backbone): | |
| 268 327 | 
             
                            {
         | 
| 269 328 | 
             
                                "images": image_input,
         | 
| 270 329 | 
             
                                "vision_indices": vision_indices_input,
         | 
| 271 | 
            -
                                " | 
| 330 | 
            +
                                "vision_mask": vision_mask_input,
         | 
| 272 331 | 
             
                            }
         | 
| 273 332 | 
             
                        )
         | 
| 274 333 |  | 
| @@ -296,6 +355,8 @@ class Gemma3Backbone(Backbone): | |
| 296 355 | 
             
                    self.final_logit_soft_cap = final_logit_soft_cap
         | 
| 297 356 | 
             
                    self.use_sliding_window_attention = use_sliding_window_attention
         | 
| 298 357 | 
             
                    self.sliding_window_size = sliding_window_size
         | 
| 358 | 
            +
                    self.local_rope_scaling_factor = local_rope_scaling_factor
         | 
| 359 | 
            +
                    self.global_rope_scaling_factor = global_rope_scaling_factor
         | 
| 299 360 | 
             
                    self.layer_norm_epsilon = layer_norm_epsilon
         | 
| 300 361 | 
             
                    self.dropout = dropout
         | 
| 301 362 |  | 
| @@ -330,6 +391,8 @@ class Gemma3Backbone(Backbone): | |
| 330 391 | 
             
                                self.use_sliding_window_attention
         | 
| 331 392 | 
             
                            ),
         | 
| 332 393 | 
             
                            "sliding_window_size": self.sliding_window_size,
         | 
| 394 | 
            +
                            "local_rope_scaling_factor": self.local_rope_scaling_factor,
         | 
| 395 | 
            +
                            "global_rope_scaling_factor": self.global_rope_scaling_factor,
         | 
| 333 396 | 
             
                            "vision_encoder": None
         | 
| 334 397 | 
             
                            if self.vision_encoder is None
         | 
| 335 398 | 
             
                            else keras.layers.serialize(self.vision_encoder),
         | 
| @@ -339,6 +402,14 @@ class Gemma3Backbone(Backbone): | |
| 339 402 | 
             
                    )
         | 
| 340 403 | 
             
                    return config
         | 
| 341 404 |  | 
| 405 | 
            +
                def get_lora_target_names(self):
         | 
| 406 | 
            +
                    target_names = super().get_lora_target_names()
         | 
| 407 | 
            +
             | 
| 408 | 
            +
                    # Add these for `Gemma3VITAttention`.
         | 
| 409 | 
            +
                    if not self.text_only_model:
         | 
| 410 | 
            +
                        target_names += ["query_proj", "value_proj"]
         | 
| 411 | 
            +
                    return target_names
         | 
| 412 | 
            +
             | 
| 342 413 | 
             
                @classmethod
         | 
| 343 414 | 
             
                def from_config(cls, config):
         | 
| 344 415 | 
             
                    config.update(
         |