keras-hub-nightly 0.22.0.dev202508170419__py3-none-any.whl → 0.24.0.dev202511090424__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of keras-hub-nightly might be problematic. Click here for more details.

Files changed (126) hide show
  1. keras_hub/layers/__init__.py +15 -0
  2. keras_hub/models/__init__.py +93 -0
  3. keras_hub/src/layers/modeling/position_embedding.py +21 -6
  4. keras_hub/src/layers/modeling/reversible_embedding.py +8 -1
  5. keras_hub/src/layers/modeling/rotary_embedding.py +16 -6
  6. keras_hub/src/layers/modeling/sine_position_encoding.py +21 -8
  7. keras_hub/src/layers/modeling/token_and_position_embedding.py +2 -1
  8. keras_hub/src/models/backbone.py +28 -16
  9. keras_hub/src/models/causal_lm.py +37 -0
  10. keras_hub/src/models/causal_lm_preprocessor.py +14 -0
  11. keras_hub/src/models/clip/clip_presets.py +8 -8
  12. keras_hub/src/models/d_fine/__init__.py +5 -0
  13. keras_hub/src/models/d_fine/d_fine_attention.py +461 -0
  14. keras_hub/src/models/d_fine/d_fine_backbone.py +891 -0
  15. keras_hub/src/models/d_fine/d_fine_decoder.py +944 -0
  16. keras_hub/src/models/d_fine/d_fine_encoder.py +365 -0
  17. keras_hub/src/models/d_fine/d_fine_hybrid_encoder.py +642 -0
  18. keras_hub/src/models/d_fine/d_fine_image_converter.py +8 -0
  19. keras_hub/src/models/d_fine/d_fine_layers.py +1828 -0
  20. keras_hub/src/models/d_fine/d_fine_loss.py +938 -0
  21. keras_hub/src/models/d_fine/d_fine_object_detector.py +875 -0
  22. keras_hub/src/models/d_fine/d_fine_object_detector_preprocessor.py +14 -0
  23. keras_hub/src/models/d_fine/d_fine_presets.py +155 -0
  24. keras_hub/src/models/d_fine/d_fine_utils.py +827 -0
  25. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +7 -2
  26. keras_hub/src/models/depth_anything/__init__.py +9 -0
  27. keras_hub/src/models/depth_anything/depth_anything_backbone.py +232 -0
  28. keras_hub/src/models/depth_anything/depth_anything_depth_estimator.py +70 -0
  29. keras_hub/src/models/depth_anything/depth_anything_depth_estimator_preprocessor.py +16 -0
  30. keras_hub/src/models/depth_anything/depth_anything_image_converter.py +10 -0
  31. keras_hub/src/models/depth_anything/depth_anything_layers.py +725 -0
  32. keras_hub/src/models/depth_anything/depth_anything_loss.py +89 -0
  33. keras_hub/src/models/depth_anything/depth_anything_presets.py +41 -0
  34. keras_hub/src/models/depth_anything/interpolate.py +62 -0
  35. keras_hub/src/models/depth_estimator.py +239 -0
  36. keras_hub/src/models/depth_estimator_preprocessor.py +78 -0
  37. keras_hub/src/models/dinov2/dinov2_backbone.py +29 -3
  38. keras_hub/src/models/dinov2/dinov2_layers.py +16 -4
  39. keras_hub/src/models/dinov3/__init__.py +5 -0
  40. keras_hub/src/models/dinov3/dinov3_backbone.py +263 -0
  41. keras_hub/src/models/dinov3/dinov3_image_converter.py +8 -0
  42. keras_hub/src/models/dinov3/dinov3_layers.py +1013 -0
  43. keras_hub/src/models/dinov3/dinov3_presets.py +4 -0
  44. keras_hub/src/models/gemma/gemma_backbone.py +0 -1
  45. keras_hub/src/models/gemma/gemma_presets.py +30 -0
  46. keras_hub/src/models/gemma3/gemma3_attention.py +48 -0
  47. keras_hub/src/models/gemma3/gemma3_backbone.py +4 -1
  48. keras_hub/src/models/gemma3/gemma3_decoder_block.py +12 -0
  49. keras_hub/src/models/gemma3/gemma3_presets.py +39 -0
  50. keras_hub/src/models/hgnetv2/hgnetv2_backbone.py +4 -1
  51. keras_hub/src/models/hgnetv2/hgnetv2_encoder.py +3 -2
  52. keras_hub/src/models/hgnetv2/hgnetv2_layers.py +27 -11
  53. keras_hub/src/models/image_to_image.py +5 -0
  54. keras_hub/src/models/inpaint.py +5 -0
  55. keras_hub/src/models/mobilenetv5/__init__.py +9 -0
  56. keras_hub/src/models/mobilenetv5/mobilenetv5_attention.py +699 -0
  57. keras_hub/src/models/mobilenetv5/mobilenetv5_backbone.py +396 -0
  58. keras_hub/src/models/mobilenetv5/mobilenetv5_blocks.py +890 -0
  59. keras_hub/src/models/mobilenetv5/mobilenetv5_builder.py +436 -0
  60. keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier.py +157 -0
  61. keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier_preprocessor.py +16 -0
  62. keras_hub/src/models/mobilenetv5/mobilenetv5_image_converter.py +10 -0
  63. keras_hub/src/models/mobilenetv5/mobilenetv5_layers.py +462 -0
  64. keras_hub/src/models/mobilenetv5/mobilenetv5_presets.py +15 -0
  65. keras_hub/src/models/mobilenetv5/mobilenetv5_utils.py +146 -0
  66. keras_hub/src/models/parseq/__init__.py +5 -0
  67. keras_hub/src/models/parseq/parseq_backbone.py +134 -0
  68. keras_hub/src/models/parseq/parseq_causal_lm.py +466 -0
  69. keras_hub/src/models/parseq/parseq_causal_lm_preprocessor.py +168 -0
  70. keras_hub/src/models/parseq/parseq_decoder.py +418 -0
  71. keras_hub/src/models/parseq/parseq_image_converter.py +8 -0
  72. keras_hub/src/models/parseq/parseq_presets.py +15 -0
  73. keras_hub/src/models/parseq/parseq_tokenizer.py +221 -0
  74. keras_hub/src/models/qwen3_moe/__init__.py +5 -0
  75. keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py +371 -0
  76. keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +365 -0
  77. keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm.py +357 -0
  78. keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm_preprocessor.py +12 -0
  79. keras_hub/src/models/qwen3_moe/qwen3_moe_decoder.py +672 -0
  80. keras_hub/src/models/qwen3_moe/qwen3_moe_layernorm.py +45 -0
  81. keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +30 -0
  82. keras_hub/src/models/qwen3_moe/qwen3_moe_tokenizer.py +48 -0
  83. keras_hub/src/models/sam/sam_prompt_encoder.py +3 -1
  84. keras_hub/src/models/siglip/siglip_presets.py +15 -0
  85. keras_hub/src/models/smollm3/smollm3_backbone.py +211 -0
  86. keras_hub/src/models/smollm3/smollm3_causal_lm.py +310 -0
  87. keras_hub/src/models/smollm3/smollm3_causal_lm_preprocessor.py +84 -0
  88. keras_hub/src/models/smollm3/smollm3_layers.py +757 -0
  89. keras_hub/src/models/smollm3/smollm3_tokenizer.py +60 -0
  90. keras_hub/src/models/smollm3/smollm3_utils.py +56 -0
  91. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +3 -3
  92. keras_hub/src/models/t5gemma/__init__.py +5 -0
  93. keras_hub/src/models/t5gemma/t5gemma_attention.py +370 -0
  94. keras_hub/src/models/t5gemma/t5gemma_backbone.py +366 -0
  95. keras_hub/src/models/t5gemma/t5gemma_decoder.py +355 -0
  96. keras_hub/src/models/t5gemma/t5gemma_encoder.py +214 -0
  97. keras_hub/src/models/t5gemma/t5gemma_layers.py +118 -0
  98. keras_hub/src/models/t5gemma/t5gemma_presets.py +374 -0
  99. keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm.py +442 -0
  100. keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm_preprocessor.py +216 -0
  101. keras_hub/src/models/t5gemma/t5gemma_tokenizer.py +84 -0
  102. keras_hub/src/models/text_to_image.py +5 -0
  103. keras_hub/src/samplers/beam_sampler.py +6 -6
  104. keras_hub/src/samplers/sampler.py +8 -6
  105. keras_hub/src/tests/test_case.py +40 -3
  106. keras_hub/src/tokenizers/tokenizer.py +15 -0
  107. keras_hub/src/utils/openvino_utils.py +141 -0
  108. keras_hub/src/utils/preset_utils.py +58 -2
  109. keras_hub/src/utils/tensor_utils.py +26 -2
  110. keras_hub/src/utils/timm/convert_mobilenetv5.py +321 -0
  111. keras_hub/src/utils/timm/preset_loader.py +8 -4
  112. keras_hub/src/utils/transformers/convert_dinov2.py +1 -0
  113. keras_hub/src/utils/transformers/convert_dinov3.py +106 -0
  114. keras_hub/src/utils/transformers/convert_qwen3_moe.py +216 -0
  115. keras_hub/src/utils/transformers/convert_smollm3.py +139 -0
  116. keras_hub/src/utils/transformers/convert_t5gemma.py +229 -0
  117. keras_hub/src/utils/transformers/convert_vit.py +4 -1
  118. keras_hub/src/utils/transformers/export/gemma.py +49 -4
  119. keras_hub/src/utils/transformers/export/hf_exporter.py +71 -25
  120. keras_hub/src/utils/transformers/preset_loader.py +12 -0
  121. keras_hub/src/version.py +1 -1
  122. keras_hub/tokenizers/__init__.py +15 -0
  123. {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/METADATA +1 -1
  124. {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/RECORD +126 -47
  125. {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/WHEEL +0 -0
  126. {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,134 @@
1
+ import keras
2
+
3
+ from keras_hub.src.api_export import keras_hub_export
4
+ from keras_hub.src.models.backbone import Backbone
5
+ from keras_hub.src.models.parseq.parseq_decoder import PARSeqDecoder
6
+
7
+
8
+ @keras_hub_export("keras_hub.models.PARSeqBackbone")
9
+ class PARSeqBackbone(Backbone):
10
+ """Scene Text Detection with PARSeq.
11
+
12
+ Performs OCR in natural scenes using the PARSeq model described in [Scene
13
+ Text Recognition with Permuted Autoregressive Sequence Models](
14
+ https://arxiv.org/abs/2207.06966). PARSeq is a ViT-based model that allows
15
+ iterative decoding by performing an autoregressive decoding phase, followed
16
+ by a refinement phase.
17
+
18
+ Args:
19
+ image_encoder: keras.Model. The image encoder model.
20
+ vocabulary_size: int. The size of the vocabulary.
21
+ max_label_length: int. The maximum length of the label sequence.
22
+ decoder_hidden_dim: int. The dimension of the decoder hidden layers.
23
+ num_decoder_layers: int. The number of decoder layers.
24
+ num_decoder_heads: int. The number of attention heads in the decoder.
25
+ decoder_mlp_dim: int. The dimension of the decoder MLP hidden layer.
26
+ dropout_rate: float. The dropout rate for the decoder network.
27
+ Defaults to `0.1`.
28
+ attention_dropout: float. The dropout rate for the attention weights.
29
+ Defaults to `0.1`.
30
+ dtype: str. `None`, str, or `keras.mixed_precision.DTypePolicy`. The
31
+ dtype to use for the computations and weights.
32
+ **kwargs: Additional keyword arguments passed to the base
33
+ `keras.Model` constructor.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ image_encoder,
39
+ vocabulary_size,
40
+ max_label_length,
41
+ decoder_hidden_dim,
42
+ num_decoder_layers,
43
+ num_decoder_heads,
44
+ decoder_mlp_dim,
45
+ dropout_rate=0.1,
46
+ attention_dropout=0.1,
47
+ dtype=None,
48
+ **kwargs,
49
+ ):
50
+ # === Layers ===
51
+ self.image_encoder = image_encoder
52
+ self.decoder = PARSeqDecoder(
53
+ vocabulary_size=vocabulary_size,
54
+ max_label_length=max_label_length,
55
+ num_layers=num_decoder_layers,
56
+ num_heads=num_decoder_heads,
57
+ hidden_dim=decoder_hidden_dim,
58
+ mlp_dim=decoder_mlp_dim,
59
+ dropout_rate=dropout_rate,
60
+ attention_dropout=attention_dropout,
61
+ name="decoder",
62
+ dtype=dtype,
63
+ )
64
+ self.head = keras.layers.Dense(
65
+ vocabulary_size - 2, # We don't predict <bos> nor <pad>
66
+ dtype=dtype,
67
+ )
68
+
69
+ # === Functional Model ===
70
+ image_input = self.image_encoder.input
71
+
72
+ token_id_input = keras.Input(
73
+ shape=(None,), dtype="int32", name="token_ids"
74
+ )
75
+ padding_mask_input = keras.Input(
76
+ shape=(None,), dtype="int32", name="padding_mask"
77
+ )
78
+
79
+ memory = self.image_encoder(image_input)
80
+ target_out = self.decoder(
81
+ token_id_input, memory, padding_mask=padding_mask_input
82
+ )
83
+ logits = self.head(target_out)
84
+
85
+ # === Config ===
86
+ self.vocabulary_size = vocabulary_size
87
+ self.max_label_length = max_label_length
88
+ self.decoder_hidden_dim = decoder_hidden_dim
89
+ self.num_decoder_layers = num_decoder_layers
90
+ self.num_decoder_heads = num_decoder_heads
91
+ self.decoder_mlp_dim = decoder_mlp_dim
92
+ self.dropout_rate = dropout_rate
93
+ self.attention_dropout = attention_dropout
94
+
95
+ super().__init__(
96
+ inputs={
97
+ "images": image_input,
98
+ "token_ids": token_id_input,
99
+ "padding_mask": padding_mask_input,
100
+ },
101
+ outputs=logits,
102
+ dtype=dtype,
103
+ **kwargs,
104
+ )
105
+
106
+ def get_config(self):
107
+ config = super().get_config()
108
+ config.update(
109
+ {
110
+ "image_encoder": keras.layers.serialize(self.image_encoder),
111
+ "vocabulary_size": self.vocabulary_size,
112
+ "max_label_length": self.max_label_length,
113
+ "decoder_hidden_dim": self.decoder_hidden_dim,
114
+ "num_decoder_layers": self.num_decoder_layers,
115
+ "num_decoder_heads": self.num_decoder_heads,
116
+ "decoder_mlp_dim": self.decoder_mlp_dim,
117
+ "dropout_rate": self.dropout_rate,
118
+ "attention_dropout": self.attention_dropout,
119
+ }
120
+ )
121
+
122
+ return config
123
+
124
+ @classmethod
125
+ def from_config(cls, config):
126
+ config.update(
127
+ {
128
+ "image_encoder": keras.layers.deserialize(
129
+ config["image_encoder"]
130
+ ),
131
+ }
132
+ )
133
+
134
+ return super().from_config(config)
@@ -0,0 +1,466 @@
1
+ import math
2
+
3
+ import keras
4
+ from keras import ops
5
+ from keras import random
6
+
7
+ from keras_hub.src.api_export import keras_hub_export
8
+ from keras_hub.src.models.causal_lm import CausalLM
9
+ from keras_hub.src.models.parseq.parseq_backbone import PARSeqBackbone
10
+ from keras_hub.src.models.parseq.parseq_causal_lm_preprocessor import (
11
+ PARSeqCausalLMPreprocessor,
12
+ )
13
+ from keras_hub.src.utils.tensor_utils import any_equal
14
+
15
+
16
+ @keras_hub_export("keras_hub.models.PARSeqCausalLM")
17
+ class PARSeqCausalLM(CausalLM):
18
+ """Scene Text Recognition with PARSeq.
19
+ Performs OCR in natural scenes using the PARSeq model described in
20
+ [Scene Text Recognition with Permuted Autoregressive Sequence Models](
21
+ https://arxiv.org/abs/2207.06966). PARSeq is a ViT-based model that allows
22
+ iterative decoding by performing an autoregressive decoding phase, followed
23
+ by a refinement phase.
24
+ Args:
25
+ preprocessor: A `keras_hub.models.Preprocessor` instance or a
26
+ `keras.Layer` instance. The preprocessor to use for the model.
27
+ backbone: A `keras_hub.models.PARSeqBackbone` instance or a
28
+ `keras.Model`. The backbone model to use for the model.
29
+ num_perms: int. The number of permutations to generate for training.
30
+ Defaults to 6.
31
+ add_forward_perms: bool. Whether to add forward permutations to the
32
+ generated permutations. Defaults to `True`.
33
+ add_mirrored_perms: bool. Whether to add mirrored permutations to the
34
+ generated permutations. Defaults to `True`.
35
+ seed: int. The random seed to use for generating permutations.
36
+ Defaults to `None`, which means no seed is set.
37
+ **kwargs: Additional keyword arguments passed to the base
38
+ `keras_hub.models.CausalLM` constructor.
39
+
40
+ Examples:
41
+
42
+ Call `predict()` to run inference.
43
+ ```python
44
+ # Load preset and run inference
45
+ images = np.random.randint(0, 256, size=(2, 32, 128, 3))
46
+ parseq = keras_hub.models.PARSeqCausalLM.from_preset(
47
+ "parseq_vit"
48
+ )
49
+ parseq.generate(images)
50
+
51
+ # Call `fit()` on a single batch.
52
+ images = np.random.randint(0, 256, size=(2, 32, 128, 3))
53
+ token_ids = np.array([[1, 2, 3, 4], [1, 2, 3, 0]])
54
+ padding_mask = np.array([[1, 1, 1, 1], [1, 1, 1, 0]])
55
+ parseq = keras_hub.models.PARSeqCausalLM.from_preset(
56
+ "parseq_vit"
57
+ )
58
+ parseq.fit(
59
+ x={
60
+ "images": images,
61
+ "token_ids": token_ids,
62
+ "padding_mask": padding_mask
63
+ },
64
+ batch_size=2,
65
+ )
66
+ ```
67
+ # Call `fit()` with custom loss, optimizer and image encoder.
68
+ ```python
69
+ # Initialize the image encoder, preprocessor and tokenizer
70
+ mean, std = 0.5, 0.5
71
+ image_converter = PARSeqImageConverter(
72
+ image_size=(32, 128),
73
+ offset=-mean / std,
74
+ scale=1.0 / 255.0 / std,
75
+ interpolation="bicubic",
76
+ )
77
+ tokenizer = PARSeqTokenizer(max_label_length=25)
78
+ preprocessor = keras_hub.models.PARSeqCausalLMPreprocessor(
79
+ image_converter=image_converter,
80
+ tokenizer=tokenizer,
81
+ )
82
+
83
+ # Create the backbone
84
+ image_encoder = ViTBackbone(
85
+ image_shape=(32, 128, 3),
86
+ patch_size=(4, 8),
87
+ num_layers=12,
88
+ num_heads=6,
89
+ hidden_dim=384,
90
+ mlp_dim=384 * 4,
91
+ use_class_token=False,
92
+ name="encoder",
93
+ )
94
+ backbone = PARSeqBackbone(
95
+ vocabulary_size=97,
96
+ max_label_length=25,
97
+ image_encoder=image_encoder,
98
+ num_decoder_heads=12,
99
+ num_decoder_layers=1,
100
+ decoder_hidden_dim=384,
101
+ decoder_mlp_dim=4 * 384,
102
+ )
103
+ # Create the PARSeq model
104
+ parseq = keras_hub.models.PARSeqCausalLM(
105
+ backbone=backbone,
106
+ preprocessor=preprocessor,
107
+ )
108
+ parseq.compile(
109
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
110
+ optimizer=keras.optimizers.Adam(5e-5),
111
+ )
112
+ parseq.fit(
113
+ x={
114
+ "images": images,
115
+ "token_ids": token_ids,
116
+ "padding_mask": padding_mask
117
+ },
118
+ batch_size=2,
119
+ )
120
+ ```
121
+ """
122
+
123
+ backbone_cls = PARSeqBackbone
124
+ preprocessor_cls = PARSeqCausalLMPreprocessor
125
+
126
+ def __init__(
127
+ self,
128
+ preprocessor,
129
+ backbone,
130
+ num_perms=6,
131
+ add_forward_perms=True,
132
+ add_mirrored_perms=True,
133
+ seed=None,
134
+ end_token_id=0, # default tokenizer.end_token_id
135
+ **kwargs,
136
+ ):
137
+ # === Layers ===
138
+ self.preprocessor = preprocessor
139
+ self.backbone = backbone
140
+
141
+ # === Functional Model ===
142
+ # This must be "backbone.input" i.e. the full input structure,
143
+ # rather than "backbone.inputs" which is the flattened list of inputs.
144
+ inputs = backbone.input
145
+ outputs = backbone(inputs=inputs)
146
+ super().__init__(
147
+ inputs=inputs,
148
+ outputs=outputs,
149
+ **kwargs,
150
+ )
151
+
152
+ # === Config ===
153
+ self.num_perms = num_perms
154
+ self.add_forward_perms = add_forward_perms
155
+ self.add_mirrored_perms = add_mirrored_perms
156
+ self.end_token_id = end_token_id
157
+ self.seed = seed
158
+ self.seed_generator = keras.random.SeedGenerator(seed)
159
+
160
+ def get_config(self):
161
+ config = super().get_config()
162
+ config.update(
163
+ {
164
+ "num_perms": self.num_perms,
165
+ "add_forward_perms": self.add_forward_perms,
166
+ "add_mirrored_perms": self.add_mirrored_perms,
167
+ "seed": self.seed,
168
+ "end_token_id": self.end_token_id,
169
+ }
170
+ )
171
+
172
+ return config
173
+
174
+ def compile(
175
+ self,
176
+ optimizer="auto",
177
+ loss="auto",
178
+ *,
179
+ weighted_metrics="auto",
180
+ sampler="greedy",
181
+ **kwargs,
182
+ ):
183
+ if loss == "auto":
184
+ loss = keras.losses.SparseCategoricalCrossentropy(
185
+ from_logits=True,
186
+ ignore_class=self.preprocessor.tokenizer.pad_token_id,
187
+ )
188
+ super().compile(
189
+ optimizer=optimizer,
190
+ loss=loss,
191
+ weighted_metrics=weighted_metrics,
192
+ sampler=sampler,
193
+ **kwargs,
194
+ )
195
+
196
+ def compute_loss(
197
+ self, x, y, y_pred, sample_weight, training=True, *args, **kwargs
198
+ ):
199
+ # For keras we have fixed input for all batches, so in this case
200
+ # we permute 23 tokens excluding BOS and EOS tokens instead of max
201
+ # characters for current batch used in torch implementation
202
+ # -1 because we will be generating permutation mask for considering
203
+ # tokens before creating target label.
204
+ max_num_chars = self.backbone.max_label_length - 1
205
+ perms = self.generate_training_permutations(max_num_chars)
206
+ max_label_length = self.backbone.max_label_length
207
+ memory = self.backbone.image_encoder(x["images"])
208
+ batch_size = ops.shape(x["images"])[0]
209
+ losses = []
210
+ for i in range(ops.shape(perms)[0]):
211
+ query_mask, content_mask = self.generate_attention_masks(perms[i])
212
+ query_mask = ops.broadcast_to(
213
+ query_mask, (batch_size, max_label_length, max_label_length)
214
+ )
215
+ content_mask = ops.broadcast_to(
216
+ content_mask, (batch_size, max_label_length, max_label_length)
217
+ )
218
+ out = self.backbone.decoder(
219
+ x["token_ids"],
220
+ memory,
221
+ padding_mask=x["padding_mask"],
222
+ query_mask=query_mask,
223
+ content_mask=content_mask,
224
+ )
225
+ y_pred = self.backbone.head(out)
226
+ loss = super().compute_loss(
227
+ x=x, y=y, y_pred=y_pred, sample_weight=sample_weight, **kwargs
228
+ )
229
+ losses.append(loss)
230
+ if i == 1:
231
+ # Sample weights are set to zero for end-of-sequence (EOS)
232
+ # tokens to prevent them from affecting loss calculations.
233
+ # reference: https://github.com/baudm/parseq/blob/1902db043c029a7e03a3818c616c06600af574be/strhub/models/parseq/system.py#L194 # noqa: E501
234
+ sample_weight = ops.logical_and(
235
+ y != self.end_token_id, sample_weight
236
+ )
237
+
238
+ return ops.sum(losses) / ops.shape(perms)[0]
239
+
240
+ def generate_training_permutations(self, max_num_chars):
241
+ max_gen_perms = (
242
+ self.num_perms // 2 if self.add_mirrored_perms else self.num_perms
243
+ )
244
+
245
+ if max_num_chars == 1:
246
+ return ops.expand_dims(ops.arange(3), axis=0)
247
+
248
+ perms = [ops.arange(max_num_chars)] if self.add_forward_perms else []
249
+ max_num_perms = math.factorial(max_num_chars)
250
+ max_gen_perms = min(max_gen_perms, max_num_perms)
251
+
252
+ for _ in range(max_gen_perms - len(perms)):
253
+ perm = random.shuffle(
254
+ ops.arange(max_num_chars), seed=self.seed_generator
255
+ )
256
+ perms.append(perm)
257
+
258
+ perms = ops.stack(perms)
259
+ comp = ops.flip(perms, axis=-1)
260
+ perms = ops.stack([perms, comp])
261
+ perms = ops.reshape(
262
+ ops.transpose(perms, (1, 0, 2)), (-1, max_num_chars)
263
+ )
264
+
265
+ bos_idx = ops.zeros((ops.shape(perms)[0], 1), dtype="int32")
266
+ eos_idx = ops.full(
267
+ (ops.shape(perms)[0], 1), max_num_chars + 1, dtype="int32"
268
+ )
269
+ perms = ops.concatenate([bos_idx, perms + 1, eos_idx], axis=1)
270
+
271
+ if perms.shape[0] > 1:
272
+ perms = ops.scatter_update(
273
+ perms,
274
+ ops.concatenate(
275
+ [
276
+ ops.ones((max_num_chars + 1, 1), dtype="int32"),
277
+ ops.expand_dims(
278
+ ops.arange(1, max_num_chars + 2, dtype="int32"),
279
+ axis=1,
280
+ ),
281
+ ],
282
+ axis=1,
283
+ ),
284
+ max_num_chars + 1 - ops.arange(max_num_chars + 1),
285
+ )
286
+
287
+ return perms
288
+
289
+ def generate_attention_masks(self, perm):
290
+ """Generate attention masks given a sequence permutation
291
+ (includes pos. for BOS and EOS tokens)"""
292
+ input_length = ops.shape(perm)[0]
293
+ mask = ops.ones((input_length, input_length))
294
+ for i in range(input_length - 1):
295
+ masked_keys = perm[i + 1 : input_length]
296
+ query_idx = ops.broadcast_to(perm[i], ops.shape(masked_keys))
297
+ indices = ops.stack((query_idx, masked_keys), axis=1)
298
+ mask = keras.ops.scatter_update(
299
+ mask, indices, keras.ops.zeros(ops.shape(masked_keys)[0])
300
+ )
301
+ content_mask = mask[:-1, :-1]
302
+ mask = mask * (1 - ops.eye(input_length))
303
+ query_mask = mask[1:, :-1]
304
+ return query_mask, content_mask
305
+
306
+ def call_with_cache(
307
+ self,
308
+ token_ids,
309
+ cache,
310
+ cache_update_index,
311
+ img_embeddings,
312
+ padding_mask=None,
313
+ ):
314
+ bs = ops.shape(token_ids)[0]
315
+ # <bos> stands for the null context. We only supply position information
316
+ # for characters after <bos>.
317
+ content = ops.where(
318
+ cache_update_index == 0,
319
+ self.backbone.decoder_hidden_dim**0.5
320
+ * self.backbone.decoder.token_embedding(token_ids),
321
+ ops.expand_dims(
322
+ self.backbone.decoder.pos_query_embeddings[
323
+ :, cache_update_index - 1, :
324
+ ],
325
+ axis=0,
326
+ )
327
+ + self.backbone.decoder_hidden_dim**0.5
328
+ * self.backbone.decoder.token_embedding(token_ids),
329
+ )
330
+ content = self.backbone.decoder.dropout(content)
331
+
332
+ query = ops.ones((bs, 1, 1)) * ops.expand_dims(
333
+ self.backbone.decoder.pos_query_embeddings[
334
+ :, cache_update_index, :
335
+ ],
336
+ axis=0,
337
+ )
338
+ query = self.backbone.decoder.dropout(query)
339
+
340
+ query_cache = []
341
+ content_cache = []
342
+ for i, decoder_layer in enumerate(self.backbone.decoder.decoder_layers):
343
+ last = i == self.backbone.num_decoder_layers - 1
344
+ current_query_cache = cache[:, i, 0, ...]
345
+ current_content_cache = cache[:, i, 1, ...]
346
+ (
347
+ query,
348
+ content,
349
+ query_self_attention_new_cache,
350
+ content_self_attention_cache,
351
+ ) = decoder_layer(
352
+ query=query,
353
+ content=content,
354
+ memory=img_embeddings,
355
+ padding_mask=padding_mask,
356
+ update_content=not last,
357
+ query_self_attention_cache=current_query_cache,
358
+ query_self_attention_cache_update_index=cache_update_index,
359
+ content_self_attention_cache=current_content_cache,
360
+ content_self_attention_cache_update_index=cache_update_index,
361
+ )
362
+ query_cache.append(query_self_attention_new_cache)
363
+ content_cache.append(content_self_attention_cache)
364
+
365
+ query_cache = ops.stack(query_cache, axis=1)
366
+ content_cache = ops.stack(content_cache, axis=1)
367
+ cache = ops.stack([query_cache, content_cache], axis=2)
368
+ hidden_states = self.backbone.decoder.layer_norm(query)
369
+ logits = self.backbone.head(hidden_states)
370
+ return logits, hidden_states, cache
371
+
372
+ def _build_cache(self, token_ids, img_embeddings, padding_mask):
373
+ batch_size = ops.shape(token_ids)[0]
374
+ max_length = ops.shape(token_ids)[1]
375
+ num_layers = self.backbone.num_decoder_layers
376
+ head_dim = (
377
+ self.backbone.decoder_hidden_dim // self.backbone.num_decoder_heads
378
+ )
379
+ num_heads = self.backbone.num_decoder_heads
380
+ shape = [batch_size, num_layers, 2, 2, max_length, num_heads, head_dim]
381
+ cache = ops.zeros(shape)
382
+
383
+ # Seed the cache.
384
+ logits, hidden_states, cache = self.call_with_cache(
385
+ token_ids=token_ids,
386
+ img_embeddings=img_embeddings,
387
+ cache=cache,
388
+ cache_update_index=0,
389
+ padding_mask=padding_mask,
390
+ )
391
+ return hidden_states, cache
392
+
393
+ def generate_step(self, inputs, stop_token_ids=None):
394
+ token_ids, padding_mask, images = (
395
+ inputs["token_ids"],
396
+ inputs["padding_mask"],
397
+ inputs["images"],
398
+ )
399
+ images_shape = ops.shape(images)
400
+ if len(images_shape) == 3:
401
+ # Handle an unbatched image. Unlike `token_ids` and `padding_mask`
402
+ # this will not automatically be upranked.
403
+ images = ops.expand_dims(images, axis=0)
404
+
405
+ img_embeddings = self.backbone.image_encoder(images)
406
+ # Create and seed cache with a single forward pass.
407
+ hidden_states, cache = self._build_cache(
408
+ token_ids=token_ids,
409
+ img_embeddings=img_embeddings,
410
+ padding_mask=padding_mask,
411
+ )
412
+ # Compute the lengths of all user inputted tokens ids.
413
+ row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1)
414
+ # Start at the first index that has no user inputted id.
415
+ index = ops.min(row_lengths)
416
+
417
+ def next(prompt, cache, index):
418
+ # The cache index is the index of our previous token.
419
+ cache_update_index = index - 1
420
+ batch_size = ops.shape(prompt)[0]
421
+ prompt = ops.slice(prompt, [0, index - 1], [batch_size, 1])
422
+ logits, hidden_states, cache = self.call_with_cache(
423
+ token_ids=prompt,
424
+ cache=cache,
425
+ cache_update_index=cache_update_index,
426
+ img_embeddings=img_embeddings,
427
+ )
428
+ return (
429
+ ops.squeeze(logits, axis=1),
430
+ ops.squeeze(hidden_states, axis=1),
431
+ cache,
432
+ )
433
+
434
+ token_ids = self.sampler(
435
+ next=next,
436
+ prompt=token_ids,
437
+ cache=cache,
438
+ index=index,
439
+ mask=padding_mask,
440
+ stop_token_ids=stop_token_ids,
441
+ hidden_states=hidden_states,
442
+ model=self,
443
+ )
444
+
445
+ # Compute an output padding mask with the token ids we updated.
446
+ if stop_token_ids is not None:
447
+ # Build a mask of `stop_token_ids` locations not in the original
448
+ # prompt (not in locations where `padding_mask` is True).
449
+ end_locations = any_equal(
450
+ token_ids, stop_token_ids, ops.logical_not(padding_mask)
451
+ )
452
+
453
+ end_locations = ops.cast(end_locations, "int32")
454
+ # Use cumsum to get ones in all locations after end_locations.
455
+ cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32")
456
+ overflow = cumsum - end_locations
457
+ # Our padding mask is the inverse of these overflow locations.
458
+ padding_mask = ops.logical_not(ops.cast(overflow, "bool"))
459
+ else:
460
+ # Without early stopping, all locations will have been updated.
461
+ padding_mask = ops.ones_like(token_ids, dtype="bool")
462
+ return {
463
+ "token_ids": token_ids,
464
+ "padding_mask": padding_mask,
465
+ "images": images,
466
+ }