keras-hub-nightly 0.20.0.dev202504030357__py3-none-any.whl → 0.21.0.dev202504040358__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. keras_hub/api/models/__init__.py +5 -20
  2. keras_hub/api/tokenizers/__init__.py +0 -4
  3. keras_hub/src/layers/preprocessing/image_converter.py +26 -16
  4. keras_hub/src/models/gemma3/gemma3_attention.py +74 -21
  5. keras_hub/src/models/gemma3/gemma3_backbone.py +117 -46
  6. keras_hub/src/models/gemma3/gemma3_causal_lm.py +72 -15
  7. keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +512 -355
  8. keras_hub/src/models/gemma3/gemma3_decoder_block.py +23 -19
  9. keras_hub/src/models/gemma3/gemma3_image_converter.py +6 -0
  10. keras_hub/src/models/gemma3/gemma3_interleave_embeddings.py +56 -16
  11. keras_hub/src/models/gemma3/gemma3_presets.py +74 -8
  12. keras_hub/src/models/gemma3/gemma3_tokenizer.py +9 -0
  13. keras_hub/src/models/gemma3/{gemma3_vit.py → gemma3_vision_encoder.py} +150 -139
  14. keras_hub/src/models/qwen/qwen_backbone.py +0 -7
  15. keras_hub/src/models/qwen/qwen_causal_lm.py +0 -7
  16. keras_hub/src/models/qwen/qwen_causal_lm_preprocessor.py +0 -7
  17. keras_hub/src/models/qwen/qwen_tokenizer.py +0 -9
  18. keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -1
  19. keras_hub/src/models/roformer_v2/roformer_v2_text_classifier.py +2 -2
  20. keras_hub/src/models/vit/vit_image_converter.py +8 -3
  21. keras_hub/src/tests/test_case.py +4 -0
  22. keras_hub/src/utils/tensor_utils.py +6 -0
  23. keras_hub/src/version_utils.py +1 -1
  24. {keras_hub_nightly-0.20.0.dev202504030357.dist-info → keras_hub_nightly-0.21.0.dev202504040358.dist-info}/METADATA +1 -1
  25. {keras_hub_nightly-0.20.0.dev202504030357.dist-info → keras_hub_nightly-0.21.0.dev202504040358.dist-info}/RECORD +27 -27
  26. {keras_hub_nightly-0.20.0.dev202504030357.dist-info → keras_hub_nightly-0.21.0.dev202504040358.dist-info}/WHEEL +0 -0
  27. {keras_hub_nightly-0.20.0.dev202504030357.dist-info → keras_hub_nightly-0.21.0.dev202504040358.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,5 @@
1
1
  import keras
2
+ import numpy as np
2
3
  import tensorflow as tf
3
4
 
4
5
  from keras_hub.src.api_export import keras_hub_export
@@ -14,24 +15,28 @@ from keras_hub.src.models.gemma3.gemma3_tokenizer import Gemma3Tokenizer
14
15
  from keras_hub.src.utils.tensor_utils import preprocessing_function
15
16
  from keras_hub.src.utils.tensor_utils import strip_to_ragged
16
17
 
17
- START_OF_IMAGE_TOKEN = "<start_of_image>"
18
- IMAGE_PLACEHOLDER_TOKEN = "<img>"
19
- END_OF_IMAGE_TOKEN = "<end_of_image>"
20
-
21
18
 
22
19
  @keras_hub_export("keras_hub.models.Gemma3CausalLMPreprocessor")
23
20
  class Gemma3CausalLMPreprocessor(CausalLMPreprocessor):
24
21
  """Gemma3 Causal LM preprocessor.
25
22
 
26
23
  This preprocessing layer is meant for use with
27
- `keras_hub.models.Gemma3CausalLM`. By default, it will take in batches of
28
- images and strings, and return outputs in a `(x, y, sample_weight)` format,
29
- where the `y` label is the next token id in the `x` sequence.
30
-
31
- There is only one mode this layer currently supports, i.e.,
32
- `image_converter` is `None`. We preprocess the text like any other
33
- Causal LM preprocessor, i.e., tokenisation, padding, etc. The sequence
34
- is padded to `sequence_length`.
24
+ `keras_hub.models.Gemma3CausalLM`. It can be configured in two ways:
25
+ text-only and text + vision, based on whether the passed value of
26
+ `image_converter` is None. For the former, it takes in batches of strings,
27
+ whereas for the latter, it takes in batches of images and strings. It
28
+ returns outputs in a `(x, y, sample_weight)` format, where the `y` label is
29
+ the next token id in the `x` sequence. `sample_weight` is 0 for "prompt"
30
+ tokens, and 1 for "response" tokens, so that the loss is computed only on
31
+ the "response" tokens.
32
+
33
+ For the text + vision case, this layer replaces instance of
34
+ `<start_of_image>` token in the prompt with `num_vision_tokens_per_image`
35
+ placeholder tokens. It also returns indices of where these vision tokens
36
+ are present so that in the model, image embeddings can be placed in the
37
+ right position in the sequence of text embeddings. Note that if
38
+ `max_images_per_prompt` is 2, you can pass either 0, 1, 2 images per sample.
39
+ The value 0 corresponds to text-only input.
35
40
 
36
41
  For use with generation, the layer also exposes two methods
37
42
  `generate_preprocess()` and `generate_postprocess()`. When this preprocessor
@@ -64,25 +69,170 @@ class Gemma3CausalLMPreprocessor(CausalLMPreprocessor):
64
69
 
65
70
  Examples:
66
71
  ```python
72
+ # === Language Gemma3 model ===
67
73
  # Load the preprocessor from a preset.
68
74
  preprocessor = keras_hub.models.Gemma3CausalLMPreprocessor.from_preset(
69
- "gemma3_4b_en"
75
+ "gemma3_instruct_1b"
70
76
  )
71
77
 
72
- # Text-only input.
78
+ # Unbatched inputs.
73
79
  preprocessor(
74
- "prompts": ["The quick brown fox jumped."],
75
- "responses": [""],
80
+ {
81
+ "prompts": "What is the capital of India?",
82
+ "responses": "New Delhi",
83
+ }
76
84
  )
77
85
 
78
- # Images (pass one image)
79
- max_images_per_prompt = 2
86
+ # Batched inputs.
80
87
  preprocessor(
81
- "prompts": ["The quick brown fox jumped."],
82
- "responses": [""],
83
- "images": [np.ones((2, 896, 896, 3)).astype("float32")],
84
- "num_valid_images": np.array([1,], dtype=np.int32)
88
+ {
89
+ "prompts": [
90
+ "What is the capital of India?",
91
+ "What is the capital of Spain?"
92
+ ],
93
+ "responses": ["New Delhi", "Madrid"],
94
+ }
85
95
  )
96
+
97
+ # Apply preprocessing to a `tf.data.Dataset`.
98
+ features = {
99
+ "prompts": [
100
+ "What is the capital of India?",
101
+ "What is the capital of Spain?"
102
+ ],
103
+ "responses": ["New Delhi", "Madrid"],
104
+ }
105
+
106
+ ds = tf.data.Dataset.from_tensor_slices(features)
107
+ ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
108
+
109
+ # Prepare tokens for generation (no end token).
110
+ preprocessor.generate_preprocess(["The quick brown fox jumped."])
111
+
112
+ # Map generation outputs back to strings.
113
+ preprocessor.generate_postprocess({
114
+ 'token_ids': np.array([[2, 818, 3823, 8864, 37423, 32694, 236761, 0]]),
115
+ 'padding_mask': np.array([[ 1, 1, 1, 1, 1, 1, 1, 0]]),
116
+ })
117
+
118
+ # === Vision and Language Gemma3 model ===
119
+ # Load the preprocessor from a preset.
120
+ preprocessor = keras_hub.models.Gemma3CausalLMPreprocessor.from_preset(
121
+ "gemma3_instruct_4b"
122
+ )
123
+
124
+ # text-only inputs (unbatched)
125
+ preprocessor(
126
+ {
127
+ "prompts": "What is the capital of India?",
128
+ "responses": "New Delhi",
129
+ }
130
+ )
131
+
132
+ # text-only inputs (batched)
133
+ preprocessor(
134
+ {
135
+ "prompts": [
136
+ "What is the capital of India?",
137
+ "What is the capital of Spain?"
138
+ ],
139
+ "responses": ["New Delhi", "Madrid"],
140
+ }
141
+ )
142
+
143
+ # Unbatched inputs, with one image.
144
+ preprocessor(
145
+ {
146
+ "prompts": "this is a lily <start_of_image>",
147
+ "responses": "pristine!",
148
+ "images": np.ones((896, 896, 3), dtype="float32")
149
+ }
150
+ )
151
+
152
+ # Unbatched inputs, with two images.
153
+ preprocessor(
154
+ {
155
+ "prompts": "lily: <start_of_image>, sunflower: <start_of_image>",
156
+ "responses": "pristine!",
157
+ "images": [
158
+ np.ones((896, 896, 3), dtype="float32"),
159
+ np.ones((896, 896, 3), dtype="float32")
160
+ ],
161
+ }
162
+ )
163
+
164
+ # Batched inputs, one image per prompt.
165
+ preprocessor(
166
+ {
167
+ "prompts": [
168
+ "this is a lily: <start_of_image>",
169
+ "this is a sunflower: <start_of_image>"
170
+ ],
171
+ "responses": ["pristine!", "radiant!"],
172
+ "images": [
173
+ np.ones((896, 896, 3), dtype="float32"),
174
+ np.ones((896, 896, 3), dtype="float32")
175
+ ]
176
+ }
177
+ )
178
+
179
+ # Can also be written this way.
180
+ preprocessor(
181
+ {
182
+ "prompts": [
183
+ "this is a lily: <start_of_image>",
184
+ "this is a sunflower: <start_of_image>"
185
+ ],
186
+ "responses": ["pristine!", "radiant!"],
187
+ "images": [
188
+ [np.ones((896, 896, 3), dtype="float32")],
189
+ [np.ones((896, 896, 3), dtype="float32")]
190
+ ]
191
+ }
192
+ )
193
+
194
+ # Different number of images in every sample.
195
+ preprocessor(
196
+ {
197
+ "prompts": [
198
+ "Who is this singer: <start_of_image>?",
199
+ "Who are these musicians <start_of_image>, <start_of_image>?"
200
+ ],
201
+ "responses": ["Arijit Singh", "John Lennon, Paul Mccartney"],
202
+ "images": [
203
+ [
204
+ np.ones((896, 896, 3), dtype="float32"),
205
+ np.ones((896, 896, 3), dtype="float32")
206
+ ],
207
+ [np.ones((896, 896, 3), dtype="float32")]
208
+ ]
209
+ }
210
+ )
211
+
212
+ # Apply preprocessing to a `tf.data.Dataset`.
213
+ inputs = {
214
+ "prompts": [
215
+ "Who are these two: <start_of_image>, <start_of_image>",
216
+ "Who is this: <start_of_image>?",
217
+ "What is the capital of India?"
218
+ ],
219
+ "responses": [
220
+ "John Lennon, Paul Mccartney",
221
+ "Arijit Singh",
222
+ "New Delhi"
223
+ ],
224
+ "images": (
225
+ tf.ragged.constant(
226
+ [
227
+ [np.ones((10, 10, 3)), np.ones((10, 10, 3))],
228
+ [np.ones((10, 10, 3))],
229
+ [],
230
+ ]
231
+ )
232
+ )
233
+ }
234
+ ds = tf.data.Dataset.from_tensor_slices(inputs)
235
+ ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
86
236
  ```
87
237
  """
88
238
 
@@ -109,18 +259,34 @@ class Gemma3CausalLMPreprocessor(CausalLMPreprocessor):
109
259
  **kwargs,
110
260
  )
111
261
 
112
- if image_converter is not None:
262
+ # Ensure `max_images_per_prompt * num_vision_tokens_per_image` is
263
+ # greater than `sequence_length`.
264
+ if (
265
+ image_converter is not None
266
+ and sequence_length
267
+ <= max_images_per_prompt * num_vision_tokens_per_image
268
+ ):
113
269
  raise ValueError(
114
- "Currently, only the text version of the Gemma3 model is "
115
- "supported."
270
+ "`sequence_length` should be greater than "
271
+ "`max_images_per_prompt * num_vision_tokens_per_image`."
272
+ f"Received: `sequence_length` = {sequence_length}"
273
+ f"`max_images_per_prompt` = {max_images_per_prompt}"
274
+ "`num_vision_tokens_per_image` = "
275
+ f"{num_vision_tokens_per_image}"
116
276
  )
117
277
 
118
278
  self.image_converter = image_converter
119
279
  self.max_images_per_prompt = max_images_per_prompt
120
280
  self.num_vision_tokens_per_image = num_vision_tokens_per_image
121
281
 
282
+ # The preprocessor and model are "text-only" if `self.image_converter`
283
+ # is `None`.
122
284
  self.text_only_model = self.image_converter is None
123
285
 
286
+ self.image_placeholder = self.tokenizer.image_placeholder
287
+ self.start_of_image_token = self.tokenizer.start_of_image_token
288
+ self.end_of_image_token = self.tokenizer.end_of_image_token
289
+
124
290
  def build(self, input_shape):
125
291
  # Defer packer creation to `build()` so that we can be sure tokenizer
126
292
  # assets have loaded when restoring a saved model.
@@ -133,15 +299,77 @@ class Gemma3CausalLMPreprocessor(CausalLMPreprocessor):
133
299
  )
134
300
  self.built = True
135
301
 
302
+ def _get_vision_indices(self, vision_mask):
303
+ """Computes indices given vision mask, and pads with 0.
304
+
305
+ If `vision_mask` is
306
+
307
+ ```
308
+ [
309
+ [False, True, True], [False, True, False], [False, False, False]
310
+ ]
311
+ ```
312
+
313
+ , then the output will be:
314
+
315
+ ```
316
+ [
317
+ [1, 2, 0], [1, 0, 0], [0, 0, 0]
318
+ ]
319
+ ```
320
+ """
321
+ batch_size, sequence_length = vision_mask.shape
322
+
323
+ vision_mask_flattened = tf.reshape(vision_mask, [-1])
324
+ vision_indices = tf.where(vision_mask_flattened)[..., 0]
325
+ vision_indices = tf.cast(vision_indices, dtype=tf.int32)
326
+
327
+ row_lengths = tf.math.reduce_sum(
328
+ tf.cast(vision_mask, dtype=vision_indices.dtype), axis=1
329
+ )
330
+
331
+ batched_vision_indices = tf.RaggedTensor.from_row_lengths(
332
+ values=vision_indices,
333
+ row_lengths=row_lengths,
334
+ )
335
+
336
+ to_subtract = tf.math.scalar_mul(
337
+ scalar=tf.cast(sequence_length, dtype=tf.int32),
338
+ x=tf.range(
339
+ start=0,
340
+ limit=tf.shape(vision_mask)[0],
341
+ dtype=tf.int32,
342
+ ),
343
+ )
344
+
345
+ # All indices should be independent of other samples in the batch. If
346
+ # not, and if we do sharding along the batch dimension for data
347
+ # parallel, things might get weird.
348
+ batched_vision_indices = tf.math.subtract(
349
+ batched_vision_indices,
350
+ tf.expand_dims(to_subtract, axis=-1),
351
+ )
352
+
353
+ # Pad the indices.
354
+ batched_vision_indices = batched_vision_indices.to_tensor(
355
+ shape=[
356
+ batch_size,
357
+ self.max_images_per_prompt * self.num_vision_tokens_per_image,
358
+ ],
359
+ default_value=0,
360
+ )
361
+ return batched_vision_indices
362
+
136
363
  def _format_output(
137
364
  self,
138
365
  images,
139
366
  token_ids,
140
- text_mask,
367
+ vision_mask,
141
368
  response_mask,
142
369
  padding_mask,
143
370
  return_labels=False,
144
371
  text_only_input=False,
372
+ batched=False,
145
373
  ):
146
374
  if return_labels:
147
375
  # Target `y` will be the next token.
@@ -149,12 +377,13 @@ class Gemma3CausalLMPreprocessor(CausalLMPreprocessor):
149
377
  # Only compute the loss for labels in the response.
150
378
  sample_weight = response_mask[..., 1:]
151
379
 
380
+ # The last token does not have a next token. So, remove it.
152
381
  token_ids = token_ids[..., :-1]
153
- text_mask = text_mask[..., :-1]
382
+ vision_mask = vision_mask[..., :-1]
154
383
  response_mask = response_mask[..., :-1]
155
384
  padding_mask = padding_mask[..., :-1]
156
385
 
157
- batch_size, sequence_length = tf.shape(text_mask)
386
+ batch_size = tf.shape(vision_mask)[0]
158
387
 
159
388
  if text_only_input:
160
389
  vision_indices = tf.ones(
@@ -165,48 +394,109 @@ class Gemma3CausalLMPreprocessor(CausalLMPreprocessor):
165
394
  dtype=tf.int32,
166
395
  )
167
396
  else:
168
- sequence_length = tf.shape(text_mask)[-1]
169
- flat_text_mask = tf.reshape(
170
- text_mask, (batch_size * sequence_length)
171
- )
172
- vision_indices = tf.where(tf.logical_not(flat_text_mask))
173
- vision_indices = tf.reshape(vision_indices, (batch_size, -1))
397
+ vision_indices = self._get_vision_indices(vision_mask=vision_mask)
174
398
 
175
- # The last token does not have a next token, so we truncate it out.
176
399
  x = {
177
400
  # Image
178
- "images": images,
401
+ "images": images if batched else tf.squeeze(images, axis=0),
179
402
  # Text
180
- "token_ids": token_ids,
181
- "vision_indices": vision_indices,
182
- "text_mask": text_mask,
183
- "padding_mask": padding_mask,
403
+ "token_ids": (
404
+ token_ids if batched else tf.squeeze(token_ids, axis=0)
405
+ ),
406
+ "vision_indices": (
407
+ vision_indices
408
+ if batched
409
+ else tf.squeeze(vision_indices, axis=0)
410
+ ),
411
+ # This mask is redundant information. But easier to compute it here
412
+ # than the model forward pass.
413
+ "vision_mask": (
414
+ vision_mask if batched else tf.squeeze(vision_mask, axis=0)
415
+ ),
416
+ "padding_mask": (
417
+ padding_mask if batched else tf.squeeze(padding_mask, axis=0)
418
+ ),
184
419
  }
185
420
 
186
421
  if return_labels:
422
+ if not batched:
423
+ y = tf.squeeze(y, axis=0)
424
+ sample_weight = tf.squeeze(sample_weight, 0)
425
+
187
426
  return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
188
427
  else:
189
428
  return x
190
429
 
191
- def _get_image_placeholder_ragged_tensor(self, required_length, fill_value):
192
- """Identifies the number of dummy placeholder tokens to pad input with.
430
+ def _preprocess_images(self, images, batched):
431
+ desired_height = self.image_converter.image_size[0]
432
+ desired_width = self.image_converter.image_size[1]
433
+
434
+ # Images can be lists/ragged tensors. We need to pad them/truncate them.
435
+ if isinstance(images, (list, np.ndarray)):
436
+ images = tf.ragged.constant(images)
437
+ elif isinstance(images, tf.RaggedTensor):
438
+ pass
439
+ elif isinstance(images, tf.Tensor):
440
+ images = tf.RaggedTensor.from_tensor(images)
441
+ else:
442
+ # Attempt to convert anyway. This handles the case where
443
+ # the inputs might be `jax.Array`, `torch.Tensor`. To check the
444
+ # type, we will have to import all three frameworks, which is
445
+ # undesirable.
446
+ try:
447
+ images = tf.RaggedTensor.from_tensor(images)
448
+ except: # noqa: E722
449
+ raise ValueError(
450
+ "`images` should be a list, ragged tensor, dense tensor."
451
+ f"Received: `type(images)` = {type(images)}"
452
+ )
193
453
 
194
- Depending on the number of images provided per sample, and the
195
- allowed number of images, this method identifies the number of vision
196
- placeholder tokens we need to pad tokens with. This is necessary to
197
- ensure the same number of image tokens in every sample so as to not
198
- cause dynamic shape issues with XLA in the interleaving layer.
199
- """
200
- required_length = tf.cast(required_length, tf.int32)
201
- ones_tensor = tf.ones_like(required_length, dtype=tf.int32)
202
- flattened_tensor = tf.repeat(ones_tensor, required_length)
203
- row_splits = tf.concat([[0], tf.cumsum(required_length)], axis=0)
204
- ragged_tensor = tf.RaggedTensor.from_row_splits(
205
- flattened_tensor, row_splits
454
+ if not batched:
455
+ images = tf.expand_dims(images, axis=0)
456
+
457
+ # If the input is a list of images, instead of list of lists of images.
458
+ if len(images.shape) == 4:
459
+ images = tf.expand_dims(images, axis=1)
460
+
461
+ # Convert to dense tensor.
462
+ images = images.to_tensor(
463
+ shape=[None, self.max_images_per_prompt, None, None, 3],
464
+ default_value=0,
465
+ )
466
+
467
+ # Resize, rescale, etc. the images.
468
+ original_images_shape = tf.shape(images)
469
+
470
+ # Before passing through image converter, we need to collapse the
471
+ # first two dimensions (`batch_size`, `max_images_per_prompt`) into one.
472
+ images = tf.reshape(
473
+ images,
474
+ [
475
+ -1,
476
+ original_images_shape[-3],
477
+ original_images_shape[-2],
478
+ original_images_shape[-1],
479
+ ],
480
+ )
481
+ images = self.image_converter(images)
482
+
483
+ if keras.config.backend() == "torch" and not isinstance(
484
+ images, tf.Tensor
485
+ ):
486
+ images = images.cpu()
487
+
488
+ # Recover the rank.
489
+ images = tf.reshape(
490
+ images,
491
+ [
492
+ original_images_shape[0],
493
+ self.max_images_per_prompt,
494
+ desired_height,
495
+ desired_width,
496
+ original_images_shape[-1],
497
+ ],
206
498
  )
207
- ragged_tensor = ragged_tensor * fill_value
208
- ragged_tensor = tf.cast(ragged_tensor, tf.int32)
209
- return ragged_tensor
499
+ return images
210
500
 
211
501
  @preprocessing_function
212
502
  def call(
@@ -218,52 +508,76 @@ class Gemma3CausalLMPreprocessor(CausalLMPreprocessor):
218
508
  ):
219
509
  sequence_length = sequence_length or self.sequence_length
220
510
 
511
+ # === Input extraction and validation ===
512
+
221
513
  # Extract text part of the input.
222
514
  prompts, responses = x["prompts"], x["responses"]
223
515
 
516
+ # Find out if the input is batched/not batched. Uprank if not batched.
517
+ # In other preprocessors, we don't have to do this, but here, all
518
+ # the following logic (indices, etc.) uses tensors with a batch dim.
519
+ # We will squeeze these back at the end.
520
+ batched = True
521
+ if isinstance(prompts, str):
522
+ batched = False
523
+ prompts = [prompts]
524
+ responses = [responses]
525
+ if isinstance(prompts, tf.Tensor) and len(prompts.shape) == 0:
526
+ batched = False
527
+ prompts = tf.expand_dims(prompts, axis=0)
528
+ responses = tf.expand_dims(responses, axis=0)
529
+
224
530
  # Extract images from the input.
225
531
  images = x.get("images", None)
226
- num_valid_images = x.get("num_valid_images", None)
227
532
 
228
- if self.text_only_model:
229
- if images is not None or num_valid_images is not None:
230
- raise ValueError(
231
- "`image_converter` cannot be None when `images` or"
232
- " `num_valid_images` is not None."
233
- )
234
- else:
235
- # Replace `"<start_of_image>"` in prompts with
236
- # `"\n\n<start_of_image> <img> * 256 <end_of_image>\n\n"`.
533
+ # There are 8 cases, based on values of
534
+ # a = `self.text_only_model`, b = `images` is `None`, and whether
535
+ # c = `<start_of_image>` token is present in `prompts`.
536
+ # F F F, F F T -> Raise error if #`<start_of_image>` <0, or
537
+ # > `max_images_per_prompt`.
538
+ # F T F -> Return empty images and vision indices
539
+ # F T T -> Return empty images and vision indices to the model.
540
+ # T F F, T F T -> Raise error.
541
+ # T T F -> Only token IDs and padding mask are returned.
542
+ # T T T -> Only token IDs and padding mask are returned.
543
+
544
+ if self.text_only_model and images is not None:
545
+ raise ValueError(
546
+ "The initialized preprocessor/model is text-only, but "
547
+ " `images` is not `None`."
548
+ )
549
+
550
+ # Add image placeholder tokens. Replace `"<start_of_image>"` in
551
+ # prompts with
552
+ # `"\n\n<start_of_image> <img> * 256 <end_of_image>\n\n"`.
553
+ if not self.text_only_model:
237
554
  prompts = tf.strings.regex_replace(
238
555
  prompts,
239
- START_OF_IMAGE_TOKEN,
240
- f"\n\n{START_OF_IMAGE_TOKEN}"
241
- + IMAGE_PLACEHOLDER_TOKEN * self.num_vision_tokens_per_image
242
- + f"{END_OF_IMAGE_TOKEN}\n\n",
556
+ self.start_of_image_token,
557
+ f"\n\n{self.start_of_image_token}"
558
+ + self.image_placeholder * self.num_vision_tokens_per_image
559
+ + f"{self.end_of_image_token}\n\n",
243
560
  )
244
561
 
562
+ # === Tokenization, padding, etc. ===
563
+
245
564
  # Tokenise the inputs.
246
565
  prompts = self.tokenizer(prompts)
247
566
  responses = self.tokenizer(responses)
248
567
 
249
- # All truncation should happen on the text token IDs and not on
250
- # the dummy placeholder image tokens which we will add at the end.
251
- # Hence, we use a packer only on the text part first, and then
252
- # add the padded dummy placeholder tokens separately.
568
+ # Padding.
253
569
  token_ids, segment_ids = self.packer(
254
570
  (prompts, responses),
255
- sequence_length=sequence_length
256
- if images is not None
257
- else sequence_length + 1,
571
+ sequence_length=sequence_length + 1,
258
572
  add_start_value=self.add_start_token,
259
573
  add_end_value=self.add_end_token,
260
574
  )
575
+ response_mask = segment_ids == 1
576
+ padding_mask = token_ids != self.tokenizer.pad_token_id
261
577
 
262
- # If it is a text only model, return immediately.
578
+ # === Text Model ===
263
579
  if self.text_only_model:
264
580
  # The last token does not have a next token, so we truncate it out.
265
- response_mask = segment_ids == 1
266
- padding_mask = token_ids != self.tokenizer.pad_token_id
267
581
  x = {
268
582
  "token_ids": token_ids[..., :-1],
269
583
  "padding_mask": padding_mask[..., :-1],
@@ -273,162 +587,68 @@ class Gemma3CausalLMPreprocessor(CausalLMPreprocessor):
273
587
  y = token_ids[..., 1:]
274
588
  # Only compute the loss for labels in the response.
275
589
  sample_weight = response_mask[..., 1:]
590
+
591
+ # Squeeze if not batched.
592
+ if not batched:
593
+ x["token_ids"] = tf.squeeze(x["token_ids"], axis=0)
594
+ x["padding_mask"] = tf.squeeze(x["padding_mask"], axis=0)
595
+ y = tf.squeeze(y, axis=0)
596
+ sample_weight = tf.squeeze(sample_weight, axis=0)
597
+
276
598
  return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
277
599
 
278
- # Vision preprocessing
600
+ # === Vision processing ===
601
+
279
602
  batch_size = tf.shape(prompts)[0]
603
+ desired_height = self.image_converter.image_size[0]
604
+ desired_width = self.image_converter.image_size[1]
280
605
  if images is None:
606
+ # == Branch: vision model, with `None` value for `images` ==
607
+
281
608
  # To handle the text-only input case, we need to pass an empty
282
- # tensor so as to skip the vision part of the model.
609
+ # tensor so as to skip the vision layers of the model.
610
+
611
+ # TODO: Once functional models accept `None` inputs, consider
612
+ # passing this as `None` directly.
283
613
  images = tf.ones(
284
614
  shape=[
285
615
  batch_size,
286
616
  0,
287
- self.image_converter.image_size[0],
288
- self.image_converter.image_size[1],
617
+ desired_height,
618
+ desired_width,
289
619
  3,
290
620
  ],
291
621
  dtype="float32",
292
622
  )
293
623
 
294
- text_mask = tf.ones_like(token_ids, dtype=bool)
295
- padding_mask = token_ids != self.tokenizer.pad_token_id
296
- response_mask = segment_ids == 1
624
+ vision_mask = tf.zeros_like(token_ids, dtype=bool)
297
625
 
298
626
  return self._format_output(
299
627
  images=images,
300
628
  token_ids=token_ids,
301
- text_mask=text_mask,
629
+ vision_mask=vision_mask,
302
630
  response_mask=response_mask,
303
631
  padding_mask=padding_mask,
304
632
  return_labels=True,
305
633
  text_only_input=True,
634
+ batched=batched,
306
635
  )
307
636
 
308
- original_image_shape = tf.shape(images)
309
- if num_valid_images is None:
310
- num_valid_images = tf.fill(
311
- dims=(batch_size,),
312
- value=self.max_images_per_prompt,
313
- )
637
+ # == Branch: vision model, with non-`None` value for `images` ==
314
638
 
315
- # Image inputs checks.
316
- if original_image_shape[1] != self.max_images_per_prompt:
317
- raise ValueError(
318
- "The number of images per sample should be the same as "
319
- "`max_images_per_prompt`. Received: "
320
- f"images.shape = {original_image_shape}, "
321
- f"max_images_per_prompt = {self.max_images_per_prompt}"
322
- )
323
- if tf.cast(
324
- tf.math.reduce_sum(
325
- tf.cast(
326
- tf.math.greater(
327
- num_valid_images, self.max_images_per_prompt
328
- ),
329
- dtype=tf.int32,
330
- )
331
- ),
332
- dtype=bool,
333
- ):
334
- raise ValueError(
335
- "`num_valid_images` should have values <= "
336
- "self.max_images_per_prompt. Received: "
337
- f"num_valid_images = {num_valid_images}, ",
338
- f"max_images_per_prompt = {self.max_images_per_prompt}",
339
- )
639
+ images = self._preprocess_images(images=images, batched=batched)
340
640
 
341
- # Resize, rescale, etc. the images.
342
- padded_images_shape = tf.shape(images)
343
- images = tf.reshape(
344
- images,
345
- [
346
- -1,
347
- padded_images_shape[-3],
348
- padded_images_shape[-2],
349
- padded_images_shape[-1],
350
- ],
351
- )
352
- images = self.image_converter(images)
353
- height = (
354
- self.image_size[0]
355
- if self.image_converter.image_size
356
- else original_image_shape[-3]
357
- )
358
- width = (
359
- self.image_size[1]
360
- if self.image_converter.image_size
361
- else original_image_shape[-2]
362
- )
363
- images = tf.reshape(
364
- images,
365
- [
366
- padded_images_shape[0],
367
- self.max_images_per_prompt,
368
- height,
369
- width,
370
- 3,
371
- ],
372
- )
373
-
374
- # Format tokens.
375
- padding_mask = token_ids != self.tokenizer.pad_token_id
376
- token_ids = tf.ragged.boolean_mask(token_ids, padding_mask)
377
- segment_ids = tf.ragged.boolean_mask(segment_ids, padding_mask)
378
- padding_mask = tf.ragged.boolean_mask(padding_mask, padding_mask)
379
- response_mask = segment_ids == 1
380
-
381
- # Using `num_valid_images`, we need to add dummy image tokens at the
382
- # end of the tokenized text. Ideally, we could have passed an image
383
- # padding mask to the model, but it won't work with XLA since an
384
- # `ops.where` on it in the interleaving layer will return different
385
- # number of images every time. So, we need to fix the number of images.
386
- vision_placeholder_tensor = self._get_image_placeholder_ragged_tensor(
387
- (self.max_images_per_prompt - num_valid_images)
388
- * self.num_vision_tokens_per_image,
389
- self.tokenizer.token_to_id("<img>"),
390
- )
391
- vision_placeholder_tensor = vision_placeholder_tensor.to_tensor(
392
- shape=[
393
- batch_size,
394
- self.max_images_per_prompt * self.num_vision_tokens_per_image,
395
- ],
396
- default_value=self.tokenizer.pad_token_id,
397
- )
398
-
399
- token_ids_with_placeholder = tf.concat(
400
- [token_ids, vision_placeholder_tensor], axis=1
401
- )
402
-
403
- # Now, pad everything to the same length.
404
- desired_length = (
405
- sequence_length
406
- + self.max_images_per_prompt * self.num_vision_tokens_per_image
407
- )
408
- token_ids_with_placeholder = token_ids_with_placeholder.to_tensor(
409
- shape=[batch_size, desired_length + 1],
410
- default_value=self.tokenizer.pad_token_id,
411
- )
412
- padding_mask_with_placeholder = padding_mask.to_tensor(
413
- shape=[batch_size, desired_length + 1],
414
- default_value=False,
415
- )
416
- response_mask_with_placeholder = response_mask.to_tensor(
417
- shape=[batch_size, desired_length + 1],
418
- default_value=False,
419
- )
420
-
421
- text_mask = token_ids_with_placeholder != self.tokenizer.token_to_id(
422
- "<img>"
423
- )
641
+ vision_mask = token_ids == self.tokenizer.image_placeholder_id
424
642
 
425
643
  return self._format_output(
426
644
  images=images,
427
- token_ids=token_ids_with_placeholder,
428
- text_mask=text_mask,
429
- response_mask=response_mask_with_placeholder,
430
- padding_mask=padding_mask_with_placeholder,
645
+ token_ids=token_ids,
646
+ vision_mask=vision_mask,
647
+ response_mask=response_mask,
648
+ padding_mask=padding_mask,
431
649
  return_labels=True,
650
+ text_only_input=False,
651
+ batched=batched,
432
652
  )
433
653
 
434
654
  @preprocessing_function
@@ -448,39 +668,59 @@ class Gemma3CausalLMPreprocessor(CausalLMPreprocessor):
448
668
  the sequence (as generation is expected to continue at the end of the
449
669
  inputted prompt).
450
670
  """
671
+
451
672
  if not self.built:
452
673
  self.build(None)
453
674
 
675
+ # Extract inputs.
454
676
  if isinstance(x, dict):
455
677
  images = x.get("images", None)
456
- num_valid_images = x.get("num_valid_images", None)
678
+
457
679
  # TODO: do we even need `responses` for generation? Makes sense for
458
- # finetuning (i.e., `call()`).
680
+ # finetuning only (i.e., `call()`).
459
681
  responses = x.get("responses", None)
460
682
  prompts = x["prompts"]
461
683
  else:
462
684
  images = None
463
- num_valid_images = None
464
685
  responses = None
465
686
  prompts = x
466
687
 
467
- if self.text_only_model:
468
- if images is not None or num_valid_images is not None:
469
- raise ValueError(
470
- "`image_converter` cannot be None when `images` or"
471
- " `num_valid_images` is not None."
472
- )
473
- else:
474
- # Replace `"<start_of_image>"` in prompts with
475
- # `"\n\n<start_of_image> <img> * 256 <end_of_image>\n\n"`.
688
+ # Find out if the input is batched/not batched. Uprank if not batched.
689
+ # In other preprocessors, we don't have to do this, but here, all
690
+ # the following logic (indices, etc.) uses tensors with a batch dim.
691
+ # We will squeeze these back at the end.
692
+ batched = True
693
+ if isinstance(prompts, str):
694
+ batched = False
695
+ prompts = [prompts]
696
+ if responses is not None:
697
+ responses = [responses]
698
+ if isinstance(prompts, tf.Tensor) and len(prompts.shape) == 0:
699
+ batched = False
700
+ prompts = tf.expand_dims(prompts, axis=0)
701
+ if responses is not None:
702
+ responses = tf.expand_dims(responses, axis=0)
703
+
704
+ # We have the same 8 cases here, as in `call()`.
705
+ if self.text_only_model and images is not None:
706
+ raise ValueError(
707
+ "The initialized preprocessor/model is text-only, but "
708
+ " `images` is not `None`."
709
+ )
710
+
711
+ # Add image placeholder tokens. Replace `"<start_of_image>"` in
712
+ # prompts with
713
+ # `"\n\n<start_of_image> <img> * 256 <end_of_image>\n\n"`.
714
+ if not self.text_only_model:
476
715
  prompts = tf.strings.regex_replace(
477
716
  prompts,
478
- START_OF_IMAGE_TOKEN,
479
- f"\n\n{START_OF_IMAGE_TOKEN}"
480
- + IMAGE_PLACEHOLDER_TOKEN * self.num_vision_tokens_per_image
481
- + f"{END_OF_IMAGE_TOKEN}\n\n",
717
+ self.start_of_image_token,
718
+ f"\n\n{self.start_of_image_token}"
719
+ + self.image_placeholder * self.num_vision_tokens_per_image
720
+ + f"{self.end_of_image_token}\n\n",
482
721
  )
483
722
 
723
+ # === Tokenization, padding, etc. ===
484
724
  prompts = self.tokenizer(prompts)
485
725
 
486
726
  if responses is not None:
@@ -489,174 +729,79 @@ class Gemma3CausalLMPreprocessor(CausalLMPreprocessor):
489
729
  else:
490
730
  segments = (prompts,)
491
731
 
732
+ # Padding.
492
733
  token_ids, segment_ids = self.packer(
493
734
  segments,
494
735
  sequence_length=sequence_length,
495
736
  add_end_value=False,
496
737
  )
738
+ response_mask = segment_ids == 1
739
+ padding_mask = token_ids != self.tokenizer.pad_token_id
497
740
 
498
- # If it is a text only model, return immediately.
741
+ # === Text Model ===
499
742
  if self.text_only_model:
500
- response_mask = segment_ids == 1
501
- padding_mask = token_ids != self.tokenizer.pad_token_id
502
743
  return {
503
- "token_ids": token_ids,
504
- "padding_mask": padding_mask,
744
+ "token_ids": (
745
+ token_ids if batched else tf.squeeze(token_ids, axis=0)
746
+ ),
747
+ "padding_mask": (
748
+ padding_mask
749
+ if batched
750
+ else tf.squeeze(padding_mask, axis=0)
751
+ ),
505
752
  }
506
753
 
507
- # Vision preprocessing
754
+ # === Vision processing ===
755
+
508
756
  batch_size = tf.shape(prompts)[0]
757
+ desired_height = self.image_converter.image_size[0]
758
+ desired_width = self.image_converter.image_size[1]
509
759
  if images is None:
760
+ # == Branch: vision model, with `None` value for `images` ==
761
+
510
762
  # To handle the text-only input case, we need to pass an empty
511
- # tensor so as to skip the vision part of the model.
763
+ # tensor so as to skip the vision layers of the model.
764
+
765
+ # TODO: Once functional models accept `None` inputs, consider
766
+ # passing this as `None` directly.
512
767
  images = tf.ones(
513
768
  shape=[
514
769
  batch_size,
515
770
  0,
516
- self.image_converter.image_size[0],
517
- self.image_converter.image_size[1],
771
+ desired_height,
772
+ desired_width,
518
773
  3,
519
774
  ],
520
775
  dtype="float32",
521
776
  )
522
777
 
523
- text_mask = tf.ones_like(token_ids, dtype=bool)
524
- padding_mask = token_ids != self.tokenizer.pad_token_id
525
- response_mask = segment_ids == 1
778
+ vision_mask = tf.zeros_like(token_ids, dtype=bool)
526
779
 
527
780
  return self._format_output(
528
781
  images=images,
529
782
  token_ids=token_ids,
530
- text_mask=text_mask,
783
+ vision_mask=vision_mask,
531
784
  response_mask=response_mask,
532
785
  padding_mask=padding_mask,
533
786
  return_labels=False,
534
787
  text_only_input=True,
788
+ batched=batched,
535
789
  )
536
790
 
537
- # Pad images.
538
- original_image_shape = tf.shape(images)
539
- if num_valid_images is None:
540
- num_valid_images = tf.fill(
541
- dims=(batch_size,),
542
- value=self.max_images_per_prompt,
543
- )
544
-
545
- # Image inputs checks.
546
- if original_image_shape[1] != self.max_images_per_prompt:
547
- raise ValueError(
548
- "The number of images per sample should be the same as "
549
- "`max_images_per_prompt`. Received: "
550
- f"images.shape = {original_image_shape}, "
551
- f"max_images_per_prompt = {self.max_images_per_prompt}"
552
- )
553
- if tf.cast(
554
- tf.math.reduce_sum(
555
- tf.cast(
556
- tf.math.greater(
557
- num_valid_images, self.max_images_per_prompt
558
- ),
559
- dtype=tf.int32,
560
- )
561
- ),
562
- dtype=bool,
563
- ):
564
- raise ValueError(
565
- "`num_valid_images` should have values <= "
566
- "self.max_images_per_prompt. Received: "
567
- f"num_valid_images = {num_valid_images}, ",
568
- f"max_images_per_prompt = {self.max_images_per_prompt}",
569
- )
570
-
571
- # Resize, rescale, etc. the images.
572
- padded_images_shape = tf.shape(images)
573
- images = tf.reshape(
574
- images,
575
- [
576
- -1,
577
- padded_images_shape[-3],
578
- padded_images_shape[-2],
579
- padded_images_shape[-1],
580
- ],
581
- )
582
- images = self.image_converter(images)
583
- height = (
584
- self.image_size[0]
585
- if self.image_converter.image_size
586
- else original_image_shape[-3]
587
- )
588
- width = (
589
- self.image_size[1]
590
- if self.image_converter.image_size
591
- else original_image_shape[-2]
592
- )
593
- images = tf.reshape(
594
- images,
595
- [
596
- padded_images_shape[0],
597
- self.max_images_per_prompt,
598
- height,
599
- width,
600
- 3,
601
- ],
602
- )
603
-
604
- padding_mask = token_ids != self.tokenizer.pad_token_id
605
- token_ids = tf.ragged.boolean_mask(token_ids, padding_mask)
606
- segment_ids = tf.ragged.boolean_mask(segment_ids, padding_mask)
607
- padding_mask = tf.ragged.boolean_mask(padding_mask, padding_mask)
608
- response_mask = segment_ids == 1
609
-
610
- # Using `num_valid_images`, we need to add dummy image tokens at the
611
- # end of the tokenized text. Ideally, we could have passed an image
612
- # padding mask to the model, but it won't work with XLA since an
613
- # `ops.where` on it in the interleaving layer will return different
614
- # number of images every time. So, we need to fix the number of images.
615
- vision_placeholder_tensor = self._get_image_placeholder_ragged_tensor(
616
- (self.max_images_per_prompt - num_valid_images)
617
- * self.num_vision_tokens_per_image,
618
- self.tokenizer.token_to_id("<img>"),
619
- )
620
- vision_placeholder_tensor = vision_placeholder_tensor.to_tensor(
621
- shape=[
622
- batch_size,
623
- self.max_images_per_prompt * self.num_vision_tokens_per_image,
624
- ],
625
- default_value=self.tokenizer.pad_token_id,
626
- )
627
- token_ids_with_placeholder = tf.concat(
628
- [token_ids, vision_placeholder_tensor], axis=1
629
- )
630
-
631
- # Now, pad everything to the same length.
632
- desired_length = (
633
- sequence_length
634
- + self.max_images_per_prompt * self.num_vision_tokens_per_image
635
- )
636
- token_ids_with_placeholder = token_ids_with_placeholder.to_tensor(
637
- shape=[batch_size, desired_length],
638
- default_value=self.tokenizer.pad_token_id,
639
- )
640
- padding_mask_with_placeholder = padding_mask.to_tensor(
641
- shape=[batch_size, desired_length],
642
- default_value=False,
643
- )
644
- response_mask_with_placeholder = response_mask.to_tensor(
645
- shape=[batch_size, desired_length],
646
- default_value=False,
647
- )
791
+ # == Branch: vision model, with non-`None` value for `images` ==
792
+ images = self._preprocess_images(images=images, batched=batched)
648
793
 
649
- text_mask = token_ids_with_placeholder != self.tokenizer.token_to_id(
650
- "<img>"
651
- )
794
+ vision_mask = token_ids == self.tokenizer.image_placeholder_id
652
795
 
653
796
  return self._format_output(
654
797
  images=images,
655
- token_ids=token_ids_with_placeholder,
656
- text_mask=text_mask,
657
- response_mask=response_mask_with_placeholder,
658
- padding_mask=padding_mask_with_placeholder,
798
+ token_ids=token_ids,
799
+ vision_mask=vision_mask,
800
+ response_mask=response_mask,
801
+ padding_mask=padding_mask,
659
802
  return_labels=False,
803
+ text_only_input=False,
804
+ batched=batched,
660
805
  )
661
806
 
662
807
  def get_config(self):
@@ -686,6 +831,18 @@ class Gemma3CausalLMPreprocessor(CausalLMPreprocessor):
686
831
 
687
832
  token_ids, padding_mask = x["token_ids"], x["padding_mask"]
688
833
  ids_to_strip = self.tokenizer.special_token_ids
689
- ids_to_strip += [self.tokenizer.token_to_id("<end_of_image>")]
834
+
835
+ # We do not want to strip SoI token because it is provided by the user.
836
+ if self.tokenizer.start_of_image_token_id in ids_to_strip:
837
+ ids_to_strip.remove(self.tokenizer.start_of_image_token_id)
838
+
690
839
  token_ids = strip_to_ragged(token_ids, padding_mask, ids_to_strip)
691
840
  return self.tokenizer.detokenize(token_ids)
841
+
842
+ @property
843
+ def max_images_per_prompt(self):
844
+ return self._max_images_per_prompt
845
+
846
+ @max_images_per_prompt.setter
847
+ def max_images_per_prompt(self, value):
848
+ self._max_images_per_prompt = value