keras-hub-nightly 0.23.0.dev202508250413__py3-none-any.whl → 0.23.0.dev202508260411__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,15 @@
1
+ # Metadata for loading pretrained model weights.
2
+ backbone_presets = {
3
+ "t5gemma_b_b_prefixlm_it": {
4
+ "metadata": {
5
+ "description": (
6
+ "T5Gemma B/B model with a base encoder and base decoder, "
7
+ "adapted as a prefix language model and fine-tuned for "
8
+ "instruction following."
9
+ ),
10
+ "params": 591490560,
11
+ "path": "t5gemma",
12
+ },
13
+ "kaggle_handle": "kaggle://harshaljanjani/t5gemma/keras/t5gemma_b_b_prefixlm_it",
14
+ },
15
+ }
@@ -0,0 +1,442 @@
1
+ import keras
2
+
3
+ from keras_hub.src.api_export import keras_hub_export
4
+ from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM
5
+ from keras_hub.src.models.t5gemma.t5gemma_backbone import T5GemmaBackbone
6
+ from keras_hub.src.models.t5gemma.t5gemma_seq_2_seq_lm_preprocessor import (
7
+ T5GemmaSeq2SeqLMPreprocessor,
8
+ )
9
+ from keras_hub.src.utils.tensor_utils import any_equal
10
+
11
+
12
+ @keras_hub_export("keras_hub.models.T5GemmaSeq2SeqLM")
13
+ class T5GemmaSeq2SeqLM(Seq2SeqLM):
14
+ """An end-to-end T5Gemma model for seq2seq language modeling.
15
+
16
+ A seq2seq language model (LM) is an encoder-decoder model which is used for
17
+ conditional text generation. The encoder is given a "context" text (fed to
18
+ the encoder), and the decoder predicts the next token based on both the
19
+ encoder inputs and the previous tokens. You can finetune `T5GemmaSeq2SeqLM`
20
+ to generate text for any seq2seq task (e.g., translation or summarization).
21
+
22
+ This model has a `generate()` method, which generates text based on a
23
+ prompt. The generation strategy used is controlled by an additional
24
+ `sampler` argument on `compile()`. You can recompile the model with
25
+ different `keras_hub.samplers` objects to control the generation. By
26
+ default, `"greedy"` sampling will be used.
27
+
28
+ This model can optionally be configured with a `preprocessor` layer, in
29
+ which case it will automatically apply preprocessing to string inputs during
30
+ `fit()`, `predict()`, `evaluate()` and `generate()`. This is done by default
31
+ when creating the model with `from_preset()`.
32
+
33
+ Args:
34
+ backbone: A `keras_hub.models.T5GemmaBackbone` instance.
35
+ preprocessor: A `keras_hub.models.T5GemmaSeq2SeqLMPreprocessor` or
36
+ `None`. If `None`, this model will not apply preprocessing, and
37
+ inputs should be preprocessed before calling the model. Defaults
38
+ to `None`.
39
+
40
+ Examples:
41
+
42
+ Use `generate()` to do text generation.
43
+ ```python
44
+ import numpy as np
45
+ t5gemma_lm = keras_hub.models.T5GemmaSeq2SeqLM.from_preset(
46
+ "t5gemma_b_b_prefixlm_it"
47
+ )
48
+ # Generate with encoder-only input.
49
+ t5gemma_lm.generate("The quick brown fox jumped.", max_length=30)
50
+
51
+ # Generate with batched encoder-only inputs.
52
+ t5gemma_lm.generate(
53
+ ["The quick brown fox jumped.", "The whale."],
54
+ max_length=30
55
+ )
56
+ # Generate with encoder and decoder inputs.
57
+ t5gemma_lm.generate(
58
+ {
59
+ "encoder_text": "The quick brown fox jumped.",
60
+ "decoder_text": "A fast fox"
61
+ },
62
+ max_length=30
63
+ )
64
+ ```
65
+
66
+ Compile the `generate()` function with a custom sampler.
67
+ ```python
68
+ t5gemma_lm = keras_hub.models.T5GemmaSeq2SeqLM.from_preset(
69
+ "t5gemma_b_b_prefixlm_it"
70
+ )
71
+ t5gemma_lm.compile(sampler="top_k")
72
+ t5gemma_lm.generate("I want to say", max_length=30)
73
+
74
+ t5gemma_lm.compile(sampler=keras_hub.samplers.BeamSampler(num_beams=2))
75
+ t5gemma_lm.generate("I want to say", max_length=30)
76
+ ```
77
+
78
+ Use `generate()` without preprocessing.
79
+ ```python
80
+ # Preprocessed inputs, with encoder inputs corresponding to
81
+ # "The quick brown fox", and the decoder inputs to "A fast fox".
82
+ # Use `"padding_mask"` to indicate values that should not be overridden.
83
+ prompt = {
84
+ "encoder_token_ids": np.array([[2, 10, 133, 2119, 6219, 23602, 1, 0]]),
85
+ "encoder_padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 0]]),
86
+ "decoder_token_ids": np.array([[2, 133, 1769, 1, 0, 0, 0]]),
87
+ "decoder_padding_mask": np.array([[1, 1, 1, 1, 0, 0, 0]])
88
+ }
89
+
90
+ t5gemma_lm = keras_hub.models.T5GemmaSeq2SeqLM.from_preset(
91
+ "t5gemma_b_b_prefixlm_it",
92
+ preprocessor=None,
93
+ )
94
+ t5gemma_lm.generate(prompt)
95
+ ```
96
+
97
+ Call `fit()` on a single batch.
98
+ ```python
99
+ features = {
100
+ "encoder_text": ["The quick fox jumped.", "I forgot my homework."],
101
+ "decoder_text": ["The fast hazel fox leapt.", "I forgot my assignment."]
102
+ }
103
+ t5gemma_lm = keras_hub.models.T5GemmaSeq2SeqLM.from_preset(
104
+ "t5gemma_b_b_prefixlm_it"
105
+ )
106
+ t5gemma_lm.fit(x=features, batch_size=2)
107
+ ```
108
+
109
+ Call `fit()` without preprocessing.
110
+ ```python
111
+ x = {
112
+ "encoder_token_ids": np.array([[2, 133, 2119, 1, 0]] * 2),
113
+ "encoder_padding_mask": np.array([[1, 1, 1, 1, 0]] * 2),
114
+ "decoder_token_ids": np.array([[2, 133, 1769, 1, 0]] * 2),
115
+ "decoder_padding_mask": np.array([[1, 1, 1, 1, 1]] * 2),
116
+ }
117
+ y = np.array([[133, 1769, 1, 0, 0]] * 2)
118
+ sw = np.array([[1, 1, 1, 0, 0]] * 2)
119
+
120
+ t5gemma_lm = keras_hub.models.T5GemmaSeq2SeqLM.from_preset(
121
+ "t5gemma_b_b_prefixlm_it",
122
+ preprocessor=None,
123
+ )
124
+ t5gemma_lm.fit(x=x, y=y, sample_weight=sw, batch_size=2)
125
+ ```
126
+
127
+ Custom backbone and vocabulary.
128
+ ```python
129
+ features = {
130
+ "encoder_text": ["The quick fox jumped.", "I forgot my homework."],
131
+ "decoder_text": ["The fast hazel fox leapt.", "I forgot my assignment."]
132
+ }
133
+ tokenizer = keras_hub.models.T5GemmaTokenizer(
134
+ proto="proto.spm",
135
+ )
136
+ preprocessor = keras_hub.models.T5GemmaSeq2SeqLMPreprocessor(
137
+ tokenizer=tokenizer,
138
+ encoder_sequence_length=128,
139
+ decoder_sequence_length=128,
140
+ )
141
+ backbone = keras_hub.models.T5GemmaBackbone(
142
+ vocabulary_size=32000,
143
+ # Encoder parameters.
144
+ encoder_hidden_dim=256,
145
+ encoder_intermediate_dim=512,
146
+ encoder_num_layers=4,
147
+ encoder_num_attention_heads=4,
148
+ encoder_num_key_value_heads=2,
149
+ encoder_head_dim=64,
150
+ encoder_layer_types=["full_attention"] * 4,
151
+ # Decoder parameters.
152
+ decoder_hidden_dim=256,
153
+ decoder_intermediate_dim=512,
154
+ decoder_num_layers=4,
155
+ decoder_num_attention_heads=4,
156
+ decoder_num_key_value_heads=2,
157
+ decoder_head_dim=64,
158
+ decoder_layer_types=["full_attention"] * 4,
159
+ # Common parameters.
160
+ dropout_rate=0.1,
161
+ rms_norm_eps=1e-6,
162
+ query_pre_attn_scalar=1.0,
163
+ attention_bias=False,
164
+ hidden_activation="gelu_approximate",
165
+ )
166
+ t5gemma_lm = keras_hub.models.T5GemmaSeq2SeqLM(
167
+ backbone=backbone,
168
+ preprocessor=preprocessor,
169
+ )
170
+ t5gemma_lm.fit(x=features, batch_size=2)
171
+ ```
172
+ """
173
+
174
+ backbone_cls = T5GemmaBackbone
175
+ preprocessor_cls = T5GemmaSeq2SeqLMPreprocessor
176
+
177
+ def __init__(self, backbone, preprocessor=None, **kwargs):
178
+ # === Layers ===
179
+ self.backbone = backbone
180
+ self.preprocessor = preprocessor
181
+
182
+ # === Functional Model ===
183
+ # This must be "backbone.input" i.e. the full input structure,
184
+ # rather than "backbone.inputs" which is the flattened list of inputs.
185
+ inputs = backbone.input
186
+ sequence_output = backbone(inputs)["decoder_sequence_output"]
187
+ logits = backbone.decoder_token_embedding(sequence_output, reverse=True)
188
+ if self.backbone.final_logit_softcapping is not None:
189
+ logits = logits / self.backbone.final_logit_softcapping
190
+ logits = keras.ops.tanh(logits)
191
+ logits = logits * self.backbone.final_logit_softcapping
192
+ super().__init__(
193
+ inputs=inputs,
194
+ outputs=logits,
195
+ **kwargs,
196
+ )
197
+
198
+ def call_encoder(self, token_ids, padding_mask):
199
+ """Process inputs through the encoder stack."""
200
+ encoder_embeddings = self.backbone.token_embedding(token_ids)
201
+ encoder_embeddings *= keras.ops.cast(
202
+ keras.ops.sqrt(self.backbone.encoder_hidden_dim),
203
+ encoder_embeddings.dtype,
204
+ )
205
+ encoder_hidden_states = self.backbone.encoder_dropout(
206
+ encoder_embeddings, training=False
207
+ )
208
+ for layer in self.backbone.encoder_layers:
209
+ encoder_hidden_states = layer(
210
+ encoder_hidden_states, padding_mask=padding_mask, training=False
211
+ )
212
+ encoder_output = self.backbone.encoder_norm(encoder_hidden_states)
213
+ encoder_output = self.backbone.encoder_dropout(
214
+ encoder_output, training=False
215
+ )
216
+ return encoder_output, padding_mask
217
+
218
+ def call_decoder_with_cache(
219
+ self,
220
+ decoder_token_ids,
221
+ decoder_padding_mask,
222
+ cache,
223
+ cache_update_index,
224
+ encoder_output,
225
+ encoder_padding_mask,
226
+ ):
227
+ """Forward pass of `T5GemmaSeq2SeqLM`'s decoder with cache.
228
+
229
+ `call_decoder_with_cache` adds an additional forward pass for the model
230
+ for autoregressive inference. Unlike calling the model directly, this
231
+ method allows caching previous key/value Tensors in the attention
232
+ layers, and avoids recomputing the outputs of seen tokens.
233
+
234
+ Args:
235
+ decoder_token_ids: A dense int Tensor with shape
236
+ `(batch_size, max_length)`. The token ids for the decoder.
237
+ decoder_padding_mask: A dense int Tensor with shape `(batch_size,
238
+ max_length)`. The padding mask for the decoder.
239
+ cache: A dense float Tensor, the cache of key and value states.
240
+ cache_update_index: int, or int Tensor. The index of the current
241
+ token being processed in the whole sequence.
242
+ encoder_output: A dense float Tensor. The output of the encoder.
243
+ encoder_padding_mask: A dense int Tensor. The padding mask for
244
+ the encoder output.
245
+
246
+ Returns:
247
+ A `(logits, hidden_states, cache)` tuple. Where `logits` is the
248
+ language model logits for the input token_ids, `hidden_states` is
249
+ the final hidden representation of the input tokens, and `cache` is
250
+ the updated decoding cache.
251
+ """
252
+ self_attention_cache, cross_attention_cache = cache
253
+ hidden_states = self.backbone.decoder_token_embedding(decoder_token_ids)
254
+ hidden_states *= keras.ops.cast(
255
+ keras.ops.sqrt(self.backbone.decoder_hidden_dim),
256
+ hidden_states.dtype,
257
+ )
258
+ hidden_states = self.backbone.decoder_dropout(
259
+ hidden_states, training=False
260
+ )
261
+ # Every decoder layer has a separate cache for the self-attention layer
262
+ # and the cross-attention layer. We update all of them separately.
263
+ updated_self_attention_caches = []
264
+ updated_cross_attention_caches = []
265
+ for i, layer in enumerate(self.backbone.decoder_layers):
266
+ layer_self_cache = (
267
+ self_attention_cache[:, i, ...]
268
+ if self_attention_cache is not None
269
+ else None
270
+ )
271
+ layer_cross_cache = (
272
+ cross_attention_cache[:, i, ...]
273
+ if cross_attention_cache is not None
274
+ else None
275
+ )
276
+ layer_cache = (layer_self_cache, layer_cross_cache)
277
+ hidden_states, updated_layer_cache = layer(
278
+ (hidden_states, encoder_output),
279
+ self_attention_padding_mask=decoder_padding_mask,
280
+ cross_attention_padding_mask=encoder_padding_mask,
281
+ cache=layer_cache,
282
+ cache_update_index=cache_update_index,
283
+ training=False,
284
+ )
285
+ new_self_cache, new_cross_cache = updated_layer_cache
286
+ updated_self_attention_caches.append(new_self_cache)
287
+ updated_cross_attention_caches.append(new_cross_cache)
288
+ self_attention_cache = keras.ops.stack(
289
+ updated_self_attention_caches, axis=1
290
+ )
291
+ cross_attention_cache = keras.ops.stack(
292
+ updated_cross_attention_caches, axis=1
293
+ )
294
+ hidden_states = self.backbone.decoder_norm(hidden_states)
295
+ logits = self.backbone.decoder_token_embedding(
296
+ hidden_states, reverse=True
297
+ )
298
+ if self.backbone.final_logit_softcapping is not None:
299
+ logits = logits / self.backbone.final_logit_softcapping
300
+ logits = keras.ops.tanh(logits)
301
+ logits = logits * self.backbone.final_logit_softcapping
302
+ return (
303
+ logits,
304
+ hidden_states,
305
+ (self_attention_cache, cross_attention_cache),
306
+ )
307
+
308
+ def _build_cache(
309
+ self,
310
+ encoder_token_ids,
311
+ encoder_padding_mask,
312
+ decoder_token_ids,
313
+ decoder_padding_mask,
314
+ ):
315
+ """Build an empty cache for use with `call_with_cache()`."""
316
+ encoder_output, encoder_padding_mask = self.call_encoder(
317
+ encoder_token_ids, encoder_padding_mask
318
+ )
319
+ batch_size = keras.ops.shape(decoder_token_ids)[0]
320
+ num_layers = self.backbone.decoder_num_layers
321
+ num_kv_heads = self.backbone.decoder_num_key_value_heads
322
+ head_dim = self.backbone.decoder_head_dim
323
+ self_cache_shape = (
324
+ batch_size,
325
+ num_layers,
326
+ 2,
327
+ keras.ops.shape(decoder_token_ids)[1],
328
+ num_kv_heads,
329
+ head_dim,
330
+ )
331
+ self_attention_cache = keras.ops.zeros(
332
+ self_cache_shape, dtype=self.compute_dtype
333
+ )
334
+ cross_attention_cache = None
335
+ _, hidden_states, cache = self.call_decoder_with_cache(
336
+ decoder_token_ids=decoder_token_ids,
337
+ decoder_padding_mask=decoder_padding_mask,
338
+ cache=(self_attention_cache, cross_attention_cache),
339
+ cache_update_index=0,
340
+ encoder_output=encoder_output,
341
+ encoder_padding_mask=encoder_padding_mask,
342
+ )
343
+ extra_cache_info = (encoder_output, encoder_padding_mask)
344
+ return hidden_states, cache, extra_cache_info
345
+
346
+ def generate_step(self, inputs, stop_token_ids=None):
347
+ """A compilable generation function for a single batch of inputs.
348
+
349
+ This function represents the inner, XLA-compilable, generation function
350
+ for a single batch of inputs. Inputs should have the same structure as
351
+ model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`.
352
+ `"encoder_token_ids"`, `"encoder_padding_mask"`, `"decoder_token_ids"`
353
+ and `"decoder_padding_mask"`.
354
+
355
+ Args:
356
+ inputs: A dictionary with four keys - `"encoder_token_ids"`,
357
+ `"encoder_padding_mask"`, `"decoder_token_ids"` and
358
+ `"decoder_padding_mask"`, with batched tensor values.
359
+ stop_token_ids: Tuple of id's of end token's to stop on. If all
360
+ sequences have produced a new stop token, generation
361
+ will stop.
362
+ """
363
+ encoder_token_ids = inputs["encoder_token_ids"]
364
+ encoder_padding_mask = inputs["encoder_padding_mask"]
365
+ decoder_token_ids = inputs["decoder_token_ids"]
366
+ decoder_padding_mask = inputs["decoder_padding_mask"]
367
+ # Create and seed cache with a single forward pass.
368
+ hidden_states, cache, extra_cache_info = self._build_cache(
369
+ encoder_token_ids=encoder_token_ids,
370
+ encoder_padding_mask=encoder_padding_mask,
371
+ decoder_token_ids=decoder_token_ids,
372
+ decoder_padding_mask=decoder_padding_mask,
373
+ )
374
+ encoder_output, encoder_padding_mask = extra_cache_info
375
+ # Compute the lengths of all user inputted tokens ids.
376
+ row_lengths = keras.ops.sum(
377
+ keras.ops.cast(decoder_padding_mask, "int32"), axis=-1
378
+ )
379
+ # Start at the first index that has no user inputted id.
380
+ index = keras.ops.min(row_lengths)
381
+
382
+ def next(prompt, cache, index):
383
+ # The cache index is the index of our previous token.
384
+ cache_update_index = index - 1
385
+ batch_size = keras.ops.shape(prompt)[0]
386
+ prompt = keras.ops.slice(
387
+ prompt, [0, cache_update_index], [batch_size, 1]
388
+ )
389
+ (
390
+ logits,
391
+ _,
392
+ updated_cache,
393
+ ) = self.call_decoder_with_cache(
394
+ decoder_token_ids=prompt,
395
+ decoder_padding_mask=None,
396
+ cache_update_index=cache_update_index,
397
+ cache=cache,
398
+ encoder_output=encoder_output,
399
+ encoder_padding_mask=encoder_padding_mask,
400
+ )
401
+ return keras.ops.squeeze(logits, axis=1), None, updated_cache
402
+
403
+ decoder_token_ids = self.sampler(
404
+ next=next,
405
+ prompt=decoder_token_ids,
406
+ cache=cache,
407
+ index=index,
408
+ mask=decoder_padding_mask,
409
+ stop_token_ids=stop_token_ids,
410
+ hidden_states=hidden_states,
411
+ model=self,
412
+ )
413
+
414
+ # Compute an output padding mask with the token ids we updated.
415
+ if stop_token_ids is not None:
416
+ # Build a mask of `stop_token_ids` locations not in the original
417
+ # prompt (not in locations where `decoder_padding_mask` is True).
418
+ end_locations = any_equal(
419
+ decoder_token_ids,
420
+ stop_token_ids,
421
+ keras.ops.logical_not(decoder_padding_mask),
422
+ )
423
+ # Use cumsum to get ones in all locations after end_locations.
424
+ end_locations = keras.ops.cast(end_locations, "int32")
425
+ cumsum = keras.ops.cast(
426
+ keras.ops.cumsum(end_locations, axis=-1), "int32"
427
+ )
428
+ overflow = cumsum - end_locations
429
+ # Our padding mask is the inverse of these overflow locations.
430
+ decoder_padding_mask = keras.ops.logical_not(
431
+ keras.ops.cast(overflow, "bool")
432
+ )
433
+ else:
434
+ # Without early stopping, all locations will have been updated.
435
+ decoder_padding_mask = keras.ops.ones_like(
436
+ decoder_token_ids, dtype="bool"
437
+ )
438
+
439
+ return {
440
+ "decoder_token_ids": decoder_token_ids,
441
+ "decoder_padding_mask": decoder_padding_mask,
442
+ }
@@ -0,0 +1,216 @@
1
+ import keras
2
+
3
+ from keras_hub.src.api_export import keras_hub_export
4
+ from keras_hub.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor
5
+ from keras_hub.src.models.t5gemma.t5gemma_backbone import T5GemmaBackbone
6
+ from keras_hub.src.models.t5gemma.t5gemma_tokenizer import T5GemmaTokenizer
7
+ from keras_hub.src.utils.tensor_utils import preprocessing_function
8
+
9
+ try:
10
+ import tensorflow as tf
11
+ except ImportError:
12
+ tf = None
13
+
14
+
15
+ @keras_hub_export("keras_hub.models.T5GemmaSeq2SeqLMPreprocessor")
16
+ class T5GemmaSeq2SeqLMPreprocessor(Seq2SeqLMPreprocessor):
17
+ """T5Gemma Seq2Seq LM preprocessor.
18
+
19
+ This preprocessing layer is meant for use with
20
+ `keras_hub.models.T5GemmaSeq2SeqLM`. By default, it will take in batches of
21
+ strings, and return outputs in a `(x, y, sample_weight)` format, where the
22
+ `y` label is the next token id in the `x` sequence.
23
+
24
+ For use with generation, the layer also exposes two methods
25
+ `generate_preprocess()` and `generate_postprocess()`. When this preprocessor
26
+ is attached to a `keras_hub.models.T5GemmaSeq2SeqLM` instance, these methods
27
+ will be called implicitly in `generate()`. They can also be called
28
+ standalone (e.g. to precompute preprocessing inputs for generation in a
29
+ separate process).
30
+
31
+ Args:
32
+ tokenizer: A `keras_hub.models.T5GemmaTokenizer` instance.
33
+ encoder_sequence_length: The length of the packed encoder inputs.
34
+ decoder_sequence_length: The length of the packed decoder inputs.
35
+ add_start_token: If `True`, the preprocessor will prepend the
36
+ tokenizer start token to each input sequence. For T5Gemma models,
37
+ this should be `False`. Defaults to `False`.
38
+ add_end_token: If `True`, the preprocessor will append the tokenizer end
39
+ token to each input sequence. For T5Gemma models, this should be
40
+ `True`. Defaults to `True`.
41
+
42
+ Call arguments:
43
+ x: A dictionary with two keys, `"encoder_text"` and `"decoder_text"`.
44
+ The values can be a string, a `tf.Tensor` or a list of python
45
+ strings.
46
+ y: Label data. Should always be `None` as the layer generates labels.
47
+ sample_weight: Label weights. Should always be `None` as the layer
48
+ generates label weights.
49
+ encoder_sequence_length: Pass to override the configured
50
+ `encoder_sequence_length` of the layer.
51
+ decoder_sequence_length: Pass to override the configured
52
+ `decoder_sequence_length` of the layer.
53
+
54
+ Examples:
55
+ ```python
56
+ import tensorflow as tf
57
+ import numpy as np
58
+
59
+ # Load the preprocessor from a preset.
60
+ preprocessor = keras_hub.models.T5GemmaSeq2SeqLMPreprocessor.from_preset(
61
+ "t5gemma_b_b_prefixlm_it"
62
+ )
63
+
64
+ # For example usage, see the dictionary example below which provides
65
+ # both encoder and decoder text.
66
+ # Tokenize a batch of sentences.
67
+ preprocessor(["The quick brown fox jumped.", "Call me Ishmael."])
68
+ # Tokenize a dictionary with separate encoder and decoder inputs.
69
+ preprocessor({
70
+ "encoder_text": "The quick brown fox jumped.",
71
+ "decoder_text": "The fast fox."
72
+ })
73
+
74
+ # Apply tokenization to a `tf.data.Dataset`.
75
+ encoder_features = tf.constant(["The quick brown fox.", "Call me Ishmael."])
76
+ decoder_features = tf.constant(["The fast fox.", "I am Ishmael."])
77
+ ds = tf.data.Dataset.from_tensor_slices(
78
+ {"encoder_text": encoder_features, "decoder_text": decoder_features}
79
+ )
80
+ ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
81
+
82
+ # Prepare tokens for generation.
83
+ preprocessor.generate_preprocess({
84
+ "encoder_text": "The quick brown fox jumped.",
85
+ "decoder_text": "The fast fox."
86
+ })
87
+
88
+ # Map generation outputs back to strings.
89
+ preprocessor.generate_postprocess({
90
+ 'decoder_token_ids': np.array([[2, 714, 4320, 8426, 25341, 1, 0, 0]]),
91
+ 'decoder_padding_mask': np.array([[1, 1, 1, 1, 1, 1, 0, 0]]),
92
+ })
93
+ ```
94
+ """
95
+
96
+ backbone_cls = T5GemmaBackbone
97
+ tokenizer_cls = T5GemmaTokenizer
98
+
99
+ def __init__(
100
+ self,
101
+ tokenizer,
102
+ encoder_sequence_length=512,
103
+ decoder_sequence_length=512,
104
+ add_start_token=False,
105
+ add_end_token=True,
106
+ **kwargs,
107
+ ):
108
+ # Do not pass `add_start_token` and `add_end_token` to the base class.
109
+ super().__init__(
110
+ tokenizer=tokenizer,
111
+ encoder_sequence_length=encoder_sequence_length,
112
+ decoder_sequence_length=decoder_sequence_length,
113
+ **kwargs,
114
+ )
115
+ # Store them directly on the subclass instance.
116
+ self.add_start_token = add_start_token
117
+ self.add_end_token = add_end_token
118
+
119
+ @preprocessing_function
120
+ def call(
121
+ self,
122
+ x,
123
+ y=None,
124
+ sample_weight=None,
125
+ *,
126
+ encoder_sequence_length=None,
127
+ decoder_sequence_length=None,
128
+ sequence_length=None,
129
+ ):
130
+ if encoder_sequence_length is None:
131
+ encoder_sequence_length = self.encoder_sequence_length
132
+ decoder_sequence_length = decoder_sequence_length or sequence_length
133
+ if decoder_sequence_length is None:
134
+ decoder_sequence_length = self.decoder_sequence_length
135
+
136
+ encoder_inputs = self.tokenizer(x["encoder_text"])
137
+ encoder_token_ids, encoder_padding_mask = self.encoder_packer(
138
+ encoder_inputs,
139
+ sequence_length=encoder_sequence_length,
140
+ add_start_value=self.add_start_token,
141
+ add_end_value=self.add_end_token,
142
+ )
143
+ decoder_inputs = self.tokenizer(x["decoder_text"])
144
+ decoder_token_ids, decoder_padding_mask = self.decoder_packer(
145
+ decoder_inputs,
146
+ sequence_length=decoder_sequence_length + 1,
147
+ add_start_value=True,
148
+ add_end_value=self.add_end_token,
149
+ )
150
+ x = {
151
+ "encoder_token_ids": encoder_token_ids,
152
+ "encoder_padding_mask": encoder_padding_mask,
153
+ "decoder_token_ids": decoder_token_ids[..., :-1],
154
+ "decoder_padding_mask": decoder_padding_mask[..., :-1],
155
+ }
156
+ y = decoder_token_ids[..., 1:]
157
+ sample_weight = decoder_padding_mask[..., 1:]
158
+ return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
159
+
160
+ @preprocessing_function
161
+ def generate_preprocess(
162
+ self,
163
+ x,
164
+ *,
165
+ encoder_sequence_length=None,
166
+ decoder_sequence_length=None,
167
+ sequence_length=None,
168
+ ):
169
+ if not self.built:
170
+ self.build(None)
171
+
172
+ if isinstance(x, dict):
173
+ encoder_text = x["encoder_text"]
174
+ decoder_text = x["decoder_text"]
175
+ else:
176
+ encoder_text = x
177
+ decoder_text = tf.fill((tf.shape(encoder_text)[0],), "")
178
+
179
+ if encoder_sequence_length is None:
180
+ encoder_sequence_length = self.encoder_sequence_length
181
+ decoder_sequence_length = decoder_sequence_length or sequence_length
182
+ if decoder_sequence_length is None:
183
+ decoder_sequence_length = self.decoder_sequence_length
184
+
185
+ encoder_token_ids = self.tokenizer(encoder_text)
186
+ encoder_token_ids, encoder_padding_mask = self.encoder_packer(
187
+ encoder_token_ids,
188
+ sequence_length=None,
189
+ add_start_value=self.add_start_token,
190
+ add_end_value=False,
191
+ )
192
+
193
+ decoder_token_ids = self.tokenizer(decoder_text)
194
+ decoder_token_ids, decoder_padding_mask = self.decoder_packer(
195
+ decoder_token_ids,
196
+ sequence_length=decoder_sequence_length,
197
+ add_start_value=True,
198
+ add_end_value=False,
199
+ )
200
+
201
+ return {
202
+ "encoder_token_ids": encoder_token_ids,
203
+ "encoder_padding_mask": encoder_padding_mask,
204
+ "decoder_token_ids": decoder_token_ids,
205
+ "decoder_padding_mask": decoder_padding_mask,
206
+ }
207
+
208
+ def get_config(self):
209
+ config = super().get_config()
210
+ config.update(
211
+ {
212
+ "add_start_token": self.add_start_token,
213
+ "add_end_token": self.add_end_token,
214
+ }
215
+ )
216
+ return config