keras-hub 0.25.1__py3-none-any.whl → 0.26.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. keras_hub/layers/__init__.py +21 -0
  2. keras_hub/models/__init__.py +27 -0
  3. keras_hub/src/layers/modeling/non_max_supression.py +5 -2
  4. keras_hub/src/layers/modeling/reversible_embedding.py +2 -275
  5. keras_hub/src/layers/modeling/token_and_position_embedding.py +6 -6
  6. keras_hub/src/layers/modeling/transformer_layer_utils.py +9 -9
  7. keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +3 -1
  8. keras_hub/src/layers/preprocessing/multi_segment_packer.py +3 -1
  9. keras_hub/src/models/albert/albert_backbone.py +1 -3
  10. keras_hub/src/models/backbone.py +3 -0
  11. keras_hub/src/models/bart/bart_backbone.py +1 -3
  12. keras_hub/src/models/bert/bert_backbone.py +2 -4
  13. keras_hub/src/models/bloom/bloom_backbone.py +1 -3
  14. keras_hub/src/models/causal_lm.py +2 -2
  15. keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -3
  16. keras_hub/src/models/edrec/edrec_backbone.py +147 -0
  17. keras_hub/src/models/edrec/edrec_layers.py +434 -0
  18. keras_hub/src/models/edrec/edrec_seq2seq_lm.py +273 -0
  19. keras_hub/src/models/electra/electra_backbone.py +1 -3
  20. keras_hub/src/models/f_net/f_net_backbone.py +1 -3
  21. keras_hub/src/models/falcon/falcon_backbone.py +1 -3
  22. keras_hub/src/models/flux/flux_layers.py +3 -3
  23. keras_hub/src/models/flux/flux_maths.py +29 -15
  24. keras_hub/src/models/gemma/gemma_backbone.py +1 -3
  25. keras_hub/src/models/gemma/gemma_causal_lm.py +1 -1
  26. keras_hub/src/models/gemma3/gemma3_attention.py +1 -1
  27. keras_hub/src/models/gemma3/gemma3_backbone.py +70 -8
  28. keras_hub/src/models/gemma3/gemma3_causal_lm.py +16 -1
  29. keras_hub/src/models/gemma3/gemma3_decoder_block.py +1 -1
  30. keras_hub/src/models/gemma3/{gemma3_interleave_embeddings.py → gemma3_layers.py} +101 -0
  31. keras_hub/src/models/gemma3/gemma3_presets.py +67 -7
  32. keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
  33. keras_hub/src/models/gpt2/gpt2_backbone.py +1 -3
  34. keras_hub/src/models/gpt2/gpt2_causal_lm.py +1 -1
  35. keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +1 -3
  36. keras_hub/src/models/gpt_oss/gpt_oss_backbone.py +1 -3
  37. keras_hub/src/models/llama/llama_backbone.py +1 -3
  38. keras_hub/src/models/masked_lm.py +1 -1
  39. keras_hub/src/models/mistral/mistral_backbone.py +1 -3
  40. keras_hub/src/models/mixtral/mixtral_backbone.py +1 -3
  41. keras_hub/src/models/moonshine/moonshine_backbone.py +1 -3
  42. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +1 -3
  43. keras_hub/src/models/parseq/parseq_tokenizer.py +3 -1
  44. keras_hub/src/models/phi3/phi3_backbone.py +1 -3
  45. keras_hub/src/models/qwen/qwen_backbone.py +1 -3
  46. keras_hub/src/models/qwen/qwen_presets.py +209 -0
  47. keras_hub/src/models/qwen3/qwen3_backbone.py +1 -3
  48. keras_hub/src/models/qwen3_moe/qwen3_moe_backbone.py +1 -3
  49. keras_hub/src/models/qwen3_moe/qwen3_moe_presets.py +15 -0
  50. keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +1 -3
  51. keras_hub/src/models/roformer_v2/roformer_v2_backbone.py +1 -3
  52. keras_hub/src/models/rqvae/__init__.py +5 -0
  53. keras_hub/src/models/rqvae/rqvae_backbone.py +167 -0
  54. keras_hub/src/models/rqvae/rqvae_layers.py +335 -0
  55. keras_hub/src/models/rwkv7/__init__.py +5 -0
  56. keras_hub/src/models/rwkv7/rwkv7_backbone.py +180 -0
  57. keras_hub/src/models/rwkv7/rwkv7_causal_lm.py +259 -0
  58. keras_hub/src/models/rwkv7/rwkv7_causal_lm_preprocessor.py +214 -0
  59. keras_hub/src/models/rwkv7/rwkv7_layer.py +724 -0
  60. keras_hub/src/models/rwkv7/rwkv7_presets.py +26 -0
  61. keras_hub/src/models/rwkv7/rwkv7_tokenizer.py +495 -0
  62. keras_hub/src/models/sam/sam_backbone.py +5 -1
  63. keras_hub/src/models/sam/sam_prompt_encoder.py +1 -1
  64. keras_hub/src/models/sam3/__init__.py +7 -0
  65. keras_hub/src/models/sam3/roi_align.py +222 -0
  66. keras_hub/src/models/sam3/sam3_detr_decoder.py +641 -0
  67. keras_hub/src/models/sam3/sam3_detr_encoder.py +293 -0
  68. keras_hub/src/models/sam3/sam3_dot_product_scoring.py +120 -0
  69. keras_hub/src/models/sam3/sam3_geometry_encoder.py +517 -0
  70. keras_hub/src/models/sam3/sam3_image_converter.py +10 -0
  71. keras_hub/src/models/sam3/sam3_layers.py +814 -0
  72. keras_hub/src/models/sam3/sam3_mask_decoder.py +374 -0
  73. keras_hub/src/models/sam3/sam3_pc_backbone.py +306 -0
  74. keras_hub/src/models/sam3/sam3_pc_image_segmenter.py +282 -0
  75. keras_hub/src/models/sam3/sam3_pc_image_segmenter_preprocessor.py +336 -0
  76. keras_hub/src/models/sam3/sam3_presets.py +16 -0
  77. keras_hub/src/models/sam3/sam3_text_encoder.py +212 -0
  78. keras_hub/src/models/sam3/sam3_tokenizer.py +65 -0
  79. keras_hub/src/models/sam3/sam3_utils.py +134 -0
  80. keras_hub/src/models/sam3/sam3_vision_encoder.py +738 -0
  81. keras_hub/src/models/segformer/segformer_backbone.py +6 -6
  82. keras_hub/src/models/siglip/siglip_layers.py +1 -3
  83. keras_hub/src/models/smollm3/smollm3_backbone.py +1 -3
  84. keras_hub/src/models/stable_diffusion_3/t5_encoder.py +1 -3
  85. keras_hub/src/models/t5/t5_backbone.py +1 -3
  86. keras_hub/src/models/t5gemma/t5gemma_backbone.py +1 -3
  87. keras_hub/src/models/task.py +1 -1
  88. keras_hub/src/tests/test_case.py +394 -3
  89. keras_hub/src/tokenizers/byte_pair_tokenizer.py +33 -2
  90. keras_hub/src/tokenizers/byte_tokenizer.py +3 -1
  91. keras_hub/src/tokenizers/sentence_piece_tokenizer.py +15 -1
  92. keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +3 -1
  93. keras_hub/src/tokenizers/word_piece_tokenizer.py +15 -1
  94. keras_hub/src/utils/preset_utils.py +1 -1
  95. keras_hub/src/utils/tensor_utils.py +12 -0
  96. keras_hub/src/utils/transformers/convert_gemma3.py +68 -22
  97. keras_hub/src/utils/transformers/convert_qwen3_moe.py +4 -1
  98. keras_hub/src/utils/transformers/convert_sam3.py +472 -0
  99. keras_hub/src/utils/transformers/export/gemma3.py +196 -0
  100. keras_hub/src/utils/transformers/export/hf_exporter.py +86 -25
  101. keras_hub/src/utils/transformers/export/qwen.py +136 -0
  102. keras_hub/src/utils/transformers/preset_loader.py +15 -1
  103. keras_hub/src/version.py +1 -1
  104. keras_hub/tokenizers/__init__.py +6 -0
  105. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/METADATA +6 -13
  106. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/RECORD +108 -76
  107. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/WHEEL +1 -1
  108. keras_hub/src/models/gemma3/rms_normalization.py +0 -26
  109. {keras_hub-0.25.1.dist-info → keras_hub-0.26.0.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,147 @@
1
+ import keras
2
+
3
+ from keras_hub.src.api_export import keras_hub_export
4
+ from keras_hub.src.models.backbone import Backbone
5
+ from keras_hub.src.models.edrec.edrec_layers import EdRecDecoderBlock
6
+ from keras_hub.src.models.edrec.edrec_layers import EdRecEncoderBlock
7
+
8
+
9
+ @keras_hub_export("keras_hub.models.EdRecBackbone")
10
+ class EdRecBackbone(Backbone):
11
+ """EdRec Backbone model.
12
+
13
+ Args:
14
+ vocab_size: int, size of the vocabulary.
15
+ num_layers_enc: int, number of encoder layers.
16
+ num_layers_dec: int, number of decoder layers.
17
+ hidden_dim: int, hidden dimension (d_model).
18
+ intermediate_dim: int, intermediate dimension (d_ff).
19
+ num_heads: int, number of attention heads.
20
+ dropout: float, dropout rate.
21
+ epsilon: float, epsilon for simple RMSNorm.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ vocab_size,
27
+ num_layers_enc,
28
+ num_layers_dec,
29
+ hidden_dim,
30
+ intermediate_dim,
31
+ num_heads,
32
+ dropout=0.0,
33
+ epsilon=1e-6,
34
+ dtype=None,
35
+ **kwargs,
36
+ ):
37
+ # === Layers ===
38
+ self.embedding = keras.layers.Embedding(
39
+ input_dim=vocab_size,
40
+ output_dim=hidden_dim,
41
+ dtype=dtype,
42
+ name="embedding",
43
+ )
44
+ self.encoder_layers = []
45
+ for i in range(num_layers_enc):
46
+ self.encoder_layers.append(
47
+ EdRecEncoderBlock(
48
+ hidden_dim=hidden_dim,
49
+ num_heads=num_heads,
50
+ intermediate_dim=intermediate_dim,
51
+ dropout_rate=dropout,
52
+ epsilon=epsilon,
53
+ dtype=dtype,
54
+ name=f"encoder_layer_{i}",
55
+ )
56
+ )
57
+ self.decoder_layers = []
58
+ for i in range(num_layers_dec):
59
+ self.decoder_layers.append(
60
+ EdRecDecoderBlock(
61
+ hidden_dim=hidden_dim,
62
+ num_heads=num_heads,
63
+ intermediate_dim=intermediate_dim,
64
+ dropout_rate=dropout,
65
+ epsilon=epsilon,
66
+ dtype=dtype,
67
+ name=f"decoder_layer_{i}",
68
+ )
69
+ )
70
+
71
+ # === Functional Model ===
72
+ encoder_token_ids = keras.Input(
73
+ shape=(None,), dtype="int32", name="encoder_token_ids"
74
+ )
75
+ decoder_token_ids = keras.Input(
76
+ shape=(None,), dtype="int32", name="decoder_token_ids"
77
+ )
78
+ encoder_padding_mask = keras.Input(
79
+ shape=(None,), dtype="bool", name="encoder_padding_mask"
80
+ )
81
+ decoder_padding_mask = keras.Input(
82
+ shape=(None,), dtype="bool", name="decoder_padding_mask"
83
+ )
84
+
85
+ # Encoder
86
+ x_enc = self.embedding(encoder_token_ids)
87
+
88
+ for layer in self.encoder_layers:
89
+ x_enc = layer(
90
+ x_enc,
91
+ padding_mask=encoder_padding_mask,
92
+ )
93
+
94
+ # Decoder
95
+ x_dec = self.embedding(decoder_token_ids)
96
+ for layer in self.decoder_layers:
97
+ x_dec, _, _ = layer(
98
+ x_dec,
99
+ encoder_outputs=x_enc,
100
+ decoder_padding_mask=decoder_padding_mask,
101
+ encoder_padding_mask=encoder_padding_mask,
102
+ )
103
+
104
+ super().__init__(
105
+ inputs={
106
+ "encoder_token_ids": encoder_token_ids,
107
+ "decoder_token_ids": decoder_token_ids,
108
+ "encoder_padding_mask": encoder_padding_mask,
109
+ "decoder_padding_mask": decoder_padding_mask,
110
+ },
111
+ outputs={
112
+ "encoder_sequence_output": x_enc,
113
+ "decoder_sequence_output": x_dec,
114
+ },
115
+ dtype=dtype,
116
+ **kwargs,
117
+ )
118
+
119
+ # === Config ===
120
+ self.vocab_size = vocab_size
121
+ self.num_layers_enc = num_layers_enc
122
+ self.num_layers_dec = num_layers_dec
123
+ self.hidden_dim = hidden_dim
124
+ self.intermediate_dim = intermediate_dim
125
+ self.num_heads = num_heads
126
+ self.dropout = dropout
127
+ self.epsilon = epsilon
128
+
129
+ def get_config(self):
130
+ config = super().get_config()
131
+ config.update(
132
+ {
133
+ "vocab_size": self.vocab_size,
134
+ "num_layers_enc": self.num_layers_enc,
135
+ "num_layers_dec": self.num_layers_dec,
136
+ "hidden_dim": self.hidden_dim,
137
+ "intermediate_dim": self.intermediate_dim,
138
+ "num_heads": self.num_heads,
139
+ "dropout": self.dropout,
140
+ "epsilon": self.epsilon,
141
+ }
142
+ )
143
+ return config
144
+
145
+ @property
146
+ def token_embedding(self):
147
+ return self.embedding
@@ -0,0 +1,434 @@
1
+ import keras
2
+ from keras import ops
3
+
4
+ from keras_hub.src.layers.modeling.cached_multi_head_attention import (
5
+ CachedMultiHeadAttention,
6
+ )
7
+ from keras_hub.src.layers.modeling.transformer_layer_utils import (
8
+ compute_causal_mask,
9
+ )
10
+ from keras_hub.src.layers.modeling.transformer_layer_utils import (
11
+ merge_padding_and_attention_mask,
12
+ )
13
+
14
+
15
+ class EdRecRMSNormalization(keras.layers.Layer):
16
+ """RMSNorm layer that matches JAX EdRec implementation.
17
+
18
+ Attributes:
19
+ epsilon: float, epsilon value for numerical stability.
20
+ """
21
+
22
+ def __init__(self, epsilon=1e-6, **kwargs):
23
+ super().__init__(**kwargs)
24
+ self.epsilon = epsilon
25
+
26
+ def build(self, input_shape):
27
+ self.scale = self.add_weight(
28
+ name="scale",
29
+ shape=(input_shape[-1],),
30
+ initializer="ones",
31
+ trainable=True,
32
+ )
33
+ super().build(input_shape)
34
+
35
+ def call(self, x):
36
+ # JAX: rms = jnp.sqrt(jnp.mean(jnp.square(x), axis=-1, keepdims=True)
37
+ # + self.eps)
38
+ # JAX: normed = x / rms
39
+ # JAX: normed = normed * (1 + scale)
40
+
41
+ # Standard RMSNorm is x * scale / rms.
42
+ # EdRec RMSNorm is x * (1 + scale) / rms.
43
+ # Note: If scale is initialized to ones, (1+scale) starts at 2.
44
+
45
+ mean_square = ops.mean(ops.square(x), axis=-1, keepdims=True)
46
+ rms = ops.sqrt(mean_square + self.epsilon)
47
+ normed = x / rms
48
+ return normed * ops.cast(1.0 + self.scale, x.dtype)
49
+
50
+ def get_config(self):
51
+ config = super().get_config()
52
+ config.update({"epsilon": self.epsilon})
53
+ return config
54
+
55
+
56
+ class EdRecGatedFeedForward(keras.layers.Layer):
57
+ """Gated FeedForward (GLU-style) layer.
58
+
59
+ y = GELU(up_proj(x)) * gate_proj(x)
60
+ y = down_proj(y)
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ intermediate_dim,
66
+ hidden_dim,
67
+ dropout_rate=0.0,
68
+ activation="gelu",
69
+ kernel_initializer="glorot_uniform",
70
+ bias_initializer="zeros",
71
+ **kwargs,
72
+ ):
73
+ super().__init__(**kwargs)
74
+ self.intermediate_dim = intermediate_dim
75
+ self.hidden_dim = hidden_dim # The output dimension (d_model)
76
+ self.dropout_rate = dropout_rate
77
+ self.activation = activation
78
+ self.kernel_initializer = kernel_initializer
79
+ self.bias_initializer = bias_initializer
80
+
81
+ def build(self, input_shape):
82
+ self.up_proj = keras.layers.Dense(
83
+ self.intermediate_dim,
84
+ use_bias=False,
85
+ kernel_initializer=self.kernel_initializer,
86
+ dtype=self.dtype_policy,
87
+ name="up_proj",
88
+ )
89
+ self.gate_proj = keras.layers.Dense(
90
+ self.intermediate_dim,
91
+ use_bias=False,
92
+ kernel_initializer=self.kernel_initializer,
93
+ dtype=self.dtype_policy,
94
+ name="gate_proj",
95
+ )
96
+ self.down_proj = keras.layers.Dense(
97
+ self.hidden_dim,
98
+ use_bias=False,
99
+ kernel_initializer=self.kernel_initializer,
100
+ dtype=self.dtype_policy,
101
+ name="down_proj",
102
+ )
103
+ self.dropout = keras.layers.Dropout(
104
+ self.dropout_rate, dtype=self.dtype_policy, name="dropout"
105
+ )
106
+
107
+ def call(self, x, training=False):
108
+ # Up projection + activation (GELU)
109
+ h = self.up_proj(x)
110
+ if self.activation == "gelu":
111
+ h = keras.activations.gelu(h, approximate=True)
112
+ else:
113
+ h = keras.activations.get(self.activation)(h)
114
+
115
+ # Gate projection
116
+ g = self.gate_proj(x)
117
+
118
+ # Elementwise gating
119
+ y = h * g
120
+
121
+ # Down projection
122
+ y = self.down_proj(y)
123
+
124
+ # Dropout
125
+ if self.dropout_rate > 0.0:
126
+ y = self.dropout(y, training=training)
127
+
128
+ return y
129
+
130
+ def get_config(self):
131
+ config = super().get_config()
132
+ config.update(
133
+ {
134
+ "intermediate_dim": self.intermediate_dim,
135
+ "hidden_dim": self.hidden_dim,
136
+ "dropout_rate": self.dropout_rate,
137
+ "activation": self.activation,
138
+ "kernel_initializer": self.kernel_initializer,
139
+ "bias_initializer": self.bias_initializer,
140
+ }
141
+ )
142
+ return config
143
+
144
+
145
+ class EdRecEncoderBlock(keras.layers.Layer):
146
+ """EdRec Encoder Block.
147
+
148
+ Pre-norm: x = x + Dropout(Attention(RMSNorm(x))) x = x +
149
+ GatedFeedForward(RMSNorm(x))
150
+ """
151
+
152
+ def __init__(
153
+ self,
154
+ hidden_dim,
155
+ num_heads,
156
+ intermediate_dim,
157
+ dropout_rate=0.0,
158
+ epsilon=1e-6,
159
+ **kwargs,
160
+ ):
161
+ super().__init__(**kwargs)
162
+ self.hidden_dim = hidden_dim
163
+ self.num_heads = num_heads
164
+ self.intermediate_dim = intermediate_dim
165
+ self.dropout_rate = dropout_rate
166
+ self.epsilon = epsilon
167
+ self.head_dim = hidden_dim // num_heads
168
+
169
+ def build(self, input_shape):
170
+ self.pre_attention_norm = EdRecRMSNormalization(
171
+ epsilon=self.epsilon,
172
+ dtype=self.dtype_policy,
173
+ name="pre_attention_norm",
174
+ )
175
+ self.attention = keras.layers.MultiHeadAttention(
176
+ num_heads=self.num_heads,
177
+ key_dim=self.head_dim,
178
+ use_bias=False,
179
+ output_shape=self.hidden_dim,
180
+ dtype=self.dtype_policy,
181
+ name="attention",
182
+ )
183
+ self.dropout1 = keras.layers.Dropout(
184
+ self.dropout_rate, dtype=self.dtype_policy, name="dropout1"
185
+ )
186
+
187
+ self.pre_ffw_norm = EdRecRMSNormalization(
188
+ epsilon=self.epsilon, dtype=self.dtype_policy, name="pre_ffw_norm"
189
+ )
190
+ self.mlp = EdRecGatedFeedForward(
191
+ intermediate_dim=self.intermediate_dim,
192
+ hidden_dim=self.hidden_dim,
193
+ dropout_rate=self.dropout_rate,
194
+ dtype=self.dtype_policy,
195
+ name="mlp",
196
+ )
197
+
198
+ def call(self, x, padding_mask=None, training=False):
199
+ # Self Attention
200
+ residual = x
201
+ x_norm = self.pre_attention_norm(x)
202
+
203
+ # padding_mask is [B, L]
204
+ # We need to expand it to [B, 1, 1, L] for broadcasting against
205
+ # [B, H, L, L]
206
+ if padding_mask is not None:
207
+ padding_mask = merge_padding_and_attention_mask(
208
+ x, padding_mask, None
209
+ )
210
+
211
+ attn_out = self.attention(
212
+ query=x_norm,
213
+ value=x_norm,
214
+ attention_mask=padding_mask,
215
+ training=training,
216
+ )
217
+ attn_out = self.dropout1(attn_out, training=training)
218
+ x = residual + attn_out
219
+
220
+ # Feed Forward
221
+ residual = x
222
+ ff_norm = self.pre_ffw_norm(x)
223
+ ff_out = self.mlp(ff_norm, training=training)
224
+ x = residual + ff_out
225
+
226
+ return x
227
+
228
+ def get_config(self):
229
+ config = super().get_config()
230
+ config.update(
231
+ {
232
+ "hidden_dim": self.hidden_dim,
233
+ "num_heads": self.num_heads,
234
+ "intermediate_dim": self.intermediate_dim,
235
+ "dropout_rate": self.dropout_rate,
236
+ "epsilon": self.epsilon,
237
+ }
238
+ )
239
+ return config
240
+
241
+
242
+ class EdRecDecoderBlock(keras.layers.Layer):
243
+ """EdRec Decoder Block.
244
+
245
+ x = x + Dropout(SelfAttention(RMSNorm(x)))
246
+ x = x + Dropout(CrossAttention(RMSNorm(x), encoder_outputs))
247
+ x = x + GatedFeedForward(RMSNorm(x))
248
+ """
249
+
250
+ def __init__(
251
+ self,
252
+ hidden_dim,
253
+ num_heads,
254
+ intermediate_dim,
255
+ dropout_rate=0.0,
256
+ epsilon=1e-6,
257
+ **kwargs,
258
+ ):
259
+ super().__init__(**kwargs)
260
+ self.hidden_dim = hidden_dim
261
+ self.num_heads = num_heads
262
+ self.intermediate_dim = intermediate_dim
263
+ self.dropout_rate = dropout_rate
264
+ self.epsilon = epsilon
265
+ self.head_dim = hidden_dim // num_heads
266
+
267
+ def build(self, input_shape):
268
+ self.pre_self_attn_norm = EdRecRMSNormalization(
269
+ epsilon=self.epsilon,
270
+ dtype=self.dtype_policy,
271
+ name="pre_self_attn_norm",
272
+ )
273
+ self.self_attention = CachedMultiHeadAttention(
274
+ num_heads=self.num_heads,
275
+ key_dim=self.head_dim,
276
+ use_bias=False,
277
+ output_shape=self.hidden_dim,
278
+ dtype=self.dtype_policy,
279
+ name="self_attention",
280
+ )
281
+ self.dropout1 = keras.layers.Dropout(
282
+ self.dropout_rate, dtype=self.dtype_policy, name="dropout1"
283
+ )
284
+
285
+ self.pre_cross_attn_norm = EdRecRMSNormalization(
286
+ epsilon=self.epsilon,
287
+ dtype=self.dtype_policy,
288
+ name="pre_cross_attn_norm",
289
+ )
290
+ self.cross_attention = CachedMultiHeadAttention(
291
+ num_heads=self.num_heads,
292
+ key_dim=self.head_dim,
293
+ use_bias=False,
294
+ output_shape=self.hidden_dim,
295
+ dtype=self.dtype_policy,
296
+ name="cross_attention",
297
+ )
298
+ self.dropout2 = keras.layers.Dropout(
299
+ self.dropout_rate, dtype=self.dtype_policy, name="dropout2"
300
+ )
301
+
302
+ self.pre_ffw_norm = EdRecRMSNormalization(
303
+ epsilon=self.epsilon, dtype=self.dtype_policy, name="pre_ffw_norm"
304
+ )
305
+ self.mlp = EdRecGatedFeedForward(
306
+ intermediate_dim=self.intermediate_dim,
307
+ hidden_dim=self.hidden_dim,
308
+ dropout_rate=self.dropout_rate,
309
+ dtype=self.dtype_policy,
310
+ name="mlp",
311
+ )
312
+
313
+ def call(
314
+ self,
315
+ x,
316
+ encoder_outputs,
317
+ decoder_padding_mask=None,
318
+ encoder_padding_mask=None,
319
+ self_attention_cache=None,
320
+ self_attention_cache_update_index=None,
321
+ cross_attention_cache=None,
322
+ cross_attention_cache_update_index=None,
323
+ use_causal_mask=True,
324
+ training=False,
325
+ ):
326
+ # Self Attention
327
+ residual = x
328
+ x_norm = self.pre_self_attn_norm(x)
329
+
330
+ batch_size = ops.shape(x)[0]
331
+ input_length = ops.shape(x)[1]
332
+
333
+ total_length = input_length
334
+ if self_attention_cache is not None:
335
+ total_length = ops.shape(self_attention_cache)[2]
336
+
337
+ # Compute causal mask
338
+ causal_mask = None
339
+ if use_causal_mask:
340
+ causal_mask = compute_causal_mask(
341
+ batch_size,
342
+ total_length,
343
+ input_length,
344
+ 0
345
+ if self_attention_cache_update_index is None
346
+ else self_attention_cache_update_index,
347
+ )
348
+
349
+ # Merge with padding mask
350
+ self_attn_mask = causal_mask
351
+ if decoder_padding_mask is not None:
352
+ # decoder_padding_mask is [B, L_dec]
353
+ # merge_padding_and_attention_mask gives [B, 1, L, L]
354
+ padding_mask_merged = merge_padding_and_attention_mask(
355
+ x, decoder_padding_mask, None
356
+ )
357
+
358
+ if causal_mask is not None:
359
+ self_attn_mask = ops.minimum(padding_mask_merged, causal_mask)
360
+ else:
361
+ self_attn_mask = padding_mask_merged
362
+
363
+ self_attn_out = self.self_attention(
364
+ query=x_norm,
365
+ value=x_norm,
366
+ attention_mask=self_attn_mask,
367
+ cache=self_attention_cache,
368
+ cache_update_index=self_attention_cache_update_index,
369
+ training=training,
370
+ )
371
+
372
+ if self_attention_cache is not None:
373
+ self_attn_out, self_attention_cache = self_attn_out
374
+
375
+ self_attn_out = self.dropout1(self_attn_out, training=training)
376
+ x = residual + self_attn_out
377
+
378
+ # Cross Attention
379
+ residual = x
380
+ x_norm = self.pre_cross_attn_norm(x)
381
+
382
+ cross_mask = None
383
+ if encoder_padding_mask is not None:
384
+ cross_mask = merge_padding_and_attention_mask(
385
+ encoder_outputs, encoder_padding_mask, None
386
+ )
387
+
388
+ cross_attn_out = self.cross_attention(
389
+ query=x_norm,
390
+ value=encoder_outputs,
391
+ attention_mask=cross_mask,
392
+ cache=cross_attention_cache,
393
+ cache_update_index=cross_attention_cache_update_index,
394
+ training=training,
395
+ )
396
+
397
+ if cross_attention_cache is not None:
398
+ cross_attn_out, cross_attention_cache = cross_attn_out
399
+
400
+ cross_attn_out = self.dropout2(cross_attn_out, training=training)
401
+ x = residual + cross_attn_out
402
+
403
+ # Feed Forward
404
+ residual = x
405
+ ff_norm = self.pre_ffw_norm(x)
406
+ ff_out = self.mlp(ff_norm, training=training)
407
+ x = residual + ff_out
408
+
409
+ if self_attention_cache is not None:
410
+ if cross_attention_cache is not None:
411
+ return x, self_attention_cache, cross_attention_cache
412
+ return (
413
+ x,
414
+ self_attention_cache,
415
+ ops.zeros((0,), dtype=self.compute_dtype),
416
+ )
417
+ return (
418
+ x,
419
+ ops.zeros((0,), dtype=self.compute_dtype),
420
+ ops.zeros((0,), dtype=self.compute_dtype),
421
+ )
422
+
423
+ def get_config(self):
424
+ config = super().get_config()
425
+ config.update(
426
+ {
427
+ "hidden_dim": self.hidden_dim,
428
+ "num_heads": self.num_heads,
429
+ "intermediate_dim": self.intermediate_dim,
430
+ "dropout_rate": self.dropout_rate,
431
+ "epsilon": self.epsilon,
432
+ }
433
+ )
434
+ return config