keras-hub-nightly 0.20.0.dev202504020401__py3-none-any.whl → 0.21.0.dev202504040358__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. keras_hub/api/models/__init__.py +5 -20
  2. keras_hub/api/tokenizers/__init__.py +0 -4
  3. keras_hub/src/layers/preprocessing/image_converter.py +26 -16
  4. keras_hub/src/models/gemma/gemma_attention.py +17 -10
  5. keras_hub/src/models/gemma3/gemma3_attention.py +76 -23
  6. keras_hub/src/models/gemma3/gemma3_backbone.py +117 -46
  7. keras_hub/src/models/gemma3/gemma3_causal_lm.py +72 -15
  8. keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +512 -355
  9. keras_hub/src/models/gemma3/gemma3_decoder_block.py +23 -19
  10. keras_hub/src/models/gemma3/gemma3_image_converter.py +6 -0
  11. keras_hub/src/models/gemma3/gemma3_interleave_embeddings.py +56 -16
  12. keras_hub/src/models/gemma3/gemma3_presets.py +74 -8
  13. keras_hub/src/models/gemma3/gemma3_tokenizer.py +9 -0
  14. keras_hub/src/models/gemma3/{gemma3_vit.py → gemma3_vision_encoder.py} +150 -139
  15. keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -2
  16. keras_hub/src/models/llama/llama_attention.py +2 -2
  17. keras_hub/src/models/mistral/mistral_attention.py +2 -2
  18. keras_hub/src/models/phi3/phi3_attention.py +2 -2
  19. keras_hub/src/models/qwen/qwen_attention.py +2 -2
  20. keras_hub/src/models/qwen/qwen_backbone.py +0 -7
  21. keras_hub/src/models/qwen/qwen_causal_lm.py +0 -7
  22. keras_hub/src/models/qwen/qwen_causal_lm_preprocessor.py +0 -7
  23. keras_hub/src/models/qwen/qwen_tokenizer.py +0 -9
  24. keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -1
  25. keras_hub/src/models/roformer_v2/roformer_v2_text_classifier.py +2 -2
  26. keras_hub/src/models/stable_diffusion_3/mmdit.py +2 -2
  27. keras_hub/src/models/vit/vit_image_converter.py +8 -3
  28. keras_hub/src/tests/test_case.py +4 -0
  29. keras_hub/src/utils/keras_utils.py +44 -1
  30. keras_hub/src/utils/tensor_utils.py +6 -0
  31. keras_hub/src/version_utils.py +1 -1
  32. {keras_hub_nightly-0.20.0.dev202504020401.dist-info → keras_hub_nightly-0.21.0.dev202504040358.dist-info}/METADATA +1 -1
  33. {keras_hub_nightly-0.20.0.dev202504020401.dist-info → keras_hub_nightly-0.21.0.dev202504040358.dist-info}/RECORD +35 -35
  34. {keras_hub_nightly-0.20.0.dev202504020401.dist-info → keras_hub_nightly-0.21.0.dev202504040358.dist-info}/WHEEL +0 -0
  35. {keras_hub_nightly-0.20.0.dev202504020401.dist-info → keras_hub_nightly-0.21.0.dev202504040358.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ import numpy as np
1
2
  from keras import ops
2
3
 
3
4
  from keras_hub.src.api_export import keras_hub_export
@@ -8,15 +9,21 @@ from keras_hub.src.models.gemma3.gemma3_causal_lm_preprocessor import (
8
9
  )
9
10
  from keras_hub.src.utils.tensor_utils import any_equal
10
11
 
12
+ try:
13
+ import tensorflow as tf
14
+ except ImportError:
15
+ tf = None
16
+
11
17
 
12
18
  @keras_hub_export("keras_hub.models.Gemma3CausalLM")
13
19
  class Gemma3CausalLM(CausalLM):
14
- """An end-to-end multi modal Gemma3 model for causal language modeling.
20
+ """An end-to-end multimodal Gemma3 model for causal language modeling.
15
21
 
16
22
  A causal language model (LM) predicts the next token based on previous
17
23
  tokens. This task setup can be used to train the model unsupervised on
18
- image and plain text input, or to autoregressively generate plain text
19
- similar to the data used for training.
24
+ images and plain text inputs, or to autoregressively generate plain text
25
+ similar to the data used for training. Note that the model is
26
+ image-text in, text out.
20
27
 
21
28
  This model has a `generate()` method, which generates text based on a
22
29
  prompt. The generation strategy used is controlled by an additional
@@ -30,10 +37,10 @@ class Gemma3CausalLM(CausalLM):
30
37
  when creating the model with `from_preset()`.
31
38
 
32
39
  Args:
33
- backbone: A `keras_hub.models.Gemma3Backbone` instance.
34
40
  preprocessor: A `keras_hub.models.Gemma3CausalLMPreprocessor` or
35
41
  `None`. If `None`, this model will not apply preprocessing, and
36
42
  inputs should be preprocessed before calling the model.
43
+ backbone: A `keras_hub.models.Gemma3Backbone` instance.
37
44
  """
38
45
 
39
46
  backbone_cls = Gemma3Backbone
@@ -79,15 +86,55 @@ class Gemma3CausalLM(CausalLM):
79
86
  **kwargs,
80
87
  )
81
88
 
89
+ def _normalize_generate_inputs(
90
+ self,
91
+ inputs,
92
+ ):
93
+ """Overrides the superclass' method to handle unbatched image inputs."""
94
+ if tf and isinstance(inputs, tf.data.Dataset):
95
+ return inputs.as_numpy_iterator(), False
96
+
97
+ if self.preprocessor is None:
98
+ return [inputs], False
99
+
100
+ def normalize(x):
101
+ if isinstance(x, str):
102
+ return [x], True
103
+ if tf and isinstance(x, tf.Tensor) and x.shape.rank == 0:
104
+ return x[tf.newaxis], True
105
+ return x, False
106
+
107
+ if isinstance(inputs, dict):
108
+ inputs["prompts"], input_is_scalar = normalize(inputs["prompts"])
109
+
110
+ # If prompt is scalar, images can be either a 3D NumPy array/Tensor,
111
+ # or, list of 3D NumPy arrays. Let's uprank images.
112
+ if input_is_scalar and "images" in inputs:
113
+ x = inputs["images"]
114
+ if isinstance(x, np.ndarray) and len(x.shape) == 3:
115
+ inputs["images"] = [x]
116
+ elif tf and isinstance(x, tf.Tensor) and x.shape.rank == 3:
117
+ inputs["images"] = x[tf.newaxis]
118
+ elif isinstance(x, list):
119
+ inputs["images"] = [x]
120
+
121
+ if "responses" in inputs:
122
+ inputs["responses"], _ = normalize(inputs["responses"])
123
+ else:
124
+ inputs, input_is_scalar = normalize(inputs)
125
+
126
+ return [inputs], input_is_scalar
127
+
82
128
  def call_with_cache(
83
129
  self,
84
130
  token_ids,
85
131
  cache,
86
132
  cache_update_index,
87
133
  img_embeddings=None,
88
- text_mask=None,
134
+ vision_mask=None,
89
135
  padding_mask=None,
90
136
  vision_indices=None,
137
+ cache_update_mask=None,
91
138
  ):
92
139
  """Forward pass of `Gemma3CausalLM` with cache.
93
140
 
@@ -139,7 +186,8 @@ class Gemma3CausalLM(CausalLM):
139
186
  cache=current_cache,
140
187
  cache_update_index=cache_update_index,
141
188
  padding_mask=padding_mask,
142
- text_mask=text_mask,
189
+ vision_mask=vision_mask,
190
+ cache_update_mask=cache_update_mask,
143
191
  )
144
192
  caches.append(next_cache)
145
193
  cache = ops.stack(caches, axis=1)
@@ -151,7 +199,7 @@ class Gemma3CausalLM(CausalLM):
151
199
  self,
152
200
  token_ids,
153
201
  img_embeddings,
154
- text_mask,
202
+ vision_mask,
155
203
  padding_mask,
156
204
  vision_indices,
157
205
  ):
@@ -170,11 +218,12 @@ class Gemma3CausalLM(CausalLM):
170
218
  logits, hidden_states, cache = self.call_with_cache(
171
219
  token_ids=token_ids,
172
220
  img_embeddings=img_embeddings,
173
- text_mask=text_mask,
221
+ vision_mask=vision_mask,
174
222
  cache=cache,
175
223
  cache_update_index=0,
176
224
  padding_mask=padding_mask,
177
225
  vision_indices=vision_indices,
226
+ cache_update_mask=None,
178
227
  )
179
228
  return hidden_states, cache
180
229
 
@@ -193,29 +242,33 @@ class Gemma3CausalLM(CausalLM):
193
242
  will stop.
194
243
  """
195
244
 
196
- token_ids, padding_mask, images, text_mask, vision_indices = (
245
+ token_ids, padding_mask, images, vision_mask, vision_indices = (
197
246
  inputs["token_ids"],
198
247
  inputs["padding_mask"],
199
248
  inputs.get("images", None),
200
- inputs.get("text_mask", None),
249
+ inputs.get("vision_mask", None),
201
250
  inputs.get("vision_indices", None),
202
251
  )
203
252
  if not self.backbone.text_only_model:
204
- if len(ops.shape(images)) == 3:
205
- # Handle an unbatched image. Unlike `token_ids` and
206
- # `padding_mask` this will not automatically be upranked.
253
+ # Handle an unbatched image. Unlike `token_ids` and
254
+ # `padding_mask`, this will not automatically be upranked.
255
+ if len(ops.shape(images)) == 4:
207
256
  images = ops.expand_dims(images, axis=0)
257
+ if len(ops.shape(vision_mask)) == 1:
258
+ vision_mask = ops.expand_dims(vision_mask, axis=0)
259
+ if len(ops.shape(vision_indices)) == 1:
260
+ vision_indices = ops.expand_dims(vision_indices, axis=0)
208
261
  img_embeddings = self.backbone.vision_encoder(images)
209
262
  else:
210
263
  img_embeddings = None
211
- text_mask = None
264
+ vision_mask = None
212
265
  vision_indices = None
213
266
 
214
267
  # Create and seed cache with a single forward pass.
215
268
  hidden_states, cache = self._build_cache(
216
269
  token_ids,
217
270
  img_embeddings,
218
- text_mask,
271
+ vision_mask,
219
272
  padding_mask,
220
273
  vision_indices,
221
274
  )
@@ -230,10 +283,14 @@ class Gemma3CausalLM(CausalLM):
230
283
  cache_update_index = index - 1
231
284
  batch_size = ops.shape(prompt)[0]
232
285
  prompt = ops.slice(prompt, [0, index - 1], [batch_size, 1])
286
+ sliced_cache_update_mask = ops.slice(
287
+ ~padding_mask, [0, index - 1], [batch_size, 1]
288
+ )
233
289
  logits, hidden_states, cache = self.call_with_cache(
234
290
  token_ids=prompt,
235
291
  cache=cache,
236
292
  cache_update_index=cache_update_index,
293
+ cache_update_mask=sliced_cache_update_mask,
237
294
  )
238
295
  return (
239
296
  ops.squeeze(logits, axis=1),