keras-hub 0.20.0.dev1__py3-none-any.whl → 0.21.0.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +15 -33
- keras_hub/layers/__init__.py +134 -0
- keras_hub/metrics/__init__.py +11 -0
- keras_hub/models/__init__.py +642 -0
- keras_hub/samplers/__init__.py +18 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +25 -35
- keras_hub/src/layers/preprocessing/image_converter.py +1 -0
- keras_hub/src/layers/preprocessing/random_deletion.py +1 -1
- keras_hub/src/layers/preprocessing/random_swap.py +1 -1
- keras_hub/src/models/audio_to_text.py +66 -0
- keras_hub/src/models/audio_to_text_preprocessor.py +80 -0
- keras_hub/src/models/backbone.py +5 -2
- keras_hub/src/models/cspnet/cspnet_backbone.py +51 -26
- keras_hub/src/models/cspnet/cspnet_presets.py +38 -3
- keras_hub/src/models/falcon/falcon_backbone.py +1 -1
- keras_hub/src/models/gemma/gemma_presets.py +10 -10
- keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +3 -2
- keras_hub/src/models/gemma3/gemma3_presets.py +8 -8
- keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
- keras_hub/src/models/llama/llama_attention.py +24 -6
- keras_hub/src/models/llama/llama_backbone.py +50 -16
- keras_hub/src/models/llama/llama_decoder.py +20 -3
- keras_hub/src/models/llama/llama_presets.py +3 -3
- keras_hub/src/models/llama/llama_rotary_embedding.py +180 -0
- keras_hub/src/models/llama3/llama3_backbone.py +10 -2
- keras_hub/src/models/llama3/llama3_presets.py +84 -2
- keras_hub/src/models/mistral/mistral_presets.py +3 -3
- keras_hub/src/models/mixtral/__init__.py +5 -0
- keras_hub/src/models/mixtral/mixtral_attention.py +252 -0
- keras_hub/src/models/mixtral/mixtral_backbone.py +207 -0
- keras_hub/src/models/mixtral/mixtral_causal_lm.py +281 -0
- keras_hub/src/models/mixtral/mixtral_causal_lm_preprocessor.py +76 -0
- keras_hub/src/models/mixtral/mixtral_decoder.py +494 -0
- keras_hub/src/models/mixtral/mixtral_layer_norm.py +34 -0
- keras_hub/src/models/mixtral/mixtral_presets.py +26 -0
- keras_hub/src/models/mixtral/mixtral_tokenizer.py +21 -0
- keras_hub/src/models/moonshine/__init__.py +5 -0
- keras_hub/src/models/moonshine/moonshine_audio_converter.py +301 -0
- keras_hub/src/models/moonshine/moonshine_audio_to_text.py +383 -0
- keras_hub/src/models/moonshine/moonshine_audio_to_text_preprocessor.py +272 -0
- keras_hub/src/models/moonshine/moonshine_backbone.py +478 -0
- keras_hub/src/models/moonshine/moonshine_decoder.py +313 -0
- keras_hub/src/models/moonshine/moonshine_encoder.py +212 -0
- keras_hub/src/models/moonshine/moonshine_layers.py +239 -0
- keras_hub/src/models/moonshine/moonshine_multi_head_attention.py +355 -0
- keras_hub/src/models/moonshine/moonshine_presets.py +25 -0
- keras_hub/src/models/moonshine/moonshine_tokenizer.py +62 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +11 -11
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +1 -1
- keras_hub/src/models/qwen/__init__.py +4 -0
- keras_hub/src/models/qwen/qwen_attention.py +3 -1
- keras_hub/src/models/qwen/qwen_backbone.py +8 -1
- keras_hub/src/models/qwen/qwen_causal_lm.py +7 -0
- keras_hub/src/models/qwen/qwen_causal_lm_preprocessor.py +7 -0
- keras_hub/src/models/qwen/qwen_presets.py +61 -0
- keras_hub/src/models/qwen/qwen_tokenizer.py +9 -0
- keras_hub/src/models/qwen_moe/__init__.py +5 -0
- keras_hub/src/models/qwen_moe/qwen_moe_attention.py +375 -0
- keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +373 -0
- keras_hub/src/models/qwen_moe/qwen_moe_causal_lm.py +350 -0
- keras_hub/src/models/qwen_moe/qwen_moe_causal_lm_preprocessor.py +17 -0
- keras_hub/src/models/qwen_moe/qwen_moe_decoder.py +625 -0
- keras_hub/src/models/qwen_moe/qwen_moe_layernorm.py +32 -0
- keras_hub/src/models/qwen_moe/qwen_moe_presets.py +15 -0
- keras_hub/src/models/qwen_moe/qwen_moe_tokenizer.py +46 -0
- keras_hub/src/models/retinanet/retinanet_image_converter.py +0 -13
- keras_hub/src/models/retinanet/retinanet_presets.py +2 -2
- keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +0 -18
- keras_hub/src/models/segformer/segformer_presets.py +12 -12
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +6 -0
- keras_hub/src/models/task.py +5 -2
- keras_hub/src/models/xception/__init__.py +5 -0
- keras_hub/src/models/xception/xception_backbone.py +188 -0
- keras_hub/src/models/xception/xception_image_classifier.py +12 -0
- keras_hub/src/models/xception/xception_image_classifier_preprocessor.py +14 -0
- keras_hub/src/models/xception/xception_image_converter.py +8 -0
- keras_hub/src/models/xception/xception_presets.py +14 -0
- keras_hub/src/tests/mocks/mock_gemma3_tokenizer.py +155 -0
- keras_hub/src/utils/coco/__init__.py +0 -0
- keras_hub/src/utils/coco/coco_utils.py +133 -0
- keras_hub/src/utils/imagenet/imagenet_utils.py +36 -0
- keras_hub/src/utils/keras_utils.py +11 -0
- keras_hub/src/utils/preset_utils.py +70 -10
- keras_hub/src/utils/tensor_utils.py +27 -1
- keras_hub/src/utils/timm/convert_cspnet.py +94 -23
- keras_hub/src/utils/timm/preset_loader.py +6 -6
- keras_hub/src/utils/transformers/convert_llama3.py +21 -1
- keras_hub/src/utils/transformers/convert_mixtral.py +139 -0
- keras_hub/src/utils/transformers/convert_qwen.py +1 -0
- keras_hub/src/utils/transformers/convert_qwen_moe.py +253 -0
- keras_hub/src/utils/transformers/preset_loader.py +6 -0
- keras_hub/src/{version_utils.py → version.py} +1 -1
- keras_hub/tokenizers/__init__.py +117 -0
- keras_hub/utils/__init__.py +21 -0
- {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/METADATA +6 -20
- {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/RECORD +98 -55
- {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/WHEEL +1 -1
- keras_hub/api/__init__.py +0 -15
- keras_hub/api/layers/__init__.py +0 -86
- keras_hub/api/metrics/__init__.py +0 -11
- keras_hub/api/models/__init__.py +0 -416
- keras_hub/api/samplers/__init__.py +0 -16
- keras_hub/api/tokenizers/__init__.py +0 -58
- keras_hub/api/utils/__init__.py +0 -9
- {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,301 @@
|
|
|
1
|
+
import keras
|
|
2
|
+
|
|
3
|
+
try:
|
|
4
|
+
import tensorflow as tf
|
|
5
|
+
except ImportError:
|
|
6
|
+
tf = None
|
|
7
|
+
|
|
8
|
+
from keras_hub.src.api_export import keras_hub_export
|
|
9
|
+
from keras_hub.src.layers.preprocessing.audio_converter import AudioConverter
|
|
10
|
+
from keras_hub.src.models.moonshine.moonshine_backbone import MoonshineBackbone
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@keras_hub_export("keras_hub.layers.MoonshineAudioConverter")
|
|
14
|
+
class MoonshineAudioConverter(AudioConverter):
|
|
15
|
+
"""Moonshine audio preprocessing layer.
|
|
16
|
+
|
|
17
|
+
This layer processes raw audio waveforms for the Moonshine ASR model. Audio
|
|
18
|
+
is formatted as a batched tensor at a 16kHz sample rate and validated for
|
|
19
|
+
length (0.1 to 64 seconds). The layer handles padding and optional
|
|
20
|
+
normalization. It does not contain trainable weights.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
sampling_rate: int, optional. The audio sampling rate in Hz. Defaults to
|
|
24
|
+
16,000.
|
|
25
|
+
padding_value: float, optional. The value for padding. Defaults to 0.0.
|
|
26
|
+
do_normalize: bool, optional. Whether to normalize inputs. Defaults to
|
|
27
|
+
False.
|
|
28
|
+
**kwargs: Additional keyword arguments passed to the base AudioConverter
|
|
29
|
+
class for customizing the underlying preprocessing behavior.
|
|
30
|
+
|
|
31
|
+
Call arguments:
|
|
32
|
+
- `inputs`: The raw audio data to be processed. It should be a tensor of
|
|
33
|
+
shape `(batch_size, time_steps, 1)` for mono audio. If the input has
|
|
34
|
+
shape `(batch_size, time_steps)`, the layer will add the channel
|
|
35
|
+
dimension.
|
|
36
|
+
- `sampling_rate`: The sampling rate of the audio in Hz. If
|
|
37
|
+
provided, it must match the expected sampling rate set during
|
|
38
|
+
initialization (default is 16,000 Hz). If not provided, the expected
|
|
39
|
+
sampling rate is taken from the initialization arguments.
|
|
40
|
+
- `padding`: The padding strategy to apply. If provided, can be one of:
|
|
41
|
+
- `"longest"`: If `pad_to_multiple_of` is set, pads the audio to
|
|
42
|
+
make the time_steps dimension a multiple of `pad_to_multiple_of`.
|
|
43
|
+
- `"max_length"`: Pads or truncates the audio to `max_length` time
|
|
44
|
+
steps. If `pad_to_multiple_of` is set, the target length will be
|
|
45
|
+
the smallest multiple of `pad_to_multiple_of` that is greater than
|
|
46
|
+
or equal to `max_length`.
|
|
47
|
+
- If not specified or `None`, no padding is applied.
|
|
48
|
+
- `max_length`: The target number of time steps when `padding` is
|
|
49
|
+
`"max_length"`. If not provided and `padding` is `"max_length"`, no
|
|
50
|
+
padding or truncation is applied.
|
|
51
|
+
- `pad_to_multiple_of`: If set, the padded time_steps will be a
|
|
52
|
+
multiple of this value for the chosen padding strategy.
|
|
53
|
+
|
|
54
|
+
Examples:
|
|
55
|
+
```python
|
|
56
|
+
import keras
|
|
57
|
+
from keras_hub.layers import MoonshineAudioConverter
|
|
58
|
+
|
|
59
|
+
# Create a dummy audio input (1 second at 16kHz).
|
|
60
|
+
dummy_audio = keras.ops.convert_to_tensor(
|
|
61
|
+
[[0.1] * 16000],
|
|
62
|
+
dtype="float32"
|
|
63
|
+
)
|
|
64
|
+
dummy_audio = keras.ops.expand_dims(dummy_audio, axis=-1)
|
|
65
|
+
|
|
66
|
+
# Initialize the preprocessor.
|
|
67
|
+
preprocessor = MoonshineAudioConverter(do_normalize=True)
|
|
68
|
+
|
|
69
|
+
# Process the audio.
|
|
70
|
+
processed_audio = preprocessor(dummy_audio)
|
|
71
|
+
|
|
72
|
+
# Output shape.
|
|
73
|
+
print(processed_audio.shape) # Expected: (1, 16000, 1) or padded length
|
|
74
|
+
```
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
# References:
|
|
78
|
+
# Defined and formulated based on the UsefulSensors implementation of audio
|
|
79
|
+
# preprocessing logic (https://github.com/usefulsensors/moonshine/blob/main/moonshine/transcribe.py).
|
|
80
|
+
|
|
81
|
+
backbone_cls = MoonshineBackbone
|
|
82
|
+
|
|
83
|
+
def __init__(
|
|
84
|
+
self,
|
|
85
|
+
sampling_rate=16000,
|
|
86
|
+
padding_value=0.0,
|
|
87
|
+
do_normalize=False,
|
|
88
|
+
**kwargs,
|
|
89
|
+
):
|
|
90
|
+
super().__init__(**kwargs)
|
|
91
|
+
self._convert_input_args = False
|
|
92
|
+
self._allow_non_tensor_positional_args = True
|
|
93
|
+
self.sampling_rate = sampling_rate
|
|
94
|
+
self.padding_value = padding_value
|
|
95
|
+
self.do_normalize = do_normalize
|
|
96
|
+
|
|
97
|
+
def call(
|
|
98
|
+
self,
|
|
99
|
+
inputs,
|
|
100
|
+
sampling_rate=None,
|
|
101
|
+
padding=None,
|
|
102
|
+
max_length=None,
|
|
103
|
+
pad_to_multiple_of=None,
|
|
104
|
+
):
|
|
105
|
+
# Validate sampling rate.
|
|
106
|
+
if sampling_rate is not None and sampling_rate != self.sampling_rate:
|
|
107
|
+
raise ValueError(
|
|
108
|
+
f"Expected sampling_rate {self.sampling_rate}, got "
|
|
109
|
+
f"{sampling_rate}"
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# Ensure inputs are (batch_size, time_steps, 1).
|
|
113
|
+
input_shape = keras.ops.shape(inputs)
|
|
114
|
+
input_rank = len(input_shape)
|
|
115
|
+
if input_rank == 2:
|
|
116
|
+
processed_inputs = keras.ops.expand_dims(inputs, axis=-1)
|
|
117
|
+
elif input_rank == 3:
|
|
118
|
+
processed_inputs = inputs
|
|
119
|
+
else:
|
|
120
|
+
raise ValueError(
|
|
121
|
+
"Inputs must be mono audio: (batch_size, time_steps, 1)"
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
# Get original length and validate duration.
|
|
125
|
+
current_shape = keras.ops.shape(processed_inputs)
|
|
126
|
+
original_length = current_shape[1]
|
|
127
|
+
duration = (
|
|
128
|
+
keras.ops.cast(original_length, keras.backend.floatx())
|
|
129
|
+
/ self.sampling_rate
|
|
130
|
+
)
|
|
131
|
+
# Source: https://github.com/usefulsensors/moonshine/blob/4a000427bd36a1c2c6d20a86c672dbd850b44c88/moonshine/transcribe.py#L20
|
|
132
|
+
is_invalid_duration = keras.ops.logical_or(
|
|
133
|
+
keras.ops.less(duration, 0.1), keras.ops.greater(duration, 64.0)
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
def print_warning_fn():
|
|
137
|
+
import warnings
|
|
138
|
+
|
|
139
|
+
warnings.warn(
|
|
140
|
+
"Audio duration must be between 0.1 and 64 seconds. For "
|
|
141
|
+
"transcribing longer segments, pre-segment your audio and "
|
|
142
|
+
"provide shorter segments."
|
|
143
|
+
)
|
|
144
|
+
return keras.ops.convert_to_tensor(True, dtype="bool")
|
|
145
|
+
|
|
146
|
+
is_tf_symbolic = (
|
|
147
|
+
tf is not None
|
|
148
|
+
and hasattr(processed_inputs, "graph")
|
|
149
|
+
and hasattr(processed_inputs.graph, "as_graph_def")
|
|
150
|
+
)
|
|
151
|
+
use_tf_graph_ops = tf is not None and is_tf_symbolic
|
|
152
|
+
if use_tf_graph_ops and keras.config.backend() != "torch":
|
|
153
|
+
_ = tf.cond(
|
|
154
|
+
is_invalid_duration,
|
|
155
|
+
print_warning_fn,
|
|
156
|
+
lambda: keras.ops.convert_to_tensor(False, dtype="bool"),
|
|
157
|
+
)
|
|
158
|
+
else:
|
|
159
|
+
if keras.ops.convert_to_numpy(is_invalid_duration):
|
|
160
|
+
print_warning_fn()
|
|
161
|
+
|
|
162
|
+
# Handle padding.
|
|
163
|
+
if padding == "longest":
|
|
164
|
+
target_length = original_length
|
|
165
|
+
if pad_to_multiple_of:
|
|
166
|
+
target_length = (
|
|
167
|
+
(target_length + pad_to_multiple_of - 1)
|
|
168
|
+
// pad_to_multiple_of
|
|
169
|
+
) * pad_to_multiple_of
|
|
170
|
+
|
|
171
|
+
needs_padding = keras.ops.greater(target_length, original_length)
|
|
172
|
+
|
|
173
|
+
def pad_fn():
|
|
174
|
+
padding_amount = target_length - original_length
|
|
175
|
+
paddings = [[0, 0], [0, padding_amount], [0, 0]]
|
|
176
|
+
if use_tf_graph_ops and keras.config.backend() != "tensorflow":
|
|
177
|
+
return tf.pad(
|
|
178
|
+
processed_inputs,
|
|
179
|
+
paddings,
|
|
180
|
+
mode="CONSTANT",
|
|
181
|
+
constant_values=float(self.padding_value),
|
|
182
|
+
)
|
|
183
|
+
else:
|
|
184
|
+
return keras.ops.pad(
|
|
185
|
+
processed_inputs,
|
|
186
|
+
paddings,
|
|
187
|
+
mode="constant",
|
|
188
|
+
constant_values=self.padding_value,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
if use_tf_graph_ops and keras.config.backend() != "torch":
|
|
192
|
+
processed_inputs = tf.cond(
|
|
193
|
+
needs_padding, pad_fn, lambda: processed_inputs
|
|
194
|
+
)
|
|
195
|
+
else:
|
|
196
|
+
processed_inputs = keras.ops.cond(
|
|
197
|
+
needs_padding, pad_fn, lambda: processed_inputs
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
elif padding == "max_length" and max_length is not None:
|
|
201
|
+
target_length_const = max_length
|
|
202
|
+
if pad_to_multiple_of:
|
|
203
|
+
target_length_const = (
|
|
204
|
+
(target_length_const + pad_to_multiple_of - 1)
|
|
205
|
+
// pad_to_multiple_of
|
|
206
|
+
) * pad_to_multiple_of
|
|
207
|
+
|
|
208
|
+
needs_padding = keras.ops.less(original_length, target_length_const)
|
|
209
|
+
needs_truncating = keras.ops.greater(
|
|
210
|
+
original_length, target_length_const
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
def pad_fn():
|
|
214
|
+
padding_amount = target_length_const - original_length
|
|
215
|
+
paddings = [[0, 0], [0, padding_amount], [0, 0]]
|
|
216
|
+
if use_tf_graph_ops and keras.config.backend() != "tensorflow":
|
|
217
|
+
return tf.pad(
|
|
218
|
+
processed_inputs,
|
|
219
|
+
paddings,
|
|
220
|
+
mode="CONSTANT",
|
|
221
|
+
constant_values=float(self.padding_value),
|
|
222
|
+
)
|
|
223
|
+
else:
|
|
224
|
+
return keras.ops.pad(
|
|
225
|
+
processed_inputs,
|
|
226
|
+
paddings,
|
|
227
|
+
mode="constant",
|
|
228
|
+
constant_values=self.padding_value,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
def trunc_fn():
|
|
232
|
+
if use_tf_graph_ops and keras.config.backend() != "tensorflow":
|
|
233
|
+
return processed_inputs[:, :target_length_const, :]
|
|
234
|
+
else:
|
|
235
|
+
return keras.ops.slice(
|
|
236
|
+
processed_inputs,
|
|
237
|
+
[0, 0, 0],
|
|
238
|
+
[-1, target_length_const, -1],
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
if use_tf_graph_ops and keras.config.backend() != "torch":
|
|
242
|
+
processed_inputs = tf.cond(
|
|
243
|
+
needs_padding,
|
|
244
|
+
pad_fn,
|
|
245
|
+
lambda: tf.cond(
|
|
246
|
+
needs_truncating, trunc_fn, lambda: processed_inputs
|
|
247
|
+
),
|
|
248
|
+
)
|
|
249
|
+
else:
|
|
250
|
+
needs_padding = keras.ops.less(
|
|
251
|
+
original_length, target_length_const
|
|
252
|
+
)
|
|
253
|
+
needs_truncating = keras.ops.greater(
|
|
254
|
+
original_length, target_length_const
|
|
255
|
+
)
|
|
256
|
+
needs_padding_bool = keras.ops.convert_to_numpy(needs_padding)
|
|
257
|
+
needs_truncating_bool = keras.ops.convert_to_numpy(
|
|
258
|
+
needs_truncating
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
if needs_padding_bool:
|
|
262
|
+
padding_amount = target_length_const - original_length
|
|
263
|
+
paddings = [[0, 0], [0, padding_amount], [0, 0]]
|
|
264
|
+
processed_inputs = keras.ops.pad(
|
|
265
|
+
processed_inputs,
|
|
266
|
+
paddings,
|
|
267
|
+
mode="constant",
|
|
268
|
+
constant_values=self.padding_value,
|
|
269
|
+
)
|
|
270
|
+
elif needs_truncating_bool:
|
|
271
|
+
processed_inputs = processed_inputs[
|
|
272
|
+
:, :target_length_const, :
|
|
273
|
+
]
|
|
274
|
+
|
|
275
|
+
# Normalize if enabled.
|
|
276
|
+
if self.do_normalize:
|
|
277
|
+
mean = keras.ops.mean(processed_inputs, axis=1, keepdims=True)
|
|
278
|
+
var = keras.ops.var(processed_inputs, axis=1, keepdims=True)
|
|
279
|
+
processed_inputs = (processed_inputs - mean) / keras.ops.sqrt(
|
|
280
|
+
var + 1e-7
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
return processed_inputs
|
|
284
|
+
|
|
285
|
+
def compute_output_shape(self, input_shape):
|
|
286
|
+
# [batch_size, time_steps] → [batch_size, time_steps, 1].
|
|
287
|
+
if len(input_shape) == 2 or len(input_shape) == 3:
|
|
288
|
+
return (input_shape[0], None, 1)
|
|
289
|
+
else:
|
|
290
|
+
raise ValueError("Input shape must be rank 2 or 3.")
|
|
291
|
+
|
|
292
|
+
def get_config(self):
|
|
293
|
+
config = super().get_config()
|
|
294
|
+
config.update(
|
|
295
|
+
{
|
|
296
|
+
"sampling_rate": self.sampling_rate,
|
|
297
|
+
"padding_value": self.padding_value,
|
|
298
|
+
"do_normalize": self.do_normalize,
|
|
299
|
+
}
|
|
300
|
+
)
|
|
301
|
+
return config
|
|
@@ -0,0 +1,383 @@
|
|
|
1
|
+
import keras
|
|
2
|
+
|
|
3
|
+
from keras_hub.src.api_export import keras_hub_export
|
|
4
|
+
from keras_hub.src.models.audio_to_text import AudioToText
|
|
5
|
+
from keras_hub.src.models.moonshine.moonshine_audio_to_text_preprocessor import ( # noqa: E501
|
|
6
|
+
MoonshineAudioToTextPreprocessor,
|
|
7
|
+
)
|
|
8
|
+
from keras_hub.src.models.moonshine.moonshine_backbone import Arange
|
|
9
|
+
from keras_hub.src.models.moonshine.moonshine_backbone import MoonshineBackbone
|
|
10
|
+
from keras_hub.src.models.moonshine.moonshine_backbone import (
|
|
11
|
+
compute_output_lengths,
|
|
12
|
+
)
|
|
13
|
+
from keras_hub.src.utils.tensor_utils import any_equal
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@keras_hub_export("keras_hub.models.MoonshineAudioToText")
|
|
17
|
+
class MoonshineAudioToText(AudioToText):
|
|
18
|
+
"""An end-to-end Moonshine model for audio-to-text tasks.
|
|
19
|
+
|
|
20
|
+
A Seq2Seq LM designed for audio-to-text tasks, such as speech recognition.
|
|
21
|
+
The encoder processes audio features, and the decoder generates text
|
|
22
|
+
transcriptions. You can finetune `MoonshineAudioToText` for any
|
|
23
|
+
audio-to-text task (e.g., live transcription or voice commands).
|
|
24
|
+
|
|
25
|
+
This model includes a `generate()` method for text generation based on audio
|
|
26
|
+
inputs and an optional text prompt for the decoder. The generation strategy
|
|
27
|
+
is controlled by a `sampler` argument passed to `compile()`. By default,
|
|
28
|
+
`"top_k"` sampling is used.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
backbone: A `keras_hub.models.MoonshineBackbone` instance.
|
|
32
|
+
preprocessor: A `keras_hub.models.MoonshineAudioToTextPreprocessor` or
|
|
33
|
+
`None`. If `None`, inputs must be preprocessed before calling the
|
|
34
|
+
model.
|
|
35
|
+
|
|
36
|
+
Examples:
|
|
37
|
+
```python
|
|
38
|
+
# Initialize model from preset.
|
|
39
|
+
moonshine_lm = keras_hub.models.MoonshineAudioToText.from_preset(
|
|
40
|
+
"moonshine_base"
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
# Generate with single audio input.
|
|
44
|
+
audio_tensor = keras.random.normal((1, 16000, 1))
|
|
45
|
+
moonshine_lm.generate({"audio": audio_tensor})
|
|
46
|
+
|
|
47
|
+
# Generate with text prompt.
|
|
48
|
+
moonshine_lm.generate({"audio": audio_tensor, "text": "quick"})
|
|
49
|
+
|
|
50
|
+
# Use different sampling strategy.
|
|
51
|
+
moonshine_lm.compile(sampler="greedy")
|
|
52
|
+
moonshine_lm.generate({"audio": audio_tensor})
|
|
53
|
+
```
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
# References:
|
|
57
|
+
# Defined and formulated based on the Hugging Face implementation of the
|
|
58
|
+
# MoonshineForConditionalGeneration class (https://github.com/huggingface/transformers/blob/dcbdf7e962c4b36140cc9ee76f870016121e69e5/src/transformers/models/moonshine/modeling_moonshine.py#L1509-L1626).
|
|
59
|
+
|
|
60
|
+
backbone_cls = MoonshineBackbone
|
|
61
|
+
preprocessor_cls = MoonshineAudioToTextPreprocessor
|
|
62
|
+
|
|
63
|
+
def __init__(self, backbone, preprocessor=None, **kwargs):
|
|
64
|
+
# === Layers ===
|
|
65
|
+
self.backbone = backbone
|
|
66
|
+
self.preprocessor = preprocessor
|
|
67
|
+
|
|
68
|
+
# === Functional Model ===
|
|
69
|
+
inputs = backbone.input
|
|
70
|
+
hidden_states = backbone(inputs)["decoder_sequence_output"]
|
|
71
|
+
outputs = backbone.token_embedding(hidden_states, reverse=True)
|
|
72
|
+
super().__init__(
|
|
73
|
+
inputs=inputs,
|
|
74
|
+
outputs=outputs,
|
|
75
|
+
**kwargs,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
def call_decoder_with_cache(
|
|
79
|
+
self,
|
|
80
|
+
encoder_hidden_states,
|
|
81
|
+
encoder_padding_mask,
|
|
82
|
+
decoder_token_ids,
|
|
83
|
+
self_attention_cache=None,
|
|
84
|
+
self_attention_cache_update_index=None,
|
|
85
|
+
cross_attention_cache=None,
|
|
86
|
+
):
|
|
87
|
+
"""Process decoder inputs with attention caching for efficient
|
|
88
|
+
generation.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
encoder_hidden_states: Tensor. Encoder outputs.
|
|
92
|
+
encoder_padding_mask: Tensor. Padding mask for encoder outputs.
|
|
93
|
+
decoder_token_ids: Tensor. Decoder input token IDs.
|
|
94
|
+
self_attention_cache: Tensor, optional. Cache for self-attention
|
|
95
|
+
layers.
|
|
96
|
+
self_attention_cache_update_index: int, optional. Index for cache
|
|
97
|
+
updates.
|
|
98
|
+
cross_attention_cache: Tensor, optional. Cache for cross-attention
|
|
99
|
+
layers. This cache is computed once and reused.
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
Tuple: Tuple of (logits, hidden_states, new_self_attention_cache,
|
|
103
|
+
cross_attention_cache).
|
|
104
|
+
"""
|
|
105
|
+
tokens = self.backbone.token_embedding(decoder_token_ids)
|
|
106
|
+
x = tokens
|
|
107
|
+
|
|
108
|
+
# Cache management for audio-to-text generation.
|
|
109
|
+
self_attention_caches = []
|
|
110
|
+
position = keras.ops.array(
|
|
111
|
+
[self_attention_cache_update_index], dtype="int32"
|
|
112
|
+
)
|
|
113
|
+
rotary_embedding = self.backbone.decoder_rotary_embedding(position)
|
|
114
|
+
|
|
115
|
+
for i, layer in enumerate(self.backbone.decoder_blocks):
|
|
116
|
+
current_self_cache = self_attention_cache[:, i, ...]
|
|
117
|
+
current_cross_cache = cross_attention_cache[:, i, ...]
|
|
118
|
+
x, new_self_cache = layer(
|
|
119
|
+
decoder_sequence=x,
|
|
120
|
+
encoder_sequence=encoder_hidden_states,
|
|
121
|
+
rotary_embedding=rotary_embedding,
|
|
122
|
+
encoder_padding_mask=encoder_padding_mask,
|
|
123
|
+
self_attention_cache=current_self_cache,
|
|
124
|
+
self_attention_cache_update_index=self_attention_cache_update_index,
|
|
125
|
+
cross_attention_cache=current_cross_cache,
|
|
126
|
+
training=False,
|
|
127
|
+
)
|
|
128
|
+
# Update self-attention cache.
|
|
129
|
+
self_attention_caches.append(new_self_cache)
|
|
130
|
+
|
|
131
|
+
# [batch_size, num_layers, 2, seq_len, num_heads, head_dim].
|
|
132
|
+
new_self_attention_cache = keras.ops.stack(
|
|
133
|
+
self_attention_caches, axis=1
|
|
134
|
+
)
|
|
135
|
+
hidden_states = self.backbone.decoder_post_norm(x)
|
|
136
|
+
logits = self.backbone.token_embedding(hidden_states, reverse=True)
|
|
137
|
+
return (
|
|
138
|
+
logits,
|
|
139
|
+
hidden_states,
|
|
140
|
+
new_self_attention_cache,
|
|
141
|
+
cross_attention_cache,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
def _build_cache(
|
|
145
|
+
self,
|
|
146
|
+
encoder_input_values,
|
|
147
|
+
encoder_padding_mask,
|
|
148
|
+
decoder_token_ids,
|
|
149
|
+
decoder_padding_mask,
|
|
150
|
+
):
|
|
151
|
+
"""Build initial cache states from inputs."""
|
|
152
|
+
encoder_hidden_states, encoder_attention_mask_for_decoder = (
|
|
153
|
+
self.call_encoder(
|
|
154
|
+
encoder_input_values=encoder_input_values,
|
|
155
|
+
padding_mask=encoder_padding_mask,
|
|
156
|
+
)
|
|
157
|
+
)
|
|
158
|
+
precomputed_cross_caches = []
|
|
159
|
+
for layer in self.backbone.decoder_blocks:
|
|
160
|
+
cross_k = layer.cross_attention._key_dense(encoder_hidden_states)
|
|
161
|
+
cross_v = layer.cross_attention._value_dense(encoder_hidden_states)
|
|
162
|
+
layer_cross_cache = keras.ops.stack([cross_k, cross_v], axis=1)
|
|
163
|
+
precomputed_cross_caches.append(layer_cross_cache)
|
|
164
|
+
precomputed_cross_cache = keras.ops.stack(
|
|
165
|
+
precomputed_cross_caches, axis=1
|
|
166
|
+
)
|
|
167
|
+
batch_size = keras.ops.shape(encoder_input_values)[0]
|
|
168
|
+
num_layers = self.backbone.decoder_num_layers
|
|
169
|
+
num_heads = self.backbone.decoder_num_heads
|
|
170
|
+
head_dim = self.backbone.hidden_dim // self.backbone.decoder_num_heads
|
|
171
|
+
if self.backbone.pad_head_dim_to_multiple_of is not None:
|
|
172
|
+
head_dim = (
|
|
173
|
+
(head_dim + self.backbone.pad_head_dim_to_multiple_of - 1)
|
|
174
|
+
// self.backbone.pad_head_dim_to_multiple_of
|
|
175
|
+
) * self.backbone.pad_head_dim_to_multiple_of
|
|
176
|
+
# Use the full sequence length for the cache dimension.
|
|
177
|
+
cache_length = keras.ops.shape(decoder_token_ids)[1]
|
|
178
|
+
initial_self_cache_shape = (
|
|
179
|
+
batch_size,
|
|
180
|
+
num_layers,
|
|
181
|
+
2,
|
|
182
|
+
cache_length,
|
|
183
|
+
num_heads,
|
|
184
|
+
head_dim,
|
|
185
|
+
)
|
|
186
|
+
initial_self_cache = keras.ops.zeros(
|
|
187
|
+
initial_self_cache_shape, dtype=self.compute_dtype
|
|
188
|
+
)
|
|
189
|
+
tokens = self.backbone.token_embedding(decoder_token_ids)
|
|
190
|
+
x = tokens
|
|
191
|
+
positions = keras.ops.arange(0, cache_length, dtype="int32")
|
|
192
|
+
rotary_embedding = self.backbone.decoder_rotary_embedding(positions)
|
|
193
|
+
seeded_self_caches = []
|
|
194
|
+
for i, layer in enumerate(self.backbone.decoder_blocks):
|
|
195
|
+
current_initial_self_cache = initial_self_cache[:, i, ...]
|
|
196
|
+
current_precomputed_cross_cache = precomputed_cross_cache[:, i, ...]
|
|
197
|
+
x, seeded_self_cache_layer = layer(
|
|
198
|
+
decoder_sequence=x,
|
|
199
|
+
encoder_sequence=encoder_hidden_states,
|
|
200
|
+
rotary_embedding=rotary_embedding,
|
|
201
|
+
decoder_padding_mask=decoder_padding_mask,
|
|
202
|
+
encoder_padding_mask=encoder_attention_mask_for_decoder,
|
|
203
|
+
self_attention_cache=current_initial_self_cache,
|
|
204
|
+
self_attention_cache_update_index=0,
|
|
205
|
+
cross_attention_cache=current_precomputed_cross_cache,
|
|
206
|
+
training=False,
|
|
207
|
+
)
|
|
208
|
+
seeded_self_caches.append(seeded_self_cache_layer)
|
|
209
|
+
hidden_states = self.backbone.decoder_post_norm(x)
|
|
210
|
+
self_attn_cache = keras.ops.stack(seeded_self_caches, axis=1)
|
|
211
|
+
return (
|
|
212
|
+
hidden_states,
|
|
213
|
+
self_attn_cache,
|
|
214
|
+
precomputed_cross_cache,
|
|
215
|
+
encoder_hidden_states,
|
|
216
|
+
encoder_attention_mask_for_decoder,
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
def call_encoder(self, encoder_input_values, padding_mask):
|
|
220
|
+
"""Process audio input through the encoder stack."""
|
|
221
|
+
x = self.backbone.conv1(encoder_input_values)
|
|
222
|
+
x = self.backbone.tanh_after_conv1(x)
|
|
223
|
+
x = self.backbone.group_norm(x)
|
|
224
|
+
x = self.backbone.conv2(x)
|
|
225
|
+
x = self.backbone.gelu_after_conv2(x)
|
|
226
|
+
x = self.backbone.conv3(x)
|
|
227
|
+
x = self.backbone.gelu_after_conv3(x)
|
|
228
|
+
original_lengths = keras.ops.sum(
|
|
229
|
+
keras.ops.cast(padding_mask, "int32"), axis=1
|
|
230
|
+
)
|
|
231
|
+
output_lengths = compute_output_lengths(original_lengths)
|
|
232
|
+
padding_mask = self.backbone._compute_mask_layer(x, output_lengths)
|
|
233
|
+
positions = Arange(name="encoder_positions")(x)
|
|
234
|
+
rotary_embedding = self.backbone.encoder_rotary_embedding(positions)
|
|
235
|
+
x = self.backbone.encoder_dropout(x, training=False)
|
|
236
|
+
for transformer_layer in self.backbone.encoder_blocks:
|
|
237
|
+
x = transformer_layer(
|
|
238
|
+
inputs=x,
|
|
239
|
+
rotary_embedding=rotary_embedding,
|
|
240
|
+
attention_mask=padding_mask,
|
|
241
|
+
training=False,
|
|
242
|
+
)
|
|
243
|
+
x = self.backbone.encoder_final_layer_norm(x)
|
|
244
|
+
return x, padding_mask
|
|
245
|
+
|
|
246
|
+
# Source: https://github.com/huggingface/transformers/blob/9e94801146ceeb3b215bbdb9492be74d7d7b7210/src/transformers/generation/utils.py#L1970-L2463
|
|
247
|
+
def generate_step(self, inputs, stop_token_ids=None):
|
|
248
|
+
"""A compilable generation function for a batch of inputs.
|
|
249
|
+
|
|
250
|
+
This function represents the inner, XLA-compilable, generation function
|
|
251
|
+
for a single batch of inputs. Inputs should have the same structure as
|
|
252
|
+
model inputs, a dictionary with keys `"encoder_input_values"`,
|
|
253
|
+
`"encoder_padding_mask"`, `"decoder_token_ids"` and
|
|
254
|
+
`"decoder_padding_mask"`.
|
|
255
|
+
|
|
256
|
+
Args:
|
|
257
|
+
inputs: A dictionary with four keys - `"encoder_input_values"`,
|
|
258
|
+
`"encoder_padding_mask"`, `"decoder_token_ids"` and
|
|
259
|
+
`"decoder_padding_mask"`, with batched tensor values.
|
|
260
|
+
stop_token_ids: Tuple of id's of end token's to stop on. If all
|
|
261
|
+
sequences have produced a new stop token, generation
|
|
262
|
+
will stop.
|
|
263
|
+
|
|
264
|
+
Returns:
|
|
265
|
+
Dictionary: A dictionary with two keys - `"decoder_token_ids"`
|
|
266
|
+
containing the updated token sequence with newly generated
|
|
267
|
+
tokens, and `"decoder_padding_mask"` containing the updated
|
|
268
|
+
padding mask for the generated sequence.
|
|
269
|
+
"""
|
|
270
|
+
encoder_input_values = inputs["encoder_input_values"]
|
|
271
|
+
encoder_padding_mask = inputs["encoder_padding_mask"]
|
|
272
|
+
decoder_token_ids = inputs["decoder_token_ids"]
|
|
273
|
+
decoder_padding_mask = inputs["decoder_padding_mask"]
|
|
274
|
+
|
|
275
|
+
if (
|
|
276
|
+
encoder_input_values is None
|
|
277
|
+
or encoder_padding_mask is None
|
|
278
|
+
or decoder_token_ids is None
|
|
279
|
+
):
|
|
280
|
+
raise ValueError("Input tensors cannot be None")
|
|
281
|
+
|
|
282
|
+
(
|
|
283
|
+
hidden_states,
|
|
284
|
+
self_attention_cache,
|
|
285
|
+
cross_attention_cache,
|
|
286
|
+
encoder_hidden_states,
|
|
287
|
+
encoder_attention_mask_for_decoder,
|
|
288
|
+
) = self._build_cache(
|
|
289
|
+
encoder_input_values,
|
|
290
|
+
encoder_padding_mask,
|
|
291
|
+
decoder_token_ids,
|
|
292
|
+
decoder_padding_mask,
|
|
293
|
+
)
|
|
294
|
+
row_lengths = keras.ops.sum(
|
|
295
|
+
keras.ops.cast(decoder_padding_mask, "int32"),
|
|
296
|
+
axis=-1,
|
|
297
|
+
)
|
|
298
|
+
index = keras.ops.min(row_lengths)
|
|
299
|
+
|
|
300
|
+
def next(prompt, cache, index):
|
|
301
|
+
if isinstance(cache, tuple) and len(cache) == 2:
|
|
302
|
+
current_self_attention_cache = cache[0]
|
|
303
|
+
current_cross_attention_cache = cache[1]
|
|
304
|
+
elif cache is not None and not isinstance(cache, tuple):
|
|
305
|
+
current_self_attention_cache = cache
|
|
306
|
+
current_cross_attention_cache = cross_attention_cache
|
|
307
|
+
else:
|
|
308
|
+
cache = None
|
|
309
|
+
cache_index = index - 1
|
|
310
|
+
num_samples = keras.ops.shape(prompt)[0]
|
|
311
|
+
next_token_input = keras.ops.slice(
|
|
312
|
+
prompt, [0, cache_index], [num_samples, 1]
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
batch_size = keras.ops.shape(encoder_input_values)[0]
|
|
316
|
+
|
|
317
|
+
def repeat_tensor(x):
|
|
318
|
+
if keras.ops.shape(x)[0] == num_samples:
|
|
319
|
+
return x
|
|
320
|
+
return keras.ops.repeat(
|
|
321
|
+
x, repeats=num_samples // batch_size, axis=0
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
cross_attention_cache_repeated = repeat_tensor(
|
|
325
|
+
current_cross_attention_cache
|
|
326
|
+
)
|
|
327
|
+
logits, hidden_states, new_self_attention_cache, _ = (
|
|
328
|
+
self.call_decoder_with_cache(
|
|
329
|
+
encoder_hidden_states=repeat_tensor(encoder_hidden_states),
|
|
330
|
+
encoder_padding_mask=repeat_tensor(
|
|
331
|
+
encoder_attention_mask_for_decoder
|
|
332
|
+
),
|
|
333
|
+
decoder_token_ids=next_token_input,
|
|
334
|
+
self_attention_cache=current_self_attention_cache,
|
|
335
|
+
self_attention_cache_update_index=cache_index,
|
|
336
|
+
cross_attention_cache=cross_attention_cache_repeated,
|
|
337
|
+
)
|
|
338
|
+
)
|
|
339
|
+
return (
|
|
340
|
+
logits[:, 0, :],
|
|
341
|
+
hidden_states[:, 0, :],
|
|
342
|
+
(new_self_attention_cache, current_cross_attention_cache),
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
decoder_token_ids = self.sampler(
|
|
346
|
+
next=next,
|
|
347
|
+
prompt=decoder_token_ids,
|
|
348
|
+
cache=(self_attention_cache, cross_attention_cache),
|
|
349
|
+
index=index,
|
|
350
|
+
mask=keras.ops.cast(
|
|
351
|
+
decoder_token_ids != self.preprocessor.tokenizer.pad_token_id
|
|
352
|
+
if self.preprocessor is not None
|
|
353
|
+
else decoder_padding_mask,
|
|
354
|
+
dtype="bool",
|
|
355
|
+
),
|
|
356
|
+
stop_token_ids=stop_token_ids,
|
|
357
|
+
hidden_states=hidden_states,
|
|
358
|
+
model=self,
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
if stop_token_ids is not None:
|
|
362
|
+
end_locations = any_equal(
|
|
363
|
+
decoder_token_ids,
|
|
364
|
+
stop_token_ids,
|
|
365
|
+
decoder_token_ids == self.preprocessor.tokenizer.pad_token_id
|
|
366
|
+
if self.preprocessor is not None
|
|
367
|
+
else False,
|
|
368
|
+
)
|
|
369
|
+
end_locations = keras.ops.cast(end_locations, "int32")
|
|
370
|
+
cumsum = keras.ops.cumsum(end_locations, axis=-1)
|
|
371
|
+
overflow = cumsum - end_locations
|
|
372
|
+
decoder_padding_mask = keras.ops.logical_not(
|
|
373
|
+
keras.ops.cast(overflow, "bool")
|
|
374
|
+
)
|
|
375
|
+
else:
|
|
376
|
+
decoder_padding_mask = keras.ops.ones_like(
|
|
377
|
+
decoder_token_ids, dtype="bool"
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
return {
|
|
381
|
+
"decoder_token_ids": decoder_token_ids,
|
|
382
|
+
"decoder_padding_mask": decoder_padding_mask,
|
|
383
|
+
}
|