keras-hub-nightly 0.20.0.dev202503260356__py3-none-any.whl → 0.20.0.dev202503270400__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,691 @@
1
+ import keras
2
+ import tensorflow as tf
3
+
4
+ from keras_hub.src.api_export import keras_hub_export
5
+ from keras_hub.src.layers.preprocessing.multi_segment_packer import (
6
+ MultiSegmentPacker,
7
+ )
8
+ from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor
9
+ from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3Backbone
10
+ from keras_hub.src.models.gemma3.gemma3_image_converter import (
11
+ Gemma3ImageConverter,
12
+ )
13
+ from keras_hub.src.models.gemma3.gemma3_tokenizer import Gemma3Tokenizer
14
+ from keras_hub.src.utils.tensor_utils import preprocessing_function
15
+ from keras_hub.src.utils.tensor_utils import strip_to_ragged
16
+
17
+ START_OF_IMAGE_TOKEN = "<start_of_image>"
18
+ IMAGE_PLACEHOLDER_TOKEN = "<img>"
19
+ END_OF_IMAGE_TOKEN = "<end_of_image>"
20
+
21
+
22
+ @keras_hub_export("keras_hub.models.Gemma3CausalLMPreprocessor")
23
+ class Gemma3CausalLMPreprocessor(CausalLMPreprocessor):
24
+ """Gemma3 Causal LM preprocessor.
25
+
26
+ This preprocessing layer is meant for use with
27
+ `keras_hub.models.Gemma3CausalLM`. By default, it will take in batches of
28
+ images and strings, and return outputs in a `(x, y, sample_weight)` format,
29
+ where the `y` label is the next token id in the `x` sequence.
30
+
31
+ There is only one mode this layer currently supports, i.e.,
32
+ `image_converter` is `None`. We preprocess the text like any other
33
+ Causal LM preprocessor, i.e., tokenisation, padding, etc. The sequence
34
+ is padded to `sequence_length`.
35
+
36
+ For use with generation, the layer also exposes two methods
37
+ `generate_preprocess()` and `generate_postprocess()`. When this preprocessor
38
+ is attached to a `keras_hub.models.GemmaCausalLM` instance, these methods
39
+ will be called implicitly in `generate()`. They can also be called
40
+ standalone (e.g. to precompute preprocessing inputs for generation in a
41
+ separate process).
42
+
43
+ Args:
44
+ tokenizer: A `keras_hub.models.GemmaTokenizer` instance.
45
+ image_converter: A `keras_hub.layers.ImageConverter` instance. Defaults
46
+ to `None`.
47
+ sequence_length: The length of the packed inputs. Defaults to 1024.
48
+ add_start_token: If `True`, the preprocessor will prepend the tokenizer
49
+ start token to each input sequence. Defaults to `True`.
50
+ add_end_token: If `True`, the preprocessor will append the tokenizer
51
+ end token to each input sequence. Defaults to `True`.
52
+ max_images_per_prompt: int. Permissible number of images per sample in
53
+ the batch. Defaults to 2.
54
+ num_vision_tokens_per_image: int. Number of vision placeholder tokens
55
+ per image. Defaults to 256.
56
+
57
+ Call arguments:
58
+ x: A string, `tf.Tensor` or list of python strings.
59
+ y: Label data. Should always be `None` as the layer generates labels.
60
+ sample_weight: Label weights. Should always be `None` as the layer
61
+ generates label weights.
62
+ sequence_length: Pass to override the configured `sequence_length` of
63
+ the layer.
64
+
65
+ Examples:
66
+ ```python
67
+ # Load the preprocessor from a preset.
68
+ preprocessor = keras_hub.models.Gemma3CausalLMPreprocessor.from_preset(
69
+ "gemma3_4b_en"
70
+ )
71
+
72
+ # Text-only input.
73
+ preprocessor(
74
+ "prompts": ["The quick brown fox jumped."],
75
+ "responses": [""],
76
+ )
77
+
78
+ # Images (pass one image)
79
+ max_images_per_prompt = 2
80
+ preprocessor(
81
+ "prompts": ["The quick brown fox jumped."],
82
+ "responses": [""],
83
+ "images": [np.ones((2, 896, 896, 3)).astype("float32")],
84
+ "num_valid_images": np.array([1,], dtype=np.int32)
85
+ )
86
+ ```
87
+ """
88
+
89
+ backbone_cls = Gemma3Backbone
90
+ tokenizer_cls = Gemma3Tokenizer
91
+ image_converter_cls = Gemma3ImageConverter
92
+
93
+ def __init__(
94
+ self,
95
+ tokenizer,
96
+ image_converter=None,
97
+ sequence_length=1024,
98
+ add_start_token=True,
99
+ add_end_token=True,
100
+ max_images_per_prompt=2,
101
+ num_vision_tokens_per_image=256,
102
+ **kwargs,
103
+ ):
104
+ super().__init__(
105
+ tokenizer=tokenizer,
106
+ sequence_length=sequence_length,
107
+ add_start_token=add_start_token,
108
+ add_end_token=add_end_token,
109
+ **kwargs,
110
+ )
111
+
112
+ if image_converter is not None:
113
+ raise ValueError(
114
+ "Currently, only the text version of the Gemma3 model is "
115
+ "supported."
116
+ )
117
+
118
+ self.image_converter = image_converter
119
+ self.max_images_per_prompt = max_images_per_prompt
120
+ self.num_vision_tokens_per_image = num_vision_tokens_per_image
121
+
122
+ self.text_only_model = self.image_converter is None
123
+
124
+ def build(self, input_shape):
125
+ # Defer packer creation to `build()` so that we can be sure tokenizer
126
+ # assets have loaded when restoring a saved model.
127
+ self.packer = MultiSegmentPacker(
128
+ start_value=self.tokenizer.start_token_id,
129
+ end_value=self.tokenizer.end_token_id,
130
+ pad_value=self.tokenizer.pad_token_id,
131
+ sep_value=[],
132
+ sequence_length=self.sequence_length,
133
+ )
134
+ self.built = True
135
+
136
+ def _format_output(
137
+ self,
138
+ images,
139
+ token_ids,
140
+ text_mask,
141
+ response_mask,
142
+ padding_mask,
143
+ return_labels=False,
144
+ text_only_input=False,
145
+ ):
146
+ if return_labels:
147
+ # Target `y` will be the next token.
148
+ y = token_ids[..., 1:]
149
+ # Only compute the loss for labels in the response.
150
+ sample_weight = response_mask[..., 1:]
151
+
152
+ token_ids = token_ids[..., :-1]
153
+ text_mask = text_mask[..., :-1]
154
+ response_mask = response_mask[..., :-1]
155
+ padding_mask = padding_mask[..., :-1]
156
+
157
+ batch_size, sequence_length = tf.shape(text_mask)
158
+
159
+ if text_only_input:
160
+ vision_indices = tf.ones(
161
+ shape=[
162
+ batch_size,
163
+ 0,
164
+ ],
165
+ dtype=tf.int32,
166
+ )
167
+ else:
168
+ sequence_length = tf.shape(text_mask)[-1]
169
+ flat_text_mask = tf.reshape(
170
+ text_mask, (batch_size * sequence_length)
171
+ )
172
+ vision_indices = tf.where(tf.logical_not(flat_text_mask))
173
+ vision_indices = tf.reshape(vision_indices, (batch_size, -1))
174
+
175
+ # The last token does not have a next token, so we truncate it out.
176
+ x = {
177
+ # Image
178
+ "images": images,
179
+ # Text
180
+ "token_ids": token_ids,
181
+ "vision_indices": vision_indices,
182
+ "text_mask": text_mask,
183
+ "padding_mask": padding_mask,
184
+ }
185
+
186
+ if return_labels:
187
+ return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
188
+ else:
189
+ return x
190
+
191
+ def _get_image_placeholder_ragged_tensor(self, required_length, fill_value):
192
+ """Identifies the number of dummy placeholder tokens to pad input with.
193
+
194
+ Depending on the number of images provided per sample, and the
195
+ allowed number of images, this method identifies the number of vision
196
+ placeholder tokens we need to pad tokens with. This is necessary to
197
+ ensure the same number of image tokens in every sample so as to not
198
+ cause dynamic shape issues with XLA in the interleaving layer.
199
+ """
200
+ required_length = tf.cast(required_length, tf.int32)
201
+ ones_tensor = tf.ones_like(required_length, dtype=tf.int32)
202
+ flattened_tensor = tf.repeat(ones_tensor, required_length)
203
+ row_splits = tf.concat([[0], tf.cumsum(required_length)], axis=0)
204
+ ragged_tensor = tf.RaggedTensor.from_row_splits(
205
+ flattened_tensor, row_splits
206
+ )
207
+ ragged_tensor = ragged_tensor * fill_value
208
+ ragged_tensor = tf.cast(ragged_tensor, tf.int32)
209
+ return ragged_tensor
210
+
211
+ @preprocessing_function
212
+ def call(
213
+ self,
214
+ x,
215
+ y=None,
216
+ sample_weight=None,
217
+ sequence_length=None,
218
+ ):
219
+ sequence_length = sequence_length or self.sequence_length
220
+
221
+ # Extract text part of the input.
222
+ prompts, responses = x["prompts"], x["responses"]
223
+
224
+ # Extract images from the input.
225
+ images = x.get("images", None)
226
+ num_valid_images = x.get("num_valid_images", None)
227
+
228
+ if self.text_only_model:
229
+ if images is not None or num_valid_images is not None:
230
+ raise ValueError(
231
+ "`image_converter` cannot be None when `images` or"
232
+ " `num_valid_images` is not None."
233
+ )
234
+ else:
235
+ # Replace `"<start_of_image>"` in prompts with
236
+ # `"\n\n<start_of_image> <img> * 256 <end_of_image>\n\n"`.
237
+ prompts = tf.strings.regex_replace(
238
+ prompts,
239
+ START_OF_IMAGE_TOKEN,
240
+ f"\n\n{START_OF_IMAGE_TOKEN}"
241
+ + IMAGE_PLACEHOLDER_TOKEN * self.num_vision_tokens_per_image
242
+ + f"{END_OF_IMAGE_TOKEN}\n\n",
243
+ )
244
+
245
+ # Tokenise the inputs.
246
+ prompts = self.tokenizer(prompts)
247
+ responses = self.tokenizer(responses)
248
+
249
+ # All truncation should happen on the text token IDs and not on
250
+ # the dummy placeholder image tokens which we will add at the end.
251
+ # Hence, we use a packer only on the text part first, and then
252
+ # add the padded dummy placeholder tokens separately.
253
+ token_ids, segment_ids = self.packer(
254
+ (prompts, responses),
255
+ sequence_length=sequence_length
256
+ if images is not None
257
+ else sequence_length + 1,
258
+ add_start_value=self.add_start_token,
259
+ add_end_value=self.add_end_token,
260
+ )
261
+
262
+ # If it is a text only model, return immediately.
263
+ if self.text_only_model:
264
+ # The last token does not have a next token, so we truncate it out.
265
+ response_mask = segment_ids == 1
266
+ padding_mask = token_ids != self.tokenizer.pad_token_id
267
+ x = {
268
+ "token_ids": token_ids[..., :-1],
269
+ "padding_mask": padding_mask[..., :-1],
270
+ }
271
+
272
+ # Target `y` will be the next token.
273
+ y = token_ids[..., 1:]
274
+ # Only compute the loss for labels in the response.
275
+ sample_weight = response_mask[..., 1:]
276
+ return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
277
+
278
+ # Vision preprocessing
279
+ batch_size = tf.shape(prompts)[0]
280
+ if images is None:
281
+ # To handle the text-only input case, we need to pass an empty
282
+ # tensor so as to skip the vision part of the model.
283
+ images = tf.ones(
284
+ shape=[
285
+ batch_size,
286
+ 0,
287
+ self.image_converter.image_size[0],
288
+ self.image_converter.image_size[1],
289
+ 3,
290
+ ],
291
+ dtype="float32",
292
+ )
293
+
294
+ text_mask = tf.ones_like(token_ids, dtype=bool)
295
+ padding_mask = token_ids != self.tokenizer.pad_token_id
296
+ response_mask = segment_ids == 1
297
+
298
+ return self._format_output(
299
+ images=images,
300
+ token_ids=token_ids,
301
+ text_mask=text_mask,
302
+ response_mask=response_mask,
303
+ padding_mask=padding_mask,
304
+ return_labels=True,
305
+ text_only_input=True,
306
+ )
307
+
308
+ original_image_shape = tf.shape(images)
309
+ if num_valid_images is None:
310
+ num_valid_images = tf.fill(
311
+ dims=(batch_size,),
312
+ value=self.max_images_per_prompt,
313
+ )
314
+
315
+ # Image inputs checks.
316
+ if original_image_shape[1] != self.max_images_per_prompt:
317
+ raise ValueError(
318
+ "The number of images per sample should be the same as "
319
+ "`max_images_per_prompt`. Received: "
320
+ f"images.shape = {original_image_shape}, "
321
+ f"max_images_per_prompt = {self.max_images_per_prompt}"
322
+ )
323
+ if tf.cast(
324
+ tf.math.reduce_sum(
325
+ tf.cast(
326
+ tf.math.greater(
327
+ num_valid_images, self.max_images_per_prompt
328
+ ),
329
+ dtype=tf.int32,
330
+ )
331
+ ),
332
+ dtype=bool,
333
+ ):
334
+ raise ValueError(
335
+ "`num_valid_images` should have values <= "
336
+ "self.max_images_per_prompt. Received: "
337
+ f"num_valid_images = {num_valid_images}, ",
338
+ f"max_images_per_prompt = {self.max_images_per_prompt}",
339
+ )
340
+
341
+ # Resize, rescale, etc. the images.
342
+ padded_images_shape = tf.shape(images)
343
+ images = tf.reshape(
344
+ images,
345
+ [
346
+ -1,
347
+ padded_images_shape[-3],
348
+ padded_images_shape[-2],
349
+ padded_images_shape[-1],
350
+ ],
351
+ )
352
+ images = self.image_converter(images)
353
+ height = (
354
+ self.image_size[0]
355
+ if self.image_converter.image_size
356
+ else original_image_shape[-3]
357
+ )
358
+ width = (
359
+ self.image_size[1]
360
+ if self.image_converter.image_size
361
+ else original_image_shape[-2]
362
+ )
363
+ images = tf.reshape(
364
+ images,
365
+ [
366
+ padded_images_shape[0],
367
+ self.max_images_per_prompt,
368
+ height,
369
+ width,
370
+ 3,
371
+ ],
372
+ )
373
+
374
+ # Format tokens.
375
+ padding_mask = token_ids != self.tokenizer.pad_token_id
376
+ token_ids = tf.ragged.boolean_mask(token_ids, padding_mask)
377
+ segment_ids = tf.ragged.boolean_mask(segment_ids, padding_mask)
378
+ padding_mask = tf.ragged.boolean_mask(padding_mask, padding_mask)
379
+ response_mask = segment_ids == 1
380
+
381
+ # Using `num_valid_images`, we need to add dummy image tokens at the
382
+ # end of the tokenized text. Ideally, we could have passed an image
383
+ # padding mask to the model, but it won't work with XLA since an
384
+ # `ops.where` on it in the interleaving layer will return different
385
+ # number of images every time. So, we need to fix the number of images.
386
+ vision_placeholder_tensor = self._get_image_placeholder_ragged_tensor(
387
+ (self.max_images_per_prompt - num_valid_images)
388
+ * self.num_vision_tokens_per_image,
389
+ self.tokenizer.token_to_id("<img>"),
390
+ )
391
+ vision_placeholder_tensor = vision_placeholder_tensor.to_tensor(
392
+ shape=[
393
+ batch_size,
394
+ self.max_images_per_prompt * self.num_vision_tokens_per_image,
395
+ ],
396
+ default_value=self.tokenizer.pad_token_id,
397
+ )
398
+
399
+ token_ids_with_placeholder = tf.concat(
400
+ [token_ids, vision_placeholder_tensor], axis=1
401
+ )
402
+
403
+ # Now, pad everything to the same length.
404
+ desired_length = (
405
+ sequence_length
406
+ + self.max_images_per_prompt * self.num_vision_tokens_per_image
407
+ )
408
+ token_ids_with_placeholder = token_ids_with_placeholder.to_tensor(
409
+ shape=[batch_size, desired_length + 1],
410
+ default_value=self.tokenizer.pad_token_id,
411
+ )
412
+ padding_mask_with_placeholder = padding_mask.to_tensor(
413
+ shape=[batch_size, desired_length + 1],
414
+ default_value=False,
415
+ )
416
+ response_mask_with_placeholder = response_mask.to_tensor(
417
+ shape=[batch_size, desired_length + 1],
418
+ default_value=False,
419
+ )
420
+
421
+ text_mask = token_ids_with_placeholder != self.tokenizer.token_to_id(
422
+ "<img>"
423
+ )
424
+
425
+ return self._format_output(
426
+ images=images,
427
+ token_ids=token_ids_with_placeholder,
428
+ text_mask=text_mask,
429
+ response_mask=response_mask_with_placeholder,
430
+ padding_mask=padding_mask_with_placeholder,
431
+ return_labels=True,
432
+ )
433
+
434
+ @preprocessing_function
435
+ def generate_preprocess(
436
+ self,
437
+ x,
438
+ sequence_length=None,
439
+ ):
440
+ """Convert strings to integer token input for generation.
441
+
442
+ Similar to calling the layer for training, this method takes in strings
443
+ or tensor strings, tokenizes and packs the input, and computes a padding
444
+ mask masking all inputs not filled in with a padded value.
445
+
446
+ Unlike calling the layer for training, this method does not compute
447
+ labels and will never append a `tokenizer.end_token_id` to the end of
448
+ the sequence (as generation is expected to continue at the end of the
449
+ inputted prompt).
450
+ """
451
+ if not self.built:
452
+ self.build(None)
453
+
454
+ if isinstance(x, dict):
455
+ images = x.get("images", None)
456
+ num_valid_images = x.get("num_valid_images", None)
457
+ # TODO: do we even need `responses` for generation? Makes sense for
458
+ # finetuning (i.e., `call()`).
459
+ responses = x.get("responses", None)
460
+ prompts = x["prompts"]
461
+ else:
462
+ images = None
463
+ num_valid_images = None
464
+ responses = None
465
+ prompts = x
466
+
467
+ if self.text_only_model:
468
+ if images is not None or num_valid_images is not None:
469
+ raise ValueError(
470
+ "`image_converter` cannot be None when `images` or"
471
+ " `num_valid_images` is not None."
472
+ )
473
+ else:
474
+ # Replace `"<start_of_image>"` in prompts with
475
+ # `"\n\n<start_of_image> <img> * 256 <end_of_image>\n\n"`.
476
+ prompts = tf.strings.regex_replace(
477
+ prompts,
478
+ START_OF_IMAGE_TOKEN,
479
+ f"\n\n{START_OF_IMAGE_TOKEN}"
480
+ + IMAGE_PLACEHOLDER_TOKEN * self.num_vision_tokens_per_image
481
+ + f"{END_OF_IMAGE_TOKEN}\n\n",
482
+ )
483
+
484
+ prompts = self.tokenizer(prompts)
485
+
486
+ if responses is not None:
487
+ responses = self.tokenizer(responses)
488
+ segments = (prompts, responses)
489
+ else:
490
+ segments = (prompts,)
491
+
492
+ token_ids, segment_ids = self.packer(
493
+ segments,
494
+ sequence_length=sequence_length,
495
+ add_end_value=False,
496
+ )
497
+
498
+ # If it is a text only model, return immediately.
499
+ if self.text_only_model:
500
+ response_mask = segment_ids == 1
501
+ padding_mask = token_ids != self.tokenizer.pad_token_id
502
+ return {
503
+ "token_ids": token_ids,
504
+ "padding_mask": padding_mask,
505
+ }
506
+
507
+ # Vision preprocessing
508
+ batch_size = tf.shape(prompts)[0]
509
+ if images is None:
510
+ # To handle the text-only input case, we need to pass an empty
511
+ # tensor so as to skip the vision part of the model.
512
+ images = tf.ones(
513
+ shape=[
514
+ batch_size,
515
+ 0,
516
+ self.image_converter.image_size[0],
517
+ self.image_converter.image_size[1],
518
+ 3,
519
+ ],
520
+ dtype="float32",
521
+ )
522
+
523
+ text_mask = tf.ones_like(token_ids, dtype=bool)
524
+ padding_mask = token_ids != self.tokenizer.pad_token_id
525
+ response_mask = segment_ids == 1
526
+
527
+ return self._format_output(
528
+ images=images,
529
+ token_ids=token_ids,
530
+ text_mask=text_mask,
531
+ response_mask=response_mask,
532
+ padding_mask=padding_mask,
533
+ return_labels=False,
534
+ text_only_input=True,
535
+ )
536
+
537
+ # Pad images.
538
+ original_image_shape = tf.shape(images)
539
+ if num_valid_images is None:
540
+ num_valid_images = tf.fill(
541
+ dims=(batch_size,),
542
+ value=self.max_images_per_prompt,
543
+ )
544
+
545
+ # Image inputs checks.
546
+ if original_image_shape[1] != self.max_images_per_prompt:
547
+ raise ValueError(
548
+ "The number of images per sample should be the same as "
549
+ "`max_images_per_prompt`. Received: "
550
+ f"images.shape = {original_image_shape}, "
551
+ f"max_images_per_prompt = {self.max_images_per_prompt}"
552
+ )
553
+ if tf.cast(
554
+ tf.math.reduce_sum(
555
+ tf.cast(
556
+ tf.math.greater(
557
+ num_valid_images, self.max_images_per_prompt
558
+ ),
559
+ dtype=tf.int32,
560
+ )
561
+ ),
562
+ dtype=bool,
563
+ ):
564
+ raise ValueError(
565
+ "`num_valid_images` should have values <= "
566
+ "self.max_images_per_prompt. Received: "
567
+ f"num_valid_images = {num_valid_images}, ",
568
+ f"max_images_per_prompt = {self.max_images_per_prompt}",
569
+ )
570
+
571
+ # Resize, rescale, etc. the images.
572
+ padded_images_shape = tf.shape(images)
573
+ images = tf.reshape(
574
+ images,
575
+ [
576
+ -1,
577
+ padded_images_shape[-3],
578
+ padded_images_shape[-2],
579
+ padded_images_shape[-1],
580
+ ],
581
+ )
582
+ images = self.image_converter(images)
583
+ height = (
584
+ self.image_size[0]
585
+ if self.image_converter.image_size
586
+ else original_image_shape[-3]
587
+ )
588
+ width = (
589
+ self.image_size[1]
590
+ if self.image_converter.image_size
591
+ else original_image_shape[-2]
592
+ )
593
+ images = tf.reshape(
594
+ images,
595
+ [
596
+ padded_images_shape[0],
597
+ self.max_images_per_prompt,
598
+ height,
599
+ width,
600
+ 3,
601
+ ],
602
+ )
603
+
604
+ padding_mask = token_ids != self.tokenizer.pad_token_id
605
+ token_ids = tf.ragged.boolean_mask(token_ids, padding_mask)
606
+ segment_ids = tf.ragged.boolean_mask(segment_ids, padding_mask)
607
+ padding_mask = tf.ragged.boolean_mask(padding_mask, padding_mask)
608
+ response_mask = segment_ids == 1
609
+
610
+ # Using `num_valid_images`, we need to add dummy image tokens at the
611
+ # end of the tokenized text. Ideally, we could have passed an image
612
+ # padding mask to the model, but it won't work with XLA since an
613
+ # `ops.where` on it in the interleaving layer will return different
614
+ # number of images every time. So, we need to fix the number of images.
615
+ vision_placeholder_tensor = self._get_image_placeholder_ragged_tensor(
616
+ (self.max_images_per_prompt - num_valid_images)
617
+ * self.num_vision_tokens_per_image,
618
+ self.tokenizer.token_to_id("<img>"),
619
+ )
620
+ vision_placeholder_tensor = vision_placeholder_tensor.to_tensor(
621
+ shape=[
622
+ batch_size,
623
+ self.max_images_per_prompt * self.num_vision_tokens_per_image,
624
+ ],
625
+ default_value=self.tokenizer.pad_token_id,
626
+ )
627
+ token_ids_with_placeholder = tf.concat(
628
+ [token_ids, vision_placeholder_tensor], axis=1
629
+ )
630
+
631
+ # Now, pad everything to the same length.
632
+ desired_length = (
633
+ sequence_length
634
+ + self.max_images_per_prompt * self.num_vision_tokens_per_image
635
+ )
636
+ token_ids_with_placeholder = token_ids_with_placeholder.to_tensor(
637
+ shape=[batch_size, desired_length],
638
+ default_value=self.tokenizer.pad_token_id,
639
+ )
640
+ padding_mask_with_placeholder = padding_mask.to_tensor(
641
+ shape=[batch_size, desired_length],
642
+ default_value=False,
643
+ )
644
+ response_mask_with_placeholder = response_mask.to_tensor(
645
+ shape=[batch_size, desired_length],
646
+ default_value=False,
647
+ )
648
+
649
+ text_mask = token_ids_with_placeholder != self.tokenizer.token_to_id(
650
+ "<img>"
651
+ )
652
+
653
+ return self._format_output(
654
+ images=images,
655
+ token_ids=token_ids_with_placeholder,
656
+ text_mask=text_mask,
657
+ response_mask=response_mask_with_placeholder,
658
+ padding_mask=padding_mask_with_placeholder,
659
+ return_labels=False,
660
+ )
661
+
662
+ def get_config(self):
663
+ config = super().get_config()
664
+
665
+ config.update(
666
+ {
667
+ "num_vision_tokens_per_image": self.num_vision_tokens_per_image,
668
+ "max_images_per_prompt": self.max_images_per_prompt,
669
+ }
670
+ )
671
+ return config
672
+
673
+ @preprocessing_function
674
+ def generate_postprocess(
675
+ self,
676
+ x,
677
+ ):
678
+ """Convert integer token output to strings for generation.
679
+
680
+ This method reverses `generate_preprocess()`, by first removing all
681
+ padding and start/end tokens, and then converting the integer sequence
682
+ back to a string.
683
+ """
684
+ if not self.built:
685
+ self.build(None)
686
+
687
+ token_ids, padding_mask = x["token_ids"], x["padding_mask"]
688
+ ids_to_strip = self.tokenizer.special_token_ids
689
+ ids_to_strip += [self.tokenizer.token_to_id("<end_of_image>")]
690
+ token_ids = strip_to_ragged(token_ids, padding_mask, ids_to_strip)
691
+ return self.tokenizer.detokenize(token_ids)