keras-hub-nightly 0.23.0.dev202508250413__py3-none-any.whl → 0.23.0.dev202508260411__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/models/__init__.py +12 -0
- keras_hub/src/models/t5gemma/__init__.py +5 -0
- keras_hub/src/models/t5gemma/t5gemma_attention.py +370 -0
- keras_hub/src/models/t5gemma/t5gemma_backbone.py +366 -0
- keras_hub/src/models/t5gemma/t5gemma_decoder.py +355 -0
- keras_hub/src/models/t5gemma/t5gemma_encoder.py +214 -0
- keras_hub/src/models/t5gemma/t5gemma_layers.py +118 -0
- keras_hub/src/models/t5gemma/t5gemma_presets.py +15 -0
- keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm.py +442 -0
- keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm_preprocessor.py +216 -0
- keras_hub/src/models/t5gemma/t5gemma_tokenizer.py +84 -0
- keras_hub/src/utils/transformers/convert_t5gemma.py +229 -0
- keras_hub/src/utils/transformers/preset_loader.py +3 -0
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +3 -0
- {keras_hub_nightly-0.23.0.dev202508250413.dist-info → keras_hub_nightly-0.23.0.dev202508260411.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.23.0.dev202508250413.dist-info → keras_hub_nightly-0.23.0.dev202508260411.dist-info}/RECORD +19 -8
- {keras_hub_nightly-0.23.0.dev202508250413.dist-info → keras_hub_nightly-0.23.0.dev202508260411.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.23.0.dev202508250413.dist-info → keras_hub_nightly-0.23.0.dev202508260411.dist-info}/top_level.txt +0 -0
keras_hub/models/__init__.py
CHANGED
@@ -615,6 +615,18 @@ from keras_hub.src.models.t5.t5_preprocessor import (
|
|
615
615
|
T5Preprocessor as T5Preprocessor,
|
616
616
|
)
|
617
617
|
from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer as T5Tokenizer
|
618
|
+
from keras_hub.src.models.t5gemma.t5gemma_backbone import (
|
619
|
+
T5GemmaBackbone as T5GemmaBackbone,
|
620
|
+
)
|
621
|
+
from keras_hub.src.models.t5gemma.t5gemma_seq_2_seq_lm import (
|
622
|
+
T5GemmaSeq2SeqLM as T5GemmaSeq2SeqLM,
|
623
|
+
)
|
624
|
+
from keras_hub.src.models.t5gemma.t5gemma_seq_2_seq_lm_preprocessor import (
|
625
|
+
T5GemmaSeq2SeqLMPreprocessor as T5GemmaSeq2SeqLMPreprocessor,
|
626
|
+
)
|
627
|
+
from keras_hub.src.models.t5gemma.t5gemma_tokenizer import (
|
628
|
+
T5GemmaTokenizer as T5GemmaTokenizer,
|
629
|
+
)
|
618
630
|
from keras_hub.src.models.task import Task as Task
|
619
631
|
from keras_hub.src.models.text_classifier import TextClassifier as Classifier
|
620
632
|
from keras_hub.src.models.text_classifier import (
|
@@ -0,0 +1,370 @@
|
|
1
|
+
import inspect
|
2
|
+
|
3
|
+
import keras
|
4
|
+
|
5
|
+
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
|
6
|
+
from keras_hub.src.models.gemma.gemma_attention import CachedGemmaAttention
|
7
|
+
from keras_hub.src.models.t5gemma.t5gemma_layers import (
|
8
|
+
t5gemma_kernel_initializer,
|
9
|
+
)
|
10
|
+
from keras_hub.src.utils.keras_utils import clone_initializer
|
11
|
+
|
12
|
+
|
13
|
+
def repeat_kv(hidden_states, n_rep):
|
14
|
+
"""Repeats the key/value hidden states to match the number of query heads
|
15
|
+
for Grouped Query Attention (GQA).
|
16
|
+
|
17
|
+
This function is used in `T5GemmaAttention` to broadcast key and value
|
18
|
+
states across multiple query heads when Grouped Query Attention (GQA) is
|
19
|
+
used (i.e., when `num_query_heads` > `num_key_value_heads`).
|
20
|
+
|
21
|
+
Args:
|
22
|
+
hidden_states: Tensor, The key or value hidden states with shape
|
23
|
+
`(batch, sequence_length, num_key_value_heads, head_dim)`.
|
24
|
+
n_rep: int, The number of times to repeat the key/value heads. This is
|
25
|
+
typically `num_query_heads // num_key_value_heads`.
|
26
|
+
|
27
|
+
Returns:
|
28
|
+
Tensor: The expanded key/value hidden states with shape
|
29
|
+
`(batch, sequence_length, num_query_heads, head_dim)`.
|
30
|
+
"""
|
31
|
+
if n_rep == 1:
|
32
|
+
return hidden_states
|
33
|
+
batch, slen, num_key_value_heads, head_dim = keras.ops.shape(hidden_states)
|
34
|
+
hidden_states = keras.ops.expand_dims(hidden_states, 3)
|
35
|
+
hidden_states = keras.ops.tile(hidden_states, (1, 1, 1, n_rep, 1))
|
36
|
+
return keras.ops.reshape(
|
37
|
+
hidden_states, (batch, slen, num_key_value_heads * n_rep, head_dim)
|
38
|
+
)
|
39
|
+
|
40
|
+
|
41
|
+
class T5GemmaAttention(CachedGemmaAttention):
|
42
|
+
"""A unified attention layer for T5Gemma that handles both self-attention
|
43
|
+
and cross-attention.
|
44
|
+
|
45
|
+
This layer performs attention with optional Rotary Positional Embeddings
|
46
|
+
(RoPE) and supports Grouped Query Attention (GQA). It is used in
|
47
|
+
`T5GemmaEncoderLayer` and `T5GemmaDecoderLayer`.
|
48
|
+
|
49
|
+
Args:
|
50
|
+
hidden_size: int, The dimensionality of the hidden states.
|
51
|
+
num_attention_heads: int, The number of attention heads.
|
52
|
+
num_key_value_heads: int, The number of key-value heads. For GQA, this
|
53
|
+
can be less than `num_attention_heads`.
|
54
|
+
query_pre_attn_scalar: float, Scalar to multiply queries by before
|
55
|
+
attention.
|
56
|
+
attention_bias: bool, Whether to include bias in the dense layers.
|
57
|
+
head_dim: int, The dimensionality of each attention head.
|
58
|
+
attention_type: str, The type of attention, either 'self' or 'cross'.
|
59
|
+
Defaults to 'self'.
|
60
|
+
cross_attention_hidden_size: int, optional, The dimensionality of
|
61
|
+
encoder hidden states for cross-attention. Defaults to `None`.
|
62
|
+
initializer_range: float, The range for the random normal initializer
|
63
|
+
for kernel weights. Defaults to `0.02`.
|
64
|
+
attention_dropout: float, The dropout rate applied to attention weights.
|
65
|
+
Defaults to `0.0`.
|
66
|
+
attn_logit_softcapping: float, optional, The softcapping value for
|
67
|
+
attention logits. Defaults to `None`.
|
68
|
+
rope_max_wavelength: float, The maximum wavelength for Rotary Positional
|
69
|
+
Embeddings. Defaults to `10000.0`. Only used for self-attention.
|
70
|
+
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
|
71
|
+
for model computations and weights. Defaults to `None`.
|
72
|
+
**kwargs: Additional keyword arguments passed to the parent class.
|
73
|
+
"""
|
74
|
+
|
75
|
+
def __init__(
|
76
|
+
self,
|
77
|
+
hidden_size,
|
78
|
+
num_attention_heads,
|
79
|
+
num_key_value_heads,
|
80
|
+
query_pre_attn_scalar,
|
81
|
+
attention_bias,
|
82
|
+
head_dim,
|
83
|
+
attention_type="self",
|
84
|
+
cross_attention_hidden_size=None,
|
85
|
+
initializer_range=0.02,
|
86
|
+
attention_dropout=0.0,
|
87
|
+
attn_logit_softcapping=None,
|
88
|
+
rope_max_wavelength=10000.0,
|
89
|
+
dtype=None,
|
90
|
+
**kwargs,
|
91
|
+
):
|
92
|
+
super().__init__(
|
93
|
+
head_dim=head_dim,
|
94
|
+
num_query_heads=num_attention_heads,
|
95
|
+
num_key_value_heads=num_key_value_heads,
|
96
|
+
kernel_initializer=t5gemma_kernel_initializer(initializer_range),
|
97
|
+
logit_soft_cap=attn_logit_softcapping,
|
98
|
+
dropout=attention_dropout,
|
99
|
+
query_head_dim_normalize=False,
|
100
|
+
use_sliding_window_attention=False,
|
101
|
+
dtype=dtype,
|
102
|
+
**kwargs,
|
103
|
+
)
|
104
|
+
if attention_type not in ["self", "cross"]:
|
105
|
+
raise ValueError(
|
106
|
+
f"attention_type must be 'self' or 'cross', but got "
|
107
|
+
f"{attention_type}"
|
108
|
+
)
|
109
|
+
self.attention_type = attention_type
|
110
|
+
self.hidden_size = hidden_size
|
111
|
+
self.cross_attention_hidden_size = (
|
112
|
+
cross_attention_hidden_size or hidden_size
|
113
|
+
)
|
114
|
+
self.query_pre_attn_scalar = query_pre_attn_scalar
|
115
|
+
self.attention_bias = attention_bias
|
116
|
+
self.initializer_range = initializer_range
|
117
|
+
self.attention_dropout = attention_dropout
|
118
|
+
self.rope_max_wavelength = rope_max_wavelength
|
119
|
+
self.num_key_value_groups = (
|
120
|
+
self.num_query_heads // self.num_key_value_heads
|
121
|
+
)
|
122
|
+
self.scaling = self.query_pre_attn_scalar**-0.5
|
123
|
+
if self.attention_type == "self":
|
124
|
+
self.rotary_embedding = RotaryEmbedding(
|
125
|
+
max_wavelength=self.rope_max_wavelength,
|
126
|
+
sequence_axis=1,
|
127
|
+
feature_axis=3,
|
128
|
+
name="rotary_embedding",
|
129
|
+
dtype=self.dtype_policy,
|
130
|
+
)
|
131
|
+
|
132
|
+
def build(self, input_shape):
|
133
|
+
self._kernel_initializer = t5gemma_kernel_initializer(
|
134
|
+
self.initializer_range
|
135
|
+
)
|
136
|
+
|
137
|
+
if self.attention_type == "cross":
|
138
|
+
hidden_states_shape, kv_states_shape = input_shape
|
139
|
+
else:
|
140
|
+
hidden_states_shape = input_shape
|
141
|
+
kv_states_shape = input_shape
|
142
|
+
# Query projection layer.
|
143
|
+
self.hidden_dim = hidden_states_shape[-1]
|
144
|
+
self.query_dense = keras.layers.EinsumDense(
|
145
|
+
equation="btd,dnh->btnh",
|
146
|
+
output_shape=(None, self.num_query_heads, self.head_dim),
|
147
|
+
kernel_initializer=clone_initializer(self._kernel_initializer),
|
148
|
+
bias_axes="nh" if self.attention_bias else None,
|
149
|
+
dtype=self.dtype_policy,
|
150
|
+
name="query",
|
151
|
+
)
|
152
|
+
self.query_dense.build(hidden_states_shape)
|
153
|
+
|
154
|
+
# Key projection layer.
|
155
|
+
self.key_dense = keras.layers.EinsumDense(
|
156
|
+
equation="bsd,dkh->bskh",
|
157
|
+
output_shape=(None, self.num_key_value_heads, self.head_dim),
|
158
|
+
kernel_initializer=clone_initializer(self._kernel_initializer),
|
159
|
+
bias_axes="kh" if self.attention_bias else None,
|
160
|
+
dtype=self.dtype_policy,
|
161
|
+
name="key",
|
162
|
+
)
|
163
|
+
self.key_dense.build(kv_states_shape)
|
164
|
+
|
165
|
+
# Value projection layer.
|
166
|
+
self.value_dense = keras.layers.EinsumDense(
|
167
|
+
equation="bsd,dkh->bskh",
|
168
|
+
output_shape=(None, self.num_key_value_heads, self.head_dim),
|
169
|
+
kernel_initializer=clone_initializer(self._kernel_initializer),
|
170
|
+
bias_axes="kh" if self.attention_bias else None,
|
171
|
+
dtype=self.dtype_policy,
|
172
|
+
name="value",
|
173
|
+
)
|
174
|
+
self.value_dense.build(kv_states_shape)
|
175
|
+
|
176
|
+
# Output projection layer.
|
177
|
+
self.output_dense = keras.layers.EinsumDense(
|
178
|
+
equation="btnh,nhd->btd",
|
179
|
+
output_shape=(None, self.hidden_dim),
|
180
|
+
kernel_initializer=clone_initializer(self._kernel_initializer),
|
181
|
+
bias_axes="d" if self.attention_bias else None,
|
182
|
+
dtype=self.dtype_policy,
|
183
|
+
name="attention_output",
|
184
|
+
)
|
185
|
+
self.output_dense.build(
|
186
|
+
(
|
187
|
+
hidden_states_shape[0],
|
188
|
+
hidden_states_shape[1],
|
189
|
+
self.num_query_heads,
|
190
|
+
self.head_dim,
|
191
|
+
)
|
192
|
+
)
|
193
|
+
self.dropout_layer = keras.layers.Dropout(
|
194
|
+
rate=self.attention_dropout,
|
195
|
+
dtype=self.dtype_policy,
|
196
|
+
)
|
197
|
+
self.softmax = keras.layers.Softmax(axis=-1, dtype="float32")
|
198
|
+
self.built = True
|
199
|
+
|
200
|
+
def _compute_attention_without_fused_op(
|
201
|
+
self, query_states, key_states, value_states, attention_mask, training
|
202
|
+
):
|
203
|
+
attn_weights = keras.ops.einsum(
|
204
|
+
"btnh,bsnh->bnts", query_states, key_states
|
205
|
+
)
|
206
|
+
attn_weights *= self.scaling
|
207
|
+
if self.logit_soft_cap is not None:
|
208
|
+
attn_weights = attn_weights / self.logit_soft_cap
|
209
|
+
attn_weights = keras.ops.tanh(attn_weights)
|
210
|
+
attn_weights = attn_weights * self.logit_soft_cap
|
211
|
+
if attention_mask is not None:
|
212
|
+
attn_weights += attention_mask
|
213
|
+
attn_weights = keras.ops.cast(
|
214
|
+
self.softmax(attn_weights),
|
215
|
+
query_states.dtype,
|
216
|
+
)
|
217
|
+
attn_weights = self.dropout_layer(attn_weights, training=training)
|
218
|
+
attn_output = keras.ops.einsum(
|
219
|
+
"bnts,bsnh->btnh", attn_weights, value_states
|
220
|
+
)
|
221
|
+
return attn_output
|
222
|
+
|
223
|
+
def _compute_attention(
|
224
|
+
self, query_states, key_states, value_states, attention_mask, training
|
225
|
+
):
|
226
|
+
if self._use_fused_attention_op():
|
227
|
+
kwargs = {"bias": attention_mask}
|
228
|
+
if self.logit_soft_cap is not None:
|
229
|
+
sig = inspect.signature(keras.ops.dot_product_attention)
|
230
|
+
# This is only supported in JAX TPU backend.
|
231
|
+
# https://keras.io/api/ops/nn/#dot_product_attention-function
|
232
|
+
if "attn_logits_soft_cap" in sig.parameters:
|
233
|
+
kwargs["attn_logits_soft_cap"] = self.logit_soft_cap
|
234
|
+
return keras.ops.dot_product_attention(
|
235
|
+
query=query_states,
|
236
|
+
key=key_states,
|
237
|
+
value=value_states,
|
238
|
+
scale=self.scaling,
|
239
|
+
**kwargs,
|
240
|
+
)
|
241
|
+
return self._compute_attention_without_fused_op(
|
242
|
+
query_states,
|
243
|
+
key_states,
|
244
|
+
value_states,
|
245
|
+
attention_mask,
|
246
|
+
training,
|
247
|
+
)
|
248
|
+
|
249
|
+
def call(
|
250
|
+
self,
|
251
|
+
inputs,
|
252
|
+
attention_mask=None,
|
253
|
+
cache=None,
|
254
|
+
cache_update_index=None,
|
255
|
+
training=None,
|
256
|
+
):
|
257
|
+
if self.attention_type == "cross":
|
258
|
+
if not isinstance(inputs, (list, tuple)) or len(inputs) != 2:
|
259
|
+
raise ValueError(
|
260
|
+
"For cross-attention, `inputs` must be a list or tuple of "
|
261
|
+
"two tensors: `[hidden_states, encoder_hidden_states]`."
|
262
|
+
)
|
263
|
+
hidden_states, kv_states = inputs
|
264
|
+
query_states = self.query_dense(hidden_states)
|
265
|
+
if cache is not None:
|
266
|
+
if cache_update_index is not None:
|
267
|
+
raise ValueError(
|
268
|
+
"`cache_update_index` should not be set for "
|
269
|
+
"cross-attention caching."
|
270
|
+
)
|
271
|
+
key_states, value_states = cache[:, 0, ...], cache[:, 1, ...]
|
272
|
+
updated_cache = cache
|
273
|
+
else:
|
274
|
+
key_states = self.key_dense(kv_states)
|
275
|
+
value_states = self.value_dense(kv_states)
|
276
|
+
updated_cache = keras.ops.stack(
|
277
|
+
(key_states, value_states), axis=1
|
278
|
+
)
|
279
|
+
# Repeat key-value heads for GQA.
|
280
|
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
281
|
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
282
|
+
attn_output = self._compute_attention(
|
283
|
+
query_states, key_states, value_states, attention_mask, training
|
284
|
+
)
|
285
|
+
attn_output = self.output_dense(attn_output)
|
286
|
+
return attn_output, updated_cache
|
287
|
+
else: # Self-attention
|
288
|
+
hidden_states = inputs
|
289
|
+
kv_states = hidden_states
|
290
|
+
query_states = self.query_dense(hidden_states)
|
291
|
+
key_states = self.key_dense(kv_states)
|
292
|
+
value_states = self.value_dense(kv_states)
|
293
|
+
start_index = (
|
294
|
+
0 if cache_update_index is None else cache_update_index
|
295
|
+
)
|
296
|
+
query_states = self.rotary_embedding(
|
297
|
+
query_states, start_index=start_index
|
298
|
+
)
|
299
|
+
key_states = self.rotary_embedding(
|
300
|
+
key_states, start_index=start_index
|
301
|
+
)
|
302
|
+
if cache is not None:
|
303
|
+
if cache_update_index is None:
|
304
|
+
raise ValueError(
|
305
|
+
"Both `cache` and `cache_update_index` must be passed "
|
306
|
+
"for self-attention caching."
|
307
|
+
)
|
308
|
+
key_cache, value_cache = cache[:, 0, ...], cache[:, 1, ...]
|
309
|
+
start = [0, cache_update_index, 0, 0]
|
310
|
+
key_states = keras.ops.slice_update(
|
311
|
+
key_cache, start, key_states
|
312
|
+
)
|
313
|
+
value_states = keras.ops.slice_update(
|
314
|
+
value_cache, start, value_states
|
315
|
+
)
|
316
|
+
cache = keras.ops.stack((key_states, value_states), axis=1)
|
317
|
+
elif cache_update_index is not None:
|
318
|
+
raise ValueError(
|
319
|
+
"`cache_update_index` should not be set if `cache` is "
|
320
|
+
"`None`."
|
321
|
+
)
|
322
|
+
else:
|
323
|
+
cache = keras.ops.stack((key_states, value_states), axis=1)
|
324
|
+
|
325
|
+
# Repeat key-value heads for GQA.
|
326
|
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
327
|
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
328
|
+
|
329
|
+
attn_output = self._compute_attention(
|
330
|
+
query_states, key_states, value_states, attention_mask, training
|
331
|
+
)
|
332
|
+
attn_output = self.output_dense(attn_output)
|
333
|
+
return attn_output, cache
|
334
|
+
|
335
|
+
def compute_output_shape(self, input_shape):
|
336
|
+
if self.attention_type == "cross":
|
337
|
+
hidden_states_shape, kv_states_shape = input_shape
|
338
|
+
else:
|
339
|
+
hidden_states_shape = input_shape
|
340
|
+
kv_states_shape = input_shape
|
341
|
+
attn_output_shape = hidden_states_shape
|
342
|
+
kv_len = kv_states_shape[1]
|
343
|
+
cache_shape = (
|
344
|
+
hidden_states_shape[0], # batch
|
345
|
+
2, # key and value
|
346
|
+
kv_len,
|
347
|
+
self.num_key_value_heads,
|
348
|
+
self.head_dim,
|
349
|
+
)
|
350
|
+
return attn_output_shape, cache_shape
|
351
|
+
|
352
|
+
def get_config(self):
|
353
|
+
config = super().get_config()
|
354
|
+
config.update(
|
355
|
+
{
|
356
|
+
"hidden_size": self.hidden_size,
|
357
|
+
"head_dim": self.head_dim,
|
358
|
+
"num_attention_heads": self.num_query_heads,
|
359
|
+
"num_key_value_heads": self.num_key_value_heads,
|
360
|
+
"query_pre_attn_scalar": self.query_pre_attn_scalar,
|
361
|
+
"attention_bias": self.attention_bias,
|
362
|
+
"attention_type": self.attention_type,
|
363
|
+
"cross_attention_hidden_size": self.cross_attention_hidden_size,
|
364
|
+
"initializer_range": self.initializer_range,
|
365
|
+
"attention_dropout": self.attention_dropout,
|
366
|
+
"attn_logit_softcapping": self.logit_soft_cap,
|
367
|
+
"rope_max_wavelength": self.rope_max_wavelength,
|
368
|
+
}
|
369
|
+
)
|
370
|
+
return config
|