keras-hub-nightly 0.20.0.dev202504030357__py3-none-any.whl → 0.21.0.dev202504040358__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/models/__init__.py +5 -20
- keras_hub/api/tokenizers/__init__.py +0 -4
- keras_hub/src/layers/preprocessing/image_converter.py +26 -16
- keras_hub/src/models/gemma3/gemma3_attention.py +74 -21
- keras_hub/src/models/gemma3/gemma3_backbone.py +117 -46
- keras_hub/src/models/gemma3/gemma3_causal_lm.py +72 -15
- keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +512 -355
- keras_hub/src/models/gemma3/gemma3_decoder_block.py +23 -19
- keras_hub/src/models/gemma3/gemma3_image_converter.py +6 -0
- keras_hub/src/models/gemma3/gemma3_interleave_embeddings.py +56 -16
- keras_hub/src/models/gemma3/gemma3_presets.py +74 -8
- keras_hub/src/models/gemma3/gemma3_tokenizer.py +9 -0
- keras_hub/src/models/gemma3/{gemma3_vit.py → gemma3_vision_encoder.py} +150 -139
- keras_hub/src/models/qwen/qwen_backbone.py +0 -7
- keras_hub/src/models/qwen/qwen_causal_lm.py +0 -7
- keras_hub/src/models/qwen/qwen_causal_lm_preprocessor.py +0 -7
- keras_hub/src/models/qwen/qwen_tokenizer.py +0 -9
- keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -1
- keras_hub/src/models/roformer_v2/roformer_v2_text_classifier.py +2 -2
- keras_hub/src/models/vit/vit_image_converter.py +8 -3
- keras_hub/src/tests/test_case.py +4 -0
- keras_hub/src/utils/tensor_utils.py +6 -0
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.20.0.dev202504030357.dist-info → keras_hub_nightly-0.21.0.dev202504040358.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.20.0.dev202504030357.dist-info → keras_hub_nightly-0.21.0.dev202504040358.dist-info}/RECORD +27 -27
- {keras_hub_nightly-0.20.0.dev202504030357.dist-info → keras_hub_nightly-0.21.0.dev202504040358.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.20.0.dev202504030357.dist-info → keras_hub_nightly-0.21.0.dev202504040358.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,5 @@
|
|
1
1
|
import keras
|
2
|
+
import numpy as np
|
2
3
|
import tensorflow as tf
|
3
4
|
|
4
5
|
from keras_hub.src.api_export import keras_hub_export
|
@@ -14,24 +15,28 @@ from keras_hub.src.models.gemma3.gemma3_tokenizer import Gemma3Tokenizer
|
|
14
15
|
from keras_hub.src.utils.tensor_utils import preprocessing_function
|
15
16
|
from keras_hub.src.utils.tensor_utils import strip_to_ragged
|
16
17
|
|
17
|
-
START_OF_IMAGE_TOKEN = "<start_of_image>"
|
18
|
-
IMAGE_PLACEHOLDER_TOKEN = "<img>"
|
19
|
-
END_OF_IMAGE_TOKEN = "<end_of_image>"
|
20
|
-
|
21
18
|
|
22
19
|
@keras_hub_export("keras_hub.models.Gemma3CausalLMPreprocessor")
|
23
20
|
class Gemma3CausalLMPreprocessor(CausalLMPreprocessor):
|
24
21
|
"""Gemma3 Causal LM preprocessor.
|
25
22
|
|
26
23
|
This preprocessing layer is meant for use with
|
27
|
-
`keras_hub.models.Gemma3CausalLM`.
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
24
|
+
`keras_hub.models.Gemma3CausalLM`. It can be configured in two ways:
|
25
|
+
text-only and text + vision, based on whether the passed value of
|
26
|
+
`image_converter` is None. For the former, it takes in batches of strings,
|
27
|
+
whereas for the latter, it takes in batches of images and strings. It
|
28
|
+
returns outputs in a `(x, y, sample_weight)` format, where the `y` label is
|
29
|
+
the next token id in the `x` sequence. `sample_weight` is 0 for "prompt"
|
30
|
+
tokens, and 1 for "response" tokens, so that the loss is computed only on
|
31
|
+
the "response" tokens.
|
32
|
+
|
33
|
+
For the text + vision case, this layer replaces instance of
|
34
|
+
`<start_of_image>` token in the prompt with `num_vision_tokens_per_image`
|
35
|
+
placeholder tokens. It also returns indices of where these vision tokens
|
36
|
+
are present so that in the model, image embeddings can be placed in the
|
37
|
+
right position in the sequence of text embeddings. Note that if
|
38
|
+
`max_images_per_prompt` is 2, you can pass either 0, 1, 2 images per sample.
|
39
|
+
The value 0 corresponds to text-only input.
|
35
40
|
|
36
41
|
For use with generation, the layer also exposes two methods
|
37
42
|
`generate_preprocess()` and `generate_postprocess()`. When this preprocessor
|
@@ -64,25 +69,170 @@ class Gemma3CausalLMPreprocessor(CausalLMPreprocessor):
|
|
64
69
|
|
65
70
|
Examples:
|
66
71
|
```python
|
72
|
+
# === Language Gemma3 model ===
|
67
73
|
# Load the preprocessor from a preset.
|
68
74
|
preprocessor = keras_hub.models.Gemma3CausalLMPreprocessor.from_preset(
|
69
|
-
"
|
75
|
+
"gemma3_instruct_1b"
|
70
76
|
)
|
71
77
|
|
72
|
-
#
|
78
|
+
# Unbatched inputs.
|
73
79
|
preprocessor(
|
74
|
-
|
75
|
-
|
80
|
+
{
|
81
|
+
"prompts": "What is the capital of India?",
|
82
|
+
"responses": "New Delhi",
|
83
|
+
}
|
76
84
|
)
|
77
85
|
|
78
|
-
#
|
79
|
-
max_images_per_prompt = 2
|
86
|
+
# Batched inputs.
|
80
87
|
preprocessor(
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
88
|
+
{
|
89
|
+
"prompts": [
|
90
|
+
"What is the capital of India?",
|
91
|
+
"What is the capital of Spain?"
|
92
|
+
],
|
93
|
+
"responses": ["New Delhi", "Madrid"],
|
94
|
+
}
|
85
95
|
)
|
96
|
+
|
97
|
+
# Apply preprocessing to a `tf.data.Dataset`.
|
98
|
+
features = {
|
99
|
+
"prompts": [
|
100
|
+
"What is the capital of India?",
|
101
|
+
"What is the capital of Spain?"
|
102
|
+
],
|
103
|
+
"responses": ["New Delhi", "Madrid"],
|
104
|
+
}
|
105
|
+
|
106
|
+
ds = tf.data.Dataset.from_tensor_slices(features)
|
107
|
+
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
|
108
|
+
|
109
|
+
# Prepare tokens for generation (no end token).
|
110
|
+
preprocessor.generate_preprocess(["The quick brown fox jumped."])
|
111
|
+
|
112
|
+
# Map generation outputs back to strings.
|
113
|
+
preprocessor.generate_postprocess({
|
114
|
+
'token_ids': np.array([[2, 818, 3823, 8864, 37423, 32694, 236761, 0]]),
|
115
|
+
'padding_mask': np.array([[ 1, 1, 1, 1, 1, 1, 1, 0]]),
|
116
|
+
})
|
117
|
+
|
118
|
+
# === Vision and Language Gemma3 model ===
|
119
|
+
# Load the preprocessor from a preset.
|
120
|
+
preprocessor = keras_hub.models.Gemma3CausalLMPreprocessor.from_preset(
|
121
|
+
"gemma3_instruct_4b"
|
122
|
+
)
|
123
|
+
|
124
|
+
# text-only inputs (unbatched)
|
125
|
+
preprocessor(
|
126
|
+
{
|
127
|
+
"prompts": "What is the capital of India?",
|
128
|
+
"responses": "New Delhi",
|
129
|
+
}
|
130
|
+
)
|
131
|
+
|
132
|
+
# text-only inputs (batched)
|
133
|
+
preprocessor(
|
134
|
+
{
|
135
|
+
"prompts": [
|
136
|
+
"What is the capital of India?",
|
137
|
+
"What is the capital of Spain?"
|
138
|
+
],
|
139
|
+
"responses": ["New Delhi", "Madrid"],
|
140
|
+
}
|
141
|
+
)
|
142
|
+
|
143
|
+
# Unbatched inputs, with one image.
|
144
|
+
preprocessor(
|
145
|
+
{
|
146
|
+
"prompts": "this is a lily <start_of_image>",
|
147
|
+
"responses": "pristine!",
|
148
|
+
"images": np.ones((896, 896, 3), dtype="float32")
|
149
|
+
}
|
150
|
+
)
|
151
|
+
|
152
|
+
# Unbatched inputs, with two images.
|
153
|
+
preprocessor(
|
154
|
+
{
|
155
|
+
"prompts": "lily: <start_of_image>, sunflower: <start_of_image>",
|
156
|
+
"responses": "pristine!",
|
157
|
+
"images": [
|
158
|
+
np.ones((896, 896, 3), dtype="float32"),
|
159
|
+
np.ones((896, 896, 3), dtype="float32")
|
160
|
+
],
|
161
|
+
}
|
162
|
+
)
|
163
|
+
|
164
|
+
# Batched inputs, one image per prompt.
|
165
|
+
preprocessor(
|
166
|
+
{
|
167
|
+
"prompts": [
|
168
|
+
"this is a lily: <start_of_image>",
|
169
|
+
"this is a sunflower: <start_of_image>"
|
170
|
+
],
|
171
|
+
"responses": ["pristine!", "radiant!"],
|
172
|
+
"images": [
|
173
|
+
np.ones((896, 896, 3), dtype="float32"),
|
174
|
+
np.ones((896, 896, 3), dtype="float32")
|
175
|
+
]
|
176
|
+
}
|
177
|
+
)
|
178
|
+
|
179
|
+
# Can also be written this way.
|
180
|
+
preprocessor(
|
181
|
+
{
|
182
|
+
"prompts": [
|
183
|
+
"this is a lily: <start_of_image>",
|
184
|
+
"this is a sunflower: <start_of_image>"
|
185
|
+
],
|
186
|
+
"responses": ["pristine!", "radiant!"],
|
187
|
+
"images": [
|
188
|
+
[np.ones((896, 896, 3), dtype="float32")],
|
189
|
+
[np.ones((896, 896, 3), dtype="float32")]
|
190
|
+
]
|
191
|
+
}
|
192
|
+
)
|
193
|
+
|
194
|
+
# Different number of images in every sample.
|
195
|
+
preprocessor(
|
196
|
+
{
|
197
|
+
"prompts": [
|
198
|
+
"Who is this singer: <start_of_image>?",
|
199
|
+
"Who are these musicians <start_of_image>, <start_of_image>?"
|
200
|
+
],
|
201
|
+
"responses": ["Arijit Singh", "John Lennon, Paul Mccartney"],
|
202
|
+
"images": [
|
203
|
+
[
|
204
|
+
np.ones((896, 896, 3), dtype="float32"),
|
205
|
+
np.ones((896, 896, 3), dtype="float32")
|
206
|
+
],
|
207
|
+
[np.ones((896, 896, 3), dtype="float32")]
|
208
|
+
]
|
209
|
+
}
|
210
|
+
)
|
211
|
+
|
212
|
+
# Apply preprocessing to a `tf.data.Dataset`.
|
213
|
+
inputs = {
|
214
|
+
"prompts": [
|
215
|
+
"Who are these two: <start_of_image>, <start_of_image>",
|
216
|
+
"Who is this: <start_of_image>?",
|
217
|
+
"What is the capital of India?"
|
218
|
+
],
|
219
|
+
"responses": [
|
220
|
+
"John Lennon, Paul Mccartney",
|
221
|
+
"Arijit Singh",
|
222
|
+
"New Delhi"
|
223
|
+
],
|
224
|
+
"images": (
|
225
|
+
tf.ragged.constant(
|
226
|
+
[
|
227
|
+
[np.ones((10, 10, 3)), np.ones((10, 10, 3))],
|
228
|
+
[np.ones((10, 10, 3))],
|
229
|
+
[],
|
230
|
+
]
|
231
|
+
)
|
232
|
+
)
|
233
|
+
}
|
234
|
+
ds = tf.data.Dataset.from_tensor_slices(inputs)
|
235
|
+
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
|
86
236
|
```
|
87
237
|
"""
|
88
238
|
|
@@ -109,18 +259,34 @@ class Gemma3CausalLMPreprocessor(CausalLMPreprocessor):
|
|
109
259
|
**kwargs,
|
110
260
|
)
|
111
261
|
|
112
|
-
|
262
|
+
# Ensure `max_images_per_prompt * num_vision_tokens_per_image` is
|
263
|
+
# greater than `sequence_length`.
|
264
|
+
if (
|
265
|
+
image_converter is not None
|
266
|
+
and sequence_length
|
267
|
+
<= max_images_per_prompt * num_vision_tokens_per_image
|
268
|
+
):
|
113
269
|
raise ValueError(
|
114
|
-
"
|
115
|
-
"
|
270
|
+
"`sequence_length` should be greater than "
|
271
|
+
"`max_images_per_prompt * num_vision_tokens_per_image`."
|
272
|
+
f"Received: `sequence_length` = {sequence_length}"
|
273
|
+
f"`max_images_per_prompt` = {max_images_per_prompt}"
|
274
|
+
"`num_vision_tokens_per_image` = "
|
275
|
+
f"{num_vision_tokens_per_image}"
|
116
276
|
)
|
117
277
|
|
118
278
|
self.image_converter = image_converter
|
119
279
|
self.max_images_per_prompt = max_images_per_prompt
|
120
280
|
self.num_vision_tokens_per_image = num_vision_tokens_per_image
|
121
281
|
|
282
|
+
# The preprocessor and model are "text-only" if `self.image_converter`
|
283
|
+
# is `None`.
|
122
284
|
self.text_only_model = self.image_converter is None
|
123
285
|
|
286
|
+
self.image_placeholder = self.tokenizer.image_placeholder
|
287
|
+
self.start_of_image_token = self.tokenizer.start_of_image_token
|
288
|
+
self.end_of_image_token = self.tokenizer.end_of_image_token
|
289
|
+
|
124
290
|
def build(self, input_shape):
|
125
291
|
# Defer packer creation to `build()` so that we can be sure tokenizer
|
126
292
|
# assets have loaded when restoring a saved model.
|
@@ -133,15 +299,77 @@ class Gemma3CausalLMPreprocessor(CausalLMPreprocessor):
|
|
133
299
|
)
|
134
300
|
self.built = True
|
135
301
|
|
302
|
+
def _get_vision_indices(self, vision_mask):
|
303
|
+
"""Computes indices given vision mask, and pads with 0.
|
304
|
+
|
305
|
+
If `vision_mask` is
|
306
|
+
|
307
|
+
```
|
308
|
+
[
|
309
|
+
[False, True, True], [False, True, False], [False, False, False]
|
310
|
+
]
|
311
|
+
```
|
312
|
+
|
313
|
+
, then the output will be:
|
314
|
+
|
315
|
+
```
|
316
|
+
[
|
317
|
+
[1, 2, 0], [1, 0, 0], [0, 0, 0]
|
318
|
+
]
|
319
|
+
```
|
320
|
+
"""
|
321
|
+
batch_size, sequence_length = vision_mask.shape
|
322
|
+
|
323
|
+
vision_mask_flattened = tf.reshape(vision_mask, [-1])
|
324
|
+
vision_indices = tf.where(vision_mask_flattened)[..., 0]
|
325
|
+
vision_indices = tf.cast(vision_indices, dtype=tf.int32)
|
326
|
+
|
327
|
+
row_lengths = tf.math.reduce_sum(
|
328
|
+
tf.cast(vision_mask, dtype=vision_indices.dtype), axis=1
|
329
|
+
)
|
330
|
+
|
331
|
+
batched_vision_indices = tf.RaggedTensor.from_row_lengths(
|
332
|
+
values=vision_indices,
|
333
|
+
row_lengths=row_lengths,
|
334
|
+
)
|
335
|
+
|
336
|
+
to_subtract = tf.math.scalar_mul(
|
337
|
+
scalar=tf.cast(sequence_length, dtype=tf.int32),
|
338
|
+
x=tf.range(
|
339
|
+
start=0,
|
340
|
+
limit=tf.shape(vision_mask)[0],
|
341
|
+
dtype=tf.int32,
|
342
|
+
),
|
343
|
+
)
|
344
|
+
|
345
|
+
# All indices should be independent of other samples in the batch. If
|
346
|
+
# not, and if we do sharding along the batch dimension for data
|
347
|
+
# parallel, things might get weird.
|
348
|
+
batched_vision_indices = tf.math.subtract(
|
349
|
+
batched_vision_indices,
|
350
|
+
tf.expand_dims(to_subtract, axis=-1),
|
351
|
+
)
|
352
|
+
|
353
|
+
# Pad the indices.
|
354
|
+
batched_vision_indices = batched_vision_indices.to_tensor(
|
355
|
+
shape=[
|
356
|
+
batch_size,
|
357
|
+
self.max_images_per_prompt * self.num_vision_tokens_per_image,
|
358
|
+
],
|
359
|
+
default_value=0,
|
360
|
+
)
|
361
|
+
return batched_vision_indices
|
362
|
+
|
136
363
|
def _format_output(
|
137
364
|
self,
|
138
365
|
images,
|
139
366
|
token_ids,
|
140
|
-
|
367
|
+
vision_mask,
|
141
368
|
response_mask,
|
142
369
|
padding_mask,
|
143
370
|
return_labels=False,
|
144
371
|
text_only_input=False,
|
372
|
+
batched=False,
|
145
373
|
):
|
146
374
|
if return_labels:
|
147
375
|
# Target `y` will be the next token.
|
@@ -149,12 +377,13 @@ class Gemma3CausalLMPreprocessor(CausalLMPreprocessor):
|
|
149
377
|
# Only compute the loss for labels in the response.
|
150
378
|
sample_weight = response_mask[..., 1:]
|
151
379
|
|
380
|
+
# The last token does not have a next token. So, remove it.
|
152
381
|
token_ids = token_ids[..., :-1]
|
153
|
-
|
382
|
+
vision_mask = vision_mask[..., :-1]
|
154
383
|
response_mask = response_mask[..., :-1]
|
155
384
|
padding_mask = padding_mask[..., :-1]
|
156
385
|
|
157
|
-
batch_size
|
386
|
+
batch_size = tf.shape(vision_mask)[0]
|
158
387
|
|
159
388
|
if text_only_input:
|
160
389
|
vision_indices = tf.ones(
|
@@ -165,48 +394,109 @@ class Gemma3CausalLMPreprocessor(CausalLMPreprocessor):
|
|
165
394
|
dtype=tf.int32,
|
166
395
|
)
|
167
396
|
else:
|
168
|
-
|
169
|
-
flat_text_mask = tf.reshape(
|
170
|
-
text_mask, (batch_size * sequence_length)
|
171
|
-
)
|
172
|
-
vision_indices = tf.where(tf.logical_not(flat_text_mask))
|
173
|
-
vision_indices = tf.reshape(vision_indices, (batch_size, -1))
|
397
|
+
vision_indices = self._get_vision_indices(vision_mask=vision_mask)
|
174
398
|
|
175
|
-
# The last token does not have a next token, so we truncate it out.
|
176
399
|
x = {
|
177
400
|
# Image
|
178
|
-
"images": images,
|
401
|
+
"images": images if batched else tf.squeeze(images, axis=0),
|
179
402
|
# Text
|
180
|
-
"token_ids":
|
181
|
-
|
182
|
-
|
183
|
-
"
|
403
|
+
"token_ids": (
|
404
|
+
token_ids if batched else tf.squeeze(token_ids, axis=0)
|
405
|
+
),
|
406
|
+
"vision_indices": (
|
407
|
+
vision_indices
|
408
|
+
if batched
|
409
|
+
else tf.squeeze(vision_indices, axis=0)
|
410
|
+
),
|
411
|
+
# This mask is redundant information. But easier to compute it here
|
412
|
+
# than the model forward pass.
|
413
|
+
"vision_mask": (
|
414
|
+
vision_mask if batched else tf.squeeze(vision_mask, axis=0)
|
415
|
+
),
|
416
|
+
"padding_mask": (
|
417
|
+
padding_mask if batched else tf.squeeze(padding_mask, axis=0)
|
418
|
+
),
|
184
419
|
}
|
185
420
|
|
186
421
|
if return_labels:
|
422
|
+
if not batched:
|
423
|
+
y = tf.squeeze(y, axis=0)
|
424
|
+
sample_weight = tf.squeeze(sample_weight, 0)
|
425
|
+
|
187
426
|
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
|
188
427
|
else:
|
189
428
|
return x
|
190
429
|
|
191
|
-
def
|
192
|
-
|
430
|
+
def _preprocess_images(self, images, batched):
|
431
|
+
desired_height = self.image_converter.image_size[0]
|
432
|
+
desired_width = self.image_converter.image_size[1]
|
433
|
+
|
434
|
+
# Images can be lists/ragged tensors. We need to pad them/truncate them.
|
435
|
+
if isinstance(images, (list, np.ndarray)):
|
436
|
+
images = tf.ragged.constant(images)
|
437
|
+
elif isinstance(images, tf.RaggedTensor):
|
438
|
+
pass
|
439
|
+
elif isinstance(images, tf.Tensor):
|
440
|
+
images = tf.RaggedTensor.from_tensor(images)
|
441
|
+
else:
|
442
|
+
# Attempt to convert anyway. This handles the case where
|
443
|
+
# the inputs might be `jax.Array`, `torch.Tensor`. To check the
|
444
|
+
# type, we will have to import all three frameworks, which is
|
445
|
+
# undesirable.
|
446
|
+
try:
|
447
|
+
images = tf.RaggedTensor.from_tensor(images)
|
448
|
+
except: # noqa: E722
|
449
|
+
raise ValueError(
|
450
|
+
"`images` should be a list, ragged tensor, dense tensor."
|
451
|
+
f"Received: `type(images)` = {type(images)}"
|
452
|
+
)
|
193
453
|
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
454
|
+
if not batched:
|
455
|
+
images = tf.expand_dims(images, axis=0)
|
456
|
+
|
457
|
+
# If the input is a list of images, instead of list of lists of images.
|
458
|
+
if len(images.shape) == 4:
|
459
|
+
images = tf.expand_dims(images, axis=1)
|
460
|
+
|
461
|
+
# Convert to dense tensor.
|
462
|
+
images = images.to_tensor(
|
463
|
+
shape=[None, self.max_images_per_prompt, None, None, 3],
|
464
|
+
default_value=0,
|
465
|
+
)
|
466
|
+
|
467
|
+
# Resize, rescale, etc. the images.
|
468
|
+
original_images_shape = tf.shape(images)
|
469
|
+
|
470
|
+
# Before passing through image converter, we need to collapse the
|
471
|
+
# first two dimensions (`batch_size`, `max_images_per_prompt`) into one.
|
472
|
+
images = tf.reshape(
|
473
|
+
images,
|
474
|
+
[
|
475
|
+
-1,
|
476
|
+
original_images_shape[-3],
|
477
|
+
original_images_shape[-2],
|
478
|
+
original_images_shape[-1],
|
479
|
+
],
|
480
|
+
)
|
481
|
+
images = self.image_converter(images)
|
482
|
+
|
483
|
+
if keras.config.backend() == "torch" and not isinstance(
|
484
|
+
images, tf.Tensor
|
485
|
+
):
|
486
|
+
images = images.cpu()
|
487
|
+
|
488
|
+
# Recover the rank.
|
489
|
+
images = tf.reshape(
|
490
|
+
images,
|
491
|
+
[
|
492
|
+
original_images_shape[0],
|
493
|
+
self.max_images_per_prompt,
|
494
|
+
desired_height,
|
495
|
+
desired_width,
|
496
|
+
original_images_shape[-1],
|
497
|
+
],
|
206
498
|
)
|
207
|
-
|
208
|
-
ragged_tensor = tf.cast(ragged_tensor, tf.int32)
|
209
|
-
return ragged_tensor
|
499
|
+
return images
|
210
500
|
|
211
501
|
@preprocessing_function
|
212
502
|
def call(
|
@@ -218,52 +508,76 @@ class Gemma3CausalLMPreprocessor(CausalLMPreprocessor):
|
|
218
508
|
):
|
219
509
|
sequence_length = sequence_length or self.sequence_length
|
220
510
|
|
511
|
+
# === Input extraction and validation ===
|
512
|
+
|
221
513
|
# Extract text part of the input.
|
222
514
|
prompts, responses = x["prompts"], x["responses"]
|
223
515
|
|
516
|
+
# Find out if the input is batched/not batched. Uprank if not batched.
|
517
|
+
# In other preprocessors, we don't have to do this, but here, all
|
518
|
+
# the following logic (indices, etc.) uses tensors with a batch dim.
|
519
|
+
# We will squeeze these back at the end.
|
520
|
+
batched = True
|
521
|
+
if isinstance(prompts, str):
|
522
|
+
batched = False
|
523
|
+
prompts = [prompts]
|
524
|
+
responses = [responses]
|
525
|
+
if isinstance(prompts, tf.Tensor) and len(prompts.shape) == 0:
|
526
|
+
batched = False
|
527
|
+
prompts = tf.expand_dims(prompts, axis=0)
|
528
|
+
responses = tf.expand_dims(responses, axis=0)
|
529
|
+
|
224
530
|
# Extract images from the input.
|
225
531
|
images = x.get("images", None)
|
226
|
-
num_valid_images = x.get("num_valid_images", None)
|
227
532
|
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
533
|
+
# There are 8 cases, based on values of
|
534
|
+
# a = `self.text_only_model`, b = `images` is `None`, and whether
|
535
|
+
# c = `<start_of_image>` token is present in `prompts`.
|
536
|
+
# F F F, F F T -> Raise error if #`<start_of_image>` <0, or
|
537
|
+
# > `max_images_per_prompt`.
|
538
|
+
# F T F -> Return empty images and vision indices
|
539
|
+
# F T T -> Return empty images and vision indices to the model.
|
540
|
+
# T F F, T F T -> Raise error.
|
541
|
+
# T T F -> Only token IDs and padding mask are returned.
|
542
|
+
# T T T -> Only token IDs and padding mask are returned.
|
543
|
+
|
544
|
+
if self.text_only_model and images is not None:
|
545
|
+
raise ValueError(
|
546
|
+
"The initialized preprocessor/model is text-only, but "
|
547
|
+
" `images` is not `None`."
|
548
|
+
)
|
549
|
+
|
550
|
+
# Add image placeholder tokens. Replace `"<start_of_image>"` in
|
551
|
+
# prompts with
|
552
|
+
# `"\n\n<start_of_image> <img> * 256 <end_of_image>\n\n"`.
|
553
|
+
if not self.text_only_model:
|
237
554
|
prompts = tf.strings.regex_replace(
|
238
555
|
prompts,
|
239
|
-
|
240
|
-
f"\n\n{
|
241
|
-
+
|
242
|
-
+ f"{
|
556
|
+
self.start_of_image_token,
|
557
|
+
f"\n\n{self.start_of_image_token}"
|
558
|
+
+ self.image_placeholder * self.num_vision_tokens_per_image
|
559
|
+
+ f"{self.end_of_image_token}\n\n",
|
243
560
|
)
|
244
561
|
|
562
|
+
# === Tokenization, padding, etc. ===
|
563
|
+
|
245
564
|
# Tokenise the inputs.
|
246
565
|
prompts = self.tokenizer(prompts)
|
247
566
|
responses = self.tokenizer(responses)
|
248
567
|
|
249
|
-
#
|
250
|
-
# the dummy placeholder image tokens which we will add at the end.
|
251
|
-
# Hence, we use a packer only on the text part first, and then
|
252
|
-
# add the padded dummy placeholder tokens separately.
|
568
|
+
# Padding.
|
253
569
|
token_ids, segment_ids = self.packer(
|
254
570
|
(prompts, responses),
|
255
|
-
sequence_length=sequence_length
|
256
|
-
if images is not None
|
257
|
-
else sequence_length + 1,
|
571
|
+
sequence_length=sequence_length + 1,
|
258
572
|
add_start_value=self.add_start_token,
|
259
573
|
add_end_value=self.add_end_token,
|
260
574
|
)
|
575
|
+
response_mask = segment_ids == 1
|
576
|
+
padding_mask = token_ids != self.tokenizer.pad_token_id
|
261
577
|
|
262
|
-
#
|
578
|
+
# === Text Model ===
|
263
579
|
if self.text_only_model:
|
264
580
|
# The last token does not have a next token, so we truncate it out.
|
265
|
-
response_mask = segment_ids == 1
|
266
|
-
padding_mask = token_ids != self.tokenizer.pad_token_id
|
267
581
|
x = {
|
268
582
|
"token_ids": token_ids[..., :-1],
|
269
583
|
"padding_mask": padding_mask[..., :-1],
|
@@ -273,162 +587,68 @@ class Gemma3CausalLMPreprocessor(CausalLMPreprocessor):
|
|
273
587
|
y = token_ids[..., 1:]
|
274
588
|
# Only compute the loss for labels in the response.
|
275
589
|
sample_weight = response_mask[..., 1:]
|
590
|
+
|
591
|
+
# Squeeze if not batched.
|
592
|
+
if not batched:
|
593
|
+
x["token_ids"] = tf.squeeze(x["token_ids"], axis=0)
|
594
|
+
x["padding_mask"] = tf.squeeze(x["padding_mask"], axis=0)
|
595
|
+
y = tf.squeeze(y, axis=0)
|
596
|
+
sample_weight = tf.squeeze(sample_weight, axis=0)
|
597
|
+
|
276
598
|
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
|
277
599
|
|
278
|
-
# Vision
|
600
|
+
# === Vision processing ===
|
601
|
+
|
279
602
|
batch_size = tf.shape(prompts)[0]
|
603
|
+
desired_height = self.image_converter.image_size[0]
|
604
|
+
desired_width = self.image_converter.image_size[1]
|
280
605
|
if images is None:
|
606
|
+
# == Branch: vision model, with `None` value for `images` ==
|
607
|
+
|
281
608
|
# To handle the text-only input case, we need to pass an empty
|
282
|
-
# tensor so as to skip the vision
|
609
|
+
# tensor so as to skip the vision layers of the model.
|
610
|
+
|
611
|
+
# TODO: Once functional models accept `None` inputs, consider
|
612
|
+
# passing this as `None` directly.
|
283
613
|
images = tf.ones(
|
284
614
|
shape=[
|
285
615
|
batch_size,
|
286
616
|
0,
|
287
|
-
|
288
|
-
|
617
|
+
desired_height,
|
618
|
+
desired_width,
|
289
619
|
3,
|
290
620
|
],
|
291
621
|
dtype="float32",
|
292
622
|
)
|
293
623
|
|
294
|
-
|
295
|
-
padding_mask = token_ids != self.tokenizer.pad_token_id
|
296
|
-
response_mask = segment_ids == 1
|
624
|
+
vision_mask = tf.zeros_like(token_ids, dtype=bool)
|
297
625
|
|
298
626
|
return self._format_output(
|
299
627
|
images=images,
|
300
628
|
token_ids=token_ids,
|
301
|
-
|
629
|
+
vision_mask=vision_mask,
|
302
630
|
response_mask=response_mask,
|
303
631
|
padding_mask=padding_mask,
|
304
632
|
return_labels=True,
|
305
633
|
text_only_input=True,
|
634
|
+
batched=batched,
|
306
635
|
)
|
307
636
|
|
308
|
-
|
309
|
-
if num_valid_images is None:
|
310
|
-
num_valid_images = tf.fill(
|
311
|
-
dims=(batch_size,),
|
312
|
-
value=self.max_images_per_prompt,
|
313
|
-
)
|
637
|
+
# == Branch: vision model, with non-`None` value for `images` ==
|
314
638
|
|
315
|
-
|
316
|
-
if original_image_shape[1] != self.max_images_per_prompt:
|
317
|
-
raise ValueError(
|
318
|
-
"The number of images per sample should be the same as "
|
319
|
-
"`max_images_per_prompt`. Received: "
|
320
|
-
f"images.shape = {original_image_shape}, "
|
321
|
-
f"max_images_per_prompt = {self.max_images_per_prompt}"
|
322
|
-
)
|
323
|
-
if tf.cast(
|
324
|
-
tf.math.reduce_sum(
|
325
|
-
tf.cast(
|
326
|
-
tf.math.greater(
|
327
|
-
num_valid_images, self.max_images_per_prompt
|
328
|
-
),
|
329
|
-
dtype=tf.int32,
|
330
|
-
)
|
331
|
-
),
|
332
|
-
dtype=bool,
|
333
|
-
):
|
334
|
-
raise ValueError(
|
335
|
-
"`num_valid_images` should have values <= "
|
336
|
-
"self.max_images_per_prompt. Received: "
|
337
|
-
f"num_valid_images = {num_valid_images}, ",
|
338
|
-
f"max_images_per_prompt = {self.max_images_per_prompt}",
|
339
|
-
)
|
639
|
+
images = self._preprocess_images(images=images, batched=batched)
|
340
640
|
|
341
|
-
|
342
|
-
padded_images_shape = tf.shape(images)
|
343
|
-
images = tf.reshape(
|
344
|
-
images,
|
345
|
-
[
|
346
|
-
-1,
|
347
|
-
padded_images_shape[-3],
|
348
|
-
padded_images_shape[-2],
|
349
|
-
padded_images_shape[-1],
|
350
|
-
],
|
351
|
-
)
|
352
|
-
images = self.image_converter(images)
|
353
|
-
height = (
|
354
|
-
self.image_size[0]
|
355
|
-
if self.image_converter.image_size
|
356
|
-
else original_image_shape[-3]
|
357
|
-
)
|
358
|
-
width = (
|
359
|
-
self.image_size[1]
|
360
|
-
if self.image_converter.image_size
|
361
|
-
else original_image_shape[-2]
|
362
|
-
)
|
363
|
-
images = tf.reshape(
|
364
|
-
images,
|
365
|
-
[
|
366
|
-
padded_images_shape[0],
|
367
|
-
self.max_images_per_prompt,
|
368
|
-
height,
|
369
|
-
width,
|
370
|
-
3,
|
371
|
-
],
|
372
|
-
)
|
373
|
-
|
374
|
-
# Format tokens.
|
375
|
-
padding_mask = token_ids != self.tokenizer.pad_token_id
|
376
|
-
token_ids = tf.ragged.boolean_mask(token_ids, padding_mask)
|
377
|
-
segment_ids = tf.ragged.boolean_mask(segment_ids, padding_mask)
|
378
|
-
padding_mask = tf.ragged.boolean_mask(padding_mask, padding_mask)
|
379
|
-
response_mask = segment_ids == 1
|
380
|
-
|
381
|
-
# Using `num_valid_images`, we need to add dummy image tokens at the
|
382
|
-
# end of the tokenized text. Ideally, we could have passed an image
|
383
|
-
# padding mask to the model, but it won't work with XLA since an
|
384
|
-
# `ops.where` on it in the interleaving layer will return different
|
385
|
-
# number of images every time. So, we need to fix the number of images.
|
386
|
-
vision_placeholder_tensor = self._get_image_placeholder_ragged_tensor(
|
387
|
-
(self.max_images_per_prompt - num_valid_images)
|
388
|
-
* self.num_vision_tokens_per_image,
|
389
|
-
self.tokenizer.token_to_id("<img>"),
|
390
|
-
)
|
391
|
-
vision_placeholder_tensor = vision_placeholder_tensor.to_tensor(
|
392
|
-
shape=[
|
393
|
-
batch_size,
|
394
|
-
self.max_images_per_prompt * self.num_vision_tokens_per_image,
|
395
|
-
],
|
396
|
-
default_value=self.tokenizer.pad_token_id,
|
397
|
-
)
|
398
|
-
|
399
|
-
token_ids_with_placeholder = tf.concat(
|
400
|
-
[token_ids, vision_placeholder_tensor], axis=1
|
401
|
-
)
|
402
|
-
|
403
|
-
# Now, pad everything to the same length.
|
404
|
-
desired_length = (
|
405
|
-
sequence_length
|
406
|
-
+ self.max_images_per_prompt * self.num_vision_tokens_per_image
|
407
|
-
)
|
408
|
-
token_ids_with_placeholder = token_ids_with_placeholder.to_tensor(
|
409
|
-
shape=[batch_size, desired_length + 1],
|
410
|
-
default_value=self.tokenizer.pad_token_id,
|
411
|
-
)
|
412
|
-
padding_mask_with_placeholder = padding_mask.to_tensor(
|
413
|
-
shape=[batch_size, desired_length + 1],
|
414
|
-
default_value=False,
|
415
|
-
)
|
416
|
-
response_mask_with_placeholder = response_mask.to_tensor(
|
417
|
-
shape=[batch_size, desired_length + 1],
|
418
|
-
default_value=False,
|
419
|
-
)
|
420
|
-
|
421
|
-
text_mask = token_ids_with_placeholder != self.tokenizer.token_to_id(
|
422
|
-
"<img>"
|
423
|
-
)
|
641
|
+
vision_mask = token_ids == self.tokenizer.image_placeholder_id
|
424
642
|
|
425
643
|
return self._format_output(
|
426
644
|
images=images,
|
427
|
-
token_ids=
|
428
|
-
|
429
|
-
response_mask=
|
430
|
-
padding_mask=
|
645
|
+
token_ids=token_ids,
|
646
|
+
vision_mask=vision_mask,
|
647
|
+
response_mask=response_mask,
|
648
|
+
padding_mask=padding_mask,
|
431
649
|
return_labels=True,
|
650
|
+
text_only_input=False,
|
651
|
+
batched=batched,
|
432
652
|
)
|
433
653
|
|
434
654
|
@preprocessing_function
|
@@ -448,39 +668,59 @@ class Gemma3CausalLMPreprocessor(CausalLMPreprocessor):
|
|
448
668
|
the sequence (as generation is expected to continue at the end of the
|
449
669
|
inputted prompt).
|
450
670
|
"""
|
671
|
+
|
451
672
|
if not self.built:
|
452
673
|
self.build(None)
|
453
674
|
|
675
|
+
# Extract inputs.
|
454
676
|
if isinstance(x, dict):
|
455
677
|
images = x.get("images", None)
|
456
|
-
|
678
|
+
|
457
679
|
# TODO: do we even need `responses` for generation? Makes sense for
|
458
|
-
# finetuning (i.e., `call()`).
|
680
|
+
# finetuning only (i.e., `call()`).
|
459
681
|
responses = x.get("responses", None)
|
460
682
|
prompts = x["prompts"]
|
461
683
|
else:
|
462
684
|
images = None
|
463
|
-
num_valid_images = None
|
464
685
|
responses = None
|
465
686
|
prompts = x
|
466
687
|
|
467
|
-
if
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
688
|
+
# Find out if the input is batched/not batched. Uprank if not batched.
|
689
|
+
# In other preprocessors, we don't have to do this, but here, all
|
690
|
+
# the following logic (indices, etc.) uses tensors with a batch dim.
|
691
|
+
# We will squeeze these back at the end.
|
692
|
+
batched = True
|
693
|
+
if isinstance(prompts, str):
|
694
|
+
batched = False
|
695
|
+
prompts = [prompts]
|
696
|
+
if responses is not None:
|
697
|
+
responses = [responses]
|
698
|
+
if isinstance(prompts, tf.Tensor) and len(prompts.shape) == 0:
|
699
|
+
batched = False
|
700
|
+
prompts = tf.expand_dims(prompts, axis=0)
|
701
|
+
if responses is not None:
|
702
|
+
responses = tf.expand_dims(responses, axis=0)
|
703
|
+
|
704
|
+
# We have the same 8 cases here, as in `call()`.
|
705
|
+
if self.text_only_model and images is not None:
|
706
|
+
raise ValueError(
|
707
|
+
"The initialized preprocessor/model is text-only, but "
|
708
|
+
" `images` is not `None`."
|
709
|
+
)
|
710
|
+
|
711
|
+
# Add image placeholder tokens. Replace `"<start_of_image>"` in
|
712
|
+
# prompts with
|
713
|
+
# `"\n\n<start_of_image> <img> * 256 <end_of_image>\n\n"`.
|
714
|
+
if not self.text_only_model:
|
476
715
|
prompts = tf.strings.regex_replace(
|
477
716
|
prompts,
|
478
|
-
|
479
|
-
f"\n\n{
|
480
|
-
+
|
481
|
-
+ f"{
|
717
|
+
self.start_of_image_token,
|
718
|
+
f"\n\n{self.start_of_image_token}"
|
719
|
+
+ self.image_placeholder * self.num_vision_tokens_per_image
|
720
|
+
+ f"{self.end_of_image_token}\n\n",
|
482
721
|
)
|
483
722
|
|
723
|
+
# === Tokenization, padding, etc. ===
|
484
724
|
prompts = self.tokenizer(prompts)
|
485
725
|
|
486
726
|
if responses is not None:
|
@@ -489,174 +729,79 @@ class Gemma3CausalLMPreprocessor(CausalLMPreprocessor):
|
|
489
729
|
else:
|
490
730
|
segments = (prompts,)
|
491
731
|
|
732
|
+
# Padding.
|
492
733
|
token_ids, segment_ids = self.packer(
|
493
734
|
segments,
|
494
735
|
sequence_length=sequence_length,
|
495
736
|
add_end_value=False,
|
496
737
|
)
|
738
|
+
response_mask = segment_ids == 1
|
739
|
+
padding_mask = token_ids != self.tokenizer.pad_token_id
|
497
740
|
|
498
|
-
#
|
741
|
+
# === Text Model ===
|
499
742
|
if self.text_only_model:
|
500
|
-
response_mask = segment_ids == 1
|
501
|
-
padding_mask = token_ids != self.tokenizer.pad_token_id
|
502
743
|
return {
|
503
|
-
"token_ids":
|
504
|
-
|
744
|
+
"token_ids": (
|
745
|
+
token_ids if batched else tf.squeeze(token_ids, axis=0)
|
746
|
+
),
|
747
|
+
"padding_mask": (
|
748
|
+
padding_mask
|
749
|
+
if batched
|
750
|
+
else tf.squeeze(padding_mask, axis=0)
|
751
|
+
),
|
505
752
|
}
|
506
753
|
|
507
|
-
# Vision
|
754
|
+
# === Vision processing ===
|
755
|
+
|
508
756
|
batch_size = tf.shape(prompts)[0]
|
757
|
+
desired_height = self.image_converter.image_size[0]
|
758
|
+
desired_width = self.image_converter.image_size[1]
|
509
759
|
if images is None:
|
760
|
+
# == Branch: vision model, with `None` value for `images` ==
|
761
|
+
|
510
762
|
# To handle the text-only input case, we need to pass an empty
|
511
|
-
# tensor so as to skip the vision
|
763
|
+
# tensor so as to skip the vision layers of the model.
|
764
|
+
|
765
|
+
# TODO: Once functional models accept `None` inputs, consider
|
766
|
+
# passing this as `None` directly.
|
512
767
|
images = tf.ones(
|
513
768
|
shape=[
|
514
769
|
batch_size,
|
515
770
|
0,
|
516
|
-
|
517
|
-
|
771
|
+
desired_height,
|
772
|
+
desired_width,
|
518
773
|
3,
|
519
774
|
],
|
520
775
|
dtype="float32",
|
521
776
|
)
|
522
777
|
|
523
|
-
|
524
|
-
padding_mask = token_ids != self.tokenizer.pad_token_id
|
525
|
-
response_mask = segment_ids == 1
|
778
|
+
vision_mask = tf.zeros_like(token_ids, dtype=bool)
|
526
779
|
|
527
780
|
return self._format_output(
|
528
781
|
images=images,
|
529
782
|
token_ids=token_ids,
|
530
|
-
|
783
|
+
vision_mask=vision_mask,
|
531
784
|
response_mask=response_mask,
|
532
785
|
padding_mask=padding_mask,
|
533
786
|
return_labels=False,
|
534
787
|
text_only_input=True,
|
788
|
+
batched=batched,
|
535
789
|
)
|
536
790
|
|
537
|
-
#
|
538
|
-
|
539
|
-
if num_valid_images is None:
|
540
|
-
num_valid_images = tf.fill(
|
541
|
-
dims=(batch_size,),
|
542
|
-
value=self.max_images_per_prompt,
|
543
|
-
)
|
544
|
-
|
545
|
-
# Image inputs checks.
|
546
|
-
if original_image_shape[1] != self.max_images_per_prompt:
|
547
|
-
raise ValueError(
|
548
|
-
"The number of images per sample should be the same as "
|
549
|
-
"`max_images_per_prompt`. Received: "
|
550
|
-
f"images.shape = {original_image_shape}, "
|
551
|
-
f"max_images_per_prompt = {self.max_images_per_prompt}"
|
552
|
-
)
|
553
|
-
if tf.cast(
|
554
|
-
tf.math.reduce_sum(
|
555
|
-
tf.cast(
|
556
|
-
tf.math.greater(
|
557
|
-
num_valid_images, self.max_images_per_prompt
|
558
|
-
),
|
559
|
-
dtype=tf.int32,
|
560
|
-
)
|
561
|
-
),
|
562
|
-
dtype=bool,
|
563
|
-
):
|
564
|
-
raise ValueError(
|
565
|
-
"`num_valid_images` should have values <= "
|
566
|
-
"self.max_images_per_prompt. Received: "
|
567
|
-
f"num_valid_images = {num_valid_images}, ",
|
568
|
-
f"max_images_per_prompt = {self.max_images_per_prompt}",
|
569
|
-
)
|
570
|
-
|
571
|
-
# Resize, rescale, etc. the images.
|
572
|
-
padded_images_shape = tf.shape(images)
|
573
|
-
images = tf.reshape(
|
574
|
-
images,
|
575
|
-
[
|
576
|
-
-1,
|
577
|
-
padded_images_shape[-3],
|
578
|
-
padded_images_shape[-2],
|
579
|
-
padded_images_shape[-1],
|
580
|
-
],
|
581
|
-
)
|
582
|
-
images = self.image_converter(images)
|
583
|
-
height = (
|
584
|
-
self.image_size[0]
|
585
|
-
if self.image_converter.image_size
|
586
|
-
else original_image_shape[-3]
|
587
|
-
)
|
588
|
-
width = (
|
589
|
-
self.image_size[1]
|
590
|
-
if self.image_converter.image_size
|
591
|
-
else original_image_shape[-2]
|
592
|
-
)
|
593
|
-
images = tf.reshape(
|
594
|
-
images,
|
595
|
-
[
|
596
|
-
padded_images_shape[0],
|
597
|
-
self.max_images_per_prompt,
|
598
|
-
height,
|
599
|
-
width,
|
600
|
-
3,
|
601
|
-
],
|
602
|
-
)
|
603
|
-
|
604
|
-
padding_mask = token_ids != self.tokenizer.pad_token_id
|
605
|
-
token_ids = tf.ragged.boolean_mask(token_ids, padding_mask)
|
606
|
-
segment_ids = tf.ragged.boolean_mask(segment_ids, padding_mask)
|
607
|
-
padding_mask = tf.ragged.boolean_mask(padding_mask, padding_mask)
|
608
|
-
response_mask = segment_ids == 1
|
609
|
-
|
610
|
-
# Using `num_valid_images`, we need to add dummy image tokens at the
|
611
|
-
# end of the tokenized text. Ideally, we could have passed an image
|
612
|
-
# padding mask to the model, but it won't work with XLA since an
|
613
|
-
# `ops.where` on it in the interleaving layer will return different
|
614
|
-
# number of images every time. So, we need to fix the number of images.
|
615
|
-
vision_placeholder_tensor = self._get_image_placeholder_ragged_tensor(
|
616
|
-
(self.max_images_per_prompt - num_valid_images)
|
617
|
-
* self.num_vision_tokens_per_image,
|
618
|
-
self.tokenizer.token_to_id("<img>"),
|
619
|
-
)
|
620
|
-
vision_placeholder_tensor = vision_placeholder_tensor.to_tensor(
|
621
|
-
shape=[
|
622
|
-
batch_size,
|
623
|
-
self.max_images_per_prompt * self.num_vision_tokens_per_image,
|
624
|
-
],
|
625
|
-
default_value=self.tokenizer.pad_token_id,
|
626
|
-
)
|
627
|
-
token_ids_with_placeholder = tf.concat(
|
628
|
-
[token_ids, vision_placeholder_tensor], axis=1
|
629
|
-
)
|
630
|
-
|
631
|
-
# Now, pad everything to the same length.
|
632
|
-
desired_length = (
|
633
|
-
sequence_length
|
634
|
-
+ self.max_images_per_prompt * self.num_vision_tokens_per_image
|
635
|
-
)
|
636
|
-
token_ids_with_placeholder = token_ids_with_placeholder.to_tensor(
|
637
|
-
shape=[batch_size, desired_length],
|
638
|
-
default_value=self.tokenizer.pad_token_id,
|
639
|
-
)
|
640
|
-
padding_mask_with_placeholder = padding_mask.to_tensor(
|
641
|
-
shape=[batch_size, desired_length],
|
642
|
-
default_value=False,
|
643
|
-
)
|
644
|
-
response_mask_with_placeholder = response_mask.to_tensor(
|
645
|
-
shape=[batch_size, desired_length],
|
646
|
-
default_value=False,
|
647
|
-
)
|
791
|
+
# == Branch: vision model, with non-`None` value for `images` ==
|
792
|
+
images = self._preprocess_images(images=images, batched=batched)
|
648
793
|
|
649
|
-
|
650
|
-
"<img>"
|
651
|
-
)
|
794
|
+
vision_mask = token_ids == self.tokenizer.image_placeholder_id
|
652
795
|
|
653
796
|
return self._format_output(
|
654
797
|
images=images,
|
655
|
-
token_ids=
|
656
|
-
|
657
|
-
response_mask=
|
658
|
-
padding_mask=
|
798
|
+
token_ids=token_ids,
|
799
|
+
vision_mask=vision_mask,
|
800
|
+
response_mask=response_mask,
|
801
|
+
padding_mask=padding_mask,
|
659
802
|
return_labels=False,
|
803
|
+
text_only_input=False,
|
804
|
+
batched=batched,
|
660
805
|
)
|
661
806
|
|
662
807
|
def get_config(self):
|
@@ -686,6 +831,18 @@ class Gemma3CausalLMPreprocessor(CausalLMPreprocessor):
|
|
686
831
|
|
687
832
|
token_ids, padding_mask = x["token_ids"], x["padding_mask"]
|
688
833
|
ids_to_strip = self.tokenizer.special_token_ids
|
689
|
-
|
834
|
+
|
835
|
+
# We do not want to strip SoI token because it is provided by the user.
|
836
|
+
if self.tokenizer.start_of_image_token_id in ids_to_strip:
|
837
|
+
ids_to_strip.remove(self.tokenizer.start_of_image_token_id)
|
838
|
+
|
690
839
|
token_ids = strip_to_ragged(token_ids, padding_mask, ids_to_strip)
|
691
840
|
return self.tokenizer.detokenize(token_ids)
|
841
|
+
|
842
|
+
@property
|
843
|
+
def max_images_per_prompt(self):
|
844
|
+
return self._max_images_per_prompt
|
845
|
+
|
846
|
+
@max_images_per_prompt.setter
|
847
|
+
def max_images_per_prompt(self, value):
|
848
|
+
self._max_images_per_prompt = value
|