keras-hub-nightly 0.23.0.dev202508240418__py3-none-any.whl → 0.23.0.dev202508260411__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,366 @@
1
+ import keras
2
+
3
+ from keras_hub.src.api_export import keras_hub_export
4
+ from keras_hub.src.layers.modeling.reversible_embedding import (
5
+ ReversibleEmbedding,
6
+ )
7
+ from keras_hub.src.models.backbone import Backbone
8
+ from keras_hub.src.models.gemma.rms_normalization import RMSNormalization
9
+ from keras_hub.src.models.t5gemma.t5gemma_decoder import T5GemmaDecoderLayer
10
+ from keras_hub.src.models.t5gemma.t5gemma_encoder import T5GemmaEncoderLayer
11
+ from keras_hub.src.models.t5gemma.t5gemma_layers import (
12
+ t5gemma_kernel_initializer,
13
+ )
14
+ from keras_hub.src.utils.keras_utils import clone_initializer
15
+
16
+
17
+ @keras_hub_export("keras_hub.models.T5GemmaBackbone")
18
+ class T5GemmaBackbone(Backbone):
19
+ """T5Gemma backbone model.
20
+
21
+ This class implements the encoder-decoder backbone of the T5Gemma model,
22
+ consisting of an embedding layer, a stack of encoder layers, and a
23
+ stack of decoder layers.
24
+
25
+ Args:
26
+ vocabulary_size: int, The size of the vocabulary.
27
+ encoder_hidden_dim: int, The hidden dimensionality of the encoder.
28
+ encoder_intermediate_dim: int, The intermediate size of the encoder's
29
+ feed-forward networks.
30
+ encoder_num_layers: int, The number of encoder layers.
31
+ encoder_num_attention_heads: int, The number of attention heads in the
32
+ encoder.
33
+ encoder_num_key_value_heads: int, The number of key-value heads in the
34
+ encoder.
35
+ encoder_head_dim: int, The dimensionality of each attention head in the
36
+ encoder.
37
+ encoder_layer_types: list of str, A list of strings specifying the type
38
+ of attention layer for each encoder layer. Each element can be
39
+ either `"sliding_attention"` or `"full_attention"`. For example,
40
+ `["full_attention", "sliding_attention", ...]`.
41
+ decoder_hidden_dim: int, The hidden dimensionality of the decoder.
42
+ decoder_intermediate_dim: int, The intermediate size of the decoder's
43
+ feed-forward networks.
44
+ decoder_num_layers: int, The number of decoder layers.
45
+ decoder_num_attention_heads: int, The number of attention heads in the
46
+ decoder.
47
+ decoder_num_key_value_heads: int, The number of key-value heads in the
48
+ decoder.
49
+ decoder_head_dim: int, The dimensionality of each attention head in the
50
+ decoder.
51
+ decoder_layer_types: list of str, A list of strings specifying the type
52
+ of attention layer for each decoder layer. Each element can be
53
+ either `"sliding_attention"` or `"full_attention"`. For example,
54
+ `["full_attention", "sliding_attention", ...]`.
55
+ dropout_rate: float, The dropout rate applied throughout the model.
56
+ Defaults to `0.0`.
57
+ rms_norm_eps: float, The epsilon value for RMS normalization. Defaults
58
+ to `1e-6`.
59
+ query_pre_attn_scalar: float, Scalar to multiply queries by before
60
+ attention. Defaults to `1.0`.
61
+ attention_bias: bool, Whether to include bias in attention computations.
62
+ Defaults to `False`.
63
+ hidden_activation: str, The activation function used in the feed-forward
64
+ networks. Defaults to `"gelu_approximate"`.
65
+ tie_word_embeddings: bool, Whether to tie input and output word
66
+ embeddings. Defaults to `True`.
67
+ initializer_range: float, The range for the random normal initializer.
68
+ Defaults to `0.02`.
69
+ attention_dropout: float, The dropout rate applied to attention weights.
70
+ Defaults to `0.0`.
71
+ sliding_window: int, optional, The window size for sliding attention.
72
+ Required if any `layer_type` is `"sliding_attention"`. Defaults to
73
+ `None`.
74
+ cross_attention_hidden_size: int, optional, The hidden size for
75
+ cross-attention in the decoder layers. If None, it defaults to
76
+ `encoder_hidden_dim`. Defaults to `None`.
77
+ attn_logit_softcapping: float, optional, The softcapping value for
78
+ attention logits. Defaults to `None`.
79
+ final_logit_softcapping: float, optional, The softcapping value for
80
+ final logits. Defaults to `None`.
81
+ rope_max_wavelength: float, The maximum wavelength for Rotary Positional
82
+ Embeddings. Defaults to `10000.0`.
83
+ dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
84
+ for model computations and weights. Note that some computations,
85
+ such as softmax and layer normalization, will always be done at
86
+ float32 precision regardless of dtype. Defaults to `None`.
87
+ **kwargs: Additional keyword arguments passed to the parent `Backbone`
88
+ class.
89
+
90
+ Examples:
91
+ ```python
92
+ import numpy as np
93
+ from keras_hub.models import T5GemmaBackbone
94
+
95
+ input_data = {
96
+ "encoder_token_ids": np.ones(shape=(1, 12), dtype="int32"),
97
+ "encoder_padding_mask": np.array(
98
+ [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]], dtype="int32"
99
+ ),
100
+ "decoder_token_ids": np.ones(shape=(1, 8), dtype="int32"),
101
+ "decoder_padding_mask": np.array(
102
+ [[1, 1, 1, 1, 1, 1, 1, 1]], dtype="int32"
103
+ ),
104
+ }
105
+
106
+ # Randomly initialized T5Gemma backbone with custom config.
107
+ model = T5GemmaBackbone(
108
+ vocabulary_size=32000,
109
+ # Encoder parameters.
110
+ encoder_hidden_dim=256,
111
+ encoder_intermediate_dim=512,
112
+ encoder_num_layers=4,
113
+ encoder_num_attention_heads=4,
114
+ encoder_num_key_value_heads=2,
115
+ encoder_head_dim=64,
116
+ encoder_layer_types=["full_attention"] * 4,
117
+ # Decoder parameters.
118
+ decoder_hidden_dim=256,
119
+ decoder_intermediate_dim=512,
120
+ decoder_num_layers=4,
121
+ decoder_num_attention_heads=4,
122
+ decoder_num_key_value_heads=2,
123
+ decoder_head_dim=64,
124
+ decoder_layer_types=["full_attention"] * 4,
125
+ # Common parameters.
126
+ dropout_rate=0.1,
127
+ rms_norm_eps=1e-6,
128
+ query_pre_attn_scalar=1.0,
129
+ attention_bias=False,
130
+ hidden_activation="gelu_approximate",
131
+ )
132
+ output = model(input_data)
133
+ ```
134
+ """
135
+
136
+ def __init__(
137
+ self,
138
+ vocabulary_size,
139
+ encoder_hidden_dim,
140
+ encoder_intermediate_dim,
141
+ encoder_num_layers,
142
+ encoder_num_attention_heads,
143
+ encoder_num_key_value_heads,
144
+ encoder_head_dim,
145
+ encoder_layer_types,
146
+ decoder_hidden_dim,
147
+ decoder_intermediate_dim,
148
+ decoder_num_layers,
149
+ decoder_num_attention_heads,
150
+ decoder_num_key_value_heads,
151
+ decoder_head_dim,
152
+ decoder_layer_types,
153
+ dropout_rate=0.0,
154
+ rms_norm_eps=1e-6,
155
+ query_pre_attn_scalar=1.0,
156
+ attention_bias=False,
157
+ hidden_activation="gelu_approximate",
158
+ tie_word_embeddings=True,
159
+ initializer_range=0.02,
160
+ attention_dropout=0.0,
161
+ sliding_window=None,
162
+ cross_attention_hidden_size=None,
163
+ attn_logit_softcapping=None,
164
+ final_logit_softcapping=None,
165
+ rope_max_wavelength=10000.0,
166
+ dtype=None,
167
+ **kwargs,
168
+ ):
169
+ self.kernel_initializer = t5gemma_kernel_initializer(initializer_range)
170
+
171
+ # === Layers ===
172
+ self.token_embedding = keras.layers.Embedding(
173
+ input_dim=vocabulary_size,
174
+ output_dim=encoder_hidden_dim,
175
+ embeddings_initializer=clone_initializer(self.kernel_initializer),
176
+ dtype=dtype,
177
+ name="encoder_token_embedding",
178
+ )
179
+ self.decoder_token_embedding = ReversibleEmbedding(
180
+ input_dim=vocabulary_size,
181
+ output_dim=decoder_hidden_dim,
182
+ tie_weights=tie_word_embeddings,
183
+ embeddings_initializer=clone_initializer(self.kernel_initializer),
184
+ dtype=dtype,
185
+ name="decoder_token_embedding",
186
+ )
187
+ self.encoder_layers = [
188
+ T5GemmaEncoderLayer(
189
+ hidden_size=encoder_hidden_dim,
190
+ rms_norm_eps=rms_norm_eps,
191
+ num_attention_heads=encoder_num_attention_heads,
192
+ num_key_value_heads=encoder_num_key_value_heads,
193
+ query_pre_attn_scalar=query_pre_attn_scalar,
194
+ attention_bias=attention_bias,
195
+ intermediate_size=encoder_intermediate_dim,
196
+ hidden_activation=hidden_activation,
197
+ head_dim=encoder_head_dim,
198
+ dropout_rate=dropout_rate,
199
+ initializer_range=initializer_range,
200
+ attention_dropout=attention_dropout,
201
+ layer_type=encoder_layer_types[i],
202
+ sliding_window=sliding_window,
203
+ attn_logit_softcapping=attn_logit_softcapping,
204
+ rope_max_wavelength=rope_max_wavelength,
205
+ name=f"encoder_layer_{i}",
206
+ dtype=dtype,
207
+ )
208
+ for i in range(encoder_num_layers)
209
+ ]
210
+ self.encoder_norm = RMSNormalization(epsilon=rms_norm_eps, dtype=dtype)
211
+ self.encoder_dropout = keras.layers.Dropout(dropout_rate, dtype=dtype)
212
+ self.decoder_layers = [
213
+ T5GemmaDecoderLayer(
214
+ hidden_size=decoder_hidden_dim,
215
+ rms_norm_eps=rms_norm_eps,
216
+ num_attention_heads=decoder_num_attention_heads,
217
+ num_key_value_heads=decoder_num_key_value_heads,
218
+ query_pre_attn_scalar=query_pre_attn_scalar,
219
+ attention_bias=attention_bias,
220
+ intermediate_size=decoder_intermediate_dim,
221
+ hidden_activation=hidden_activation,
222
+ dropout_rate=dropout_rate,
223
+ initializer_range=initializer_range,
224
+ head_dim=decoder_head_dim,
225
+ attention_dropout=attention_dropout,
226
+ layer_type=decoder_layer_types[i],
227
+ sliding_window=sliding_window,
228
+ cross_attention_hidden_size=(
229
+ cross_attention_hidden_size or encoder_hidden_dim
230
+ ),
231
+ attn_logit_softcapping=attn_logit_softcapping,
232
+ rope_max_wavelength=rope_max_wavelength,
233
+ name=f"decoder_layer_{i}",
234
+ dtype=dtype,
235
+ )
236
+ for i in range(decoder_num_layers)
237
+ ]
238
+ self.decoder_norm = RMSNormalization(epsilon=rms_norm_eps, dtype=dtype)
239
+ self.decoder_dropout = keras.layers.Dropout(dropout_rate, dtype=dtype)
240
+
241
+ # === Functional Model ===
242
+ encoder_token_id_input = keras.Input(
243
+ shape=(None,), dtype="int32", name="encoder_token_ids"
244
+ )
245
+ encoder_padding_mask_input = keras.Input(
246
+ shape=(None,), dtype="int32", name="encoder_padding_mask"
247
+ )
248
+ decoder_token_id_input = keras.Input(
249
+ shape=(None,), dtype="int32", name="decoder_token_ids"
250
+ )
251
+ decoder_padding_mask_input = keras.Input(
252
+ shape=(None,), dtype="int32", name="decoder_padding_mask"
253
+ )
254
+
255
+ # Encoder.
256
+ encoder_embeddings = self.token_embedding(encoder_token_id_input)
257
+ encoder_embeddings = encoder_embeddings * keras.ops.cast(
258
+ keras.ops.sqrt(encoder_hidden_dim), encoder_embeddings.dtype
259
+ )
260
+ encoder_hidden_states = self.encoder_dropout(encoder_embeddings)
261
+ for layer in self.encoder_layers:
262
+ encoder_hidden_states = layer(
263
+ encoder_hidden_states, padding_mask=encoder_padding_mask_input
264
+ )
265
+ encoder_output = self.encoder_norm(encoder_hidden_states)
266
+ encoder_output = self.encoder_dropout(encoder_output)
267
+
268
+ # Decoder.
269
+ decoder_embeddings = self.decoder_token_embedding(
270
+ decoder_token_id_input
271
+ )
272
+ decoder_embeddings = decoder_embeddings * keras.ops.cast(
273
+ keras.ops.sqrt(decoder_hidden_dim), decoder_embeddings.dtype
274
+ )
275
+ decoder_hidden_states = self.decoder_dropout(decoder_embeddings)
276
+ for layer in self.decoder_layers:
277
+ decoder_hidden_states, _ = layer(
278
+ (decoder_hidden_states, encoder_output),
279
+ self_attention_padding_mask=decoder_padding_mask_input,
280
+ cross_attention_padding_mask=encoder_padding_mask_input,
281
+ )
282
+ decoder_output = self.decoder_norm(decoder_hidden_states)
283
+ decoder_output = self.decoder_dropout(decoder_output)
284
+
285
+ super().__init__(
286
+ inputs={
287
+ "encoder_token_ids": encoder_token_id_input,
288
+ "encoder_padding_mask": encoder_padding_mask_input,
289
+ "decoder_token_ids": decoder_token_id_input,
290
+ "decoder_padding_mask": decoder_padding_mask_input,
291
+ },
292
+ outputs={
293
+ "encoder_sequence_output": encoder_output,
294
+ "decoder_sequence_output": decoder_output,
295
+ },
296
+ dtype=dtype,
297
+ **kwargs,
298
+ )
299
+
300
+ # === Config ===
301
+ self.encoder_hidden_dim = encoder_hidden_dim
302
+ self.encoder_intermediate_dim = encoder_intermediate_dim
303
+ self.encoder_num_layers = encoder_num_layers
304
+ self.encoder_num_attention_heads = encoder_num_attention_heads
305
+ self.encoder_num_key_value_heads = encoder_num_key_value_heads
306
+ self.encoder_head_dim = encoder_head_dim
307
+ self.encoder_layer_types = encoder_layer_types
308
+ self.decoder_hidden_dim = decoder_hidden_dim
309
+ self.decoder_intermediate_dim = decoder_intermediate_dim
310
+ self.decoder_num_layers = decoder_num_layers
311
+ self.decoder_num_attention_heads = decoder_num_attention_heads
312
+ self.decoder_num_key_value_heads = decoder_num_key_value_heads
313
+ self.decoder_head_dim = decoder_head_dim
314
+ self.decoder_layer_types = decoder_layer_types
315
+ self.vocabulary_size = vocabulary_size
316
+ self.dropout_rate = dropout_rate
317
+ self.rms_norm_eps = rms_norm_eps
318
+ self.tie_word_embeddings = tie_word_embeddings
319
+ self.query_pre_attn_scalar = query_pre_attn_scalar
320
+ self.attention_bias = attention_bias
321
+ self.hidden_activation = hidden_activation
322
+ self.initializer_range = initializer_range
323
+ self.attention_dropout = attention_dropout
324
+ self.sliding_window = sliding_window
325
+ self.cross_attention_hidden_size = (
326
+ cross_attention_hidden_size or encoder_hidden_dim
327
+ )
328
+ self.attn_logit_softcapping = attn_logit_softcapping
329
+ self.final_logit_softcapping = final_logit_softcapping
330
+ self.rope_max_wavelength = rope_max_wavelength
331
+
332
+ def get_config(self):
333
+ config = super().get_config()
334
+ config.update(
335
+ {
336
+ "vocabulary_size": self.vocabulary_size,
337
+ "encoder_hidden_dim": self.encoder_hidden_dim,
338
+ "encoder_intermediate_dim": self.encoder_intermediate_dim,
339
+ "encoder_num_layers": self.encoder_num_layers,
340
+ "encoder_num_attention_heads": self.encoder_num_attention_heads,
341
+ "encoder_num_key_value_heads": self.encoder_num_key_value_heads,
342
+ "encoder_layer_types": self.encoder_layer_types,
343
+ "encoder_head_dim": self.encoder_head_dim,
344
+ "decoder_hidden_dim": self.decoder_hidden_dim,
345
+ "decoder_intermediate_dim": self.decoder_intermediate_dim,
346
+ "decoder_num_layers": self.decoder_num_layers,
347
+ "decoder_num_attention_heads": self.decoder_num_attention_heads,
348
+ "decoder_num_key_value_heads": self.decoder_num_key_value_heads,
349
+ "decoder_layer_types": self.decoder_layer_types,
350
+ "decoder_head_dim": self.decoder_head_dim,
351
+ "dropout_rate": self.dropout_rate,
352
+ "rms_norm_eps": self.rms_norm_eps,
353
+ "tie_word_embeddings": self.tie_word_embeddings,
354
+ "query_pre_attn_scalar": self.query_pre_attn_scalar,
355
+ "attention_bias": self.attention_bias,
356
+ "hidden_activation": self.hidden_activation,
357
+ "initializer_range": self.initializer_range,
358
+ "attention_dropout": self.attention_dropout,
359
+ "sliding_window": self.sliding_window,
360
+ "cross_attention_hidden_size": self.cross_attention_hidden_size,
361
+ "attn_logit_softcapping": self.attn_logit_softcapping,
362
+ "final_logit_softcapping": self.final_logit_softcapping,
363
+ "rope_max_wavelength": self.rope_max_wavelength,
364
+ }
365
+ )
366
+ return config