keras-hub-nightly 0.21.0.dev202505140407__py3-none-any.whl → 0.21.0.dev202505150407__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/layers/__init__.py +3 -0
- keras_hub/models/__init__.py +12 -0
- keras_hub/src/models/moonshine/__init__.py +0 -0
- keras_hub/src/models/moonshine/moonshine_audio_converter.py +301 -0
- keras_hub/src/models/moonshine/moonshine_audio_to_text.py +383 -0
- keras_hub/src/models/moonshine/moonshine_audio_to_text_preprocessor.py +267 -0
- keras_hub/src/models/moonshine/moonshine_backbone.py +478 -0
- keras_hub/src/models/moonshine/moonshine_decoder.py +313 -0
- keras_hub/src/models/moonshine/moonshine_encoder.py +212 -0
- keras_hub/src/models/moonshine/moonshine_layers.py +239 -0
- keras_hub/src/models/moonshine/moonshine_multi_head_attention.py +355 -0
- keras_hub/src/models/moonshine/moonshine_presets.py +25 -0
- keras_hub/src/models/moonshine/moonshine_tokenizer.py +62 -0
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +3 -0
- {keras_hub_nightly-0.21.0.dev202505140407.dist-info → keras_hub_nightly-0.21.0.dev202505150407.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.21.0.dev202505140407.dist-info → keras_hub_nightly-0.21.0.dev202505150407.dist-info}/RECORD +19 -8
- {keras_hub_nightly-0.21.0.dev202505140407.dist-info → keras_hub_nightly-0.21.0.dev202505150407.dist-info}/WHEEL +1 -1
- {keras_hub_nightly-0.21.0.dev202505140407.dist-info → keras_hub_nightly-0.21.0.dev202505150407.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,239 @@
|
|
1
|
+
import keras
|
2
|
+
|
3
|
+
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
|
4
|
+
from keras_hub.src.utils.keras_utils import clone_initializer
|
5
|
+
|
6
|
+
|
7
|
+
def moonshine_kernel_initializer(initializer_range=0.02):
|
8
|
+
return keras.initializers.TruncatedNormal(stddev=initializer_range)
|
9
|
+
|
10
|
+
|
11
|
+
@keras.saving.register_keras_serializable(package="keras_hub")
|
12
|
+
class MoonshineRotaryEmbedding(RotaryEmbedding):
|
13
|
+
"""
|
14
|
+
Moonshine rotary embedding layer.
|
15
|
+
|
16
|
+
Computes rotary positional embeddings using precomputed inverse frequencies
|
17
|
+
for a fraction of dimensions.
|
18
|
+
|
19
|
+
The layer stores inverse frequency weights as a non-trainable parameter and
|
20
|
+
computes sinusoidal embeddings based on input positions. Unlike KerasHub's
|
21
|
+
`RotaryEmbedding` class, this implementation explicitly requires `head_dim`
|
22
|
+
and applies `partial_rotary_factor` for selective rotary embedding, whereas
|
23
|
+
KerasHub uses `max_wavelength` without partial application.
|
24
|
+
|
25
|
+
Args:
|
26
|
+
head_dim: int. The dimensionality of each attention head, determining
|
27
|
+
the feature space for rotary embeddings.
|
28
|
+
max_position_embeddings: int, optional. The maximum sequence length the
|
29
|
+
model can process, controlling the positional embedding scale.
|
30
|
+
Defaults to 2048.
|
31
|
+
base_value: float, optional. Base value for computing inverse
|
32
|
+
frequencies. Higher values result in longer wavelengths. Defaults to
|
33
|
+
10000.
|
34
|
+
partial_rotary_factor: float, optional. The fraction of `head_dim`
|
35
|
+
dimensions that receive rotary embeddings, balancing rotary and
|
36
|
+
non-rotary components. Defaults to 0.62.
|
37
|
+
dtype: string, optional. The data type for model computations and
|
38
|
+
weights. Defaults to None.
|
39
|
+
**kwargs: Additional keyword arguments passed to the parent class.
|
40
|
+
"""
|
41
|
+
|
42
|
+
# References:
|
43
|
+
# Based on the UsefulSensors implementation of the RotaryEmbedding class (https://github.com/usefulsensors/moonshine/blob/4a000427bd36a1c2c6d20a86c672dbd850b44c88/moonshine/model.py#L176-L193).
|
44
|
+
|
45
|
+
def __init__(
|
46
|
+
self,
|
47
|
+
head_dim,
|
48
|
+
max_position_embeddings=2048,
|
49
|
+
base_value=10000,
|
50
|
+
partial_rotary_factor=0.62,
|
51
|
+
dtype=None,
|
52
|
+
**kwargs,
|
53
|
+
):
|
54
|
+
super().__init__(dtype=dtype, **kwargs)
|
55
|
+
self.head_dim = head_dim
|
56
|
+
self.max_position_embeddings = max_position_embeddings
|
57
|
+
self.base_value = base_value
|
58
|
+
self.partial_rotary_factor = partial_rotary_factor
|
59
|
+
self.built = False
|
60
|
+
self.rotary_dim = None
|
61
|
+
self.inv_freq = None
|
62
|
+
|
63
|
+
def build(self, input_shape):
|
64
|
+
if self.built:
|
65
|
+
return
|
66
|
+
# Create and track the non-trainable weight immediately.
|
67
|
+
rotary_dim = int(self.head_dim * self.partial_rotary_factor)
|
68
|
+
rotary_dim = (rotary_dim // 2) * 2
|
69
|
+
if rotary_dim <= 0:
|
70
|
+
raise ValueError(
|
71
|
+
f"Calculated rotary_dim ({rotary_dim}) must be a positive even "
|
72
|
+
f"number. Check head_dim ({self.head_dim}) and "
|
73
|
+
f"partial_rotary_factor ({self.partial_rotary_factor})."
|
74
|
+
)
|
75
|
+
self.rotary_dim = rotary_dim
|
76
|
+
rotary_dim_half = rotary_dim // 2
|
77
|
+
|
78
|
+
# Compute inv_freq.
|
79
|
+
inv_freq = 1.0 / (
|
80
|
+
self.base_value
|
81
|
+
** (
|
82
|
+
keras.ops.arange(0, rotary_dim_half, dtype=self.dtype)
|
83
|
+
/ rotary_dim_half
|
84
|
+
)
|
85
|
+
)
|
86
|
+
|
87
|
+
# Set the non-trainable weight using the computed tensor.
|
88
|
+
self.inv_freq = self.add_weight(
|
89
|
+
name="inv_freq",
|
90
|
+
shape=(rotary_dim_half,),
|
91
|
+
initializer=keras.initializers.Constant(inv_freq),
|
92
|
+
trainable=False,
|
93
|
+
dtype=self.dtype,
|
94
|
+
)
|
95
|
+
self.built = True
|
96
|
+
|
97
|
+
def call(self, t):
|
98
|
+
t_cast = keras.ops.cast(t, keras.ops.dtype(self.inv_freq))
|
99
|
+
freqs = keras.ops.einsum("i,j->ij", t_cast, self.inv_freq)
|
100
|
+
emb = keras.ops.stack((freqs, freqs), axis=-1)
|
101
|
+
shape_list = list(keras.ops.shape(emb))
|
102
|
+
shape_list[-2:] = [-1]
|
103
|
+
return keras.ops.reshape(emb, shape_list)
|
104
|
+
|
105
|
+
def get_config(self):
|
106
|
+
config = super().get_config()
|
107
|
+
config.update(
|
108
|
+
{
|
109
|
+
"head_dim": self.head_dim,
|
110
|
+
"max_position_embeddings": self.max_position_embeddings,
|
111
|
+
"base_value": self.base_value,
|
112
|
+
"partial_rotary_factor": self.partial_rotary_factor,
|
113
|
+
"dtype": self.dtype,
|
114
|
+
}
|
115
|
+
)
|
116
|
+
return config
|
117
|
+
|
118
|
+
|
119
|
+
@keras.saving.register_keras_serializable(package="keras_hub")
|
120
|
+
class MoonshineMLP(keras.layers.Layer):
|
121
|
+
"""
|
122
|
+
Moonshine MLP layer.
|
123
|
+
|
124
|
+
Implements a Multi-Layer Perceptron (MLP) for Moonshine models with support
|
125
|
+
for both `SwiGLU` and `LinearGeLU` activation patterns. The MLP consists of
|
126
|
+
two dense layers with an activation function in between, expanding the input
|
127
|
+
dimension before projecting back to the original dimension.
|
128
|
+
|
129
|
+
Args:
|
130
|
+
hidden_dim: int. The dimensionality of the input and output tensors.
|
131
|
+
feedforward_expansion_factor: float. The factor by which to expand the
|
132
|
+
hidden dimension in the intermediate layer.
|
133
|
+
use_swiglu_activation: bool, optional. If `True`, uses SwiGLU activation
|
134
|
+
(SiLU with gating). If `False`, uses standard GeLU activation.
|
135
|
+
Defaults to `True`.
|
136
|
+
initializer_range: float, optional. The standard deviation for kernel
|
137
|
+
initialization. Defaults to 0.02.
|
138
|
+
dtype: string, optional. The data type for model computations and
|
139
|
+
weights. Defaults to `None`.
|
140
|
+
**kwargs: Additional keyword arguments passed to the parent class.
|
141
|
+
"""
|
142
|
+
|
143
|
+
# References:
|
144
|
+
# Based on the HuggingFace implementation of the MoonshineEncoderMLP and
|
145
|
+
# MoonshineDecoderMLP classes (https://github.com/huggingface/transformers/blob/fc8764c9a618add64c33e83720f974750bcd0978/src/transformers/models/moonshine/modeling_moonshine.py#L66-L94).
|
146
|
+
|
147
|
+
def __init__(
|
148
|
+
self,
|
149
|
+
hidden_dim,
|
150
|
+
feedforward_expansion_factor,
|
151
|
+
use_swiglu_activation=True,
|
152
|
+
initializer_range=0.02,
|
153
|
+
dtype=None,
|
154
|
+
**kwargs,
|
155
|
+
):
|
156
|
+
super().__init__(dtype=dtype, **kwargs)
|
157
|
+
self.hidden_dim = hidden_dim
|
158
|
+
self.feedforward_expansion_factor = feedforward_expansion_factor
|
159
|
+
self.use_swiglu_activation = use_swiglu_activation
|
160
|
+
self.kernel_initializer = moonshine_kernel_initializer(
|
161
|
+
initializer_range=initializer_range
|
162
|
+
)
|
163
|
+
self.initializer_range = initializer_range
|
164
|
+
|
165
|
+
if use_swiglu_activation:
|
166
|
+
# First dense layer produces (2 * feedforward_expansion_factor *
|
167
|
+
# hidden_dim) outputs.
|
168
|
+
self.dense_1 = keras.layers.Dense(
|
169
|
+
int(hidden_dim * feedforward_expansion_factor * 2),
|
170
|
+
use_bias=True,
|
171
|
+
name="dense_1",
|
172
|
+
dtype=self.dtype,
|
173
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
174
|
+
)
|
175
|
+
# Activation layer using "silu" (Swish activation).
|
176
|
+
self.activation = keras.layers.Activation(
|
177
|
+
"silu", name="activation", dtype=self.dtype
|
178
|
+
)
|
179
|
+
else:
|
180
|
+
# Taken from pretrained weights.
|
181
|
+
# First dense layer: output dimension is (hidden_dim *
|
182
|
+
# feedforward_expansion_factor).
|
183
|
+
self.dense_1 = keras.layers.Dense(
|
184
|
+
int(hidden_dim * feedforward_expansion_factor),
|
185
|
+
use_bias=True,
|
186
|
+
name="dense_1",
|
187
|
+
dtype=self.dtype,
|
188
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
189
|
+
)
|
190
|
+
self.activation = keras.layers.Activation(
|
191
|
+
"gelu", name="activation", dtype=self.dtype
|
192
|
+
)
|
193
|
+
|
194
|
+
# Second dense layer projects back to hidden_dim.
|
195
|
+
self.dense_2 = keras.layers.Dense(
|
196
|
+
hidden_dim,
|
197
|
+
use_bias=True,
|
198
|
+
name="dense_2",
|
199
|
+
dtype=self.dtype,
|
200
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
201
|
+
)
|
202
|
+
|
203
|
+
def build(self, input_shape):
|
204
|
+
super().build(input_shape)
|
205
|
+
# Build the first dense layer using the original input shape.
|
206
|
+
self.dense_1.build(input_shape)
|
207
|
+
# After dense_1, the output shape becomes: (..., 2 *
|
208
|
+
# feedforward_expansion_factor * hidden_dim).
|
209
|
+
# When splitting, each part will have shape (...,
|
210
|
+
# feedforward_expansion_factor * hidden_dim).
|
211
|
+
new_input_shape = list(input_shape)
|
212
|
+
new_input_shape[-1] = (
|
213
|
+
self.hidden_dim * self.feedforward_expansion_factor
|
214
|
+
)
|
215
|
+
self.dense_2.build(tuple(new_input_shape))
|
216
|
+
|
217
|
+
def call(self, inputs):
|
218
|
+
x = self.dense_1(inputs)
|
219
|
+
if self.use_swiglu_activation:
|
220
|
+
x1, gate = keras.ops.split(x, 2, axis=-1)
|
221
|
+
activated_gate = self.activation(gate)
|
222
|
+
x = x1 * activated_gate
|
223
|
+
else:
|
224
|
+
x = self.activation(x)
|
225
|
+
output = self.dense_2(x)
|
226
|
+
return output
|
227
|
+
|
228
|
+
def get_config(self):
|
229
|
+
config = super().get_config()
|
230
|
+
config.update(
|
231
|
+
{
|
232
|
+
"hidden_dim": self.hidden_dim,
|
233
|
+
"feedforward_expansion_factor": self.feedforward_expansion_factor, # noqa: E501
|
234
|
+
"use_swiglu_activation": self.use_swiglu_activation,
|
235
|
+
"initializer_range": self.initializer_range,
|
236
|
+
"dtype": self.dtype,
|
237
|
+
}
|
238
|
+
)
|
239
|
+
return config
|
@@ -0,0 +1,355 @@
|
|
1
|
+
import keras
|
2
|
+
from keras import backend
|
3
|
+
|
4
|
+
from keras_hub.src.layers.modeling.cached_multi_head_attention import (
|
5
|
+
CachedMultiHeadAttention,
|
6
|
+
)
|
7
|
+
from keras_hub.src.models.whisper.whisper_cached_multi_head_attention import (
|
8
|
+
_build_proj_equation,
|
9
|
+
)
|
10
|
+
from keras_hub.src.models.whisper.whisper_cached_multi_head_attention import (
|
11
|
+
_get_output_shape,
|
12
|
+
)
|
13
|
+
|
14
|
+
|
15
|
+
# Removed dependence on einops.
|
16
|
+
# Source: https://github.com/usefulsensors/moonshine/blob/4a000427bd36a1c2c6d20a86c672dbd850b44c88/moonshine/model.py#L35
|
17
|
+
def _rotate_half(x):
|
18
|
+
"""
|
19
|
+
Rotates the two halves of the last dimension.
|
20
|
+
|
21
|
+
This function splits the last dimension of the input tensor into two equal
|
22
|
+
halves and swaps them with a sign inversion. Specifically, for an input of
|
23
|
+
shape `[..., 2*d]`, it returns a tensor of the same shape where `[x1, x2]`
|
24
|
+
is transformed into `[-x2, x1]`.
|
25
|
+
|
26
|
+
Args:
|
27
|
+
x: Tensor. Shape `[..., 2*d]`. The input tensor to be rotated.
|
28
|
+
|
29
|
+
Returns:
|
30
|
+
Tensor: A tensor of shape `[..., 2*d]` with the two halves rotated.
|
31
|
+
"""
|
32
|
+
# Conditional for Tensorflow backend.
|
33
|
+
if backend.backend() == "tensorflow":
|
34
|
+
x_shape = keras.ops.shape(x)
|
35
|
+
last_dim = x_shape[-1]
|
36
|
+
d = last_dim // 2
|
37
|
+
x_shape_tensor = keras.ops.convert_to_tensor(x_shape)
|
38
|
+
new_shape = keras.ops.concatenate(
|
39
|
+
[x_shape_tensor[:-1], keras.ops.convert_to_tensor([d, 2])], axis=0
|
40
|
+
)
|
41
|
+
x = keras.ops.reshape(x, new_shape)
|
42
|
+
x1 = x[..., 0]
|
43
|
+
x2 = x[..., 1]
|
44
|
+
x_rotated = keras.ops.stack([-x2, x1], axis=-1)
|
45
|
+
x_rotated = keras.ops.reshape(x_rotated, x_shape)
|
46
|
+
return x_rotated
|
47
|
+
|
48
|
+
# Conditional for PyTorch and JAX backends.
|
49
|
+
if backend.backend() == "torch" or backend.backend() == "jax":
|
50
|
+
x_shape = keras.ops.shape(x)
|
51
|
+
x_shape_tuple = tuple(
|
52
|
+
int(keras.ops.convert_to_numpy(dim).item()) for dim in x_shape
|
53
|
+
)
|
54
|
+
last_dim = x_shape_tuple[-1]
|
55
|
+
d = last_dim // 2
|
56
|
+
new_shape = x_shape_tuple[:-1] + (d, 2)
|
57
|
+
x = keras.ops.reshape(x, new_shape)
|
58
|
+
x1 = x[..., 0]
|
59
|
+
x2 = x[..., 1]
|
60
|
+
x_rotated = keras.ops.stack([-x2, x1], axis=-1)
|
61
|
+
x_rotated = keras.ops.reshape(x_rotated, x_shape_tuple)
|
62
|
+
return x_rotated
|
63
|
+
|
64
|
+
else:
|
65
|
+
raise NotImplementedError(
|
66
|
+
"Backend not supported. Please use TensorFlow, PyTorch, or JAX."
|
67
|
+
)
|
68
|
+
|
69
|
+
|
70
|
+
def _apply_rotary_pos_emb(t, freqs):
|
71
|
+
"""
|
72
|
+
Applies rotary positional embeddings to the input tensor. Used in on-the-fly
|
73
|
+
computation of rotary positional embeddings in multi-head attention layers.
|
74
|
+
|
75
|
+
Args:
|
76
|
+
t: A tensor with shape `[..., seq_len, ..., hidden_dim]` where the
|
77
|
+
rotary embedding is applied to the first `rot_dim` channels of the
|
78
|
+
last dimension.
|
79
|
+
freqs: A tensor of frequency values with shape `[max_seq_len, rot_dim]`.
|
80
|
+
The last `seq_len` entries are used to compute the rotary
|
81
|
+
embeddings.
|
82
|
+
|
83
|
+
Returns:
|
84
|
+
Tensor: A tensor of the same shape as `t` with the rotary positional
|
85
|
+
embeddings applied to the first `rot_dim` channels of the last dimension
|
86
|
+
and the remaining channels concatenated unchanged.
|
87
|
+
"""
|
88
|
+
rot_dim = keras.ops.shape(freqs)[-1]
|
89
|
+
seq_len = keras.ops.shape(t)[1]
|
90
|
+
orig_dtype = t.dtype
|
91
|
+
freqs = freqs[:seq_len, :]
|
92
|
+
freqs = keras.ops.reshape(freqs, (1, seq_len, 1, rot_dim))
|
93
|
+
t_rot = t[..., :rot_dim]
|
94
|
+
t_nonrot = t[..., rot_dim:]
|
95
|
+
t_rotated = t_rot * keras.ops.cos(freqs) + _rotate_half(
|
96
|
+
t_rot
|
97
|
+
) * keras.ops.sin(freqs)
|
98
|
+
out = keras.ops.concatenate([t_rotated, t_nonrot], axis=-1)
|
99
|
+
return keras.ops.cast(out, orig_dtype)
|
100
|
+
|
101
|
+
|
102
|
+
@keras.saving.register_keras_serializable(package="keras_hub")
|
103
|
+
class MoonshineMultiHeadAttention(CachedMultiHeadAttention):
|
104
|
+
"""
|
105
|
+
Moonshine multi-head attention layer.
|
106
|
+
|
107
|
+
Implements a multi-head attention mechanism for Moonshine models with
|
108
|
+
support for rotary position embeddings and different caching strategies.
|
109
|
+
This layer extends the `CachedMultiHeadAttention` base class to include
|
110
|
+
specialized functionality for Moonshine models, such as rotary embeddings
|
111
|
+
and causal masking.
|
112
|
+
|
113
|
+
Args:
|
114
|
+
num_heads: int. Number of attention heads.
|
115
|
+
key_dim: int. Size of each attention head for key.
|
116
|
+
value_dim: int, optional. Size of each attention head for value. If
|
117
|
+
None, defaults to `key_dim`.
|
118
|
+
attention_bias: bool, optional. Whether to include bias in attention
|
119
|
+
projection layers. Defaults to `False`.
|
120
|
+
attention_dropout: float, optional. Dropout probability for attention
|
121
|
+
weights. Defaults to 0.0.
|
122
|
+
use_causal_mask: bool, optional. Whether to apply causal masking to
|
123
|
+
prevent positions from attending to subsequent positions. Defaults
|
124
|
+
to `False`.
|
125
|
+
apply_rotary_embedding: bool, optional. Whether to apply rotary position
|
126
|
+
embeddings to queries and keys. Defaults to `True`.
|
127
|
+
**kwargs: Additional keyword arguments passed to the parent class.
|
128
|
+
"""
|
129
|
+
|
130
|
+
# References:
|
131
|
+
# Based on the HuggingFace implementation of the MoonshineAttention class (https://github.com/huggingface/transformers/blob/fc8764c9a618add64c33e83720f974750bcd0978/src/transformers/models/moonshine/modeling_moonshine.py#L184-L315).
|
132
|
+
|
133
|
+
def __init__(
|
134
|
+
self,
|
135
|
+
num_heads,
|
136
|
+
key_dim,
|
137
|
+
value_dim=None,
|
138
|
+
attention_bias=False,
|
139
|
+
attention_dropout=0.0,
|
140
|
+
use_causal_mask=False,
|
141
|
+
apply_rotary_embedding=True,
|
142
|
+
**kwargs,
|
143
|
+
):
|
144
|
+
kwargs.pop("use_bias", None)
|
145
|
+
kwargs.pop("dropout", None)
|
146
|
+
super().__init__(
|
147
|
+
num_heads=num_heads,
|
148
|
+
key_dim=key_dim,
|
149
|
+
value_dim=value_dim,
|
150
|
+
use_bias=attention_bias,
|
151
|
+
dropout=attention_dropout,
|
152
|
+
**kwargs,
|
153
|
+
)
|
154
|
+
self.attention_bias = attention_bias
|
155
|
+
self.attention_dropout = attention_dropout
|
156
|
+
self.use_causal_mask = use_causal_mask
|
157
|
+
self.apply_rotary_embedding = apply_rotary_embedding
|
158
|
+
|
159
|
+
def build(self, query_shape, value_shape, key_shape=None):
|
160
|
+
# Ensure key_shape is defined.
|
161
|
+
key_shape = value_shape if key_shape is None else key_shape
|
162
|
+
query_rank = len(query_shape)
|
163
|
+
value_rank = len(value_shape)
|
164
|
+
key_rank = len(key_shape)
|
165
|
+
|
166
|
+
# Build query projection layer.
|
167
|
+
einsum_equation, bias_axes, output_rank = _build_proj_equation(
|
168
|
+
free_dims=query_rank - 1, bound_dims=1, output_dims=2
|
169
|
+
)
|
170
|
+
self._query_dense = keras.layers.EinsumDense(
|
171
|
+
einsum_equation,
|
172
|
+
output_shape=_get_output_shape(
|
173
|
+
output_rank - 1, [self._num_heads, self._key_dim]
|
174
|
+
),
|
175
|
+
bias_axes=bias_axes if self._use_bias else None,
|
176
|
+
name="query",
|
177
|
+
**self._get_common_kwargs_for_sublayer(),
|
178
|
+
)
|
179
|
+
self._query_dense.build(query_shape)
|
180
|
+
|
181
|
+
# Build key projection layer.
|
182
|
+
einsum_equation, bias_axes, output_rank = _build_proj_equation(
|
183
|
+
free_dims=key_rank - 1, bound_dims=1, output_dims=2
|
184
|
+
)
|
185
|
+
self._key_dense = keras.layers.EinsumDense(
|
186
|
+
einsum_equation,
|
187
|
+
output_shape=_get_output_shape(
|
188
|
+
output_rank - 1, [self._num_heads, self._key_dim]
|
189
|
+
),
|
190
|
+
bias_axes=bias_axes if self._use_bias else None,
|
191
|
+
name="key",
|
192
|
+
**self._get_common_kwargs_for_sublayer(),
|
193
|
+
)
|
194
|
+
self._key_dense.build(key_shape)
|
195
|
+
|
196
|
+
# Build value projection layer.
|
197
|
+
einsum_equation, bias_axes, output_rank = _build_proj_equation(
|
198
|
+
free_dims=value_rank - 1, bound_dims=1, output_dims=2
|
199
|
+
)
|
200
|
+
self._value_dense = keras.layers.EinsumDense(
|
201
|
+
einsum_equation,
|
202
|
+
output_shape=_get_output_shape(
|
203
|
+
output_rank - 1, [self._num_heads, self._value_dim]
|
204
|
+
),
|
205
|
+
bias_axes=bias_axes if self._use_bias else None,
|
206
|
+
name="value",
|
207
|
+
**self._get_common_kwargs_for_sublayer(),
|
208
|
+
)
|
209
|
+
self._value_dense.build(value_shape)
|
210
|
+
|
211
|
+
# Build the internal attention computation sublayer.
|
212
|
+
self._build_attention(output_rank)
|
213
|
+
|
214
|
+
# Build output projection layer.
|
215
|
+
output_shape = (
|
216
|
+
query_shape[-1] if not self._output_shape else self._output_shape
|
217
|
+
)
|
218
|
+
if isinstance(output_shape, (list, tuple)):
|
219
|
+
output_shape = list(output_shape)
|
220
|
+
else:
|
221
|
+
output_shape = [output_shape]
|
222
|
+
|
223
|
+
einsum_equation, bias_axes, output_rank = _build_proj_equation(
|
224
|
+
free_dims=query_rank - 1,
|
225
|
+
bound_dims=2,
|
226
|
+
output_dims=len(output_shape),
|
227
|
+
)
|
228
|
+
self._output_dense = keras.layers.EinsumDense(
|
229
|
+
einsum_equation,
|
230
|
+
output_shape=_get_output_shape(output_rank - 1, output_shape),
|
231
|
+
bias_axes=bias_axes if self._use_bias else None,
|
232
|
+
name="attention_output",
|
233
|
+
**self._get_common_kwargs_for_sublayer(),
|
234
|
+
)
|
235
|
+
output_dense_input_shape = list(
|
236
|
+
self._query_dense.compute_output_shape(query_shape)
|
237
|
+
)
|
238
|
+
output_dense_input_shape[-1] = self._value_dim
|
239
|
+
self._output_dense.build(tuple(output_dense_input_shape))
|
240
|
+
|
241
|
+
self.built = True
|
242
|
+
|
243
|
+
def _compute_causal_mask(self, query, value=None, for_cache=False):
|
244
|
+
if backend.backend() == "torch" or backend.backend() == "jax":
|
245
|
+
q_seq_length = int(
|
246
|
+
keras.ops.convert_to_numpy(keras.ops.shape(query)[1]).item()
|
247
|
+
)
|
248
|
+
v_seq_length = (
|
249
|
+
int(
|
250
|
+
keras.ops.convert_to_numpy(keras.ops.shape(value)[1]).item()
|
251
|
+
)
|
252
|
+
if value is not None
|
253
|
+
else q_seq_length
|
254
|
+
)
|
255
|
+
elif backend.backend() == "tensorflow":
|
256
|
+
if for_cache:
|
257
|
+
assert value is not None
|
258
|
+
v_seq_length = keras.ops.shape(value)[1]
|
259
|
+
else:
|
260
|
+
v_seq_length = keras.ops.shape(query)[1]
|
261
|
+
q_seq_length = keras.ops.shape(query)[1]
|
262
|
+
n_rows = v_seq_length if for_cache else q_seq_length
|
263
|
+
ones_mask = keras.ops.ones((1, n_rows, v_seq_length), dtype="int32")
|
264
|
+
row_index = keras.ops.cumsum(ones_mask, axis=-2)
|
265
|
+
col_index = keras.ops.cumsum(ones_mask, axis=-1)
|
266
|
+
mask = keras.ops.greater_equal(row_index, col_index)
|
267
|
+
|
268
|
+
if for_cache:
|
269
|
+
mask = mask[:, -q_seq_length:, :]
|
270
|
+
|
271
|
+
return mask
|
272
|
+
|
273
|
+
def call(
|
274
|
+
self,
|
275
|
+
query,
|
276
|
+
value,
|
277
|
+
key,
|
278
|
+
rotary_embedding=None,
|
279
|
+
attention_mask=None,
|
280
|
+
cache=None,
|
281
|
+
cache_update_index=None,
|
282
|
+
training=None,
|
283
|
+
**kwargs,
|
284
|
+
):
|
285
|
+
# Project inputs.
|
286
|
+
query_proj = self._query_dense(query)
|
287
|
+
if rotary_embedding is not None:
|
288
|
+
query_proj = _apply_rotary_pos_emb(query_proj, rotary_embedding)
|
289
|
+
|
290
|
+
# Handle caching.
|
291
|
+
if cache is not None:
|
292
|
+
key_cache = cache[:, 0, ...]
|
293
|
+
value_cache = cache[:, 1, ...]
|
294
|
+
if cache_update_index is None:
|
295
|
+
key_proj = key_cache
|
296
|
+
value_proj = value_cache
|
297
|
+
else:
|
298
|
+
new_key = self._key_dense(key)
|
299
|
+
new_value = self._value_dense(value)
|
300
|
+
if self.apply_rotary_embedding and rotary_embedding is not None:
|
301
|
+
new_key = _apply_rotary_pos_emb(new_key, rotary_embedding)
|
302
|
+
update_shape = keras.ops.shape(new_key)
|
303
|
+
start_indices = [0] * len(update_shape)
|
304
|
+
start_indices[1] = cache_update_index
|
305
|
+
key_proj = keras.ops.slice_update(
|
306
|
+
key_cache, tuple(start_indices), new_key
|
307
|
+
)
|
308
|
+
value_proj = keras.ops.slice_update(
|
309
|
+
value_cache, tuple(start_indices), new_value
|
310
|
+
)
|
311
|
+
cache = keras.ops.stack((key_proj, value_proj), axis=1)
|
312
|
+
|
313
|
+
else:
|
314
|
+
if cache_update_index is not None:
|
315
|
+
raise ValueError(
|
316
|
+
"`cache_update_index` should not be set if `cache` is "
|
317
|
+
f"`None`. Received: cache={cache}, cache_update_index="
|
318
|
+
f"{cache_update_index}"
|
319
|
+
)
|
320
|
+
key_proj = self._key_dense(key)
|
321
|
+
value_proj = self._value_dense(value)
|
322
|
+
if self.apply_rotary_embedding and rotary_embedding is not None:
|
323
|
+
key_proj = _apply_rotary_pos_emb(key_proj, rotary_embedding)
|
324
|
+
|
325
|
+
# Compute attention mask.
|
326
|
+
final_mask = attention_mask
|
327
|
+
|
328
|
+
if final_mask is not None:
|
329
|
+
mask_shape = keras.ops.shape(final_mask)
|
330
|
+
if len(mask_shape) == 2:
|
331
|
+
final_mask = final_mask[:, None, None, :]
|
332
|
+
elif len(mask_shape) == 3:
|
333
|
+
final_mask = final_mask[:, None, :, :]
|
334
|
+
|
335
|
+
attention_kwargs = {
|
336
|
+
k: v for k, v in kwargs.items() if k != "padding_mask"
|
337
|
+
}
|
338
|
+
# Compute attention.
|
339
|
+
attention_output, _ = self._compute_attention(
|
340
|
+
query=query_proj,
|
341
|
+
key=key_proj,
|
342
|
+
value=value_proj,
|
343
|
+
attention_mask=final_mask,
|
344
|
+
training=training,
|
345
|
+
**attention_kwargs,
|
346
|
+
)
|
347
|
+
|
348
|
+
# Project the attention output.
|
349
|
+
output = self._output_dense(attention_output)
|
350
|
+
|
351
|
+
# Return output + cache if cache is provided, otherwise return just
|
352
|
+
# output.
|
353
|
+
if cache is not None:
|
354
|
+
return output, cache
|
355
|
+
return output
|
@@ -0,0 +1,25 @@
|
|
1
|
+
# Metadata for loading pretrained model weights.
|
2
|
+
backbone_presets = {
|
3
|
+
"moonshine_tiny_en": {
|
4
|
+
"metadata": {
|
5
|
+
"description": (
|
6
|
+
"Moonshine tiny model for English speech recognition. "
|
7
|
+
"Developed by Useful Sensors for real-time transcription."
|
8
|
+
),
|
9
|
+
"params": 27092736,
|
10
|
+
"path": "moonshine",
|
11
|
+
},
|
12
|
+
"kaggle_handle": "",
|
13
|
+
},
|
14
|
+
"moonshine_base_en": {
|
15
|
+
"metadata": {
|
16
|
+
"description": (
|
17
|
+
"Moonshine base model for English speech recognition. "
|
18
|
+
"Developed by Useful Sensors for real-time transcription."
|
19
|
+
),
|
20
|
+
"params": 61513920,
|
21
|
+
"path": "moonshine",
|
22
|
+
},
|
23
|
+
"kaggle_handle": "",
|
24
|
+
},
|
25
|
+
}
|
@@ -0,0 +1,62 @@
|
|
1
|
+
from keras_hub.src.api_export import keras_hub_export
|
2
|
+
from keras_hub.src.models.llama.llama_tokenizer import LlamaTokenizer
|
3
|
+
|
4
|
+
|
5
|
+
@keras_hub_export(
|
6
|
+
[
|
7
|
+
"keras_hub.tokenizers.MoonshineTokenizer",
|
8
|
+
"keras_hub.models.MoonshineTokenizer",
|
9
|
+
]
|
10
|
+
)
|
11
|
+
class MoonshineTokenizer(LlamaTokenizer):
|
12
|
+
"""
|
13
|
+
Moonshine tokenizer layer based on `keras_hub.models.LlamaTokenizer`.
|
14
|
+
|
15
|
+
This tokenizer class is an alias of `LlamaTokenizer` but for the Moonshine
|
16
|
+
model. It uses a SentencePiece vocabulary to handle tokenization.
|
17
|
+
|
18
|
+
Args:
|
19
|
+
proto: `str` or `bytes`. Either a string path to a SentencePiece proto
|
20
|
+
file or a bytes object containing a serialized SentencePiece proto.
|
21
|
+
See the [SentencePiece repository](https://github.com/google/sentencepiece)
|
22
|
+
for details on the format.
|
23
|
+
**kwargs: Additional keyword arguments passed to the parent
|
24
|
+
`LlamaTokenizer`.
|
25
|
+
|
26
|
+
Examples:
|
27
|
+
```python
|
28
|
+
from keras_hub.tokenizers import MoonshineTokenizer
|
29
|
+
|
30
|
+
# Initialize tokenizer.
|
31
|
+
tokenizer = MoonshineTokenizer(
|
32
|
+
"keras_hub/src/tests/test_data/llama_test_vocab.spm"
|
33
|
+
)
|
34
|
+
|
35
|
+
# Single input example.
|
36
|
+
single_input = "the quick brown fox"
|
37
|
+
single_tokens = tokenizer(single_input)
|
38
|
+
print("Single input tokenization:")
|
39
|
+
print(f"Input text: {single_input}")
|
40
|
+
print(f"Tokenized: {single_tokens}")
|
41
|
+
|
42
|
+
# Batched input example.
|
43
|
+
batch_input = ["the quick brown fox", "the earth is round"]
|
44
|
+
batch_tokens = tokenizer(batch_input)
|
45
|
+
print("Batch input tokenization:")
|
46
|
+
print(f"Input texts: {batch_input}")
|
47
|
+
print(f"Tokenized: {batch_tokens}")
|
48
|
+
|
49
|
+
# Detokenization example.
|
50
|
+
encoded = tokenizer(single_input)
|
51
|
+
decoded = tokenizer.detokenize(encoded)
|
52
|
+
print("Detokenization:")
|
53
|
+
print(f"Original text: {single_input}")
|
54
|
+
print(f"Encoded: {encoded}")
|
55
|
+
print(f"Decoded: {decoded}")
|
56
|
+
```
|
57
|
+
"""
|
58
|
+
|
59
|
+
# NOTE: The 768 future-use tokens defined in Section 3.1 of the Moonshine
|
60
|
+
# paper, "Moonshine: Speech Recognition for Live Transcription and Voice
|
61
|
+
# Commands" (https://arxiv.org/pdf/2410.15608.pdf) serve no purpose in the
|
62
|
+
# tokenizer at the moment, and are hence not included in the vocabulary.
|
keras_hub/src/version.py
CHANGED