keras-hub-nightly 0.21.0.dev202505140407__py3-none-any.whl → 0.21.0.dev202505150407__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/layers/__init__.py +3 -0
- keras_hub/models/__init__.py +12 -0
- keras_hub/src/models/moonshine/__init__.py +0 -0
- keras_hub/src/models/moonshine/moonshine_audio_converter.py +301 -0
- keras_hub/src/models/moonshine/moonshine_audio_to_text.py +383 -0
- keras_hub/src/models/moonshine/moonshine_audio_to_text_preprocessor.py +267 -0
- keras_hub/src/models/moonshine/moonshine_backbone.py +478 -0
- keras_hub/src/models/moonshine/moonshine_decoder.py +313 -0
- keras_hub/src/models/moonshine/moonshine_encoder.py +212 -0
- keras_hub/src/models/moonshine/moonshine_layers.py +239 -0
- keras_hub/src/models/moonshine/moonshine_multi_head_attention.py +355 -0
- keras_hub/src/models/moonshine/moonshine_presets.py +25 -0
- keras_hub/src/models/moonshine/moonshine_tokenizer.py +62 -0
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +3 -0
- {keras_hub_nightly-0.21.0.dev202505140407.dist-info → keras_hub_nightly-0.21.0.dev202505150407.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.21.0.dev202505140407.dist-info → keras_hub_nightly-0.21.0.dev202505150407.dist-info}/RECORD +19 -8
- {keras_hub_nightly-0.21.0.dev202505140407.dist-info → keras_hub_nightly-0.21.0.dev202505150407.dist-info}/WHEEL +1 -1
- {keras_hub_nightly-0.21.0.dev202505140407.dist-info → keras_hub_nightly-0.21.0.dev202505150407.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,313 @@
|
|
1
|
+
import keras
|
2
|
+
|
3
|
+
from keras_hub.src.layers.modeling.transformer_decoder import TransformerDecoder
|
4
|
+
from keras_hub.src.layers.modeling.transformer_layer_utils import (
|
5
|
+
compute_causal_mask,
|
6
|
+
)
|
7
|
+
from keras_hub.src.layers.modeling.transformer_layer_utils import (
|
8
|
+
merge_padding_and_attention_mask,
|
9
|
+
)
|
10
|
+
from keras_hub.src.models.moonshine.moonshine_layers import MoonshineMLP
|
11
|
+
from keras_hub.src.models.moonshine.moonshine_layers import (
|
12
|
+
moonshine_kernel_initializer,
|
13
|
+
)
|
14
|
+
from keras_hub.src.models.moonshine.moonshine_multi_head_attention import (
|
15
|
+
MoonshineMultiHeadAttention,
|
16
|
+
)
|
17
|
+
from keras_hub.src.utils.keras_utils import clone_initializer
|
18
|
+
|
19
|
+
|
20
|
+
@keras.saving.register_keras_serializable(package="keras_hub")
|
21
|
+
class MoonshineDecoderBlock(TransformerDecoder):
|
22
|
+
"""Moonshine decoder block for sequence processing.
|
23
|
+
|
24
|
+
This layer implements a decoder block that includes self-attention with
|
25
|
+
causal masking, cross-attention with precomputed key/value pairs, and a
|
26
|
+
feedforward network.
|
27
|
+
|
28
|
+
Args:
|
29
|
+
hidden_dim: int. The dimensionality of the model's hidden
|
30
|
+
representations.
|
31
|
+
intermediate_dim: int. The dimensionality of the intermediate
|
32
|
+
representations in the feedforward network.
|
33
|
+
num_heads: int. The number of attention heads for multi-head attention
|
34
|
+
mechanisms.
|
35
|
+
feedforward_expansion_factor: int, optional. A multiplicative factor for
|
36
|
+
scaling the feedforward network dimension. Defaults to 4.
|
37
|
+
use_swiglu_activation: bool, optional. Whether to use the SwiGLU
|
38
|
+
activation in the feedforward network for improved performance.
|
39
|
+
Defaults to True.
|
40
|
+
pad_head_dim_to_multiple_of: int, optional. If specified, pads the head
|
41
|
+
dimension to be a multiple of this value for performance
|
42
|
+
optimization. Defaults to None.
|
43
|
+
initializer_range: float, optional. The standard deviation of the
|
44
|
+
truncated normal distribution used to initialize model weights.
|
45
|
+
Defaults to 0.02.
|
46
|
+
attention_bias: bool, optional. Whether to add a bias term to the
|
47
|
+
attention computations. Defaults to False.
|
48
|
+
attention_dropout: float, optional. The dropout rate applied to
|
49
|
+
attention weights during training. Defaults to 0.0.
|
50
|
+
dtype: str, optional. The data type to use for model computations and
|
51
|
+
weights. Defaults to None.
|
52
|
+
**kwargs: Additional keyword arguments passed to the base layer.
|
53
|
+
"""
|
54
|
+
|
55
|
+
# References:
|
56
|
+
# Defined and formulated based on the UsefulSensors implementation of the
|
57
|
+
# DecoderLayer class (https://github.com/usefulsensors/moonshine/blob/4a000427bd36a1c2c6d20a86c672dbd850b44c88/moonshine/model.py#L348-L466).
|
58
|
+
|
59
|
+
def __init__(
|
60
|
+
self,
|
61
|
+
hidden_dim,
|
62
|
+
intermediate_dim,
|
63
|
+
num_heads,
|
64
|
+
feedforward_expansion_factor=4,
|
65
|
+
use_swiglu_activation=True,
|
66
|
+
pad_head_dim_to_multiple_of=None,
|
67
|
+
initializer_range=0.02,
|
68
|
+
attention_bias=False,
|
69
|
+
attention_dropout=0.0,
|
70
|
+
dtype=None,
|
71
|
+
**kwargs,
|
72
|
+
):
|
73
|
+
kwargs.pop("dropout", None)
|
74
|
+
kwargs.pop("activation", None)
|
75
|
+
kwargs.pop("kernel_initializer", None)
|
76
|
+
self.kernel_initializer = moonshine_kernel_initializer(
|
77
|
+
initializer_range=initializer_range
|
78
|
+
)
|
79
|
+
super().__init__(
|
80
|
+
intermediate_dim=intermediate_dim,
|
81
|
+
num_heads=num_heads,
|
82
|
+
dropout=attention_dropout,
|
83
|
+
activation="gelu" if use_swiglu_activation else "silu",
|
84
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
85
|
+
dtype=dtype,
|
86
|
+
**kwargs,
|
87
|
+
)
|
88
|
+
self.initializer_range = initializer_range
|
89
|
+
self.hidden_dim = hidden_dim
|
90
|
+
self.intermediate_dim = intermediate_dim
|
91
|
+
self.num_heads = num_heads
|
92
|
+
self.feedforward_expansion_factor = feedforward_expansion_factor
|
93
|
+
self.use_swiglu_activation = use_swiglu_activation
|
94
|
+
self.pad_head_dim_to_multiple_of = pad_head_dim_to_multiple_of
|
95
|
+
self.attention_dropout = attention_dropout
|
96
|
+
self.attention_bias = attention_bias
|
97
|
+
|
98
|
+
self.head_dim = hidden_dim // num_heads
|
99
|
+
if pad_head_dim_to_multiple_of is not None:
|
100
|
+
self.head_dim = (
|
101
|
+
(self.head_dim + pad_head_dim_to_multiple_of - 1)
|
102
|
+
// pad_head_dim_to_multiple_of
|
103
|
+
) * pad_head_dim_to_multiple_of
|
104
|
+
|
105
|
+
self.norm1 = keras.layers.LayerNormalization(
|
106
|
+
axis=-1,
|
107
|
+
epsilon=1e-5,
|
108
|
+
center=False,
|
109
|
+
scale=True,
|
110
|
+
dtype=self.dtype,
|
111
|
+
)
|
112
|
+
self.self_attention = MoonshineMultiHeadAttention(
|
113
|
+
num_heads=num_heads,
|
114
|
+
key_dim=self.head_dim,
|
115
|
+
use_bias=False,
|
116
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
117
|
+
attention_bias=attention_bias,
|
118
|
+
attention_dropout=attention_dropout,
|
119
|
+
use_causal_mask=True,
|
120
|
+
apply_rotary_embedding=True,
|
121
|
+
dtype=self.dtype,
|
122
|
+
)
|
123
|
+
self.norm2 = keras.layers.LayerNormalization(
|
124
|
+
axis=-1,
|
125
|
+
epsilon=1e-5,
|
126
|
+
center=False,
|
127
|
+
scale=True,
|
128
|
+
dtype=self.dtype,
|
129
|
+
)
|
130
|
+
self.cross_attention = MoonshineMultiHeadAttention(
|
131
|
+
num_heads=num_heads,
|
132
|
+
key_dim=self.head_dim,
|
133
|
+
use_bias=False,
|
134
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
135
|
+
attention_bias=attention_bias,
|
136
|
+
attention_dropout=attention_dropout,
|
137
|
+
use_causal_mask=False,
|
138
|
+
apply_rotary_embedding=False,
|
139
|
+
dtype=self.dtype,
|
140
|
+
)
|
141
|
+
self.norm3 = keras.layers.LayerNormalization(
|
142
|
+
axis=-1,
|
143
|
+
epsilon=1e-5,
|
144
|
+
center=False,
|
145
|
+
scale=True,
|
146
|
+
dtype=self.dtype,
|
147
|
+
)
|
148
|
+
self.ff = MoonshineMLP(
|
149
|
+
hidden_dim=hidden_dim,
|
150
|
+
feedforward_expansion_factor=feedforward_expansion_factor,
|
151
|
+
use_swiglu_activation=use_swiglu_activation,
|
152
|
+
initializer_range=initializer_range,
|
153
|
+
dtype=self.dtype,
|
154
|
+
)
|
155
|
+
|
156
|
+
def build(self, decoder_sequence_shape, encoder_sequence_shape=None):
|
157
|
+
if encoder_sequence_shape is None:
|
158
|
+
raise ValueError(
|
159
|
+
"Encoder sequence shape must be provided for "
|
160
|
+
"MoonshineDecoderBlock."
|
161
|
+
)
|
162
|
+
context_shape = encoder_sequence_shape # Shape of context
|
163
|
+
|
164
|
+
# Build sublayers.
|
165
|
+
self.norm1.build(decoder_sequence_shape)
|
166
|
+
self.norm2.build(decoder_sequence_shape)
|
167
|
+
self.norm3.build(decoder_sequence_shape)
|
168
|
+
|
169
|
+
self.self_attention.build(
|
170
|
+
query_shape=decoder_sequence_shape,
|
171
|
+
key_shape=decoder_sequence_shape,
|
172
|
+
value_shape=decoder_sequence_shape,
|
173
|
+
)
|
174
|
+
|
175
|
+
self.cross_attention.build(
|
176
|
+
query_shape=decoder_sequence_shape,
|
177
|
+
key_shape=context_shape,
|
178
|
+
value_shape=context_shape,
|
179
|
+
)
|
180
|
+
|
181
|
+
self.ff.build(decoder_sequence_shape)
|
182
|
+
self.built = True
|
183
|
+
|
184
|
+
def _compute_self_attention_mask(
|
185
|
+
self,
|
186
|
+
decoder_sequence,
|
187
|
+
decoder_padding_mask,
|
188
|
+
self_attention_cache=None,
|
189
|
+
self_attention_cache_update_index=None,
|
190
|
+
):
|
191
|
+
decoder_mask = merge_padding_and_attention_mask(
|
192
|
+
inputs=decoder_sequence,
|
193
|
+
padding_mask=decoder_padding_mask,
|
194
|
+
attention_mask=None,
|
195
|
+
)
|
196
|
+
if self.self_attention.use_causal_mask:
|
197
|
+
batch_size = keras.ops.shape(decoder_sequence)[0]
|
198
|
+
output_length = keras.ops.shape(decoder_sequence)[1]
|
199
|
+
current_cache_update_index = (
|
200
|
+
0
|
201
|
+
if self_attention_cache_update_index is None
|
202
|
+
else self_attention_cache_update_index
|
203
|
+
)
|
204
|
+
if self_attention_cache is not None:
|
205
|
+
input_length = keras.ops.shape(self_attention_cache)[2]
|
206
|
+
else:
|
207
|
+
input_length = output_length
|
208
|
+
causal_mask = compute_causal_mask(
|
209
|
+
batch_size,
|
210
|
+
input_length,
|
211
|
+
output_length,
|
212
|
+
current_cache_update_index,
|
213
|
+
)
|
214
|
+
return (
|
215
|
+
keras.ops.minimum(decoder_mask, causal_mask)
|
216
|
+
if decoder_mask is not None
|
217
|
+
else causal_mask
|
218
|
+
)
|
219
|
+
return decoder_mask
|
220
|
+
|
221
|
+
def call(
|
222
|
+
self,
|
223
|
+
decoder_sequence,
|
224
|
+
encoder_sequence,
|
225
|
+
rotary_embedding,
|
226
|
+
encoder_attention_mask=None,
|
227
|
+
decoder_padding_mask=None,
|
228
|
+
encoder_padding_mask=None,
|
229
|
+
self_attention_cache=None,
|
230
|
+
self_attention_cache_update_index=None,
|
231
|
+
cross_attention_cache=None,
|
232
|
+
cross_attention_cache_update_index=None,
|
233
|
+
training=None,
|
234
|
+
):
|
235
|
+
x = decoder_sequence
|
236
|
+
context = encoder_sequence
|
237
|
+
has_self_attention_cache = self_attention_cache is not None
|
238
|
+
has_cross_attention_cache = cross_attention_cache is not None
|
239
|
+
|
240
|
+
self_attention_mask = self._compute_self_attention_mask(
|
241
|
+
decoder_sequence=x,
|
242
|
+
decoder_padding_mask=decoder_padding_mask,
|
243
|
+
self_attention_cache=self_attention_cache,
|
244
|
+
self_attention_cache_update_index=self_attention_cache_update_index,
|
245
|
+
)
|
246
|
+
|
247
|
+
# Self attention block.
|
248
|
+
residual = x
|
249
|
+
x_norm1 = self.norm1(x)
|
250
|
+
x_self_attn = self.self_attention(
|
251
|
+
query=x_norm1,
|
252
|
+
key=x_norm1,
|
253
|
+
value=x_norm1,
|
254
|
+
rotary_embedding=rotary_embedding,
|
255
|
+
cache=self_attention_cache,
|
256
|
+
cache_update_index=self_attention_cache_update_index,
|
257
|
+
attention_mask=self_attention_mask,
|
258
|
+
training=training,
|
259
|
+
)
|
260
|
+
if has_self_attention_cache:
|
261
|
+
x_self_attn, self_attention_cache = x_self_attn
|
262
|
+
x = x_self_attn + residual
|
263
|
+
# Cross attention block.
|
264
|
+
residual = x
|
265
|
+
x_norm2 = self.norm2(x)
|
266
|
+
cross_attention_mask = merge_padding_and_attention_mask(
|
267
|
+
inputs=encoder_sequence,
|
268
|
+
padding_mask=encoder_padding_mask,
|
269
|
+
attention_mask=encoder_attention_mask,
|
270
|
+
)
|
271
|
+
x_cross_attn = self.cross_attention(
|
272
|
+
query=x_norm2,
|
273
|
+
key=context,
|
274
|
+
value=context,
|
275
|
+
cache=cross_attention_cache,
|
276
|
+
cache_update_index=cross_attention_cache_update_index,
|
277
|
+
attention_mask=cross_attention_mask,
|
278
|
+
training=training,
|
279
|
+
)
|
280
|
+
if has_cross_attention_cache:
|
281
|
+
x_cross_attn, cross_attention_cache = x_cross_attn
|
282
|
+
x = x_cross_attn + residual
|
283
|
+
residual = x
|
284
|
+
x_norm3 = self.norm3(x)
|
285
|
+
x_ff = self.ff(x_norm3)
|
286
|
+
x = x_ff + residual
|
287
|
+
|
288
|
+
if has_self_attention_cache:
|
289
|
+
return x, self_attention_cache
|
290
|
+
return x
|
291
|
+
|
292
|
+
def compute_output_shape(
|
293
|
+
self, decoder_sequence_shape, encoder_sequence_shape=None
|
294
|
+
):
|
295
|
+
return decoder_sequence_shape
|
296
|
+
|
297
|
+
def get_config(self):
|
298
|
+
config = super().get_config()
|
299
|
+
config.update(
|
300
|
+
{
|
301
|
+
"hidden_dim": self.hidden_dim,
|
302
|
+
"intermediate_dim": self.intermediate_dim,
|
303
|
+
"num_heads": self.num_heads,
|
304
|
+
"feedforward_expansion_factor": self.feedforward_expansion_factor, # noqa: E501
|
305
|
+
"use_swiglu_activation": self.use_swiglu_activation,
|
306
|
+
"pad_head_dim_to_multiple_of": self.pad_head_dim_to_multiple_of, # noqa: E501
|
307
|
+
"initializer_range": self.initializer_range,
|
308
|
+
"attention_bias": self.attention_bias,
|
309
|
+
"attention_dropout": self.attention_dropout,
|
310
|
+
"dtype": self.dtype,
|
311
|
+
}
|
312
|
+
)
|
313
|
+
return config
|
@@ -0,0 +1,212 @@
|
|
1
|
+
import keras
|
2
|
+
|
3
|
+
from keras_hub.src.layers.modeling.transformer_encoder import TransformerEncoder
|
4
|
+
from keras_hub.src.models.moonshine.moonshine_layers import MoonshineMLP
|
5
|
+
from keras_hub.src.models.moonshine.moonshine_layers import (
|
6
|
+
moonshine_kernel_initializer,
|
7
|
+
)
|
8
|
+
from keras_hub.src.models.moonshine.moonshine_multi_head_attention import (
|
9
|
+
MoonshineMultiHeadAttention,
|
10
|
+
)
|
11
|
+
from keras_hub.src.utils.keras_utils import clone_initializer
|
12
|
+
|
13
|
+
|
14
|
+
@keras.saving.register_keras_serializable(package="keras_hub")
|
15
|
+
class MoonshineEncoderBlock(TransformerEncoder):
|
16
|
+
"""
|
17
|
+
Moonshine encoder block for sequence processing.
|
18
|
+
|
19
|
+
Implements a standard encoder block with self-attention and feedforward
|
20
|
+
sublayers, including residual connections and layer normalization. The
|
21
|
+
implementation utilizes Moonshine-specific attention and feedforward
|
22
|
+
mechanisms.
|
23
|
+
|
24
|
+
Args:
|
25
|
+
hidden_dim: int. The dimensionality of the model's hidden
|
26
|
+
representations throughout the block.
|
27
|
+
intermediate_dim: int. The dimensionality used in projections before
|
28
|
+
applying non-linearities.
|
29
|
+
num_heads: int. The number of attention heads for multi-head attention
|
30
|
+
computation.
|
31
|
+
feedforward_expansion_factor: int, optional. A multiplier for expanding
|
32
|
+
the dimension in the feedforward network. Defaults to 4.
|
33
|
+
use_swiglu_activation: bool, optional. Whether to use SwiGLU activation
|
34
|
+
(True) or LinearGeLU (False) in the feedforward sublayer. Defaults
|
35
|
+
to False.
|
36
|
+
pad_head_dim_to_multiple_of: int, optional. If specified, pads the head
|
37
|
+
dimension to be a multiple of this value for hardware optimization.
|
38
|
+
Defaults to None.
|
39
|
+
initializer_range: float, optional. The standard deviation of the
|
40
|
+
truncated normal distribution used for weight initialization.
|
41
|
+
Defaults to 0.02.
|
42
|
+
attention_bias: bool, optional. Whether to use a bias term in the
|
43
|
+
attention mechanism. Defaults to False.
|
44
|
+
attention_dropout: float, optional. The dropout rate applied to the
|
45
|
+
attention weights. Defaults to 0.0.
|
46
|
+
dtype: str, optional. The data type to use for model computations and
|
47
|
+
weights. Defaults to None.
|
48
|
+
**kwargs: Additional keyword arguments passed to the base layer.
|
49
|
+
"""
|
50
|
+
|
51
|
+
# References:
|
52
|
+
# Defined and formulated based on the UsefulSensors implementation of the
|
53
|
+
# EncoderLayer class (https://github.com/usefulsensors/moonshine/blob/4a000427bd36a1c2c6d20a86c672dbd850b44c88/moonshine/model.py#L124-L161).
|
54
|
+
|
55
|
+
def __init__(
|
56
|
+
self,
|
57
|
+
hidden_dim,
|
58
|
+
intermediate_dim,
|
59
|
+
num_heads,
|
60
|
+
feedforward_expansion_factor=4,
|
61
|
+
use_swiglu_activation=False,
|
62
|
+
pad_head_dim_to_multiple_of=None,
|
63
|
+
dtype=None,
|
64
|
+
initializer_range=0.02,
|
65
|
+
attention_bias=False,
|
66
|
+
attention_dropout=0.0,
|
67
|
+
**kwargs,
|
68
|
+
):
|
69
|
+
kwargs.pop("dropout", None)
|
70
|
+
kwargs.pop("activation", None)
|
71
|
+
kwargs.pop("kernel_initializer", None)
|
72
|
+
self.kernel_initializer = moonshine_kernel_initializer(
|
73
|
+
initializer_range=initializer_range
|
74
|
+
)
|
75
|
+
super().__init__(
|
76
|
+
intermediate_dim=intermediate_dim,
|
77
|
+
num_heads=num_heads,
|
78
|
+
dropout=attention_dropout,
|
79
|
+
activation="gelu" if use_swiglu_activation else "silu",
|
80
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
81
|
+
dtype=dtype,
|
82
|
+
**kwargs,
|
83
|
+
)
|
84
|
+
self.attention_bias = attention_bias
|
85
|
+
self.attention_dropout = attention_dropout
|
86
|
+
self.initializer_range = initializer_range
|
87
|
+
self.hidden_dim = hidden_dim
|
88
|
+
self.intermediate_dim = intermediate_dim
|
89
|
+
self.num_heads = num_heads
|
90
|
+
self.feedforward_expansion_factor = feedforward_expansion_factor
|
91
|
+
self.use_swiglu_activation = use_swiglu_activation
|
92
|
+
|
93
|
+
# Self-attention sublayers.
|
94
|
+
self.pad_head_dim_to_multiple_of = pad_head_dim_to_multiple_of
|
95
|
+
|
96
|
+
self.head_dim = hidden_dim // num_heads
|
97
|
+
if pad_head_dim_to_multiple_of is not None:
|
98
|
+
self.head_dim = (
|
99
|
+
(self.head_dim + pad_head_dim_to_multiple_of - 1)
|
100
|
+
// pad_head_dim_to_multiple_of
|
101
|
+
) * pad_head_dim_to_multiple_of
|
102
|
+
|
103
|
+
self.self_attention_layer = MoonshineMultiHeadAttention(
|
104
|
+
num_heads=num_heads,
|
105
|
+
key_dim=self.head_dim,
|
106
|
+
use_bias=False,
|
107
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
108
|
+
attention_bias=attention_bias,
|
109
|
+
attention_dropout=attention_dropout,
|
110
|
+
use_causal_mask=False,
|
111
|
+
apply_rotary_embedding=True,
|
112
|
+
name="self_attention_layer",
|
113
|
+
dtype=self.dtype,
|
114
|
+
)
|
115
|
+
self.self_attention_layer_norm = keras.layers.LayerNormalization(
|
116
|
+
axis=-1,
|
117
|
+
epsilon=1e-5,
|
118
|
+
center=False,
|
119
|
+
scale=True,
|
120
|
+
name="self_attention_layer_norm",
|
121
|
+
dtype=self.dtype,
|
122
|
+
)
|
123
|
+
|
124
|
+
# Feedforward sublayers.
|
125
|
+
self.feedforward_layer_norm = keras.layers.LayerNormalization(
|
126
|
+
axis=-1,
|
127
|
+
epsilon=1e-5,
|
128
|
+
center=False,
|
129
|
+
scale=True,
|
130
|
+
name="feedforward_layer_norm",
|
131
|
+
dtype=self.dtype,
|
132
|
+
)
|
133
|
+
self.feedforward = MoonshineMLP(
|
134
|
+
hidden_dim=hidden_dim,
|
135
|
+
feedforward_expansion_factor=feedforward_expansion_factor,
|
136
|
+
use_swiglu_activation=use_swiglu_activation,
|
137
|
+
initializer_range=initializer_range,
|
138
|
+
name="feedforward",
|
139
|
+
dtype=self.dtype,
|
140
|
+
)
|
141
|
+
|
142
|
+
def build(self, input_shape):
|
143
|
+
if isinstance(input_shape, dict):
|
144
|
+
encoder_input_shape = input_shape["input_values"]
|
145
|
+
else:
|
146
|
+
encoder_input_shape = input_shape
|
147
|
+
# Build self-attention branch.
|
148
|
+
self.self_attention_layer_norm.build(encoder_input_shape)
|
149
|
+
self.self_attention_layer.build(
|
150
|
+
encoder_input_shape, encoder_input_shape, encoder_input_shape
|
151
|
+
)
|
152
|
+
# Build feedforward branch.
|
153
|
+
self.feedforward_layer_norm.build(encoder_input_shape)
|
154
|
+
# The feedforward layer expects the last dimension to be hidden_dim.
|
155
|
+
feed_forward_input_shape = list(encoder_input_shape)
|
156
|
+
feed_forward_input_shape[-1] = self.hidden_dim
|
157
|
+
self.feedforward.build(tuple(feed_forward_input_shape))
|
158
|
+
self.built = True
|
159
|
+
|
160
|
+
def call(
|
161
|
+
self,
|
162
|
+
inputs,
|
163
|
+
rotary_embedding,
|
164
|
+
attention_mask=None,
|
165
|
+
training=None,
|
166
|
+
**kwargs,
|
167
|
+
):
|
168
|
+
x = inputs
|
169
|
+
|
170
|
+
# Self-attention block with residual connection.
|
171
|
+
attention_residual = x
|
172
|
+
x = self.self_attention_layer_norm(x)
|
173
|
+
x = self.self_attention_layer(
|
174
|
+
query=x,
|
175
|
+
value=x,
|
176
|
+
key=x,
|
177
|
+
rotary_embedding=rotary_embedding,
|
178
|
+
attention_mask=attention_mask,
|
179
|
+
training=training,
|
180
|
+
**kwargs,
|
181
|
+
)
|
182
|
+
x = x + attention_residual
|
183
|
+
|
184
|
+
# Feedforward block with residual connection.
|
185
|
+
ff_residual = x
|
186
|
+
x = self.feedforward_layer_norm(x)
|
187
|
+
x = self.feedforward(x)
|
188
|
+
x = x + ff_residual
|
189
|
+
|
190
|
+
return x
|
191
|
+
|
192
|
+
def compute_output_shape(self, input_shape):
|
193
|
+
return input_shape
|
194
|
+
|
195
|
+
def get_config(self):
|
196
|
+
# ==== Config ====
|
197
|
+
config = super().get_config()
|
198
|
+
config.update(
|
199
|
+
{
|
200
|
+
"hidden_dim": self.hidden_dim,
|
201
|
+
"intermediate_dim": self.intermediate_dim,
|
202
|
+
"num_heads": self.num_heads,
|
203
|
+
"feedforward_expansion_factor": self.feedforward_expansion_factor, # noqa: E501
|
204
|
+
"use_swiglu_activation": self.use_swiglu_activation,
|
205
|
+
"pad_head_dim_to_multiple_of": self.pad_head_dim_to_multiple_of,
|
206
|
+
"initializer_range": self.initializer_range,
|
207
|
+
"attention_bias": self.attention_bias,
|
208
|
+
"attention_dropout": self.attention_dropout,
|
209
|
+
"dtype": self.dtype,
|
210
|
+
}
|
211
|
+
)
|
212
|
+
return config
|