keras-hub-nightly 0.20.0.dev202504020401__py3-none-any.whl → 0.21.0.dev202504040358__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/models/__init__.py +5 -20
 - keras_hub/api/tokenizers/__init__.py +0 -4
 - keras_hub/src/layers/preprocessing/image_converter.py +26 -16
 - keras_hub/src/models/gemma/gemma_attention.py +17 -10
 - keras_hub/src/models/gemma3/gemma3_attention.py +76 -23
 - keras_hub/src/models/gemma3/gemma3_backbone.py +117 -46
 - keras_hub/src/models/gemma3/gemma3_causal_lm.py +72 -15
 - keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +512 -355
 - keras_hub/src/models/gemma3/gemma3_decoder_block.py +23 -19
 - keras_hub/src/models/gemma3/gemma3_image_converter.py +6 -0
 - keras_hub/src/models/gemma3/gemma3_interleave_embeddings.py +56 -16
 - keras_hub/src/models/gemma3/gemma3_presets.py +74 -8
 - keras_hub/src/models/gemma3/gemma3_tokenizer.py +9 -0
 - keras_hub/src/models/gemma3/{gemma3_vit.py → gemma3_vision_encoder.py} +150 -139
 - keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -2
 - keras_hub/src/models/llama/llama_attention.py +2 -2
 - keras_hub/src/models/mistral/mistral_attention.py +2 -2
 - keras_hub/src/models/phi3/phi3_attention.py +2 -2
 - keras_hub/src/models/qwen/qwen_attention.py +2 -2
 - keras_hub/src/models/qwen/qwen_backbone.py +0 -7
 - keras_hub/src/models/qwen/qwen_causal_lm.py +0 -7
 - keras_hub/src/models/qwen/qwen_causal_lm_preprocessor.py +0 -7
 - keras_hub/src/models/qwen/qwen_tokenizer.py +0 -9
 - keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -1
 - keras_hub/src/models/roformer_v2/roformer_v2_text_classifier.py +2 -2
 - keras_hub/src/models/stable_diffusion_3/mmdit.py +2 -2
 - keras_hub/src/models/vit/vit_image_converter.py +8 -3
 - keras_hub/src/tests/test_case.py +4 -0
 - keras_hub/src/utils/keras_utils.py +44 -1
 - keras_hub/src/utils/tensor_utils.py +6 -0
 - keras_hub/src/version_utils.py +1 -1
 - {keras_hub_nightly-0.20.0.dev202504020401.dist-info → keras_hub_nightly-0.21.0.dev202504040358.dist-info}/METADATA +1 -1
 - {keras_hub_nightly-0.20.0.dev202504020401.dist-info → keras_hub_nightly-0.21.0.dev202504040358.dist-info}/RECORD +35 -35
 - {keras_hub_nightly-0.20.0.dev202504020401.dist-info → keras_hub_nightly-0.21.0.dev202504040358.dist-info}/WHEEL +0 -0
 - {keras_hub_nightly-0.20.0.dev202504020401.dist-info → keras_hub_nightly-0.21.0.dev202504040358.dist-info}/top_level.txt +0 -0
 
| 
         @@ -1,3 +1,4 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            import numpy as np
         
     | 
| 
       1 
2 
     | 
    
         
             
            from keras import ops
         
     | 
| 
       2 
3 
     | 
    
         | 
| 
       3 
4 
     | 
    
         
             
            from keras_hub.src.api_export import keras_hub_export
         
     | 
| 
         @@ -8,15 +9,21 @@ from keras_hub.src.models.gemma3.gemma3_causal_lm_preprocessor import ( 
     | 
|
| 
       8 
9 
     | 
    
         
             
            )
         
     | 
| 
       9 
10 
     | 
    
         
             
            from keras_hub.src.utils.tensor_utils import any_equal
         
     | 
| 
       10 
11 
     | 
    
         | 
| 
      
 12 
     | 
    
         
            +
            try:
         
     | 
| 
      
 13 
     | 
    
         
            +
                import tensorflow as tf
         
     | 
| 
      
 14 
     | 
    
         
            +
            except ImportError:
         
     | 
| 
      
 15 
     | 
    
         
            +
                tf = None
         
     | 
| 
      
 16 
     | 
    
         
            +
             
     | 
| 
       11 
17 
     | 
    
         | 
| 
       12 
18 
     | 
    
         
             
            @keras_hub_export("keras_hub.models.Gemma3CausalLM")
         
     | 
| 
       13 
19 
     | 
    
         
             
            class Gemma3CausalLM(CausalLM):
         
     | 
| 
       14 
     | 
    
         
            -
                """An end-to-end  
     | 
| 
      
 20 
     | 
    
         
            +
                """An end-to-end multimodal Gemma3 model for causal language modeling.
         
     | 
| 
       15 
21 
     | 
    
         | 
| 
       16 
22 
     | 
    
         
             
                A causal language model (LM) predicts the next token based on previous
         
     | 
| 
       17 
23 
     | 
    
         
             
                tokens. This task setup can be used to train the model unsupervised on
         
     | 
| 
       18 
     | 
    
         
            -
                 
     | 
| 
       19 
     | 
    
         
            -
                similar to the data used for training.
         
     | 
| 
      
 24 
     | 
    
         
            +
                images and plain text inputs, or to autoregressively generate plain text
         
     | 
| 
      
 25 
     | 
    
         
            +
                similar to the data used for training. Note that the model is
         
     | 
| 
      
 26 
     | 
    
         
            +
                image-text in, text out.
         
     | 
| 
       20 
27 
     | 
    
         | 
| 
       21 
28 
     | 
    
         
             
                This model has a `generate()` method, which generates text based on a
         
     | 
| 
       22 
29 
     | 
    
         
             
                prompt. The generation strategy used is controlled by an additional
         
     | 
| 
         @@ -30,10 +37,10 @@ class Gemma3CausalLM(CausalLM): 
     | 
|
| 
       30 
37 
     | 
    
         
             
                when creating the model with `from_preset()`.
         
     | 
| 
       31 
38 
     | 
    
         | 
| 
       32 
39 
     | 
    
         
             
                Args:
         
     | 
| 
       33 
     | 
    
         
            -
                    backbone: A `keras_hub.models.Gemma3Backbone` instance.
         
     | 
| 
       34 
40 
     | 
    
         
             
                    preprocessor: A `keras_hub.models.Gemma3CausalLMPreprocessor` or
         
     | 
| 
       35 
41 
     | 
    
         
             
                        `None`. If `None`, this model will not apply preprocessing, and
         
     | 
| 
       36 
42 
     | 
    
         
             
                        inputs should be preprocessed before calling the model.
         
     | 
| 
      
 43 
     | 
    
         
            +
                    backbone: A `keras_hub.models.Gemma3Backbone` instance.
         
     | 
| 
       37 
44 
     | 
    
         
             
                """
         
     | 
| 
       38 
45 
     | 
    
         | 
| 
       39 
46 
     | 
    
         
             
                backbone_cls = Gemma3Backbone
         
     | 
| 
         @@ -79,15 +86,55 @@ class Gemma3CausalLM(CausalLM): 
     | 
|
| 
       79 
86 
     | 
    
         
             
                        **kwargs,
         
     | 
| 
       80 
87 
     | 
    
         
             
                    )
         
     | 
| 
       81 
88 
     | 
    
         | 
| 
      
 89 
     | 
    
         
            +
                def _normalize_generate_inputs(
         
     | 
| 
      
 90 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 91 
     | 
    
         
            +
                    inputs,
         
     | 
| 
      
 92 
     | 
    
         
            +
                ):
         
     | 
| 
      
 93 
     | 
    
         
            +
                    """Overrides the superclass' method to handle unbatched image inputs."""
         
     | 
| 
      
 94 
     | 
    
         
            +
                    if tf and isinstance(inputs, tf.data.Dataset):
         
     | 
| 
      
 95 
     | 
    
         
            +
                        return inputs.as_numpy_iterator(), False
         
     | 
| 
      
 96 
     | 
    
         
            +
             
     | 
| 
      
 97 
     | 
    
         
            +
                    if self.preprocessor is None:
         
     | 
| 
      
 98 
     | 
    
         
            +
                        return [inputs], False
         
     | 
| 
      
 99 
     | 
    
         
            +
             
     | 
| 
      
 100 
     | 
    
         
            +
                    def normalize(x):
         
     | 
| 
      
 101 
     | 
    
         
            +
                        if isinstance(x, str):
         
     | 
| 
      
 102 
     | 
    
         
            +
                            return [x], True
         
     | 
| 
      
 103 
     | 
    
         
            +
                        if tf and isinstance(x, tf.Tensor) and x.shape.rank == 0:
         
     | 
| 
      
 104 
     | 
    
         
            +
                            return x[tf.newaxis], True
         
     | 
| 
      
 105 
     | 
    
         
            +
                        return x, False
         
     | 
| 
      
 106 
     | 
    
         
            +
             
     | 
| 
      
 107 
     | 
    
         
            +
                    if isinstance(inputs, dict):
         
     | 
| 
      
 108 
     | 
    
         
            +
                        inputs["prompts"], input_is_scalar = normalize(inputs["prompts"])
         
     | 
| 
      
 109 
     | 
    
         
            +
             
     | 
| 
      
 110 
     | 
    
         
            +
                        # If prompt is scalar, images can be either a 3D NumPy array/Tensor,
         
     | 
| 
      
 111 
     | 
    
         
            +
                        # or, list of 3D NumPy arrays. Let's uprank images.
         
     | 
| 
      
 112 
     | 
    
         
            +
                        if input_is_scalar and "images" in inputs:
         
     | 
| 
      
 113 
     | 
    
         
            +
                            x = inputs["images"]
         
     | 
| 
      
 114 
     | 
    
         
            +
                            if isinstance(x, np.ndarray) and len(x.shape) == 3:
         
     | 
| 
      
 115 
     | 
    
         
            +
                                inputs["images"] = [x]
         
     | 
| 
      
 116 
     | 
    
         
            +
                            elif tf and isinstance(x, tf.Tensor) and x.shape.rank == 3:
         
     | 
| 
      
 117 
     | 
    
         
            +
                                inputs["images"] = x[tf.newaxis]
         
     | 
| 
      
 118 
     | 
    
         
            +
                            elif isinstance(x, list):
         
     | 
| 
      
 119 
     | 
    
         
            +
                                inputs["images"] = [x]
         
     | 
| 
      
 120 
     | 
    
         
            +
             
     | 
| 
      
 121 
     | 
    
         
            +
                        if "responses" in inputs:
         
     | 
| 
      
 122 
     | 
    
         
            +
                            inputs["responses"], _ = normalize(inputs["responses"])
         
     | 
| 
      
 123 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 124 
     | 
    
         
            +
                        inputs, input_is_scalar = normalize(inputs)
         
     | 
| 
      
 125 
     | 
    
         
            +
             
     | 
| 
      
 126 
     | 
    
         
            +
                    return [inputs], input_is_scalar
         
     | 
| 
      
 127 
     | 
    
         
            +
             
     | 
| 
       82 
128 
     | 
    
         
             
                def call_with_cache(
         
     | 
| 
       83 
129 
     | 
    
         
             
                    self,
         
     | 
| 
       84 
130 
     | 
    
         
             
                    token_ids,
         
     | 
| 
       85 
131 
     | 
    
         
             
                    cache,
         
     | 
| 
       86 
132 
     | 
    
         
             
                    cache_update_index,
         
     | 
| 
       87 
133 
     | 
    
         
             
                    img_embeddings=None,
         
     | 
| 
       88 
     | 
    
         
            -
                     
     | 
| 
      
 134 
     | 
    
         
            +
                    vision_mask=None,
         
     | 
| 
       89 
135 
     | 
    
         
             
                    padding_mask=None,
         
     | 
| 
       90 
136 
     | 
    
         
             
                    vision_indices=None,
         
     | 
| 
      
 137 
     | 
    
         
            +
                    cache_update_mask=None,
         
     | 
| 
       91 
138 
     | 
    
         
             
                ):
         
     | 
| 
       92 
139 
     | 
    
         
             
                    """Forward pass of `Gemma3CausalLM` with cache.
         
     | 
| 
       93 
140 
     | 
    
         | 
| 
         @@ -139,7 +186,8 @@ class Gemma3CausalLM(CausalLM): 
     | 
|
| 
       139 
186 
     | 
    
         
             
                            cache=current_cache,
         
     | 
| 
       140 
187 
     | 
    
         
             
                            cache_update_index=cache_update_index,
         
     | 
| 
       141 
188 
     | 
    
         
             
                            padding_mask=padding_mask,
         
     | 
| 
       142 
     | 
    
         
            -
                             
     | 
| 
      
 189 
     | 
    
         
            +
                            vision_mask=vision_mask,
         
     | 
| 
      
 190 
     | 
    
         
            +
                            cache_update_mask=cache_update_mask,
         
     | 
| 
       143 
191 
     | 
    
         
             
                        )
         
     | 
| 
       144 
192 
     | 
    
         
             
                        caches.append(next_cache)
         
     | 
| 
       145 
193 
     | 
    
         
             
                    cache = ops.stack(caches, axis=1)
         
     | 
| 
         @@ -151,7 +199,7 @@ class Gemma3CausalLM(CausalLM): 
     | 
|
| 
       151 
199 
     | 
    
         
             
                    self,
         
     | 
| 
       152 
200 
     | 
    
         
             
                    token_ids,
         
     | 
| 
       153 
201 
     | 
    
         
             
                    img_embeddings,
         
     | 
| 
       154 
     | 
    
         
            -
                     
     | 
| 
      
 202 
     | 
    
         
            +
                    vision_mask,
         
     | 
| 
       155 
203 
     | 
    
         
             
                    padding_mask,
         
     | 
| 
       156 
204 
     | 
    
         
             
                    vision_indices,
         
     | 
| 
       157 
205 
     | 
    
         
             
                ):
         
     | 
| 
         @@ -170,11 +218,12 @@ class Gemma3CausalLM(CausalLM): 
     | 
|
| 
       170 
218 
     | 
    
         
             
                    logits, hidden_states, cache = self.call_with_cache(
         
     | 
| 
       171 
219 
     | 
    
         
             
                        token_ids=token_ids,
         
     | 
| 
       172 
220 
     | 
    
         
             
                        img_embeddings=img_embeddings,
         
     | 
| 
       173 
     | 
    
         
            -
                         
     | 
| 
      
 221 
     | 
    
         
            +
                        vision_mask=vision_mask,
         
     | 
| 
       174 
222 
     | 
    
         
             
                        cache=cache,
         
     | 
| 
       175 
223 
     | 
    
         
             
                        cache_update_index=0,
         
     | 
| 
       176 
224 
     | 
    
         
             
                        padding_mask=padding_mask,
         
     | 
| 
       177 
225 
     | 
    
         
             
                        vision_indices=vision_indices,
         
     | 
| 
      
 226 
     | 
    
         
            +
                        cache_update_mask=None,
         
     | 
| 
       178 
227 
     | 
    
         
             
                    )
         
     | 
| 
       179 
228 
     | 
    
         
             
                    return hidden_states, cache
         
     | 
| 
       180 
229 
     | 
    
         | 
| 
         @@ -193,29 +242,33 @@ class Gemma3CausalLM(CausalLM): 
     | 
|
| 
       193 
242 
     | 
    
         
             
                            will stop.
         
     | 
| 
       194 
243 
     | 
    
         
             
                    """
         
     | 
| 
       195 
244 
     | 
    
         | 
| 
       196 
     | 
    
         
            -
                    token_ids, padding_mask, images,  
     | 
| 
      
 245 
     | 
    
         
            +
                    token_ids, padding_mask, images, vision_mask, vision_indices = (
         
     | 
| 
       197 
246 
     | 
    
         
             
                        inputs["token_ids"],
         
     | 
| 
       198 
247 
     | 
    
         
             
                        inputs["padding_mask"],
         
     | 
| 
       199 
248 
     | 
    
         
             
                        inputs.get("images", None),
         
     | 
| 
       200 
     | 
    
         
            -
                        inputs.get(" 
     | 
| 
      
 249 
     | 
    
         
            +
                        inputs.get("vision_mask", None),
         
     | 
| 
       201 
250 
     | 
    
         
             
                        inputs.get("vision_indices", None),
         
     | 
| 
       202 
251 
     | 
    
         
             
                    )
         
     | 
| 
       203 
252 
     | 
    
         
             
                    if not self.backbone.text_only_model:
         
     | 
| 
       204 
     | 
    
         
            -
                         
     | 
| 
       205 
     | 
    
         
            -
             
     | 
| 
       206 
     | 
    
         
            -
             
     | 
| 
      
 253 
     | 
    
         
            +
                        # Handle an unbatched image. Unlike `token_ids` and
         
     | 
| 
      
 254 
     | 
    
         
            +
                        # `padding_mask`, this will not automatically be upranked.
         
     | 
| 
      
 255 
     | 
    
         
            +
                        if len(ops.shape(images)) == 4:
         
     | 
| 
       207 
256 
     | 
    
         
             
                            images = ops.expand_dims(images, axis=0)
         
     | 
| 
      
 257 
     | 
    
         
            +
                        if len(ops.shape(vision_mask)) == 1:
         
     | 
| 
      
 258 
     | 
    
         
            +
                            vision_mask = ops.expand_dims(vision_mask, axis=0)
         
     | 
| 
      
 259 
     | 
    
         
            +
                        if len(ops.shape(vision_indices)) == 1:
         
     | 
| 
      
 260 
     | 
    
         
            +
                            vision_indices = ops.expand_dims(vision_indices, axis=0)
         
     | 
| 
       208 
261 
     | 
    
         
             
                        img_embeddings = self.backbone.vision_encoder(images)
         
     | 
| 
       209 
262 
     | 
    
         
             
                    else:
         
     | 
| 
       210 
263 
     | 
    
         
             
                        img_embeddings = None
         
     | 
| 
       211 
     | 
    
         
            -
                         
     | 
| 
      
 264 
     | 
    
         
            +
                        vision_mask = None
         
     | 
| 
       212 
265 
     | 
    
         
             
                        vision_indices = None
         
     | 
| 
       213 
266 
     | 
    
         | 
| 
       214 
267 
     | 
    
         
             
                    # Create and seed cache with a single forward pass.
         
     | 
| 
       215 
268 
     | 
    
         
             
                    hidden_states, cache = self._build_cache(
         
     | 
| 
       216 
269 
     | 
    
         
             
                        token_ids,
         
     | 
| 
       217 
270 
     | 
    
         
             
                        img_embeddings,
         
     | 
| 
       218 
     | 
    
         
            -
                         
     | 
| 
      
 271 
     | 
    
         
            +
                        vision_mask,
         
     | 
| 
       219 
272 
     | 
    
         
             
                        padding_mask,
         
     | 
| 
       220 
273 
     | 
    
         
             
                        vision_indices,
         
     | 
| 
       221 
274 
     | 
    
         
             
                    )
         
     | 
| 
         @@ -230,10 +283,14 @@ class Gemma3CausalLM(CausalLM): 
     | 
|
| 
       230 
283 
     | 
    
         
             
                        cache_update_index = index - 1
         
     | 
| 
       231 
284 
     | 
    
         
             
                        batch_size = ops.shape(prompt)[0]
         
     | 
| 
       232 
285 
     | 
    
         
             
                        prompt = ops.slice(prompt, [0, index - 1], [batch_size, 1])
         
     | 
| 
      
 286 
     | 
    
         
            +
                        sliced_cache_update_mask = ops.slice(
         
     | 
| 
      
 287 
     | 
    
         
            +
                            ~padding_mask, [0, index - 1], [batch_size, 1]
         
     | 
| 
      
 288 
     | 
    
         
            +
                        )
         
     | 
| 
       233 
289 
     | 
    
         
             
                        logits, hidden_states, cache = self.call_with_cache(
         
     | 
| 
       234 
290 
     | 
    
         
             
                            token_ids=prompt,
         
     | 
| 
       235 
291 
     | 
    
         
             
                            cache=cache,
         
     | 
| 
       236 
292 
     | 
    
         
             
                            cache_update_index=cache_update_index,
         
     | 
| 
      
 293 
     | 
    
         
            +
                            cache_update_mask=sliced_cache_update_mask,
         
     | 
| 
       237 
294 
     | 
    
         
             
                        )
         
     | 
| 
       238 
295 
     | 
    
         
             
                        return (
         
     | 
| 
       239 
296 
     | 
    
         
             
                            ops.squeeze(logits, axis=1),
         
     |