keras-hub-nightly 0.22.0.dev202508170419__py3-none-any.whl → 0.24.0.dev202511090424__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of keras-hub-nightly might be problematic. Click here for more details.

Files changed (126) hide show
  1. keras_hub/layers/__init__.py +15 -0
  2. keras_hub/models/__init__.py +93 -0
  3. keras_hub/src/layers/modeling/position_embedding.py +21 -6
  4. keras_hub/src/layers/modeling/reversible_embedding.py +8 -1
  5. keras_hub/src/layers/modeling/rotary_embedding.py +16 -6
  6. keras_hub/src/layers/modeling/sine_position_encoding.py +21 -8
  7. keras_hub/src/layers/modeling/token_and_position_embedding.py +2 -1
  8. keras_hub/src/models/backbone.py +28 -16
  9. keras_hub/src/models/causal_lm.py +37 -0
  10. keras_hub/src/models/causal_lm_preprocessor.py +14 -0
  11. keras_hub/src/models/clip/clip_presets.py +8 -8
  12. keras_hub/src/models/d_fine/__init__.py +5 -0
  13. keras_hub/src/models/d_fine/d_fine_attention.py +461 -0
  14. keras_hub/src/models/d_fine/d_fine_backbone.py +891 -0
  15. keras_hub/src/models/d_fine/d_fine_decoder.py +944 -0
  16. keras_hub/src/models/d_fine/d_fine_encoder.py +365 -0
  17. keras_hub/src/models/d_fine/d_fine_hybrid_encoder.py +642 -0
  18. keras_hub/src/models/d_fine/d_fine_image_converter.py +8 -0
  19. keras_hub/src/models/d_fine/d_fine_layers.py +1828 -0
  20. keras_hub/src/models/d_fine/d_fine_loss.py +938 -0
  21. keras_hub/src/models/d_fine/d_fine_object_detector.py +875 -0
  22. keras_hub/src/models/d_fine/d_fine_object_detector_preprocessor.py +14 -0
  23. keras_hub/src/models/d_fine/d_fine_presets.py +155 -0
  24. keras_hub/src/models/d_fine/d_fine_utils.py +827 -0
  25. keras_hub/src/models/deberta_v3/disentangled_self_attention.py +7 -2
  26. keras_hub/src/models/depth_anything/__init__.py +9 -0
  27. keras_hub/src/models/depth_anything/depth_anything_backbone.py +232 -0
  28. keras_hub/src/models/depth_anything/depth_anything_depth_estimator.py +70 -0
  29. keras_hub/src/models/depth_anything/depth_anything_depth_estimator_preprocessor.py +16 -0
  30. keras_hub/src/models/depth_anything/depth_anything_image_converter.py +10 -0
  31. keras_hub/src/models/depth_anything/depth_anything_layers.py +725 -0
  32. keras_hub/src/models/depth_anything/depth_anything_loss.py +89 -0
  33. keras_hub/src/models/depth_anything/depth_anything_presets.py +41 -0
  34. keras_hub/src/models/depth_anything/interpolate.py +62 -0
  35. keras_hub/src/models/depth_estimator.py +239 -0
  36. keras_hub/src/models/depth_estimator_preprocessor.py +78 -0
  37. keras_hub/src/models/dinov2/dinov2_backbone.py +29 -3
  38. keras_hub/src/models/dinov2/dinov2_layers.py +16 -4
  39. keras_hub/src/models/dinov3/__init__.py +5 -0
  40. keras_hub/src/models/dinov3/dinov3_backbone.py +263 -0
  41. keras_hub/src/models/dinov3/dinov3_image_converter.py +8 -0
  42. keras_hub/src/models/dinov3/dinov3_layers.py +1013 -0
  43. keras_hub/src/models/dinov3/dinov3_presets.py +4 -0
  44. keras_hub/src/models/gemma/gemma_backbone.py +0 -1
  45. keras_hub/src/models/gemma/gemma_presets.py +30 -0
  46. keras_hub/src/models/gemma3/gemma3_attention.py +48 -0
  47. keras_hub/src/models/gemma3/gemma3_backbone.py +4 -1
  48. keras_hub/src/models/gemma3/gemma3_decoder_block.py +12 -0
  49. keras_hub/src/models/gemma3/gemma3_presets.py +39 -0
  50. keras_hub/src/models/hgnetv2/hgnetv2_backbone.py +4 -1
  51. keras_hub/src/models/hgnetv2/hgnetv2_encoder.py +3 -2
  52. keras_hub/src/models/hgnetv2/hgnetv2_layers.py +27 -11
  53. keras_hub/src/models/image_to_image.py +5 -0
  54. keras_hub/src/models/inpaint.py +5 -0
  55. keras_hub/src/models/mobilenetv5/__init__.py +9 -0
  56. keras_hub/src/models/mobilenetv5/mobilenetv5_attention.py +699 -0
  57. keras_hub/src/models/mobilenetv5/mobilenetv5_backbone.py +396 -0
  58. keras_hub/src/models/mobilenetv5/mobilenetv5_blocks.py +890 -0
  59. keras_hub/src/models/mobilenetv5/mobilenetv5_builder.py +436 -0
  60. keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier.py +157 -0
  61. keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier_preprocessor.py +16 -0
  62. keras_hub/src/models/mobilenetv5/mobilenetv5_image_converter.py +10 -0
  63. keras_hub/src/models/mobilenetv5/mobilenetv5_layers.py +462 -0
  64. keras_hub/src/models/mobilenetv5/mobilenetv5_presets.py +15 -0
  65. keras_hub/src/models/mobilenetv5/mobilenetv5_utils.py +146 -0
  66. keras_hub/src/models/parseq/__init__.py +5 -0
  67. keras_hub/src/models/parseq/parseq_backbone.py +134 -0
  68. keras_hub/src/models/parseq/parseq_causal_lm.py +466 -0
  69. keras_hub/src/models/parseq/parseq_causal_lm_preprocessor.py +168 -0
  70. keras_hub/src/models/parseq/parseq_decoder.py +418 -0
  71. keras_hub/src/models/parseq/parseq_image_converter.py +8 -0
  72. keras_hub/src/models/parseq/parseq_presets.py +15 -0
  73. keras_hub/src/models/parseq/parseq_tokenizer.py +221 -0
  74. keras_hub/src/models/qwen3_moe/__init__.py +5 -0
  75. keras_hub/src/models/qwen3_moe/qwen3_moe_attention.py +371 -0
  76. keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +365 -0
  77. keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm.py +357 -0
  78. keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm_preprocessor.py +12 -0
  79. keras_hub/src/models/qwen3_moe/qwen3_moe_decoder.py +672 -0
  80. keras_hub/src/models/qwen3_moe/qwen3_moe_layernorm.py +45 -0
  81. keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +30 -0
  82. keras_hub/src/models/qwen3_moe/qwen3_moe_tokenizer.py +48 -0
  83. keras_hub/src/models/sam/sam_prompt_encoder.py +3 -1
  84. keras_hub/src/models/siglip/siglip_presets.py +15 -0
  85. keras_hub/src/models/smollm3/smollm3_backbone.py +211 -0
  86. keras_hub/src/models/smollm3/smollm3_causal_lm.py +310 -0
  87. keras_hub/src/models/smollm3/smollm3_causal_lm_preprocessor.py +84 -0
  88. keras_hub/src/models/smollm3/smollm3_layers.py +757 -0
  89. keras_hub/src/models/smollm3/smollm3_tokenizer.py +60 -0
  90. keras_hub/src/models/smollm3/smollm3_utils.py +56 -0
  91. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +3 -3
  92. keras_hub/src/models/t5gemma/__init__.py +5 -0
  93. keras_hub/src/models/t5gemma/t5gemma_attention.py +370 -0
  94. keras_hub/src/models/t5gemma/t5gemma_backbone.py +366 -0
  95. keras_hub/src/models/t5gemma/t5gemma_decoder.py +355 -0
  96. keras_hub/src/models/t5gemma/t5gemma_encoder.py +214 -0
  97. keras_hub/src/models/t5gemma/t5gemma_layers.py +118 -0
  98. keras_hub/src/models/t5gemma/t5gemma_presets.py +374 -0
  99. keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm.py +442 -0
  100. keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm_preprocessor.py +216 -0
  101. keras_hub/src/models/t5gemma/t5gemma_tokenizer.py +84 -0
  102. keras_hub/src/models/text_to_image.py +5 -0
  103. keras_hub/src/samplers/beam_sampler.py +6 -6
  104. keras_hub/src/samplers/sampler.py +8 -6
  105. keras_hub/src/tests/test_case.py +40 -3
  106. keras_hub/src/tokenizers/tokenizer.py +15 -0
  107. keras_hub/src/utils/openvino_utils.py +141 -0
  108. keras_hub/src/utils/preset_utils.py +58 -2
  109. keras_hub/src/utils/tensor_utils.py +26 -2
  110. keras_hub/src/utils/timm/convert_mobilenetv5.py +321 -0
  111. keras_hub/src/utils/timm/preset_loader.py +8 -4
  112. keras_hub/src/utils/transformers/convert_dinov2.py +1 -0
  113. keras_hub/src/utils/transformers/convert_dinov3.py +106 -0
  114. keras_hub/src/utils/transformers/convert_qwen3_moe.py +216 -0
  115. keras_hub/src/utils/transformers/convert_smollm3.py +139 -0
  116. keras_hub/src/utils/transformers/convert_t5gemma.py +229 -0
  117. keras_hub/src/utils/transformers/convert_vit.py +4 -1
  118. keras_hub/src/utils/transformers/export/gemma.py +49 -4
  119. keras_hub/src/utils/transformers/export/hf_exporter.py +71 -25
  120. keras_hub/src/utils/transformers/preset_loader.py +12 -0
  121. keras_hub/src/version.py +1 -1
  122. keras_hub/tokenizers/__init__.py +15 -0
  123. {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/METADATA +1 -1
  124. {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/RECORD +126 -47
  125. {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/WHEEL +0 -0
  126. {keras_hub_nightly-0.22.0.dev202508170419.dist-info → keras_hub_nightly-0.24.0.dev202511090424.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,168 @@
1
+ import keras
2
+ from keras import ops
3
+
4
+ from keras_hub.src.api_export import keras_hub_export
5
+ from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker
6
+ from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor
7
+ from keras_hub.src.models.parseq.parseq_backbone import PARSeqBackbone
8
+ from keras_hub.src.models.parseq.parseq_image_converter import (
9
+ PARSeqImageConverter,
10
+ )
11
+ from keras_hub.src.models.parseq.parseq_tokenizer import PARSeqTokenizer
12
+ from keras_hub.src.utils.tensor_utils import preprocessing_function
13
+ from keras_hub.src.utils.tensor_utils import strip_to_ragged
14
+
15
+
16
+ @keras_hub_export("keras_hub.models.PARSeqCausalLMPreprocessor")
17
+ class PARSeqCausalLMPreprocessor(CausalLMPreprocessor):
18
+ backbone_cls = PARSeqBackbone
19
+ tokenizer_cls = PARSeqTokenizer
20
+ image_converter_cls = PARSeqImageConverter
21
+
22
+ def __init__(
23
+ self,
24
+ image_converter=None,
25
+ tokenizer=None,
26
+ sequence_length=25,
27
+ add_start_token=True,
28
+ add_end_token=True,
29
+ **kwargs,
30
+ ):
31
+ super().__init__(
32
+ tokenizer=tokenizer,
33
+ sequence_length=sequence_length,
34
+ add_start_token=add_start_token,
35
+ add_end_token=add_end_token,
36
+ **kwargs,
37
+ )
38
+ self.image_converter = image_converter
39
+
40
+ def build(self, input_shape):
41
+ # Defer packer creation to `build()` so that we can be sure tokenizer
42
+ # assets have loaded when restoring a saved model.
43
+ self.packer = StartEndPacker(
44
+ start_value=self.tokenizer.start_token_id,
45
+ end_value=self.tokenizer.end_token_id,
46
+ pad_value=self.tokenizer.pad_token_id,
47
+ sequence_length=self.sequence_length,
48
+ return_padding_mask=True,
49
+ )
50
+ self.built = True
51
+
52
+ @preprocessing_function
53
+ def call(self, x, y=None, sample_weight=None, sequence_length=None):
54
+ """Preprocesses the input data for training.
55
+
56
+ This method takes a dictionary containing images and text responses,
57
+ and converts them into a format suitable for training a PARSeq model.
58
+
59
+ Args:
60
+ x: dict. A dictionary containing the input data. Must have keys
61
+ "images" and "responses".
62
+ y: The target data. Defaults to None.
63
+ sample_weight: The sample weights. Defaults to None.
64
+ sequence_length: int. The maximum length of the input sequence.
65
+ Defaults to None, which uses the pre-defined sequence length.
66
+
67
+ Returns:
68
+ A tuple containing the preprocessed input data, target data, and
69
+ sample weights.
70
+ """
71
+ sequence_length = sequence_length or self.sequence_length
72
+ images, responses = x["images"], x["responses"]
73
+ if self.image_converter:
74
+ images = self.image_converter(images)
75
+ token_ids = self.tokenizer(responses)
76
+ token_ids, padding_mask = self.packer(
77
+ token_ids,
78
+ sequence_length=sequence_length + 1,
79
+ add_start_value=self.add_start_token,
80
+ add_end_value=self.add_end_token,
81
+ )
82
+ x = {
83
+ "images": images,
84
+ "token_ids": token_ids[..., :-1],
85
+ "padding_mask": padding_mask[..., :-1],
86
+ }
87
+ # Target `y` will be the next token.
88
+ y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:]
89
+ return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
90
+
91
+ @preprocessing_function
92
+ def generate_preprocess(
93
+ self,
94
+ x,
95
+ sequence_length=None,
96
+ ):
97
+ """Convert strings to integer token input for generation.
98
+
99
+ Similar to calling the layer for training, this method takes in strings
100
+ or tensor strings, tokenizes and packs the input, and computes a padding
101
+ mask masking all inputs not filled in with a padded value.
102
+
103
+ Unlike calling the layer for training, this method does not compute
104
+ labels and will never append a `tokenizer.end_token_id` to the end of
105
+ the sequence (as generation is expected to continue at the end of the
106
+ inputted prompt).
107
+ """
108
+ if not self.built:
109
+ self.build(None)
110
+ sequence_length = sequence_length or self.sequence_length
111
+ images = x
112
+ if self.image_converter:
113
+ images = self.image_converter(images)
114
+
115
+ images_shape = keras.ops.shape(images)
116
+ if len(images_shape) == 3:
117
+ batch_size = 1
118
+ else:
119
+ batch_size = images_shape[0]
120
+
121
+ token_ids = ops.concatenate(
122
+ (
123
+ ops.full([batch_size, 1], self.tokenizer.start_token_id),
124
+ ops.full(
125
+ [batch_size, sequence_length - 1],
126
+ self.tokenizer.pad_token_id,
127
+ ),
128
+ ),
129
+ axis=1,
130
+ )
131
+
132
+ padding_mask = ops.equal(token_ids, self.tokenizer.start_token_id)
133
+
134
+ return {
135
+ "images": images,
136
+ "token_ids": token_ids,
137
+ "padding_mask": padding_mask,
138
+ }
139
+
140
+ @preprocessing_function
141
+ def generate_postprocess(
142
+ self,
143
+ x,
144
+ ):
145
+ """Convert integer token output to strings for generation.
146
+
147
+ This method reverses `generate_preprocess()`, by first removing all
148
+ padding and start/end tokens, and then converting the integer sequence
149
+ back to a string.
150
+ """
151
+ if not self.built:
152
+ self.build(None)
153
+
154
+ token_ids, padding_mask = x["token_ids"], x["padding_mask"]
155
+ ids_to_strip = self.tokenizer.special_token_ids
156
+ token_ids = strip_to_ragged(token_ids, padding_mask, ids_to_strip)
157
+ return self.tokenizer.detokenize(token_ids)
158
+
159
+ def get_config(self):
160
+ config = super().get_config()
161
+ config.update(
162
+ {
163
+ "sequence_length": self.sequence_length,
164
+ "add_start_token": self.add_start_token,
165
+ "add_end_token": self.add_end_token,
166
+ }
167
+ )
168
+ return config
@@ -0,0 +1,418 @@
1
+ import keras
2
+ from keras import ops
3
+
4
+ from keras_hub.src.layers.modeling.cached_multi_head_attention import (
5
+ CachedMultiHeadAttention,
6
+ )
7
+ from keras_hub.src.layers.modeling.transformer_layer_utils import (
8
+ compute_causal_mask,
9
+ )
10
+ from keras_hub.src.layers.modeling.transformer_layer_utils import (
11
+ merge_padding_and_attention_mask,
12
+ )
13
+ from keras_hub.src.models.vit.vit_layers import MLP
14
+
15
+
16
+ class PARSeqDecoderBlock(keras.layers.Layer):
17
+ """A decoder block for the PARSeq model.
18
+
19
+ This block consists of self-attention, cross-attention, and a multilayer
20
+ perceptron (MLP). It also includes layer normalization and dropout layers.
21
+
22
+ Args:
23
+ hidden_dim: int. The dimension of the hidden layers.
24
+ num_heads: int. The number of attention heads.
25
+ mlp_dim: int. The dimension of the MLP hidden layer.
26
+ dropout_rate: float. The dropout rate used in the feedforward layers.
27
+ attention_dropout: float. The dropout rate for the attention weights.
28
+ layer_norm_epsilon: float. A small float added to the denominator for
29
+ numerical stability in layer normalization.
30
+ **kwargs: Additional keyword arguments passed to the base
31
+ `keras.layers.Layer` constructor.
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ hidden_dim,
37
+ num_heads,
38
+ mlp_dim,
39
+ dropout_rate=0.1,
40
+ attention_dropout=0.1,
41
+ layer_norm_epsilon=1e-5,
42
+ **kwargs,
43
+ ):
44
+ super().__init__(**kwargs)
45
+
46
+ key_dim = hidden_dim // num_heads
47
+
48
+ # === Config ===
49
+ self.hidden_dim = hidden_dim
50
+ self.num_heads = num_heads
51
+ self.mlp_dim = mlp_dim
52
+ self.key_dim = key_dim
53
+ self.dropout_rate = dropout_rate
54
+ self.attention_dropout = attention_dropout
55
+ self.layer_norm_epsilon = layer_norm_epsilon
56
+
57
+ def build(self, input_shape):
58
+ self.query_layer_norm = keras.layers.LayerNormalization(
59
+ epsilon=self.layer_norm_epsilon,
60
+ name="query_layer_norm",
61
+ dtype=self.dtype_policy,
62
+ )
63
+ self.query_layer_norm.build(input_shape)
64
+ self.content_layer_norm = keras.layers.LayerNormalization(
65
+ epsilon=self.layer_norm_epsilon,
66
+ name="content_layer_norm",
67
+ dtype=self.dtype_policy,
68
+ )
69
+ self.content_layer_norm.build(input_shape)
70
+ self.self_attention = CachedMultiHeadAttention(
71
+ num_heads=self.num_heads,
72
+ key_dim=self.key_dim,
73
+ dropout=self.attention_dropout,
74
+ name="self_attention",
75
+ dtype=self.dtype_policy,
76
+ )
77
+ self.self_attention.build(input_shape, input_shape)
78
+ self.cross_attention = CachedMultiHeadAttention(
79
+ num_heads=self.num_heads,
80
+ key_dim=self.key_dim,
81
+ dropout=self.attention_dropout,
82
+ name="cross_attention",
83
+ dtype=self.dtype_policy,
84
+ )
85
+ self.cross_attention.build(input_shape, input_shape)
86
+
87
+ self.layer_norm_1 = keras.layers.LayerNormalization(
88
+ epsilon=self.layer_norm_epsilon,
89
+ name="ln_1",
90
+ dtype=self.dtype_policy,
91
+ )
92
+ self.layer_norm_1.build((None, None, self.hidden_dim))
93
+ self.layer_norm_2 = keras.layers.LayerNormalization(
94
+ epsilon=self.layer_norm_epsilon,
95
+ name="ln_2",
96
+ dtype=self.dtype_policy,
97
+ )
98
+ self.layer_norm_2.build((None, None, self.hidden_dim))
99
+ self.mlp = MLP(
100
+ hidden_dim=self.hidden_dim,
101
+ mlp_dim=self.mlp_dim,
102
+ dropout_rate=self.dropout_rate,
103
+ name="mlp",
104
+ dtype=self.dtype_policy,
105
+ )
106
+ self.mlp.build((None, None, self.hidden_dim))
107
+ self.dropout = keras.layers.Dropout(
108
+ rate=self.dropout_rate,
109
+ dtype=self.dtype_policy,
110
+ name="decoder_block_dropout",
111
+ )
112
+
113
+ self.built = True
114
+
115
+ def forward_stream(
116
+ self,
117
+ target,
118
+ target_norm,
119
+ target_kv,
120
+ memory,
121
+ padding_mask=None,
122
+ self_attention_cache=None,
123
+ self_attention_cache_update_index=0,
124
+ train_attention_mask=None,
125
+ ):
126
+ self_attention_new_cache = None
127
+ if train_attention_mask is None:
128
+ target_attention_mask = self._compute_attention_mask(
129
+ target_norm,
130
+ padding_mask,
131
+ self_attention_cache,
132
+ self_attention_cache_update_index,
133
+ )
134
+ else:
135
+ target_attention_mask = merge_padding_and_attention_mask(
136
+ target_norm, padding_mask, attention_mask=train_attention_mask
137
+ )
138
+
139
+ if self_attention_cache is not None:
140
+ target2, self_attention_new_cache = self.self_attention(
141
+ target_norm,
142
+ target_kv,
143
+ target_kv,
144
+ attention_mask=target_attention_mask,
145
+ cache=self_attention_cache,
146
+ cache_update_index=self_attention_cache_update_index,
147
+ )
148
+ else:
149
+ target2 = self.self_attention(
150
+ target_norm,
151
+ target_kv,
152
+ target_kv,
153
+ attention_mask=target_attention_mask,
154
+ )
155
+ target = ops.add(target, self.dropout(target2))
156
+ target2 = self.cross_attention(
157
+ self.layer_norm_1(target),
158
+ memory,
159
+ memory,
160
+ )
161
+ target = ops.add(target, self.dropout(target2))
162
+
163
+ target2 = self.mlp(self.layer_norm_2(target))
164
+ target = ops.add(target, target2)
165
+
166
+ return target, self_attention_new_cache
167
+
168
+ def call(
169
+ self,
170
+ query,
171
+ content,
172
+ memory,
173
+ padding_mask=None,
174
+ update_content=True,
175
+ query_self_attention_cache=None,
176
+ query_self_attention_cache_update_index=0,
177
+ content_self_attention_cache=None,
178
+ content_self_attention_cache_update_index=0,
179
+ query_mask=None,
180
+ content_mask=None,
181
+ ):
182
+ # position + token embeddings
183
+ query_norm = self.query_layer_norm(query)
184
+ # position embeddings
185
+ content_norm = self.content_layer_norm(content)
186
+ (
187
+ query,
188
+ query_self_attention_new_cache,
189
+ ) = self.forward_stream(
190
+ query,
191
+ query_norm,
192
+ content_norm,
193
+ memory,
194
+ padding_mask=padding_mask,
195
+ train_attention_mask=query_mask,
196
+ self_attention_cache=query_self_attention_cache,
197
+ self_attention_cache_update_index=query_self_attention_cache_update_index,
198
+ )
199
+
200
+ if update_content:
201
+ (
202
+ content,
203
+ content_self_attention_new_cache,
204
+ ) = self.forward_stream(
205
+ content,
206
+ content_norm,
207
+ content_norm,
208
+ memory, # image embeddings (encoder embeddings)
209
+ padding_mask=padding_mask,
210
+ train_attention_mask=content_mask,
211
+ self_attention_cache=content_self_attention_cache,
212
+ self_attention_cache_update_index=content_self_attention_cache_update_index,
213
+ )
214
+
215
+ return_values = [query, content]
216
+
217
+ if query_self_attention_cache is not None:
218
+ return_values.append(query_self_attention_new_cache)
219
+ if update_content and content_self_attention_cache is not None:
220
+ return_values.append(content_self_attention_new_cache)
221
+ elif not update_content and content_self_attention_cache is not None:
222
+ return_values.append(content_self_attention_cache)
223
+
224
+ return tuple(return_values)
225
+
226
+ def _compute_attention_mask(
227
+ self, x, padding_mask, cache, cache_update_index
228
+ ):
229
+ decoder_mask = merge_padding_and_attention_mask(
230
+ inputs=x, padding_mask=padding_mask, attention_mask=None
231
+ )
232
+ batch_size = ops.shape(x)[0]
233
+ input_length = output_length = ops.shape(x)[1]
234
+ if cache is not None:
235
+ input_length = ops.shape(cache)[2]
236
+
237
+ causal_mask = compute_causal_mask(
238
+ batch_size=batch_size,
239
+ input_length=input_length,
240
+ output_length=output_length,
241
+ cache_index=cache_update_index,
242
+ )
243
+
244
+ return (
245
+ ops.minimum(decoder_mask, causal_mask)
246
+ if decoder_mask is not None
247
+ else causal_mask
248
+ )
249
+
250
+ def get_config(self):
251
+ config = super().get_config()
252
+ config.update(
253
+ {
254
+ "num_heads": self.num_heads,
255
+ "hidden_dim": self.hidden_dim,
256
+ "key_dim": self.key_dim,
257
+ "mlp_dim": self.mlp_dim,
258
+ "dropout_rate": self.dropout_rate,
259
+ "attention_dropout": self.attention_dropout,
260
+ "layer_norm_epsilon": self.layer_norm_epsilon,
261
+ }
262
+ )
263
+ return config
264
+
265
+
266
+ class PARSeqDecoder(keras.layers.Layer):
267
+ """The PARSeq decoder.
268
+
269
+ This decoder consists of multiple decoder blocks and a token embedding
270
+ layer. It takes token IDs and memory from the encoder as input and outputs a
271
+ sequence of hidden states.
272
+
273
+ Args:
274
+ vocabulary_size: int. The size of the vocabulary.
275
+ max_label_length: int. The maximum length of the label sequence.
276
+ num_layers: int. The number of decoder layers.
277
+ hidden_dim: int. The dimension of the hidden layers.
278
+ mlp_dim: int. The dimension of the MLP hidden layer.
279
+ num_heads: int. The number of attention heads.
280
+ dropout_rate: float. The dropout rate.
281
+ attention_dropout: float. The dropout rate for the attention weights.
282
+ layer_norm_epsilon: float. A small float added to the denominator for
283
+ numerical stability in layer normalization.
284
+ **kwargs: Additional keyword arguments passed to the base
285
+ `keras.layers.Layer` constructor.
286
+ """
287
+
288
+ def __init__(
289
+ self,
290
+ vocabulary_size,
291
+ max_label_length,
292
+ num_layers,
293
+ hidden_dim,
294
+ mlp_dim,
295
+ num_heads,
296
+ dropout_rate=0.1,
297
+ attention_dropout=0.1,
298
+ layer_norm_epsilon=1e-5,
299
+ **kwargs,
300
+ ):
301
+ super().__init__(**kwargs)
302
+
303
+ # === Config ===
304
+ self.vocabulary_size = vocabulary_size
305
+ self.max_label_length = max_label_length
306
+ self.hidden_dim = hidden_dim
307
+ self.mlp_dim = mlp_dim
308
+ self.num_heads = num_heads
309
+ self.dropout_rate = dropout_rate
310
+ self.attention_dropout = attention_dropout
311
+ self.layer_norm_epsilon = layer_norm_epsilon
312
+ self.num_layers = num_layers
313
+
314
+ def build(self, input_shape):
315
+ self.token_embedding = keras.layers.Embedding(
316
+ input_dim=self.vocabulary_size,
317
+ output_dim=self.hidden_dim,
318
+ dtype=self.dtype_policy,
319
+ name="token_embedding",
320
+ )
321
+ self.token_embedding.build((1, self.vocabulary_size))
322
+ self.pos_query_embeddings = self.add_weight(
323
+ shape=(1, self.max_label_length + 1, self.hidden_dim),
324
+ name="pos_query_embeddings",
325
+ dtype=self.dtype,
326
+ )
327
+ self.dropout = keras.layers.Dropout(
328
+ self.dropout_rate, dtype=self.dtype_policy, name="decoder_dropout"
329
+ )
330
+ self.decoder_layers = []
331
+ for i in range(self.num_layers):
332
+ decoder_layer = PARSeqDecoderBlock(
333
+ hidden_dim=self.hidden_dim,
334
+ num_heads=self.num_heads,
335
+ mlp_dim=self.mlp_dim,
336
+ dropout_rate=self.dropout_rate,
337
+ attention_dropout=self.attention_dropout,
338
+ layer_norm_epsilon=self.layer_norm_epsilon,
339
+ dtype=self.dtype_policy,
340
+ name=f"decoder_layer_{i}",
341
+ )
342
+ decoder_layer.build((None, None, self.hidden_dim))
343
+ self.decoder_layers.append(decoder_layer)
344
+
345
+ self.layer_norm = keras.layers.LayerNormalization(
346
+ epsilon=self.layer_norm_epsilon,
347
+ dtype=self.dtype_policy,
348
+ name="layer_norm",
349
+ )
350
+ self.layer_norm.build((None, None, self.hidden_dim))
351
+ self.built = True
352
+
353
+ def call(
354
+ self,
355
+ token_ids,
356
+ memory,
357
+ padding_mask=None,
358
+ query_mask=None,
359
+ content_mask=None,
360
+ ):
361
+ bs, tokens_length = ops.shape(token_ids)
362
+ # <bos> stands for the null context. We only supply position information
363
+ # for characters after <bos>.
364
+ null_context = self.hidden_dim**0.5 * self.token_embedding(
365
+ token_ids[:, :1]
366
+ )
367
+ if tokens_length > 1:
368
+ content = self.pos_query_embeddings[:, : tokens_length - 1, :]
369
+ content = content + self.hidden_dim**0.5 * self.token_embedding(
370
+ token_ids[:, 1:]
371
+ )
372
+ content = ops.concatenate([null_context, content], axis=1)
373
+ else:
374
+ content = null_context
375
+
376
+ content = self.dropout(content)
377
+
378
+ query = ops.multiply(
379
+ ops.ones((bs, 1, 1), dtype=self.dtype),
380
+ self.pos_query_embeddings[:, :tokens_length, :],
381
+ )
382
+ query = self.dropout(query)
383
+
384
+ for i, decoder_layer in enumerate(self.decoder_layers):
385
+ last = i == self.num_layers - 1
386
+ query, content = decoder_layer(
387
+ query=query,
388
+ content=content,
389
+ memory=memory,
390
+ padding_mask=padding_mask,
391
+ update_content=not last,
392
+ query_mask=query_mask,
393
+ content_mask=content_mask,
394
+ )
395
+
396
+ query = self.layer_norm(query)
397
+
398
+ return query
399
+
400
+ def compute_output_shape(self, input_shape):
401
+ return (None, None, self.hidden_dim)
402
+
403
+ def get_config(self):
404
+ config = super().get_config()
405
+ config.update(
406
+ {
407
+ "vocabulary_size": self.vocabulary_size,
408
+ "max_label_length": self.max_label_length,
409
+ "num_layers": self.num_layers,
410
+ "num_heads": self.num_heads,
411
+ "hidden_dim": self.hidden_dim,
412
+ "mlp_dim": self.mlp_dim,
413
+ "dropout_rate": self.dropout_rate,
414
+ "attention_dropout": self.attention_dropout,
415
+ "layer_norm_epsilon": self.layer_norm_epsilon,
416
+ }
417
+ )
418
+ return config
@@ -0,0 +1,8 @@
1
+ from keras_hub.src.api_export import keras_hub_export
2
+ from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
3
+ from keras_hub.src.models.parseq.parseq_backbone import PARSeqBackbone
4
+
5
+
6
+ @keras_hub_export("keras_hub.layers.PARSeqImageConverter")
7
+ class PARSeqImageConverter(ImageConverter):
8
+ backbone_cls = PARSeqBackbone
@@ -0,0 +1,15 @@
1
+ """PARSeq preset configurations."""
2
+
3
+ backbone_presets = {
4
+ "parseq": {
5
+ "metadata": {
6
+ "description": (
7
+ "Permuted autoregressive sequence (PARSeq) base "
8
+ "model for scene text recognition"
9
+ ),
10
+ "params": 23_832_671,
11
+ "path": "parseq",
12
+ },
13
+ "kaggle_handle": "kaggle://keras/parseq/keras/parseq/1",
14
+ }
15
+ }