keras-hub-nightly 0.20.0.dev202503260356__py3-none-any.whl → 0.20.0.dev202503270400__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +3 -0
- keras_hub/api/models/__init__.py +6 -0
- keras_hub/api/tokenizers/__init__.py +1 -0
- keras_hub/src/models/gemma3/__init__.py +5 -0
- keras_hub/src/models/gemma3/gemma3_attention.py +315 -0
- keras_hub/src/models/gemma3/gemma3_backbone.py +352 -0
- keras_hub/src/models/gemma3/gemma3_causal_lm.py +306 -0
- keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +691 -0
- keras_hub/src/models/gemma3/gemma3_decoder_block.py +305 -0
- keras_hub/src/models/gemma3/gemma3_image_converter.py +8 -0
- keras_hub/src/models/gemma3/gemma3_interleave_embeddings.py +79 -0
- keras_hub/src/models/gemma3/gemma3_presets.py +93 -0
- keras_hub/src/models/gemma3/gemma3_tokenizer.py +87 -0
- keras_hub/src/models/gemma3/gemma3_vit.py +608 -0
- keras_hub/src/models/gemma3/rms_normalization.py +26 -0
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.20.0.dev202503260356.dist-info → keras_hub_nightly-0.20.0.dev202503270400.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.20.0.dev202503260356.dist-info → keras_hub_nightly-0.20.0.dev202503270400.dist-info}/RECORD +20 -8
- {keras_hub_nightly-0.20.0.dev202503260356.dist-info → keras_hub_nightly-0.20.0.dev202503270400.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.20.0.dev202503260356.dist-info → keras_hub_nightly-0.20.0.dev202503270400.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,691 @@
|
|
1
|
+
import keras
|
2
|
+
import tensorflow as tf
|
3
|
+
|
4
|
+
from keras_hub.src.api_export import keras_hub_export
|
5
|
+
from keras_hub.src.layers.preprocessing.multi_segment_packer import (
|
6
|
+
MultiSegmentPacker,
|
7
|
+
)
|
8
|
+
from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor
|
9
|
+
from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3Backbone
|
10
|
+
from keras_hub.src.models.gemma3.gemma3_image_converter import (
|
11
|
+
Gemma3ImageConverter,
|
12
|
+
)
|
13
|
+
from keras_hub.src.models.gemma3.gemma3_tokenizer import Gemma3Tokenizer
|
14
|
+
from keras_hub.src.utils.tensor_utils import preprocessing_function
|
15
|
+
from keras_hub.src.utils.tensor_utils import strip_to_ragged
|
16
|
+
|
17
|
+
START_OF_IMAGE_TOKEN = "<start_of_image>"
|
18
|
+
IMAGE_PLACEHOLDER_TOKEN = "<img>"
|
19
|
+
END_OF_IMAGE_TOKEN = "<end_of_image>"
|
20
|
+
|
21
|
+
|
22
|
+
@keras_hub_export("keras_hub.models.Gemma3CausalLMPreprocessor")
|
23
|
+
class Gemma3CausalLMPreprocessor(CausalLMPreprocessor):
|
24
|
+
"""Gemma3 Causal LM preprocessor.
|
25
|
+
|
26
|
+
This preprocessing layer is meant for use with
|
27
|
+
`keras_hub.models.Gemma3CausalLM`. By default, it will take in batches of
|
28
|
+
images and strings, and return outputs in a `(x, y, sample_weight)` format,
|
29
|
+
where the `y` label is the next token id in the `x` sequence.
|
30
|
+
|
31
|
+
There is only one mode this layer currently supports, i.e.,
|
32
|
+
`image_converter` is `None`. We preprocess the text like any other
|
33
|
+
Causal LM preprocessor, i.e., tokenisation, padding, etc. The sequence
|
34
|
+
is padded to `sequence_length`.
|
35
|
+
|
36
|
+
For use with generation, the layer also exposes two methods
|
37
|
+
`generate_preprocess()` and `generate_postprocess()`. When this preprocessor
|
38
|
+
is attached to a `keras_hub.models.GemmaCausalLM` instance, these methods
|
39
|
+
will be called implicitly in `generate()`. They can also be called
|
40
|
+
standalone (e.g. to precompute preprocessing inputs for generation in a
|
41
|
+
separate process).
|
42
|
+
|
43
|
+
Args:
|
44
|
+
tokenizer: A `keras_hub.models.GemmaTokenizer` instance.
|
45
|
+
image_converter: A `keras_hub.layers.ImageConverter` instance. Defaults
|
46
|
+
to `None`.
|
47
|
+
sequence_length: The length of the packed inputs. Defaults to 1024.
|
48
|
+
add_start_token: If `True`, the preprocessor will prepend the tokenizer
|
49
|
+
start token to each input sequence. Defaults to `True`.
|
50
|
+
add_end_token: If `True`, the preprocessor will append the tokenizer
|
51
|
+
end token to each input sequence. Defaults to `True`.
|
52
|
+
max_images_per_prompt: int. Permissible number of images per sample in
|
53
|
+
the batch. Defaults to 2.
|
54
|
+
num_vision_tokens_per_image: int. Number of vision placeholder tokens
|
55
|
+
per image. Defaults to 256.
|
56
|
+
|
57
|
+
Call arguments:
|
58
|
+
x: A string, `tf.Tensor` or list of python strings.
|
59
|
+
y: Label data. Should always be `None` as the layer generates labels.
|
60
|
+
sample_weight: Label weights. Should always be `None` as the layer
|
61
|
+
generates label weights.
|
62
|
+
sequence_length: Pass to override the configured `sequence_length` of
|
63
|
+
the layer.
|
64
|
+
|
65
|
+
Examples:
|
66
|
+
```python
|
67
|
+
# Load the preprocessor from a preset.
|
68
|
+
preprocessor = keras_hub.models.Gemma3CausalLMPreprocessor.from_preset(
|
69
|
+
"gemma3_4b_en"
|
70
|
+
)
|
71
|
+
|
72
|
+
# Text-only input.
|
73
|
+
preprocessor(
|
74
|
+
"prompts": ["The quick brown fox jumped."],
|
75
|
+
"responses": [""],
|
76
|
+
)
|
77
|
+
|
78
|
+
# Images (pass one image)
|
79
|
+
max_images_per_prompt = 2
|
80
|
+
preprocessor(
|
81
|
+
"prompts": ["The quick brown fox jumped."],
|
82
|
+
"responses": [""],
|
83
|
+
"images": [np.ones((2, 896, 896, 3)).astype("float32")],
|
84
|
+
"num_valid_images": np.array([1,], dtype=np.int32)
|
85
|
+
)
|
86
|
+
```
|
87
|
+
"""
|
88
|
+
|
89
|
+
backbone_cls = Gemma3Backbone
|
90
|
+
tokenizer_cls = Gemma3Tokenizer
|
91
|
+
image_converter_cls = Gemma3ImageConverter
|
92
|
+
|
93
|
+
def __init__(
|
94
|
+
self,
|
95
|
+
tokenizer,
|
96
|
+
image_converter=None,
|
97
|
+
sequence_length=1024,
|
98
|
+
add_start_token=True,
|
99
|
+
add_end_token=True,
|
100
|
+
max_images_per_prompt=2,
|
101
|
+
num_vision_tokens_per_image=256,
|
102
|
+
**kwargs,
|
103
|
+
):
|
104
|
+
super().__init__(
|
105
|
+
tokenizer=tokenizer,
|
106
|
+
sequence_length=sequence_length,
|
107
|
+
add_start_token=add_start_token,
|
108
|
+
add_end_token=add_end_token,
|
109
|
+
**kwargs,
|
110
|
+
)
|
111
|
+
|
112
|
+
if image_converter is not None:
|
113
|
+
raise ValueError(
|
114
|
+
"Currently, only the text version of the Gemma3 model is "
|
115
|
+
"supported."
|
116
|
+
)
|
117
|
+
|
118
|
+
self.image_converter = image_converter
|
119
|
+
self.max_images_per_prompt = max_images_per_prompt
|
120
|
+
self.num_vision_tokens_per_image = num_vision_tokens_per_image
|
121
|
+
|
122
|
+
self.text_only_model = self.image_converter is None
|
123
|
+
|
124
|
+
def build(self, input_shape):
|
125
|
+
# Defer packer creation to `build()` so that we can be sure tokenizer
|
126
|
+
# assets have loaded when restoring a saved model.
|
127
|
+
self.packer = MultiSegmentPacker(
|
128
|
+
start_value=self.tokenizer.start_token_id,
|
129
|
+
end_value=self.tokenizer.end_token_id,
|
130
|
+
pad_value=self.tokenizer.pad_token_id,
|
131
|
+
sep_value=[],
|
132
|
+
sequence_length=self.sequence_length,
|
133
|
+
)
|
134
|
+
self.built = True
|
135
|
+
|
136
|
+
def _format_output(
|
137
|
+
self,
|
138
|
+
images,
|
139
|
+
token_ids,
|
140
|
+
text_mask,
|
141
|
+
response_mask,
|
142
|
+
padding_mask,
|
143
|
+
return_labels=False,
|
144
|
+
text_only_input=False,
|
145
|
+
):
|
146
|
+
if return_labels:
|
147
|
+
# Target `y` will be the next token.
|
148
|
+
y = token_ids[..., 1:]
|
149
|
+
# Only compute the loss for labels in the response.
|
150
|
+
sample_weight = response_mask[..., 1:]
|
151
|
+
|
152
|
+
token_ids = token_ids[..., :-1]
|
153
|
+
text_mask = text_mask[..., :-1]
|
154
|
+
response_mask = response_mask[..., :-1]
|
155
|
+
padding_mask = padding_mask[..., :-1]
|
156
|
+
|
157
|
+
batch_size, sequence_length = tf.shape(text_mask)
|
158
|
+
|
159
|
+
if text_only_input:
|
160
|
+
vision_indices = tf.ones(
|
161
|
+
shape=[
|
162
|
+
batch_size,
|
163
|
+
0,
|
164
|
+
],
|
165
|
+
dtype=tf.int32,
|
166
|
+
)
|
167
|
+
else:
|
168
|
+
sequence_length = tf.shape(text_mask)[-1]
|
169
|
+
flat_text_mask = tf.reshape(
|
170
|
+
text_mask, (batch_size * sequence_length)
|
171
|
+
)
|
172
|
+
vision_indices = tf.where(tf.logical_not(flat_text_mask))
|
173
|
+
vision_indices = tf.reshape(vision_indices, (batch_size, -1))
|
174
|
+
|
175
|
+
# The last token does not have a next token, so we truncate it out.
|
176
|
+
x = {
|
177
|
+
# Image
|
178
|
+
"images": images,
|
179
|
+
# Text
|
180
|
+
"token_ids": token_ids,
|
181
|
+
"vision_indices": vision_indices,
|
182
|
+
"text_mask": text_mask,
|
183
|
+
"padding_mask": padding_mask,
|
184
|
+
}
|
185
|
+
|
186
|
+
if return_labels:
|
187
|
+
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
|
188
|
+
else:
|
189
|
+
return x
|
190
|
+
|
191
|
+
def _get_image_placeholder_ragged_tensor(self, required_length, fill_value):
|
192
|
+
"""Identifies the number of dummy placeholder tokens to pad input with.
|
193
|
+
|
194
|
+
Depending on the number of images provided per sample, and the
|
195
|
+
allowed number of images, this method identifies the number of vision
|
196
|
+
placeholder tokens we need to pad tokens with. This is necessary to
|
197
|
+
ensure the same number of image tokens in every sample so as to not
|
198
|
+
cause dynamic shape issues with XLA in the interleaving layer.
|
199
|
+
"""
|
200
|
+
required_length = tf.cast(required_length, tf.int32)
|
201
|
+
ones_tensor = tf.ones_like(required_length, dtype=tf.int32)
|
202
|
+
flattened_tensor = tf.repeat(ones_tensor, required_length)
|
203
|
+
row_splits = tf.concat([[0], tf.cumsum(required_length)], axis=0)
|
204
|
+
ragged_tensor = tf.RaggedTensor.from_row_splits(
|
205
|
+
flattened_tensor, row_splits
|
206
|
+
)
|
207
|
+
ragged_tensor = ragged_tensor * fill_value
|
208
|
+
ragged_tensor = tf.cast(ragged_tensor, tf.int32)
|
209
|
+
return ragged_tensor
|
210
|
+
|
211
|
+
@preprocessing_function
|
212
|
+
def call(
|
213
|
+
self,
|
214
|
+
x,
|
215
|
+
y=None,
|
216
|
+
sample_weight=None,
|
217
|
+
sequence_length=None,
|
218
|
+
):
|
219
|
+
sequence_length = sequence_length or self.sequence_length
|
220
|
+
|
221
|
+
# Extract text part of the input.
|
222
|
+
prompts, responses = x["prompts"], x["responses"]
|
223
|
+
|
224
|
+
# Extract images from the input.
|
225
|
+
images = x.get("images", None)
|
226
|
+
num_valid_images = x.get("num_valid_images", None)
|
227
|
+
|
228
|
+
if self.text_only_model:
|
229
|
+
if images is not None or num_valid_images is not None:
|
230
|
+
raise ValueError(
|
231
|
+
"`image_converter` cannot be None when `images` or"
|
232
|
+
" `num_valid_images` is not None."
|
233
|
+
)
|
234
|
+
else:
|
235
|
+
# Replace `"<start_of_image>"` in prompts with
|
236
|
+
# `"\n\n<start_of_image> <img> * 256 <end_of_image>\n\n"`.
|
237
|
+
prompts = tf.strings.regex_replace(
|
238
|
+
prompts,
|
239
|
+
START_OF_IMAGE_TOKEN,
|
240
|
+
f"\n\n{START_OF_IMAGE_TOKEN}"
|
241
|
+
+ IMAGE_PLACEHOLDER_TOKEN * self.num_vision_tokens_per_image
|
242
|
+
+ f"{END_OF_IMAGE_TOKEN}\n\n",
|
243
|
+
)
|
244
|
+
|
245
|
+
# Tokenise the inputs.
|
246
|
+
prompts = self.tokenizer(prompts)
|
247
|
+
responses = self.tokenizer(responses)
|
248
|
+
|
249
|
+
# All truncation should happen on the text token IDs and not on
|
250
|
+
# the dummy placeholder image tokens which we will add at the end.
|
251
|
+
# Hence, we use a packer only on the text part first, and then
|
252
|
+
# add the padded dummy placeholder tokens separately.
|
253
|
+
token_ids, segment_ids = self.packer(
|
254
|
+
(prompts, responses),
|
255
|
+
sequence_length=sequence_length
|
256
|
+
if images is not None
|
257
|
+
else sequence_length + 1,
|
258
|
+
add_start_value=self.add_start_token,
|
259
|
+
add_end_value=self.add_end_token,
|
260
|
+
)
|
261
|
+
|
262
|
+
# If it is a text only model, return immediately.
|
263
|
+
if self.text_only_model:
|
264
|
+
# The last token does not have a next token, so we truncate it out.
|
265
|
+
response_mask = segment_ids == 1
|
266
|
+
padding_mask = token_ids != self.tokenizer.pad_token_id
|
267
|
+
x = {
|
268
|
+
"token_ids": token_ids[..., :-1],
|
269
|
+
"padding_mask": padding_mask[..., :-1],
|
270
|
+
}
|
271
|
+
|
272
|
+
# Target `y` will be the next token.
|
273
|
+
y = token_ids[..., 1:]
|
274
|
+
# Only compute the loss for labels in the response.
|
275
|
+
sample_weight = response_mask[..., 1:]
|
276
|
+
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
|
277
|
+
|
278
|
+
# Vision preprocessing
|
279
|
+
batch_size = tf.shape(prompts)[0]
|
280
|
+
if images is None:
|
281
|
+
# To handle the text-only input case, we need to pass an empty
|
282
|
+
# tensor so as to skip the vision part of the model.
|
283
|
+
images = tf.ones(
|
284
|
+
shape=[
|
285
|
+
batch_size,
|
286
|
+
0,
|
287
|
+
self.image_converter.image_size[0],
|
288
|
+
self.image_converter.image_size[1],
|
289
|
+
3,
|
290
|
+
],
|
291
|
+
dtype="float32",
|
292
|
+
)
|
293
|
+
|
294
|
+
text_mask = tf.ones_like(token_ids, dtype=bool)
|
295
|
+
padding_mask = token_ids != self.tokenizer.pad_token_id
|
296
|
+
response_mask = segment_ids == 1
|
297
|
+
|
298
|
+
return self._format_output(
|
299
|
+
images=images,
|
300
|
+
token_ids=token_ids,
|
301
|
+
text_mask=text_mask,
|
302
|
+
response_mask=response_mask,
|
303
|
+
padding_mask=padding_mask,
|
304
|
+
return_labels=True,
|
305
|
+
text_only_input=True,
|
306
|
+
)
|
307
|
+
|
308
|
+
original_image_shape = tf.shape(images)
|
309
|
+
if num_valid_images is None:
|
310
|
+
num_valid_images = tf.fill(
|
311
|
+
dims=(batch_size,),
|
312
|
+
value=self.max_images_per_prompt,
|
313
|
+
)
|
314
|
+
|
315
|
+
# Image inputs checks.
|
316
|
+
if original_image_shape[1] != self.max_images_per_prompt:
|
317
|
+
raise ValueError(
|
318
|
+
"The number of images per sample should be the same as "
|
319
|
+
"`max_images_per_prompt`. Received: "
|
320
|
+
f"images.shape = {original_image_shape}, "
|
321
|
+
f"max_images_per_prompt = {self.max_images_per_prompt}"
|
322
|
+
)
|
323
|
+
if tf.cast(
|
324
|
+
tf.math.reduce_sum(
|
325
|
+
tf.cast(
|
326
|
+
tf.math.greater(
|
327
|
+
num_valid_images, self.max_images_per_prompt
|
328
|
+
),
|
329
|
+
dtype=tf.int32,
|
330
|
+
)
|
331
|
+
),
|
332
|
+
dtype=bool,
|
333
|
+
):
|
334
|
+
raise ValueError(
|
335
|
+
"`num_valid_images` should have values <= "
|
336
|
+
"self.max_images_per_prompt. Received: "
|
337
|
+
f"num_valid_images = {num_valid_images}, ",
|
338
|
+
f"max_images_per_prompt = {self.max_images_per_prompt}",
|
339
|
+
)
|
340
|
+
|
341
|
+
# Resize, rescale, etc. the images.
|
342
|
+
padded_images_shape = tf.shape(images)
|
343
|
+
images = tf.reshape(
|
344
|
+
images,
|
345
|
+
[
|
346
|
+
-1,
|
347
|
+
padded_images_shape[-3],
|
348
|
+
padded_images_shape[-2],
|
349
|
+
padded_images_shape[-1],
|
350
|
+
],
|
351
|
+
)
|
352
|
+
images = self.image_converter(images)
|
353
|
+
height = (
|
354
|
+
self.image_size[0]
|
355
|
+
if self.image_converter.image_size
|
356
|
+
else original_image_shape[-3]
|
357
|
+
)
|
358
|
+
width = (
|
359
|
+
self.image_size[1]
|
360
|
+
if self.image_converter.image_size
|
361
|
+
else original_image_shape[-2]
|
362
|
+
)
|
363
|
+
images = tf.reshape(
|
364
|
+
images,
|
365
|
+
[
|
366
|
+
padded_images_shape[0],
|
367
|
+
self.max_images_per_prompt,
|
368
|
+
height,
|
369
|
+
width,
|
370
|
+
3,
|
371
|
+
],
|
372
|
+
)
|
373
|
+
|
374
|
+
# Format tokens.
|
375
|
+
padding_mask = token_ids != self.tokenizer.pad_token_id
|
376
|
+
token_ids = tf.ragged.boolean_mask(token_ids, padding_mask)
|
377
|
+
segment_ids = tf.ragged.boolean_mask(segment_ids, padding_mask)
|
378
|
+
padding_mask = tf.ragged.boolean_mask(padding_mask, padding_mask)
|
379
|
+
response_mask = segment_ids == 1
|
380
|
+
|
381
|
+
# Using `num_valid_images`, we need to add dummy image tokens at the
|
382
|
+
# end of the tokenized text. Ideally, we could have passed an image
|
383
|
+
# padding mask to the model, but it won't work with XLA since an
|
384
|
+
# `ops.where` on it in the interleaving layer will return different
|
385
|
+
# number of images every time. So, we need to fix the number of images.
|
386
|
+
vision_placeholder_tensor = self._get_image_placeholder_ragged_tensor(
|
387
|
+
(self.max_images_per_prompt - num_valid_images)
|
388
|
+
* self.num_vision_tokens_per_image,
|
389
|
+
self.tokenizer.token_to_id("<img>"),
|
390
|
+
)
|
391
|
+
vision_placeholder_tensor = vision_placeholder_tensor.to_tensor(
|
392
|
+
shape=[
|
393
|
+
batch_size,
|
394
|
+
self.max_images_per_prompt * self.num_vision_tokens_per_image,
|
395
|
+
],
|
396
|
+
default_value=self.tokenizer.pad_token_id,
|
397
|
+
)
|
398
|
+
|
399
|
+
token_ids_with_placeholder = tf.concat(
|
400
|
+
[token_ids, vision_placeholder_tensor], axis=1
|
401
|
+
)
|
402
|
+
|
403
|
+
# Now, pad everything to the same length.
|
404
|
+
desired_length = (
|
405
|
+
sequence_length
|
406
|
+
+ self.max_images_per_prompt * self.num_vision_tokens_per_image
|
407
|
+
)
|
408
|
+
token_ids_with_placeholder = token_ids_with_placeholder.to_tensor(
|
409
|
+
shape=[batch_size, desired_length + 1],
|
410
|
+
default_value=self.tokenizer.pad_token_id,
|
411
|
+
)
|
412
|
+
padding_mask_with_placeholder = padding_mask.to_tensor(
|
413
|
+
shape=[batch_size, desired_length + 1],
|
414
|
+
default_value=False,
|
415
|
+
)
|
416
|
+
response_mask_with_placeholder = response_mask.to_tensor(
|
417
|
+
shape=[batch_size, desired_length + 1],
|
418
|
+
default_value=False,
|
419
|
+
)
|
420
|
+
|
421
|
+
text_mask = token_ids_with_placeholder != self.tokenizer.token_to_id(
|
422
|
+
"<img>"
|
423
|
+
)
|
424
|
+
|
425
|
+
return self._format_output(
|
426
|
+
images=images,
|
427
|
+
token_ids=token_ids_with_placeholder,
|
428
|
+
text_mask=text_mask,
|
429
|
+
response_mask=response_mask_with_placeholder,
|
430
|
+
padding_mask=padding_mask_with_placeholder,
|
431
|
+
return_labels=True,
|
432
|
+
)
|
433
|
+
|
434
|
+
@preprocessing_function
|
435
|
+
def generate_preprocess(
|
436
|
+
self,
|
437
|
+
x,
|
438
|
+
sequence_length=None,
|
439
|
+
):
|
440
|
+
"""Convert strings to integer token input for generation.
|
441
|
+
|
442
|
+
Similar to calling the layer for training, this method takes in strings
|
443
|
+
or tensor strings, tokenizes and packs the input, and computes a padding
|
444
|
+
mask masking all inputs not filled in with a padded value.
|
445
|
+
|
446
|
+
Unlike calling the layer for training, this method does not compute
|
447
|
+
labels and will never append a `tokenizer.end_token_id` to the end of
|
448
|
+
the sequence (as generation is expected to continue at the end of the
|
449
|
+
inputted prompt).
|
450
|
+
"""
|
451
|
+
if not self.built:
|
452
|
+
self.build(None)
|
453
|
+
|
454
|
+
if isinstance(x, dict):
|
455
|
+
images = x.get("images", None)
|
456
|
+
num_valid_images = x.get("num_valid_images", None)
|
457
|
+
# TODO: do we even need `responses` for generation? Makes sense for
|
458
|
+
# finetuning (i.e., `call()`).
|
459
|
+
responses = x.get("responses", None)
|
460
|
+
prompts = x["prompts"]
|
461
|
+
else:
|
462
|
+
images = None
|
463
|
+
num_valid_images = None
|
464
|
+
responses = None
|
465
|
+
prompts = x
|
466
|
+
|
467
|
+
if self.text_only_model:
|
468
|
+
if images is not None or num_valid_images is not None:
|
469
|
+
raise ValueError(
|
470
|
+
"`image_converter` cannot be None when `images` or"
|
471
|
+
" `num_valid_images` is not None."
|
472
|
+
)
|
473
|
+
else:
|
474
|
+
# Replace `"<start_of_image>"` in prompts with
|
475
|
+
# `"\n\n<start_of_image> <img> * 256 <end_of_image>\n\n"`.
|
476
|
+
prompts = tf.strings.regex_replace(
|
477
|
+
prompts,
|
478
|
+
START_OF_IMAGE_TOKEN,
|
479
|
+
f"\n\n{START_OF_IMAGE_TOKEN}"
|
480
|
+
+ IMAGE_PLACEHOLDER_TOKEN * self.num_vision_tokens_per_image
|
481
|
+
+ f"{END_OF_IMAGE_TOKEN}\n\n",
|
482
|
+
)
|
483
|
+
|
484
|
+
prompts = self.tokenizer(prompts)
|
485
|
+
|
486
|
+
if responses is not None:
|
487
|
+
responses = self.tokenizer(responses)
|
488
|
+
segments = (prompts, responses)
|
489
|
+
else:
|
490
|
+
segments = (prompts,)
|
491
|
+
|
492
|
+
token_ids, segment_ids = self.packer(
|
493
|
+
segments,
|
494
|
+
sequence_length=sequence_length,
|
495
|
+
add_end_value=False,
|
496
|
+
)
|
497
|
+
|
498
|
+
# If it is a text only model, return immediately.
|
499
|
+
if self.text_only_model:
|
500
|
+
response_mask = segment_ids == 1
|
501
|
+
padding_mask = token_ids != self.tokenizer.pad_token_id
|
502
|
+
return {
|
503
|
+
"token_ids": token_ids,
|
504
|
+
"padding_mask": padding_mask,
|
505
|
+
}
|
506
|
+
|
507
|
+
# Vision preprocessing
|
508
|
+
batch_size = tf.shape(prompts)[0]
|
509
|
+
if images is None:
|
510
|
+
# To handle the text-only input case, we need to pass an empty
|
511
|
+
# tensor so as to skip the vision part of the model.
|
512
|
+
images = tf.ones(
|
513
|
+
shape=[
|
514
|
+
batch_size,
|
515
|
+
0,
|
516
|
+
self.image_converter.image_size[0],
|
517
|
+
self.image_converter.image_size[1],
|
518
|
+
3,
|
519
|
+
],
|
520
|
+
dtype="float32",
|
521
|
+
)
|
522
|
+
|
523
|
+
text_mask = tf.ones_like(token_ids, dtype=bool)
|
524
|
+
padding_mask = token_ids != self.tokenizer.pad_token_id
|
525
|
+
response_mask = segment_ids == 1
|
526
|
+
|
527
|
+
return self._format_output(
|
528
|
+
images=images,
|
529
|
+
token_ids=token_ids,
|
530
|
+
text_mask=text_mask,
|
531
|
+
response_mask=response_mask,
|
532
|
+
padding_mask=padding_mask,
|
533
|
+
return_labels=False,
|
534
|
+
text_only_input=True,
|
535
|
+
)
|
536
|
+
|
537
|
+
# Pad images.
|
538
|
+
original_image_shape = tf.shape(images)
|
539
|
+
if num_valid_images is None:
|
540
|
+
num_valid_images = tf.fill(
|
541
|
+
dims=(batch_size,),
|
542
|
+
value=self.max_images_per_prompt,
|
543
|
+
)
|
544
|
+
|
545
|
+
# Image inputs checks.
|
546
|
+
if original_image_shape[1] != self.max_images_per_prompt:
|
547
|
+
raise ValueError(
|
548
|
+
"The number of images per sample should be the same as "
|
549
|
+
"`max_images_per_prompt`. Received: "
|
550
|
+
f"images.shape = {original_image_shape}, "
|
551
|
+
f"max_images_per_prompt = {self.max_images_per_prompt}"
|
552
|
+
)
|
553
|
+
if tf.cast(
|
554
|
+
tf.math.reduce_sum(
|
555
|
+
tf.cast(
|
556
|
+
tf.math.greater(
|
557
|
+
num_valid_images, self.max_images_per_prompt
|
558
|
+
),
|
559
|
+
dtype=tf.int32,
|
560
|
+
)
|
561
|
+
),
|
562
|
+
dtype=bool,
|
563
|
+
):
|
564
|
+
raise ValueError(
|
565
|
+
"`num_valid_images` should have values <= "
|
566
|
+
"self.max_images_per_prompt. Received: "
|
567
|
+
f"num_valid_images = {num_valid_images}, ",
|
568
|
+
f"max_images_per_prompt = {self.max_images_per_prompt}",
|
569
|
+
)
|
570
|
+
|
571
|
+
# Resize, rescale, etc. the images.
|
572
|
+
padded_images_shape = tf.shape(images)
|
573
|
+
images = tf.reshape(
|
574
|
+
images,
|
575
|
+
[
|
576
|
+
-1,
|
577
|
+
padded_images_shape[-3],
|
578
|
+
padded_images_shape[-2],
|
579
|
+
padded_images_shape[-1],
|
580
|
+
],
|
581
|
+
)
|
582
|
+
images = self.image_converter(images)
|
583
|
+
height = (
|
584
|
+
self.image_size[0]
|
585
|
+
if self.image_converter.image_size
|
586
|
+
else original_image_shape[-3]
|
587
|
+
)
|
588
|
+
width = (
|
589
|
+
self.image_size[1]
|
590
|
+
if self.image_converter.image_size
|
591
|
+
else original_image_shape[-2]
|
592
|
+
)
|
593
|
+
images = tf.reshape(
|
594
|
+
images,
|
595
|
+
[
|
596
|
+
padded_images_shape[0],
|
597
|
+
self.max_images_per_prompt,
|
598
|
+
height,
|
599
|
+
width,
|
600
|
+
3,
|
601
|
+
],
|
602
|
+
)
|
603
|
+
|
604
|
+
padding_mask = token_ids != self.tokenizer.pad_token_id
|
605
|
+
token_ids = tf.ragged.boolean_mask(token_ids, padding_mask)
|
606
|
+
segment_ids = tf.ragged.boolean_mask(segment_ids, padding_mask)
|
607
|
+
padding_mask = tf.ragged.boolean_mask(padding_mask, padding_mask)
|
608
|
+
response_mask = segment_ids == 1
|
609
|
+
|
610
|
+
# Using `num_valid_images`, we need to add dummy image tokens at the
|
611
|
+
# end of the tokenized text. Ideally, we could have passed an image
|
612
|
+
# padding mask to the model, but it won't work with XLA since an
|
613
|
+
# `ops.where` on it in the interleaving layer will return different
|
614
|
+
# number of images every time. So, we need to fix the number of images.
|
615
|
+
vision_placeholder_tensor = self._get_image_placeholder_ragged_tensor(
|
616
|
+
(self.max_images_per_prompt - num_valid_images)
|
617
|
+
* self.num_vision_tokens_per_image,
|
618
|
+
self.tokenizer.token_to_id("<img>"),
|
619
|
+
)
|
620
|
+
vision_placeholder_tensor = vision_placeholder_tensor.to_tensor(
|
621
|
+
shape=[
|
622
|
+
batch_size,
|
623
|
+
self.max_images_per_prompt * self.num_vision_tokens_per_image,
|
624
|
+
],
|
625
|
+
default_value=self.tokenizer.pad_token_id,
|
626
|
+
)
|
627
|
+
token_ids_with_placeholder = tf.concat(
|
628
|
+
[token_ids, vision_placeholder_tensor], axis=1
|
629
|
+
)
|
630
|
+
|
631
|
+
# Now, pad everything to the same length.
|
632
|
+
desired_length = (
|
633
|
+
sequence_length
|
634
|
+
+ self.max_images_per_prompt * self.num_vision_tokens_per_image
|
635
|
+
)
|
636
|
+
token_ids_with_placeholder = token_ids_with_placeholder.to_tensor(
|
637
|
+
shape=[batch_size, desired_length],
|
638
|
+
default_value=self.tokenizer.pad_token_id,
|
639
|
+
)
|
640
|
+
padding_mask_with_placeholder = padding_mask.to_tensor(
|
641
|
+
shape=[batch_size, desired_length],
|
642
|
+
default_value=False,
|
643
|
+
)
|
644
|
+
response_mask_with_placeholder = response_mask.to_tensor(
|
645
|
+
shape=[batch_size, desired_length],
|
646
|
+
default_value=False,
|
647
|
+
)
|
648
|
+
|
649
|
+
text_mask = token_ids_with_placeholder != self.tokenizer.token_to_id(
|
650
|
+
"<img>"
|
651
|
+
)
|
652
|
+
|
653
|
+
return self._format_output(
|
654
|
+
images=images,
|
655
|
+
token_ids=token_ids_with_placeholder,
|
656
|
+
text_mask=text_mask,
|
657
|
+
response_mask=response_mask_with_placeholder,
|
658
|
+
padding_mask=padding_mask_with_placeholder,
|
659
|
+
return_labels=False,
|
660
|
+
)
|
661
|
+
|
662
|
+
def get_config(self):
|
663
|
+
config = super().get_config()
|
664
|
+
|
665
|
+
config.update(
|
666
|
+
{
|
667
|
+
"num_vision_tokens_per_image": self.num_vision_tokens_per_image,
|
668
|
+
"max_images_per_prompt": self.max_images_per_prompt,
|
669
|
+
}
|
670
|
+
)
|
671
|
+
return config
|
672
|
+
|
673
|
+
@preprocessing_function
|
674
|
+
def generate_postprocess(
|
675
|
+
self,
|
676
|
+
x,
|
677
|
+
):
|
678
|
+
"""Convert integer token output to strings for generation.
|
679
|
+
|
680
|
+
This method reverses `generate_preprocess()`, by first removing all
|
681
|
+
padding and start/end tokens, and then converting the integer sequence
|
682
|
+
back to a string.
|
683
|
+
"""
|
684
|
+
if not self.built:
|
685
|
+
self.build(None)
|
686
|
+
|
687
|
+
token_ids, padding_mask = x["token_ids"], x["padding_mask"]
|
688
|
+
ids_to_strip = self.tokenizer.special_token_ids
|
689
|
+
ids_to_strip += [self.tokenizer.token_to_id("<end_of_image>")]
|
690
|
+
token_ids = strip_to_ragged(token_ids, padding_mask, ids_to_strip)
|
691
|
+
return self.tokenizer.detokenize(token_ids)
|