keras-hub 0.25.1__py3-none-any.whl → 0.26.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. keras_hub/layers/__init__.py +21 -0
  2. keras_hub/models/__init__.py +27 -0
  3. keras_hub/src/layers/modeling/non_max_supression.py +5 -2
  4. keras_hub/src/layers/modeling/reversible_embedding.py +2 -275
  5. keras_hub/src/layers/modeling/token_and_position_embedding.py +6 -6
  6. keras_hub/src/layers/modeling/transformer_layer_utils.py +9 -9
  7. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +3 -1
  8. keras_hub/src/layers/preprocessing/multi_segment_packer.py +3 -1
  9. keras_hub/src/models/albert/albert_backbone.py +1 -3
  10. keras_hub/src/models/backbone.py +3 -0
  11. keras_hub/src/models/bart/bart_backbone.py +1 -3
  12. keras_hub/src/models/bert/bert_backbone.py +2 -4
  13. keras_hub/src/models/bloom/bloom_backbone.py +1 -3
  14. keras_hub/src/models/causal_lm.py +2 -2
  15. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -3
  16. keras_hub/src/models/edrec/edrec_backbone.py +147 -0
  17. keras_hub/src/models/edrec/edrec_layers.py +434 -0
  18. keras_hub/src/models/edrec/edrec_seq2seq_lm.py +273 -0
  19. keras_hub/src/models/electra/electra_backbone.py +1 -3
  20. keras_hub/src/models/f_net/f_net_backbone.py +1 -3
  21. keras_hub/src/models/falcon/falcon_backbone.py +1 -3
  22. keras_hub/src/models/flux/flux_layers.py +3 -3
  23. keras_hub/src/models/flux/flux_maths.py +29 -15
  24. keras_hub/src/models/gemma/gemma_backbone.py +1 -3
  25. keras_hub/src/models/gemma/gemma_causal_lm.py +1 -1
  26. keras_hub/src/models/gemma3/gemma3_attention.py +1 -1
  27. keras_hub/src/models/gemma3/gemma3_backbone.py +70 -8
  28. keras_hub/src/models/gemma3/gemma3_causal_lm.py +16 -1
  29. keras_hub/src/models/gemma3/gemma3_decoder_block.py +1 -1
  30. keras_hub/src/models/gemma3/{gemma3_interleave_embeddings.py → gemma3_layers.py} +101 -0
  31. keras_hub/src/models/gemma3/gemma3_presets.py +67 -7
  32. keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
  33. keras_hub/src/models/gpt2/gpt2_backbone.py +1 -3
  34. keras_hub/src/models/gpt2/gpt2_causal_lm.py +1 -1
  35. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +1 -3
  36. keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +1 -3
  37. keras_hub/src/models/llama/llama_backbone.py +1 -3
  38. keras_hub/src/models/masked_lm.py +1 -1
  39. keras_hub/src/models/mistral/mistral_backbone.py +1 -3
  40. keras_hub/src/models/mixtral/mixtral_backbone.py +1 -3
  41. keras_hub/src/models/moonshine/moonshine_backbone.py +1 -3
  42. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +1 -3
  43. keras_hub/src/models/parseq/parseq_tokenizer.py +3 -1
  44. keras_hub/src/models/phi3/phi3_backbone.py +1 -3
  45. keras_hub/src/models/qwen/qwen_backbone.py +1 -3
  46. keras_hub/src/models/qwen/qwen_presets.py +209 -0
  47. keras_hub/src/models/qwen3/qwen3_backbone.py +1 -3
  48. keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +1 -3
  49. keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +15 -0
  50. keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +1 -3
  51. keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -3
  52. keras_hub/src/models/rqvae/__init__.py +5 -0
  53. keras_hub/src/models/rqvae/rqvae_backbone.py +167 -0
  54. keras_hub/src/models/rqvae/rqvae_layers.py +335 -0
  55. keras_hub/src/models/rwkv7/__init__.py +5 -0
  56. keras_hub/src/models/rwkv7/rwkv7_backbone.py +180 -0
  57. keras_hub/src/models/rwkv7/rwkv7_causal_lm.py +259 -0
  58. keras_hub/src/models/rwkv7/rwkv7_causal_lm_preprocessor.py +214 -0
  59. keras_hub/src/models/rwkv7/rwkv7_layer.py +724 -0
  60. keras_hub/src/models/rwkv7/rwkv7_presets.py +26 -0
  61. keras_hub/src/models/rwkv7/rwkv7_tokenizer.py +495 -0
  62. keras_hub/src/models/sam/sam_backbone.py +5 -1
  63. keras_hub/src/models/sam/sam_prompt_encoder.py +1 -1
  64. keras_hub/src/models/sam3/__init__.py +7 -0
  65. keras_hub/src/models/sam3/roi_align.py +222 -0
  66. keras_hub/src/models/sam3/sam3_detr_decoder.py +641 -0
  67. keras_hub/src/models/sam3/sam3_detr_encoder.py +293 -0
  68. keras_hub/src/models/sam3/sam3_dot_product_scoring.py +120 -0
  69. keras_hub/src/models/sam3/sam3_geometry_encoder.py +517 -0
  70. keras_hub/src/models/sam3/sam3_image_converter.py +10 -0
  71. keras_hub/src/models/sam3/sam3_layers.py +814 -0
  72. keras_hub/src/models/sam3/sam3_mask_decoder.py +374 -0
  73. keras_hub/src/models/sam3/sam3_pc_backbone.py +306 -0
  74. keras_hub/src/models/sam3/sam3_pc_image_segmenter.py +282 -0
  75. keras_hub/src/models/sam3/sam3_pc_image_segmenter_preprocessor.py +336 -0
  76. keras_hub/src/models/sam3/sam3_presets.py +16 -0
  77. keras_hub/src/models/sam3/sam3_text_encoder.py +212 -0
  78. keras_hub/src/models/sam3/sam3_tokenizer.py +65 -0
  79. keras_hub/src/models/sam3/sam3_utils.py +134 -0
  80. keras_hub/src/models/sam3/sam3_vision_encoder.py +738 -0
  81. keras_hub/src/models/segformer/segformer_backbone.py +6 -6
  82. keras_hub/src/models/siglip/siglip_layers.py +1 -3
  83. keras_hub/src/models/smollm3/smollm3_backbone.py +1 -3
  84. keras_hub/src/models/stable_diffusion_3/t5_encoder.py +1 -3
  85. keras_hub/src/models/t5/t5_backbone.py +1 -3
  86. keras_hub/src/models/t5gemma/t5gemma_backbone.py +1 -3
  87. keras_hub/src/models/task.py +1 -1
  88. keras_hub/src/tests/test_case.py +394 -3
  89. keras_hub/src/tokenizers/byte_pair_tokenizer.py +33 -2
  90. keras_hub/src/tokenizers/byte_tokenizer.py +3 -1
  91. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +15 -1
  92. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +3 -1
  93. keras_hub/src/tokenizers/word_piece_tokenizer.py +15 -1
  94. keras_hub/src/utils/preset_utils.py +1 -1
  95. keras_hub/src/utils/tensor_utils.py +12 -0
  96. keras_hub/src/utils/transformers/convert_gemma3.py +68 -22
  97. keras_hub/src/utils/transformers/convert_qwen3_moe.py +4 -1
  98. keras_hub/src/utils/transformers/convert_sam3.py +472 -0
  99. keras_hub/src/utils/transformers/export/gemma3.py +196 -0
  100. keras_hub/src/utils/transformers/export/hf_exporter.py +86 -25
  101. keras_hub/src/utils/transformers/export/qwen.py +136 -0
  102. keras_hub/src/utils/transformers/preset_loader.py +15 -1
  103. keras_hub/src/version.py +1 -1
  104. keras_hub/tokenizers/__init__.py +6 -0
  105. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/METADATA +6 -13
  106. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/RECORD +108 -76
  107. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/WHEEL +1 -1
  108. keras_hub/src/models/gemma3/rms_normalization.py +0 -26
  109. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,293 @@
1
+ from keras import layers
2
+ from keras import ops
3
+
4
+ from keras_hub.src.api_export import keras_hub_export
5
+ from keras_hub.src.models.sam3.sam3_layers import SAM3MLP
6
+ from keras_hub.src.models.sam3.sam3_layers import SAM3Attention
7
+ from keras_hub.src.models.sam3.sam3_utils import create_bidirectional_mask
8
+
9
+
10
+ class SAM3DetrEncoderLayer(layers.Layer):
11
+ def __init__(
12
+ self,
13
+ hidden_dim,
14
+ intermediate_dim,
15
+ num_heads,
16
+ hidden_activation="relu",
17
+ dropout_rate=0.0,
18
+ layer_norm_epsilon=1e-6,
19
+ **kwargs,
20
+ ):
21
+ super().__init__(**kwargs)
22
+ self.hidden_dim = int(hidden_dim)
23
+ self.intermediate_dim = int(intermediate_dim)
24
+ self.num_heads = int(num_heads)
25
+ self.hidden_activation = hidden_activation
26
+ self.dropout_rate = float(dropout_rate)
27
+ self.layer_norm_epsilon = float(layer_norm_epsilon)
28
+
29
+ self.layer_norm1 = layers.LayerNormalization(
30
+ epsilon=self.layer_norm_epsilon,
31
+ dtype=self.dtype_policy,
32
+ name="layer_norm1",
33
+ )
34
+ self.self_attn = SAM3Attention(
35
+ hidden_dim=self.hidden_dim,
36
+ num_heads=self.num_heads,
37
+ dtype=self.dtype_policy,
38
+ name="self_attn",
39
+ )
40
+ self.dropout = layers.Dropout(
41
+ self.dropout_rate, dtype=self.dtype_policy, name="dropout"
42
+ )
43
+ self.cross_attn = SAM3Attention(
44
+ hidden_dim=self.hidden_dim,
45
+ num_heads=self.num_heads,
46
+ dtype=self.dtype_policy,
47
+ name="cross_attn",
48
+ )
49
+ self.layer_norm2 = layers.LayerNormalization(
50
+ epsilon=self.layer_norm_epsilon,
51
+ dtype=self.dtype_policy,
52
+ name="layer_norm2",
53
+ )
54
+ self.mlp = SAM3MLP(
55
+ hidden_dim=self.hidden_dim,
56
+ intermediate_dim=self.intermediate_dim,
57
+ activation=self.hidden_activation,
58
+ dropout_rate=self.dropout_rate,
59
+ dtype=self.dtype_policy,
60
+ name="mlp",
61
+ )
62
+ self.layer_norm3 = layers.LayerNormalization(
63
+ epsilon=self.layer_norm_epsilon,
64
+ dtype=self.dtype_policy,
65
+ name="layer_norm3",
66
+ )
67
+
68
+ def build(
69
+ self,
70
+ vision_feats_shape,
71
+ prompt_feats_shape,
72
+ vision_pos_encodings_shape,
73
+ prompt_cross_attn_masks_shape,
74
+ ):
75
+ self.layer_norm1.build(vision_feats_shape)
76
+ self.self_attn.build(
77
+ vision_feats_shape, vision_feats_shape, vision_feats_shape
78
+ )
79
+ self.dropout.build(vision_feats_shape)
80
+ self.layer_norm2.build(vision_feats_shape)
81
+ self.cross_attn.build(
82
+ vision_feats_shape, prompt_feats_shape, prompt_feats_shape
83
+ )
84
+ self.layer_norm3.build(vision_feats_shape)
85
+ self.mlp.build(vision_feats_shape)
86
+
87
+ def call(
88
+ self,
89
+ vision_feats,
90
+ prompt_feats,
91
+ vision_pos_encodings,
92
+ prompt_cross_attn_masks=None,
93
+ training=None,
94
+ ):
95
+ residual = vision_feats
96
+ hidden_states = self.layer_norm1(vision_feats, training=training)
97
+ hidden_states_with_pos = ops.add(hidden_states, vision_pos_encodings)
98
+ hidden_states = self.self_attn(
99
+ query=hidden_states_with_pos,
100
+ key=hidden_states_with_pos,
101
+ value=hidden_states,
102
+ training=training,
103
+ )
104
+ hidden_states = ops.add(
105
+ self.dropout(hidden_states, training=training), residual
106
+ )
107
+
108
+ residual = hidden_states
109
+ hidden_states = self.layer_norm2(hidden_states, training=training)
110
+ hidden_states = self.cross_attn(
111
+ query=hidden_states,
112
+ key=prompt_feats,
113
+ value=prompt_feats,
114
+ attention_mask=prompt_cross_attn_masks,
115
+ training=training,
116
+ )
117
+
118
+ hidden_states = ops.add(
119
+ self.dropout(hidden_states, training=training), residual
120
+ )
121
+
122
+ residual = hidden_states
123
+ hidden_states = self.layer_norm3(hidden_states, training=training)
124
+ hidden_states = self.mlp(hidden_states, training=training)
125
+ hidden_states = ops.add(
126
+ self.dropout(hidden_states, training=training), residual
127
+ )
128
+ return hidden_states
129
+
130
+ def get_config(self):
131
+ config = super().get_config()
132
+ config.update(
133
+ {
134
+ "hidden_dim": self.hidden_dim,
135
+ "intermediate_dim": self.intermediate_dim,
136
+ "num_heads": self.num_heads,
137
+ "hidden_activation": self.hidden_activation,
138
+ "dropout_rate": self.dropout_rate,
139
+ "layer_norm_epsilon": self.layer_norm_epsilon,
140
+ }
141
+ )
142
+ return config
143
+
144
+ def compute_output_shape(
145
+ self,
146
+ vision_feats_shape,
147
+ prompt_feats_shape,
148
+ vision_pos_encodings_shape,
149
+ prompt_cross_attn_masks_shape,
150
+ ):
151
+ return vision_feats_shape
152
+
153
+
154
+ @keras_hub_export("keras_hub.layers.SAM3DetrEncoder")
155
+ class SAM3DetrEncoder(layers.Layer):
156
+ """A DETR encoder for the Segment Anything Model 3 (SAM3).
157
+
158
+ This layer implements a transformer-based encoder that fuses vision and
159
+ prompt features. It processes flattened vision features and prompt features
160
+ through multiple layers of self-attention and cross-attention.
161
+
162
+ Args:
163
+ num_layers: int. The number of transformer layers.
164
+ hidden_dim: int. The hidden dimension of the transformer layers.
165
+ intermediate_dim: int. The dimension of the intermediate layer in the
166
+ transformer's MLP.
167
+ num_heads: int. The number of attention heads.
168
+ hidden_activation: str. The activation function for the transformer
169
+ layers. Defaults to `"relu"`.
170
+ dropout_rate: float. The dropout rate for the MLP and attention.
171
+ Defaults to `0.0`.
172
+ layer_norm_epsilon: float. The epsilon value for layer normalization.
173
+ Defaults to `1e-6`.
174
+ """
175
+
176
+ def __init__(
177
+ self,
178
+ num_layers,
179
+ hidden_dim,
180
+ intermediate_dim,
181
+ num_heads,
182
+ hidden_activation="relu",
183
+ dropout_rate=0.0,
184
+ layer_norm_epsilon=1e-6,
185
+ **kwargs,
186
+ ):
187
+ super().__init__(**kwargs)
188
+ self.num_layers = int(num_layers)
189
+ self.hidden_dim = int(hidden_dim)
190
+ self.intermediate_dim = int(intermediate_dim)
191
+ self.num_heads = int(num_heads)
192
+ self.hidden_activation = hidden_activation
193
+ self.dropout_rate = float(dropout_rate)
194
+ self.layer_norm_epsilon = float(layer_norm_epsilon)
195
+
196
+ self.layers = [
197
+ SAM3DetrEncoderLayer(
198
+ hidden_dim=self.hidden_dim,
199
+ intermediate_dim=self.intermediate_dim,
200
+ num_heads=self.num_heads,
201
+ dropout_rate=self.dropout_rate,
202
+ hidden_activation=self.hidden_activation,
203
+ layer_norm_epsilon=self.layer_norm_epsilon,
204
+ dtype=self.dtype_policy,
205
+ name=f"layer_{i}",
206
+ )
207
+ for i in range(self.num_layers)
208
+ ]
209
+
210
+ def build(
211
+ self,
212
+ vision_features_shape,
213
+ text_features_shape,
214
+ vision_pos_embeds_shape,
215
+ text_masks_shape,
216
+ ):
217
+ self.height = int(vision_features_shape[1])
218
+ self.width = int(vision_features_shape[2])
219
+ feature_flattened_shape = [
220
+ vision_features_shape[0],
221
+ vision_features_shape[1] * vision_features_shape[2],
222
+ vision_features_shape[-1],
223
+ ]
224
+ for layer in self.layers:
225
+ layer.build(
226
+ feature_flattened_shape,
227
+ text_features_shape,
228
+ feature_flattened_shape,
229
+ None,
230
+ )
231
+
232
+ def call(
233
+ self,
234
+ vision_features,
235
+ text_features,
236
+ vision_pos_embeds,
237
+ text_masks,
238
+ training=None,
239
+ ):
240
+ # Flatten multi-level features for encoder processing.
241
+ batch_size = ops.shape(vision_features)[0]
242
+ hidden_dim = ops.shape(vision_features)[-1]
243
+ features_flattened = ops.reshape(
244
+ vision_features,
245
+ (batch_size, self.height * self.width, hidden_dim),
246
+ )
247
+ pos_embeds_flattened = ops.reshape(
248
+ vision_pos_embeds,
249
+ (batch_size, self.height * self.width, hidden_dim),
250
+ )
251
+
252
+ prompt_cross_attn_masks = create_bidirectional_mask(
253
+ features_flattened, text_masks
254
+ )
255
+ hidden_states = features_flattened
256
+ for layer in self.layers:
257
+ hidden_states = layer(
258
+ hidden_states,
259
+ prompt_feats=text_features,
260
+ vision_pos_encodings=pos_embeds_flattened,
261
+ prompt_cross_attn_masks=prompt_cross_attn_masks,
262
+ training=training,
263
+ )
264
+ return hidden_states, pos_embeds_flattened
265
+
266
+ def get_config(self):
267
+ config = super().get_config()
268
+ config.update(
269
+ {
270
+ "num_layers": self.num_layers,
271
+ "hidden_dim": self.hidden_dim,
272
+ "intermediate_dim": self.intermediate_dim,
273
+ "num_heads": self.num_heads,
274
+ "hidden_activation": self.hidden_activation,
275
+ "dropout_rate": self.dropout_rate,
276
+ "layer_norm_epsilon": self.layer_norm_epsilon,
277
+ }
278
+ )
279
+ return config
280
+
281
+ def compute_output_shape(
282
+ self,
283
+ vision_features_shape,
284
+ text_features_shape,
285
+ vision_pos_embeds_shape,
286
+ text_masks_shape,
287
+ ):
288
+ features_flattened_shape = [
289
+ vision_features_shape[0],
290
+ vision_features_shape[1] * vision_features_shape[2],
291
+ vision_features_shape[-1],
292
+ ]
293
+ return features_flattened_shape, features_flattened_shape
@@ -0,0 +1,120 @@
1
+ import numpy as np
2
+ from keras import layers
3
+ from keras import ops
4
+
5
+ from keras_hub.src.models.sam3.sam3_layers import SAM3DecoderMLP
6
+
7
+
8
+ class SAM3DotProductScoring(layers.Layer):
9
+ def __init__(
10
+ self,
11
+ hidden_dim,
12
+ intermediate_dim,
13
+ dropout_rate=0.0,
14
+ layer_norm_epsilon=1e-6,
15
+ **kwargs,
16
+ ):
17
+ super().__init__(**kwargs)
18
+ self.hidden_dim = int(hidden_dim)
19
+ self.intermediate_dim = int(intermediate_dim)
20
+ self.dropout_rate = float(dropout_rate)
21
+ self.layer_norm_epsilon = float(layer_norm_epsilon)
22
+
23
+ self.text_mlp = SAM3DecoderMLP(
24
+ num_layers=2,
25
+ hidden_dim=self.intermediate_dim,
26
+ output_dim=self.hidden_dim,
27
+ dtype=self.dtype_policy,
28
+ name="text_mlp",
29
+ )
30
+ self.text_mlp_dropout = layers.Dropout(
31
+ self.dropout_rate, dtype=self.dtype_policy, name="text_mlp_dropout"
32
+ )
33
+ self.text_mlp_out_norm = layers.LayerNormalization(
34
+ epsilon=self.layer_norm_epsilon,
35
+ dtype=self.dtype_policy,
36
+ name="text_mlp_out_norm",
37
+ )
38
+
39
+ # Projections for text and query features.
40
+ self.text_proj = layers.Dense(
41
+ self.hidden_dim, dtype=self.dtype_policy, name="text_proj"
42
+ )
43
+ self.query_proj = layers.Dense(
44
+ self.hidden_dim, dtype=self.dtype_policy, name="query_proj"
45
+ )
46
+
47
+ # Scale factor for dot product.
48
+ self.scale = float(1.0 / np.sqrt(self.hidden_dim))
49
+
50
+ # Clamping to avoid numerical issues.
51
+ self.clamp_max_val = 12.0
52
+
53
+ def build(
54
+ self, decoder_hidden_states_shape, text_features_shape, text_masks_shape
55
+ ):
56
+ self.text_mlp.build(text_features_shape)
57
+ self.text_mlp_dropout.build(text_features_shape)
58
+ self.text_mlp_out_norm.build(text_features_shape)
59
+ pooled_text_shape = [text_features_shape[0], text_features_shape[-1]]
60
+ self.text_proj.build(pooled_text_shape)
61
+ self.query_proj.build(decoder_hidden_states_shape)
62
+
63
+ def _pool_text_features(self, text_features, text_mask=None):
64
+ if text_mask is None:
65
+ # No padding, simple mean.
66
+ return ops.mean(text_features, axis=1)
67
+
68
+ is_valid = ops.expand_dims(
69
+ ops.cast(text_mask, text_features.dtype), axis=-1
70
+ )
71
+ # Count valid tokens per batch.
72
+ num_valid = ops.maximum(ops.sum(is_valid, axis=1), 1.0)
73
+ # Mean pool only over valid tokens.
74
+ return ops.divide(
75
+ ops.sum(ops.multiply(text_features, is_valid), axis=1), num_valid
76
+ )
77
+
78
+ def call(
79
+ self,
80
+ decoder_hidden_states,
81
+ text_features,
82
+ text_masks=None,
83
+ training=None,
84
+ ):
85
+ orig_text_features = text_features
86
+ text_features = self.text_mlp(text_features, training=training)
87
+ text_features = self.text_mlp_dropout(text_features, training=training)
88
+ text_features = ops.add(text_features, orig_text_features)
89
+ text_features = self.text_mlp_out_norm(text_features, training=training)
90
+
91
+ pooled_text = self._pool_text_features(text_features, text_masks)
92
+
93
+ proj_text = self.text_proj(pooled_text, training=training)
94
+ proj_queries = self.query_proj(decoder_hidden_states, training=training)
95
+
96
+ proj_text = ops.expand_dims(proj_text, axis=-1)
97
+ scores = ops.matmul(proj_queries, ops.expand_dims(proj_text, axis=1))
98
+ scores = ops.multiply(scores, self.scale)
99
+ scores = ops.clip(scores, -self.clamp_max_val, self.clamp_max_val)
100
+ return scores
101
+
102
+ def get_config(self):
103
+ config = super().get_config()
104
+ config.update(
105
+ {
106
+ "hidden_dim": self.hidden_dim,
107
+ "intermediate_dim": self.intermediate_dim,
108
+ "dropout_rate": self.dropout_rate,
109
+ "layer_norm_epsilon": self.layer_norm_epsilon,
110
+ }
111
+ )
112
+ return config
113
+
114
+ def compute_output_shape(
115
+ self, decoder_hidden_states_shape, text_features_shape, text_masks_shape
116
+ ):
117
+ batch_size = decoder_hidden_states_shape[0]
118
+ num_layers = decoder_hidden_states_shape[1]
119
+ num_queries = decoder_hidden_states_shape[2]
120
+ return [batch_size, num_layers, num_queries, 1]