keras-hub-nightly 0.23.0.dev202508240418__py3-none-any.whl → 0.23.0.dev202508260411__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,355 @@
1
+ import keras
2
+
3
+ from keras_hub.src.models.gemma.rms_normalization import RMSNormalization
4
+ from keras_hub.src.models.t5gemma.t5gemma_attention import T5GemmaAttention
5
+ from keras_hub.src.models.t5gemma.t5gemma_layers import T5GemmaMLP
6
+
7
+
8
+ class T5GemmaDecoderLayer(keras.layers.Layer):
9
+ """Decoder layer for the T5Gemma model.
10
+
11
+ This layer implements a single decoder block in the T5Gemma architecture,
12
+ comprising self-attention, cross-attention, and a feed-forward network
13
+ (MLP).
14
+
15
+ Args:
16
+ hidden_size: int, The dimensionality of the hidden states.
17
+ rms_norm_eps: float, The epsilon value for RMS normalization.
18
+ num_attention_heads: int, The number of attention heads in
19
+ self-attention and cross-attention.
20
+ num_key_value_heads: int, The number of key-value heads for grouped
21
+ query attention.
22
+ query_pre_attn_scalar: float, Scalar to multiply queries by before
23
+ attention.
24
+ attention_bias: bool, Whether to include bias in attention computations.
25
+ intermediate_size: int, The intermediate size of the feed-forward
26
+ network.
27
+ hidden_activation: str, The activation function used in the feed-forward
28
+ network.
29
+ dropout_rate: float, The dropout rate applied after attention and MLP.
30
+ head_dim: int, The dimensionality of each attention head.
31
+ initializer_range: float, The range for the random normal initializer.
32
+ attention_dropout: float, The dropout rate applied to attention weights.
33
+ layer_type: str, Type of attention layer, e.g., `"sliding_attention"`.
34
+ cross_attention_hidden_size: int, optional, The hidden size for
35
+ cross-attention. If None, it defaults to `hidden_size`. Defaults to
36
+ `None`.
37
+ attn_logit_softcapping: float, optional, The softcapping value for
38
+ attention logits. Defaults to `None`.
39
+ sliding_window: int, optional, The window size for sliding attention.
40
+ Required if `layer_type` is `"sliding_attention"`. Defaults to
41
+ `None`.
42
+ rope_max_wavelength: float, The maximum wavelength for Rotary
43
+ Positional Embeddings. Defaults to `10000.0`.
44
+ dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
45
+ for model computations and weights. Defaults to `None`.
46
+ **kwargs: Additional keyword arguments passed to the parent class.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ hidden_size,
52
+ rms_norm_eps,
53
+ num_attention_heads,
54
+ num_key_value_heads,
55
+ query_pre_attn_scalar,
56
+ attention_bias,
57
+ intermediate_size,
58
+ hidden_activation,
59
+ dropout_rate,
60
+ head_dim,
61
+ initializer_range,
62
+ attention_dropout,
63
+ layer_type,
64
+ cross_attention_hidden_size=None,
65
+ attn_logit_softcapping=None,
66
+ sliding_window=None,
67
+ rope_max_wavelength=10000.0,
68
+ dtype=None,
69
+ **kwargs,
70
+ ):
71
+ super().__init__(dtype=dtype, **kwargs)
72
+ self.head_dim = head_dim
73
+ self.hidden_size = hidden_size
74
+ self.rms_norm_eps = rms_norm_eps
75
+ self.num_attention_heads = num_attention_heads
76
+ self.num_key_value_heads = num_key_value_heads
77
+ self.query_pre_attn_scalar = query_pre_attn_scalar
78
+ self.attention_bias = attention_bias
79
+ self.intermediate_size = intermediate_size
80
+ self.hidden_activation = hidden_activation
81
+ self.dropout_rate = dropout_rate
82
+ self.initializer_range = initializer_range
83
+ self.attention_dropout = attention_dropout
84
+ self.layer_type = layer_type
85
+ self.sliding_window = sliding_window
86
+ self.rope_max_wavelength = rope_max_wavelength
87
+ self.cross_attention_hidden_size = cross_attention_hidden_size
88
+ self.attn_logit_softcapping = attn_logit_softcapping
89
+ if (
90
+ self.layer_type == "sliding_attention"
91
+ and self.sliding_window is None
92
+ ):
93
+ raise ValueError(
94
+ "`sliding_window` must be set for `sliding_attention` layer "
95
+ "type."
96
+ )
97
+
98
+ # Self-attention.
99
+ self.self_attn = T5GemmaAttention(
100
+ hidden_size=hidden_size,
101
+ num_attention_heads=num_attention_heads,
102
+ num_key_value_heads=num_key_value_heads,
103
+ query_pre_attn_scalar=query_pre_attn_scalar,
104
+ attention_bias=attention_bias,
105
+ head_dim=self.head_dim,
106
+ attention_type="self",
107
+ initializer_range=initializer_range,
108
+ attention_dropout=attention_dropout,
109
+ attn_logit_softcapping=attn_logit_softcapping,
110
+ rope_max_wavelength=self.rope_max_wavelength,
111
+ dtype=self.dtype_policy,
112
+ name="self_attention",
113
+ )
114
+ self.pre_self_attn_layernorm = RMSNormalization(
115
+ epsilon=rms_norm_eps,
116
+ dtype=self.dtype_policy,
117
+ name="decoder_pre_self_attention_layernorm",
118
+ )
119
+ self.post_self_attn_layernorm = RMSNormalization(
120
+ epsilon=rms_norm_eps,
121
+ dtype=self.dtype_policy,
122
+ name="decoder_post_self_attention_layernorm",
123
+ )
124
+
125
+ # Cross-attention.
126
+ self.cross_attn = T5GemmaAttention(
127
+ hidden_size=hidden_size,
128
+ cross_attention_hidden_size=cross_attention_hidden_size,
129
+ num_attention_heads=num_attention_heads,
130
+ num_key_value_heads=num_key_value_heads,
131
+ query_pre_attn_scalar=query_pre_attn_scalar,
132
+ attention_bias=attention_bias,
133
+ head_dim=self.head_dim,
134
+ attention_type="cross",
135
+ initializer_range=initializer_range,
136
+ attention_dropout=attention_dropout,
137
+ attn_logit_softcapping=attn_logit_softcapping,
138
+ dtype=self.dtype_policy,
139
+ name="cross_attention",
140
+ )
141
+ self.pre_cross_attn_layernorm = RMSNormalization(
142
+ epsilon=rms_norm_eps,
143
+ dtype=self.dtype_policy,
144
+ name="decoder_pre_cross_attention_layernorm",
145
+ )
146
+ self.post_cross_attn_layernorm = RMSNormalization(
147
+ epsilon=rms_norm_eps,
148
+ dtype=self.dtype_policy,
149
+ name="decoder_post_cross_attention_layernorm",
150
+ )
151
+
152
+ # MLP.
153
+ self.mlp = T5GemmaMLP(
154
+ hidden_size,
155
+ intermediate_size,
156
+ hidden_activation,
157
+ dropout_rate,
158
+ initializer_range=initializer_range,
159
+ dtype=self.dtype_policy,
160
+ name="mlp",
161
+ )
162
+ self.pre_feedforward_layernorm = RMSNormalization(
163
+ epsilon=rms_norm_eps,
164
+ dtype=self.dtype_policy,
165
+ name="decoder_pre_feedforward_layernorm",
166
+ )
167
+ self.post_feedforward_layernorm = RMSNormalization(
168
+ epsilon=rms_norm_eps,
169
+ dtype=self.dtype_policy,
170
+ name="decoder_post_feedforward_layernorm",
171
+ )
172
+
173
+ self.dropout = keras.layers.Dropout(
174
+ dropout_rate,
175
+ dtype=self.dtype_policy,
176
+ name="decoder_residual_dropout",
177
+ )
178
+
179
+ def build(self, input_shape):
180
+ hidden_states_shape, encoder_hidden_states_shape = input_shape
181
+ self.pre_self_attn_layernorm.build(hidden_states_shape)
182
+ current_shape = hidden_states_shape
183
+ self.self_attn.build(current_shape)
184
+ attn_output_shape, _ = self.self_attn.compute_output_shape(
185
+ current_shape
186
+ )
187
+ self.post_self_attn_layernorm.build(attn_output_shape)
188
+ current_shape = attn_output_shape
189
+ self.dropout.build(current_shape)
190
+ self.pre_cross_attn_layernorm.build(current_shape)
191
+ self.cross_attn.build([current_shape, encoder_hidden_states_shape])
192
+ attn_output_shape, _ = self.cross_attn.compute_output_shape(
193
+ [current_shape, encoder_hidden_states_shape]
194
+ )
195
+ self.post_cross_attn_layernorm.build(attn_output_shape)
196
+ current_shape = attn_output_shape
197
+ self.pre_feedforward_layernorm.build(current_shape)
198
+ self.mlp.build(current_shape)
199
+ mlp_output_shape = self.mlp.compute_output_shape(current_shape)
200
+ self.post_feedforward_layernorm.build(mlp_output_shape)
201
+ self.built = True
202
+
203
+ def _make_self_attention_mask(
204
+ self,
205
+ hidden_states,
206
+ padding_mask,
207
+ cache=None,
208
+ cache_update_index=None,
209
+ ):
210
+ if cache is not None:
211
+ q_len = keras.ops.shape(hidden_states)[1]
212
+ kv_len = keras.ops.shape(cache)[2]
213
+ q_indices = (
214
+ keras.ops.arange(0, q_len, dtype="int32") + cache_update_index
215
+ )
216
+ kv_indices = keras.ops.arange(0, kv_len, dtype="int32")
217
+ else:
218
+ q_len = kv_len = keras.ops.shape(hidden_states)[1]
219
+ q_indices = keras.ops.arange(0, q_len, dtype="int32")
220
+ kv_indices = keras.ops.arange(0, kv_len, dtype="int32")
221
+ # Create the causal mask.
222
+ causal_mask = kv_indices[None, :] <= q_indices[:, None]
223
+ # Apply sliding window if applicable.
224
+ if self.layer_type == "sliding_attention":
225
+ sliding_mask = (
226
+ q_indices[:, None] - self.sliding_window
227
+ ) <= kv_indices[None, :]
228
+ causal_mask = keras.ops.logical_and(causal_mask, sliding_mask)
229
+ # Combine with padding mask.
230
+ final_mask = causal_mask[None, None, :, :]
231
+ if padding_mask is not None:
232
+ padding_mask_slice = padding_mask[:, :kv_len]
233
+ padding_mask_4d = padding_mask_slice[:, None, None, :]
234
+ final_mask = keras.ops.logical_and(final_mask, padding_mask_4d)
235
+ return (1.0 - keras.ops.cast(final_mask, hidden_states.dtype)) * -1e9
236
+
237
+ def _make_cross_attention_mask(self, hidden_states, padding_mask):
238
+ if padding_mask is None:
239
+ return None
240
+ bidirectional_mask = padding_mask[:, None, None, :]
241
+ additive_bidirectional_mask = (
242
+ 1.0 - keras.ops.cast(bidirectional_mask, hidden_states.dtype)
243
+ ) * -1e9
244
+ return additive_bidirectional_mask
245
+
246
+ def call(
247
+ self,
248
+ inputs,
249
+ self_attention_padding_mask=None,
250
+ cross_attention_padding_mask=None,
251
+ cache=None,
252
+ cache_update_index=None,
253
+ training=None,
254
+ ):
255
+ hidden_states, encoder_hidden_states = inputs
256
+ self_attention_cache, cross_attention_cache = (
257
+ cache if cache is not None else (None, None)
258
+ )
259
+ # Self Attention.
260
+ residual = hidden_states
261
+ self_attention_mask = self._make_self_attention_mask(
262
+ hidden_states,
263
+ self_attention_padding_mask,
264
+ cache=self_attention_cache,
265
+ cache_update_index=cache_update_index,
266
+ )
267
+ hidden_states = self.pre_self_attn_layernorm(hidden_states)
268
+ hidden_states, updated_self_attention_cache = self.self_attn(
269
+ inputs=hidden_states,
270
+ attention_mask=self_attention_mask,
271
+ cache=self_attention_cache,
272
+ cache_update_index=cache_update_index,
273
+ training=training,
274
+ )
275
+ hidden_states = self.post_self_attn_layernorm(hidden_states)
276
+ hidden_states = residual + self.dropout(
277
+ hidden_states, training=training
278
+ )
279
+
280
+ # Cross Attention.
281
+ residual = hidden_states
282
+ cross_attention_mask = self._make_cross_attention_mask(
283
+ encoder_hidden_states, cross_attention_padding_mask
284
+ )
285
+ hidden_states = self.pre_cross_attn_layernorm(hidden_states)
286
+ hidden_states, updated_cross_attention_cache = self.cross_attn(
287
+ inputs=[hidden_states, encoder_hidden_states],
288
+ attention_mask=cross_attention_mask,
289
+ cache=cross_attention_cache,
290
+ training=training,
291
+ )
292
+
293
+ hidden_states = self.post_cross_attn_layernorm(hidden_states)
294
+ hidden_states = residual + self.dropout(
295
+ hidden_states, training=training
296
+ )
297
+
298
+ # MLP.
299
+ residual = hidden_states
300
+ hidden_states = self.pre_feedforward_layernorm(hidden_states)
301
+ hidden_states = self.mlp(hidden_states, training=training)
302
+ hidden_states = self.post_feedforward_layernorm(hidden_states)
303
+ hidden_states = residual + self.dropout(
304
+ hidden_states, training=training
305
+ )
306
+ updated_cache = (
307
+ updated_self_attention_cache,
308
+ updated_cross_attention_cache,
309
+ )
310
+ return hidden_states, updated_cache
311
+
312
+ def compute_output_shape(self, input_shape):
313
+ hidden_states_shape, encoder_hidden_states_shape = input_shape
314
+ batch_size, dec_seq_len, _ = hidden_states_shape
315
+ _, enc_seq_len, _ = encoder_hidden_states_shape
316
+ self_cache_shape = (
317
+ batch_size,
318
+ 2,
319
+ dec_seq_len,
320
+ self.num_key_value_heads,
321
+ self.head_dim,
322
+ )
323
+ cross_cache_shape = (
324
+ batch_size,
325
+ 2,
326
+ enc_seq_len,
327
+ self.num_key_value_heads,
328
+ self.head_dim,
329
+ )
330
+ return hidden_states_shape, (self_cache_shape, cross_cache_shape)
331
+
332
+ def get_config(self):
333
+ config = super().get_config()
334
+ config.update(
335
+ {
336
+ "hidden_size": self.hidden_size,
337
+ "rms_norm_eps": self.rms_norm_eps,
338
+ "num_attention_heads": self.num_attention_heads,
339
+ "num_key_value_heads": self.num_key_value_heads,
340
+ "query_pre_attn_scalar": self.query_pre_attn_scalar,
341
+ "attention_bias": self.attention_bias,
342
+ "intermediate_size": self.intermediate_size,
343
+ "hidden_activation": self.hidden_activation,
344
+ "dropout_rate": self.dropout_rate,
345
+ "initializer_range": self.initializer_range,
346
+ "attention_dropout": self.attention_dropout,
347
+ "layer_type": self.layer_type,
348
+ "sliding_window": self.sliding_window,
349
+ "rope_max_wavelength": self.rope_max_wavelength,
350
+ "head_dim": self.head_dim,
351
+ "cross_attention_hidden_size": self.cross_attention_hidden_size,
352
+ "attn_logit_softcapping": self.attn_logit_softcapping,
353
+ }
354
+ )
355
+ return config
@@ -0,0 +1,214 @@
1
+ import keras
2
+
3
+ from keras_hub.src.models.gemma.rms_normalization import RMSNormalization
4
+ from keras_hub.src.models.t5gemma.t5gemma_attention import T5GemmaAttention
5
+ from keras_hub.src.models.t5gemma.t5gemma_layers import T5GemmaMLP
6
+
7
+
8
+ class T5GemmaEncoderLayer(keras.layers.Layer):
9
+ """Encoder layer for the T5Gemma model.
10
+
11
+ This layer implements a single encoder block in the T5Gemma architecture,
12
+ comprising self-attention and a feed-forward network (MLP).
13
+
14
+ Args:
15
+ hidden_size: int, The dimensionality of the hidden states.
16
+ rms_norm_eps: float, The epsilon value for RMS normalization.
17
+ num_attention_heads: int, The number of attention heads in
18
+ self-attention.
19
+ num_key_value_heads: int, The number of key-value heads for grouped
20
+ query attention.
21
+ query_pre_attn_scalar: float, Scalar to multiply queries by before
22
+ attention.
23
+ attention_bias: bool, Whether to include bias in attention computations.
24
+ intermediate_size: int, The intermediate size of the feed-forward
25
+ network.
26
+ hidden_activation: str, The activation function used in the feed-forward
27
+ network.
28
+ dropout_rate: float, The dropout rate applied after attention and MLP.
29
+ initializer_range: float, The range for the random normal initializer.
30
+ attention_dropout: float, The dropout rate applied to attention weights.
31
+ layer_type: str, Type of attention layer, e.g., `"sliding_attention"`.
32
+ head_dim: int, The dimensionality of each attention head.
33
+ attn_logit_softcapping: float, optional, The softcapping value for
34
+ attention logits. Defaults to `None`.
35
+ sliding_window: int, optional, The window size for sliding attention.
36
+ Required if `layer_type` is `"sliding_attention"`. Defaults to
37
+ `None`.
38
+ rope_max_wavelength: float, The maximum wavelength for Rotary Positional
39
+ Embeddings. Defaults to `10000.0`.
40
+ dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
41
+ for model computations and weights. Defaults to `None`.
42
+ **kwargs: Additional keyword arguments passed to the parent class.
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ hidden_size,
48
+ rms_norm_eps,
49
+ num_attention_heads,
50
+ num_key_value_heads,
51
+ query_pre_attn_scalar,
52
+ attention_bias,
53
+ intermediate_size,
54
+ hidden_activation,
55
+ dropout_rate,
56
+ initializer_range,
57
+ attention_dropout,
58
+ layer_type,
59
+ head_dim,
60
+ attn_logit_softcapping=None,
61
+ sliding_window=None,
62
+ rope_max_wavelength=10000.0,
63
+ dtype=None,
64
+ **kwargs,
65
+ ):
66
+ super().__init__(dtype=dtype, **kwargs)
67
+ self.hidden_size = hidden_size
68
+ self.rms_norm_eps = rms_norm_eps
69
+ self.num_attention_heads = num_attention_heads
70
+ self.num_key_value_heads = num_key_value_heads
71
+ self.query_pre_attn_scalar = query_pre_attn_scalar
72
+ self.attention_bias = attention_bias
73
+ self.intermediate_size = intermediate_size
74
+ self.hidden_activation = hidden_activation
75
+ self.dropout_rate = dropout_rate
76
+ self.initializer_range = initializer_range
77
+ self.attention_dropout = attention_dropout
78
+ self.layer_type = layer_type
79
+ self.sliding_window = sliding_window
80
+ self.rope_max_wavelength = rope_max_wavelength
81
+ self.head_dim = head_dim
82
+ self.attn_logit_softcapping = attn_logit_softcapping
83
+ if (
84
+ self.layer_type == "sliding_attention"
85
+ and self.sliding_window is None
86
+ ):
87
+ raise ValueError(
88
+ "`sliding_window` must be set for `sliding_attention` layer "
89
+ "type."
90
+ )
91
+ self.self_attn = T5GemmaAttention(
92
+ hidden_size=hidden_size,
93
+ num_attention_heads=num_attention_heads,
94
+ num_key_value_heads=num_key_value_heads,
95
+ query_pre_attn_scalar=query_pre_attn_scalar,
96
+ attention_bias=attention_bias,
97
+ head_dim=self.head_dim,
98
+ attention_type="self",
99
+ initializer_range=initializer_range,
100
+ attention_dropout=attention_dropout,
101
+ attn_logit_softcapping=attn_logit_softcapping,
102
+ rope_max_wavelength=self.rope_max_wavelength,
103
+ dtype=self.dtype_policy,
104
+ name="self_attention",
105
+ )
106
+ self.pre_self_attn_layernorm = RMSNormalization(
107
+ epsilon=rms_norm_eps,
108
+ dtype=self.dtype_policy,
109
+ name="pre_self_attention_layernorm",
110
+ )
111
+ self.post_self_attn_layernorm = RMSNormalization(
112
+ epsilon=rms_norm_eps,
113
+ dtype=self.dtype_policy,
114
+ name="post_self_attention_layernorm",
115
+ )
116
+
117
+ self.mlp = T5GemmaMLP(
118
+ hidden_size,
119
+ intermediate_size,
120
+ hidden_activation,
121
+ dropout_rate,
122
+ initializer_range=initializer_range,
123
+ dtype=self.dtype_policy,
124
+ name="mlp",
125
+ )
126
+ self.pre_feedforward_layernorm = RMSNormalization(
127
+ epsilon=rms_norm_eps,
128
+ dtype=self.dtype_policy,
129
+ name="pre_feedforward_layernorm",
130
+ )
131
+ self.post_feedforward_layernorm = RMSNormalization(
132
+ epsilon=rms_norm_eps,
133
+ dtype=self.dtype_policy,
134
+ name="post_feedforward_layernorm",
135
+ )
136
+ self.dropout = keras.layers.Dropout(
137
+ dropout_rate,
138
+ dtype=self.dtype_policy,
139
+ name="residual_dropout",
140
+ )
141
+
142
+ def build(self, input_shape):
143
+ self.pre_self_attn_layernorm.build(input_shape)
144
+ self.self_attn.build(input_shape)
145
+ attn_output_shape, _ = self.self_attn.compute_output_shape(input_shape)
146
+ self.post_self_attn_layernorm.build(attn_output_shape)
147
+ self.dropout.build(attn_output_shape)
148
+ self.pre_feedforward_layernorm.build(attn_output_shape)
149
+ self.mlp.build(attn_output_shape)
150
+ mlp_output_shape = self.mlp.compute_output_shape(attn_output_shape)
151
+ self.post_feedforward_layernorm.build(mlp_output_shape)
152
+ self.built = True
153
+
154
+ def _make_attention_mask(self, hidden_states, padding_mask):
155
+ attention_mask = padding_mask[:, None, None, :]
156
+ additive_mask = (
157
+ 1.0 - keras.ops.cast(attention_mask, hidden_states.dtype)
158
+ ) * -1e9
159
+ return additive_mask
160
+
161
+ def call(
162
+ self,
163
+ hidden_states,
164
+ padding_mask=None,
165
+ training=None,
166
+ ):
167
+ residual = hidden_states
168
+ attention_mask = self._make_attention_mask(hidden_states, padding_mask)
169
+ hidden_states = self.pre_self_attn_layernorm(hidden_states)
170
+ hidden_states, _ = self.self_attn(
171
+ inputs=hidden_states,
172
+ attention_mask=attention_mask,
173
+ training=training,
174
+ )
175
+ hidden_states = self.post_self_attn_layernorm(hidden_states)
176
+ hidden_states = residual + self.dropout(
177
+ hidden_states, training=training
178
+ )
179
+ residual = hidden_states
180
+ hidden_states = self.pre_feedforward_layernorm(hidden_states)
181
+ hidden_states = self.mlp(hidden_states, training=training)
182
+ hidden_states = self.post_feedforward_layernorm(hidden_states)
183
+ hidden_states = residual + self.dropout(
184
+ hidden_states, training=training
185
+ )
186
+ return hidden_states
187
+
188
+ def compute_output_shape(self, input_shape):
189
+ # Isometric.
190
+ return input_shape
191
+
192
+ def get_config(self):
193
+ config = super().get_config()
194
+ config.update(
195
+ {
196
+ "hidden_size": self.hidden_size,
197
+ "rms_norm_eps": self.rms_norm_eps,
198
+ "head_dim": self.head_dim,
199
+ "num_attention_heads": self.num_attention_heads,
200
+ "num_key_value_heads": self.num_key_value_heads,
201
+ "query_pre_attn_scalar": self.query_pre_attn_scalar,
202
+ "attention_bias": self.attention_bias,
203
+ "intermediate_size": self.intermediate_size,
204
+ "hidden_activation": self.hidden_activation,
205
+ "dropout_rate": self.dropout_rate,
206
+ "initializer_range": self.initializer_range,
207
+ "attention_dropout": self.attention_dropout,
208
+ "layer_type": self.layer_type,
209
+ "sliding_window": self.sliding_window,
210
+ "rope_max_wavelength": self.rope_max_wavelength,
211
+ "attn_logit_softcapping": self.attn_logit_softcapping,
212
+ }
213
+ )
214
+ return config
@@ -0,0 +1,118 @@
1
+ import keras
2
+
3
+ from keras_hub.src.utils.keras_utils import clone_initializer
4
+
5
+
6
+ def t5gemma_kernel_initializer(initializer_range=0.01):
7
+ """Creates a RandomNormal initializer for T5Gemma kernels.
8
+
9
+ Args:
10
+ initializer_range: float, The standard deviation of the normal
11
+ distribution. Defaults to `0.01`.
12
+
13
+ Returns:
14
+ keras.initializers.RandomNormal: A Keras RandomNormal initializer.
15
+ """
16
+ return keras.initializers.RandomNormal(mean=0.0, stddev=initializer_range)
17
+
18
+
19
+ class T5GemmaMLP(keras.layers.Layer):
20
+ """Multilayer Perceptron (MLP) block for the T5Gemma model.
21
+
22
+ This layer implements the feed-forward part of a transformer block,
23
+ consisting of two dense layers with a GELU activation and dropout.
24
+
25
+ Args:
26
+ hidden_size: int, The dimensionality of the input and output hidden
27
+ states.
28
+ intermediate_size: int, The dimensionality of the intermediate layer.
29
+ hidden_activation: str, The activation function to use, e.g.,
30
+ "gelu_approximate".
31
+ dropout_rate: float, The dropout rate applied to the intermediate
32
+ hidden states.
33
+ initializer_range: float, The range for the random normal initializer
34
+ for kernel weights. Defaults to `0.02`.
35
+ dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
36
+ for model computations and weights. Defaults to `None`.
37
+ **kwargs: Additional keyword arguments passed to the parent class.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ hidden_size,
43
+ intermediate_size,
44
+ hidden_activation,
45
+ dropout_rate,
46
+ initializer_range=0.02,
47
+ dtype=None,
48
+ **kwargs,
49
+ ):
50
+ super().__init__(dtype=dtype, **kwargs)
51
+ self.hidden_size = hidden_size
52
+ self.intermediate_size = intermediate_size
53
+ self.hidden_activation = hidden_activation
54
+ self.dropout_rate = dropout_rate
55
+ self.initializer_range = initializer_range
56
+ self.kernel_initializer = t5gemma_kernel_initializer(initializer_range)
57
+
58
+ self.gate_proj = keras.layers.Dense(
59
+ self.intermediate_size,
60
+ use_bias=False,
61
+ kernel_initializer=clone_initializer(self.kernel_initializer),
62
+ dtype=self.dtype_policy,
63
+ name="gate_proj",
64
+ )
65
+ self.up_proj = keras.layers.Dense(
66
+ self.intermediate_size,
67
+ use_bias=False,
68
+ kernel_initializer=clone_initializer(self.kernel_initializer),
69
+ dtype=self.dtype_policy,
70
+ name="up_proj",
71
+ )
72
+ self.down_proj = keras.layers.Dense(
73
+ self.hidden_size,
74
+ use_bias=False,
75
+ kernel_initializer=clone_initializer(self.kernel_initializer),
76
+ dtype=self.dtype_policy,
77
+ name="down_proj",
78
+ )
79
+ if self.hidden_activation == "gelu_approximate":
80
+ # NOTE: `gelu_pytorch_tanh` is the same as `gelu(approximate=True)`.
81
+ self.act_fn = lambda x: keras.activations.gelu(x, approximate=True)
82
+ else:
83
+ self.act_fn = keras.activations.get(self.hidden_activation)
84
+ self.dropout = keras.layers.Dropout(
85
+ self.dropout_rate,
86
+ dtype=self.dtype_policy,
87
+ name="mlp_dropout",
88
+ )
89
+
90
+ def build(self, input_shape):
91
+ self.gate_proj.build(input_shape)
92
+ self.up_proj.build(input_shape)
93
+ intermediate_shape = self.gate_proj.compute_output_shape(input_shape)
94
+ self.dropout.build(intermediate_shape)
95
+ self.down_proj.build(intermediate_shape)
96
+ self.built = True
97
+
98
+ def compute_output_shape(self, input_shape):
99
+ return input_shape
100
+
101
+ def call(self, x, training=None):
102
+ hidden_states = self.act_fn(self.gate_proj(x)) * self.up_proj(x)
103
+ hidden_states = self.dropout(hidden_states, training=training)
104
+ down_proj = self.down_proj(hidden_states)
105
+ return down_proj
106
+
107
+ def get_config(self):
108
+ config = super().get_config()
109
+ config.update(
110
+ {
111
+ "hidden_size": self.hidden_size,
112
+ "intermediate_size": self.intermediate_size,
113
+ "hidden_activation": self.hidden_activation,
114
+ "dropout_rate": self.dropout_rate,
115
+ "initializer_range": self.initializer_range,
116
+ }
117
+ )
118
+ return config