keras-hub-nightly 0.21.0.dev202505140407__py3-none-any.whl → 0.21.0.dev202505160409__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,383 @@
1
+ import keras
2
+
3
+ from keras_hub.src.api_export import keras_hub_export
4
+ from keras_hub.src.models.moonshine.moonshine_audio_to_text_preprocessor import ( # noqa: E501
5
+ MoonshineAudioToTextPreprocessor,
6
+ )
7
+ from keras_hub.src.models.moonshine.moonshine_backbone import Arange
8
+ from keras_hub.src.models.moonshine.moonshine_backbone import MoonshineBackbone
9
+ from keras_hub.src.models.moonshine.moonshine_backbone import (
10
+ compute_output_lengths,
11
+ )
12
+ from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM
13
+ from keras_hub.src.utils.tensor_utils import any_equal
14
+
15
+
16
+ @keras_hub_export("keras_hub.models.MoonshineAudioToText")
17
+ class MoonshineAudioToText(Seq2SeqLM):
18
+ """An end-to-end Moonshine model for audio-to-text tasks.
19
+
20
+ A Seq2Seq LM designed for audio-to-text tasks, such as speech recognition.
21
+ The encoder processes audio features, and the decoder generates text
22
+ transcriptions. You can finetune `MoonshineAudioToText` for any
23
+ audio-to-text task (e.g., live transcription or voice commands).
24
+
25
+ This model includes a `generate()` method for text generation based on audio
26
+ inputs and an optional text prompt for the decoder. The generation strategy
27
+ is controlled by a `sampler` argument passed to `compile()`. By default,
28
+ `"top_k"` sampling is used.
29
+
30
+ Args:
31
+ backbone: A `keras_hub.models.MoonshineBackbone` instance.
32
+ preprocessor: A `keras_hub.models.MoonshineAudioToTextPreprocessor` or
33
+ `None`. If `None`, inputs must be preprocessed before calling the
34
+ model.
35
+
36
+ Examples:
37
+ ```python
38
+ # Initialize model from preset.
39
+ moonshine_lm = keras_hub.models.MoonshineAudioToText.from_preset(
40
+ "moonshine_base"
41
+ )
42
+
43
+ # Generate with single audio input.
44
+ audio_tensor = keras.random.normal((1, 16000, 1))
45
+ moonshine_lm.generate({"audio": audio_tensor})
46
+
47
+ # Generate with text prompt.
48
+ moonshine_lm.generate({"audio": audio_tensor, "text": "quick"})
49
+
50
+ # Use different sampling strategy.
51
+ moonshine_lm.compile(sampler="greedy")
52
+ moonshine_lm.generate({"audio": audio_tensor})
53
+ ```
54
+ """
55
+
56
+ # References:
57
+ # Defined and formulated based on the Hugging Face implementation of the
58
+ # MoonshineForConditionalGeneration class (https://github.com/huggingface/transformers/blob/dcbdf7e962c4b36140cc9ee76f870016121e69e5/src/transformers/models/moonshine/modeling_moonshine.py#L1509-L1626).
59
+
60
+ backbone_cls = MoonshineBackbone
61
+ preprocessor_cls = MoonshineAudioToTextPreprocessor
62
+
63
+ def __init__(self, backbone, preprocessor=None, **kwargs):
64
+ # === Layers ===
65
+ self.backbone = backbone
66
+ self.preprocessor = preprocessor
67
+
68
+ # === Functional Model ===
69
+ inputs = backbone.input
70
+ hidden_states = backbone(inputs)["decoder_sequence_output"]
71
+ outputs = backbone.token_embedding(hidden_states, reverse=True)
72
+ super().__init__(
73
+ inputs=inputs,
74
+ outputs=outputs,
75
+ **kwargs,
76
+ )
77
+
78
+ def call_decoder_with_cache(
79
+ self,
80
+ encoder_hidden_states,
81
+ encoder_padding_mask,
82
+ decoder_token_ids,
83
+ self_attention_cache=None,
84
+ self_attention_cache_update_index=None,
85
+ cross_attention_cache=None,
86
+ ):
87
+ """Process decoder inputs with attention caching for efficient
88
+ generation.
89
+
90
+ Args:
91
+ encoder_hidden_states: Tensor. Encoder outputs.
92
+ encoder_padding_mask: Tensor. Padding mask for encoder outputs.
93
+ decoder_token_ids: Tensor. Decoder input token IDs.
94
+ self_attention_cache: Tensor, optional. Cache for self-attention
95
+ layers.
96
+ self_attention_cache_update_index: int, optional. Index for cache
97
+ updates.
98
+ cross_attention_cache: Tensor, optional. Cache for cross-attention
99
+ layers. This cache is computed once and reused.
100
+
101
+ Returns:
102
+ Tuple: Tuple of (logits, hidden_states, new_self_attention_cache,
103
+ cross_attention_cache).
104
+ """
105
+ tokens = self.backbone.token_embedding(decoder_token_ids)
106
+ x = tokens
107
+
108
+ # Cache management for audio-to-text generation.
109
+ self_attention_caches = []
110
+ position = keras.ops.array(
111
+ [self_attention_cache_update_index], dtype="int32"
112
+ )
113
+ rotary_embedding = self.backbone.decoder_rotary_embedding(position)
114
+
115
+ for i, layer in enumerate(self.backbone.decoder_blocks):
116
+ current_self_cache = self_attention_cache[:, i, ...]
117
+ current_cross_cache = cross_attention_cache[:, i, ...]
118
+ x, new_self_cache = layer(
119
+ decoder_sequence=x,
120
+ encoder_sequence=encoder_hidden_states,
121
+ rotary_embedding=rotary_embedding,
122
+ encoder_padding_mask=encoder_padding_mask,
123
+ self_attention_cache=current_self_cache,
124
+ self_attention_cache_update_index=self_attention_cache_update_index,
125
+ cross_attention_cache=current_cross_cache,
126
+ training=False,
127
+ )
128
+ # Update self-attention cache.
129
+ self_attention_caches.append(new_self_cache)
130
+
131
+ # [batch_size, num_layers, 2, seq_len, num_heads, head_dim].
132
+ new_self_attention_cache = keras.ops.stack(
133
+ self_attention_caches, axis=1
134
+ )
135
+ hidden_states = self.backbone.decoder_post_norm(x)
136
+ logits = self.backbone.token_embedding(hidden_states, reverse=True)
137
+ return (
138
+ logits,
139
+ hidden_states,
140
+ new_self_attention_cache,
141
+ cross_attention_cache,
142
+ )
143
+
144
+ def _build_cache(
145
+ self,
146
+ encoder_input_values,
147
+ encoder_padding_mask,
148
+ decoder_token_ids,
149
+ decoder_padding_mask,
150
+ ):
151
+ """Build initial cache states from inputs."""
152
+ encoder_hidden_states, encoder_attention_mask_for_decoder = (
153
+ self.call_encoder(
154
+ encoder_input_values=encoder_input_values,
155
+ padding_mask=encoder_padding_mask,
156
+ )
157
+ )
158
+ precomputed_cross_caches = []
159
+ for layer in self.backbone.decoder_blocks:
160
+ cross_k = layer.cross_attention._key_dense(encoder_hidden_states)
161
+ cross_v = layer.cross_attention._value_dense(encoder_hidden_states)
162
+ layer_cross_cache = keras.ops.stack([cross_k, cross_v], axis=1)
163
+ precomputed_cross_caches.append(layer_cross_cache)
164
+ precomputed_cross_cache = keras.ops.stack(
165
+ precomputed_cross_caches, axis=1
166
+ )
167
+ batch_size = keras.ops.shape(encoder_input_values)[0]
168
+ num_layers = self.backbone.decoder_num_layers
169
+ num_heads = self.backbone.decoder_num_heads
170
+ head_dim = self.backbone.hidden_dim // self.backbone.decoder_num_heads
171
+ if self.backbone.pad_head_dim_to_multiple_of is not None:
172
+ head_dim = (
173
+ (head_dim + self.backbone.pad_head_dim_to_multiple_of - 1)
174
+ // self.backbone.pad_head_dim_to_multiple_of
175
+ ) * self.backbone.pad_head_dim_to_multiple_of
176
+ # Use the full sequence length for the cache dimension.
177
+ cache_length = keras.ops.shape(decoder_token_ids)[1]
178
+ initial_self_cache_shape = (
179
+ batch_size,
180
+ num_layers,
181
+ 2,
182
+ cache_length,
183
+ num_heads,
184
+ head_dim,
185
+ )
186
+ initial_self_cache = keras.ops.zeros(
187
+ initial_self_cache_shape, dtype=self.compute_dtype
188
+ )
189
+ tokens = self.backbone.token_embedding(decoder_token_ids)
190
+ x = tokens
191
+ positions = keras.ops.arange(0, cache_length, dtype="int32")
192
+ rotary_embedding = self.backbone.decoder_rotary_embedding(positions)
193
+ seeded_self_caches = []
194
+ for i, layer in enumerate(self.backbone.decoder_blocks):
195
+ current_initial_self_cache = initial_self_cache[:, i, ...]
196
+ current_precomputed_cross_cache = precomputed_cross_cache[:, i, ...]
197
+ x, seeded_self_cache_layer = layer(
198
+ decoder_sequence=x,
199
+ encoder_sequence=encoder_hidden_states,
200
+ rotary_embedding=rotary_embedding,
201
+ decoder_padding_mask=decoder_padding_mask,
202
+ encoder_padding_mask=encoder_attention_mask_for_decoder,
203
+ self_attention_cache=current_initial_self_cache,
204
+ self_attention_cache_update_index=0,
205
+ cross_attention_cache=current_precomputed_cross_cache,
206
+ training=False,
207
+ )
208
+ seeded_self_caches.append(seeded_self_cache_layer)
209
+ hidden_states = self.backbone.decoder_post_norm(x)
210
+ self_attn_cache = keras.ops.stack(seeded_self_caches, axis=1)
211
+ return (
212
+ hidden_states,
213
+ self_attn_cache,
214
+ precomputed_cross_cache,
215
+ encoder_hidden_states,
216
+ encoder_attention_mask_for_decoder,
217
+ )
218
+
219
+ def call_encoder(self, encoder_input_values, padding_mask):
220
+ """Process audio input through the encoder stack."""
221
+ x = self.backbone.conv1(encoder_input_values)
222
+ x = self.backbone.tanh_after_conv1(x)
223
+ x = self.backbone.group_norm(x)
224
+ x = self.backbone.conv2(x)
225
+ x = self.backbone.gelu_after_conv2(x)
226
+ x = self.backbone.conv3(x)
227
+ x = self.backbone.gelu_after_conv3(x)
228
+ original_lengths = keras.ops.sum(
229
+ keras.ops.cast(padding_mask, "int32"), axis=1
230
+ )
231
+ output_lengths = compute_output_lengths(original_lengths)
232
+ padding_mask = self.backbone._compute_mask_layer(x, output_lengths)
233
+ positions = Arange(name="encoder_positions")(x)
234
+ rotary_embedding = self.backbone.encoder_rotary_embedding(positions)
235
+ x = self.backbone.encoder_dropout(x, training=False)
236
+ for transformer_layer in self.backbone.encoder_blocks:
237
+ x = transformer_layer(
238
+ inputs=x,
239
+ rotary_embedding=rotary_embedding,
240
+ attention_mask=padding_mask,
241
+ training=False,
242
+ )
243
+ x = self.backbone.encoder_final_layer_norm(x)
244
+ return x, padding_mask
245
+
246
+ # Source: https://github.com/huggingface/transformers/blob/9e94801146ceeb3b215bbdb9492be74d7d7b7210/src/transformers/generation/utils.py#L1970-L2463
247
+ def generate_step(self, inputs, stop_token_ids=None):
248
+ """A compilable generation function for a batch of inputs.
249
+
250
+ This function represents the inner, XLA-compilable, generation function
251
+ for a single batch of inputs. Inputs should have the same structure as
252
+ model inputs, a dictionary with keys `"encoder_input_values"`,
253
+ `"encoder_padding_mask"`, `"decoder_token_ids"` and
254
+ `"decoder_padding_mask"`.
255
+
256
+ Args:
257
+ inputs: A dictionary with four keys - `"encoder_input_values"`,
258
+ `"encoder_padding_mask"`, `"decoder_token_ids"` and
259
+ `"decoder_padding_mask"`, with batched tensor values.
260
+ stop_token_ids: Tuple of id's of end token's to stop on. If all
261
+ sequences have produced a new stop token, generation
262
+ will stop.
263
+
264
+ Returns:
265
+ Dictionary: A dictionary with two keys - `"decoder_token_ids"`
266
+ containing the updated token sequence with newly generated
267
+ tokens, and `"decoder_padding_mask"` containing the updated
268
+ padding mask for the generated sequence.
269
+ """
270
+ encoder_input_values = inputs["encoder_input_values"]
271
+ encoder_padding_mask = inputs["encoder_padding_mask"]
272
+ decoder_token_ids = inputs["decoder_token_ids"]
273
+ decoder_padding_mask = inputs["decoder_padding_mask"]
274
+
275
+ if (
276
+ encoder_input_values is None
277
+ or encoder_padding_mask is None
278
+ or decoder_token_ids is None
279
+ ):
280
+ raise ValueError("Input tensors cannot be None")
281
+
282
+ (
283
+ hidden_states,
284
+ self_attention_cache,
285
+ cross_attention_cache,
286
+ encoder_hidden_states,
287
+ encoder_attention_mask_for_decoder,
288
+ ) = self._build_cache(
289
+ encoder_input_values,
290
+ encoder_padding_mask,
291
+ decoder_token_ids,
292
+ decoder_padding_mask,
293
+ )
294
+ row_lengths = keras.ops.sum(
295
+ keras.ops.cast(decoder_padding_mask, "int32"),
296
+ axis=-1,
297
+ )
298
+ index = keras.ops.min(row_lengths)
299
+
300
+ def next(prompt, cache, index):
301
+ if isinstance(cache, tuple) and len(cache) == 2:
302
+ current_self_attention_cache = cache[0]
303
+ current_cross_attention_cache = cache[1]
304
+ elif cache is not None and not isinstance(cache, tuple):
305
+ current_self_attention_cache = cache
306
+ current_cross_attention_cache = cross_attention_cache
307
+ else:
308
+ cache = None
309
+ cache_index = index - 1
310
+ num_samples = keras.ops.shape(prompt)[0]
311
+ next_token_input = keras.ops.slice(
312
+ prompt, [0, cache_index], [num_samples, 1]
313
+ )
314
+
315
+ batch_size = keras.ops.shape(encoder_input_values)[0]
316
+
317
+ def repeat_tensor(x):
318
+ if keras.ops.shape(x)[0] == num_samples:
319
+ return x
320
+ return keras.ops.repeat(
321
+ x, repeats=num_samples // batch_size, axis=0
322
+ )
323
+
324
+ cross_attention_cache_repeated = repeat_tensor(
325
+ current_cross_attention_cache
326
+ )
327
+ logits, hidden_states, new_self_attention_cache, _ = (
328
+ self.call_decoder_with_cache(
329
+ encoder_hidden_states=repeat_tensor(encoder_hidden_states),
330
+ encoder_padding_mask=repeat_tensor(
331
+ encoder_attention_mask_for_decoder
332
+ ),
333
+ decoder_token_ids=next_token_input,
334
+ self_attention_cache=current_self_attention_cache,
335
+ self_attention_cache_update_index=cache_index,
336
+ cross_attention_cache=cross_attention_cache_repeated,
337
+ )
338
+ )
339
+ return (
340
+ logits[:, 0, :],
341
+ hidden_states[:, 0, :],
342
+ (new_self_attention_cache, current_cross_attention_cache),
343
+ )
344
+
345
+ decoder_token_ids = self.sampler(
346
+ next=next,
347
+ prompt=decoder_token_ids,
348
+ cache=(self_attention_cache, cross_attention_cache),
349
+ index=index,
350
+ mask=keras.ops.cast(
351
+ decoder_token_ids != self.preprocessor.tokenizer.pad_token_id
352
+ if self.preprocessor is not None
353
+ else decoder_padding_mask,
354
+ dtype="bool",
355
+ ),
356
+ stop_token_ids=stop_token_ids,
357
+ hidden_states=hidden_states,
358
+ model=self,
359
+ )
360
+
361
+ if stop_token_ids is not None:
362
+ end_locations = any_equal(
363
+ decoder_token_ids,
364
+ stop_token_ids,
365
+ decoder_token_ids == self.preprocessor.tokenizer.pad_token_id
366
+ if self.preprocessor is not None
367
+ else False,
368
+ )
369
+ end_locations = keras.ops.cast(end_locations, "int32")
370
+ cumsum = keras.ops.cumsum(end_locations, axis=-1)
371
+ overflow = cumsum - end_locations
372
+ decoder_padding_mask = keras.ops.logical_not(
373
+ keras.ops.cast(overflow, "bool")
374
+ )
375
+ else:
376
+ decoder_padding_mask = keras.ops.ones_like(
377
+ decoder_token_ids, dtype="bool"
378
+ )
379
+
380
+ return {
381
+ "decoder_token_ids": decoder_token_ids,
382
+ "decoder_padding_mask": decoder_padding_mask,
383
+ }
@@ -0,0 +1,267 @@
1
+ import keras
2
+
3
+ try:
4
+ import tensorflow as tf
5
+ except ImportError:
6
+ tf = None
7
+ from keras_hub.src.api_export import keras_hub_export
8
+ from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker
9
+ from keras_hub.src.models.moonshine.moonshine_backbone import MoonshineBackbone
10
+ from keras_hub.src.models.moonshine.moonshine_tokenizer import (
11
+ MoonshineTokenizer,
12
+ )
13
+ from keras_hub.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor
14
+ from keras_hub.src.utils.tensor_utils import preprocessing_function
15
+
16
+
17
+ @keras_hub_export("keras_hub.models.MoonshineAudioToTextPreprocessor")
18
+ class MoonshineAudioToTextPreprocessor(Seq2SeqLMPreprocessor):
19
+ """Moonshine Seq2Seq LM preprocessor for audio-to-text tasks.
20
+
21
+ This preprocessor converts raw audio and text inputs into a format suitable
22
+ for the `MoonshineAudioToText` model. It processes audio waveforms using
23
+ `MoonshineAudioConverter` for basic preprocessing (padding, normalization)
24
+ and tokenizes text using `MoonshineTokenizer` for the decoder. It supports
25
+ training and generation.
26
+
27
+ Args:
28
+ audio_converter: A `MoonshineAudioConverter` instance to process audio.
29
+ tokenizer: A `MoonshineTokenizer` instance to tokenize text.
30
+ decoder_sequence_length: int, optional. Maximum length for decoder token
31
+ sequences. Defaults to 1024.
32
+ **kwargs: Additional keyword arguments for the parent class.
33
+
34
+ Examples:
35
+ ```python
36
+ import keras
37
+ from keras_hub.layers import MoonshineAudioConverter
38
+ from keras_hub.models import MoonshineTokenizer
39
+
40
+ # Create audio converter and tokenizer instances.
41
+ audio_converter = MoonshineAudioConverter()
42
+ tokenizer = MoonshineTokenizer.from_preset("moonshine_base")
43
+
44
+ # Initialize the preprocessor.
45
+ preprocessor = keras_hub.models.MoonshineAudioToTextPreprocessor(
46
+ audio_converter=audio_converter,
47
+ tokenizer=tokenizer,
48
+ decoder_sequence_length=8
49
+ )
50
+
51
+ # Prepare input data (audio tensor and text).
52
+ inputs = {
53
+ "audio": keras.random.normal((1, 16000)),
54
+ "text": ["the quick brown fox"]
55
+ }
56
+
57
+ # Process the inputs for training.
58
+ x, y, sample_weight = preprocessor(inputs)
59
+
60
+ # Check output keys and shapes (shapes depend on padding/truncation).
61
+ print(x.keys())
62
+ # dict_keys(['encoder_input_values', 'encoder_padding_mask',
63
+ # 'decoder_token_ids', 'decoder_padding_mask']).
64
+ print(x["encoder_input_values"].shape) # e.g., (1, 16000, 1) / padded length
65
+ print(x["encoder_padding_mask"].shape) # e.g., (1, 16000) or padded length
66
+ print(x["decoder_token_ids"].shape) # (1, 8)
67
+ print(x["decoder_padding_mask"].shape) # (1, 8)
68
+ print(y.shape) # (1, 8) - Labels
69
+ print(sample_weight.shape) # (1, 8) - Sample weights
70
+
71
+ # Process inputs for generation.
72
+ gen_inputs = preprocessor.generate_preprocess(inputs)
73
+ print(gen_inputs.keys())
74
+ # dict_keys(['encoder_input_values', 'encoder_padding_mask',
75
+ # 'decoder_token_ids', 'decoder_padding_mask']).
76
+ ```
77
+ """
78
+
79
+ backbone_cls = MoonshineBackbone
80
+ tokenizer_cls = MoonshineTokenizer
81
+
82
+ def __init__(
83
+ self,
84
+ audio_converter,
85
+ tokenizer,
86
+ decoder_sequence_length=1024,
87
+ **kwargs,
88
+ ):
89
+ super().__init__(tokenizer=tokenizer, **kwargs)
90
+ self.audio_converter = audio_converter
91
+ self.decoder_sequence_length = decoder_sequence_length
92
+ self.decoder_packer = None
93
+ self._special_token_ids_set = None
94
+
95
+ def build(self, input_shape):
96
+ self.decoder_packer = StartEndPacker(
97
+ start_value=self.tokenizer.start_token_id,
98
+ end_value=self.tokenizer.end_token_id,
99
+ pad_value=self.tokenizer.pad_token_id,
100
+ sequence_length=self.decoder_sequence_length,
101
+ return_padding_mask=True,
102
+ )
103
+ self._special_token_ids_set = set(self.tokenizer.special_token_ids)
104
+ if self.tokenizer.pad_token_id is not None:
105
+ self._special_token_ids_set.add(self.tokenizer.pad_token_id)
106
+ self.built = True
107
+
108
+ @preprocessing_function
109
+ def call(
110
+ self,
111
+ x,
112
+ y=None,
113
+ sample_weight=None,
114
+ decoder_sequence_length=None,
115
+ sequence_length=None,
116
+ ):
117
+ if not self.built:
118
+ self.build(None)
119
+ if isinstance(x, tuple) and len(x) == 1:
120
+ x = x[0]
121
+ decoder_sequence_length = (
122
+ decoder_sequence_length
123
+ or sequence_length
124
+ or self.decoder_sequence_length
125
+ )
126
+ text = x["text"]
127
+ encoder_inputs = self.audio_converter(
128
+ x["audio"],
129
+ padding="longest",
130
+ )
131
+ encoder_inputs_shape = keras.ops.shape(encoder_inputs)
132
+ if len(encoder_inputs_shape) == 2:
133
+ encoder_inputs = keras.ops.expand_dims(encoder_inputs, axis=-1)
134
+ squeezed_inputs = encoder_inputs[:, :, 0]
135
+ is_tf_symbolic = (
136
+ tf is not None
137
+ and hasattr(squeezed_inputs, "graph")
138
+ and hasattr(squeezed_inputs.graph, "as_graph_def")
139
+ )
140
+ if is_tf_symbolic and keras.config.backend() != "tensorflow":
141
+ encoder_padding_mask = tf.logical_not(
142
+ tf.math.equal(
143
+ squeezed_inputs, float(self.audio_converter.padding_value)
144
+ )
145
+ )
146
+ else:
147
+ encoder_padding_mask = keras.ops.logical_not(
148
+ keras.ops.equal(
149
+ squeezed_inputs, self.audio_converter.padding_value
150
+ )
151
+ )
152
+ decoder_inputs = self.tokenizer(text)
153
+ decoder_token_ids, decoder_padding_mask = self.decoder_packer(
154
+ decoder_inputs,
155
+ sequence_length=decoder_sequence_length + 1,
156
+ add_end_value=True,
157
+ )
158
+ x_out = {
159
+ "encoder_input_values": encoder_inputs,
160
+ "encoder_padding_mask": encoder_padding_mask,
161
+ "decoder_token_ids": decoder_token_ids[..., :-1],
162
+ "decoder_padding_mask": decoder_padding_mask[..., :-1],
163
+ }
164
+ y_out = decoder_token_ids[..., 1:]
165
+ sample_weight_out = decoder_padding_mask[..., 1:]
166
+
167
+ return keras.utils.pack_x_y_sample_weight(
168
+ x_out, y_out, sample_weight_out
169
+ )
170
+
171
+ @preprocessing_function
172
+ def generate_preprocess(
173
+ self,
174
+ x,
175
+ decoder_sequence_length=None,
176
+ sequence_length=None,
177
+ ):
178
+ if not self.built:
179
+ self.build(None)
180
+ if isinstance(x, tuple) and len(x) == 1:
181
+ x = x[0]
182
+ decoder_sequence_length = (
183
+ decoder_sequence_length
184
+ or sequence_length
185
+ or self.decoder_sequence_length
186
+ )
187
+ encoder_inputs = self.audio_converter(
188
+ x["audio"],
189
+ padding="longest",
190
+ )
191
+ encoder_inputs_shape = keras.ops.shape(encoder_inputs)
192
+ if len(encoder_inputs_shape) == 2:
193
+ encoder_inputs = keras.ops.expand_dims(encoder_inputs, axis=-1)
194
+ squeezed_inputs = encoder_inputs[:, :, 0]
195
+ is_tf_symbolic = (
196
+ tf is not None
197
+ and hasattr(squeezed_inputs, "graph")
198
+ and hasattr(squeezed_inputs.graph, "as_graph_def")
199
+ )
200
+ if is_tf_symbolic and keras.config.backend() != "tensorflow":
201
+ encoder_padding_mask = tf.logical_not(
202
+ tf.math.equal(
203
+ squeezed_inputs, float(self.audio_converter.padding_value)
204
+ )
205
+ )
206
+ else:
207
+ encoder_padding_mask = keras.ops.logical_not(
208
+ keras.ops.equal(
209
+ squeezed_inputs, self.audio_converter.padding_value
210
+ )
211
+ )
212
+ audio_batch_size = keras.ops.shape(x["audio"])[0]
213
+ decoder_text = x.get("text", None)
214
+ if decoder_text is None:
215
+ decoder_token_ids = [
216
+ [self.tokenizer.start_token_id]
217
+ ] * audio_batch_size
218
+ else:
219
+ if isinstance(decoder_text, str):
220
+ decoder_text = [decoder_text] * audio_batch_size
221
+ elif len(decoder_text) != audio_batch_size:
222
+ if len(decoder_text) == 1:
223
+ decoder_text = decoder_text * audio_batch_size
224
+ else:
225
+ raise ValueError(
226
+ f"Batch size mismatch between audio "
227
+ f"({audio_batch_size}) and text prompts "
228
+ f"({len(decoder_text)})"
229
+ )
230
+ decoder_token_ids = self.tokenizer(decoder_text)
231
+ decoder_token_ids, decoder_padding_mask = self.decoder_packer(
232
+ decoder_token_ids,
233
+ sequence_length=decoder_sequence_length,
234
+ add_end_value=False,
235
+ )
236
+
237
+ return {
238
+ "encoder_input_values": encoder_inputs,
239
+ "encoder_padding_mask": encoder_padding_mask,
240
+ "decoder_token_ids": decoder_token_ids,
241
+ "decoder_padding_mask": decoder_padding_mask,
242
+ }
243
+
244
+ @preprocessing_function
245
+ def generate_postprocess(self, x):
246
+ if not self.built:
247
+ self.build(None)
248
+ token_ids, padding_mask = (
249
+ x["decoder_token_ids"],
250
+ x["decoder_padding_mask"],
251
+ )
252
+ token_ids_np = keras.ops.convert_to_numpy(token_ids)
253
+ padding_mask_np = keras.ops.convert_to_numpy(padding_mask)
254
+ vocab_size = self.tokenizer.vocabulary_size()
255
+ processed_sequences = []
256
+ for i in range(token_ids_np.shape[0]):
257
+ sequence = token_ids_np[i]
258
+ mask = padding_mask_np[i].astype(bool)
259
+ valid_tokens = sequence[mask]
260
+ filtered_tokens = [
261
+ int(token)
262
+ for token in valid_tokens
263
+ if token not in self._special_token_ids_set
264
+ and 0 <= token < vocab_size
265
+ ]
266
+ processed_sequences.append(filtered_tokens)
267
+ return self.tokenizer.detokenize(processed_sequences)