keras-hub-nightly 0.23.0.dev202508250413__py3-none-any.whl → 0.23.0.dev202508260411__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -615,6 +615,18 @@ from keras_hub.src.models.t5.t5_preprocessor import (
615
615
  T5Preprocessor as T5Preprocessor,
616
616
  )
617
617
  from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer as T5Tokenizer
618
+ from keras_hub.src.models.t5gemma.t5gemma_backbone import (
619
+ T5GemmaBackbone as T5GemmaBackbone,
620
+ )
621
+ from keras_hub.src.models.t5gemma.t5gemma_seq_2_seq_lm import (
622
+ T5GemmaSeq2SeqLM as T5GemmaSeq2SeqLM,
623
+ )
624
+ from keras_hub.src.models.t5gemma.t5gemma_seq_2_seq_lm_preprocessor import (
625
+ T5GemmaSeq2SeqLMPreprocessor as T5GemmaSeq2SeqLMPreprocessor,
626
+ )
627
+ from keras_hub.src.models.t5gemma.t5gemma_tokenizer import (
628
+ T5GemmaTokenizer as T5GemmaTokenizer,
629
+ )
618
630
  from keras_hub.src.models.task import Task as Task
619
631
  from keras_hub.src.models.text_classifier import TextClassifier as Classifier
620
632
  from keras_hub.src.models.text_classifier import (
@@ -0,0 +1,5 @@
1
+ from keras_hub.src.models.t5gemma.t5gemma_backbone import T5GemmaBackbone
2
+ from keras_hub.src.models.t5gemma.t5gemma_presets import backbone_presets
3
+ from keras_hub.src.utils.preset_utils import register_presets
4
+
5
+ register_presets(backbone_presets, T5GemmaBackbone)
@@ -0,0 +1,370 @@
1
+ import inspect
2
+
3
+ import keras
4
+
5
+ from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
6
+ from keras_hub.src.models.gemma.gemma_attention import CachedGemmaAttention
7
+ from keras_hub.src.models.t5gemma.t5gemma_layers import (
8
+ t5gemma_kernel_initializer,
9
+ )
10
+ from keras_hub.src.utils.keras_utils import clone_initializer
11
+
12
+
13
+ def repeat_kv(hidden_states, n_rep):
14
+ """Repeats the key/value hidden states to match the number of query heads
15
+ for Grouped Query Attention (GQA).
16
+
17
+ This function is used in `T5GemmaAttention` to broadcast key and value
18
+ states across multiple query heads when Grouped Query Attention (GQA) is
19
+ used (i.e., when `num_query_heads` > `num_key_value_heads`).
20
+
21
+ Args:
22
+ hidden_states: Tensor, The key or value hidden states with shape
23
+ `(batch, sequence_length, num_key_value_heads, head_dim)`.
24
+ n_rep: int, The number of times to repeat the key/value heads. This is
25
+ typically `num_query_heads // num_key_value_heads`.
26
+
27
+ Returns:
28
+ Tensor: The expanded key/value hidden states with shape
29
+ `(batch, sequence_length, num_query_heads, head_dim)`.
30
+ """
31
+ if n_rep == 1:
32
+ return hidden_states
33
+ batch, slen, num_key_value_heads, head_dim = keras.ops.shape(hidden_states)
34
+ hidden_states = keras.ops.expand_dims(hidden_states, 3)
35
+ hidden_states = keras.ops.tile(hidden_states, (1, 1, 1, n_rep, 1))
36
+ return keras.ops.reshape(
37
+ hidden_states, (batch, slen, num_key_value_heads * n_rep, head_dim)
38
+ )
39
+
40
+
41
+ class T5GemmaAttention(CachedGemmaAttention):
42
+ """A unified attention layer for T5Gemma that handles both self-attention
43
+ and cross-attention.
44
+
45
+ This layer performs attention with optional Rotary Positional Embeddings
46
+ (RoPE) and supports Grouped Query Attention (GQA). It is used in
47
+ `T5GemmaEncoderLayer` and `T5GemmaDecoderLayer`.
48
+
49
+ Args:
50
+ hidden_size: int, The dimensionality of the hidden states.
51
+ num_attention_heads: int, The number of attention heads.
52
+ num_key_value_heads: int, The number of key-value heads. For GQA, this
53
+ can be less than `num_attention_heads`.
54
+ query_pre_attn_scalar: float, Scalar to multiply queries by before
55
+ attention.
56
+ attention_bias: bool, Whether to include bias in the dense layers.
57
+ head_dim: int, The dimensionality of each attention head.
58
+ attention_type: str, The type of attention, either 'self' or 'cross'.
59
+ Defaults to 'self'.
60
+ cross_attention_hidden_size: int, optional, The dimensionality of
61
+ encoder hidden states for cross-attention. Defaults to `None`.
62
+ initializer_range: float, The range for the random normal initializer
63
+ for kernel weights. Defaults to `0.02`.
64
+ attention_dropout: float, The dropout rate applied to attention weights.
65
+ Defaults to `0.0`.
66
+ attn_logit_softcapping: float, optional, The softcapping value for
67
+ attention logits. Defaults to `None`.
68
+ rope_max_wavelength: float, The maximum wavelength for Rotary Positional
69
+ Embeddings. Defaults to `10000.0`. Only used for self-attention.
70
+ dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
71
+ for model computations and weights. Defaults to `None`.
72
+ **kwargs: Additional keyword arguments passed to the parent class.
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ hidden_size,
78
+ num_attention_heads,
79
+ num_key_value_heads,
80
+ query_pre_attn_scalar,
81
+ attention_bias,
82
+ head_dim,
83
+ attention_type="self",
84
+ cross_attention_hidden_size=None,
85
+ initializer_range=0.02,
86
+ attention_dropout=0.0,
87
+ attn_logit_softcapping=None,
88
+ rope_max_wavelength=10000.0,
89
+ dtype=None,
90
+ **kwargs,
91
+ ):
92
+ super().__init__(
93
+ head_dim=head_dim,
94
+ num_query_heads=num_attention_heads,
95
+ num_key_value_heads=num_key_value_heads,
96
+ kernel_initializer=t5gemma_kernel_initializer(initializer_range),
97
+ logit_soft_cap=attn_logit_softcapping,
98
+ dropout=attention_dropout,
99
+ query_head_dim_normalize=False,
100
+ use_sliding_window_attention=False,
101
+ dtype=dtype,
102
+ **kwargs,
103
+ )
104
+ if attention_type not in ["self", "cross"]:
105
+ raise ValueError(
106
+ f"attention_type must be 'self' or 'cross', but got "
107
+ f"{attention_type}"
108
+ )
109
+ self.attention_type = attention_type
110
+ self.hidden_size = hidden_size
111
+ self.cross_attention_hidden_size = (
112
+ cross_attention_hidden_size or hidden_size
113
+ )
114
+ self.query_pre_attn_scalar = query_pre_attn_scalar
115
+ self.attention_bias = attention_bias
116
+ self.initializer_range = initializer_range
117
+ self.attention_dropout = attention_dropout
118
+ self.rope_max_wavelength = rope_max_wavelength
119
+ self.num_key_value_groups = (
120
+ self.num_query_heads // self.num_key_value_heads
121
+ )
122
+ self.scaling = self.query_pre_attn_scalar**-0.5
123
+ if self.attention_type == "self":
124
+ self.rotary_embedding = RotaryEmbedding(
125
+ max_wavelength=self.rope_max_wavelength,
126
+ sequence_axis=1,
127
+ feature_axis=3,
128
+ name="rotary_embedding",
129
+ dtype=self.dtype_policy,
130
+ )
131
+
132
+ def build(self, input_shape):
133
+ self._kernel_initializer = t5gemma_kernel_initializer(
134
+ self.initializer_range
135
+ )
136
+
137
+ if self.attention_type == "cross":
138
+ hidden_states_shape, kv_states_shape = input_shape
139
+ else:
140
+ hidden_states_shape = input_shape
141
+ kv_states_shape = input_shape
142
+ # Query projection layer.
143
+ self.hidden_dim = hidden_states_shape[-1]
144
+ self.query_dense = keras.layers.EinsumDense(
145
+ equation="btd,dnh->btnh",
146
+ output_shape=(None, self.num_query_heads, self.head_dim),
147
+ kernel_initializer=clone_initializer(self._kernel_initializer),
148
+ bias_axes="nh" if self.attention_bias else None,
149
+ dtype=self.dtype_policy,
150
+ name="query",
151
+ )
152
+ self.query_dense.build(hidden_states_shape)
153
+
154
+ # Key projection layer.
155
+ self.key_dense = keras.layers.EinsumDense(
156
+ equation="bsd,dkh->bskh",
157
+ output_shape=(None, self.num_key_value_heads, self.head_dim),
158
+ kernel_initializer=clone_initializer(self._kernel_initializer),
159
+ bias_axes="kh" if self.attention_bias else None,
160
+ dtype=self.dtype_policy,
161
+ name="key",
162
+ )
163
+ self.key_dense.build(kv_states_shape)
164
+
165
+ # Value projection layer.
166
+ self.value_dense = keras.layers.EinsumDense(
167
+ equation="bsd,dkh->bskh",
168
+ output_shape=(None, self.num_key_value_heads, self.head_dim),
169
+ kernel_initializer=clone_initializer(self._kernel_initializer),
170
+ bias_axes="kh" if self.attention_bias else None,
171
+ dtype=self.dtype_policy,
172
+ name="value",
173
+ )
174
+ self.value_dense.build(kv_states_shape)
175
+
176
+ # Output projection layer.
177
+ self.output_dense = keras.layers.EinsumDense(
178
+ equation="btnh,nhd->btd",
179
+ output_shape=(None, self.hidden_dim),
180
+ kernel_initializer=clone_initializer(self._kernel_initializer),
181
+ bias_axes="d" if self.attention_bias else None,
182
+ dtype=self.dtype_policy,
183
+ name="attention_output",
184
+ )
185
+ self.output_dense.build(
186
+ (
187
+ hidden_states_shape[0],
188
+ hidden_states_shape[1],
189
+ self.num_query_heads,
190
+ self.head_dim,
191
+ )
192
+ )
193
+ self.dropout_layer = keras.layers.Dropout(
194
+ rate=self.attention_dropout,
195
+ dtype=self.dtype_policy,
196
+ )
197
+ self.softmax = keras.layers.Softmax(axis=-1, dtype="float32")
198
+ self.built = True
199
+
200
+ def _compute_attention_without_fused_op(
201
+ self, query_states, key_states, value_states, attention_mask, training
202
+ ):
203
+ attn_weights = keras.ops.einsum(
204
+ "btnh,bsnh->bnts", query_states, key_states
205
+ )
206
+ attn_weights *= self.scaling
207
+ if self.logit_soft_cap is not None:
208
+ attn_weights = attn_weights / self.logit_soft_cap
209
+ attn_weights = keras.ops.tanh(attn_weights)
210
+ attn_weights = attn_weights * self.logit_soft_cap
211
+ if attention_mask is not None:
212
+ attn_weights += attention_mask
213
+ attn_weights = keras.ops.cast(
214
+ self.softmax(attn_weights),
215
+ query_states.dtype,
216
+ )
217
+ attn_weights = self.dropout_layer(attn_weights, training=training)
218
+ attn_output = keras.ops.einsum(
219
+ "bnts,bsnh->btnh", attn_weights, value_states
220
+ )
221
+ return attn_output
222
+
223
+ def _compute_attention(
224
+ self, query_states, key_states, value_states, attention_mask, training
225
+ ):
226
+ if self._use_fused_attention_op():
227
+ kwargs = {"bias": attention_mask}
228
+ if self.logit_soft_cap is not None:
229
+ sig = inspect.signature(keras.ops.dot_product_attention)
230
+ # This is only supported in JAX TPU backend.
231
+ # https://keras.io/api/ops/nn/#dot_product_attention-function
232
+ if "attn_logits_soft_cap" in sig.parameters:
233
+ kwargs["attn_logits_soft_cap"] = self.logit_soft_cap
234
+ return keras.ops.dot_product_attention(
235
+ query=query_states,
236
+ key=key_states,
237
+ value=value_states,
238
+ scale=self.scaling,
239
+ **kwargs,
240
+ )
241
+ return self._compute_attention_without_fused_op(
242
+ query_states,
243
+ key_states,
244
+ value_states,
245
+ attention_mask,
246
+ training,
247
+ )
248
+
249
+ def call(
250
+ self,
251
+ inputs,
252
+ attention_mask=None,
253
+ cache=None,
254
+ cache_update_index=None,
255
+ training=None,
256
+ ):
257
+ if self.attention_type == "cross":
258
+ if not isinstance(inputs, (list, tuple)) or len(inputs) != 2:
259
+ raise ValueError(
260
+ "For cross-attention, `inputs` must be a list or tuple of "
261
+ "two tensors: `[hidden_states, encoder_hidden_states]`."
262
+ )
263
+ hidden_states, kv_states = inputs
264
+ query_states = self.query_dense(hidden_states)
265
+ if cache is not None:
266
+ if cache_update_index is not None:
267
+ raise ValueError(
268
+ "`cache_update_index` should not be set for "
269
+ "cross-attention caching."
270
+ )
271
+ key_states, value_states = cache[:, 0, ...], cache[:, 1, ...]
272
+ updated_cache = cache
273
+ else:
274
+ key_states = self.key_dense(kv_states)
275
+ value_states = self.value_dense(kv_states)
276
+ updated_cache = keras.ops.stack(
277
+ (key_states, value_states), axis=1
278
+ )
279
+ # Repeat key-value heads for GQA.
280
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
281
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
282
+ attn_output = self._compute_attention(
283
+ query_states, key_states, value_states, attention_mask, training
284
+ )
285
+ attn_output = self.output_dense(attn_output)
286
+ return attn_output, updated_cache
287
+ else: # Self-attention
288
+ hidden_states = inputs
289
+ kv_states = hidden_states
290
+ query_states = self.query_dense(hidden_states)
291
+ key_states = self.key_dense(kv_states)
292
+ value_states = self.value_dense(kv_states)
293
+ start_index = (
294
+ 0 if cache_update_index is None else cache_update_index
295
+ )
296
+ query_states = self.rotary_embedding(
297
+ query_states, start_index=start_index
298
+ )
299
+ key_states = self.rotary_embedding(
300
+ key_states, start_index=start_index
301
+ )
302
+ if cache is not None:
303
+ if cache_update_index is None:
304
+ raise ValueError(
305
+ "Both `cache` and `cache_update_index` must be passed "
306
+ "for self-attention caching."
307
+ )
308
+ key_cache, value_cache = cache[:, 0, ...], cache[:, 1, ...]
309
+ start = [0, cache_update_index, 0, 0]
310
+ key_states = keras.ops.slice_update(
311
+ key_cache, start, key_states
312
+ )
313
+ value_states = keras.ops.slice_update(
314
+ value_cache, start, value_states
315
+ )
316
+ cache = keras.ops.stack((key_states, value_states), axis=1)
317
+ elif cache_update_index is not None:
318
+ raise ValueError(
319
+ "`cache_update_index` should not be set if `cache` is "
320
+ "`None`."
321
+ )
322
+ else:
323
+ cache = keras.ops.stack((key_states, value_states), axis=1)
324
+
325
+ # Repeat key-value heads for GQA.
326
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
327
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
328
+
329
+ attn_output = self._compute_attention(
330
+ query_states, key_states, value_states, attention_mask, training
331
+ )
332
+ attn_output = self.output_dense(attn_output)
333
+ return attn_output, cache
334
+
335
+ def compute_output_shape(self, input_shape):
336
+ if self.attention_type == "cross":
337
+ hidden_states_shape, kv_states_shape = input_shape
338
+ else:
339
+ hidden_states_shape = input_shape
340
+ kv_states_shape = input_shape
341
+ attn_output_shape = hidden_states_shape
342
+ kv_len = kv_states_shape[1]
343
+ cache_shape = (
344
+ hidden_states_shape[0], # batch
345
+ 2, # key and value
346
+ kv_len,
347
+ self.num_key_value_heads,
348
+ self.head_dim,
349
+ )
350
+ return attn_output_shape, cache_shape
351
+
352
+ def get_config(self):
353
+ config = super().get_config()
354
+ config.update(
355
+ {
356
+ "hidden_size": self.hidden_size,
357
+ "head_dim": self.head_dim,
358
+ "num_attention_heads": self.num_query_heads,
359
+ "num_key_value_heads": self.num_key_value_heads,
360
+ "query_pre_attn_scalar": self.query_pre_attn_scalar,
361
+ "attention_bias": self.attention_bias,
362
+ "attention_type": self.attention_type,
363
+ "cross_attention_hidden_size": self.cross_attention_hidden_size,
364
+ "initializer_range": self.initializer_range,
365
+ "attention_dropout": self.attention_dropout,
366
+ "attn_logit_softcapping": self.logit_soft_cap,
367
+ "rope_max_wavelength": self.rope_max_wavelength,
368
+ }
369
+ )
370
+ return config