keras-hub-nightly 0.23.0.dev202508260411__py3-none-any.whl → 0.23.0.dev202508280418__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/layers/__init__.py +6 -0
- keras_hub/models/__init__.py +21 -0
- keras_hub/src/layers/modeling/position_embedding.py +21 -6
- keras_hub/src/layers/modeling/rotary_embedding.py +16 -6
- keras_hub/src/layers/modeling/sine_position_encoding.py +21 -8
- keras_hub/src/layers/modeling/token_and_position_embedding.py +2 -1
- keras_hub/src/models/backbone.py +10 -15
- keras_hub/src/models/d_fine/__init__.py +0 -0
- keras_hub/src/models/d_fine/d_fine_attention.py +461 -0
- keras_hub/src/models/d_fine/d_fine_backbone.py +891 -0
- keras_hub/src/models/d_fine/d_fine_decoder.py +944 -0
- keras_hub/src/models/d_fine/d_fine_encoder.py +365 -0
- keras_hub/src/models/d_fine/d_fine_hybrid_encoder.py +642 -0
- keras_hub/src/models/d_fine/d_fine_image_converter.py +8 -0
- keras_hub/src/models/d_fine/d_fine_layers.py +1828 -0
- keras_hub/src/models/d_fine/d_fine_loss.py +938 -0
- keras_hub/src/models/d_fine/d_fine_object_detector.py +875 -0
- keras_hub/src/models/d_fine/d_fine_object_detector_preprocessor.py +14 -0
- keras_hub/src/models/d_fine/d_fine_presets.py +2 -0
- keras_hub/src/models/d_fine/d_fine_utils.py +827 -0
- keras_hub/src/models/hgnetv2/hgnetv2_backbone.py +4 -1
- keras_hub/src/models/hgnetv2/hgnetv2_encoder.py +3 -2
- keras_hub/src/models/hgnetv2/hgnetv2_layers.py +27 -11
- keras_hub/src/models/parseq/__init__.py +0 -0
- keras_hub/src/models/parseq/parseq_backbone.py +134 -0
- keras_hub/src/models/parseq/parseq_causal_lm.py +466 -0
- keras_hub/src/models/parseq/parseq_causal_lm_preprocessor.py +168 -0
- keras_hub/src/models/parseq/parseq_decoder.py +418 -0
- keras_hub/src/models/parseq/parseq_image_converter.py +8 -0
- keras_hub/src/models/parseq/parseq_tokenizer.py +221 -0
- keras_hub/src/tests/test_case.py +37 -1
- keras_hub/src/utils/preset_utils.py +49 -0
- keras_hub/src/utils/tensor_utils.py +23 -1
- keras_hub/src/utils/transformers/convert_vit.py +4 -1
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +3 -0
- {keras_hub_nightly-0.23.0.dev202508260411.dist-info → keras_hub_nightly-0.23.0.dev202508280418.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.23.0.dev202508260411.dist-info → keras_hub_nightly-0.23.0.dev202508280418.dist-info}/RECORD +40 -20
- {keras_hub_nightly-0.23.0.dev202508260411.dist-info → keras_hub_nightly-0.23.0.dev202508280418.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.23.0.dev202508260411.dist-info → keras_hub_nightly-0.23.0.dev202508280418.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,466 @@
|
|
1
|
+
import math
|
2
|
+
|
3
|
+
import keras
|
4
|
+
from keras import ops
|
5
|
+
from keras import random
|
6
|
+
|
7
|
+
from keras_hub.src.api_export import keras_hub_export
|
8
|
+
from keras_hub.src.models.causal_lm import CausalLM
|
9
|
+
from keras_hub.src.models.parseq.parseq_backbone import PARSeqBackbone
|
10
|
+
from keras_hub.src.models.parseq.parseq_causal_lm_preprocessor import (
|
11
|
+
PARSeqCausalLMPreprocessor,
|
12
|
+
)
|
13
|
+
from keras_hub.src.utils.tensor_utils import any_equal
|
14
|
+
|
15
|
+
|
16
|
+
@keras_hub_export("keras_hub.models.PARSeqCausalLM")
|
17
|
+
class PARSeqCausalLM(CausalLM):
|
18
|
+
"""Scene Text Recognition with PARSeq.
|
19
|
+
Performs OCR in natural scenes using the PARSeq model described in
|
20
|
+
[Scene Text Recognition with Permuted Autoregressive Sequence Models](
|
21
|
+
https://arxiv.org/abs/2207.06966). PARSeq is a ViT-based model that allows
|
22
|
+
iterative decoding by performing an autoregressive decoding phase, followed
|
23
|
+
by a refinement phase.
|
24
|
+
Args:
|
25
|
+
preprocessor: A `keras_hub.models.Preprocessor` instance or a
|
26
|
+
`keras.Layer` instance. The preprocessor to use for the model.
|
27
|
+
backbone: A `keras_hub.models.PARSeqBackbone` instance or a
|
28
|
+
`keras.Model`. The backbone model to use for the model.
|
29
|
+
num_perms: int. The number of permutations to generate for training.
|
30
|
+
Defaults to 6.
|
31
|
+
add_forward_perms: bool. Whether to add forward permutations to the
|
32
|
+
generated permutations. Defaults to `True`.
|
33
|
+
add_mirrored_perms: bool. Whether to add mirrored permutations to the
|
34
|
+
generated permutations. Defaults to `True`.
|
35
|
+
seed: int. The random seed to use for generating permutations.
|
36
|
+
Defaults to `None`, which means no seed is set.
|
37
|
+
**kwargs: Additional keyword arguments passed to the base
|
38
|
+
`keras_hub.models.CausalLM` constructor.
|
39
|
+
|
40
|
+
Examples:
|
41
|
+
|
42
|
+
Call `predict()` to run inference.
|
43
|
+
```python
|
44
|
+
# Load preset and run inference
|
45
|
+
images = np.random.randint(0, 256, size=(2, 32, 128, 3))
|
46
|
+
parseq = keras_hub.models.PARSeqCausalLM.from_preset(
|
47
|
+
"parseq_vit"
|
48
|
+
)
|
49
|
+
parseq.generate(images)
|
50
|
+
|
51
|
+
# Call `fit()` on a single batch.
|
52
|
+
images = np.random.randint(0, 256, size=(2, 32, 128, 3))
|
53
|
+
token_ids = np.array([[1, 2, 3, 4], [1, 2, 3, 0]])
|
54
|
+
padding_mask = np.array([[1, 1, 1, 1], [1, 1, 1, 0]])
|
55
|
+
parseq = keras_hub.models.PARSeqCausalLM.from_preset(
|
56
|
+
"parseq_vit"
|
57
|
+
)
|
58
|
+
parseq.fit(
|
59
|
+
x={
|
60
|
+
"images": images,
|
61
|
+
"token_ids": token_ids,
|
62
|
+
"padding_mask": padding_mask
|
63
|
+
},
|
64
|
+
batch_size=2,
|
65
|
+
)
|
66
|
+
```
|
67
|
+
# Call `fit()` with custom loss, optimizer and image encoder.
|
68
|
+
```python
|
69
|
+
# Initialize the image encoder, preprocessor and tokenizer
|
70
|
+
mean, std = 0.5, 0.5
|
71
|
+
image_converter = PARSeqImageConverter(
|
72
|
+
image_size=(32, 128),
|
73
|
+
offset=-mean / std,
|
74
|
+
scale=1.0 / 255.0 / std,
|
75
|
+
interpolation="bicubic",
|
76
|
+
)
|
77
|
+
tokenizer = PARSeqTokenizer(max_label_length=25)
|
78
|
+
preprocessor = keras_hub.models.PARSeqCausalLMPreprocessor(
|
79
|
+
image_converter=image_converter,
|
80
|
+
tokenizer=tokenizer,
|
81
|
+
)
|
82
|
+
|
83
|
+
# Create the backbone
|
84
|
+
image_encoder = ViTBackbone(
|
85
|
+
image_shape=(32, 128, 3),
|
86
|
+
patch_size=(4, 8),
|
87
|
+
num_layers=12,
|
88
|
+
num_heads=6,
|
89
|
+
hidden_dim=384,
|
90
|
+
mlp_dim=384 * 4,
|
91
|
+
use_class_token=False,
|
92
|
+
name="encoder",
|
93
|
+
)
|
94
|
+
backbone = PARSeqBackbone(
|
95
|
+
vocabulary_size=97,
|
96
|
+
max_label_length=25,
|
97
|
+
image_encoder=image_encoder,
|
98
|
+
num_decoder_heads=12,
|
99
|
+
num_decoder_layers=1,
|
100
|
+
decoder_hidden_dim=384,
|
101
|
+
decoder_mlp_dim=4 * 384,
|
102
|
+
)
|
103
|
+
# Create the PARSeq model
|
104
|
+
parseq = keras_hub.models.PARSeqCausalLM(
|
105
|
+
backbone=backbone,
|
106
|
+
preprocessor=preprocessor,
|
107
|
+
)
|
108
|
+
parseq.compile(
|
109
|
+
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
|
110
|
+
optimizer=keras.optimizers.Adam(5e-5),
|
111
|
+
)
|
112
|
+
parseq.fit(
|
113
|
+
x={
|
114
|
+
"images": images,
|
115
|
+
"token_ids": token_ids,
|
116
|
+
"padding_mask": padding_mask
|
117
|
+
},
|
118
|
+
batch_size=2,
|
119
|
+
)
|
120
|
+
```
|
121
|
+
"""
|
122
|
+
|
123
|
+
backbone_cls = PARSeqBackbone
|
124
|
+
preprocessor_cls = PARSeqCausalLMPreprocessor
|
125
|
+
|
126
|
+
def __init__(
|
127
|
+
self,
|
128
|
+
preprocessor,
|
129
|
+
backbone,
|
130
|
+
num_perms=6,
|
131
|
+
add_forward_perms=True,
|
132
|
+
add_mirrored_perms=True,
|
133
|
+
seed=None,
|
134
|
+
end_token_id=0, # default tokenizer.end_token_id
|
135
|
+
**kwargs,
|
136
|
+
):
|
137
|
+
# === Layers ===
|
138
|
+
self.preprocessor = preprocessor
|
139
|
+
self.backbone = backbone
|
140
|
+
|
141
|
+
# === Functional Model ===
|
142
|
+
# This must be "backbone.input" i.e. the full input structure,
|
143
|
+
# rather than "backbone.inputs" which is the flattened list of inputs.
|
144
|
+
inputs = backbone.input
|
145
|
+
outputs = backbone(inputs=inputs)
|
146
|
+
super().__init__(
|
147
|
+
inputs=inputs,
|
148
|
+
outputs=outputs,
|
149
|
+
**kwargs,
|
150
|
+
)
|
151
|
+
|
152
|
+
# === Config ===
|
153
|
+
self.num_perms = num_perms
|
154
|
+
self.add_forward_perms = add_forward_perms
|
155
|
+
self.add_mirrored_perms = add_mirrored_perms
|
156
|
+
self.end_token_id = end_token_id
|
157
|
+
self.seed = seed
|
158
|
+
self.seed_generator = keras.random.SeedGenerator(seed)
|
159
|
+
|
160
|
+
def get_config(self):
|
161
|
+
config = super().get_config()
|
162
|
+
config.update(
|
163
|
+
{
|
164
|
+
"num_perms": self.num_perms,
|
165
|
+
"add_forward_perms": self.add_forward_perms,
|
166
|
+
"add_mirrored_perms": self.add_mirrored_perms,
|
167
|
+
"seed": self.seed,
|
168
|
+
"end_token_id": self.end_token_id,
|
169
|
+
}
|
170
|
+
)
|
171
|
+
|
172
|
+
return config
|
173
|
+
|
174
|
+
def compile(
|
175
|
+
self,
|
176
|
+
optimizer="auto",
|
177
|
+
loss="auto",
|
178
|
+
*,
|
179
|
+
weighted_metrics="auto",
|
180
|
+
sampler="greedy",
|
181
|
+
**kwargs,
|
182
|
+
):
|
183
|
+
if loss == "auto":
|
184
|
+
loss = keras.losses.SparseCategoricalCrossentropy(
|
185
|
+
from_logits=True,
|
186
|
+
ignore_class=self.preprocessor.tokenizer.pad_token_id,
|
187
|
+
)
|
188
|
+
super().compile(
|
189
|
+
optimizer=optimizer,
|
190
|
+
loss=loss,
|
191
|
+
weighted_metrics=weighted_metrics,
|
192
|
+
sampler=sampler,
|
193
|
+
**kwargs,
|
194
|
+
)
|
195
|
+
|
196
|
+
def compute_loss(
|
197
|
+
self, x, y, y_pred, sample_weight, training=True, *args, **kwargs
|
198
|
+
):
|
199
|
+
# For keras we have fixed input for all batches, so in this case
|
200
|
+
# we permute 23 tokens excluding BOS and EOS tokens instead of max
|
201
|
+
# characters for current batch used in torch implementation
|
202
|
+
# -1 because we will be generating permutation mask for considering
|
203
|
+
# tokens before creating target label.
|
204
|
+
max_num_chars = self.backbone.max_label_length - 1
|
205
|
+
perms = self.generate_training_permutations(max_num_chars)
|
206
|
+
max_label_length = self.backbone.max_label_length
|
207
|
+
memory = self.backbone.image_encoder(x["images"])
|
208
|
+
batch_size = ops.shape(x["images"])[0]
|
209
|
+
losses = []
|
210
|
+
for i in range(ops.shape(perms)[0]):
|
211
|
+
query_mask, content_mask = self.generate_attention_masks(perms[i])
|
212
|
+
query_mask = ops.broadcast_to(
|
213
|
+
query_mask, (batch_size, max_label_length, max_label_length)
|
214
|
+
)
|
215
|
+
content_mask = ops.broadcast_to(
|
216
|
+
content_mask, (batch_size, max_label_length, max_label_length)
|
217
|
+
)
|
218
|
+
out = self.backbone.decoder(
|
219
|
+
x["token_ids"],
|
220
|
+
memory,
|
221
|
+
padding_mask=x["padding_mask"],
|
222
|
+
query_mask=query_mask,
|
223
|
+
content_mask=content_mask,
|
224
|
+
)
|
225
|
+
y_pred = self.backbone.head(out)
|
226
|
+
loss = super().compute_loss(
|
227
|
+
x=x, y=y, y_pred=y_pred, sample_weight=sample_weight, **kwargs
|
228
|
+
)
|
229
|
+
losses.append(loss)
|
230
|
+
if i == 1:
|
231
|
+
# Sample weights are set to zero for end-of-sequence (EOS)
|
232
|
+
# tokens to prevent them from affecting loss calculations.
|
233
|
+
# reference: https://github.com/baudm/parseq/blob/1902db043c029a7e03a3818c616c06600af574be/strhub/models/parseq/system.py#L194 # noqa: E501
|
234
|
+
sample_weight = ops.logical_and(
|
235
|
+
y != self.end_token_id, sample_weight
|
236
|
+
)
|
237
|
+
|
238
|
+
return ops.sum(losses) / ops.shape(perms)[0]
|
239
|
+
|
240
|
+
def generate_training_permutations(self, max_num_chars):
|
241
|
+
max_gen_perms = (
|
242
|
+
self.num_perms // 2 if self.add_mirrored_perms else self.num_perms
|
243
|
+
)
|
244
|
+
|
245
|
+
if max_num_chars == 1:
|
246
|
+
return ops.expand_dims(ops.arange(3), axis=0)
|
247
|
+
|
248
|
+
perms = [ops.arange(max_num_chars)] if self.add_forward_perms else []
|
249
|
+
max_num_perms = math.factorial(max_num_chars)
|
250
|
+
max_gen_perms = min(max_gen_perms, max_num_perms)
|
251
|
+
|
252
|
+
for _ in range(max_gen_perms - len(perms)):
|
253
|
+
perm = random.shuffle(
|
254
|
+
ops.arange(max_num_chars), seed=self.seed_generator
|
255
|
+
)
|
256
|
+
perms.append(perm)
|
257
|
+
|
258
|
+
perms = ops.stack(perms)
|
259
|
+
comp = ops.flip(perms, axis=-1)
|
260
|
+
perms = ops.stack([perms, comp])
|
261
|
+
perms = ops.reshape(
|
262
|
+
ops.transpose(perms, (1, 0, 2)), (-1, max_num_chars)
|
263
|
+
)
|
264
|
+
|
265
|
+
bos_idx = ops.zeros((ops.shape(perms)[0], 1), dtype="int32")
|
266
|
+
eos_idx = ops.full(
|
267
|
+
(ops.shape(perms)[0], 1), max_num_chars + 1, dtype="int32"
|
268
|
+
)
|
269
|
+
perms = ops.concatenate([bos_idx, perms + 1, eos_idx], axis=1)
|
270
|
+
|
271
|
+
if perms.shape[0] > 1:
|
272
|
+
perms = ops.scatter_update(
|
273
|
+
perms,
|
274
|
+
ops.concatenate(
|
275
|
+
[
|
276
|
+
ops.ones((max_num_chars + 1, 1), dtype="int32"),
|
277
|
+
ops.expand_dims(
|
278
|
+
ops.arange(1, max_num_chars + 2, dtype="int32"),
|
279
|
+
axis=1,
|
280
|
+
),
|
281
|
+
],
|
282
|
+
axis=1,
|
283
|
+
),
|
284
|
+
max_num_chars + 1 - ops.arange(max_num_chars + 1),
|
285
|
+
)
|
286
|
+
|
287
|
+
return perms
|
288
|
+
|
289
|
+
def generate_attention_masks(self, perm):
|
290
|
+
"""Generate attention masks given a sequence permutation
|
291
|
+
(includes pos. for BOS and EOS tokens)"""
|
292
|
+
input_length = ops.shape(perm)[0]
|
293
|
+
mask = ops.ones((input_length, input_length))
|
294
|
+
for i in range(input_length - 1):
|
295
|
+
masked_keys = perm[i + 1 : input_length]
|
296
|
+
query_idx = ops.broadcast_to(perm[i], ops.shape(masked_keys))
|
297
|
+
indices = ops.stack((query_idx, masked_keys), axis=1)
|
298
|
+
mask = keras.ops.scatter_update(
|
299
|
+
mask, indices, keras.ops.zeros(ops.shape(masked_keys)[0])
|
300
|
+
)
|
301
|
+
content_mask = mask[:-1, :-1]
|
302
|
+
mask = mask * (1 - ops.eye(input_length))
|
303
|
+
query_mask = mask[1:, :-1]
|
304
|
+
return query_mask, content_mask
|
305
|
+
|
306
|
+
def call_with_cache(
|
307
|
+
self,
|
308
|
+
token_ids,
|
309
|
+
cache,
|
310
|
+
cache_update_index,
|
311
|
+
img_embeddings,
|
312
|
+
padding_mask=None,
|
313
|
+
):
|
314
|
+
bs = ops.shape(token_ids)[0]
|
315
|
+
# <bos> stands for the null context. We only supply position information
|
316
|
+
# for characters after <bos>.
|
317
|
+
content = ops.where(
|
318
|
+
cache_update_index == 0,
|
319
|
+
self.backbone.decoder_hidden_dim**0.5
|
320
|
+
* self.backbone.decoder.token_embedding(token_ids),
|
321
|
+
ops.expand_dims(
|
322
|
+
self.backbone.decoder.pos_query_embeddings[
|
323
|
+
:, cache_update_index - 1, :
|
324
|
+
],
|
325
|
+
axis=0,
|
326
|
+
)
|
327
|
+
+ self.backbone.decoder_hidden_dim**0.5
|
328
|
+
* self.backbone.decoder.token_embedding(token_ids),
|
329
|
+
)
|
330
|
+
content = self.backbone.decoder.dropout(content)
|
331
|
+
|
332
|
+
query = ops.ones((bs, 1, 1)) * ops.expand_dims(
|
333
|
+
self.backbone.decoder.pos_query_embeddings[
|
334
|
+
:, cache_update_index, :
|
335
|
+
],
|
336
|
+
axis=0,
|
337
|
+
)
|
338
|
+
query = self.backbone.decoder.dropout(query)
|
339
|
+
|
340
|
+
query_cache = []
|
341
|
+
content_cache = []
|
342
|
+
for i, decoder_layer in enumerate(self.backbone.decoder.decoder_layers):
|
343
|
+
last = i == self.backbone.num_decoder_layers - 1
|
344
|
+
current_query_cache = cache[:, i, 0, ...]
|
345
|
+
current_content_cache = cache[:, i, 1, ...]
|
346
|
+
(
|
347
|
+
query,
|
348
|
+
content,
|
349
|
+
query_self_attention_new_cache,
|
350
|
+
content_self_attention_cache,
|
351
|
+
) = decoder_layer(
|
352
|
+
query=query,
|
353
|
+
content=content,
|
354
|
+
memory=img_embeddings,
|
355
|
+
padding_mask=padding_mask,
|
356
|
+
update_content=not last,
|
357
|
+
query_self_attention_cache=current_query_cache,
|
358
|
+
query_self_attention_cache_update_index=cache_update_index,
|
359
|
+
content_self_attention_cache=current_content_cache,
|
360
|
+
content_self_attention_cache_update_index=cache_update_index,
|
361
|
+
)
|
362
|
+
query_cache.append(query_self_attention_new_cache)
|
363
|
+
content_cache.append(content_self_attention_cache)
|
364
|
+
|
365
|
+
query_cache = ops.stack(query_cache, axis=1)
|
366
|
+
content_cache = ops.stack(content_cache, axis=1)
|
367
|
+
cache = ops.stack([query_cache, content_cache], axis=2)
|
368
|
+
hidden_states = self.backbone.decoder.layer_norm(query)
|
369
|
+
logits = self.backbone.head(hidden_states)
|
370
|
+
return logits, hidden_states, cache
|
371
|
+
|
372
|
+
def _build_cache(self, token_ids, img_embeddings, padding_mask):
|
373
|
+
batch_size = ops.shape(token_ids)[0]
|
374
|
+
max_length = ops.shape(token_ids)[1]
|
375
|
+
num_layers = self.backbone.num_decoder_layers
|
376
|
+
head_dim = (
|
377
|
+
self.backbone.decoder_hidden_dim // self.backbone.num_decoder_heads
|
378
|
+
)
|
379
|
+
num_heads = self.backbone.num_decoder_heads
|
380
|
+
shape = [batch_size, num_layers, 2, 2, max_length, num_heads, head_dim]
|
381
|
+
cache = ops.zeros(shape)
|
382
|
+
|
383
|
+
# Seed the cache.
|
384
|
+
logits, hidden_states, cache = self.call_with_cache(
|
385
|
+
token_ids=token_ids,
|
386
|
+
img_embeddings=img_embeddings,
|
387
|
+
cache=cache,
|
388
|
+
cache_update_index=0,
|
389
|
+
padding_mask=padding_mask,
|
390
|
+
)
|
391
|
+
return hidden_states, cache
|
392
|
+
|
393
|
+
def generate_step(self, inputs, stop_token_ids=None):
|
394
|
+
token_ids, padding_mask, images = (
|
395
|
+
inputs["token_ids"],
|
396
|
+
inputs["padding_mask"],
|
397
|
+
inputs["images"],
|
398
|
+
)
|
399
|
+
images_shape = ops.shape(images)
|
400
|
+
if len(images_shape) == 3:
|
401
|
+
# Handle an unbatched image. Unlike `token_ids` and `padding_mask`
|
402
|
+
# this will not automatically be upranked.
|
403
|
+
images = ops.expand_dims(images, axis=0)
|
404
|
+
|
405
|
+
img_embeddings = self.backbone.image_encoder(images)
|
406
|
+
# Create and seed cache with a single forward pass.
|
407
|
+
hidden_states, cache = self._build_cache(
|
408
|
+
token_ids=token_ids,
|
409
|
+
img_embeddings=img_embeddings,
|
410
|
+
padding_mask=padding_mask,
|
411
|
+
)
|
412
|
+
# Compute the lengths of all user inputted tokens ids.
|
413
|
+
row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1)
|
414
|
+
# Start at the first index that has no user inputted id.
|
415
|
+
index = ops.min(row_lengths)
|
416
|
+
|
417
|
+
def next(prompt, cache, index):
|
418
|
+
# The cache index is the index of our previous token.
|
419
|
+
cache_update_index = index - 1
|
420
|
+
batch_size = ops.shape(prompt)[0]
|
421
|
+
prompt = ops.slice(prompt, [0, index - 1], [batch_size, 1])
|
422
|
+
logits, hidden_states, cache = self.call_with_cache(
|
423
|
+
token_ids=prompt,
|
424
|
+
cache=cache,
|
425
|
+
cache_update_index=cache_update_index,
|
426
|
+
img_embeddings=img_embeddings,
|
427
|
+
)
|
428
|
+
return (
|
429
|
+
ops.squeeze(logits, axis=1),
|
430
|
+
ops.squeeze(hidden_states, axis=1),
|
431
|
+
cache,
|
432
|
+
)
|
433
|
+
|
434
|
+
token_ids = self.sampler(
|
435
|
+
next=next,
|
436
|
+
prompt=token_ids,
|
437
|
+
cache=cache,
|
438
|
+
index=index,
|
439
|
+
mask=padding_mask,
|
440
|
+
stop_token_ids=stop_token_ids,
|
441
|
+
hidden_states=hidden_states,
|
442
|
+
model=self,
|
443
|
+
)
|
444
|
+
|
445
|
+
# Compute an output padding mask with the token ids we updated.
|
446
|
+
if stop_token_ids is not None:
|
447
|
+
# Build a mask of `stop_token_ids` locations not in the original
|
448
|
+
# prompt (not in locations where `padding_mask` is True).
|
449
|
+
end_locations = any_equal(
|
450
|
+
token_ids, stop_token_ids, ops.logical_not(padding_mask)
|
451
|
+
)
|
452
|
+
|
453
|
+
end_locations = ops.cast(end_locations, "int32")
|
454
|
+
# Use cumsum to get ones in all locations after end_locations.
|
455
|
+
cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32")
|
456
|
+
overflow = cumsum - end_locations
|
457
|
+
# Our padding mask is the inverse of these overflow locations.
|
458
|
+
padding_mask = ops.logical_not(ops.cast(overflow, "bool"))
|
459
|
+
else:
|
460
|
+
# Without early stopping, all locations will have been updated.
|
461
|
+
padding_mask = ops.ones_like(token_ids, dtype="bool")
|
462
|
+
return {
|
463
|
+
"token_ids": token_ids,
|
464
|
+
"padding_mask": padding_mask,
|
465
|
+
"images": images,
|
466
|
+
}
|
@@ -0,0 +1,168 @@
|
|
1
|
+
import keras
|
2
|
+
from keras import ops
|
3
|
+
|
4
|
+
from keras_hub.src.api_export import keras_hub_export
|
5
|
+
from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker
|
6
|
+
from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor
|
7
|
+
from keras_hub.src.models.parseq.parseq_backbone import PARSeqBackbone
|
8
|
+
from keras_hub.src.models.parseq.parseq_image_converter import (
|
9
|
+
PARSeqImageConverter,
|
10
|
+
)
|
11
|
+
from keras_hub.src.models.parseq.parseq_tokenizer import PARSeqTokenizer
|
12
|
+
from keras_hub.src.utils.tensor_utils import preprocessing_function
|
13
|
+
from keras_hub.src.utils.tensor_utils import strip_to_ragged
|
14
|
+
|
15
|
+
|
16
|
+
@keras_hub_export("keras_hub.models.PARSeqCausalLMPreprocessor")
|
17
|
+
class PARSeqCausalLMPreprocessor(CausalLMPreprocessor):
|
18
|
+
backbone_cls = PARSeqBackbone
|
19
|
+
tokenizer_cls = PARSeqTokenizer
|
20
|
+
image_converter_cls = PARSeqImageConverter
|
21
|
+
|
22
|
+
def __init__(
|
23
|
+
self,
|
24
|
+
image_converter=None,
|
25
|
+
tokenizer=None,
|
26
|
+
sequence_length=25,
|
27
|
+
add_start_token=True,
|
28
|
+
add_end_token=True,
|
29
|
+
**kwargs,
|
30
|
+
):
|
31
|
+
super().__init__(
|
32
|
+
tokenizer=tokenizer,
|
33
|
+
sequence_length=sequence_length,
|
34
|
+
add_start_token=add_start_token,
|
35
|
+
add_end_token=add_end_token,
|
36
|
+
**kwargs,
|
37
|
+
)
|
38
|
+
self.image_converter = image_converter
|
39
|
+
|
40
|
+
def build(self, input_shape):
|
41
|
+
# Defer packer creation to `build()` so that we can be sure tokenizer
|
42
|
+
# assets have loaded when restoring a saved model.
|
43
|
+
self.packer = StartEndPacker(
|
44
|
+
start_value=self.tokenizer.start_token_id,
|
45
|
+
end_value=self.tokenizer.end_token_id,
|
46
|
+
pad_value=self.tokenizer.pad_token_id,
|
47
|
+
sequence_length=self.sequence_length,
|
48
|
+
return_padding_mask=True,
|
49
|
+
)
|
50
|
+
self.built = True
|
51
|
+
|
52
|
+
@preprocessing_function
|
53
|
+
def call(self, x, y=None, sample_weight=None, sequence_length=None):
|
54
|
+
"""Preprocesses the input data for training.
|
55
|
+
|
56
|
+
This method takes a dictionary containing images and text responses,
|
57
|
+
and converts them into a format suitable for training a PARSeq model.
|
58
|
+
|
59
|
+
Args:
|
60
|
+
x: dict. A dictionary containing the input data. Must have keys
|
61
|
+
"images" and "responses".
|
62
|
+
y: The target data. Defaults to None.
|
63
|
+
sample_weight: The sample weights. Defaults to None.
|
64
|
+
sequence_length: int. The maximum length of the input sequence.
|
65
|
+
Defaults to None, which uses the pre-defined sequence length.
|
66
|
+
|
67
|
+
Returns:
|
68
|
+
A tuple containing the preprocessed input data, target data, and
|
69
|
+
sample weights.
|
70
|
+
"""
|
71
|
+
sequence_length = sequence_length or self.sequence_length
|
72
|
+
images, responses = x["images"], x["responses"]
|
73
|
+
if self.image_converter:
|
74
|
+
images = self.image_converter(images)
|
75
|
+
token_ids = self.tokenizer(responses)
|
76
|
+
token_ids, padding_mask = self.packer(
|
77
|
+
token_ids,
|
78
|
+
sequence_length=sequence_length + 1,
|
79
|
+
add_start_value=self.add_start_token,
|
80
|
+
add_end_value=self.add_end_token,
|
81
|
+
)
|
82
|
+
x = {
|
83
|
+
"images": images,
|
84
|
+
"token_ids": token_ids[..., :-1],
|
85
|
+
"padding_mask": padding_mask[..., :-1],
|
86
|
+
}
|
87
|
+
# Target `y` will be the next token.
|
88
|
+
y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:]
|
89
|
+
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
|
90
|
+
|
91
|
+
@preprocessing_function
|
92
|
+
def generate_preprocess(
|
93
|
+
self,
|
94
|
+
x,
|
95
|
+
sequence_length=None,
|
96
|
+
):
|
97
|
+
"""Convert strings to integer token input for generation.
|
98
|
+
|
99
|
+
Similar to calling the layer for training, this method takes in strings
|
100
|
+
or tensor strings, tokenizes and packs the input, and computes a padding
|
101
|
+
mask masking all inputs not filled in with a padded value.
|
102
|
+
|
103
|
+
Unlike calling the layer for training, this method does not compute
|
104
|
+
labels and will never append a `tokenizer.end_token_id` to the end of
|
105
|
+
the sequence (as generation is expected to continue at the end of the
|
106
|
+
inputted prompt).
|
107
|
+
"""
|
108
|
+
if not self.built:
|
109
|
+
self.build(None)
|
110
|
+
sequence_length = sequence_length or self.sequence_length
|
111
|
+
images = x
|
112
|
+
if self.image_converter:
|
113
|
+
images = self.image_converter(images)
|
114
|
+
|
115
|
+
images_shape = keras.ops.shape(images)
|
116
|
+
if len(images_shape) == 3:
|
117
|
+
batch_size = 1
|
118
|
+
else:
|
119
|
+
batch_size = images_shape[0]
|
120
|
+
|
121
|
+
token_ids = ops.concatenate(
|
122
|
+
(
|
123
|
+
ops.full([batch_size, 1], self.tokenizer.start_token_id),
|
124
|
+
ops.full(
|
125
|
+
[batch_size, sequence_length - 1],
|
126
|
+
self.tokenizer.pad_token_id,
|
127
|
+
),
|
128
|
+
),
|
129
|
+
axis=1,
|
130
|
+
)
|
131
|
+
|
132
|
+
padding_mask = ops.equal(token_ids, self.tokenizer.start_token_id)
|
133
|
+
|
134
|
+
return {
|
135
|
+
"images": images,
|
136
|
+
"token_ids": token_ids,
|
137
|
+
"padding_mask": padding_mask,
|
138
|
+
}
|
139
|
+
|
140
|
+
@preprocessing_function
|
141
|
+
def generate_postprocess(
|
142
|
+
self,
|
143
|
+
x,
|
144
|
+
):
|
145
|
+
"""Convert integer token output to strings for generation.
|
146
|
+
|
147
|
+
This method reverses `generate_preprocess()`, by first removing all
|
148
|
+
padding and start/end tokens, and then converting the integer sequence
|
149
|
+
back to a string.
|
150
|
+
"""
|
151
|
+
if not self.built:
|
152
|
+
self.build(None)
|
153
|
+
|
154
|
+
token_ids, padding_mask = x["token_ids"], x["padding_mask"]
|
155
|
+
ids_to_strip = self.tokenizer.special_token_ids
|
156
|
+
token_ids = strip_to_ragged(token_ids, padding_mask, ids_to_strip)
|
157
|
+
return self.tokenizer.detokenize(token_ids)
|
158
|
+
|
159
|
+
def get_config(self):
|
160
|
+
config = super().get_config()
|
161
|
+
config.update(
|
162
|
+
{
|
163
|
+
"sequence_length": self.sequence_length,
|
164
|
+
"add_start_token": self.add_start_token,
|
165
|
+
"add_end_token": self.add_end_token,
|
166
|
+
}
|
167
|
+
)
|
168
|
+
return config
|