keras-hub-nightly 0.21.0.dev202505140407__py3-none-any.whl → 0.21.0.dev202505160409__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,478 @@
1
+ import keras
2
+
3
+ from keras_hub.src.api_export import keras_hub_export
4
+ from keras_hub.src.layers.modeling.reversible_embedding import (
5
+ ReversibleEmbedding,
6
+ )
7
+ from keras_hub.src.models.backbone import Backbone
8
+ from keras_hub.src.models.moonshine.moonshine_decoder import (
9
+ MoonshineDecoderBlock,
10
+ )
11
+ from keras_hub.src.models.moonshine.moonshine_encoder import (
12
+ MoonshineEncoderBlock,
13
+ )
14
+ from keras_hub.src.models.moonshine.moonshine_layers import (
15
+ MoonshineRotaryEmbedding,
16
+ )
17
+ from keras_hub.src.models.moonshine.moonshine_layers import (
18
+ moonshine_kernel_initializer,
19
+ )
20
+
21
+
22
+ def compute_output_lengths(input_lengths):
23
+ lengths = keras.ops.cast(input_lengths, "float32")
24
+ lengths = keras.ops.floor((lengths - 127) / 64) + 1
25
+ lengths = keras.ops.floor((lengths - 7) / 3) + 1
26
+ lengths = keras.ops.floor((lengths - 3) / 2) + 1
27
+ return keras.ops.maximum(keras.ops.cast(lengths, "int32"), 0)
28
+
29
+
30
+ @keras.saving.register_keras_serializable(package="keras_hub")
31
+ class ComputeAttentionMask(keras.layers.Layer):
32
+ def call(self, features_for_shape, output_lengths):
33
+ max_output_length = keras.ops.shape(features_for_shape)[1]
34
+ indices = keras.ops.arange(max_output_length, dtype="int32")
35
+ attention_mask = indices[None, :] < output_lengths[:, None]
36
+ attention_mask = keras.ops.cast(attention_mask, "bool")
37
+ return attention_mask
38
+
39
+ def compute_output_shape(self, input_shapes):
40
+ batch_dim = None
41
+ if isinstance(input_shapes, (list, tuple)) and len(input_shapes) > 0:
42
+ features_shape = input_shapes[0]
43
+ if (
44
+ isinstance(features_shape, (list, tuple))
45
+ and len(features_shape) > 0
46
+ ):
47
+ batch_dim = features_shape[0]
48
+ return (batch_dim, None)
49
+
50
+
51
+ @keras.saving.register_keras_serializable(package="keras_hub")
52
+ class Arange(keras.layers.Layer):
53
+ def call(self, inputs):
54
+ sequence_length = keras.ops.shape(inputs)[1]
55
+ return keras.ops.arange(sequence_length, dtype="int32")
56
+
57
+
58
+ @keras_hub_export("keras_hub.models.MoonshineBackbone")
59
+ class MoonshineBackbone(Backbone):
60
+ """Moonshine backbone with integrated audio feature extraction.
61
+
62
+ This class implements an encoder-decoder backbone, as used in the Moonshine
63
+ ASR system. It includes initial convolutional layers for audio feature
64
+ extraction followed by `MoonshineEncoderBlock` instances for processing
65
+ these features and `MoonshineDecoderBlock` instances for generating output
66
+ sequences.
67
+
68
+ Args:
69
+ vocabulary_size: int. The size of the vocabulary for the embedding
70
+ layers.
71
+ filter_dim: int. The number of filters for the initial convolutional
72
+ feature extractor layers. Typically matches `hidden_dim`.
73
+ encoder_num_layers: int. The number of stacked encoder blocks.
74
+ decoder_num_layers: int. The number of stacked decoder blocks.
75
+ hidden_dim: int. The dimensionality of the model's hidden
76
+ representations and embeddings.
77
+ intermediate_dim: int. The dimensionality of the intermediate
78
+ representations in feedforward networks.
79
+ encoder_num_heads: int. The number of attention heads in the encoder's
80
+ multi-head attention.
81
+ decoder_num_heads: int. The number of attention heads in the decoder's
82
+ multi-head attention.
83
+ feedforward_expansion_factor: int, optional. A multiplier applied to
84
+ `intermediate_dim` to determine the total width of the feedforward
85
+ network. Defaults to 4.
86
+ encoder_use_swiglu_activation: bool, optional. When True, uses SwiGLU
87
+ in the encoder feedforward network. Defaults to False.
88
+ decoder_use_swiglu_activation: bool, optional. When True, uses SwiGLU
89
+ in the decoder feedforward network. Defaults to True.
90
+ max_position_embeddings: int, optional. The maximum sequence length for
91
+ position embeddings. Defaults to 2048.
92
+ pad_head_dim_to_multiple_of: int, optional. If specified, pads the head
93
+ dimension to be a multiple of this value for performance
94
+ optimization. Defaults to None.
95
+ partial_rotary_factor: float, optional. The fraction of dimensions to
96
+ apply rotary position embeddings to. Defaults to 0.62.
97
+ dropout: float, optional. The dropout probability for input dropout
98
+ layers. Defaults to 0.0.
99
+ initializer_range: float, optional. The standard deviation of the
100
+ truncated normal initializer for weights. Defaults to 0.02.
101
+ rope_theta: float, optional. The base frequency for rotary position
102
+ embeddings. Defaults to 10,000.0.
103
+ attention_bias: bool, optional. Whether to use bias in attention
104
+ mechanisms. Defaults to False.
105
+ attention_dropout: float, optional. The dropout probability for
106
+ attention mechanisms. Defaults to 0.0.
107
+ dtype: str, optional. The dtype to use for model computations and
108
+ weights. Defaults to None.
109
+
110
+ Examples:
111
+ ```python
112
+ import numpy as np
113
+ import keras
114
+ from keras_hub.models import MoonshineBackbone
115
+
116
+ # Create random input data for demonstration.
117
+ # Input is now raw-ish audio features (e.g., from MoonshineAudioConverter).
118
+ encoder_raw_input_values = np.random.rand(1, 16000, 1).astype("float32")
119
+ # Mask corresponding to the raw input time dimension
120
+ encoder_padding_mask = np.ones((1, 16000), dtype="bool")
121
+ decoder_token_ids = np.random.randint(
122
+ 0, 1000, size=(1, 20), dtype="int32"
123
+ )
124
+ decoder_padding_mask = np.ones((1, 20), dtype="bool")
125
+
126
+ # Initialize the Moonshine backbone with specific parameters.
127
+ backbone = MoonshineBackbone(
128
+ vocabulary_size=10000,
129
+ filter_dim=256,
130
+ encoder_num_layers=6,
131
+ decoder_num_layers=6,
132
+ hidden_dim=256,
133
+ intermediate_dim=512,
134
+ encoder_num_heads=8,
135
+ decoder_num_heads=8,
136
+ feedforward_expansion_factor=4,
137
+ decoder_use_swiglu_activation=True,
138
+ encoder_use_swiglu_activation=False,
139
+ )
140
+
141
+ # Forward pass through the model.
142
+ outputs = backbone(
143
+ {
144
+ "encoder_input_values": encoder_raw_input_values,
145
+ "encoder_padding_mask": encoder_padding_mask,
146
+ "decoder_token_ids": decoder_token_ids,
147
+ "decoder_padding_mask": decoder_padding_mask,
148
+ }
149
+ )
150
+
151
+ # Display the outputs.
152
+ print("Encoder output shape:", outputs["encoder_sequence_output"].shape)
153
+ print("Decoder output shape:", outputs["decoder_sequence_output"].shape)
154
+ ```
155
+ """
156
+
157
+ # References:
158
+ # Feature Extractor: UsefulSensors implementation (https://github.com/usefulsensors/moonshine/blob/4a000427bd36a1c2c6d20a86c672dbd850b44c88/moonshine/model.py#L6-L32)
159
+ # Transformer Backbone: Hugging Face implementation (https://github.com/huggingface/transformers/blob/dcbdf7e962c4b36140cc9ee76f870016121e69e5/src/transformers/models/moonshine/modeling_moonshine.py#L1326-L1486).
160
+
161
+ def __init__(
162
+ self,
163
+ vocabulary_size,
164
+ filter_dim,
165
+ encoder_num_layers,
166
+ decoder_num_layers,
167
+ hidden_dim,
168
+ intermediate_dim,
169
+ encoder_num_heads,
170
+ decoder_num_heads,
171
+ feedforward_expansion_factor=4,
172
+ encoder_use_swiglu_activation=False,
173
+ decoder_use_swiglu_activation=True,
174
+ max_position_embeddings=2048,
175
+ pad_head_dim_to_multiple_of=None,
176
+ partial_rotary_factor=0.62,
177
+ dropout=0.0,
178
+ initializer_range=0.02,
179
+ rope_theta=10000.0,
180
+ attention_bias=False,
181
+ attention_dropout=0.0,
182
+ dtype=None,
183
+ **kwargs,
184
+ ):
185
+ # ==== Layers ====
186
+ self._compute_mask_layer = ComputeAttentionMask(
187
+ name="compute_attention_mask"
188
+ )
189
+
190
+ # Feature extractor layers.
191
+ self.conv1 = keras.layers.Conv1D(
192
+ filters=filter_dim,
193
+ kernel_size=127,
194
+ strides=64,
195
+ use_bias=False,
196
+ padding="valid",
197
+ kernel_initializer=moonshine_kernel_initializer(
198
+ initializer_range=initializer_range
199
+ ),
200
+ name="conv1",
201
+ dtype=dtype,
202
+ )
203
+ self.group_norm = keras.layers.GroupNormalization(
204
+ groups=1,
205
+ axis=-1,
206
+ epsilon=1e-5,
207
+ center=True,
208
+ scale=True,
209
+ name="group_norm",
210
+ dtype=dtype,
211
+ )
212
+ self.tanh_after_conv1 = keras.layers.Activation(
213
+ "tanh", name="tanh_after_conv1", dtype=dtype
214
+ )
215
+ self.conv2 = keras.layers.Conv1D(
216
+ filters=2 * filter_dim,
217
+ kernel_size=7,
218
+ strides=3,
219
+ use_bias=True,
220
+ padding="valid",
221
+ kernel_initializer=moonshine_kernel_initializer(
222
+ initializer_range=initializer_range
223
+ ),
224
+ name="conv2",
225
+ dtype=dtype,
226
+ )
227
+ self.gelu_after_conv2 = keras.layers.Activation(
228
+ "gelu", name="gelu_after_conv2", dtype=dtype
229
+ )
230
+ self.conv3 = keras.layers.Conv1D(
231
+ filters=filter_dim,
232
+ kernel_size=3,
233
+ strides=2,
234
+ use_bias=True,
235
+ padding="valid",
236
+ kernel_initializer=moonshine_kernel_initializer(
237
+ initializer_range=initializer_range
238
+ ),
239
+ name="conv3",
240
+ dtype=dtype,
241
+ )
242
+ self.gelu_after_conv3 = keras.layers.Activation(
243
+ "gelu", name="gelu_after_conv3", dtype=dtype
244
+ )
245
+
246
+ # Transformer layers.
247
+ encoder_head_dim = hidden_dim // encoder_num_heads
248
+ if pad_head_dim_to_multiple_of:
249
+ encoder_head_dim = (
250
+ (encoder_head_dim + pad_head_dim_to_multiple_of - 1)
251
+ // pad_head_dim_to_multiple_of
252
+ ) * pad_head_dim_to_multiple_of
253
+
254
+ decoder_head_dim = hidden_dim // decoder_num_heads
255
+ if pad_head_dim_to_multiple_of:
256
+ decoder_head_dim = (
257
+ (decoder_head_dim + pad_head_dim_to_multiple_of - 1)
258
+ // pad_head_dim_to_multiple_of
259
+ ) * pad_head_dim_to_multiple_of
260
+
261
+ # Embedding layer for decoder.
262
+ self.token_embedding = ReversibleEmbedding(
263
+ input_dim=vocabulary_size,
264
+ output_dim=hidden_dim,
265
+ embeddings_initializer=moonshine_kernel_initializer(
266
+ initializer_range=initializer_range
267
+ ),
268
+ name="token_embedding",
269
+ dtype=dtype,
270
+ )
271
+
272
+ # Rotary embeddings for encoder and decoder.
273
+ self.encoder_rotary_embedding = MoonshineRotaryEmbedding(
274
+ head_dim=encoder_head_dim,
275
+ max_position_embeddings=max_position_embeddings,
276
+ partial_rotary_factor=partial_rotary_factor,
277
+ base_value=rope_theta,
278
+ name="encoder_rotary_embedding",
279
+ dtype=dtype,
280
+ )
281
+
282
+ self.decoder_rotary_embedding = MoonshineRotaryEmbedding(
283
+ head_dim=decoder_head_dim,
284
+ max_position_embeddings=max_position_embeddings,
285
+ partial_rotary_factor=partial_rotary_factor,
286
+ base_value=rope_theta,
287
+ name="decoder_rotary_embedding",
288
+ dtype=dtype,
289
+ )
290
+
291
+ # Dropout for encoder.
292
+ self.encoder_dropout = keras.layers.Dropout(
293
+ dropout, name="encoder_dropout", dtype=dtype
294
+ )
295
+ # Dropout for decoder.
296
+ self.decoder_dropout = keras.layers.Dropout(
297
+ dropout, name="decoder_dropout", dtype=dtype
298
+ )
299
+
300
+ # Encoder blocks.
301
+ self.encoder_blocks = []
302
+ for i in range(encoder_num_layers):
303
+ encoder_block = MoonshineEncoderBlock(
304
+ hidden_dim=hidden_dim,
305
+ intermediate_dim=intermediate_dim,
306
+ num_heads=encoder_num_heads,
307
+ feedforward_expansion_factor=feedforward_expansion_factor,
308
+ use_swiglu_activation=encoder_use_swiglu_activation,
309
+ pad_head_dim_to_multiple_of=pad_head_dim_to_multiple_of,
310
+ initializer_range=initializer_range,
311
+ attention_bias=attention_bias,
312
+ attention_dropout=attention_dropout,
313
+ name=f"encoder_block_{i}",
314
+ dtype=dtype,
315
+ )
316
+ self.encoder_blocks.append(encoder_block)
317
+
318
+ # Layer normalization for encoder.
319
+ self.encoder_final_layer_norm = keras.layers.LayerNormalization(
320
+ epsilon=1e-5,
321
+ center=False,
322
+ scale=True,
323
+ name="encoder_final_layer_norm",
324
+ dtype=dtype,
325
+ )
326
+
327
+ # Decoder blocks.
328
+ self.decoder_blocks = []
329
+ for i in range(decoder_num_layers):
330
+ decoder_block = MoonshineDecoderBlock(
331
+ hidden_dim=hidden_dim,
332
+ intermediate_dim=intermediate_dim,
333
+ num_heads=decoder_num_heads,
334
+ feedforward_expansion_factor=feedforward_expansion_factor,
335
+ use_swiglu_activation=decoder_use_swiglu_activation,
336
+ pad_head_dim_to_multiple_of=pad_head_dim_to_multiple_of,
337
+ initializer_range=initializer_range,
338
+ attention_bias=attention_bias,
339
+ attention_dropout=attention_dropout,
340
+ name=f"decoder_block_{i}",
341
+ dtype=dtype,
342
+ )
343
+ self.decoder_blocks.append(decoder_block)
344
+
345
+ # Layer normalization for decoder.
346
+ self.decoder_post_norm = keras.layers.LayerNormalization(
347
+ epsilon=1e-5,
348
+ center=False,
349
+ scale=True,
350
+ name="decoder_post_norm",
351
+ dtype=dtype,
352
+ )
353
+
354
+ # === Functional Model ===
355
+ encoder_raw_input_values = keras.Input(
356
+ shape=(None, 1), name="encoder_input_values", dtype=dtype
357
+ )
358
+ encoder_input_padding_mask = keras.Input(
359
+ shape=(None,), name="encoder_padding_mask", dtype="bool"
360
+ )
361
+ decoder_input = keras.Input(
362
+ shape=(None,), name="decoder_token_ids", dtype="int32"
363
+ )
364
+ decoder_padding_mask = keras.Input(
365
+ shape=(None,), name="decoder_padding_mask", dtype="bool"
366
+ )
367
+
368
+ # Feature extraction.
369
+ encoder_hidden_states = self.conv1(encoder_raw_input_values)
370
+ encoder_hidden_states = self.tanh_after_conv1(encoder_hidden_states)
371
+ encoder_hidden_states = self.group_norm(encoder_hidden_states)
372
+ encoder_hidden_states = self.conv2(encoder_hidden_states)
373
+ encoder_hidden_states = self.gelu_after_conv2(encoder_hidden_states)
374
+ encoder_hidden_states = self.conv3(encoder_hidden_states)
375
+ encoder_hidden_states = self.gelu_after_conv3(encoder_hidden_states)
376
+
377
+ # Compute mask for encoder features.
378
+ original_lengths = keras.ops.sum(
379
+ keras.ops.cast(encoder_input_padding_mask, "int32"), axis=1
380
+ )
381
+ output_lengths = compute_output_lengths(original_lengths)
382
+ encoder_attention_mask = self._compute_mask_layer(
383
+ encoder_hidden_states, output_lengths
384
+ )
385
+
386
+ # Encoder.
387
+ encoder_positions = Arange(name="encoder_positions")(
388
+ encoder_hidden_states
389
+ )
390
+ encoder_rotary_emb = self.encoder_rotary_embedding(encoder_positions)
391
+ encoder_hidden_states = self.encoder_dropout(encoder_hidden_states)
392
+ for encoder_block in self.encoder_blocks:
393
+ encoder_hidden_states = encoder_block(
394
+ encoder_hidden_states,
395
+ encoder_rotary_emb,
396
+ attention_mask=encoder_attention_mask,
397
+ )
398
+ encoder_output = self.encoder_final_layer_norm(encoder_hidden_states)
399
+
400
+ # Decoder.
401
+ decoder_positions = Arange(name="decoder_positions")(decoder_input)
402
+ decoder_rotary_emb = self.decoder_rotary_embedding(decoder_positions)
403
+ decoder_hidden_states = self.token_embedding(decoder_input)
404
+ decoder_hidden_states = self.decoder_dropout(decoder_hidden_states)
405
+ for decoder_block in self.decoder_blocks:
406
+ decoder_hidden_states = decoder_block(
407
+ decoder_sequence=decoder_hidden_states,
408
+ encoder_sequence=encoder_output,
409
+ rotary_embedding=decoder_rotary_emb,
410
+ decoder_padding_mask=decoder_padding_mask,
411
+ encoder_padding_mask=encoder_attention_mask,
412
+ )
413
+ decoder_output = self.decoder_post_norm(decoder_hidden_states)
414
+
415
+ super().__init__(
416
+ inputs={
417
+ "encoder_input_values": encoder_raw_input_values,
418
+ "encoder_padding_mask": encoder_input_padding_mask,
419
+ "decoder_token_ids": decoder_input,
420
+ "decoder_padding_mask": decoder_padding_mask,
421
+ },
422
+ outputs={
423
+ "encoder_sequence_output": encoder_output,
424
+ "decoder_sequence_output": decoder_output,
425
+ "encoder_attention_mask": encoder_attention_mask,
426
+ },
427
+ dtype=dtype,
428
+ **kwargs,
429
+ )
430
+
431
+ # ==== Config ====
432
+ self.vocabulary_size = vocabulary_size
433
+ self.filter_dim = filter_dim
434
+ self.encoder_num_layers = encoder_num_layers
435
+ self.decoder_num_layers = decoder_num_layers
436
+ self.hidden_dim = hidden_dim
437
+ self.intermediate_dim = intermediate_dim
438
+ self.encoder_num_heads = encoder_num_heads
439
+ self.decoder_num_heads = decoder_num_heads
440
+ self.feedforward_expansion_factor = feedforward_expansion_factor
441
+ self.encoder_use_swiglu_activation = encoder_use_swiglu_activation
442
+ self.decoder_use_swiglu_activation = decoder_use_swiglu_activation
443
+ self.max_position_embeddings = max_position_embeddings
444
+ self.pad_head_dim_to_multiple_of = pad_head_dim_to_multiple_of
445
+ self.partial_rotary_factor = partial_rotary_factor
446
+ self.dropout = dropout
447
+ self.initializer_range = initializer_range
448
+ self.rope_theta = rope_theta
449
+ self.attention_bias = attention_bias
450
+ self.attention_dropout = attention_dropout
451
+
452
+ def get_config(self):
453
+ config = super().get_config()
454
+ config.update(
455
+ {
456
+ "vocabulary_size": self.vocabulary_size,
457
+ "filter_dim": self.filter_dim,
458
+ "encoder_num_layers": self.encoder_num_layers,
459
+ "decoder_num_layers": self.decoder_num_layers,
460
+ "hidden_dim": self.hidden_dim,
461
+ "intermediate_dim": self.intermediate_dim,
462
+ "encoder_num_heads": self.encoder_num_heads,
463
+ "decoder_num_heads": self.decoder_num_heads,
464
+ "feedforward_expansion_factor": self.feedforward_expansion_factor, # noqa: E501
465
+ "encoder_use_swiglu_activation": self.encoder_use_swiglu_activation, # noqa: E501
466
+ "decoder_use_swiglu_activation": self.decoder_use_swiglu_activation, # noqa: E501
467
+ "max_position_embeddings": self.max_position_embeddings,
468
+ "pad_head_dim_to_multiple_of": self.pad_head_dim_to_multiple_of,
469
+ "partial_rotary_factor": self.partial_rotary_factor,
470
+ "dropout": self.dropout,
471
+ "initializer_range": self.initializer_range,
472
+ "rope_theta": self.rope_theta,
473
+ "attention_bias": self.attention_bias,
474
+ "attention_dropout": self.attention_dropout,
475
+ "dtype": self.dtype,
476
+ }
477
+ )
478
+ return config