keras-hub-nightly 0.23.0.dev202508240418__py3-none-any.whl → 0.23.0.dev202508260411__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/models/__init__.py +12 -0
- keras_hub/src/models/t5gemma/__init__.py +5 -0
- keras_hub/src/models/t5gemma/t5gemma_attention.py +370 -0
- keras_hub/src/models/t5gemma/t5gemma_backbone.py +366 -0
- keras_hub/src/models/t5gemma/t5gemma_decoder.py +355 -0
- keras_hub/src/models/t5gemma/t5gemma_encoder.py +214 -0
- keras_hub/src/models/t5gemma/t5gemma_layers.py +118 -0
- keras_hub/src/models/t5gemma/t5gemma_presets.py +15 -0
- keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm.py +442 -0
- keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm_preprocessor.py +216 -0
- keras_hub/src/models/t5gemma/t5gemma_tokenizer.py +84 -0
- keras_hub/src/utils/transformers/convert_t5gemma.py +229 -0
- keras_hub/src/utils/transformers/preset_loader.py +3 -0
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +3 -0
- {keras_hub_nightly-0.23.0.dev202508240418.dist-info → keras_hub_nightly-0.23.0.dev202508260411.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.23.0.dev202508240418.dist-info → keras_hub_nightly-0.23.0.dev202508260411.dist-info}/RECORD +19 -8
- {keras_hub_nightly-0.23.0.dev202508240418.dist-info → keras_hub_nightly-0.23.0.dev202508260411.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.23.0.dev202508240418.dist-info → keras_hub_nightly-0.23.0.dev202508260411.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,366 @@
|
|
1
|
+
import keras
|
2
|
+
|
3
|
+
from keras_hub.src.api_export import keras_hub_export
|
4
|
+
from keras_hub.src.layers.modeling.reversible_embedding import (
|
5
|
+
ReversibleEmbedding,
|
6
|
+
)
|
7
|
+
from keras_hub.src.models.backbone import Backbone
|
8
|
+
from keras_hub.src.models.gemma.rms_normalization import RMSNormalization
|
9
|
+
from keras_hub.src.models.t5gemma.t5gemma_decoder import T5GemmaDecoderLayer
|
10
|
+
from keras_hub.src.models.t5gemma.t5gemma_encoder import T5GemmaEncoderLayer
|
11
|
+
from keras_hub.src.models.t5gemma.t5gemma_layers import (
|
12
|
+
t5gemma_kernel_initializer,
|
13
|
+
)
|
14
|
+
from keras_hub.src.utils.keras_utils import clone_initializer
|
15
|
+
|
16
|
+
|
17
|
+
@keras_hub_export("keras_hub.models.T5GemmaBackbone")
|
18
|
+
class T5GemmaBackbone(Backbone):
|
19
|
+
"""T5Gemma backbone model.
|
20
|
+
|
21
|
+
This class implements the encoder-decoder backbone of the T5Gemma model,
|
22
|
+
consisting of an embedding layer, a stack of encoder layers, and a
|
23
|
+
stack of decoder layers.
|
24
|
+
|
25
|
+
Args:
|
26
|
+
vocabulary_size: int, The size of the vocabulary.
|
27
|
+
encoder_hidden_dim: int, The hidden dimensionality of the encoder.
|
28
|
+
encoder_intermediate_dim: int, The intermediate size of the encoder's
|
29
|
+
feed-forward networks.
|
30
|
+
encoder_num_layers: int, The number of encoder layers.
|
31
|
+
encoder_num_attention_heads: int, The number of attention heads in the
|
32
|
+
encoder.
|
33
|
+
encoder_num_key_value_heads: int, The number of key-value heads in the
|
34
|
+
encoder.
|
35
|
+
encoder_head_dim: int, The dimensionality of each attention head in the
|
36
|
+
encoder.
|
37
|
+
encoder_layer_types: list of str, A list of strings specifying the type
|
38
|
+
of attention layer for each encoder layer. Each element can be
|
39
|
+
either `"sliding_attention"` or `"full_attention"`. For example,
|
40
|
+
`["full_attention", "sliding_attention", ...]`.
|
41
|
+
decoder_hidden_dim: int, The hidden dimensionality of the decoder.
|
42
|
+
decoder_intermediate_dim: int, The intermediate size of the decoder's
|
43
|
+
feed-forward networks.
|
44
|
+
decoder_num_layers: int, The number of decoder layers.
|
45
|
+
decoder_num_attention_heads: int, The number of attention heads in the
|
46
|
+
decoder.
|
47
|
+
decoder_num_key_value_heads: int, The number of key-value heads in the
|
48
|
+
decoder.
|
49
|
+
decoder_head_dim: int, The dimensionality of each attention head in the
|
50
|
+
decoder.
|
51
|
+
decoder_layer_types: list of str, A list of strings specifying the type
|
52
|
+
of attention layer for each decoder layer. Each element can be
|
53
|
+
either `"sliding_attention"` or `"full_attention"`. For example,
|
54
|
+
`["full_attention", "sliding_attention", ...]`.
|
55
|
+
dropout_rate: float, The dropout rate applied throughout the model.
|
56
|
+
Defaults to `0.0`.
|
57
|
+
rms_norm_eps: float, The epsilon value for RMS normalization. Defaults
|
58
|
+
to `1e-6`.
|
59
|
+
query_pre_attn_scalar: float, Scalar to multiply queries by before
|
60
|
+
attention. Defaults to `1.0`.
|
61
|
+
attention_bias: bool, Whether to include bias in attention computations.
|
62
|
+
Defaults to `False`.
|
63
|
+
hidden_activation: str, The activation function used in the feed-forward
|
64
|
+
networks. Defaults to `"gelu_approximate"`.
|
65
|
+
tie_word_embeddings: bool, Whether to tie input and output word
|
66
|
+
embeddings. Defaults to `True`.
|
67
|
+
initializer_range: float, The range for the random normal initializer.
|
68
|
+
Defaults to `0.02`.
|
69
|
+
attention_dropout: float, The dropout rate applied to attention weights.
|
70
|
+
Defaults to `0.0`.
|
71
|
+
sliding_window: int, optional, The window size for sliding attention.
|
72
|
+
Required if any `layer_type` is `"sliding_attention"`. Defaults to
|
73
|
+
`None`.
|
74
|
+
cross_attention_hidden_size: int, optional, The hidden size for
|
75
|
+
cross-attention in the decoder layers. If None, it defaults to
|
76
|
+
`encoder_hidden_dim`. Defaults to `None`.
|
77
|
+
attn_logit_softcapping: float, optional, The softcapping value for
|
78
|
+
attention logits. Defaults to `None`.
|
79
|
+
final_logit_softcapping: float, optional, The softcapping value for
|
80
|
+
final logits. Defaults to `None`.
|
81
|
+
rope_max_wavelength: float, The maximum wavelength for Rotary Positional
|
82
|
+
Embeddings. Defaults to `10000.0`.
|
83
|
+
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
|
84
|
+
for model computations and weights. Note that some computations,
|
85
|
+
such as softmax and layer normalization, will always be done at
|
86
|
+
float32 precision regardless of dtype. Defaults to `None`.
|
87
|
+
**kwargs: Additional keyword arguments passed to the parent `Backbone`
|
88
|
+
class.
|
89
|
+
|
90
|
+
Examples:
|
91
|
+
```python
|
92
|
+
import numpy as np
|
93
|
+
from keras_hub.models import T5GemmaBackbone
|
94
|
+
|
95
|
+
input_data = {
|
96
|
+
"encoder_token_ids": np.ones(shape=(1, 12), dtype="int32"),
|
97
|
+
"encoder_padding_mask": np.array(
|
98
|
+
[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]], dtype="int32"
|
99
|
+
),
|
100
|
+
"decoder_token_ids": np.ones(shape=(1, 8), dtype="int32"),
|
101
|
+
"decoder_padding_mask": np.array(
|
102
|
+
[[1, 1, 1, 1, 1, 1, 1, 1]], dtype="int32"
|
103
|
+
),
|
104
|
+
}
|
105
|
+
|
106
|
+
# Randomly initialized T5Gemma backbone with custom config.
|
107
|
+
model = T5GemmaBackbone(
|
108
|
+
vocabulary_size=32000,
|
109
|
+
# Encoder parameters.
|
110
|
+
encoder_hidden_dim=256,
|
111
|
+
encoder_intermediate_dim=512,
|
112
|
+
encoder_num_layers=4,
|
113
|
+
encoder_num_attention_heads=4,
|
114
|
+
encoder_num_key_value_heads=2,
|
115
|
+
encoder_head_dim=64,
|
116
|
+
encoder_layer_types=["full_attention"] * 4,
|
117
|
+
# Decoder parameters.
|
118
|
+
decoder_hidden_dim=256,
|
119
|
+
decoder_intermediate_dim=512,
|
120
|
+
decoder_num_layers=4,
|
121
|
+
decoder_num_attention_heads=4,
|
122
|
+
decoder_num_key_value_heads=2,
|
123
|
+
decoder_head_dim=64,
|
124
|
+
decoder_layer_types=["full_attention"] * 4,
|
125
|
+
# Common parameters.
|
126
|
+
dropout_rate=0.1,
|
127
|
+
rms_norm_eps=1e-6,
|
128
|
+
query_pre_attn_scalar=1.0,
|
129
|
+
attention_bias=False,
|
130
|
+
hidden_activation="gelu_approximate",
|
131
|
+
)
|
132
|
+
output = model(input_data)
|
133
|
+
```
|
134
|
+
"""
|
135
|
+
|
136
|
+
def __init__(
|
137
|
+
self,
|
138
|
+
vocabulary_size,
|
139
|
+
encoder_hidden_dim,
|
140
|
+
encoder_intermediate_dim,
|
141
|
+
encoder_num_layers,
|
142
|
+
encoder_num_attention_heads,
|
143
|
+
encoder_num_key_value_heads,
|
144
|
+
encoder_head_dim,
|
145
|
+
encoder_layer_types,
|
146
|
+
decoder_hidden_dim,
|
147
|
+
decoder_intermediate_dim,
|
148
|
+
decoder_num_layers,
|
149
|
+
decoder_num_attention_heads,
|
150
|
+
decoder_num_key_value_heads,
|
151
|
+
decoder_head_dim,
|
152
|
+
decoder_layer_types,
|
153
|
+
dropout_rate=0.0,
|
154
|
+
rms_norm_eps=1e-6,
|
155
|
+
query_pre_attn_scalar=1.0,
|
156
|
+
attention_bias=False,
|
157
|
+
hidden_activation="gelu_approximate",
|
158
|
+
tie_word_embeddings=True,
|
159
|
+
initializer_range=0.02,
|
160
|
+
attention_dropout=0.0,
|
161
|
+
sliding_window=None,
|
162
|
+
cross_attention_hidden_size=None,
|
163
|
+
attn_logit_softcapping=None,
|
164
|
+
final_logit_softcapping=None,
|
165
|
+
rope_max_wavelength=10000.0,
|
166
|
+
dtype=None,
|
167
|
+
**kwargs,
|
168
|
+
):
|
169
|
+
self.kernel_initializer = t5gemma_kernel_initializer(initializer_range)
|
170
|
+
|
171
|
+
# === Layers ===
|
172
|
+
self.token_embedding = keras.layers.Embedding(
|
173
|
+
input_dim=vocabulary_size,
|
174
|
+
output_dim=encoder_hidden_dim,
|
175
|
+
embeddings_initializer=clone_initializer(self.kernel_initializer),
|
176
|
+
dtype=dtype,
|
177
|
+
name="encoder_token_embedding",
|
178
|
+
)
|
179
|
+
self.decoder_token_embedding = ReversibleEmbedding(
|
180
|
+
input_dim=vocabulary_size,
|
181
|
+
output_dim=decoder_hidden_dim,
|
182
|
+
tie_weights=tie_word_embeddings,
|
183
|
+
embeddings_initializer=clone_initializer(self.kernel_initializer),
|
184
|
+
dtype=dtype,
|
185
|
+
name="decoder_token_embedding",
|
186
|
+
)
|
187
|
+
self.encoder_layers = [
|
188
|
+
T5GemmaEncoderLayer(
|
189
|
+
hidden_size=encoder_hidden_dim,
|
190
|
+
rms_norm_eps=rms_norm_eps,
|
191
|
+
num_attention_heads=encoder_num_attention_heads,
|
192
|
+
num_key_value_heads=encoder_num_key_value_heads,
|
193
|
+
query_pre_attn_scalar=query_pre_attn_scalar,
|
194
|
+
attention_bias=attention_bias,
|
195
|
+
intermediate_size=encoder_intermediate_dim,
|
196
|
+
hidden_activation=hidden_activation,
|
197
|
+
head_dim=encoder_head_dim,
|
198
|
+
dropout_rate=dropout_rate,
|
199
|
+
initializer_range=initializer_range,
|
200
|
+
attention_dropout=attention_dropout,
|
201
|
+
layer_type=encoder_layer_types[i],
|
202
|
+
sliding_window=sliding_window,
|
203
|
+
attn_logit_softcapping=attn_logit_softcapping,
|
204
|
+
rope_max_wavelength=rope_max_wavelength,
|
205
|
+
name=f"encoder_layer_{i}",
|
206
|
+
dtype=dtype,
|
207
|
+
)
|
208
|
+
for i in range(encoder_num_layers)
|
209
|
+
]
|
210
|
+
self.encoder_norm = RMSNormalization(epsilon=rms_norm_eps, dtype=dtype)
|
211
|
+
self.encoder_dropout = keras.layers.Dropout(dropout_rate, dtype=dtype)
|
212
|
+
self.decoder_layers = [
|
213
|
+
T5GemmaDecoderLayer(
|
214
|
+
hidden_size=decoder_hidden_dim,
|
215
|
+
rms_norm_eps=rms_norm_eps,
|
216
|
+
num_attention_heads=decoder_num_attention_heads,
|
217
|
+
num_key_value_heads=decoder_num_key_value_heads,
|
218
|
+
query_pre_attn_scalar=query_pre_attn_scalar,
|
219
|
+
attention_bias=attention_bias,
|
220
|
+
intermediate_size=decoder_intermediate_dim,
|
221
|
+
hidden_activation=hidden_activation,
|
222
|
+
dropout_rate=dropout_rate,
|
223
|
+
initializer_range=initializer_range,
|
224
|
+
head_dim=decoder_head_dim,
|
225
|
+
attention_dropout=attention_dropout,
|
226
|
+
layer_type=decoder_layer_types[i],
|
227
|
+
sliding_window=sliding_window,
|
228
|
+
cross_attention_hidden_size=(
|
229
|
+
cross_attention_hidden_size or encoder_hidden_dim
|
230
|
+
),
|
231
|
+
attn_logit_softcapping=attn_logit_softcapping,
|
232
|
+
rope_max_wavelength=rope_max_wavelength,
|
233
|
+
name=f"decoder_layer_{i}",
|
234
|
+
dtype=dtype,
|
235
|
+
)
|
236
|
+
for i in range(decoder_num_layers)
|
237
|
+
]
|
238
|
+
self.decoder_norm = RMSNormalization(epsilon=rms_norm_eps, dtype=dtype)
|
239
|
+
self.decoder_dropout = keras.layers.Dropout(dropout_rate, dtype=dtype)
|
240
|
+
|
241
|
+
# === Functional Model ===
|
242
|
+
encoder_token_id_input = keras.Input(
|
243
|
+
shape=(None,), dtype="int32", name="encoder_token_ids"
|
244
|
+
)
|
245
|
+
encoder_padding_mask_input = keras.Input(
|
246
|
+
shape=(None,), dtype="int32", name="encoder_padding_mask"
|
247
|
+
)
|
248
|
+
decoder_token_id_input = keras.Input(
|
249
|
+
shape=(None,), dtype="int32", name="decoder_token_ids"
|
250
|
+
)
|
251
|
+
decoder_padding_mask_input = keras.Input(
|
252
|
+
shape=(None,), dtype="int32", name="decoder_padding_mask"
|
253
|
+
)
|
254
|
+
|
255
|
+
# Encoder.
|
256
|
+
encoder_embeddings = self.token_embedding(encoder_token_id_input)
|
257
|
+
encoder_embeddings = encoder_embeddings * keras.ops.cast(
|
258
|
+
keras.ops.sqrt(encoder_hidden_dim), encoder_embeddings.dtype
|
259
|
+
)
|
260
|
+
encoder_hidden_states = self.encoder_dropout(encoder_embeddings)
|
261
|
+
for layer in self.encoder_layers:
|
262
|
+
encoder_hidden_states = layer(
|
263
|
+
encoder_hidden_states, padding_mask=encoder_padding_mask_input
|
264
|
+
)
|
265
|
+
encoder_output = self.encoder_norm(encoder_hidden_states)
|
266
|
+
encoder_output = self.encoder_dropout(encoder_output)
|
267
|
+
|
268
|
+
# Decoder.
|
269
|
+
decoder_embeddings = self.decoder_token_embedding(
|
270
|
+
decoder_token_id_input
|
271
|
+
)
|
272
|
+
decoder_embeddings = decoder_embeddings * keras.ops.cast(
|
273
|
+
keras.ops.sqrt(decoder_hidden_dim), decoder_embeddings.dtype
|
274
|
+
)
|
275
|
+
decoder_hidden_states = self.decoder_dropout(decoder_embeddings)
|
276
|
+
for layer in self.decoder_layers:
|
277
|
+
decoder_hidden_states, _ = layer(
|
278
|
+
(decoder_hidden_states, encoder_output),
|
279
|
+
self_attention_padding_mask=decoder_padding_mask_input,
|
280
|
+
cross_attention_padding_mask=encoder_padding_mask_input,
|
281
|
+
)
|
282
|
+
decoder_output = self.decoder_norm(decoder_hidden_states)
|
283
|
+
decoder_output = self.decoder_dropout(decoder_output)
|
284
|
+
|
285
|
+
super().__init__(
|
286
|
+
inputs={
|
287
|
+
"encoder_token_ids": encoder_token_id_input,
|
288
|
+
"encoder_padding_mask": encoder_padding_mask_input,
|
289
|
+
"decoder_token_ids": decoder_token_id_input,
|
290
|
+
"decoder_padding_mask": decoder_padding_mask_input,
|
291
|
+
},
|
292
|
+
outputs={
|
293
|
+
"encoder_sequence_output": encoder_output,
|
294
|
+
"decoder_sequence_output": decoder_output,
|
295
|
+
},
|
296
|
+
dtype=dtype,
|
297
|
+
**kwargs,
|
298
|
+
)
|
299
|
+
|
300
|
+
# === Config ===
|
301
|
+
self.encoder_hidden_dim = encoder_hidden_dim
|
302
|
+
self.encoder_intermediate_dim = encoder_intermediate_dim
|
303
|
+
self.encoder_num_layers = encoder_num_layers
|
304
|
+
self.encoder_num_attention_heads = encoder_num_attention_heads
|
305
|
+
self.encoder_num_key_value_heads = encoder_num_key_value_heads
|
306
|
+
self.encoder_head_dim = encoder_head_dim
|
307
|
+
self.encoder_layer_types = encoder_layer_types
|
308
|
+
self.decoder_hidden_dim = decoder_hidden_dim
|
309
|
+
self.decoder_intermediate_dim = decoder_intermediate_dim
|
310
|
+
self.decoder_num_layers = decoder_num_layers
|
311
|
+
self.decoder_num_attention_heads = decoder_num_attention_heads
|
312
|
+
self.decoder_num_key_value_heads = decoder_num_key_value_heads
|
313
|
+
self.decoder_head_dim = decoder_head_dim
|
314
|
+
self.decoder_layer_types = decoder_layer_types
|
315
|
+
self.vocabulary_size = vocabulary_size
|
316
|
+
self.dropout_rate = dropout_rate
|
317
|
+
self.rms_norm_eps = rms_norm_eps
|
318
|
+
self.tie_word_embeddings = tie_word_embeddings
|
319
|
+
self.query_pre_attn_scalar = query_pre_attn_scalar
|
320
|
+
self.attention_bias = attention_bias
|
321
|
+
self.hidden_activation = hidden_activation
|
322
|
+
self.initializer_range = initializer_range
|
323
|
+
self.attention_dropout = attention_dropout
|
324
|
+
self.sliding_window = sliding_window
|
325
|
+
self.cross_attention_hidden_size = (
|
326
|
+
cross_attention_hidden_size or encoder_hidden_dim
|
327
|
+
)
|
328
|
+
self.attn_logit_softcapping = attn_logit_softcapping
|
329
|
+
self.final_logit_softcapping = final_logit_softcapping
|
330
|
+
self.rope_max_wavelength = rope_max_wavelength
|
331
|
+
|
332
|
+
def get_config(self):
|
333
|
+
config = super().get_config()
|
334
|
+
config.update(
|
335
|
+
{
|
336
|
+
"vocabulary_size": self.vocabulary_size,
|
337
|
+
"encoder_hidden_dim": self.encoder_hidden_dim,
|
338
|
+
"encoder_intermediate_dim": self.encoder_intermediate_dim,
|
339
|
+
"encoder_num_layers": self.encoder_num_layers,
|
340
|
+
"encoder_num_attention_heads": self.encoder_num_attention_heads,
|
341
|
+
"encoder_num_key_value_heads": self.encoder_num_key_value_heads,
|
342
|
+
"encoder_layer_types": self.encoder_layer_types,
|
343
|
+
"encoder_head_dim": self.encoder_head_dim,
|
344
|
+
"decoder_hidden_dim": self.decoder_hidden_dim,
|
345
|
+
"decoder_intermediate_dim": self.decoder_intermediate_dim,
|
346
|
+
"decoder_num_layers": self.decoder_num_layers,
|
347
|
+
"decoder_num_attention_heads": self.decoder_num_attention_heads,
|
348
|
+
"decoder_num_key_value_heads": self.decoder_num_key_value_heads,
|
349
|
+
"decoder_layer_types": self.decoder_layer_types,
|
350
|
+
"decoder_head_dim": self.decoder_head_dim,
|
351
|
+
"dropout_rate": self.dropout_rate,
|
352
|
+
"rms_norm_eps": self.rms_norm_eps,
|
353
|
+
"tie_word_embeddings": self.tie_word_embeddings,
|
354
|
+
"query_pre_attn_scalar": self.query_pre_attn_scalar,
|
355
|
+
"attention_bias": self.attention_bias,
|
356
|
+
"hidden_activation": self.hidden_activation,
|
357
|
+
"initializer_range": self.initializer_range,
|
358
|
+
"attention_dropout": self.attention_dropout,
|
359
|
+
"sliding_window": self.sliding_window,
|
360
|
+
"cross_attention_hidden_size": self.cross_attention_hidden_size,
|
361
|
+
"attn_logit_softcapping": self.attn_logit_softcapping,
|
362
|
+
"final_logit_softcapping": self.final_logit_softcapping,
|
363
|
+
"rope_max_wavelength": self.rope_max_wavelength,
|
364
|
+
}
|
365
|
+
)
|
366
|
+
return config
|