keras-hub-nightly 0.23.0.dev202508250413__py3-none-any.whl → 0.23.0.dev202508260411__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/models/__init__.py +12 -0
- keras_hub/src/models/t5gemma/__init__.py +5 -0
- keras_hub/src/models/t5gemma/t5gemma_attention.py +370 -0
- keras_hub/src/models/t5gemma/t5gemma_backbone.py +366 -0
- keras_hub/src/models/t5gemma/t5gemma_decoder.py +355 -0
- keras_hub/src/models/t5gemma/t5gemma_encoder.py +214 -0
- keras_hub/src/models/t5gemma/t5gemma_layers.py +118 -0
- keras_hub/src/models/t5gemma/t5gemma_presets.py +15 -0
- keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm.py +442 -0
- keras_hub/src/models/t5gemma/t5gemma_seq_2_seq_lm_preprocessor.py +216 -0
- keras_hub/src/models/t5gemma/t5gemma_tokenizer.py +84 -0
- keras_hub/src/utils/transformers/convert_t5gemma.py +229 -0
- keras_hub/src/utils/transformers/preset_loader.py +3 -0
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +3 -0
- {keras_hub_nightly-0.23.0.dev202508250413.dist-info → keras_hub_nightly-0.23.0.dev202508260411.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.23.0.dev202508250413.dist-info → keras_hub_nightly-0.23.0.dev202508260411.dist-info}/RECORD +19 -8
- {keras_hub_nightly-0.23.0.dev202508250413.dist-info → keras_hub_nightly-0.23.0.dev202508260411.dist-info}/WHEEL +0 -0
- {keras_hub_nightly-0.23.0.dev202508250413.dist-info → keras_hub_nightly-0.23.0.dev202508260411.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,15 @@
|
|
1
|
+
# Metadata for loading pretrained model weights.
|
2
|
+
backbone_presets = {
|
3
|
+
"t5gemma_b_b_prefixlm_it": {
|
4
|
+
"metadata": {
|
5
|
+
"description": (
|
6
|
+
"T5Gemma B/B model with a base encoder and base decoder, "
|
7
|
+
"adapted as a prefix language model and fine-tuned for "
|
8
|
+
"instruction following."
|
9
|
+
),
|
10
|
+
"params": 591490560,
|
11
|
+
"path": "t5gemma",
|
12
|
+
},
|
13
|
+
"kaggle_handle": "kaggle://harshaljanjani/t5gemma/keras/t5gemma_b_b_prefixlm_it",
|
14
|
+
},
|
15
|
+
}
|
@@ -0,0 +1,442 @@
|
|
1
|
+
import keras
|
2
|
+
|
3
|
+
from keras_hub.src.api_export import keras_hub_export
|
4
|
+
from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM
|
5
|
+
from keras_hub.src.models.t5gemma.t5gemma_backbone import T5GemmaBackbone
|
6
|
+
from keras_hub.src.models.t5gemma.t5gemma_seq_2_seq_lm_preprocessor import (
|
7
|
+
T5GemmaSeq2SeqLMPreprocessor,
|
8
|
+
)
|
9
|
+
from keras_hub.src.utils.tensor_utils import any_equal
|
10
|
+
|
11
|
+
|
12
|
+
@keras_hub_export("keras_hub.models.T5GemmaSeq2SeqLM")
|
13
|
+
class T5GemmaSeq2SeqLM(Seq2SeqLM):
|
14
|
+
"""An end-to-end T5Gemma model for seq2seq language modeling.
|
15
|
+
|
16
|
+
A seq2seq language model (LM) is an encoder-decoder model which is used for
|
17
|
+
conditional text generation. The encoder is given a "context" text (fed to
|
18
|
+
the encoder), and the decoder predicts the next token based on both the
|
19
|
+
encoder inputs and the previous tokens. You can finetune `T5GemmaSeq2SeqLM`
|
20
|
+
to generate text for any seq2seq task (e.g., translation or summarization).
|
21
|
+
|
22
|
+
This model has a `generate()` method, which generates text based on a
|
23
|
+
prompt. The generation strategy used is controlled by an additional
|
24
|
+
`sampler` argument on `compile()`. You can recompile the model with
|
25
|
+
different `keras_hub.samplers` objects to control the generation. By
|
26
|
+
default, `"greedy"` sampling will be used.
|
27
|
+
|
28
|
+
This model can optionally be configured with a `preprocessor` layer, in
|
29
|
+
which case it will automatically apply preprocessing to string inputs during
|
30
|
+
`fit()`, `predict()`, `evaluate()` and `generate()`. This is done by default
|
31
|
+
when creating the model with `from_preset()`.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
backbone: A `keras_hub.models.T5GemmaBackbone` instance.
|
35
|
+
preprocessor: A `keras_hub.models.T5GemmaSeq2SeqLMPreprocessor` or
|
36
|
+
`None`. If `None`, this model will not apply preprocessing, and
|
37
|
+
inputs should be preprocessed before calling the model. Defaults
|
38
|
+
to `None`.
|
39
|
+
|
40
|
+
Examples:
|
41
|
+
|
42
|
+
Use `generate()` to do text generation.
|
43
|
+
```python
|
44
|
+
import numpy as np
|
45
|
+
t5gemma_lm = keras_hub.models.T5GemmaSeq2SeqLM.from_preset(
|
46
|
+
"t5gemma_b_b_prefixlm_it"
|
47
|
+
)
|
48
|
+
# Generate with encoder-only input.
|
49
|
+
t5gemma_lm.generate("The quick brown fox jumped.", max_length=30)
|
50
|
+
|
51
|
+
# Generate with batched encoder-only inputs.
|
52
|
+
t5gemma_lm.generate(
|
53
|
+
["The quick brown fox jumped.", "The whale."],
|
54
|
+
max_length=30
|
55
|
+
)
|
56
|
+
# Generate with encoder and decoder inputs.
|
57
|
+
t5gemma_lm.generate(
|
58
|
+
{
|
59
|
+
"encoder_text": "The quick brown fox jumped.",
|
60
|
+
"decoder_text": "A fast fox"
|
61
|
+
},
|
62
|
+
max_length=30
|
63
|
+
)
|
64
|
+
```
|
65
|
+
|
66
|
+
Compile the `generate()` function with a custom sampler.
|
67
|
+
```python
|
68
|
+
t5gemma_lm = keras_hub.models.T5GemmaSeq2SeqLM.from_preset(
|
69
|
+
"t5gemma_b_b_prefixlm_it"
|
70
|
+
)
|
71
|
+
t5gemma_lm.compile(sampler="top_k")
|
72
|
+
t5gemma_lm.generate("I want to say", max_length=30)
|
73
|
+
|
74
|
+
t5gemma_lm.compile(sampler=keras_hub.samplers.BeamSampler(num_beams=2))
|
75
|
+
t5gemma_lm.generate("I want to say", max_length=30)
|
76
|
+
```
|
77
|
+
|
78
|
+
Use `generate()` without preprocessing.
|
79
|
+
```python
|
80
|
+
# Preprocessed inputs, with encoder inputs corresponding to
|
81
|
+
# "The quick brown fox", and the decoder inputs to "A fast fox".
|
82
|
+
# Use `"padding_mask"` to indicate values that should not be overridden.
|
83
|
+
prompt = {
|
84
|
+
"encoder_token_ids": np.array([[2, 10, 133, 2119, 6219, 23602, 1, 0]]),
|
85
|
+
"encoder_padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 0]]),
|
86
|
+
"decoder_token_ids": np.array([[2, 133, 1769, 1, 0, 0, 0]]),
|
87
|
+
"decoder_padding_mask": np.array([[1, 1, 1, 1, 0, 0, 0]])
|
88
|
+
}
|
89
|
+
|
90
|
+
t5gemma_lm = keras_hub.models.T5GemmaSeq2SeqLM.from_preset(
|
91
|
+
"t5gemma_b_b_prefixlm_it",
|
92
|
+
preprocessor=None,
|
93
|
+
)
|
94
|
+
t5gemma_lm.generate(prompt)
|
95
|
+
```
|
96
|
+
|
97
|
+
Call `fit()` on a single batch.
|
98
|
+
```python
|
99
|
+
features = {
|
100
|
+
"encoder_text": ["The quick fox jumped.", "I forgot my homework."],
|
101
|
+
"decoder_text": ["The fast hazel fox leapt.", "I forgot my assignment."]
|
102
|
+
}
|
103
|
+
t5gemma_lm = keras_hub.models.T5GemmaSeq2SeqLM.from_preset(
|
104
|
+
"t5gemma_b_b_prefixlm_it"
|
105
|
+
)
|
106
|
+
t5gemma_lm.fit(x=features, batch_size=2)
|
107
|
+
```
|
108
|
+
|
109
|
+
Call `fit()` without preprocessing.
|
110
|
+
```python
|
111
|
+
x = {
|
112
|
+
"encoder_token_ids": np.array([[2, 133, 2119, 1, 0]] * 2),
|
113
|
+
"encoder_padding_mask": np.array([[1, 1, 1, 1, 0]] * 2),
|
114
|
+
"decoder_token_ids": np.array([[2, 133, 1769, 1, 0]] * 2),
|
115
|
+
"decoder_padding_mask": np.array([[1, 1, 1, 1, 1]] * 2),
|
116
|
+
}
|
117
|
+
y = np.array([[133, 1769, 1, 0, 0]] * 2)
|
118
|
+
sw = np.array([[1, 1, 1, 0, 0]] * 2)
|
119
|
+
|
120
|
+
t5gemma_lm = keras_hub.models.T5GemmaSeq2SeqLM.from_preset(
|
121
|
+
"t5gemma_b_b_prefixlm_it",
|
122
|
+
preprocessor=None,
|
123
|
+
)
|
124
|
+
t5gemma_lm.fit(x=x, y=y, sample_weight=sw, batch_size=2)
|
125
|
+
```
|
126
|
+
|
127
|
+
Custom backbone and vocabulary.
|
128
|
+
```python
|
129
|
+
features = {
|
130
|
+
"encoder_text": ["The quick fox jumped.", "I forgot my homework."],
|
131
|
+
"decoder_text": ["The fast hazel fox leapt.", "I forgot my assignment."]
|
132
|
+
}
|
133
|
+
tokenizer = keras_hub.models.T5GemmaTokenizer(
|
134
|
+
proto="proto.spm",
|
135
|
+
)
|
136
|
+
preprocessor = keras_hub.models.T5GemmaSeq2SeqLMPreprocessor(
|
137
|
+
tokenizer=tokenizer,
|
138
|
+
encoder_sequence_length=128,
|
139
|
+
decoder_sequence_length=128,
|
140
|
+
)
|
141
|
+
backbone = keras_hub.models.T5GemmaBackbone(
|
142
|
+
vocabulary_size=32000,
|
143
|
+
# Encoder parameters.
|
144
|
+
encoder_hidden_dim=256,
|
145
|
+
encoder_intermediate_dim=512,
|
146
|
+
encoder_num_layers=4,
|
147
|
+
encoder_num_attention_heads=4,
|
148
|
+
encoder_num_key_value_heads=2,
|
149
|
+
encoder_head_dim=64,
|
150
|
+
encoder_layer_types=["full_attention"] * 4,
|
151
|
+
# Decoder parameters.
|
152
|
+
decoder_hidden_dim=256,
|
153
|
+
decoder_intermediate_dim=512,
|
154
|
+
decoder_num_layers=4,
|
155
|
+
decoder_num_attention_heads=4,
|
156
|
+
decoder_num_key_value_heads=2,
|
157
|
+
decoder_head_dim=64,
|
158
|
+
decoder_layer_types=["full_attention"] * 4,
|
159
|
+
# Common parameters.
|
160
|
+
dropout_rate=0.1,
|
161
|
+
rms_norm_eps=1e-6,
|
162
|
+
query_pre_attn_scalar=1.0,
|
163
|
+
attention_bias=False,
|
164
|
+
hidden_activation="gelu_approximate",
|
165
|
+
)
|
166
|
+
t5gemma_lm = keras_hub.models.T5GemmaSeq2SeqLM(
|
167
|
+
backbone=backbone,
|
168
|
+
preprocessor=preprocessor,
|
169
|
+
)
|
170
|
+
t5gemma_lm.fit(x=features, batch_size=2)
|
171
|
+
```
|
172
|
+
"""
|
173
|
+
|
174
|
+
backbone_cls = T5GemmaBackbone
|
175
|
+
preprocessor_cls = T5GemmaSeq2SeqLMPreprocessor
|
176
|
+
|
177
|
+
def __init__(self, backbone, preprocessor=None, **kwargs):
|
178
|
+
# === Layers ===
|
179
|
+
self.backbone = backbone
|
180
|
+
self.preprocessor = preprocessor
|
181
|
+
|
182
|
+
# === Functional Model ===
|
183
|
+
# This must be "backbone.input" i.e. the full input structure,
|
184
|
+
# rather than "backbone.inputs" which is the flattened list of inputs.
|
185
|
+
inputs = backbone.input
|
186
|
+
sequence_output = backbone(inputs)["decoder_sequence_output"]
|
187
|
+
logits = backbone.decoder_token_embedding(sequence_output, reverse=True)
|
188
|
+
if self.backbone.final_logit_softcapping is not None:
|
189
|
+
logits = logits / self.backbone.final_logit_softcapping
|
190
|
+
logits = keras.ops.tanh(logits)
|
191
|
+
logits = logits * self.backbone.final_logit_softcapping
|
192
|
+
super().__init__(
|
193
|
+
inputs=inputs,
|
194
|
+
outputs=logits,
|
195
|
+
**kwargs,
|
196
|
+
)
|
197
|
+
|
198
|
+
def call_encoder(self, token_ids, padding_mask):
|
199
|
+
"""Process inputs through the encoder stack."""
|
200
|
+
encoder_embeddings = self.backbone.token_embedding(token_ids)
|
201
|
+
encoder_embeddings *= keras.ops.cast(
|
202
|
+
keras.ops.sqrt(self.backbone.encoder_hidden_dim),
|
203
|
+
encoder_embeddings.dtype,
|
204
|
+
)
|
205
|
+
encoder_hidden_states = self.backbone.encoder_dropout(
|
206
|
+
encoder_embeddings, training=False
|
207
|
+
)
|
208
|
+
for layer in self.backbone.encoder_layers:
|
209
|
+
encoder_hidden_states = layer(
|
210
|
+
encoder_hidden_states, padding_mask=padding_mask, training=False
|
211
|
+
)
|
212
|
+
encoder_output = self.backbone.encoder_norm(encoder_hidden_states)
|
213
|
+
encoder_output = self.backbone.encoder_dropout(
|
214
|
+
encoder_output, training=False
|
215
|
+
)
|
216
|
+
return encoder_output, padding_mask
|
217
|
+
|
218
|
+
def call_decoder_with_cache(
|
219
|
+
self,
|
220
|
+
decoder_token_ids,
|
221
|
+
decoder_padding_mask,
|
222
|
+
cache,
|
223
|
+
cache_update_index,
|
224
|
+
encoder_output,
|
225
|
+
encoder_padding_mask,
|
226
|
+
):
|
227
|
+
"""Forward pass of `T5GemmaSeq2SeqLM`'s decoder with cache.
|
228
|
+
|
229
|
+
`call_decoder_with_cache` adds an additional forward pass for the model
|
230
|
+
for autoregressive inference. Unlike calling the model directly, this
|
231
|
+
method allows caching previous key/value Tensors in the attention
|
232
|
+
layers, and avoids recomputing the outputs of seen tokens.
|
233
|
+
|
234
|
+
Args:
|
235
|
+
decoder_token_ids: A dense int Tensor with shape
|
236
|
+
`(batch_size, max_length)`. The token ids for the decoder.
|
237
|
+
decoder_padding_mask: A dense int Tensor with shape `(batch_size,
|
238
|
+
max_length)`. The padding mask for the decoder.
|
239
|
+
cache: A dense float Tensor, the cache of key and value states.
|
240
|
+
cache_update_index: int, or int Tensor. The index of the current
|
241
|
+
token being processed in the whole sequence.
|
242
|
+
encoder_output: A dense float Tensor. The output of the encoder.
|
243
|
+
encoder_padding_mask: A dense int Tensor. The padding mask for
|
244
|
+
the encoder output.
|
245
|
+
|
246
|
+
Returns:
|
247
|
+
A `(logits, hidden_states, cache)` tuple. Where `logits` is the
|
248
|
+
language model logits for the input token_ids, `hidden_states` is
|
249
|
+
the final hidden representation of the input tokens, and `cache` is
|
250
|
+
the updated decoding cache.
|
251
|
+
"""
|
252
|
+
self_attention_cache, cross_attention_cache = cache
|
253
|
+
hidden_states = self.backbone.decoder_token_embedding(decoder_token_ids)
|
254
|
+
hidden_states *= keras.ops.cast(
|
255
|
+
keras.ops.sqrt(self.backbone.decoder_hidden_dim),
|
256
|
+
hidden_states.dtype,
|
257
|
+
)
|
258
|
+
hidden_states = self.backbone.decoder_dropout(
|
259
|
+
hidden_states, training=False
|
260
|
+
)
|
261
|
+
# Every decoder layer has a separate cache for the self-attention layer
|
262
|
+
# and the cross-attention layer. We update all of them separately.
|
263
|
+
updated_self_attention_caches = []
|
264
|
+
updated_cross_attention_caches = []
|
265
|
+
for i, layer in enumerate(self.backbone.decoder_layers):
|
266
|
+
layer_self_cache = (
|
267
|
+
self_attention_cache[:, i, ...]
|
268
|
+
if self_attention_cache is not None
|
269
|
+
else None
|
270
|
+
)
|
271
|
+
layer_cross_cache = (
|
272
|
+
cross_attention_cache[:, i, ...]
|
273
|
+
if cross_attention_cache is not None
|
274
|
+
else None
|
275
|
+
)
|
276
|
+
layer_cache = (layer_self_cache, layer_cross_cache)
|
277
|
+
hidden_states, updated_layer_cache = layer(
|
278
|
+
(hidden_states, encoder_output),
|
279
|
+
self_attention_padding_mask=decoder_padding_mask,
|
280
|
+
cross_attention_padding_mask=encoder_padding_mask,
|
281
|
+
cache=layer_cache,
|
282
|
+
cache_update_index=cache_update_index,
|
283
|
+
training=False,
|
284
|
+
)
|
285
|
+
new_self_cache, new_cross_cache = updated_layer_cache
|
286
|
+
updated_self_attention_caches.append(new_self_cache)
|
287
|
+
updated_cross_attention_caches.append(new_cross_cache)
|
288
|
+
self_attention_cache = keras.ops.stack(
|
289
|
+
updated_self_attention_caches, axis=1
|
290
|
+
)
|
291
|
+
cross_attention_cache = keras.ops.stack(
|
292
|
+
updated_cross_attention_caches, axis=1
|
293
|
+
)
|
294
|
+
hidden_states = self.backbone.decoder_norm(hidden_states)
|
295
|
+
logits = self.backbone.decoder_token_embedding(
|
296
|
+
hidden_states, reverse=True
|
297
|
+
)
|
298
|
+
if self.backbone.final_logit_softcapping is not None:
|
299
|
+
logits = logits / self.backbone.final_logit_softcapping
|
300
|
+
logits = keras.ops.tanh(logits)
|
301
|
+
logits = logits * self.backbone.final_logit_softcapping
|
302
|
+
return (
|
303
|
+
logits,
|
304
|
+
hidden_states,
|
305
|
+
(self_attention_cache, cross_attention_cache),
|
306
|
+
)
|
307
|
+
|
308
|
+
def _build_cache(
|
309
|
+
self,
|
310
|
+
encoder_token_ids,
|
311
|
+
encoder_padding_mask,
|
312
|
+
decoder_token_ids,
|
313
|
+
decoder_padding_mask,
|
314
|
+
):
|
315
|
+
"""Build an empty cache for use with `call_with_cache()`."""
|
316
|
+
encoder_output, encoder_padding_mask = self.call_encoder(
|
317
|
+
encoder_token_ids, encoder_padding_mask
|
318
|
+
)
|
319
|
+
batch_size = keras.ops.shape(decoder_token_ids)[0]
|
320
|
+
num_layers = self.backbone.decoder_num_layers
|
321
|
+
num_kv_heads = self.backbone.decoder_num_key_value_heads
|
322
|
+
head_dim = self.backbone.decoder_head_dim
|
323
|
+
self_cache_shape = (
|
324
|
+
batch_size,
|
325
|
+
num_layers,
|
326
|
+
2,
|
327
|
+
keras.ops.shape(decoder_token_ids)[1],
|
328
|
+
num_kv_heads,
|
329
|
+
head_dim,
|
330
|
+
)
|
331
|
+
self_attention_cache = keras.ops.zeros(
|
332
|
+
self_cache_shape, dtype=self.compute_dtype
|
333
|
+
)
|
334
|
+
cross_attention_cache = None
|
335
|
+
_, hidden_states, cache = self.call_decoder_with_cache(
|
336
|
+
decoder_token_ids=decoder_token_ids,
|
337
|
+
decoder_padding_mask=decoder_padding_mask,
|
338
|
+
cache=(self_attention_cache, cross_attention_cache),
|
339
|
+
cache_update_index=0,
|
340
|
+
encoder_output=encoder_output,
|
341
|
+
encoder_padding_mask=encoder_padding_mask,
|
342
|
+
)
|
343
|
+
extra_cache_info = (encoder_output, encoder_padding_mask)
|
344
|
+
return hidden_states, cache, extra_cache_info
|
345
|
+
|
346
|
+
def generate_step(self, inputs, stop_token_ids=None):
|
347
|
+
"""A compilable generation function for a single batch of inputs.
|
348
|
+
|
349
|
+
This function represents the inner, XLA-compilable, generation function
|
350
|
+
for a single batch of inputs. Inputs should have the same structure as
|
351
|
+
model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`.
|
352
|
+
`"encoder_token_ids"`, `"encoder_padding_mask"`, `"decoder_token_ids"`
|
353
|
+
and `"decoder_padding_mask"`.
|
354
|
+
|
355
|
+
Args:
|
356
|
+
inputs: A dictionary with four keys - `"encoder_token_ids"`,
|
357
|
+
`"encoder_padding_mask"`, `"decoder_token_ids"` and
|
358
|
+
`"decoder_padding_mask"`, with batched tensor values.
|
359
|
+
stop_token_ids: Tuple of id's of end token's to stop on. If all
|
360
|
+
sequences have produced a new stop token, generation
|
361
|
+
will stop.
|
362
|
+
"""
|
363
|
+
encoder_token_ids = inputs["encoder_token_ids"]
|
364
|
+
encoder_padding_mask = inputs["encoder_padding_mask"]
|
365
|
+
decoder_token_ids = inputs["decoder_token_ids"]
|
366
|
+
decoder_padding_mask = inputs["decoder_padding_mask"]
|
367
|
+
# Create and seed cache with a single forward pass.
|
368
|
+
hidden_states, cache, extra_cache_info = self._build_cache(
|
369
|
+
encoder_token_ids=encoder_token_ids,
|
370
|
+
encoder_padding_mask=encoder_padding_mask,
|
371
|
+
decoder_token_ids=decoder_token_ids,
|
372
|
+
decoder_padding_mask=decoder_padding_mask,
|
373
|
+
)
|
374
|
+
encoder_output, encoder_padding_mask = extra_cache_info
|
375
|
+
# Compute the lengths of all user inputted tokens ids.
|
376
|
+
row_lengths = keras.ops.sum(
|
377
|
+
keras.ops.cast(decoder_padding_mask, "int32"), axis=-1
|
378
|
+
)
|
379
|
+
# Start at the first index that has no user inputted id.
|
380
|
+
index = keras.ops.min(row_lengths)
|
381
|
+
|
382
|
+
def next(prompt, cache, index):
|
383
|
+
# The cache index is the index of our previous token.
|
384
|
+
cache_update_index = index - 1
|
385
|
+
batch_size = keras.ops.shape(prompt)[0]
|
386
|
+
prompt = keras.ops.slice(
|
387
|
+
prompt, [0, cache_update_index], [batch_size, 1]
|
388
|
+
)
|
389
|
+
(
|
390
|
+
logits,
|
391
|
+
_,
|
392
|
+
updated_cache,
|
393
|
+
) = self.call_decoder_with_cache(
|
394
|
+
decoder_token_ids=prompt,
|
395
|
+
decoder_padding_mask=None,
|
396
|
+
cache_update_index=cache_update_index,
|
397
|
+
cache=cache,
|
398
|
+
encoder_output=encoder_output,
|
399
|
+
encoder_padding_mask=encoder_padding_mask,
|
400
|
+
)
|
401
|
+
return keras.ops.squeeze(logits, axis=1), None, updated_cache
|
402
|
+
|
403
|
+
decoder_token_ids = self.sampler(
|
404
|
+
next=next,
|
405
|
+
prompt=decoder_token_ids,
|
406
|
+
cache=cache,
|
407
|
+
index=index,
|
408
|
+
mask=decoder_padding_mask,
|
409
|
+
stop_token_ids=stop_token_ids,
|
410
|
+
hidden_states=hidden_states,
|
411
|
+
model=self,
|
412
|
+
)
|
413
|
+
|
414
|
+
# Compute an output padding mask with the token ids we updated.
|
415
|
+
if stop_token_ids is not None:
|
416
|
+
# Build a mask of `stop_token_ids` locations not in the original
|
417
|
+
# prompt (not in locations where `decoder_padding_mask` is True).
|
418
|
+
end_locations = any_equal(
|
419
|
+
decoder_token_ids,
|
420
|
+
stop_token_ids,
|
421
|
+
keras.ops.logical_not(decoder_padding_mask),
|
422
|
+
)
|
423
|
+
# Use cumsum to get ones in all locations after end_locations.
|
424
|
+
end_locations = keras.ops.cast(end_locations, "int32")
|
425
|
+
cumsum = keras.ops.cast(
|
426
|
+
keras.ops.cumsum(end_locations, axis=-1), "int32"
|
427
|
+
)
|
428
|
+
overflow = cumsum - end_locations
|
429
|
+
# Our padding mask is the inverse of these overflow locations.
|
430
|
+
decoder_padding_mask = keras.ops.logical_not(
|
431
|
+
keras.ops.cast(overflow, "bool")
|
432
|
+
)
|
433
|
+
else:
|
434
|
+
# Without early stopping, all locations will have been updated.
|
435
|
+
decoder_padding_mask = keras.ops.ones_like(
|
436
|
+
decoder_token_ids, dtype="bool"
|
437
|
+
)
|
438
|
+
|
439
|
+
return {
|
440
|
+
"decoder_token_ids": decoder_token_ids,
|
441
|
+
"decoder_padding_mask": decoder_padding_mask,
|
442
|
+
}
|
@@ -0,0 +1,216 @@
|
|
1
|
+
import keras
|
2
|
+
|
3
|
+
from keras_hub.src.api_export import keras_hub_export
|
4
|
+
from keras_hub.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor
|
5
|
+
from keras_hub.src.models.t5gemma.t5gemma_backbone import T5GemmaBackbone
|
6
|
+
from keras_hub.src.models.t5gemma.t5gemma_tokenizer import T5GemmaTokenizer
|
7
|
+
from keras_hub.src.utils.tensor_utils import preprocessing_function
|
8
|
+
|
9
|
+
try:
|
10
|
+
import tensorflow as tf
|
11
|
+
except ImportError:
|
12
|
+
tf = None
|
13
|
+
|
14
|
+
|
15
|
+
@keras_hub_export("keras_hub.models.T5GemmaSeq2SeqLMPreprocessor")
|
16
|
+
class T5GemmaSeq2SeqLMPreprocessor(Seq2SeqLMPreprocessor):
|
17
|
+
"""T5Gemma Seq2Seq LM preprocessor.
|
18
|
+
|
19
|
+
This preprocessing layer is meant for use with
|
20
|
+
`keras_hub.models.T5GemmaSeq2SeqLM`. By default, it will take in batches of
|
21
|
+
strings, and return outputs in a `(x, y, sample_weight)` format, where the
|
22
|
+
`y` label is the next token id in the `x` sequence.
|
23
|
+
|
24
|
+
For use with generation, the layer also exposes two methods
|
25
|
+
`generate_preprocess()` and `generate_postprocess()`. When this preprocessor
|
26
|
+
is attached to a `keras_hub.models.T5GemmaSeq2SeqLM` instance, these methods
|
27
|
+
will be called implicitly in `generate()`. They can also be called
|
28
|
+
standalone (e.g. to precompute preprocessing inputs for generation in a
|
29
|
+
separate process).
|
30
|
+
|
31
|
+
Args:
|
32
|
+
tokenizer: A `keras_hub.models.T5GemmaTokenizer` instance.
|
33
|
+
encoder_sequence_length: The length of the packed encoder inputs.
|
34
|
+
decoder_sequence_length: The length of the packed decoder inputs.
|
35
|
+
add_start_token: If `True`, the preprocessor will prepend the
|
36
|
+
tokenizer start token to each input sequence. For T5Gemma models,
|
37
|
+
this should be `False`. Defaults to `False`.
|
38
|
+
add_end_token: If `True`, the preprocessor will append the tokenizer end
|
39
|
+
token to each input sequence. For T5Gemma models, this should be
|
40
|
+
`True`. Defaults to `True`.
|
41
|
+
|
42
|
+
Call arguments:
|
43
|
+
x: A dictionary with two keys, `"encoder_text"` and `"decoder_text"`.
|
44
|
+
The values can be a string, a `tf.Tensor` or a list of python
|
45
|
+
strings.
|
46
|
+
y: Label data. Should always be `None` as the layer generates labels.
|
47
|
+
sample_weight: Label weights. Should always be `None` as the layer
|
48
|
+
generates label weights.
|
49
|
+
encoder_sequence_length: Pass to override the configured
|
50
|
+
`encoder_sequence_length` of the layer.
|
51
|
+
decoder_sequence_length: Pass to override the configured
|
52
|
+
`decoder_sequence_length` of the layer.
|
53
|
+
|
54
|
+
Examples:
|
55
|
+
```python
|
56
|
+
import tensorflow as tf
|
57
|
+
import numpy as np
|
58
|
+
|
59
|
+
# Load the preprocessor from a preset.
|
60
|
+
preprocessor = keras_hub.models.T5GemmaSeq2SeqLMPreprocessor.from_preset(
|
61
|
+
"t5gemma_b_b_prefixlm_it"
|
62
|
+
)
|
63
|
+
|
64
|
+
# For example usage, see the dictionary example below which provides
|
65
|
+
# both encoder and decoder text.
|
66
|
+
# Tokenize a batch of sentences.
|
67
|
+
preprocessor(["The quick brown fox jumped.", "Call me Ishmael."])
|
68
|
+
# Tokenize a dictionary with separate encoder and decoder inputs.
|
69
|
+
preprocessor({
|
70
|
+
"encoder_text": "The quick brown fox jumped.",
|
71
|
+
"decoder_text": "The fast fox."
|
72
|
+
})
|
73
|
+
|
74
|
+
# Apply tokenization to a `tf.data.Dataset`.
|
75
|
+
encoder_features = tf.constant(["The quick brown fox.", "Call me Ishmael."])
|
76
|
+
decoder_features = tf.constant(["The fast fox.", "I am Ishmael."])
|
77
|
+
ds = tf.data.Dataset.from_tensor_slices(
|
78
|
+
{"encoder_text": encoder_features, "decoder_text": decoder_features}
|
79
|
+
)
|
80
|
+
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
|
81
|
+
|
82
|
+
# Prepare tokens for generation.
|
83
|
+
preprocessor.generate_preprocess({
|
84
|
+
"encoder_text": "The quick brown fox jumped.",
|
85
|
+
"decoder_text": "The fast fox."
|
86
|
+
})
|
87
|
+
|
88
|
+
# Map generation outputs back to strings.
|
89
|
+
preprocessor.generate_postprocess({
|
90
|
+
'decoder_token_ids': np.array([[2, 714, 4320, 8426, 25341, 1, 0, 0]]),
|
91
|
+
'decoder_padding_mask': np.array([[1, 1, 1, 1, 1, 1, 0, 0]]),
|
92
|
+
})
|
93
|
+
```
|
94
|
+
"""
|
95
|
+
|
96
|
+
backbone_cls = T5GemmaBackbone
|
97
|
+
tokenizer_cls = T5GemmaTokenizer
|
98
|
+
|
99
|
+
def __init__(
|
100
|
+
self,
|
101
|
+
tokenizer,
|
102
|
+
encoder_sequence_length=512,
|
103
|
+
decoder_sequence_length=512,
|
104
|
+
add_start_token=False,
|
105
|
+
add_end_token=True,
|
106
|
+
**kwargs,
|
107
|
+
):
|
108
|
+
# Do not pass `add_start_token` and `add_end_token` to the base class.
|
109
|
+
super().__init__(
|
110
|
+
tokenizer=tokenizer,
|
111
|
+
encoder_sequence_length=encoder_sequence_length,
|
112
|
+
decoder_sequence_length=decoder_sequence_length,
|
113
|
+
**kwargs,
|
114
|
+
)
|
115
|
+
# Store them directly on the subclass instance.
|
116
|
+
self.add_start_token = add_start_token
|
117
|
+
self.add_end_token = add_end_token
|
118
|
+
|
119
|
+
@preprocessing_function
|
120
|
+
def call(
|
121
|
+
self,
|
122
|
+
x,
|
123
|
+
y=None,
|
124
|
+
sample_weight=None,
|
125
|
+
*,
|
126
|
+
encoder_sequence_length=None,
|
127
|
+
decoder_sequence_length=None,
|
128
|
+
sequence_length=None,
|
129
|
+
):
|
130
|
+
if encoder_sequence_length is None:
|
131
|
+
encoder_sequence_length = self.encoder_sequence_length
|
132
|
+
decoder_sequence_length = decoder_sequence_length or sequence_length
|
133
|
+
if decoder_sequence_length is None:
|
134
|
+
decoder_sequence_length = self.decoder_sequence_length
|
135
|
+
|
136
|
+
encoder_inputs = self.tokenizer(x["encoder_text"])
|
137
|
+
encoder_token_ids, encoder_padding_mask = self.encoder_packer(
|
138
|
+
encoder_inputs,
|
139
|
+
sequence_length=encoder_sequence_length,
|
140
|
+
add_start_value=self.add_start_token,
|
141
|
+
add_end_value=self.add_end_token,
|
142
|
+
)
|
143
|
+
decoder_inputs = self.tokenizer(x["decoder_text"])
|
144
|
+
decoder_token_ids, decoder_padding_mask = self.decoder_packer(
|
145
|
+
decoder_inputs,
|
146
|
+
sequence_length=decoder_sequence_length + 1,
|
147
|
+
add_start_value=True,
|
148
|
+
add_end_value=self.add_end_token,
|
149
|
+
)
|
150
|
+
x = {
|
151
|
+
"encoder_token_ids": encoder_token_ids,
|
152
|
+
"encoder_padding_mask": encoder_padding_mask,
|
153
|
+
"decoder_token_ids": decoder_token_ids[..., :-1],
|
154
|
+
"decoder_padding_mask": decoder_padding_mask[..., :-1],
|
155
|
+
}
|
156
|
+
y = decoder_token_ids[..., 1:]
|
157
|
+
sample_weight = decoder_padding_mask[..., 1:]
|
158
|
+
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
|
159
|
+
|
160
|
+
@preprocessing_function
|
161
|
+
def generate_preprocess(
|
162
|
+
self,
|
163
|
+
x,
|
164
|
+
*,
|
165
|
+
encoder_sequence_length=None,
|
166
|
+
decoder_sequence_length=None,
|
167
|
+
sequence_length=None,
|
168
|
+
):
|
169
|
+
if not self.built:
|
170
|
+
self.build(None)
|
171
|
+
|
172
|
+
if isinstance(x, dict):
|
173
|
+
encoder_text = x["encoder_text"]
|
174
|
+
decoder_text = x["decoder_text"]
|
175
|
+
else:
|
176
|
+
encoder_text = x
|
177
|
+
decoder_text = tf.fill((tf.shape(encoder_text)[0],), "")
|
178
|
+
|
179
|
+
if encoder_sequence_length is None:
|
180
|
+
encoder_sequence_length = self.encoder_sequence_length
|
181
|
+
decoder_sequence_length = decoder_sequence_length or sequence_length
|
182
|
+
if decoder_sequence_length is None:
|
183
|
+
decoder_sequence_length = self.decoder_sequence_length
|
184
|
+
|
185
|
+
encoder_token_ids = self.tokenizer(encoder_text)
|
186
|
+
encoder_token_ids, encoder_padding_mask = self.encoder_packer(
|
187
|
+
encoder_token_ids,
|
188
|
+
sequence_length=None,
|
189
|
+
add_start_value=self.add_start_token,
|
190
|
+
add_end_value=False,
|
191
|
+
)
|
192
|
+
|
193
|
+
decoder_token_ids = self.tokenizer(decoder_text)
|
194
|
+
decoder_token_ids, decoder_padding_mask = self.decoder_packer(
|
195
|
+
decoder_token_ids,
|
196
|
+
sequence_length=decoder_sequence_length,
|
197
|
+
add_start_value=True,
|
198
|
+
add_end_value=False,
|
199
|
+
)
|
200
|
+
|
201
|
+
return {
|
202
|
+
"encoder_token_ids": encoder_token_ids,
|
203
|
+
"encoder_padding_mask": encoder_padding_mask,
|
204
|
+
"decoder_token_ids": decoder_token_ids,
|
205
|
+
"decoder_padding_mask": decoder_padding_mask,
|
206
|
+
}
|
207
|
+
|
208
|
+
def get_config(self):
|
209
|
+
config = super().get_config()
|
210
|
+
config.update(
|
211
|
+
{
|
212
|
+
"add_start_token": self.add_start_token,
|
213
|
+
"add_end_token": self.add_end_token,
|
214
|
+
}
|
215
|
+
)
|
216
|
+
return config
|