keras-hub-nightly 0.21.0.dev202505140407__py3-none-any.whl → 0.21.0.dev202505150407__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,239 @@
1
+ import keras
2
+
3
+ from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
4
+ from keras_hub.src.utils.keras_utils import clone_initializer
5
+
6
+
7
+ def moonshine_kernel_initializer(initializer_range=0.02):
8
+ return keras.initializers.TruncatedNormal(stddev=initializer_range)
9
+
10
+
11
+ @keras.saving.register_keras_serializable(package="keras_hub")
12
+ class MoonshineRotaryEmbedding(RotaryEmbedding):
13
+ """
14
+ Moonshine rotary embedding layer.
15
+
16
+ Computes rotary positional embeddings using precomputed inverse frequencies
17
+ for a fraction of dimensions.
18
+
19
+ The layer stores inverse frequency weights as a non-trainable parameter and
20
+ computes sinusoidal embeddings based on input positions. Unlike KerasHub's
21
+ `RotaryEmbedding` class, this implementation explicitly requires `head_dim`
22
+ and applies `partial_rotary_factor` for selective rotary embedding, whereas
23
+ KerasHub uses `max_wavelength` without partial application.
24
+
25
+ Args:
26
+ head_dim: int. The dimensionality of each attention head, determining
27
+ the feature space for rotary embeddings.
28
+ max_position_embeddings: int, optional. The maximum sequence length the
29
+ model can process, controlling the positional embedding scale.
30
+ Defaults to 2048.
31
+ base_value: float, optional. Base value for computing inverse
32
+ frequencies. Higher values result in longer wavelengths. Defaults to
33
+ 10000.
34
+ partial_rotary_factor: float, optional. The fraction of `head_dim`
35
+ dimensions that receive rotary embeddings, balancing rotary and
36
+ non-rotary components. Defaults to 0.62.
37
+ dtype: string, optional. The data type for model computations and
38
+ weights. Defaults to None.
39
+ **kwargs: Additional keyword arguments passed to the parent class.
40
+ """
41
+
42
+ # References:
43
+ # Based on the UsefulSensors implementation of the RotaryEmbedding class (https://github.com/usefulsensors/moonshine/blob/4a000427bd36a1c2c6d20a86c672dbd850b44c88/moonshine/model.py#L176-L193).
44
+
45
+ def __init__(
46
+ self,
47
+ head_dim,
48
+ max_position_embeddings=2048,
49
+ base_value=10000,
50
+ partial_rotary_factor=0.62,
51
+ dtype=None,
52
+ **kwargs,
53
+ ):
54
+ super().__init__(dtype=dtype, **kwargs)
55
+ self.head_dim = head_dim
56
+ self.max_position_embeddings = max_position_embeddings
57
+ self.base_value = base_value
58
+ self.partial_rotary_factor = partial_rotary_factor
59
+ self.built = False
60
+ self.rotary_dim = None
61
+ self.inv_freq = None
62
+
63
+ def build(self, input_shape):
64
+ if self.built:
65
+ return
66
+ # Create and track the non-trainable weight immediately.
67
+ rotary_dim = int(self.head_dim * self.partial_rotary_factor)
68
+ rotary_dim = (rotary_dim // 2) * 2
69
+ if rotary_dim <= 0:
70
+ raise ValueError(
71
+ f"Calculated rotary_dim ({rotary_dim}) must be a positive even "
72
+ f"number. Check head_dim ({self.head_dim}) and "
73
+ f"partial_rotary_factor ({self.partial_rotary_factor})."
74
+ )
75
+ self.rotary_dim = rotary_dim
76
+ rotary_dim_half = rotary_dim // 2
77
+
78
+ # Compute inv_freq.
79
+ inv_freq = 1.0 / (
80
+ self.base_value
81
+ ** (
82
+ keras.ops.arange(0, rotary_dim_half, dtype=self.dtype)
83
+ / rotary_dim_half
84
+ )
85
+ )
86
+
87
+ # Set the non-trainable weight using the computed tensor.
88
+ self.inv_freq = self.add_weight(
89
+ name="inv_freq",
90
+ shape=(rotary_dim_half,),
91
+ initializer=keras.initializers.Constant(inv_freq),
92
+ trainable=False,
93
+ dtype=self.dtype,
94
+ )
95
+ self.built = True
96
+
97
+ def call(self, t):
98
+ t_cast = keras.ops.cast(t, keras.ops.dtype(self.inv_freq))
99
+ freqs = keras.ops.einsum("i,j->ij", t_cast, self.inv_freq)
100
+ emb = keras.ops.stack((freqs, freqs), axis=-1)
101
+ shape_list = list(keras.ops.shape(emb))
102
+ shape_list[-2:] = [-1]
103
+ return keras.ops.reshape(emb, shape_list)
104
+
105
+ def get_config(self):
106
+ config = super().get_config()
107
+ config.update(
108
+ {
109
+ "head_dim": self.head_dim,
110
+ "max_position_embeddings": self.max_position_embeddings,
111
+ "base_value": self.base_value,
112
+ "partial_rotary_factor": self.partial_rotary_factor,
113
+ "dtype": self.dtype,
114
+ }
115
+ )
116
+ return config
117
+
118
+
119
+ @keras.saving.register_keras_serializable(package="keras_hub")
120
+ class MoonshineMLP(keras.layers.Layer):
121
+ """
122
+ Moonshine MLP layer.
123
+
124
+ Implements a Multi-Layer Perceptron (MLP) for Moonshine models with support
125
+ for both `SwiGLU` and `LinearGeLU` activation patterns. The MLP consists of
126
+ two dense layers with an activation function in between, expanding the input
127
+ dimension before projecting back to the original dimension.
128
+
129
+ Args:
130
+ hidden_dim: int. The dimensionality of the input and output tensors.
131
+ feedforward_expansion_factor: float. The factor by which to expand the
132
+ hidden dimension in the intermediate layer.
133
+ use_swiglu_activation: bool, optional. If `True`, uses SwiGLU activation
134
+ (SiLU with gating). If `False`, uses standard GeLU activation.
135
+ Defaults to `True`.
136
+ initializer_range: float, optional. The standard deviation for kernel
137
+ initialization. Defaults to 0.02.
138
+ dtype: string, optional. The data type for model computations and
139
+ weights. Defaults to `None`.
140
+ **kwargs: Additional keyword arguments passed to the parent class.
141
+ """
142
+
143
+ # References:
144
+ # Based on the HuggingFace implementation of the MoonshineEncoderMLP and
145
+ # MoonshineDecoderMLP classes (https://github.com/huggingface/transformers/blob/fc8764c9a618add64c33e83720f974750bcd0978/src/transformers/models/moonshine/modeling_moonshine.py#L66-L94).
146
+
147
+ def __init__(
148
+ self,
149
+ hidden_dim,
150
+ feedforward_expansion_factor,
151
+ use_swiglu_activation=True,
152
+ initializer_range=0.02,
153
+ dtype=None,
154
+ **kwargs,
155
+ ):
156
+ super().__init__(dtype=dtype, **kwargs)
157
+ self.hidden_dim = hidden_dim
158
+ self.feedforward_expansion_factor = feedforward_expansion_factor
159
+ self.use_swiglu_activation = use_swiglu_activation
160
+ self.kernel_initializer = moonshine_kernel_initializer(
161
+ initializer_range=initializer_range
162
+ )
163
+ self.initializer_range = initializer_range
164
+
165
+ if use_swiglu_activation:
166
+ # First dense layer produces (2 * feedforward_expansion_factor *
167
+ # hidden_dim) outputs.
168
+ self.dense_1 = keras.layers.Dense(
169
+ int(hidden_dim * feedforward_expansion_factor * 2),
170
+ use_bias=True,
171
+ name="dense_1",
172
+ dtype=self.dtype,
173
+ kernel_initializer=clone_initializer(self.kernel_initializer),
174
+ )
175
+ # Activation layer using "silu" (Swish activation).
176
+ self.activation = keras.layers.Activation(
177
+ "silu", name="activation", dtype=self.dtype
178
+ )
179
+ else:
180
+ # Taken from pretrained weights.
181
+ # First dense layer: output dimension is (hidden_dim *
182
+ # feedforward_expansion_factor).
183
+ self.dense_1 = keras.layers.Dense(
184
+ int(hidden_dim * feedforward_expansion_factor),
185
+ use_bias=True,
186
+ name="dense_1",
187
+ dtype=self.dtype,
188
+ kernel_initializer=clone_initializer(self.kernel_initializer),
189
+ )
190
+ self.activation = keras.layers.Activation(
191
+ "gelu", name="activation", dtype=self.dtype
192
+ )
193
+
194
+ # Second dense layer projects back to hidden_dim.
195
+ self.dense_2 = keras.layers.Dense(
196
+ hidden_dim,
197
+ use_bias=True,
198
+ name="dense_2",
199
+ dtype=self.dtype,
200
+ kernel_initializer=clone_initializer(self.kernel_initializer),
201
+ )
202
+
203
+ def build(self, input_shape):
204
+ super().build(input_shape)
205
+ # Build the first dense layer using the original input shape.
206
+ self.dense_1.build(input_shape)
207
+ # After dense_1, the output shape becomes: (..., 2 *
208
+ # feedforward_expansion_factor * hidden_dim).
209
+ # When splitting, each part will have shape (...,
210
+ # feedforward_expansion_factor * hidden_dim).
211
+ new_input_shape = list(input_shape)
212
+ new_input_shape[-1] = (
213
+ self.hidden_dim * self.feedforward_expansion_factor
214
+ )
215
+ self.dense_2.build(tuple(new_input_shape))
216
+
217
+ def call(self, inputs):
218
+ x = self.dense_1(inputs)
219
+ if self.use_swiglu_activation:
220
+ x1, gate = keras.ops.split(x, 2, axis=-1)
221
+ activated_gate = self.activation(gate)
222
+ x = x1 * activated_gate
223
+ else:
224
+ x = self.activation(x)
225
+ output = self.dense_2(x)
226
+ return output
227
+
228
+ def get_config(self):
229
+ config = super().get_config()
230
+ config.update(
231
+ {
232
+ "hidden_dim": self.hidden_dim,
233
+ "feedforward_expansion_factor": self.feedforward_expansion_factor, # noqa: E501
234
+ "use_swiglu_activation": self.use_swiglu_activation,
235
+ "initializer_range": self.initializer_range,
236
+ "dtype": self.dtype,
237
+ }
238
+ )
239
+ return config
@@ -0,0 +1,355 @@
1
+ import keras
2
+ from keras import backend
3
+
4
+ from keras_hub.src.layers.modeling.cached_multi_head_attention import (
5
+ CachedMultiHeadAttention,
6
+ )
7
+ from keras_hub.src.models.whisper.whisper_cached_multi_head_attention import (
8
+ _build_proj_equation,
9
+ )
10
+ from keras_hub.src.models.whisper.whisper_cached_multi_head_attention import (
11
+ _get_output_shape,
12
+ )
13
+
14
+
15
+ # Removed dependence on einops.
16
+ # Source: https://github.com/usefulsensors/moonshine/blob/4a000427bd36a1c2c6d20a86c672dbd850b44c88/moonshine/model.py#L35
17
+ def _rotate_half(x):
18
+ """
19
+ Rotates the two halves of the last dimension.
20
+
21
+ This function splits the last dimension of the input tensor into two equal
22
+ halves and swaps them with a sign inversion. Specifically, for an input of
23
+ shape `[..., 2*d]`, it returns a tensor of the same shape where `[x1, x2]`
24
+ is transformed into `[-x2, x1]`.
25
+
26
+ Args:
27
+ x: Tensor. Shape `[..., 2*d]`. The input tensor to be rotated.
28
+
29
+ Returns:
30
+ Tensor: A tensor of shape `[..., 2*d]` with the two halves rotated.
31
+ """
32
+ # Conditional for Tensorflow backend.
33
+ if backend.backend() == "tensorflow":
34
+ x_shape = keras.ops.shape(x)
35
+ last_dim = x_shape[-1]
36
+ d = last_dim // 2
37
+ x_shape_tensor = keras.ops.convert_to_tensor(x_shape)
38
+ new_shape = keras.ops.concatenate(
39
+ [x_shape_tensor[:-1], keras.ops.convert_to_tensor([d, 2])], axis=0
40
+ )
41
+ x = keras.ops.reshape(x, new_shape)
42
+ x1 = x[..., 0]
43
+ x2 = x[..., 1]
44
+ x_rotated = keras.ops.stack([-x2, x1], axis=-1)
45
+ x_rotated = keras.ops.reshape(x_rotated, x_shape)
46
+ return x_rotated
47
+
48
+ # Conditional for PyTorch and JAX backends.
49
+ if backend.backend() == "torch" or backend.backend() == "jax":
50
+ x_shape = keras.ops.shape(x)
51
+ x_shape_tuple = tuple(
52
+ int(keras.ops.convert_to_numpy(dim).item()) for dim in x_shape
53
+ )
54
+ last_dim = x_shape_tuple[-1]
55
+ d = last_dim // 2
56
+ new_shape = x_shape_tuple[:-1] + (d, 2)
57
+ x = keras.ops.reshape(x, new_shape)
58
+ x1 = x[..., 0]
59
+ x2 = x[..., 1]
60
+ x_rotated = keras.ops.stack([-x2, x1], axis=-1)
61
+ x_rotated = keras.ops.reshape(x_rotated, x_shape_tuple)
62
+ return x_rotated
63
+
64
+ else:
65
+ raise NotImplementedError(
66
+ "Backend not supported. Please use TensorFlow, PyTorch, or JAX."
67
+ )
68
+
69
+
70
+ def _apply_rotary_pos_emb(t, freqs):
71
+ """
72
+ Applies rotary positional embeddings to the input tensor. Used in on-the-fly
73
+ computation of rotary positional embeddings in multi-head attention layers.
74
+
75
+ Args:
76
+ t: A tensor with shape `[..., seq_len, ..., hidden_dim]` where the
77
+ rotary embedding is applied to the first `rot_dim` channels of the
78
+ last dimension.
79
+ freqs: A tensor of frequency values with shape `[max_seq_len, rot_dim]`.
80
+ The last `seq_len` entries are used to compute the rotary
81
+ embeddings.
82
+
83
+ Returns:
84
+ Tensor: A tensor of the same shape as `t` with the rotary positional
85
+ embeddings applied to the first `rot_dim` channels of the last dimension
86
+ and the remaining channels concatenated unchanged.
87
+ """
88
+ rot_dim = keras.ops.shape(freqs)[-1]
89
+ seq_len = keras.ops.shape(t)[1]
90
+ orig_dtype = t.dtype
91
+ freqs = freqs[:seq_len, :]
92
+ freqs = keras.ops.reshape(freqs, (1, seq_len, 1, rot_dim))
93
+ t_rot = t[..., :rot_dim]
94
+ t_nonrot = t[..., rot_dim:]
95
+ t_rotated = t_rot * keras.ops.cos(freqs) + _rotate_half(
96
+ t_rot
97
+ ) * keras.ops.sin(freqs)
98
+ out = keras.ops.concatenate([t_rotated, t_nonrot], axis=-1)
99
+ return keras.ops.cast(out, orig_dtype)
100
+
101
+
102
+ @keras.saving.register_keras_serializable(package="keras_hub")
103
+ class MoonshineMultiHeadAttention(CachedMultiHeadAttention):
104
+ """
105
+ Moonshine multi-head attention layer.
106
+
107
+ Implements a multi-head attention mechanism for Moonshine models with
108
+ support for rotary position embeddings and different caching strategies.
109
+ This layer extends the `CachedMultiHeadAttention` base class to include
110
+ specialized functionality for Moonshine models, such as rotary embeddings
111
+ and causal masking.
112
+
113
+ Args:
114
+ num_heads: int. Number of attention heads.
115
+ key_dim: int. Size of each attention head for key.
116
+ value_dim: int, optional. Size of each attention head for value. If
117
+ None, defaults to `key_dim`.
118
+ attention_bias: bool, optional. Whether to include bias in attention
119
+ projection layers. Defaults to `False`.
120
+ attention_dropout: float, optional. Dropout probability for attention
121
+ weights. Defaults to 0.0.
122
+ use_causal_mask: bool, optional. Whether to apply causal masking to
123
+ prevent positions from attending to subsequent positions. Defaults
124
+ to `False`.
125
+ apply_rotary_embedding: bool, optional. Whether to apply rotary position
126
+ embeddings to queries and keys. Defaults to `True`.
127
+ **kwargs: Additional keyword arguments passed to the parent class.
128
+ """
129
+
130
+ # References:
131
+ # Based on the HuggingFace implementation of the MoonshineAttention class (https://github.com/huggingface/transformers/blob/fc8764c9a618add64c33e83720f974750bcd0978/src/transformers/models/moonshine/modeling_moonshine.py#L184-L315).
132
+
133
+ def __init__(
134
+ self,
135
+ num_heads,
136
+ key_dim,
137
+ value_dim=None,
138
+ attention_bias=False,
139
+ attention_dropout=0.0,
140
+ use_causal_mask=False,
141
+ apply_rotary_embedding=True,
142
+ **kwargs,
143
+ ):
144
+ kwargs.pop("use_bias", None)
145
+ kwargs.pop("dropout", None)
146
+ super().__init__(
147
+ num_heads=num_heads,
148
+ key_dim=key_dim,
149
+ value_dim=value_dim,
150
+ use_bias=attention_bias,
151
+ dropout=attention_dropout,
152
+ **kwargs,
153
+ )
154
+ self.attention_bias = attention_bias
155
+ self.attention_dropout = attention_dropout
156
+ self.use_causal_mask = use_causal_mask
157
+ self.apply_rotary_embedding = apply_rotary_embedding
158
+
159
+ def build(self, query_shape, value_shape, key_shape=None):
160
+ # Ensure key_shape is defined.
161
+ key_shape = value_shape if key_shape is None else key_shape
162
+ query_rank = len(query_shape)
163
+ value_rank = len(value_shape)
164
+ key_rank = len(key_shape)
165
+
166
+ # Build query projection layer.
167
+ einsum_equation, bias_axes, output_rank = _build_proj_equation(
168
+ free_dims=query_rank - 1, bound_dims=1, output_dims=2
169
+ )
170
+ self._query_dense = keras.layers.EinsumDense(
171
+ einsum_equation,
172
+ output_shape=_get_output_shape(
173
+ output_rank - 1, [self._num_heads, self._key_dim]
174
+ ),
175
+ bias_axes=bias_axes if self._use_bias else None,
176
+ name="query",
177
+ **self._get_common_kwargs_for_sublayer(),
178
+ )
179
+ self._query_dense.build(query_shape)
180
+
181
+ # Build key projection layer.
182
+ einsum_equation, bias_axes, output_rank = _build_proj_equation(
183
+ free_dims=key_rank - 1, bound_dims=1, output_dims=2
184
+ )
185
+ self._key_dense = keras.layers.EinsumDense(
186
+ einsum_equation,
187
+ output_shape=_get_output_shape(
188
+ output_rank - 1, [self._num_heads, self._key_dim]
189
+ ),
190
+ bias_axes=bias_axes if self._use_bias else None,
191
+ name="key",
192
+ **self._get_common_kwargs_for_sublayer(),
193
+ )
194
+ self._key_dense.build(key_shape)
195
+
196
+ # Build value projection layer.
197
+ einsum_equation, bias_axes, output_rank = _build_proj_equation(
198
+ free_dims=value_rank - 1, bound_dims=1, output_dims=2
199
+ )
200
+ self._value_dense = keras.layers.EinsumDense(
201
+ einsum_equation,
202
+ output_shape=_get_output_shape(
203
+ output_rank - 1, [self._num_heads, self._value_dim]
204
+ ),
205
+ bias_axes=bias_axes if self._use_bias else None,
206
+ name="value",
207
+ **self._get_common_kwargs_for_sublayer(),
208
+ )
209
+ self._value_dense.build(value_shape)
210
+
211
+ # Build the internal attention computation sublayer.
212
+ self._build_attention(output_rank)
213
+
214
+ # Build output projection layer.
215
+ output_shape = (
216
+ query_shape[-1] if not self._output_shape else self._output_shape
217
+ )
218
+ if isinstance(output_shape, (list, tuple)):
219
+ output_shape = list(output_shape)
220
+ else:
221
+ output_shape = [output_shape]
222
+
223
+ einsum_equation, bias_axes, output_rank = _build_proj_equation(
224
+ free_dims=query_rank - 1,
225
+ bound_dims=2,
226
+ output_dims=len(output_shape),
227
+ )
228
+ self._output_dense = keras.layers.EinsumDense(
229
+ einsum_equation,
230
+ output_shape=_get_output_shape(output_rank - 1, output_shape),
231
+ bias_axes=bias_axes if self._use_bias else None,
232
+ name="attention_output",
233
+ **self._get_common_kwargs_for_sublayer(),
234
+ )
235
+ output_dense_input_shape = list(
236
+ self._query_dense.compute_output_shape(query_shape)
237
+ )
238
+ output_dense_input_shape[-1] = self._value_dim
239
+ self._output_dense.build(tuple(output_dense_input_shape))
240
+
241
+ self.built = True
242
+
243
+ def _compute_causal_mask(self, query, value=None, for_cache=False):
244
+ if backend.backend() == "torch" or backend.backend() == "jax":
245
+ q_seq_length = int(
246
+ keras.ops.convert_to_numpy(keras.ops.shape(query)[1]).item()
247
+ )
248
+ v_seq_length = (
249
+ int(
250
+ keras.ops.convert_to_numpy(keras.ops.shape(value)[1]).item()
251
+ )
252
+ if value is not None
253
+ else q_seq_length
254
+ )
255
+ elif backend.backend() == "tensorflow":
256
+ if for_cache:
257
+ assert value is not None
258
+ v_seq_length = keras.ops.shape(value)[1]
259
+ else:
260
+ v_seq_length = keras.ops.shape(query)[1]
261
+ q_seq_length = keras.ops.shape(query)[1]
262
+ n_rows = v_seq_length if for_cache else q_seq_length
263
+ ones_mask = keras.ops.ones((1, n_rows, v_seq_length), dtype="int32")
264
+ row_index = keras.ops.cumsum(ones_mask, axis=-2)
265
+ col_index = keras.ops.cumsum(ones_mask, axis=-1)
266
+ mask = keras.ops.greater_equal(row_index, col_index)
267
+
268
+ if for_cache:
269
+ mask = mask[:, -q_seq_length:, :]
270
+
271
+ return mask
272
+
273
+ def call(
274
+ self,
275
+ query,
276
+ value,
277
+ key,
278
+ rotary_embedding=None,
279
+ attention_mask=None,
280
+ cache=None,
281
+ cache_update_index=None,
282
+ training=None,
283
+ **kwargs,
284
+ ):
285
+ # Project inputs.
286
+ query_proj = self._query_dense(query)
287
+ if rotary_embedding is not None:
288
+ query_proj = _apply_rotary_pos_emb(query_proj, rotary_embedding)
289
+
290
+ # Handle caching.
291
+ if cache is not None:
292
+ key_cache = cache[:, 0, ...]
293
+ value_cache = cache[:, 1, ...]
294
+ if cache_update_index is None:
295
+ key_proj = key_cache
296
+ value_proj = value_cache
297
+ else:
298
+ new_key = self._key_dense(key)
299
+ new_value = self._value_dense(value)
300
+ if self.apply_rotary_embedding and rotary_embedding is not None:
301
+ new_key = _apply_rotary_pos_emb(new_key, rotary_embedding)
302
+ update_shape = keras.ops.shape(new_key)
303
+ start_indices = [0] * len(update_shape)
304
+ start_indices[1] = cache_update_index
305
+ key_proj = keras.ops.slice_update(
306
+ key_cache, tuple(start_indices), new_key
307
+ )
308
+ value_proj = keras.ops.slice_update(
309
+ value_cache, tuple(start_indices), new_value
310
+ )
311
+ cache = keras.ops.stack((key_proj, value_proj), axis=1)
312
+
313
+ else:
314
+ if cache_update_index is not None:
315
+ raise ValueError(
316
+ "`cache_update_index` should not be set if `cache` is "
317
+ f"`None`. Received: cache={cache}, cache_update_index="
318
+ f"{cache_update_index}"
319
+ )
320
+ key_proj = self._key_dense(key)
321
+ value_proj = self._value_dense(value)
322
+ if self.apply_rotary_embedding and rotary_embedding is not None:
323
+ key_proj = _apply_rotary_pos_emb(key_proj, rotary_embedding)
324
+
325
+ # Compute attention mask.
326
+ final_mask = attention_mask
327
+
328
+ if final_mask is not None:
329
+ mask_shape = keras.ops.shape(final_mask)
330
+ if len(mask_shape) == 2:
331
+ final_mask = final_mask[:, None, None, :]
332
+ elif len(mask_shape) == 3:
333
+ final_mask = final_mask[:, None, :, :]
334
+
335
+ attention_kwargs = {
336
+ k: v for k, v in kwargs.items() if k != "padding_mask"
337
+ }
338
+ # Compute attention.
339
+ attention_output, _ = self._compute_attention(
340
+ query=query_proj,
341
+ key=key_proj,
342
+ value=value_proj,
343
+ attention_mask=final_mask,
344
+ training=training,
345
+ **attention_kwargs,
346
+ )
347
+
348
+ # Project the attention output.
349
+ output = self._output_dense(attention_output)
350
+
351
+ # Return output + cache if cache is provided, otherwise return just
352
+ # output.
353
+ if cache is not None:
354
+ return output, cache
355
+ return output
@@ -0,0 +1,25 @@
1
+ # Metadata for loading pretrained model weights.
2
+ backbone_presets = {
3
+ "moonshine_tiny_en": {
4
+ "metadata": {
5
+ "description": (
6
+ "Moonshine tiny model for English speech recognition. "
7
+ "Developed by Useful Sensors for real-time transcription."
8
+ ),
9
+ "params": 27092736,
10
+ "path": "moonshine",
11
+ },
12
+ "kaggle_handle": "",
13
+ },
14
+ "moonshine_base_en": {
15
+ "metadata": {
16
+ "description": (
17
+ "Moonshine base model for English speech recognition. "
18
+ "Developed by Useful Sensors for real-time transcription."
19
+ ),
20
+ "params": 61513920,
21
+ "path": "moonshine",
22
+ },
23
+ "kaggle_handle": "",
24
+ },
25
+ }
@@ -0,0 +1,62 @@
1
+ from keras_hub.src.api_export import keras_hub_export
2
+ from keras_hub.src.models.llama.llama_tokenizer import LlamaTokenizer
3
+
4
+
5
+ @keras_hub_export(
6
+ [
7
+ "keras_hub.tokenizers.MoonshineTokenizer",
8
+ "keras_hub.models.MoonshineTokenizer",
9
+ ]
10
+ )
11
+ class MoonshineTokenizer(LlamaTokenizer):
12
+ """
13
+ Moonshine tokenizer layer based on `keras_hub.models.LlamaTokenizer`.
14
+
15
+ This tokenizer class is an alias of `LlamaTokenizer` but for the Moonshine
16
+ model. It uses a SentencePiece vocabulary to handle tokenization.
17
+
18
+ Args:
19
+ proto: `str` or `bytes`. Either a string path to a SentencePiece proto
20
+ file or a bytes object containing a serialized SentencePiece proto.
21
+ See the [SentencePiece repository](https://github.com/google/sentencepiece)
22
+ for details on the format.
23
+ **kwargs: Additional keyword arguments passed to the parent
24
+ `LlamaTokenizer`.
25
+
26
+ Examples:
27
+ ```python
28
+ from keras_hub.tokenizers import MoonshineTokenizer
29
+
30
+ # Initialize tokenizer.
31
+ tokenizer = MoonshineTokenizer(
32
+ "keras_hub/src/tests/test_data/llama_test_vocab.spm"
33
+ )
34
+
35
+ # Single input example.
36
+ single_input = "the quick brown fox"
37
+ single_tokens = tokenizer(single_input)
38
+ print("Single input tokenization:")
39
+ print(f"Input text: {single_input}")
40
+ print(f"Tokenized: {single_tokens}")
41
+
42
+ # Batched input example.
43
+ batch_input = ["the quick brown fox", "the earth is round"]
44
+ batch_tokens = tokenizer(batch_input)
45
+ print("Batch input tokenization:")
46
+ print(f"Input texts: {batch_input}")
47
+ print(f"Tokenized: {batch_tokens}")
48
+
49
+ # Detokenization example.
50
+ encoded = tokenizer(single_input)
51
+ decoded = tokenizer.detokenize(encoded)
52
+ print("Detokenization:")
53
+ print(f"Original text: {single_input}")
54
+ print(f"Encoded: {encoded}")
55
+ print(f"Decoded: {decoded}")
56
+ ```
57
+ """
58
+
59
+ # NOTE: The 768 future-use tokens defined in Section 3.1 of the Moonshine
60
+ # paper, "Moonshine: Speech Recognition for Live Transcription and Voice
61
+ # Commands" (https://arxiv.org/pdf/2410.15608.pdf) serve no purpose in the
62
+ # tokenizer at the moment, and are hence not included in the vocabulary.
keras_hub/src/version.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from keras_hub.src.api_export import keras_hub_export
2
2
 
3
3
  # Unique source of truth for the version number.
4
- __version__ = "0.21.0.dev202505140407"
4
+ __version__ = "0.21.0.dev202505150407"
5
5
 
6
6
 
7
7
  @keras_hub_export("keras_hub.version")