keras-hub-nightly 0.22.0.dev202508170419__py3-none-any.whl → 0.24.0.dev202511090424__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of keras-hub-nightly might be problematic. Click here for more details.
- keras_hub/layers/__init__.py +15 -0
- keras_hub/models/__init__.py +93 -0
- keras_hub/src/layers/modeling/position_embedding.py +21 -6
- keras_hub/src/layers/modeling/reversible_embedding.py +8 -1
- keras_hub/src/layers/modeling/rotary_embedding.py +16 -6
- keras_hub/src/layers/modeling/sine_position_encoding.py +21 -8
- keras_hub/src/layers/modeling/token_and_position_embedding.py +2 -1
- keras_hub/src/models/backbone.py +28 -16
- keras_hub/src/models/causal_lm.py +37 -0
- keras_hub/src/models/causal_lm_preprocessor.py +14 -0
- keras_hub/src/models/clip/clip_presets.py +8 -8
- keras_hub/src/models/d_fine/__init__.py +5 -0
- keras_hub/src/models/d_fine/d_fine_attention.py +461 -0
- keras_hub/src/models/d_fine/d_fine_backbone.py +891 -0
- keras_hub/src/models/d_fine/d_fine_decoder.py +944 -0
- keras_hub/src/models/d_fine/d_fine_encoder.py +365 -0
- keras_hub/src/models/d_fine/d_fine_hybrid_encoder.py +642 -0
- keras_hub/src/models/d_fine/d_fine_image_converter.py +8 -0
- keras_hub/src/models/d_fine/d_fine_layers.py +1828 -0
- keras_hub/src/models/d_fine/d_fine_loss.py +938 -0
- keras_hub/src/models/d_fine/d_fine_object_detector.py +875 -0
- keras_hub/src/models/d_fine/d_fine_object_detector_preprocessor.py +14 -0
- keras_hub/src/models/d_fine/d_fine_presets.py +155 -0
- keras_hub/src/models/d_fine/d_fine_utils.py +827 -0
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +7 -2
- keras_hub/src/models/depth_anything/__init__.py +9 -0
- keras_hub/src/models/depth_anything/depth_anything_backbone.py +232 -0
- keras_hub/src/models/depth_anything/depth_anything_depth_estimator.py +70 -0
- keras_hub/src/models/depth_anything/depth_anything_depth_estimator_preprocessor.py +16 -0
- keras_hub/src/models/depth_anything/depth_anything_image_converter.py +10 -0
- keras_hub/src/models/depth_anything/depth_anything_layers.py +725 -0
- keras_hub/src/models/depth_anything/depth_anything_loss.py +89 -0
- keras_hub/src/models/depth_anything/depth_anything_presets.py +41 -0
- keras_hub/src/models/depth_anything/interpolate.py +62 -0
- keras_hub/src/models/depth_estimator.py +239 -0
- keras_hub/src/models/depth_estimator_preprocessor.py +78 -0
- keras_hub/src/models/dinov2/dinov2_backbone.py +29 -3
- keras_hub/src/models/dinov2/dinov2_layers.py +16 -4
- keras_hub/src/models/dinov3/__init__.py +5 -0
- keras_hub/src/models/dinov3/dinov3_backbone.py +263 -0
- keras_hub/src/models/dinov3/dinov3_image_converter.py +8 -0
- keras_hub/src/models/dinov3/dinov3_layers.py +1013 -0
- keras_hub/src/models/dinov3/dinov3_presets.py +4 -0
- keras_hub/src/models/gemma/gemma_backbone.py +0 -1
- keras_hub/src/models/gemma/gemma_presets.py +30 -0
- keras_hub/src/models/gemma3/gemma3_attention.py +48 -0
- keras_hub/src/models/gemma3/gemma3_backbone.py +4 -1
- keras_hub/src/models/gemma3/gemma3_decoder_block.py +12 -0
- keras_hub/src/models/gemma3/gemma3_presets.py +39 -0
- keras_hub/src/models/hgnetv2/hgnetv2_backbone.py +4 -1
- keras_hub/src/models/hgnetv2/hgnetv2_encoder.py +3 -2
- keras_hub/src/models/hgnetv2/hgnetv2_layers.py +27 -11
- keras_hub/src/models/image_to_image.py +5 -0
- keras_hub/src/models/inpaint.py +5 -0
- keras_hub/src/models/mobilenetv5/__init__.py +9 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_attention.py +699 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_backbone.py +396 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_blocks.py +890 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_builder.py +436 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier.py +157 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier_preprocessor.py +16 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_image_converter.py +10 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_layers.py +462 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_presets.py +15 -0
- keras_hub/src/models/mobilenetv5/mobilenetv5_utils.py +146 -0
- keras_hub/src/models/parseq/__init__.py +5 -0
- keras_hub/src/models/parseq/parseq_backbone.py +134 -0
- keras_hub/src/models/parseq/parseq_causal_lm.py +466 -0
- keras_hub/src/models/parseq/parseq_causal_lm_preprocessor.py +168 -0
- keras_hub/src/models/parseq/parseq_decoder.py +418 -0
- keras_hub/src/models/parseq/parseq_image_converter.py +8 -0
- keras_hub/src/models/parseq/parseq_presets.py +15 -0
- keras_hub/src/models/parseq/parseq_tokenizer.py +221 -0
- keras_hub/src/models/qwen3_moe/__init__.py +5 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py +371 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +365 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm.py +357 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm_preprocessor.py +12 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_decoder.py +672 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_layernorm.py +45 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +30 -0
- keras_hub/src/models/qwen3_moe/qwen3_moe_tokenizer.py +48 -0
- keras_hub/src/models/sam/sam_prompt_encoder.py +3 -1
- keras_hub/src/models/siglip/siglip_presets.py +15 -0
- keras_hub/src/models/smollm3/smollm3_backbone.py +211 -0
- keras_hub/src/models/smollm3/smollm3_causal_lm.py +310 -0
- keras_hub/src/models/smollm3/smollm3_causal_lm_preprocessor.py +84 -0
- keras_hub/src/models/smollm3/smollm3_layers.py +757 -0
- keras_hub/src/models/smollm3/smollm3_tokenizer.py +60 -0
- keras_hub/src/models/smollm3/smollm3_utils.py +56 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +3 -3
- keras_hub/src/models/t5gemma/__init__.py +5 -0
- keras_hub/src/models/t5gemma/t5gemma_attention.py +370 -0
- keras_hub/src/models/t5gemma/t5gemma_backbone.py +366 -0
- keras_hub/src/models/t5gemma/t5gemma_decoder.py +355 -0
- keras_hub/src/models/t5gemma/t5gemma_encoder.py +214 -0
- keras_hub/src/models/t5gemma/t5gemma_layers.py +118 -0
- keras_hub/src/models/t5gemma/t5gemma_presets.py +374 -0
- keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm.py +442 -0
- keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm_preprocessor.py +216 -0
- keras_hub/src/models/t5gemma/t5gemma_tokenizer.py +84 -0
- keras_hub/src/models/text_to_image.py +5 -0
- keras_hub/src/samplers/beam_sampler.py +6 -6
- keras_hub/src/samplers/sampler.py +8 -6
- keras_hub/src/tests/test_case.py +40 -3
- keras_hub/src/tokenizers/tokenizer.py +15 -0
- keras_hub/src/utils/openvino_utils.py +141 -0
- keras_hub/src/utils/preset_utils.py +58 -2
- keras_hub/src/utils/tensor_utils.py +26 -2
- keras_hub/src/utils/timm/convert_mobilenetv5.py +321 -0
- keras_hub/src/utils/timm/preset_loader.py +8 -4
- keras_hub/src/utils/transformers/convert_dinov2.py +1 -0
- keras_hub/src/utils/transformers/convert_dinov3.py +106 -0
- keras_hub/src/utils/transformers/convert_qwen3_moe.py +216 -0
- keras_hub/src/utils/transformers/convert_smollm3.py +139 -0
- keras_hub/src/utils/transformers/convert_t5gemma.py +229 -0
- keras_hub/src/utils/transformers/convert_vit.py +4 -1
- keras_hub/src/utils/transformers/export/gemma.py +49 -4
- keras_hub/src/utils/transformers/export/hf_exporter.py +71 -25
- keras_hub/src/utils/transformers/preset_loader.py +12 -0
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +15 -0
- {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/RECORD +126 -47
- {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
import keras
|
|
2
|
+
|
|
3
|
+
from keras_hub.src.api_export import keras_hub_export
|
|
4
|
+
from keras_hub.src.models.backbone import Backbone
|
|
5
|
+
from keras_hub.src.models.parseq.parseq_decoder import PARSeqDecoder
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@keras_hub_export("keras_hub.models.PARSeqBackbone")
|
|
9
|
+
class PARSeqBackbone(Backbone):
|
|
10
|
+
"""Scene Text Detection with PARSeq.
|
|
11
|
+
|
|
12
|
+
Performs OCR in natural scenes using the PARSeq model described in [Scene
|
|
13
|
+
Text Recognition with Permuted Autoregressive Sequence Models](
|
|
14
|
+
https://arxiv.org/abs/2207.06966). PARSeq is a ViT-based model that allows
|
|
15
|
+
iterative decoding by performing an autoregressive decoding phase, followed
|
|
16
|
+
by a refinement phase.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
image_encoder: keras.Model. The image encoder model.
|
|
20
|
+
vocabulary_size: int. The size of the vocabulary.
|
|
21
|
+
max_label_length: int. The maximum length of the label sequence.
|
|
22
|
+
decoder_hidden_dim: int. The dimension of the decoder hidden layers.
|
|
23
|
+
num_decoder_layers: int. The number of decoder layers.
|
|
24
|
+
num_decoder_heads: int. The number of attention heads in the decoder.
|
|
25
|
+
decoder_mlp_dim: int. The dimension of the decoder MLP hidden layer.
|
|
26
|
+
dropout_rate: float. The dropout rate for the decoder network.
|
|
27
|
+
Defaults to `0.1`.
|
|
28
|
+
attention_dropout: float. The dropout rate for the attention weights.
|
|
29
|
+
Defaults to `0.1`.
|
|
30
|
+
dtype: str. `None`, str, or `keras.mixed_precision.DTypePolicy`. The
|
|
31
|
+
dtype to use for the computations and weights.
|
|
32
|
+
**kwargs: Additional keyword arguments passed to the base
|
|
33
|
+
`keras.Model` constructor.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
image_encoder,
|
|
39
|
+
vocabulary_size,
|
|
40
|
+
max_label_length,
|
|
41
|
+
decoder_hidden_dim,
|
|
42
|
+
num_decoder_layers,
|
|
43
|
+
num_decoder_heads,
|
|
44
|
+
decoder_mlp_dim,
|
|
45
|
+
dropout_rate=0.1,
|
|
46
|
+
attention_dropout=0.1,
|
|
47
|
+
dtype=None,
|
|
48
|
+
**kwargs,
|
|
49
|
+
):
|
|
50
|
+
# === Layers ===
|
|
51
|
+
self.image_encoder = image_encoder
|
|
52
|
+
self.decoder = PARSeqDecoder(
|
|
53
|
+
vocabulary_size=vocabulary_size,
|
|
54
|
+
max_label_length=max_label_length,
|
|
55
|
+
num_layers=num_decoder_layers,
|
|
56
|
+
num_heads=num_decoder_heads,
|
|
57
|
+
hidden_dim=decoder_hidden_dim,
|
|
58
|
+
mlp_dim=decoder_mlp_dim,
|
|
59
|
+
dropout_rate=dropout_rate,
|
|
60
|
+
attention_dropout=attention_dropout,
|
|
61
|
+
name="decoder",
|
|
62
|
+
dtype=dtype,
|
|
63
|
+
)
|
|
64
|
+
self.head = keras.layers.Dense(
|
|
65
|
+
vocabulary_size - 2, # We don't predict <bos> nor <pad>
|
|
66
|
+
dtype=dtype,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
# === Functional Model ===
|
|
70
|
+
image_input = self.image_encoder.input
|
|
71
|
+
|
|
72
|
+
token_id_input = keras.Input(
|
|
73
|
+
shape=(None,), dtype="int32", name="token_ids"
|
|
74
|
+
)
|
|
75
|
+
padding_mask_input = keras.Input(
|
|
76
|
+
shape=(None,), dtype="int32", name="padding_mask"
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
memory = self.image_encoder(image_input)
|
|
80
|
+
target_out = self.decoder(
|
|
81
|
+
token_id_input, memory, padding_mask=padding_mask_input
|
|
82
|
+
)
|
|
83
|
+
logits = self.head(target_out)
|
|
84
|
+
|
|
85
|
+
# === Config ===
|
|
86
|
+
self.vocabulary_size = vocabulary_size
|
|
87
|
+
self.max_label_length = max_label_length
|
|
88
|
+
self.decoder_hidden_dim = decoder_hidden_dim
|
|
89
|
+
self.num_decoder_layers = num_decoder_layers
|
|
90
|
+
self.num_decoder_heads = num_decoder_heads
|
|
91
|
+
self.decoder_mlp_dim = decoder_mlp_dim
|
|
92
|
+
self.dropout_rate = dropout_rate
|
|
93
|
+
self.attention_dropout = attention_dropout
|
|
94
|
+
|
|
95
|
+
super().__init__(
|
|
96
|
+
inputs={
|
|
97
|
+
"images": image_input,
|
|
98
|
+
"token_ids": token_id_input,
|
|
99
|
+
"padding_mask": padding_mask_input,
|
|
100
|
+
},
|
|
101
|
+
outputs=logits,
|
|
102
|
+
dtype=dtype,
|
|
103
|
+
**kwargs,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
def get_config(self):
|
|
107
|
+
config = super().get_config()
|
|
108
|
+
config.update(
|
|
109
|
+
{
|
|
110
|
+
"image_encoder": keras.layers.serialize(self.image_encoder),
|
|
111
|
+
"vocabulary_size": self.vocabulary_size,
|
|
112
|
+
"max_label_length": self.max_label_length,
|
|
113
|
+
"decoder_hidden_dim": self.decoder_hidden_dim,
|
|
114
|
+
"num_decoder_layers": self.num_decoder_layers,
|
|
115
|
+
"num_decoder_heads": self.num_decoder_heads,
|
|
116
|
+
"decoder_mlp_dim": self.decoder_mlp_dim,
|
|
117
|
+
"dropout_rate": self.dropout_rate,
|
|
118
|
+
"attention_dropout": self.attention_dropout,
|
|
119
|
+
}
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
return config
|
|
123
|
+
|
|
124
|
+
@classmethod
|
|
125
|
+
def from_config(cls, config):
|
|
126
|
+
config.update(
|
|
127
|
+
{
|
|
128
|
+
"image_encoder": keras.layers.deserialize(
|
|
129
|
+
config["image_encoder"]
|
|
130
|
+
),
|
|
131
|
+
}
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
return super().from_config(config)
|
|
@@ -0,0 +1,466 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
import keras
|
|
4
|
+
from keras import ops
|
|
5
|
+
from keras import random
|
|
6
|
+
|
|
7
|
+
from keras_hub.src.api_export import keras_hub_export
|
|
8
|
+
from keras_hub.src.models.causal_lm import CausalLM
|
|
9
|
+
from keras_hub.src.models.parseq.parseq_backbone import PARSeqBackbone
|
|
10
|
+
from keras_hub.src.models.parseq.parseq_causal_lm_preprocessor import (
|
|
11
|
+
PARSeqCausalLMPreprocessor,
|
|
12
|
+
)
|
|
13
|
+
from keras_hub.src.utils.tensor_utils import any_equal
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@keras_hub_export("keras_hub.models.PARSeqCausalLM")
|
|
17
|
+
class PARSeqCausalLM(CausalLM):
|
|
18
|
+
"""Scene Text Recognition with PARSeq.
|
|
19
|
+
Performs OCR in natural scenes using the PARSeq model described in
|
|
20
|
+
[Scene Text Recognition with Permuted Autoregressive Sequence Models](
|
|
21
|
+
https://arxiv.org/abs/2207.06966). PARSeq is a ViT-based model that allows
|
|
22
|
+
iterative decoding by performing an autoregressive decoding phase, followed
|
|
23
|
+
by a refinement phase.
|
|
24
|
+
Args:
|
|
25
|
+
preprocessor: A `keras_hub.models.Preprocessor` instance or a
|
|
26
|
+
`keras.Layer` instance. The preprocessor to use for the model.
|
|
27
|
+
backbone: A `keras_hub.models.PARSeqBackbone` instance or a
|
|
28
|
+
`keras.Model`. The backbone model to use for the model.
|
|
29
|
+
num_perms: int. The number of permutations to generate for training.
|
|
30
|
+
Defaults to 6.
|
|
31
|
+
add_forward_perms: bool. Whether to add forward permutations to the
|
|
32
|
+
generated permutations. Defaults to `True`.
|
|
33
|
+
add_mirrored_perms: bool. Whether to add mirrored permutations to the
|
|
34
|
+
generated permutations. Defaults to `True`.
|
|
35
|
+
seed: int. The random seed to use for generating permutations.
|
|
36
|
+
Defaults to `None`, which means no seed is set.
|
|
37
|
+
**kwargs: Additional keyword arguments passed to the base
|
|
38
|
+
`keras_hub.models.CausalLM` constructor.
|
|
39
|
+
|
|
40
|
+
Examples:
|
|
41
|
+
|
|
42
|
+
Call `predict()` to run inference.
|
|
43
|
+
```python
|
|
44
|
+
# Load preset and run inference
|
|
45
|
+
images = np.random.randint(0, 256, size=(2, 32, 128, 3))
|
|
46
|
+
parseq = keras_hub.models.PARSeqCausalLM.from_preset(
|
|
47
|
+
"parseq_vit"
|
|
48
|
+
)
|
|
49
|
+
parseq.generate(images)
|
|
50
|
+
|
|
51
|
+
# Call `fit()` on a single batch.
|
|
52
|
+
images = np.random.randint(0, 256, size=(2, 32, 128, 3))
|
|
53
|
+
token_ids = np.array([[1, 2, 3, 4], [1, 2, 3, 0]])
|
|
54
|
+
padding_mask = np.array([[1, 1, 1, 1], [1, 1, 1, 0]])
|
|
55
|
+
parseq = keras_hub.models.PARSeqCausalLM.from_preset(
|
|
56
|
+
"parseq_vit"
|
|
57
|
+
)
|
|
58
|
+
parseq.fit(
|
|
59
|
+
x={
|
|
60
|
+
"images": images,
|
|
61
|
+
"token_ids": token_ids,
|
|
62
|
+
"padding_mask": padding_mask
|
|
63
|
+
},
|
|
64
|
+
batch_size=2,
|
|
65
|
+
)
|
|
66
|
+
```
|
|
67
|
+
# Call `fit()` with custom loss, optimizer and image encoder.
|
|
68
|
+
```python
|
|
69
|
+
# Initialize the image encoder, preprocessor and tokenizer
|
|
70
|
+
mean, std = 0.5, 0.5
|
|
71
|
+
image_converter = PARSeqImageConverter(
|
|
72
|
+
image_size=(32, 128),
|
|
73
|
+
offset=-mean / std,
|
|
74
|
+
scale=1.0 / 255.0 / std,
|
|
75
|
+
interpolation="bicubic",
|
|
76
|
+
)
|
|
77
|
+
tokenizer = PARSeqTokenizer(max_label_length=25)
|
|
78
|
+
preprocessor = keras_hub.models.PARSeqCausalLMPreprocessor(
|
|
79
|
+
image_converter=image_converter,
|
|
80
|
+
tokenizer=tokenizer,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# Create the backbone
|
|
84
|
+
image_encoder = ViTBackbone(
|
|
85
|
+
image_shape=(32, 128, 3),
|
|
86
|
+
patch_size=(4, 8),
|
|
87
|
+
num_layers=12,
|
|
88
|
+
num_heads=6,
|
|
89
|
+
hidden_dim=384,
|
|
90
|
+
mlp_dim=384 * 4,
|
|
91
|
+
use_class_token=False,
|
|
92
|
+
name="encoder",
|
|
93
|
+
)
|
|
94
|
+
backbone = PARSeqBackbone(
|
|
95
|
+
vocabulary_size=97,
|
|
96
|
+
max_label_length=25,
|
|
97
|
+
image_encoder=image_encoder,
|
|
98
|
+
num_decoder_heads=12,
|
|
99
|
+
num_decoder_layers=1,
|
|
100
|
+
decoder_hidden_dim=384,
|
|
101
|
+
decoder_mlp_dim=4 * 384,
|
|
102
|
+
)
|
|
103
|
+
# Create the PARSeq model
|
|
104
|
+
parseq = keras_hub.models.PARSeqCausalLM(
|
|
105
|
+
backbone=backbone,
|
|
106
|
+
preprocessor=preprocessor,
|
|
107
|
+
)
|
|
108
|
+
parseq.compile(
|
|
109
|
+
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
|
|
110
|
+
optimizer=keras.optimizers.Adam(5e-5),
|
|
111
|
+
)
|
|
112
|
+
parseq.fit(
|
|
113
|
+
x={
|
|
114
|
+
"images": images,
|
|
115
|
+
"token_ids": token_ids,
|
|
116
|
+
"padding_mask": padding_mask
|
|
117
|
+
},
|
|
118
|
+
batch_size=2,
|
|
119
|
+
)
|
|
120
|
+
```
|
|
121
|
+
"""
|
|
122
|
+
|
|
123
|
+
backbone_cls = PARSeqBackbone
|
|
124
|
+
preprocessor_cls = PARSeqCausalLMPreprocessor
|
|
125
|
+
|
|
126
|
+
def __init__(
|
|
127
|
+
self,
|
|
128
|
+
preprocessor,
|
|
129
|
+
backbone,
|
|
130
|
+
num_perms=6,
|
|
131
|
+
add_forward_perms=True,
|
|
132
|
+
add_mirrored_perms=True,
|
|
133
|
+
seed=None,
|
|
134
|
+
end_token_id=0, # default tokenizer.end_token_id
|
|
135
|
+
**kwargs,
|
|
136
|
+
):
|
|
137
|
+
# === Layers ===
|
|
138
|
+
self.preprocessor = preprocessor
|
|
139
|
+
self.backbone = backbone
|
|
140
|
+
|
|
141
|
+
# === Functional Model ===
|
|
142
|
+
# This must be "backbone.input" i.e. the full input structure,
|
|
143
|
+
# rather than "backbone.inputs" which is the flattened list of inputs.
|
|
144
|
+
inputs = backbone.input
|
|
145
|
+
outputs = backbone(inputs=inputs)
|
|
146
|
+
super().__init__(
|
|
147
|
+
inputs=inputs,
|
|
148
|
+
outputs=outputs,
|
|
149
|
+
**kwargs,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
# === Config ===
|
|
153
|
+
self.num_perms = num_perms
|
|
154
|
+
self.add_forward_perms = add_forward_perms
|
|
155
|
+
self.add_mirrored_perms = add_mirrored_perms
|
|
156
|
+
self.end_token_id = end_token_id
|
|
157
|
+
self.seed = seed
|
|
158
|
+
self.seed_generator = keras.random.SeedGenerator(seed)
|
|
159
|
+
|
|
160
|
+
def get_config(self):
|
|
161
|
+
config = super().get_config()
|
|
162
|
+
config.update(
|
|
163
|
+
{
|
|
164
|
+
"num_perms": self.num_perms,
|
|
165
|
+
"add_forward_perms": self.add_forward_perms,
|
|
166
|
+
"add_mirrored_perms": self.add_mirrored_perms,
|
|
167
|
+
"seed": self.seed,
|
|
168
|
+
"end_token_id": self.end_token_id,
|
|
169
|
+
}
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
return config
|
|
173
|
+
|
|
174
|
+
def compile(
|
|
175
|
+
self,
|
|
176
|
+
optimizer="auto",
|
|
177
|
+
loss="auto",
|
|
178
|
+
*,
|
|
179
|
+
weighted_metrics="auto",
|
|
180
|
+
sampler="greedy",
|
|
181
|
+
**kwargs,
|
|
182
|
+
):
|
|
183
|
+
if loss == "auto":
|
|
184
|
+
loss = keras.losses.SparseCategoricalCrossentropy(
|
|
185
|
+
from_logits=True,
|
|
186
|
+
ignore_class=self.preprocessor.tokenizer.pad_token_id,
|
|
187
|
+
)
|
|
188
|
+
super().compile(
|
|
189
|
+
optimizer=optimizer,
|
|
190
|
+
loss=loss,
|
|
191
|
+
weighted_metrics=weighted_metrics,
|
|
192
|
+
sampler=sampler,
|
|
193
|
+
**kwargs,
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
def compute_loss(
|
|
197
|
+
self, x, y, y_pred, sample_weight, training=True, *args, **kwargs
|
|
198
|
+
):
|
|
199
|
+
# For keras we have fixed input for all batches, so in this case
|
|
200
|
+
# we permute 23 tokens excluding BOS and EOS tokens instead of max
|
|
201
|
+
# characters for current batch used in torch implementation
|
|
202
|
+
# -1 because we will be generating permutation mask for considering
|
|
203
|
+
# tokens before creating target label.
|
|
204
|
+
max_num_chars = self.backbone.max_label_length - 1
|
|
205
|
+
perms = self.generate_training_permutations(max_num_chars)
|
|
206
|
+
max_label_length = self.backbone.max_label_length
|
|
207
|
+
memory = self.backbone.image_encoder(x["images"])
|
|
208
|
+
batch_size = ops.shape(x["images"])[0]
|
|
209
|
+
losses = []
|
|
210
|
+
for i in range(ops.shape(perms)[0]):
|
|
211
|
+
query_mask, content_mask = self.generate_attention_masks(perms[i])
|
|
212
|
+
query_mask = ops.broadcast_to(
|
|
213
|
+
query_mask, (batch_size, max_label_length, max_label_length)
|
|
214
|
+
)
|
|
215
|
+
content_mask = ops.broadcast_to(
|
|
216
|
+
content_mask, (batch_size, max_label_length, max_label_length)
|
|
217
|
+
)
|
|
218
|
+
out = self.backbone.decoder(
|
|
219
|
+
x["token_ids"],
|
|
220
|
+
memory,
|
|
221
|
+
padding_mask=x["padding_mask"],
|
|
222
|
+
query_mask=query_mask,
|
|
223
|
+
content_mask=content_mask,
|
|
224
|
+
)
|
|
225
|
+
y_pred = self.backbone.head(out)
|
|
226
|
+
loss = super().compute_loss(
|
|
227
|
+
x=x, y=y, y_pred=y_pred, sample_weight=sample_weight, **kwargs
|
|
228
|
+
)
|
|
229
|
+
losses.append(loss)
|
|
230
|
+
if i == 1:
|
|
231
|
+
# Sample weights are set to zero for end-of-sequence (EOS)
|
|
232
|
+
# tokens to prevent them from affecting loss calculations.
|
|
233
|
+
# reference: https://github.com/baudm/parseq/blob/1902db043c029a7e03a3818c616c06600af574be/strhub/models/parseq/system.py#L194 # noqa: E501
|
|
234
|
+
sample_weight = ops.logical_and(
|
|
235
|
+
y != self.end_token_id, sample_weight
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
return ops.sum(losses) / ops.shape(perms)[0]
|
|
239
|
+
|
|
240
|
+
def generate_training_permutations(self, max_num_chars):
|
|
241
|
+
max_gen_perms = (
|
|
242
|
+
self.num_perms // 2 if self.add_mirrored_perms else self.num_perms
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
if max_num_chars == 1:
|
|
246
|
+
return ops.expand_dims(ops.arange(3), axis=0)
|
|
247
|
+
|
|
248
|
+
perms = [ops.arange(max_num_chars)] if self.add_forward_perms else []
|
|
249
|
+
max_num_perms = math.factorial(max_num_chars)
|
|
250
|
+
max_gen_perms = min(max_gen_perms, max_num_perms)
|
|
251
|
+
|
|
252
|
+
for _ in range(max_gen_perms - len(perms)):
|
|
253
|
+
perm = random.shuffle(
|
|
254
|
+
ops.arange(max_num_chars), seed=self.seed_generator
|
|
255
|
+
)
|
|
256
|
+
perms.append(perm)
|
|
257
|
+
|
|
258
|
+
perms = ops.stack(perms)
|
|
259
|
+
comp = ops.flip(perms, axis=-1)
|
|
260
|
+
perms = ops.stack([perms, comp])
|
|
261
|
+
perms = ops.reshape(
|
|
262
|
+
ops.transpose(perms, (1, 0, 2)), (-1, max_num_chars)
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
bos_idx = ops.zeros((ops.shape(perms)[0], 1), dtype="int32")
|
|
266
|
+
eos_idx = ops.full(
|
|
267
|
+
(ops.shape(perms)[0], 1), max_num_chars + 1, dtype="int32"
|
|
268
|
+
)
|
|
269
|
+
perms = ops.concatenate([bos_idx, perms + 1, eos_idx], axis=1)
|
|
270
|
+
|
|
271
|
+
if perms.shape[0] > 1:
|
|
272
|
+
perms = ops.scatter_update(
|
|
273
|
+
perms,
|
|
274
|
+
ops.concatenate(
|
|
275
|
+
[
|
|
276
|
+
ops.ones((max_num_chars + 1, 1), dtype="int32"),
|
|
277
|
+
ops.expand_dims(
|
|
278
|
+
ops.arange(1, max_num_chars + 2, dtype="int32"),
|
|
279
|
+
axis=1,
|
|
280
|
+
),
|
|
281
|
+
],
|
|
282
|
+
axis=1,
|
|
283
|
+
),
|
|
284
|
+
max_num_chars + 1 - ops.arange(max_num_chars + 1),
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
return perms
|
|
288
|
+
|
|
289
|
+
def generate_attention_masks(self, perm):
|
|
290
|
+
"""Generate attention masks given a sequence permutation
|
|
291
|
+
(includes pos. for BOS and EOS tokens)"""
|
|
292
|
+
input_length = ops.shape(perm)[0]
|
|
293
|
+
mask = ops.ones((input_length, input_length))
|
|
294
|
+
for i in range(input_length - 1):
|
|
295
|
+
masked_keys = perm[i + 1 : input_length]
|
|
296
|
+
query_idx = ops.broadcast_to(perm[i], ops.shape(masked_keys))
|
|
297
|
+
indices = ops.stack((query_idx, masked_keys), axis=1)
|
|
298
|
+
mask = keras.ops.scatter_update(
|
|
299
|
+
mask, indices, keras.ops.zeros(ops.shape(masked_keys)[0])
|
|
300
|
+
)
|
|
301
|
+
content_mask = mask[:-1, :-1]
|
|
302
|
+
mask = mask * (1 - ops.eye(input_length))
|
|
303
|
+
query_mask = mask[1:, :-1]
|
|
304
|
+
return query_mask, content_mask
|
|
305
|
+
|
|
306
|
+
def call_with_cache(
|
|
307
|
+
self,
|
|
308
|
+
token_ids,
|
|
309
|
+
cache,
|
|
310
|
+
cache_update_index,
|
|
311
|
+
img_embeddings,
|
|
312
|
+
padding_mask=None,
|
|
313
|
+
):
|
|
314
|
+
bs = ops.shape(token_ids)[0]
|
|
315
|
+
# <bos> stands for the null context. We only supply position information
|
|
316
|
+
# for characters after <bos>.
|
|
317
|
+
content = ops.where(
|
|
318
|
+
cache_update_index == 0,
|
|
319
|
+
self.backbone.decoder_hidden_dim**0.5
|
|
320
|
+
* self.backbone.decoder.token_embedding(token_ids),
|
|
321
|
+
ops.expand_dims(
|
|
322
|
+
self.backbone.decoder.pos_query_embeddings[
|
|
323
|
+
:, cache_update_index - 1, :
|
|
324
|
+
],
|
|
325
|
+
axis=0,
|
|
326
|
+
)
|
|
327
|
+
+ self.backbone.decoder_hidden_dim**0.5
|
|
328
|
+
* self.backbone.decoder.token_embedding(token_ids),
|
|
329
|
+
)
|
|
330
|
+
content = self.backbone.decoder.dropout(content)
|
|
331
|
+
|
|
332
|
+
query = ops.ones((bs, 1, 1)) * ops.expand_dims(
|
|
333
|
+
self.backbone.decoder.pos_query_embeddings[
|
|
334
|
+
:, cache_update_index, :
|
|
335
|
+
],
|
|
336
|
+
axis=0,
|
|
337
|
+
)
|
|
338
|
+
query = self.backbone.decoder.dropout(query)
|
|
339
|
+
|
|
340
|
+
query_cache = []
|
|
341
|
+
content_cache = []
|
|
342
|
+
for i, decoder_layer in enumerate(self.backbone.decoder.decoder_layers):
|
|
343
|
+
last = i == self.backbone.num_decoder_layers - 1
|
|
344
|
+
current_query_cache = cache[:, i, 0, ...]
|
|
345
|
+
current_content_cache = cache[:, i, 1, ...]
|
|
346
|
+
(
|
|
347
|
+
query,
|
|
348
|
+
content,
|
|
349
|
+
query_self_attention_new_cache,
|
|
350
|
+
content_self_attention_cache,
|
|
351
|
+
) = decoder_layer(
|
|
352
|
+
query=query,
|
|
353
|
+
content=content,
|
|
354
|
+
memory=img_embeddings,
|
|
355
|
+
padding_mask=padding_mask,
|
|
356
|
+
update_content=not last,
|
|
357
|
+
query_self_attention_cache=current_query_cache,
|
|
358
|
+
query_self_attention_cache_update_index=cache_update_index,
|
|
359
|
+
content_self_attention_cache=current_content_cache,
|
|
360
|
+
content_self_attention_cache_update_index=cache_update_index,
|
|
361
|
+
)
|
|
362
|
+
query_cache.append(query_self_attention_new_cache)
|
|
363
|
+
content_cache.append(content_self_attention_cache)
|
|
364
|
+
|
|
365
|
+
query_cache = ops.stack(query_cache, axis=1)
|
|
366
|
+
content_cache = ops.stack(content_cache, axis=1)
|
|
367
|
+
cache = ops.stack([query_cache, content_cache], axis=2)
|
|
368
|
+
hidden_states = self.backbone.decoder.layer_norm(query)
|
|
369
|
+
logits = self.backbone.head(hidden_states)
|
|
370
|
+
return logits, hidden_states, cache
|
|
371
|
+
|
|
372
|
+
def _build_cache(self, token_ids, img_embeddings, padding_mask):
|
|
373
|
+
batch_size = ops.shape(token_ids)[0]
|
|
374
|
+
max_length = ops.shape(token_ids)[1]
|
|
375
|
+
num_layers = self.backbone.num_decoder_layers
|
|
376
|
+
head_dim = (
|
|
377
|
+
self.backbone.decoder_hidden_dim // self.backbone.num_decoder_heads
|
|
378
|
+
)
|
|
379
|
+
num_heads = self.backbone.num_decoder_heads
|
|
380
|
+
shape = [batch_size, num_layers, 2, 2, max_length, num_heads, head_dim]
|
|
381
|
+
cache = ops.zeros(shape)
|
|
382
|
+
|
|
383
|
+
# Seed the cache.
|
|
384
|
+
logits, hidden_states, cache = self.call_with_cache(
|
|
385
|
+
token_ids=token_ids,
|
|
386
|
+
img_embeddings=img_embeddings,
|
|
387
|
+
cache=cache,
|
|
388
|
+
cache_update_index=0,
|
|
389
|
+
padding_mask=padding_mask,
|
|
390
|
+
)
|
|
391
|
+
return hidden_states, cache
|
|
392
|
+
|
|
393
|
+
def generate_step(self, inputs, stop_token_ids=None):
|
|
394
|
+
token_ids, padding_mask, images = (
|
|
395
|
+
inputs["token_ids"],
|
|
396
|
+
inputs["padding_mask"],
|
|
397
|
+
inputs["images"],
|
|
398
|
+
)
|
|
399
|
+
images_shape = ops.shape(images)
|
|
400
|
+
if len(images_shape) == 3:
|
|
401
|
+
# Handle an unbatched image. Unlike `token_ids` and `padding_mask`
|
|
402
|
+
# this will not automatically be upranked.
|
|
403
|
+
images = ops.expand_dims(images, axis=0)
|
|
404
|
+
|
|
405
|
+
img_embeddings = self.backbone.image_encoder(images)
|
|
406
|
+
# Create and seed cache with a single forward pass.
|
|
407
|
+
hidden_states, cache = self._build_cache(
|
|
408
|
+
token_ids=token_ids,
|
|
409
|
+
img_embeddings=img_embeddings,
|
|
410
|
+
padding_mask=padding_mask,
|
|
411
|
+
)
|
|
412
|
+
# Compute the lengths of all user inputted tokens ids.
|
|
413
|
+
row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1)
|
|
414
|
+
# Start at the first index that has no user inputted id.
|
|
415
|
+
index = ops.min(row_lengths)
|
|
416
|
+
|
|
417
|
+
def next(prompt, cache, index):
|
|
418
|
+
# The cache index is the index of our previous token.
|
|
419
|
+
cache_update_index = index - 1
|
|
420
|
+
batch_size = ops.shape(prompt)[0]
|
|
421
|
+
prompt = ops.slice(prompt, [0, index - 1], [batch_size, 1])
|
|
422
|
+
logits, hidden_states, cache = self.call_with_cache(
|
|
423
|
+
token_ids=prompt,
|
|
424
|
+
cache=cache,
|
|
425
|
+
cache_update_index=cache_update_index,
|
|
426
|
+
img_embeddings=img_embeddings,
|
|
427
|
+
)
|
|
428
|
+
return (
|
|
429
|
+
ops.squeeze(logits, axis=1),
|
|
430
|
+
ops.squeeze(hidden_states, axis=1),
|
|
431
|
+
cache,
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
token_ids = self.sampler(
|
|
435
|
+
next=next,
|
|
436
|
+
prompt=token_ids,
|
|
437
|
+
cache=cache,
|
|
438
|
+
index=index,
|
|
439
|
+
mask=padding_mask,
|
|
440
|
+
stop_token_ids=stop_token_ids,
|
|
441
|
+
hidden_states=hidden_states,
|
|
442
|
+
model=self,
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
# Compute an output padding mask with the token ids we updated.
|
|
446
|
+
if stop_token_ids is not None:
|
|
447
|
+
# Build a mask of `stop_token_ids` locations not in the original
|
|
448
|
+
# prompt (not in locations where `padding_mask` is True).
|
|
449
|
+
end_locations = any_equal(
|
|
450
|
+
token_ids, stop_token_ids, ops.logical_not(padding_mask)
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
end_locations = ops.cast(end_locations, "int32")
|
|
454
|
+
# Use cumsum to get ones in all locations after end_locations.
|
|
455
|
+
cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32")
|
|
456
|
+
overflow = cumsum - end_locations
|
|
457
|
+
# Our padding mask is the inverse of these overflow locations.
|
|
458
|
+
padding_mask = ops.logical_not(ops.cast(overflow, "bool"))
|
|
459
|
+
else:
|
|
460
|
+
# Without early stopping, all locations will have been updated.
|
|
461
|
+
padding_mask = ops.ones_like(token_ids, dtype="bool")
|
|
462
|
+
return {
|
|
463
|
+
"token_ids": token_ids,
|
|
464
|
+
"padding_mask": padding_mask,
|
|
465
|
+
"images": images,
|
|
466
|
+
}
|