keras-hub-nightly 0.23.0.dev202508260411__py3-none-any.whl → 0.23.0.dev202508280418__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. keras_hub/layers/__init__.py +6 -0
  2. keras_hub/models/__init__.py +21 -0
  3. keras_hub/src/layers/modeling/position_embedding.py +21 -6
  4. keras_hub/src/layers/modeling/rotary_embedding.py +16 -6
  5. keras_hub/src/layers/modeling/sine_position_encoding.py +21 -8
  6. keras_hub/src/layers/modeling/token_and_position_embedding.py +2 -1
  7. keras_hub/src/models/backbone.py +10 -15
  8. keras_hub/src/models/d_fine/__init__.py +0 -0
  9. keras_hub/src/models/d_fine/d_fine_attention.py +461 -0
  10. keras_hub/src/models/d_fine/d_fine_backbone.py +891 -0
  11. keras_hub/src/models/d_fine/d_fine_decoder.py +944 -0
  12. keras_hub/src/models/d_fine/d_fine_encoder.py +365 -0
  13. keras_hub/src/models/d_fine/d_fine_hybrid_encoder.py +642 -0
  14. keras_hub/src/models/d_fine/d_fine_image_converter.py +8 -0
  15. keras_hub/src/models/d_fine/d_fine_layers.py +1828 -0
  16. keras_hub/src/models/d_fine/d_fine_loss.py +938 -0
  17. keras_hub/src/models/d_fine/d_fine_object_detector.py +875 -0
  18. keras_hub/src/models/d_fine/d_fine_object_detector_preprocessor.py +14 -0
  19. keras_hub/src/models/d_fine/d_fine_presets.py +2 -0
  20. keras_hub/src/models/d_fine/d_fine_utils.py +827 -0
  21. keras_hub/src/models/hgnetv2/hgnetv2_backbone.py +4 -1
  22. keras_hub/src/models/hgnetv2/hgnetv2_encoder.py +3 -2
  23. keras_hub/src/models/hgnetv2/hgnetv2_layers.py +27 -11
  24. keras_hub/src/models/parseq/__init__.py +0 -0
  25. keras_hub/src/models/parseq/parseq_backbone.py +134 -0
  26. keras_hub/src/models/parseq/parseq_causal_lm.py +466 -0
  27. keras_hub/src/models/parseq/parseq_causal_lm_preprocessor.py +168 -0
  28. keras_hub/src/models/parseq/parseq_decoder.py +418 -0
  29. keras_hub/src/models/parseq/parseq_image_converter.py +8 -0
  30. keras_hub/src/models/parseq/parseq_tokenizer.py +221 -0
  31. keras_hub/src/tests/test_case.py +37 -1
  32. keras_hub/src/utils/preset_utils.py +49 -0
  33. keras_hub/src/utils/tensor_utils.py +23 -1
  34. keras_hub/src/utils/transformers/convert_vit.py +4 -1
  35. keras_hub/src/version.py +1 -1
  36. keras_hub/tokenizers/__init__.py +3 -0
  37. {keras_hub_nightly-0.23.0.dev202508260411.dist-info → keras_hub_nightly-0.23.0.dev202508280418.dist-info}/METADATA +1 -1
  38. {keras_hub_nightly-0.23.0.dev202508260411.dist-info → keras_hub_nightly-0.23.0.dev202508280418.dist-info}/RECORD +40 -20
  39. {keras_hub_nightly-0.23.0.dev202508260411.dist-info → keras_hub_nightly-0.23.0.dev202508280418.dist-info}/WHEEL +0 -0
  40. {keras_hub_nightly-0.23.0.dev202508260411.dist-info → keras_hub_nightly-0.23.0.dev202508280418.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,466 @@
1
+ import math
2
+
3
+ import keras
4
+ from keras import ops
5
+ from keras import random
6
+
7
+ from keras_hub.src.api_export import keras_hub_export
8
+ from keras_hub.src.models.causal_lm import CausalLM
9
+ from keras_hub.src.models.parseq.parseq_backbone import PARSeqBackbone
10
+ from keras_hub.src.models.parseq.parseq_causal_lm_preprocessor import (
11
+ PARSeqCausalLMPreprocessor,
12
+ )
13
+ from keras_hub.src.utils.tensor_utils import any_equal
14
+
15
+
16
+ @keras_hub_export("keras_hub.models.PARSeqCausalLM")
17
+ class PARSeqCausalLM(CausalLM):
18
+ """Scene Text Recognition with PARSeq.
19
+ Performs OCR in natural scenes using the PARSeq model described in
20
+ [Scene Text Recognition with Permuted Autoregressive Sequence Models](
21
+ https://arxiv.org/abs/2207.06966). PARSeq is a ViT-based model that allows
22
+ iterative decoding by performing an autoregressive decoding phase, followed
23
+ by a refinement phase.
24
+ Args:
25
+ preprocessor: A `keras_hub.models.Preprocessor` instance or a
26
+ `keras.Layer` instance. The preprocessor to use for the model.
27
+ backbone: A `keras_hub.models.PARSeqBackbone` instance or a
28
+ `keras.Model`. The backbone model to use for the model.
29
+ num_perms: int. The number of permutations to generate for training.
30
+ Defaults to 6.
31
+ add_forward_perms: bool. Whether to add forward permutations to the
32
+ generated permutations. Defaults to `True`.
33
+ add_mirrored_perms: bool. Whether to add mirrored permutations to the
34
+ generated permutations. Defaults to `True`.
35
+ seed: int. The random seed to use for generating permutations.
36
+ Defaults to `None`, which means no seed is set.
37
+ **kwargs: Additional keyword arguments passed to the base
38
+ `keras_hub.models.CausalLM` constructor.
39
+
40
+ Examples:
41
+
42
+ Call `predict()` to run inference.
43
+ ```python
44
+ # Load preset and run inference
45
+ images = np.random.randint(0, 256, size=(2, 32, 128, 3))
46
+ parseq = keras_hub.models.PARSeqCausalLM.from_preset(
47
+ "parseq_vit"
48
+ )
49
+ parseq.generate(images)
50
+
51
+ # Call `fit()` on a single batch.
52
+ images = np.random.randint(0, 256, size=(2, 32, 128, 3))
53
+ token_ids = np.array([[1, 2, 3, 4], [1, 2, 3, 0]])
54
+ padding_mask = np.array([[1, 1, 1, 1], [1, 1, 1, 0]])
55
+ parseq = keras_hub.models.PARSeqCausalLM.from_preset(
56
+ "parseq_vit"
57
+ )
58
+ parseq.fit(
59
+ x={
60
+ "images": images,
61
+ "token_ids": token_ids,
62
+ "padding_mask": padding_mask
63
+ },
64
+ batch_size=2,
65
+ )
66
+ ```
67
+ # Call `fit()` with custom loss, optimizer and image encoder.
68
+ ```python
69
+ # Initialize the image encoder, preprocessor and tokenizer
70
+ mean, std = 0.5, 0.5
71
+ image_converter = PARSeqImageConverter(
72
+ image_size=(32, 128),
73
+ offset=-mean / std,
74
+ scale=1.0 / 255.0 / std,
75
+ interpolation="bicubic",
76
+ )
77
+ tokenizer = PARSeqTokenizer(max_label_length=25)
78
+ preprocessor = keras_hub.models.PARSeqCausalLMPreprocessor(
79
+ image_converter=image_converter,
80
+ tokenizer=tokenizer,
81
+ )
82
+
83
+ # Create the backbone
84
+ image_encoder = ViTBackbone(
85
+ image_shape=(32, 128, 3),
86
+ patch_size=(4, 8),
87
+ num_layers=12,
88
+ num_heads=6,
89
+ hidden_dim=384,
90
+ mlp_dim=384 * 4,
91
+ use_class_token=False,
92
+ name="encoder",
93
+ )
94
+ backbone = PARSeqBackbone(
95
+ vocabulary_size=97,
96
+ max_label_length=25,
97
+ image_encoder=image_encoder,
98
+ num_decoder_heads=12,
99
+ num_decoder_layers=1,
100
+ decoder_hidden_dim=384,
101
+ decoder_mlp_dim=4 * 384,
102
+ )
103
+ # Create the PARSeq model
104
+ parseq = keras_hub.models.PARSeqCausalLM(
105
+ backbone=backbone,
106
+ preprocessor=preprocessor,
107
+ )
108
+ parseq.compile(
109
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
110
+ optimizer=keras.optimizers.Adam(5e-5),
111
+ )
112
+ parseq.fit(
113
+ x={
114
+ "images": images,
115
+ "token_ids": token_ids,
116
+ "padding_mask": padding_mask
117
+ },
118
+ batch_size=2,
119
+ )
120
+ ```
121
+ """
122
+
123
+ backbone_cls = PARSeqBackbone
124
+ preprocessor_cls = PARSeqCausalLMPreprocessor
125
+
126
+ def __init__(
127
+ self,
128
+ preprocessor,
129
+ backbone,
130
+ num_perms=6,
131
+ add_forward_perms=True,
132
+ add_mirrored_perms=True,
133
+ seed=None,
134
+ end_token_id=0, # default tokenizer.end_token_id
135
+ **kwargs,
136
+ ):
137
+ # === Layers ===
138
+ self.preprocessor = preprocessor
139
+ self.backbone = backbone
140
+
141
+ # === Functional Model ===
142
+ # This must be "backbone.input" i.e. the full input structure,
143
+ # rather than "backbone.inputs" which is the flattened list of inputs.
144
+ inputs = backbone.input
145
+ outputs = backbone(inputs=inputs)
146
+ super().__init__(
147
+ inputs=inputs,
148
+ outputs=outputs,
149
+ **kwargs,
150
+ )
151
+
152
+ # === Config ===
153
+ self.num_perms = num_perms
154
+ self.add_forward_perms = add_forward_perms
155
+ self.add_mirrored_perms = add_mirrored_perms
156
+ self.end_token_id = end_token_id
157
+ self.seed = seed
158
+ self.seed_generator = keras.random.SeedGenerator(seed)
159
+
160
+ def get_config(self):
161
+ config = super().get_config()
162
+ config.update(
163
+ {
164
+ "num_perms": self.num_perms,
165
+ "add_forward_perms": self.add_forward_perms,
166
+ "add_mirrored_perms": self.add_mirrored_perms,
167
+ "seed": self.seed,
168
+ "end_token_id": self.end_token_id,
169
+ }
170
+ )
171
+
172
+ return config
173
+
174
+ def compile(
175
+ self,
176
+ optimizer="auto",
177
+ loss="auto",
178
+ *,
179
+ weighted_metrics="auto",
180
+ sampler="greedy",
181
+ **kwargs,
182
+ ):
183
+ if loss == "auto":
184
+ loss = keras.losses.SparseCategoricalCrossentropy(
185
+ from_logits=True,
186
+ ignore_class=self.preprocessor.tokenizer.pad_token_id,
187
+ )
188
+ super().compile(
189
+ optimizer=optimizer,
190
+ loss=loss,
191
+ weighted_metrics=weighted_metrics,
192
+ sampler=sampler,
193
+ **kwargs,
194
+ )
195
+
196
+ def compute_loss(
197
+ self, x, y, y_pred, sample_weight, training=True, *args, **kwargs
198
+ ):
199
+ # For keras we have fixed input for all batches, so in this case
200
+ # we permute 23 tokens excluding BOS and EOS tokens instead of max
201
+ # characters for current batch used in torch implementation
202
+ # -1 because we will be generating permutation mask for considering
203
+ # tokens before creating target label.
204
+ max_num_chars = self.backbone.max_label_length - 1
205
+ perms = self.generate_training_permutations(max_num_chars)
206
+ max_label_length = self.backbone.max_label_length
207
+ memory = self.backbone.image_encoder(x["images"])
208
+ batch_size = ops.shape(x["images"])[0]
209
+ losses = []
210
+ for i in range(ops.shape(perms)[0]):
211
+ query_mask, content_mask = self.generate_attention_masks(perms[i])
212
+ query_mask = ops.broadcast_to(
213
+ query_mask, (batch_size, max_label_length, max_label_length)
214
+ )
215
+ content_mask = ops.broadcast_to(
216
+ content_mask, (batch_size, max_label_length, max_label_length)
217
+ )
218
+ out = self.backbone.decoder(
219
+ x["token_ids"],
220
+ memory,
221
+ padding_mask=x["padding_mask"],
222
+ query_mask=query_mask,
223
+ content_mask=content_mask,
224
+ )
225
+ y_pred = self.backbone.head(out)
226
+ loss = super().compute_loss(
227
+ x=x, y=y, y_pred=y_pred, sample_weight=sample_weight, **kwargs
228
+ )
229
+ losses.append(loss)
230
+ if i == 1:
231
+ # Sample weights are set to zero for end-of-sequence (EOS)
232
+ # tokens to prevent them from affecting loss calculations.
233
+ # reference: https://github.com/baudm/parseq/blob/1902db043c029a7e03a3818c616c06600af574be/strhub/models/parseq/system.py#L194 # noqa: E501
234
+ sample_weight = ops.logical_and(
235
+ y != self.end_token_id, sample_weight
236
+ )
237
+
238
+ return ops.sum(losses) / ops.shape(perms)[0]
239
+
240
+ def generate_training_permutations(self, max_num_chars):
241
+ max_gen_perms = (
242
+ self.num_perms // 2 if self.add_mirrored_perms else self.num_perms
243
+ )
244
+
245
+ if max_num_chars == 1:
246
+ return ops.expand_dims(ops.arange(3), axis=0)
247
+
248
+ perms = [ops.arange(max_num_chars)] if self.add_forward_perms else []
249
+ max_num_perms = math.factorial(max_num_chars)
250
+ max_gen_perms = min(max_gen_perms, max_num_perms)
251
+
252
+ for _ in range(max_gen_perms - len(perms)):
253
+ perm = random.shuffle(
254
+ ops.arange(max_num_chars), seed=self.seed_generator
255
+ )
256
+ perms.append(perm)
257
+
258
+ perms = ops.stack(perms)
259
+ comp = ops.flip(perms, axis=-1)
260
+ perms = ops.stack([perms, comp])
261
+ perms = ops.reshape(
262
+ ops.transpose(perms, (1, 0, 2)), (-1, max_num_chars)
263
+ )
264
+
265
+ bos_idx = ops.zeros((ops.shape(perms)[0], 1), dtype="int32")
266
+ eos_idx = ops.full(
267
+ (ops.shape(perms)[0], 1), max_num_chars + 1, dtype="int32"
268
+ )
269
+ perms = ops.concatenate([bos_idx, perms + 1, eos_idx], axis=1)
270
+
271
+ if perms.shape[0] > 1:
272
+ perms = ops.scatter_update(
273
+ perms,
274
+ ops.concatenate(
275
+ [
276
+ ops.ones((max_num_chars + 1, 1), dtype="int32"),
277
+ ops.expand_dims(
278
+ ops.arange(1, max_num_chars + 2, dtype="int32"),
279
+ axis=1,
280
+ ),
281
+ ],
282
+ axis=1,
283
+ ),
284
+ max_num_chars + 1 - ops.arange(max_num_chars + 1),
285
+ )
286
+
287
+ return perms
288
+
289
+ def generate_attention_masks(self, perm):
290
+ """Generate attention masks given a sequence permutation
291
+ (includes pos. for BOS and EOS tokens)"""
292
+ input_length = ops.shape(perm)[0]
293
+ mask = ops.ones((input_length, input_length))
294
+ for i in range(input_length - 1):
295
+ masked_keys = perm[i + 1 : input_length]
296
+ query_idx = ops.broadcast_to(perm[i], ops.shape(masked_keys))
297
+ indices = ops.stack((query_idx, masked_keys), axis=1)
298
+ mask = keras.ops.scatter_update(
299
+ mask, indices, keras.ops.zeros(ops.shape(masked_keys)[0])
300
+ )
301
+ content_mask = mask[:-1, :-1]
302
+ mask = mask * (1 - ops.eye(input_length))
303
+ query_mask = mask[1:, :-1]
304
+ return query_mask, content_mask
305
+
306
+ def call_with_cache(
307
+ self,
308
+ token_ids,
309
+ cache,
310
+ cache_update_index,
311
+ img_embeddings,
312
+ padding_mask=None,
313
+ ):
314
+ bs = ops.shape(token_ids)[0]
315
+ # <bos> stands for the null context. We only supply position information
316
+ # for characters after <bos>.
317
+ content = ops.where(
318
+ cache_update_index == 0,
319
+ self.backbone.decoder_hidden_dim**0.5
320
+ * self.backbone.decoder.token_embedding(token_ids),
321
+ ops.expand_dims(
322
+ self.backbone.decoder.pos_query_embeddings[
323
+ :, cache_update_index - 1, :
324
+ ],
325
+ axis=0,
326
+ )
327
+ + self.backbone.decoder_hidden_dim**0.5
328
+ * self.backbone.decoder.token_embedding(token_ids),
329
+ )
330
+ content = self.backbone.decoder.dropout(content)
331
+
332
+ query = ops.ones((bs, 1, 1)) * ops.expand_dims(
333
+ self.backbone.decoder.pos_query_embeddings[
334
+ :, cache_update_index, :
335
+ ],
336
+ axis=0,
337
+ )
338
+ query = self.backbone.decoder.dropout(query)
339
+
340
+ query_cache = []
341
+ content_cache = []
342
+ for i, decoder_layer in enumerate(self.backbone.decoder.decoder_layers):
343
+ last = i == self.backbone.num_decoder_layers - 1
344
+ current_query_cache = cache[:, i, 0, ...]
345
+ current_content_cache = cache[:, i, 1, ...]
346
+ (
347
+ query,
348
+ content,
349
+ query_self_attention_new_cache,
350
+ content_self_attention_cache,
351
+ ) = decoder_layer(
352
+ query=query,
353
+ content=content,
354
+ memory=img_embeddings,
355
+ padding_mask=padding_mask,
356
+ update_content=not last,
357
+ query_self_attention_cache=current_query_cache,
358
+ query_self_attention_cache_update_index=cache_update_index,
359
+ content_self_attention_cache=current_content_cache,
360
+ content_self_attention_cache_update_index=cache_update_index,
361
+ )
362
+ query_cache.append(query_self_attention_new_cache)
363
+ content_cache.append(content_self_attention_cache)
364
+
365
+ query_cache = ops.stack(query_cache, axis=1)
366
+ content_cache = ops.stack(content_cache, axis=1)
367
+ cache = ops.stack([query_cache, content_cache], axis=2)
368
+ hidden_states = self.backbone.decoder.layer_norm(query)
369
+ logits = self.backbone.head(hidden_states)
370
+ return logits, hidden_states, cache
371
+
372
+ def _build_cache(self, token_ids, img_embeddings, padding_mask):
373
+ batch_size = ops.shape(token_ids)[0]
374
+ max_length = ops.shape(token_ids)[1]
375
+ num_layers = self.backbone.num_decoder_layers
376
+ head_dim = (
377
+ self.backbone.decoder_hidden_dim // self.backbone.num_decoder_heads
378
+ )
379
+ num_heads = self.backbone.num_decoder_heads
380
+ shape = [batch_size, num_layers, 2, 2, max_length, num_heads, head_dim]
381
+ cache = ops.zeros(shape)
382
+
383
+ # Seed the cache.
384
+ logits, hidden_states, cache = self.call_with_cache(
385
+ token_ids=token_ids,
386
+ img_embeddings=img_embeddings,
387
+ cache=cache,
388
+ cache_update_index=0,
389
+ padding_mask=padding_mask,
390
+ )
391
+ return hidden_states, cache
392
+
393
+ def generate_step(self, inputs, stop_token_ids=None):
394
+ token_ids, padding_mask, images = (
395
+ inputs["token_ids"],
396
+ inputs["padding_mask"],
397
+ inputs["images"],
398
+ )
399
+ images_shape = ops.shape(images)
400
+ if len(images_shape) == 3:
401
+ # Handle an unbatched image. Unlike `token_ids` and `padding_mask`
402
+ # this will not automatically be upranked.
403
+ images = ops.expand_dims(images, axis=0)
404
+
405
+ img_embeddings = self.backbone.image_encoder(images)
406
+ # Create and seed cache with a single forward pass.
407
+ hidden_states, cache = self._build_cache(
408
+ token_ids=token_ids,
409
+ img_embeddings=img_embeddings,
410
+ padding_mask=padding_mask,
411
+ )
412
+ # Compute the lengths of all user inputted tokens ids.
413
+ row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1)
414
+ # Start at the first index that has no user inputted id.
415
+ index = ops.min(row_lengths)
416
+
417
+ def next(prompt, cache, index):
418
+ # The cache index is the index of our previous token.
419
+ cache_update_index = index - 1
420
+ batch_size = ops.shape(prompt)[0]
421
+ prompt = ops.slice(prompt, [0, index - 1], [batch_size, 1])
422
+ logits, hidden_states, cache = self.call_with_cache(
423
+ token_ids=prompt,
424
+ cache=cache,
425
+ cache_update_index=cache_update_index,
426
+ img_embeddings=img_embeddings,
427
+ )
428
+ return (
429
+ ops.squeeze(logits, axis=1),
430
+ ops.squeeze(hidden_states, axis=1),
431
+ cache,
432
+ )
433
+
434
+ token_ids = self.sampler(
435
+ next=next,
436
+ prompt=token_ids,
437
+ cache=cache,
438
+ index=index,
439
+ mask=padding_mask,
440
+ stop_token_ids=stop_token_ids,
441
+ hidden_states=hidden_states,
442
+ model=self,
443
+ )
444
+
445
+ # Compute an output padding mask with the token ids we updated.
446
+ if stop_token_ids is not None:
447
+ # Build a mask of `stop_token_ids` locations not in the original
448
+ # prompt (not in locations where `padding_mask` is True).
449
+ end_locations = any_equal(
450
+ token_ids, stop_token_ids, ops.logical_not(padding_mask)
451
+ )
452
+
453
+ end_locations = ops.cast(end_locations, "int32")
454
+ # Use cumsum to get ones in all locations after end_locations.
455
+ cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32")
456
+ overflow = cumsum - end_locations
457
+ # Our padding mask is the inverse of these overflow locations.
458
+ padding_mask = ops.logical_not(ops.cast(overflow, "bool"))
459
+ else:
460
+ # Without early stopping, all locations will have been updated.
461
+ padding_mask = ops.ones_like(token_ids, dtype="bool")
462
+ return {
463
+ "token_ids": token_ids,
464
+ "padding_mask": padding_mask,
465
+ "images": images,
466
+ }
@@ -0,0 +1,168 @@
1
+ import keras
2
+ from keras import ops
3
+
4
+ from keras_hub.src.api_export import keras_hub_export
5
+ from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker
6
+ from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor
7
+ from keras_hub.src.models.parseq.parseq_backbone import PARSeqBackbone
8
+ from keras_hub.src.models.parseq.parseq_image_converter import (
9
+ PARSeqImageConverter,
10
+ )
11
+ from keras_hub.src.models.parseq.parseq_tokenizer import PARSeqTokenizer
12
+ from keras_hub.src.utils.tensor_utils import preprocessing_function
13
+ from keras_hub.src.utils.tensor_utils import strip_to_ragged
14
+
15
+
16
+ @keras_hub_export("keras_hub.models.PARSeqCausalLMPreprocessor")
17
+ class PARSeqCausalLMPreprocessor(CausalLMPreprocessor):
18
+ backbone_cls = PARSeqBackbone
19
+ tokenizer_cls = PARSeqTokenizer
20
+ image_converter_cls = PARSeqImageConverter
21
+
22
+ def __init__(
23
+ self,
24
+ image_converter=None,
25
+ tokenizer=None,
26
+ sequence_length=25,
27
+ add_start_token=True,
28
+ add_end_token=True,
29
+ **kwargs,
30
+ ):
31
+ super().__init__(
32
+ tokenizer=tokenizer,
33
+ sequence_length=sequence_length,
34
+ add_start_token=add_start_token,
35
+ add_end_token=add_end_token,
36
+ **kwargs,
37
+ )
38
+ self.image_converter = image_converter
39
+
40
+ def build(self, input_shape):
41
+ # Defer packer creation to `build()` so that we can be sure tokenizer
42
+ # assets have loaded when restoring a saved model.
43
+ self.packer = StartEndPacker(
44
+ start_value=self.tokenizer.start_token_id,
45
+ end_value=self.tokenizer.end_token_id,
46
+ pad_value=self.tokenizer.pad_token_id,
47
+ sequence_length=self.sequence_length,
48
+ return_padding_mask=True,
49
+ )
50
+ self.built = True
51
+
52
+ @preprocessing_function
53
+ def call(self, x, y=None, sample_weight=None, sequence_length=None):
54
+ """Preprocesses the input data for training.
55
+
56
+ This method takes a dictionary containing images and text responses,
57
+ and converts them into a format suitable for training a PARSeq model.
58
+
59
+ Args:
60
+ x: dict. A dictionary containing the input data. Must have keys
61
+ "images" and "responses".
62
+ y: The target data. Defaults to None.
63
+ sample_weight: The sample weights. Defaults to None.
64
+ sequence_length: int. The maximum length of the input sequence.
65
+ Defaults to None, which uses the pre-defined sequence length.
66
+
67
+ Returns:
68
+ A tuple containing the preprocessed input data, target data, and
69
+ sample weights.
70
+ """
71
+ sequence_length = sequence_length or self.sequence_length
72
+ images, responses = x["images"], x["responses"]
73
+ if self.image_converter:
74
+ images = self.image_converter(images)
75
+ token_ids = self.tokenizer(responses)
76
+ token_ids, padding_mask = self.packer(
77
+ token_ids,
78
+ sequence_length=sequence_length + 1,
79
+ add_start_value=self.add_start_token,
80
+ add_end_value=self.add_end_token,
81
+ )
82
+ x = {
83
+ "images": images,
84
+ "token_ids": token_ids[..., :-1],
85
+ "padding_mask": padding_mask[..., :-1],
86
+ }
87
+ # Target `y` will be the next token.
88
+ y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:]
89
+ return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
90
+
91
+ @preprocessing_function
92
+ def generate_preprocess(
93
+ self,
94
+ x,
95
+ sequence_length=None,
96
+ ):
97
+ """Convert strings to integer token input for generation.
98
+
99
+ Similar to calling the layer for training, this method takes in strings
100
+ or tensor strings, tokenizes and packs the input, and computes a padding
101
+ mask masking all inputs not filled in with a padded value.
102
+
103
+ Unlike calling the layer for training, this method does not compute
104
+ labels and will never append a `tokenizer.end_token_id` to the end of
105
+ the sequence (as generation is expected to continue at the end of the
106
+ inputted prompt).
107
+ """
108
+ if not self.built:
109
+ self.build(None)
110
+ sequence_length = sequence_length or self.sequence_length
111
+ images = x
112
+ if self.image_converter:
113
+ images = self.image_converter(images)
114
+
115
+ images_shape = keras.ops.shape(images)
116
+ if len(images_shape) == 3:
117
+ batch_size = 1
118
+ else:
119
+ batch_size = images_shape[0]
120
+
121
+ token_ids = ops.concatenate(
122
+ (
123
+ ops.full([batch_size, 1], self.tokenizer.start_token_id),
124
+ ops.full(
125
+ [batch_size, sequence_length - 1],
126
+ self.tokenizer.pad_token_id,
127
+ ),
128
+ ),
129
+ axis=1,
130
+ )
131
+
132
+ padding_mask = ops.equal(token_ids, self.tokenizer.start_token_id)
133
+
134
+ return {
135
+ "images": images,
136
+ "token_ids": token_ids,
137
+ "padding_mask": padding_mask,
138
+ }
139
+
140
+ @preprocessing_function
141
+ def generate_postprocess(
142
+ self,
143
+ x,
144
+ ):
145
+ """Convert integer token output to strings for generation.
146
+
147
+ This method reverses `generate_preprocess()`, by first removing all
148
+ padding and start/end tokens, and then converting the integer sequence
149
+ back to a string.
150
+ """
151
+ if not self.built:
152
+ self.build(None)
153
+
154
+ token_ids, padding_mask = x["token_ids"], x["padding_mask"]
155
+ ids_to_strip = self.tokenizer.special_token_ids
156
+ token_ids = strip_to_ragged(token_ids, padding_mask, ids_to_strip)
157
+ return self.tokenizer.detokenize(token_ids)
158
+
159
+ def get_config(self):
160
+ config = super().get_config()
161
+ config.update(
162
+ {
163
+ "sequence_length": self.sequence_length,
164
+ "add_start_token": self.add_start_token,
165
+ "add_end_token": self.add_end_token,
166
+ }
167
+ )
168
+ return config