keras-hub-nightly 0.23.0.dev202508240418__py3-none-any.whl → 0.23.0.dev202508260411__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/models/__init__.py +12 -0
- keras_hub/src/models/t5gemma/__init__.py +5 -0
- keras_hub/src/models/t5gemma/t5gemma_attention.py +370 -0
- keras_hub/src/models/t5gemma/t5gemma_backbone.py +366 -0
- keras_hub/src/models/t5gemma/t5gemma_decoder.py +355 -0
- keras_hub/src/models/t5gemma/t5gemma_encoder.py +214 -0
- keras_hub/src/models/t5gemma/t5gemma_layers.py +118 -0
- keras_hub/src/models/t5gemma/t5gemma_presets.py +15 -0
- keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm.py +442 -0
- keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm_preprocessor.py +216 -0
- keras_hub/src/models/t5gemma/t5gemma_tokenizer.py +84 -0
- keras_hub/src/utils/transformers/convert_t5gemma.py +229 -0
- keras_hub/src/utils/transformers/preset_loader.py +3 -0
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +3 -0
- {keras_hub_nightly-0.23.0.dev202508240418.dist-info → keras_hub_nightly-0.23.0.dev202508260411.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.23.0.dev202508240418.dist-info → keras_hub_nightly-0.23.0.dev202508260411.dist-info}/RECORD +19 -8
- {keras_hub_nightly-0.23.0.dev202508240418.dist-info → keras_hub_nightly-0.23.0.dev202508260411.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.23.0.dev202508240418.dist-info → keras_hub_nightly-0.23.0.dev202508260411.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,355 @@
|
|
1
|
+
import keras
|
2
|
+
|
3
|
+
from keras_hub.src.models.gemma.rms_normalization import RMSNormalization
|
4
|
+
from keras_hub.src.models.t5gemma.t5gemma_attention import T5GemmaAttention
|
5
|
+
from keras_hub.src.models.t5gemma.t5gemma_layers import T5GemmaMLP
|
6
|
+
|
7
|
+
|
8
|
+
class T5GemmaDecoderLayer(keras.layers.Layer):
|
9
|
+
"""Decoder layer for the T5Gemma model.
|
10
|
+
|
11
|
+
This layer implements a single decoder block in the T5Gemma architecture,
|
12
|
+
comprising self-attention, cross-attention, and a feed-forward network
|
13
|
+
(MLP).
|
14
|
+
|
15
|
+
Args:
|
16
|
+
hidden_size: int, The dimensionality of the hidden states.
|
17
|
+
rms_norm_eps: float, The epsilon value for RMS normalization.
|
18
|
+
num_attention_heads: int, The number of attention heads in
|
19
|
+
self-attention and cross-attention.
|
20
|
+
num_key_value_heads: int, The number of key-value heads for grouped
|
21
|
+
query attention.
|
22
|
+
query_pre_attn_scalar: float, Scalar to multiply queries by before
|
23
|
+
attention.
|
24
|
+
attention_bias: bool, Whether to include bias in attention computations.
|
25
|
+
intermediate_size: int, The intermediate size of the feed-forward
|
26
|
+
network.
|
27
|
+
hidden_activation: str, The activation function used in the feed-forward
|
28
|
+
network.
|
29
|
+
dropout_rate: float, The dropout rate applied after attention and MLP.
|
30
|
+
head_dim: int, The dimensionality of each attention head.
|
31
|
+
initializer_range: float, The range for the random normal initializer.
|
32
|
+
attention_dropout: float, The dropout rate applied to attention weights.
|
33
|
+
layer_type: str, Type of attention layer, e.g., `"sliding_attention"`.
|
34
|
+
cross_attention_hidden_size: int, optional, The hidden size for
|
35
|
+
cross-attention. If None, it defaults to `hidden_size`. Defaults to
|
36
|
+
`None`.
|
37
|
+
attn_logit_softcapping: float, optional, The softcapping value for
|
38
|
+
attention logits. Defaults to `None`.
|
39
|
+
sliding_window: int, optional, The window size for sliding attention.
|
40
|
+
Required if `layer_type` is `"sliding_attention"`. Defaults to
|
41
|
+
`None`.
|
42
|
+
rope_max_wavelength: float, The maximum wavelength for Rotary
|
43
|
+
Positional Embeddings. Defaults to `10000.0`.
|
44
|
+
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
|
45
|
+
for model computations and weights. Defaults to `None`.
|
46
|
+
**kwargs: Additional keyword arguments passed to the parent class.
|
47
|
+
"""
|
48
|
+
|
49
|
+
def __init__(
|
50
|
+
self,
|
51
|
+
hidden_size,
|
52
|
+
rms_norm_eps,
|
53
|
+
num_attention_heads,
|
54
|
+
num_key_value_heads,
|
55
|
+
query_pre_attn_scalar,
|
56
|
+
attention_bias,
|
57
|
+
intermediate_size,
|
58
|
+
hidden_activation,
|
59
|
+
dropout_rate,
|
60
|
+
head_dim,
|
61
|
+
initializer_range,
|
62
|
+
attention_dropout,
|
63
|
+
layer_type,
|
64
|
+
cross_attention_hidden_size=None,
|
65
|
+
attn_logit_softcapping=None,
|
66
|
+
sliding_window=None,
|
67
|
+
rope_max_wavelength=10000.0,
|
68
|
+
dtype=None,
|
69
|
+
**kwargs,
|
70
|
+
):
|
71
|
+
super().__init__(dtype=dtype, **kwargs)
|
72
|
+
self.head_dim = head_dim
|
73
|
+
self.hidden_size = hidden_size
|
74
|
+
self.rms_norm_eps = rms_norm_eps
|
75
|
+
self.num_attention_heads = num_attention_heads
|
76
|
+
self.num_key_value_heads = num_key_value_heads
|
77
|
+
self.query_pre_attn_scalar = query_pre_attn_scalar
|
78
|
+
self.attention_bias = attention_bias
|
79
|
+
self.intermediate_size = intermediate_size
|
80
|
+
self.hidden_activation = hidden_activation
|
81
|
+
self.dropout_rate = dropout_rate
|
82
|
+
self.initializer_range = initializer_range
|
83
|
+
self.attention_dropout = attention_dropout
|
84
|
+
self.layer_type = layer_type
|
85
|
+
self.sliding_window = sliding_window
|
86
|
+
self.rope_max_wavelength = rope_max_wavelength
|
87
|
+
self.cross_attention_hidden_size = cross_attention_hidden_size
|
88
|
+
self.attn_logit_softcapping = attn_logit_softcapping
|
89
|
+
if (
|
90
|
+
self.layer_type == "sliding_attention"
|
91
|
+
and self.sliding_window is None
|
92
|
+
):
|
93
|
+
raise ValueError(
|
94
|
+
"`sliding_window` must be set for `sliding_attention` layer "
|
95
|
+
"type."
|
96
|
+
)
|
97
|
+
|
98
|
+
# Self-attention.
|
99
|
+
self.self_attn = T5GemmaAttention(
|
100
|
+
hidden_size=hidden_size,
|
101
|
+
num_attention_heads=num_attention_heads,
|
102
|
+
num_key_value_heads=num_key_value_heads,
|
103
|
+
query_pre_attn_scalar=query_pre_attn_scalar,
|
104
|
+
attention_bias=attention_bias,
|
105
|
+
head_dim=self.head_dim,
|
106
|
+
attention_type="self",
|
107
|
+
initializer_range=initializer_range,
|
108
|
+
attention_dropout=attention_dropout,
|
109
|
+
attn_logit_softcapping=attn_logit_softcapping,
|
110
|
+
rope_max_wavelength=self.rope_max_wavelength,
|
111
|
+
dtype=self.dtype_policy,
|
112
|
+
name="self_attention",
|
113
|
+
)
|
114
|
+
self.pre_self_attn_layernorm = RMSNormalization(
|
115
|
+
epsilon=rms_norm_eps,
|
116
|
+
dtype=self.dtype_policy,
|
117
|
+
name="decoder_pre_self_attention_layernorm",
|
118
|
+
)
|
119
|
+
self.post_self_attn_layernorm = RMSNormalization(
|
120
|
+
epsilon=rms_norm_eps,
|
121
|
+
dtype=self.dtype_policy,
|
122
|
+
name="decoder_post_self_attention_layernorm",
|
123
|
+
)
|
124
|
+
|
125
|
+
# Cross-attention.
|
126
|
+
self.cross_attn = T5GemmaAttention(
|
127
|
+
hidden_size=hidden_size,
|
128
|
+
cross_attention_hidden_size=cross_attention_hidden_size,
|
129
|
+
num_attention_heads=num_attention_heads,
|
130
|
+
num_key_value_heads=num_key_value_heads,
|
131
|
+
query_pre_attn_scalar=query_pre_attn_scalar,
|
132
|
+
attention_bias=attention_bias,
|
133
|
+
head_dim=self.head_dim,
|
134
|
+
attention_type="cross",
|
135
|
+
initializer_range=initializer_range,
|
136
|
+
attention_dropout=attention_dropout,
|
137
|
+
attn_logit_softcapping=attn_logit_softcapping,
|
138
|
+
dtype=self.dtype_policy,
|
139
|
+
name="cross_attention",
|
140
|
+
)
|
141
|
+
self.pre_cross_attn_layernorm = RMSNormalization(
|
142
|
+
epsilon=rms_norm_eps,
|
143
|
+
dtype=self.dtype_policy,
|
144
|
+
name="decoder_pre_cross_attention_layernorm",
|
145
|
+
)
|
146
|
+
self.post_cross_attn_layernorm = RMSNormalization(
|
147
|
+
epsilon=rms_norm_eps,
|
148
|
+
dtype=self.dtype_policy,
|
149
|
+
name="decoder_post_cross_attention_layernorm",
|
150
|
+
)
|
151
|
+
|
152
|
+
# MLP.
|
153
|
+
self.mlp = T5GemmaMLP(
|
154
|
+
hidden_size,
|
155
|
+
intermediate_size,
|
156
|
+
hidden_activation,
|
157
|
+
dropout_rate,
|
158
|
+
initializer_range=initializer_range,
|
159
|
+
dtype=self.dtype_policy,
|
160
|
+
name="mlp",
|
161
|
+
)
|
162
|
+
self.pre_feedforward_layernorm = RMSNormalization(
|
163
|
+
epsilon=rms_norm_eps,
|
164
|
+
dtype=self.dtype_policy,
|
165
|
+
name="decoder_pre_feedforward_layernorm",
|
166
|
+
)
|
167
|
+
self.post_feedforward_layernorm = RMSNormalization(
|
168
|
+
epsilon=rms_norm_eps,
|
169
|
+
dtype=self.dtype_policy,
|
170
|
+
name="decoder_post_feedforward_layernorm",
|
171
|
+
)
|
172
|
+
|
173
|
+
self.dropout = keras.layers.Dropout(
|
174
|
+
dropout_rate,
|
175
|
+
dtype=self.dtype_policy,
|
176
|
+
name="decoder_residual_dropout",
|
177
|
+
)
|
178
|
+
|
179
|
+
def build(self, input_shape):
|
180
|
+
hidden_states_shape, encoder_hidden_states_shape = input_shape
|
181
|
+
self.pre_self_attn_layernorm.build(hidden_states_shape)
|
182
|
+
current_shape = hidden_states_shape
|
183
|
+
self.self_attn.build(current_shape)
|
184
|
+
attn_output_shape, _ = self.self_attn.compute_output_shape(
|
185
|
+
current_shape
|
186
|
+
)
|
187
|
+
self.post_self_attn_layernorm.build(attn_output_shape)
|
188
|
+
current_shape = attn_output_shape
|
189
|
+
self.dropout.build(current_shape)
|
190
|
+
self.pre_cross_attn_layernorm.build(current_shape)
|
191
|
+
self.cross_attn.build([current_shape, encoder_hidden_states_shape])
|
192
|
+
attn_output_shape, _ = self.cross_attn.compute_output_shape(
|
193
|
+
[current_shape, encoder_hidden_states_shape]
|
194
|
+
)
|
195
|
+
self.post_cross_attn_layernorm.build(attn_output_shape)
|
196
|
+
current_shape = attn_output_shape
|
197
|
+
self.pre_feedforward_layernorm.build(current_shape)
|
198
|
+
self.mlp.build(current_shape)
|
199
|
+
mlp_output_shape = self.mlp.compute_output_shape(current_shape)
|
200
|
+
self.post_feedforward_layernorm.build(mlp_output_shape)
|
201
|
+
self.built = True
|
202
|
+
|
203
|
+
def _make_self_attention_mask(
|
204
|
+
self,
|
205
|
+
hidden_states,
|
206
|
+
padding_mask,
|
207
|
+
cache=None,
|
208
|
+
cache_update_index=None,
|
209
|
+
):
|
210
|
+
if cache is not None:
|
211
|
+
q_len = keras.ops.shape(hidden_states)[1]
|
212
|
+
kv_len = keras.ops.shape(cache)[2]
|
213
|
+
q_indices = (
|
214
|
+
keras.ops.arange(0, q_len, dtype="int32") + cache_update_index
|
215
|
+
)
|
216
|
+
kv_indices = keras.ops.arange(0, kv_len, dtype="int32")
|
217
|
+
else:
|
218
|
+
q_len = kv_len = keras.ops.shape(hidden_states)[1]
|
219
|
+
q_indices = keras.ops.arange(0, q_len, dtype="int32")
|
220
|
+
kv_indices = keras.ops.arange(0, kv_len, dtype="int32")
|
221
|
+
# Create the causal mask.
|
222
|
+
causal_mask = kv_indices[None, :] <= q_indices[:, None]
|
223
|
+
# Apply sliding window if applicable.
|
224
|
+
if self.layer_type == "sliding_attention":
|
225
|
+
sliding_mask = (
|
226
|
+
q_indices[:, None] - self.sliding_window
|
227
|
+
) <= kv_indices[None, :]
|
228
|
+
causal_mask = keras.ops.logical_and(causal_mask, sliding_mask)
|
229
|
+
# Combine with padding mask.
|
230
|
+
final_mask = causal_mask[None, None, :, :]
|
231
|
+
if padding_mask is not None:
|
232
|
+
padding_mask_slice = padding_mask[:, :kv_len]
|
233
|
+
padding_mask_4d = padding_mask_slice[:, None, None, :]
|
234
|
+
final_mask = keras.ops.logical_and(final_mask, padding_mask_4d)
|
235
|
+
return (1.0 - keras.ops.cast(final_mask, hidden_states.dtype)) * -1e9
|
236
|
+
|
237
|
+
def _make_cross_attention_mask(self, hidden_states, padding_mask):
|
238
|
+
if padding_mask is None:
|
239
|
+
return None
|
240
|
+
bidirectional_mask = padding_mask[:, None, None, :]
|
241
|
+
additive_bidirectional_mask = (
|
242
|
+
1.0 - keras.ops.cast(bidirectional_mask, hidden_states.dtype)
|
243
|
+
) * -1e9
|
244
|
+
return additive_bidirectional_mask
|
245
|
+
|
246
|
+
def call(
|
247
|
+
self,
|
248
|
+
inputs,
|
249
|
+
self_attention_padding_mask=None,
|
250
|
+
cross_attention_padding_mask=None,
|
251
|
+
cache=None,
|
252
|
+
cache_update_index=None,
|
253
|
+
training=None,
|
254
|
+
):
|
255
|
+
hidden_states, encoder_hidden_states = inputs
|
256
|
+
self_attention_cache, cross_attention_cache = (
|
257
|
+
cache if cache is not None else (None, None)
|
258
|
+
)
|
259
|
+
# Self Attention.
|
260
|
+
residual = hidden_states
|
261
|
+
self_attention_mask = self._make_self_attention_mask(
|
262
|
+
hidden_states,
|
263
|
+
self_attention_padding_mask,
|
264
|
+
cache=self_attention_cache,
|
265
|
+
cache_update_index=cache_update_index,
|
266
|
+
)
|
267
|
+
hidden_states = self.pre_self_attn_layernorm(hidden_states)
|
268
|
+
hidden_states, updated_self_attention_cache = self.self_attn(
|
269
|
+
inputs=hidden_states,
|
270
|
+
attention_mask=self_attention_mask,
|
271
|
+
cache=self_attention_cache,
|
272
|
+
cache_update_index=cache_update_index,
|
273
|
+
training=training,
|
274
|
+
)
|
275
|
+
hidden_states = self.post_self_attn_layernorm(hidden_states)
|
276
|
+
hidden_states = residual + self.dropout(
|
277
|
+
hidden_states, training=training
|
278
|
+
)
|
279
|
+
|
280
|
+
# Cross Attention.
|
281
|
+
residual = hidden_states
|
282
|
+
cross_attention_mask = self._make_cross_attention_mask(
|
283
|
+
encoder_hidden_states, cross_attention_padding_mask
|
284
|
+
)
|
285
|
+
hidden_states = self.pre_cross_attn_layernorm(hidden_states)
|
286
|
+
hidden_states, updated_cross_attention_cache = self.cross_attn(
|
287
|
+
inputs=[hidden_states, encoder_hidden_states],
|
288
|
+
attention_mask=cross_attention_mask,
|
289
|
+
cache=cross_attention_cache,
|
290
|
+
training=training,
|
291
|
+
)
|
292
|
+
|
293
|
+
hidden_states = self.post_cross_attn_layernorm(hidden_states)
|
294
|
+
hidden_states = residual + self.dropout(
|
295
|
+
hidden_states, training=training
|
296
|
+
)
|
297
|
+
|
298
|
+
# MLP.
|
299
|
+
residual = hidden_states
|
300
|
+
hidden_states = self.pre_feedforward_layernorm(hidden_states)
|
301
|
+
hidden_states = self.mlp(hidden_states, training=training)
|
302
|
+
hidden_states = self.post_feedforward_layernorm(hidden_states)
|
303
|
+
hidden_states = residual + self.dropout(
|
304
|
+
hidden_states, training=training
|
305
|
+
)
|
306
|
+
updated_cache = (
|
307
|
+
updated_self_attention_cache,
|
308
|
+
updated_cross_attention_cache,
|
309
|
+
)
|
310
|
+
return hidden_states, updated_cache
|
311
|
+
|
312
|
+
def compute_output_shape(self, input_shape):
|
313
|
+
hidden_states_shape, encoder_hidden_states_shape = input_shape
|
314
|
+
batch_size, dec_seq_len, _ = hidden_states_shape
|
315
|
+
_, enc_seq_len, _ = encoder_hidden_states_shape
|
316
|
+
self_cache_shape = (
|
317
|
+
batch_size,
|
318
|
+
2,
|
319
|
+
dec_seq_len,
|
320
|
+
self.num_key_value_heads,
|
321
|
+
self.head_dim,
|
322
|
+
)
|
323
|
+
cross_cache_shape = (
|
324
|
+
batch_size,
|
325
|
+
2,
|
326
|
+
enc_seq_len,
|
327
|
+
self.num_key_value_heads,
|
328
|
+
self.head_dim,
|
329
|
+
)
|
330
|
+
return hidden_states_shape, (self_cache_shape, cross_cache_shape)
|
331
|
+
|
332
|
+
def get_config(self):
|
333
|
+
config = super().get_config()
|
334
|
+
config.update(
|
335
|
+
{
|
336
|
+
"hidden_size": self.hidden_size,
|
337
|
+
"rms_norm_eps": self.rms_norm_eps,
|
338
|
+
"num_attention_heads": self.num_attention_heads,
|
339
|
+
"num_key_value_heads": self.num_key_value_heads,
|
340
|
+
"query_pre_attn_scalar": self.query_pre_attn_scalar,
|
341
|
+
"attention_bias": self.attention_bias,
|
342
|
+
"intermediate_size": self.intermediate_size,
|
343
|
+
"hidden_activation": self.hidden_activation,
|
344
|
+
"dropout_rate": self.dropout_rate,
|
345
|
+
"initializer_range": self.initializer_range,
|
346
|
+
"attention_dropout": self.attention_dropout,
|
347
|
+
"layer_type": self.layer_type,
|
348
|
+
"sliding_window": self.sliding_window,
|
349
|
+
"rope_max_wavelength": self.rope_max_wavelength,
|
350
|
+
"head_dim": self.head_dim,
|
351
|
+
"cross_attention_hidden_size": self.cross_attention_hidden_size,
|
352
|
+
"attn_logit_softcapping": self.attn_logit_softcapping,
|
353
|
+
}
|
354
|
+
)
|
355
|
+
return config
|
@@ -0,0 +1,214 @@
|
|
1
|
+
import keras
|
2
|
+
|
3
|
+
from keras_hub.src.models.gemma.rms_normalization import RMSNormalization
|
4
|
+
from keras_hub.src.models.t5gemma.t5gemma_attention import T5GemmaAttention
|
5
|
+
from keras_hub.src.models.t5gemma.t5gemma_layers import T5GemmaMLP
|
6
|
+
|
7
|
+
|
8
|
+
class T5GemmaEncoderLayer(keras.layers.Layer):
|
9
|
+
"""Encoder layer for the T5Gemma model.
|
10
|
+
|
11
|
+
This layer implements a single encoder block in the T5Gemma architecture,
|
12
|
+
comprising self-attention and a feed-forward network (MLP).
|
13
|
+
|
14
|
+
Args:
|
15
|
+
hidden_size: int, The dimensionality of the hidden states.
|
16
|
+
rms_norm_eps: float, The epsilon value for RMS normalization.
|
17
|
+
num_attention_heads: int, The number of attention heads in
|
18
|
+
self-attention.
|
19
|
+
num_key_value_heads: int, The number of key-value heads for grouped
|
20
|
+
query attention.
|
21
|
+
query_pre_attn_scalar: float, Scalar to multiply queries by before
|
22
|
+
attention.
|
23
|
+
attention_bias: bool, Whether to include bias in attention computations.
|
24
|
+
intermediate_size: int, The intermediate size of the feed-forward
|
25
|
+
network.
|
26
|
+
hidden_activation: str, The activation function used in the feed-forward
|
27
|
+
network.
|
28
|
+
dropout_rate: float, The dropout rate applied after attention and MLP.
|
29
|
+
initializer_range: float, The range for the random normal initializer.
|
30
|
+
attention_dropout: float, The dropout rate applied to attention weights.
|
31
|
+
layer_type: str, Type of attention layer, e.g., `"sliding_attention"`.
|
32
|
+
head_dim: int, The dimensionality of each attention head.
|
33
|
+
attn_logit_softcapping: float, optional, The softcapping value for
|
34
|
+
attention logits. Defaults to `None`.
|
35
|
+
sliding_window: int, optional, The window size for sliding attention.
|
36
|
+
Required if `layer_type` is `"sliding_attention"`. Defaults to
|
37
|
+
`None`.
|
38
|
+
rope_max_wavelength: float, The maximum wavelength for Rotary Positional
|
39
|
+
Embeddings. Defaults to `10000.0`.
|
40
|
+
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
|
41
|
+
for model computations and weights. Defaults to `None`.
|
42
|
+
**kwargs: Additional keyword arguments passed to the parent class.
|
43
|
+
"""
|
44
|
+
|
45
|
+
def __init__(
|
46
|
+
self,
|
47
|
+
hidden_size,
|
48
|
+
rms_norm_eps,
|
49
|
+
num_attention_heads,
|
50
|
+
num_key_value_heads,
|
51
|
+
query_pre_attn_scalar,
|
52
|
+
attention_bias,
|
53
|
+
intermediate_size,
|
54
|
+
hidden_activation,
|
55
|
+
dropout_rate,
|
56
|
+
initializer_range,
|
57
|
+
attention_dropout,
|
58
|
+
layer_type,
|
59
|
+
head_dim,
|
60
|
+
attn_logit_softcapping=None,
|
61
|
+
sliding_window=None,
|
62
|
+
rope_max_wavelength=10000.0,
|
63
|
+
dtype=None,
|
64
|
+
**kwargs,
|
65
|
+
):
|
66
|
+
super().__init__(dtype=dtype, **kwargs)
|
67
|
+
self.hidden_size = hidden_size
|
68
|
+
self.rms_norm_eps = rms_norm_eps
|
69
|
+
self.num_attention_heads = num_attention_heads
|
70
|
+
self.num_key_value_heads = num_key_value_heads
|
71
|
+
self.query_pre_attn_scalar = query_pre_attn_scalar
|
72
|
+
self.attention_bias = attention_bias
|
73
|
+
self.intermediate_size = intermediate_size
|
74
|
+
self.hidden_activation = hidden_activation
|
75
|
+
self.dropout_rate = dropout_rate
|
76
|
+
self.initializer_range = initializer_range
|
77
|
+
self.attention_dropout = attention_dropout
|
78
|
+
self.layer_type = layer_type
|
79
|
+
self.sliding_window = sliding_window
|
80
|
+
self.rope_max_wavelength = rope_max_wavelength
|
81
|
+
self.head_dim = head_dim
|
82
|
+
self.attn_logit_softcapping = attn_logit_softcapping
|
83
|
+
if (
|
84
|
+
self.layer_type == "sliding_attention"
|
85
|
+
and self.sliding_window is None
|
86
|
+
):
|
87
|
+
raise ValueError(
|
88
|
+
"`sliding_window` must be set for `sliding_attention` layer "
|
89
|
+
"type."
|
90
|
+
)
|
91
|
+
self.self_attn = T5GemmaAttention(
|
92
|
+
hidden_size=hidden_size,
|
93
|
+
num_attention_heads=num_attention_heads,
|
94
|
+
num_key_value_heads=num_key_value_heads,
|
95
|
+
query_pre_attn_scalar=query_pre_attn_scalar,
|
96
|
+
attention_bias=attention_bias,
|
97
|
+
head_dim=self.head_dim,
|
98
|
+
attention_type="self",
|
99
|
+
initializer_range=initializer_range,
|
100
|
+
attention_dropout=attention_dropout,
|
101
|
+
attn_logit_softcapping=attn_logit_softcapping,
|
102
|
+
rope_max_wavelength=self.rope_max_wavelength,
|
103
|
+
dtype=self.dtype_policy,
|
104
|
+
name="self_attention",
|
105
|
+
)
|
106
|
+
self.pre_self_attn_layernorm = RMSNormalization(
|
107
|
+
epsilon=rms_norm_eps,
|
108
|
+
dtype=self.dtype_policy,
|
109
|
+
name="pre_self_attention_layernorm",
|
110
|
+
)
|
111
|
+
self.post_self_attn_layernorm = RMSNormalization(
|
112
|
+
epsilon=rms_norm_eps,
|
113
|
+
dtype=self.dtype_policy,
|
114
|
+
name="post_self_attention_layernorm",
|
115
|
+
)
|
116
|
+
|
117
|
+
self.mlp = T5GemmaMLP(
|
118
|
+
hidden_size,
|
119
|
+
intermediate_size,
|
120
|
+
hidden_activation,
|
121
|
+
dropout_rate,
|
122
|
+
initializer_range=initializer_range,
|
123
|
+
dtype=self.dtype_policy,
|
124
|
+
name="mlp",
|
125
|
+
)
|
126
|
+
self.pre_feedforward_layernorm = RMSNormalization(
|
127
|
+
epsilon=rms_norm_eps,
|
128
|
+
dtype=self.dtype_policy,
|
129
|
+
name="pre_feedforward_layernorm",
|
130
|
+
)
|
131
|
+
self.post_feedforward_layernorm = RMSNormalization(
|
132
|
+
epsilon=rms_norm_eps,
|
133
|
+
dtype=self.dtype_policy,
|
134
|
+
name="post_feedforward_layernorm",
|
135
|
+
)
|
136
|
+
self.dropout = keras.layers.Dropout(
|
137
|
+
dropout_rate,
|
138
|
+
dtype=self.dtype_policy,
|
139
|
+
name="residual_dropout",
|
140
|
+
)
|
141
|
+
|
142
|
+
def build(self, input_shape):
|
143
|
+
self.pre_self_attn_layernorm.build(input_shape)
|
144
|
+
self.self_attn.build(input_shape)
|
145
|
+
attn_output_shape, _ = self.self_attn.compute_output_shape(input_shape)
|
146
|
+
self.post_self_attn_layernorm.build(attn_output_shape)
|
147
|
+
self.dropout.build(attn_output_shape)
|
148
|
+
self.pre_feedforward_layernorm.build(attn_output_shape)
|
149
|
+
self.mlp.build(attn_output_shape)
|
150
|
+
mlp_output_shape = self.mlp.compute_output_shape(attn_output_shape)
|
151
|
+
self.post_feedforward_layernorm.build(mlp_output_shape)
|
152
|
+
self.built = True
|
153
|
+
|
154
|
+
def _make_attention_mask(self, hidden_states, padding_mask):
|
155
|
+
attention_mask = padding_mask[:, None, None, :]
|
156
|
+
additive_mask = (
|
157
|
+
1.0 - keras.ops.cast(attention_mask, hidden_states.dtype)
|
158
|
+
) * -1e9
|
159
|
+
return additive_mask
|
160
|
+
|
161
|
+
def call(
|
162
|
+
self,
|
163
|
+
hidden_states,
|
164
|
+
padding_mask=None,
|
165
|
+
training=None,
|
166
|
+
):
|
167
|
+
residual = hidden_states
|
168
|
+
attention_mask = self._make_attention_mask(hidden_states, padding_mask)
|
169
|
+
hidden_states = self.pre_self_attn_layernorm(hidden_states)
|
170
|
+
hidden_states, _ = self.self_attn(
|
171
|
+
inputs=hidden_states,
|
172
|
+
attention_mask=attention_mask,
|
173
|
+
training=training,
|
174
|
+
)
|
175
|
+
hidden_states = self.post_self_attn_layernorm(hidden_states)
|
176
|
+
hidden_states = residual + self.dropout(
|
177
|
+
hidden_states, training=training
|
178
|
+
)
|
179
|
+
residual = hidden_states
|
180
|
+
hidden_states = self.pre_feedforward_layernorm(hidden_states)
|
181
|
+
hidden_states = self.mlp(hidden_states, training=training)
|
182
|
+
hidden_states = self.post_feedforward_layernorm(hidden_states)
|
183
|
+
hidden_states = residual + self.dropout(
|
184
|
+
hidden_states, training=training
|
185
|
+
)
|
186
|
+
return hidden_states
|
187
|
+
|
188
|
+
def compute_output_shape(self, input_shape):
|
189
|
+
# Isometric.
|
190
|
+
return input_shape
|
191
|
+
|
192
|
+
def get_config(self):
|
193
|
+
config = super().get_config()
|
194
|
+
config.update(
|
195
|
+
{
|
196
|
+
"hidden_size": self.hidden_size,
|
197
|
+
"rms_norm_eps": self.rms_norm_eps,
|
198
|
+
"head_dim": self.head_dim,
|
199
|
+
"num_attention_heads": self.num_attention_heads,
|
200
|
+
"num_key_value_heads": self.num_key_value_heads,
|
201
|
+
"query_pre_attn_scalar": self.query_pre_attn_scalar,
|
202
|
+
"attention_bias": self.attention_bias,
|
203
|
+
"intermediate_size": self.intermediate_size,
|
204
|
+
"hidden_activation": self.hidden_activation,
|
205
|
+
"dropout_rate": self.dropout_rate,
|
206
|
+
"initializer_range": self.initializer_range,
|
207
|
+
"attention_dropout": self.attention_dropout,
|
208
|
+
"layer_type": self.layer_type,
|
209
|
+
"sliding_window": self.sliding_window,
|
210
|
+
"rope_max_wavelength": self.rope_max_wavelength,
|
211
|
+
"attn_logit_softcapping": self.attn_logit_softcapping,
|
212
|
+
}
|
213
|
+
)
|
214
|
+
return config
|
@@ -0,0 +1,118 @@
|
|
1
|
+
import keras
|
2
|
+
|
3
|
+
from keras_hub.src.utils.keras_utils import clone_initializer
|
4
|
+
|
5
|
+
|
6
|
+
def t5gemma_kernel_initializer(initializer_range=0.01):
|
7
|
+
"""Creates a RandomNormal initializer for T5Gemma kernels.
|
8
|
+
|
9
|
+
Args:
|
10
|
+
initializer_range: float, The standard deviation of the normal
|
11
|
+
distribution. Defaults to `0.01`.
|
12
|
+
|
13
|
+
Returns:
|
14
|
+
keras.initializers.RandomNormal: A Keras RandomNormal initializer.
|
15
|
+
"""
|
16
|
+
return keras.initializers.RandomNormal(mean=0.0, stddev=initializer_range)
|
17
|
+
|
18
|
+
|
19
|
+
class T5GemmaMLP(keras.layers.Layer):
|
20
|
+
"""Multilayer Perceptron (MLP) block for the T5Gemma model.
|
21
|
+
|
22
|
+
This layer implements the feed-forward part of a transformer block,
|
23
|
+
consisting of two dense layers with a GELU activation and dropout.
|
24
|
+
|
25
|
+
Args:
|
26
|
+
hidden_size: int, The dimensionality of the input and output hidden
|
27
|
+
states.
|
28
|
+
intermediate_size: int, The dimensionality of the intermediate layer.
|
29
|
+
hidden_activation: str, The activation function to use, e.g.,
|
30
|
+
"gelu_approximate".
|
31
|
+
dropout_rate: float, The dropout rate applied to the intermediate
|
32
|
+
hidden states.
|
33
|
+
initializer_range: float, The range for the random normal initializer
|
34
|
+
for kernel weights. Defaults to `0.02`.
|
35
|
+
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
|
36
|
+
for model computations and weights. Defaults to `None`.
|
37
|
+
**kwargs: Additional keyword arguments passed to the parent class.
|
38
|
+
"""
|
39
|
+
|
40
|
+
def __init__(
|
41
|
+
self,
|
42
|
+
hidden_size,
|
43
|
+
intermediate_size,
|
44
|
+
hidden_activation,
|
45
|
+
dropout_rate,
|
46
|
+
initializer_range=0.02,
|
47
|
+
dtype=None,
|
48
|
+
**kwargs,
|
49
|
+
):
|
50
|
+
super().__init__(dtype=dtype, **kwargs)
|
51
|
+
self.hidden_size = hidden_size
|
52
|
+
self.intermediate_size = intermediate_size
|
53
|
+
self.hidden_activation = hidden_activation
|
54
|
+
self.dropout_rate = dropout_rate
|
55
|
+
self.initializer_range = initializer_range
|
56
|
+
self.kernel_initializer = t5gemma_kernel_initializer(initializer_range)
|
57
|
+
|
58
|
+
self.gate_proj = keras.layers.Dense(
|
59
|
+
self.intermediate_size,
|
60
|
+
use_bias=False,
|
61
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
62
|
+
dtype=self.dtype_policy,
|
63
|
+
name="gate_proj",
|
64
|
+
)
|
65
|
+
self.up_proj = keras.layers.Dense(
|
66
|
+
self.intermediate_size,
|
67
|
+
use_bias=False,
|
68
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
69
|
+
dtype=self.dtype_policy,
|
70
|
+
name="up_proj",
|
71
|
+
)
|
72
|
+
self.down_proj = keras.layers.Dense(
|
73
|
+
self.hidden_size,
|
74
|
+
use_bias=False,
|
75
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
76
|
+
dtype=self.dtype_policy,
|
77
|
+
name="down_proj",
|
78
|
+
)
|
79
|
+
if self.hidden_activation == "gelu_approximate":
|
80
|
+
# NOTE: `gelu_pytorch_tanh` is the same as `gelu(approximate=True)`.
|
81
|
+
self.act_fn = lambda x: keras.activations.gelu(x, approximate=True)
|
82
|
+
else:
|
83
|
+
self.act_fn = keras.activations.get(self.hidden_activation)
|
84
|
+
self.dropout = keras.layers.Dropout(
|
85
|
+
self.dropout_rate,
|
86
|
+
dtype=self.dtype_policy,
|
87
|
+
name="mlp_dropout",
|
88
|
+
)
|
89
|
+
|
90
|
+
def build(self, input_shape):
|
91
|
+
self.gate_proj.build(input_shape)
|
92
|
+
self.up_proj.build(input_shape)
|
93
|
+
intermediate_shape = self.gate_proj.compute_output_shape(input_shape)
|
94
|
+
self.dropout.build(intermediate_shape)
|
95
|
+
self.down_proj.build(intermediate_shape)
|
96
|
+
self.built = True
|
97
|
+
|
98
|
+
def compute_output_shape(self, input_shape):
|
99
|
+
return input_shape
|
100
|
+
|
101
|
+
def call(self, x, training=None):
|
102
|
+
hidden_states = self.act_fn(self.gate_proj(x)) * self.up_proj(x)
|
103
|
+
hidden_states = self.dropout(hidden_states, training=training)
|
104
|
+
down_proj = self.down_proj(hidden_states)
|
105
|
+
return down_proj
|
106
|
+
|
107
|
+
def get_config(self):
|
108
|
+
config = super().get_config()
|
109
|
+
config.update(
|
110
|
+
{
|
111
|
+
"hidden_size": self.hidden_size,
|
112
|
+
"intermediate_size": self.intermediate_size,
|
113
|
+
"hidden_activation": self.hidden_activation,
|
114
|
+
"dropout_rate": self.dropout_rate,
|
115
|
+
"initializer_range": self.initializer_range,
|
116
|
+
}
|
117
|
+
)
|
118
|
+
return config
|