keras-hub-nightly 0.21.0.dev202505140407__py3-none-any.whl → 0.21.0.dev202505160409__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/layers/__init__.py +3 -0
- keras_hub/models/__init__.py +12 -0
- keras_hub/src/models/moonshine/__init__.py +0 -0
- keras_hub/src/models/moonshine/moonshine_audio_converter.py +301 -0
- keras_hub/src/models/moonshine/moonshine_audio_to_text.py +383 -0
- keras_hub/src/models/moonshine/moonshine_audio_to_text_preprocessor.py +267 -0
- keras_hub/src/models/moonshine/moonshine_backbone.py +478 -0
- keras_hub/src/models/moonshine/moonshine_decoder.py +313 -0
- keras_hub/src/models/moonshine/moonshine_encoder.py +212 -0
- keras_hub/src/models/moonshine/moonshine_layers.py +239 -0
- keras_hub/src/models/moonshine/moonshine_multi_head_attention.py +355 -0
- keras_hub/src/models/moonshine/moonshine_presets.py +25 -0
- keras_hub/src/models/moonshine/moonshine_tokenizer.py +62 -0
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +3 -0
- {keras_hub_nightly-0.21.0.dev202505140407.dist-info → keras_hub_nightly-0.21.0.dev202505160409.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.21.0.dev202505140407.dist-info → keras_hub_nightly-0.21.0.dev202505160409.dist-info}/RECORD +19 -8
- {keras_hub_nightly-0.21.0.dev202505140407.dist-info → keras_hub_nightly-0.21.0.dev202505160409.dist-info}/WHEEL +1 -1
- {keras_hub_nightly-0.21.0.dev202505140407.dist-info → keras_hub_nightly-0.21.0.dev202505160409.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,383 @@
|
|
1
|
+
import keras
|
2
|
+
|
3
|
+
from keras_hub.src.api_export import keras_hub_export
|
4
|
+
from keras_hub.src.models.moonshine.moonshine_audio_to_text_preprocessor import ( # noqa: E501
|
5
|
+
MoonshineAudioToTextPreprocessor,
|
6
|
+
)
|
7
|
+
from keras_hub.src.models.moonshine.moonshine_backbone import Arange
|
8
|
+
from keras_hub.src.models.moonshine.moonshine_backbone import MoonshineBackbone
|
9
|
+
from keras_hub.src.models.moonshine.moonshine_backbone import (
|
10
|
+
compute_output_lengths,
|
11
|
+
)
|
12
|
+
from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM
|
13
|
+
from keras_hub.src.utils.tensor_utils import any_equal
|
14
|
+
|
15
|
+
|
16
|
+
@keras_hub_export("keras_hub.models.MoonshineAudioToText")
|
17
|
+
class MoonshineAudioToText(Seq2SeqLM):
|
18
|
+
"""An end-to-end Moonshine model for audio-to-text tasks.
|
19
|
+
|
20
|
+
A Seq2Seq LM designed for audio-to-text tasks, such as speech recognition.
|
21
|
+
The encoder processes audio features, and the decoder generates text
|
22
|
+
transcriptions. You can finetune `MoonshineAudioToText` for any
|
23
|
+
audio-to-text task (e.g., live transcription or voice commands).
|
24
|
+
|
25
|
+
This model includes a `generate()` method for text generation based on audio
|
26
|
+
inputs and an optional text prompt for the decoder. The generation strategy
|
27
|
+
is controlled by a `sampler` argument passed to `compile()`. By default,
|
28
|
+
`"top_k"` sampling is used.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
backbone: A `keras_hub.models.MoonshineBackbone` instance.
|
32
|
+
preprocessor: A `keras_hub.models.MoonshineAudioToTextPreprocessor` or
|
33
|
+
`None`. If `None`, inputs must be preprocessed before calling the
|
34
|
+
model.
|
35
|
+
|
36
|
+
Examples:
|
37
|
+
```python
|
38
|
+
# Initialize model from preset.
|
39
|
+
moonshine_lm = keras_hub.models.MoonshineAudioToText.from_preset(
|
40
|
+
"moonshine_base"
|
41
|
+
)
|
42
|
+
|
43
|
+
# Generate with single audio input.
|
44
|
+
audio_tensor = keras.random.normal((1, 16000, 1))
|
45
|
+
moonshine_lm.generate({"audio": audio_tensor})
|
46
|
+
|
47
|
+
# Generate with text prompt.
|
48
|
+
moonshine_lm.generate({"audio": audio_tensor, "text": "quick"})
|
49
|
+
|
50
|
+
# Use different sampling strategy.
|
51
|
+
moonshine_lm.compile(sampler="greedy")
|
52
|
+
moonshine_lm.generate({"audio": audio_tensor})
|
53
|
+
```
|
54
|
+
"""
|
55
|
+
|
56
|
+
# References:
|
57
|
+
# Defined and formulated based on the Hugging Face implementation of the
|
58
|
+
# MoonshineForConditionalGeneration class (https://github.com/huggingface/transformers/blob/dcbdf7e962c4b36140cc9ee76f870016121e69e5/src/transformers/models/moonshine/modeling_moonshine.py#L1509-L1626).
|
59
|
+
|
60
|
+
backbone_cls = MoonshineBackbone
|
61
|
+
preprocessor_cls = MoonshineAudioToTextPreprocessor
|
62
|
+
|
63
|
+
def __init__(self, backbone, preprocessor=None, **kwargs):
|
64
|
+
# === Layers ===
|
65
|
+
self.backbone = backbone
|
66
|
+
self.preprocessor = preprocessor
|
67
|
+
|
68
|
+
# === Functional Model ===
|
69
|
+
inputs = backbone.input
|
70
|
+
hidden_states = backbone(inputs)["decoder_sequence_output"]
|
71
|
+
outputs = backbone.token_embedding(hidden_states, reverse=True)
|
72
|
+
super().__init__(
|
73
|
+
inputs=inputs,
|
74
|
+
outputs=outputs,
|
75
|
+
**kwargs,
|
76
|
+
)
|
77
|
+
|
78
|
+
def call_decoder_with_cache(
|
79
|
+
self,
|
80
|
+
encoder_hidden_states,
|
81
|
+
encoder_padding_mask,
|
82
|
+
decoder_token_ids,
|
83
|
+
self_attention_cache=None,
|
84
|
+
self_attention_cache_update_index=None,
|
85
|
+
cross_attention_cache=None,
|
86
|
+
):
|
87
|
+
"""Process decoder inputs with attention caching for efficient
|
88
|
+
generation.
|
89
|
+
|
90
|
+
Args:
|
91
|
+
encoder_hidden_states: Tensor. Encoder outputs.
|
92
|
+
encoder_padding_mask: Tensor. Padding mask for encoder outputs.
|
93
|
+
decoder_token_ids: Tensor. Decoder input token IDs.
|
94
|
+
self_attention_cache: Tensor, optional. Cache for self-attention
|
95
|
+
layers.
|
96
|
+
self_attention_cache_update_index: int, optional. Index for cache
|
97
|
+
updates.
|
98
|
+
cross_attention_cache: Tensor, optional. Cache for cross-attention
|
99
|
+
layers. This cache is computed once and reused.
|
100
|
+
|
101
|
+
Returns:
|
102
|
+
Tuple: Tuple of (logits, hidden_states, new_self_attention_cache,
|
103
|
+
cross_attention_cache).
|
104
|
+
"""
|
105
|
+
tokens = self.backbone.token_embedding(decoder_token_ids)
|
106
|
+
x = tokens
|
107
|
+
|
108
|
+
# Cache management for audio-to-text generation.
|
109
|
+
self_attention_caches = []
|
110
|
+
position = keras.ops.array(
|
111
|
+
[self_attention_cache_update_index], dtype="int32"
|
112
|
+
)
|
113
|
+
rotary_embedding = self.backbone.decoder_rotary_embedding(position)
|
114
|
+
|
115
|
+
for i, layer in enumerate(self.backbone.decoder_blocks):
|
116
|
+
current_self_cache = self_attention_cache[:, i, ...]
|
117
|
+
current_cross_cache = cross_attention_cache[:, i, ...]
|
118
|
+
x, new_self_cache = layer(
|
119
|
+
decoder_sequence=x,
|
120
|
+
encoder_sequence=encoder_hidden_states,
|
121
|
+
rotary_embedding=rotary_embedding,
|
122
|
+
encoder_padding_mask=encoder_padding_mask,
|
123
|
+
self_attention_cache=current_self_cache,
|
124
|
+
self_attention_cache_update_index=self_attention_cache_update_index,
|
125
|
+
cross_attention_cache=current_cross_cache,
|
126
|
+
training=False,
|
127
|
+
)
|
128
|
+
# Update self-attention cache.
|
129
|
+
self_attention_caches.append(new_self_cache)
|
130
|
+
|
131
|
+
# [batch_size, num_layers, 2, seq_len, num_heads, head_dim].
|
132
|
+
new_self_attention_cache = keras.ops.stack(
|
133
|
+
self_attention_caches, axis=1
|
134
|
+
)
|
135
|
+
hidden_states = self.backbone.decoder_post_norm(x)
|
136
|
+
logits = self.backbone.token_embedding(hidden_states, reverse=True)
|
137
|
+
return (
|
138
|
+
logits,
|
139
|
+
hidden_states,
|
140
|
+
new_self_attention_cache,
|
141
|
+
cross_attention_cache,
|
142
|
+
)
|
143
|
+
|
144
|
+
def _build_cache(
|
145
|
+
self,
|
146
|
+
encoder_input_values,
|
147
|
+
encoder_padding_mask,
|
148
|
+
decoder_token_ids,
|
149
|
+
decoder_padding_mask,
|
150
|
+
):
|
151
|
+
"""Build initial cache states from inputs."""
|
152
|
+
encoder_hidden_states, encoder_attention_mask_for_decoder = (
|
153
|
+
self.call_encoder(
|
154
|
+
encoder_input_values=encoder_input_values,
|
155
|
+
padding_mask=encoder_padding_mask,
|
156
|
+
)
|
157
|
+
)
|
158
|
+
precomputed_cross_caches = []
|
159
|
+
for layer in self.backbone.decoder_blocks:
|
160
|
+
cross_k = layer.cross_attention._key_dense(encoder_hidden_states)
|
161
|
+
cross_v = layer.cross_attention._value_dense(encoder_hidden_states)
|
162
|
+
layer_cross_cache = keras.ops.stack([cross_k, cross_v], axis=1)
|
163
|
+
precomputed_cross_caches.append(layer_cross_cache)
|
164
|
+
precomputed_cross_cache = keras.ops.stack(
|
165
|
+
precomputed_cross_caches, axis=1
|
166
|
+
)
|
167
|
+
batch_size = keras.ops.shape(encoder_input_values)[0]
|
168
|
+
num_layers = self.backbone.decoder_num_layers
|
169
|
+
num_heads = self.backbone.decoder_num_heads
|
170
|
+
head_dim = self.backbone.hidden_dim // self.backbone.decoder_num_heads
|
171
|
+
if self.backbone.pad_head_dim_to_multiple_of is not None:
|
172
|
+
head_dim = (
|
173
|
+
(head_dim + self.backbone.pad_head_dim_to_multiple_of - 1)
|
174
|
+
// self.backbone.pad_head_dim_to_multiple_of
|
175
|
+
) * self.backbone.pad_head_dim_to_multiple_of
|
176
|
+
# Use the full sequence length for the cache dimension.
|
177
|
+
cache_length = keras.ops.shape(decoder_token_ids)[1]
|
178
|
+
initial_self_cache_shape = (
|
179
|
+
batch_size,
|
180
|
+
num_layers,
|
181
|
+
2,
|
182
|
+
cache_length,
|
183
|
+
num_heads,
|
184
|
+
head_dim,
|
185
|
+
)
|
186
|
+
initial_self_cache = keras.ops.zeros(
|
187
|
+
initial_self_cache_shape, dtype=self.compute_dtype
|
188
|
+
)
|
189
|
+
tokens = self.backbone.token_embedding(decoder_token_ids)
|
190
|
+
x = tokens
|
191
|
+
positions = keras.ops.arange(0, cache_length, dtype="int32")
|
192
|
+
rotary_embedding = self.backbone.decoder_rotary_embedding(positions)
|
193
|
+
seeded_self_caches = []
|
194
|
+
for i, layer in enumerate(self.backbone.decoder_blocks):
|
195
|
+
current_initial_self_cache = initial_self_cache[:, i, ...]
|
196
|
+
current_precomputed_cross_cache = precomputed_cross_cache[:, i, ...]
|
197
|
+
x, seeded_self_cache_layer = layer(
|
198
|
+
decoder_sequence=x,
|
199
|
+
encoder_sequence=encoder_hidden_states,
|
200
|
+
rotary_embedding=rotary_embedding,
|
201
|
+
decoder_padding_mask=decoder_padding_mask,
|
202
|
+
encoder_padding_mask=encoder_attention_mask_for_decoder,
|
203
|
+
self_attention_cache=current_initial_self_cache,
|
204
|
+
self_attention_cache_update_index=0,
|
205
|
+
cross_attention_cache=current_precomputed_cross_cache,
|
206
|
+
training=False,
|
207
|
+
)
|
208
|
+
seeded_self_caches.append(seeded_self_cache_layer)
|
209
|
+
hidden_states = self.backbone.decoder_post_norm(x)
|
210
|
+
self_attn_cache = keras.ops.stack(seeded_self_caches, axis=1)
|
211
|
+
return (
|
212
|
+
hidden_states,
|
213
|
+
self_attn_cache,
|
214
|
+
precomputed_cross_cache,
|
215
|
+
encoder_hidden_states,
|
216
|
+
encoder_attention_mask_for_decoder,
|
217
|
+
)
|
218
|
+
|
219
|
+
def call_encoder(self, encoder_input_values, padding_mask):
|
220
|
+
"""Process audio input through the encoder stack."""
|
221
|
+
x = self.backbone.conv1(encoder_input_values)
|
222
|
+
x = self.backbone.tanh_after_conv1(x)
|
223
|
+
x = self.backbone.group_norm(x)
|
224
|
+
x = self.backbone.conv2(x)
|
225
|
+
x = self.backbone.gelu_after_conv2(x)
|
226
|
+
x = self.backbone.conv3(x)
|
227
|
+
x = self.backbone.gelu_after_conv3(x)
|
228
|
+
original_lengths = keras.ops.sum(
|
229
|
+
keras.ops.cast(padding_mask, "int32"), axis=1
|
230
|
+
)
|
231
|
+
output_lengths = compute_output_lengths(original_lengths)
|
232
|
+
padding_mask = self.backbone._compute_mask_layer(x, output_lengths)
|
233
|
+
positions = Arange(name="encoder_positions")(x)
|
234
|
+
rotary_embedding = self.backbone.encoder_rotary_embedding(positions)
|
235
|
+
x = self.backbone.encoder_dropout(x, training=False)
|
236
|
+
for transformer_layer in self.backbone.encoder_blocks:
|
237
|
+
x = transformer_layer(
|
238
|
+
inputs=x,
|
239
|
+
rotary_embedding=rotary_embedding,
|
240
|
+
attention_mask=padding_mask,
|
241
|
+
training=False,
|
242
|
+
)
|
243
|
+
x = self.backbone.encoder_final_layer_norm(x)
|
244
|
+
return x, padding_mask
|
245
|
+
|
246
|
+
# Source: https://github.com/huggingface/transformers/blob/9e94801146ceeb3b215bbdb9492be74d7d7b7210/src/transformers/generation/utils.py#L1970-L2463
|
247
|
+
def generate_step(self, inputs, stop_token_ids=None):
|
248
|
+
"""A compilable generation function for a batch of inputs.
|
249
|
+
|
250
|
+
This function represents the inner, XLA-compilable, generation function
|
251
|
+
for a single batch of inputs. Inputs should have the same structure as
|
252
|
+
model inputs, a dictionary with keys `"encoder_input_values"`,
|
253
|
+
`"encoder_padding_mask"`, `"decoder_token_ids"` and
|
254
|
+
`"decoder_padding_mask"`.
|
255
|
+
|
256
|
+
Args:
|
257
|
+
inputs: A dictionary with four keys - `"encoder_input_values"`,
|
258
|
+
`"encoder_padding_mask"`, `"decoder_token_ids"` and
|
259
|
+
`"decoder_padding_mask"`, with batched tensor values.
|
260
|
+
stop_token_ids: Tuple of id's of end token's to stop on. If all
|
261
|
+
sequences have produced a new stop token, generation
|
262
|
+
will stop.
|
263
|
+
|
264
|
+
Returns:
|
265
|
+
Dictionary: A dictionary with two keys - `"decoder_token_ids"`
|
266
|
+
containing the updated token sequence with newly generated
|
267
|
+
tokens, and `"decoder_padding_mask"` containing the updated
|
268
|
+
padding mask for the generated sequence.
|
269
|
+
"""
|
270
|
+
encoder_input_values = inputs["encoder_input_values"]
|
271
|
+
encoder_padding_mask = inputs["encoder_padding_mask"]
|
272
|
+
decoder_token_ids = inputs["decoder_token_ids"]
|
273
|
+
decoder_padding_mask = inputs["decoder_padding_mask"]
|
274
|
+
|
275
|
+
if (
|
276
|
+
encoder_input_values is None
|
277
|
+
or encoder_padding_mask is None
|
278
|
+
or decoder_token_ids is None
|
279
|
+
):
|
280
|
+
raise ValueError("Input tensors cannot be None")
|
281
|
+
|
282
|
+
(
|
283
|
+
hidden_states,
|
284
|
+
self_attention_cache,
|
285
|
+
cross_attention_cache,
|
286
|
+
encoder_hidden_states,
|
287
|
+
encoder_attention_mask_for_decoder,
|
288
|
+
) = self._build_cache(
|
289
|
+
encoder_input_values,
|
290
|
+
encoder_padding_mask,
|
291
|
+
decoder_token_ids,
|
292
|
+
decoder_padding_mask,
|
293
|
+
)
|
294
|
+
row_lengths = keras.ops.sum(
|
295
|
+
keras.ops.cast(decoder_padding_mask, "int32"),
|
296
|
+
axis=-1,
|
297
|
+
)
|
298
|
+
index = keras.ops.min(row_lengths)
|
299
|
+
|
300
|
+
def next(prompt, cache, index):
|
301
|
+
if isinstance(cache, tuple) and len(cache) == 2:
|
302
|
+
current_self_attention_cache = cache[0]
|
303
|
+
current_cross_attention_cache = cache[1]
|
304
|
+
elif cache is not None and not isinstance(cache, tuple):
|
305
|
+
current_self_attention_cache = cache
|
306
|
+
current_cross_attention_cache = cross_attention_cache
|
307
|
+
else:
|
308
|
+
cache = None
|
309
|
+
cache_index = index - 1
|
310
|
+
num_samples = keras.ops.shape(prompt)[0]
|
311
|
+
next_token_input = keras.ops.slice(
|
312
|
+
prompt, [0, cache_index], [num_samples, 1]
|
313
|
+
)
|
314
|
+
|
315
|
+
batch_size = keras.ops.shape(encoder_input_values)[0]
|
316
|
+
|
317
|
+
def repeat_tensor(x):
|
318
|
+
if keras.ops.shape(x)[0] == num_samples:
|
319
|
+
return x
|
320
|
+
return keras.ops.repeat(
|
321
|
+
x, repeats=num_samples // batch_size, axis=0
|
322
|
+
)
|
323
|
+
|
324
|
+
cross_attention_cache_repeated = repeat_tensor(
|
325
|
+
current_cross_attention_cache
|
326
|
+
)
|
327
|
+
logits, hidden_states, new_self_attention_cache, _ = (
|
328
|
+
self.call_decoder_with_cache(
|
329
|
+
encoder_hidden_states=repeat_tensor(encoder_hidden_states),
|
330
|
+
encoder_padding_mask=repeat_tensor(
|
331
|
+
encoder_attention_mask_for_decoder
|
332
|
+
),
|
333
|
+
decoder_token_ids=next_token_input,
|
334
|
+
self_attention_cache=current_self_attention_cache,
|
335
|
+
self_attention_cache_update_index=cache_index,
|
336
|
+
cross_attention_cache=cross_attention_cache_repeated,
|
337
|
+
)
|
338
|
+
)
|
339
|
+
return (
|
340
|
+
logits[:, 0, :],
|
341
|
+
hidden_states[:, 0, :],
|
342
|
+
(new_self_attention_cache, current_cross_attention_cache),
|
343
|
+
)
|
344
|
+
|
345
|
+
decoder_token_ids = self.sampler(
|
346
|
+
next=next,
|
347
|
+
prompt=decoder_token_ids,
|
348
|
+
cache=(self_attention_cache, cross_attention_cache),
|
349
|
+
index=index,
|
350
|
+
mask=keras.ops.cast(
|
351
|
+
decoder_token_ids != self.preprocessor.tokenizer.pad_token_id
|
352
|
+
if self.preprocessor is not None
|
353
|
+
else decoder_padding_mask,
|
354
|
+
dtype="bool",
|
355
|
+
),
|
356
|
+
stop_token_ids=stop_token_ids,
|
357
|
+
hidden_states=hidden_states,
|
358
|
+
model=self,
|
359
|
+
)
|
360
|
+
|
361
|
+
if stop_token_ids is not None:
|
362
|
+
end_locations = any_equal(
|
363
|
+
decoder_token_ids,
|
364
|
+
stop_token_ids,
|
365
|
+
decoder_token_ids == self.preprocessor.tokenizer.pad_token_id
|
366
|
+
if self.preprocessor is not None
|
367
|
+
else False,
|
368
|
+
)
|
369
|
+
end_locations = keras.ops.cast(end_locations, "int32")
|
370
|
+
cumsum = keras.ops.cumsum(end_locations, axis=-1)
|
371
|
+
overflow = cumsum - end_locations
|
372
|
+
decoder_padding_mask = keras.ops.logical_not(
|
373
|
+
keras.ops.cast(overflow, "bool")
|
374
|
+
)
|
375
|
+
else:
|
376
|
+
decoder_padding_mask = keras.ops.ones_like(
|
377
|
+
decoder_token_ids, dtype="bool"
|
378
|
+
)
|
379
|
+
|
380
|
+
return {
|
381
|
+
"decoder_token_ids": decoder_token_ids,
|
382
|
+
"decoder_padding_mask": decoder_padding_mask,
|
383
|
+
}
|
@@ -0,0 +1,267 @@
|
|
1
|
+
import keras
|
2
|
+
|
3
|
+
try:
|
4
|
+
import tensorflow as tf
|
5
|
+
except ImportError:
|
6
|
+
tf = None
|
7
|
+
from keras_hub.src.api_export import keras_hub_export
|
8
|
+
from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker
|
9
|
+
from keras_hub.src.models.moonshine.moonshine_backbone import MoonshineBackbone
|
10
|
+
from keras_hub.src.models.moonshine.moonshine_tokenizer import (
|
11
|
+
MoonshineTokenizer,
|
12
|
+
)
|
13
|
+
from keras_hub.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor
|
14
|
+
from keras_hub.src.utils.tensor_utils import preprocessing_function
|
15
|
+
|
16
|
+
|
17
|
+
@keras_hub_export("keras_hub.models.MoonshineAudioToTextPreprocessor")
|
18
|
+
class MoonshineAudioToTextPreprocessor(Seq2SeqLMPreprocessor):
|
19
|
+
"""Moonshine Seq2Seq LM preprocessor for audio-to-text tasks.
|
20
|
+
|
21
|
+
This preprocessor converts raw audio and text inputs into a format suitable
|
22
|
+
for the `MoonshineAudioToText` model. It processes audio waveforms using
|
23
|
+
`MoonshineAudioConverter` for basic preprocessing (padding, normalization)
|
24
|
+
and tokenizes text using `MoonshineTokenizer` for the decoder. It supports
|
25
|
+
training and generation.
|
26
|
+
|
27
|
+
Args:
|
28
|
+
audio_converter: A `MoonshineAudioConverter` instance to process audio.
|
29
|
+
tokenizer: A `MoonshineTokenizer` instance to tokenize text.
|
30
|
+
decoder_sequence_length: int, optional. Maximum length for decoder token
|
31
|
+
sequences. Defaults to 1024.
|
32
|
+
**kwargs: Additional keyword arguments for the parent class.
|
33
|
+
|
34
|
+
Examples:
|
35
|
+
```python
|
36
|
+
import keras
|
37
|
+
from keras_hub.layers import MoonshineAudioConverter
|
38
|
+
from keras_hub.models import MoonshineTokenizer
|
39
|
+
|
40
|
+
# Create audio converter and tokenizer instances.
|
41
|
+
audio_converter = MoonshineAudioConverter()
|
42
|
+
tokenizer = MoonshineTokenizer.from_preset("moonshine_base")
|
43
|
+
|
44
|
+
# Initialize the preprocessor.
|
45
|
+
preprocessor = keras_hub.models.MoonshineAudioToTextPreprocessor(
|
46
|
+
audio_converter=audio_converter,
|
47
|
+
tokenizer=tokenizer,
|
48
|
+
decoder_sequence_length=8
|
49
|
+
)
|
50
|
+
|
51
|
+
# Prepare input data (audio tensor and text).
|
52
|
+
inputs = {
|
53
|
+
"audio": keras.random.normal((1, 16000)),
|
54
|
+
"text": ["the quick brown fox"]
|
55
|
+
}
|
56
|
+
|
57
|
+
# Process the inputs for training.
|
58
|
+
x, y, sample_weight = preprocessor(inputs)
|
59
|
+
|
60
|
+
# Check output keys and shapes (shapes depend on padding/truncation).
|
61
|
+
print(x.keys())
|
62
|
+
# dict_keys(['encoder_input_values', 'encoder_padding_mask',
|
63
|
+
# 'decoder_token_ids', 'decoder_padding_mask']).
|
64
|
+
print(x["encoder_input_values"].shape) # e.g., (1, 16000, 1) / padded length
|
65
|
+
print(x["encoder_padding_mask"].shape) # e.g., (1, 16000) or padded length
|
66
|
+
print(x["decoder_token_ids"].shape) # (1, 8)
|
67
|
+
print(x["decoder_padding_mask"].shape) # (1, 8)
|
68
|
+
print(y.shape) # (1, 8) - Labels
|
69
|
+
print(sample_weight.shape) # (1, 8) - Sample weights
|
70
|
+
|
71
|
+
# Process inputs for generation.
|
72
|
+
gen_inputs = preprocessor.generate_preprocess(inputs)
|
73
|
+
print(gen_inputs.keys())
|
74
|
+
# dict_keys(['encoder_input_values', 'encoder_padding_mask',
|
75
|
+
# 'decoder_token_ids', 'decoder_padding_mask']).
|
76
|
+
```
|
77
|
+
"""
|
78
|
+
|
79
|
+
backbone_cls = MoonshineBackbone
|
80
|
+
tokenizer_cls = MoonshineTokenizer
|
81
|
+
|
82
|
+
def __init__(
|
83
|
+
self,
|
84
|
+
audio_converter,
|
85
|
+
tokenizer,
|
86
|
+
decoder_sequence_length=1024,
|
87
|
+
**kwargs,
|
88
|
+
):
|
89
|
+
super().__init__(tokenizer=tokenizer, **kwargs)
|
90
|
+
self.audio_converter = audio_converter
|
91
|
+
self.decoder_sequence_length = decoder_sequence_length
|
92
|
+
self.decoder_packer = None
|
93
|
+
self._special_token_ids_set = None
|
94
|
+
|
95
|
+
def build(self, input_shape):
|
96
|
+
self.decoder_packer = StartEndPacker(
|
97
|
+
start_value=self.tokenizer.start_token_id,
|
98
|
+
end_value=self.tokenizer.end_token_id,
|
99
|
+
pad_value=self.tokenizer.pad_token_id,
|
100
|
+
sequence_length=self.decoder_sequence_length,
|
101
|
+
return_padding_mask=True,
|
102
|
+
)
|
103
|
+
self._special_token_ids_set = set(self.tokenizer.special_token_ids)
|
104
|
+
if self.tokenizer.pad_token_id is not None:
|
105
|
+
self._special_token_ids_set.add(self.tokenizer.pad_token_id)
|
106
|
+
self.built = True
|
107
|
+
|
108
|
+
@preprocessing_function
|
109
|
+
def call(
|
110
|
+
self,
|
111
|
+
x,
|
112
|
+
y=None,
|
113
|
+
sample_weight=None,
|
114
|
+
decoder_sequence_length=None,
|
115
|
+
sequence_length=None,
|
116
|
+
):
|
117
|
+
if not self.built:
|
118
|
+
self.build(None)
|
119
|
+
if isinstance(x, tuple) and len(x) == 1:
|
120
|
+
x = x[0]
|
121
|
+
decoder_sequence_length = (
|
122
|
+
decoder_sequence_length
|
123
|
+
or sequence_length
|
124
|
+
or self.decoder_sequence_length
|
125
|
+
)
|
126
|
+
text = x["text"]
|
127
|
+
encoder_inputs = self.audio_converter(
|
128
|
+
x["audio"],
|
129
|
+
padding="longest",
|
130
|
+
)
|
131
|
+
encoder_inputs_shape = keras.ops.shape(encoder_inputs)
|
132
|
+
if len(encoder_inputs_shape) == 2:
|
133
|
+
encoder_inputs = keras.ops.expand_dims(encoder_inputs, axis=-1)
|
134
|
+
squeezed_inputs = encoder_inputs[:, :, 0]
|
135
|
+
is_tf_symbolic = (
|
136
|
+
tf is not None
|
137
|
+
and hasattr(squeezed_inputs, "graph")
|
138
|
+
and hasattr(squeezed_inputs.graph, "as_graph_def")
|
139
|
+
)
|
140
|
+
if is_tf_symbolic and keras.config.backend() != "tensorflow":
|
141
|
+
encoder_padding_mask = tf.logical_not(
|
142
|
+
tf.math.equal(
|
143
|
+
squeezed_inputs, float(self.audio_converter.padding_value)
|
144
|
+
)
|
145
|
+
)
|
146
|
+
else:
|
147
|
+
encoder_padding_mask = keras.ops.logical_not(
|
148
|
+
keras.ops.equal(
|
149
|
+
squeezed_inputs, self.audio_converter.padding_value
|
150
|
+
)
|
151
|
+
)
|
152
|
+
decoder_inputs = self.tokenizer(text)
|
153
|
+
decoder_token_ids, decoder_padding_mask = self.decoder_packer(
|
154
|
+
decoder_inputs,
|
155
|
+
sequence_length=decoder_sequence_length + 1,
|
156
|
+
add_end_value=True,
|
157
|
+
)
|
158
|
+
x_out = {
|
159
|
+
"encoder_input_values": encoder_inputs,
|
160
|
+
"encoder_padding_mask": encoder_padding_mask,
|
161
|
+
"decoder_token_ids": decoder_token_ids[..., :-1],
|
162
|
+
"decoder_padding_mask": decoder_padding_mask[..., :-1],
|
163
|
+
}
|
164
|
+
y_out = decoder_token_ids[..., 1:]
|
165
|
+
sample_weight_out = decoder_padding_mask[..., 1:]
|
166
|
+
|
167
|
+
return keras.utils.pack_x_y_sample_weight(
|
168
|
+
x_out, y_out, sample_weight_out
|
169
|
+
)
|
170
|
+
|
171
|
+
@preprocessing_function
|
172
|
+
def generate_preprocess(
|
173
|
+
self,
|
174
|
+
x,
|
175
|
+
decoder_sequence_length=None,
|
176
|
+
sequence_length=None,
|
177
|
+
):
|
178
|
+
if not self.built:
|
179
|
+
self.build(None)
|
180
|
+
if isinstance(x, tuple) and len(x) == 1:
|
181
|
+
x = x[0]
|
182
|
+
decoder_sequence_length = (
|
183
|
+
decoder_sequence_length
|
184
|
+
or sequence_length
|
185
|
+
or self.decoder_sequence_length
|
186
|
+
)
|
187
|
+
encoder_inputs = self.audio_converter(
|
188
|
+
x["audio"],
|
189
|
+
padding="longest",
|
190
|
+
)
|
191
|
+
encoder_inputs_shape = keras.ops.shape(encoder_inputs)
|
192
|
+
if len(encoder_inputs_shape) == 2:
|
193
|
+
encoder_inputs = keras.ops.expand_dims(encoder_inputs, axis=-1)
|
194
|
+
squeezed_inputs = encoder_inputs[:, :, 0]
|
195
|
+
is_tf_symbolic = (
|
196
|
+
tf is not None
|
197
|
+
and hasattr(squeezed_inputs, "graph")
|
198
|
+
and hasattr(squeezed_inputs.graph, "as_graph_def")
|
199
|
+
)
|
200
|
+
if is_tf_symbolic and keras.config.backend() != "tensorflow":
|
201
|
+
encoder_padding_mask = tf.logical_not(
|
202
|
+
tf.math.equal(
|
203
|
+
squeezed_inputs, float(self.audio_converter.padding_value)
|
204
|
+
)
|
205
|
+
)
|
206
|
+
else:
|
207
|
+
encoder_padding_mask = keras.ops.logical_not(
|
208
|
+
keras.ops.equal(
|
209
|
+
squeezed_inputs, self.audio_converter.padding_value
|
210
|
+
)
|
211
|
+
)
|
212
|
+
audio_batch_size = keras.ops.shape(x["audio"])[0]
|
213
|
+
decoder_text = x.get("text", None)
|
214
|
+
if decoder_text is None:
|
215
|
+
decoder_token_ids = [
|
216
|
+
[self.tokenizer.start_token_id]
|
217
|
+
] * audio_batch_size
|
218
|
+
else:
|
219
|
+
if isinstance(decoder_text, str):
|
220
|
+
decoder_text = [decoder_text] * audio_batch_size
|
221
|
+
elif len(decoder_text) != audio_batch_size:
|
222
|
+
if len(decoder_text) == 1:
|
223
|
+
decoder_text = decoder_text * audio_batch_size
|
224
|
+
else:
|
225
|
+
raise ValueError(
|
226
|
+
f"Batch size mismatch between audio "
|
227
|
+
f"({audio_batch_size}) and text prompts "
|
228
|
+
f"({len(decoder_text)})"
|
229
|
+
)
|
230
|
+
decoder_token_ids = self.tokenizer(decoder_text)
|
231
|
+
decoder_token_ids, decoder_padding_mask = self.decoder_packer(
|
232
|
+
decoder_token_ids,
|
233
|
+
sequence_length=decoder_sequence_length,
|
234
|
+
add_end_value=False,
|
235
|
+
)
|
236
|
+
|
237
|
+
return {
|
238
|
+
"encoder_input_values": encoder_inputs,
|
239
|
+
"encoder_padding_mask": encoder_padding_mask,
|
240
|
+
"decoder_token_ids": decoder_token_ids,
|
241
|
+
"decoder_padding_mask": decoder_padding_mask,
|
242
|
+
}
|
243
|
+
|
244
|
+
@preprocessing_function
|
245
|
+
def generate_postprocess(self, x):
|
246
|
+
if not self.built:
|
247
|
+
self.build(None)
|
248
|
+
token_ids, padding_mask = (
|
249
|
+
x["decoder_token_ids"],
|
250
|
+
x["decoder_padding_mask"],
|
251
|
+
)
|
252
|
+
token_ids_np = keras.ops.convert_to_numpy(token_ids)
|
253
|
+
padding_mask_np = keras.ops.convert_to_numpy(padding_mask)
|
254
|
+
vocab_size = self.tokenizer.vocabulary_size()
|
255
|
+
processed_sequences = []
|
256
|
+
for i in range(token_ids_np.shape[0]):
|
257
|
+
sequence = token_ids_np[i]
|
258
|
+
mask = padding_mask_np[i].astype(bool)
|
259
|
+
valid_tokens = sequence[mask]
|
260
|
+
filtered_tokens = [
|
261
|
+
int(token)
|
262
|
+
for token in valid_tokens
|
263
|
+
if token not in self._special_token_ids_set
|
264
|
+
and 0 <= token < vocab_size
|
265
|
+
]
|
266
|
+
processed_sequences.append(filtered_tokens)
|
267
|
+
return self.tokenizer.detokenize(processed_sequences)
|