keras-hub-nightly 0.21.0.dev202505130407__py3-none-any.whl → 0.21.0.dev202505150407__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/layers/__init__.py +3 -0
- keras_hub/models/__init__.py +12 -0
- keras_hub/src/models/moonshine/__init__.py +0 -0
- keras_hub/src/models/moonshine/moonshine_audio_converter.py +301 -0
- keras_hub/src/models/moonshine/moonshine_audio_to_text.py +383 -0
- keras_hub/src/models/moonshine/moonshine_audio_to_text_preprocessor.py +267 -0
- keras_hub/src/models/moonshine/moonshine_backbone.py +478 -0
- keras_hub/src/models/moonshine/moonshine_decoder.py +313 -0
- keras_hub/src/models/moonshine/moonshine_encoder.py +212 -0
- keras_hub/src/models/moonshine/moonshine_layers.py +239 -0
- keras_hub/src/models/moonshine/moonshine_multi_head_attention.py +355 -0
- keras_hub/src/models/moonshine/moonshine_presets.py +25 -0
- keras_hub/src/models/moonshine/moonshine_tokenizer.py +62 -0
- keras_hub/src/version.py +1 -1
- keras_hub/tokenizers/__init__.py +3 -0
- {keras_hub_nightly-0.21.0.dev202505130407.dist-info → keras_hub_nightly-0.21.0.dev202505150407.dist-info}/METADATA +1 -1
- {keras_hub_nightly-0.21.0.dev202505130407.dist-info → keras_hub_nightly-0.21.0.dev202505150407.dist-info}/RECORD +19 -8
- {keras_hub_nightly-0.21.0.dev202505130407.dist-info → keras_hub_nightly-0.21.0.dev202505150407.dist-info}/WHEEL +1 -1
- {keras_hub_nightly-0.21.0.dev202505130407.dist-info → keras_hub_nightly-0.21.0.dev202505150407.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,478 @@
|
|
1
|
+
import keras
|
2
|
+
|
3
|
+
from keras_hub.src.api_export import keras_hub_export
|
4
|
+
from keras_hub.src.layers.modeling.reversible_embedding import (
|
5
|
+
ReversibleEmbedding,
|
6
|
+
)
|
7
|
+
from keras_hub.src.models.backbone import Backbone
|
8
|
+
from keras_hub.src.models.moonshine.moonshine_decoder import (
|
9
|
+
MoonshineDecoderBlock,
|
10
|
+
)
|
11
|
+
from keras_hub.src.models.moonshine.moonshine_encoder import (
|
12
|
+
MoonshineEncoderBlock,
|
13
|
+
)
|
14
|
+
from keras_hub.src.models.moonshine.moonshine_layers import (
|
15
|
+
MoonshineRotaryEmbedding,
|
16
|
+
)
|
17
|
+
from keras_hub.src.models.moonshine.moonshine_layers import (
|
18
|
+
moonshine_kernel_initializer,
|
19
|
+
)
|
20
|
+
|
21
|
+
|
22
|
+
def compute_output_lengths(input_lengths):
|
23
|
+
lengths = keras.ops.cast(input_lengths, "float32")
|
24
|
+
lengths = keras.ops.floor((lengths - 127) / 64) + 1
|
25
|
+
lengths = keras.ops.floor((lengths - 7) / 3) + 1
|
26
|
+
lengths = keras.ops.floor((lengths - 3) / 2) + 1
|
27
|
+
return keras.ops.maximum(keras.ops.cast(lengths, "int32"), 0)
|
28
|
+
|
29
|
+
|
30
|
+
@keras.saving.register_keras_serializable(package="keras_hub")
|
31
|
+
class ComputeAttentionMask(keras.layers.Layer):
|
32
|
+
def call(self, features_for_shape, output_lengths):
|
33
|
+
max_output_length = keras.ops.shape(features_for_shape)[1]
|
34
|
+
indices = keras.ops.arange(max_output_length, dtype="int32")
|
35
|
+
attention_mask = indices[None, :] < output_lengths[:, None]
|
36
|
+
attention_mask = keras.ops.cast(attention_mask, "bool")
|
37
|
+
return attention_mask
|
38
|
+
|
39
|
+
def compute_output_shape(self, input_shapes):
|
40
|
+
batch_dim = None
|
41
|
+
if isinstance(input_shapes, (list, tuple)) and len(input_shapes) > 0:
|
42
|
+
features_shape = input_shapes[0]
|
43
|
+
if (
|
44
|
+
isinstance(features_shape, (list, tuple))
|
45
|
+
and len(features_shape) > 0
|
46
|
+
):
|
47
|
+
batch_dim = features_shape[0]
|
48
|
+
return (batch_dim, None)
|
49
|
+
|
50
|
+
|
51
|
+
@keras.saving.register_keras_serializable(package="keras_hub")
|
52
|
+
class Arange(keras.layers.Layer):
|
53
|
+
def call(self, inputs):
|
54
|
+
sequence_length = keras.ops.shape(inputs)[1]
|
55
|
+
return keras.ops.arange(sequence_length, dtype="int32")
|
56
|
+
|
57
|
+
|
58
|
+
@keras_hub_export("keras_hub.models.MoonshineBackbone")
|
59
|
+
class MoonshineBackbone(Backbone):
|
60
|
+
"""Moonshine backbone with integrated audio feature extraction.
|
61
|
+
|
62
|
+
This class implements an encoder-decoder backbone, as used in the Moonshine
|
63
|
+
ASR system. It includes initial convolutional layers for audio feature
|
64
|
+
extraction followed by `MoonshineEncoderBlock` instances for processing
|
65
|
+
these features and `MoonshineDecoderBlock` instances for generating output
|
66
|
+
sequences.
|
67
|
+
|
68
|
+
Args:
|
69
|
+
vocabulary_size: int. The size of the vocabulary for the embedding
|
70
|
+
layers.
|
71
|
+
filter_dim: int. The number of filters for the initial convolutional
|
72
|
+
feature extractor layers. Typically matches `hidden_dim`.
|
73
|
+
encoder_num_layers: int. The number of stacked encoder blocks.
|
74
|
+
decoder_num_layers: int. The number of stacked decoder blocks.
|
75
|
+
hidden_dim: int. The dimensionality of the model's hidden
|
76
|
+
representations and embeddings.
|
77
|
+
intermediate_dim: int. The dimensionality of the intermediate
|
78
|
+
representations in feedforward networks.
|
79
|
+
encoder_num_heads: int. The number of attention heads in the encoder's
|
80
|
+
multi-head attention.
|
81
|
+
decoder_num_heads: int. The number of attention heads in the decoder's
|
82
|
+
multi-head attention.
|
83
|
+
feedforward_expansion_factor: int, optional. A multiplier applied to
|
84
|
+
`intermediate_dim` to determine the total width of the feedforward
|
85
|
+
network. Defaults to 4.
|
86
|
+
encoder_use_swiglu_activation: bool, optional. When True, uses SwiGLU
|
87
|
+
in the encoder feedforward network. Defaults to False.
|
88
|
+
decoder_use_swiglu_activation: bool, optional. When True, uses SwiGLU
|
89
|
+
in the decoder feedforward network. Defaults to True.
|
90
|
+
max_position_embeddings: int, optional. The maximum sequence length for
|
91
|
+
position embeddings. Defaults to 2048.
|
92
|
+
pad_head_dim_to_multiple_of: int, optional. If specified, pads the head
|
93
|
+
dimension to be a multiple of this value for performance
|
94
|
+
optimization. Defaults to None.
|
95
|
+
partial_rotary_factor: float, optional. The fraction of dimensions to
|
96
|
+
apply rotary position embeddings to. Defaults to 0.62.
|
97
|
+
dropout: float, optional. The dropout probability for input dropout
|
98
|
+
layers. Defaults to 0.0.
|
99
|
+
initializer_range: float, optional. The standard deviation of the
|
100
|
+
truncated normal initializer for weights. Defaults to 0.02.
|
101
|
+
rope_theta: float, optional. The base frequency for rotary position
|
102
|
+
embeddings. Defaults to 10,000.0.
|
103
|
+
attention_bias: bool, optional. Whether to use bias in attention
|
104
|
+
mechanisms. Defaults to False.
|
105
|
+
attention_dropout: float, optional. The dropout probability for
|
106
|
+
attention mechanisms. Defaults to 0.0.
|
107
|
+
dtype: str, optional. The dtype to use for model computations and
|
108
|
+
weights. Defaults to None.
|
109
|
+
|
110
|
+
Examples:
|
111
|
+
```python
|
112
|
+
import numpy as np
|
113
|
+
import keras
|
114
|
+
from keras_hub.models import MoonshineBackbone
|
115
|
+
|
116
|
+
# Create random input data for demonstration.
|
117
|
+
# Input is now raw-ish audio features (e.g., from MoonshineAudioConverter).
|
118
|
+
encoder_raw_input_values = np.random.rand(1, 16000, 1).astype("float32")
|
119
|
+
# Mask corresponding to the raw input time dimension
|
120
|
+
encoder_padding_mask = np.ones((1, 16000), dtype="bool")
|
121
|
+
decoder_token_ids = np.random.randint(
|
122
|
+
0, 1000, size=(1, 20), dtype="int32"
|
123
|
+
)
|
124
|
+
decoder_padding_mask = np.ones((1, 20), dtype="bool")
|
125
|
+
|
126
|
+
# Initialize the Moonshine backbone with specific parameters.
|
127
|
+
backbone = MoonshineBackbone(
|
128
|
+
vocabulary_size=10000,
|
129
|
+
filter_dim=256,
|
130
|
+
encoder_num_layers=6,
|
131
|
+
decoder_num_layers=6,
|
132
|
+
hidden_dim=256,
|
133
|
+
intermediate_dim=512,
|
134
|
+
encoder_num_heads=8,
|
135
|
+
decoder_num_heads=8,
|
136
|
+
feedforward_expansion_factor=4,
|
137
|
+
decoder_use_swiglu_activation=True,
|
138
|
+
encoder_use_swiglu_activation=False,
|
139
|
+
)
|
140
|
+
|
141
|
+
# Forward pass through the model.
|
142
|
+
outputs = backbone(
|
143
|
+
{
|
144
|
+
"encoder_input_values": encoder_raw_input_values,
|
145
|
+
"encoder_padding_mask": encoder_padding_mask,
|
146
|
+
"decoder_token_ids": decoder_token_ids,
|
147
|
+
"decoder_padding_mask": decoder_padding_mask,
|
148
|
+
}
|
149
|
+
)
|
150
|
+
|
151
|
+
# Display the outputs.
|
152
|
+
print("Encoder output shape:", outputs["encoder_sequence_output"].shape)
|
153
|
+
print("Decoder output shape:", outputs["decoder_sequence_output"].shape)
|
154
|
+
```
|
155
|
+
"""
|
156
|
+
|
157
|
+
# References:
|
158
|
+
# Feature Extractor: UsefulSensors implementation (https://github.com/usefulsensors/moonshine/blob/4a000427bd36a1c2c6d20a86c672dbd850b44c88/moonshine/model.py#L6-L32)
|
159
|
+
# Transformer Backbone: Hugging Face implementation (https://github.com/huggingface/transformers/blob/dcbdf7e962c4b36140cc9ee76f870016121e69e5/src/transformers/models/moonshine/modeling_moonshine.py#L1326-L1486).
|
160
|
+
|
161
|
+
def __init__(
|
162
|
+
self,
|
163
|
+
vocabulary_size,
|
164
|
+
filter_dim,
|
165
|
+
encoder_num_layers,
|
166
|
+
decoder_num_layers,
|
167
|
+
hidden_dim,
|
168
|
+
intermediate_dim,
|
169
|
+
encoder_num_heads,
|
170
|
+
decoder_num_heads,
|
171
|
+
feedforward_expansion_factor=4,
|
172
|
+
encoder_use_swiglu_activation=False,
|
173
|
+
decoder_use_swiglu_activation=True,
|
174
|
+
max_position_embeddings=2048,
|
175
|
+
pad_head_dim_to_multiple_of=None,
|
176
|
+
partial_rotary_factor=0.62,
|
177
|
+
dropout=0.0,
|
178
|
+
initializer_range=0.02,
|
179
|
+
rope_theta=10000.0,
|
180
|
+
attention_bias=False,
|
181
|
+
attention_dropout=0.0,
|
182
|
+
dtype=None,
|
183
|
+
**kwargs,
|
184
|
+
):
|
185
|
+
# ==== Layers ====
|
186
|
+
self._compute_mask_layer = ComputeAttentionMask(
|
187
|
+
name="compute_attention_mask"
|
188
|
+
)
|
189
|
+
|
190
|
+
# Feature extractor layers.
|
191
|
+
self.conv1 = keras.layers.Conv1D(
|
192
|
+
filters=filter_dim,
|
193
|
+
kernel_size=127,
|
194
|
+
strides=64,
|
195
|
+
use_bias=False,
|
196
|
+
padding="valid",
|
197
|
+
kernel_initializer=moonshine_kernel_initializer(
|
198
|
+
initializer_range=initializer_range
|
199
|
+
),
|
200
|
+
name="conv1",
|
201
|
+
dtype=dtype,
|
202
|
+
)
|
203
|
+
self.group_norm = keras.layers.GroupNormalization(
|
204
|
+
groups=1,
|
205
|
+
axis=-1,
|
206
|
+
epsilon=1e-5,
|
207
|
+
center=True,
|
208
|
+
scale=True,
|
209
|
+
name="group_norm",
|
210
|
+
dtype=dtype,
|
211
|
+
)
|
212
|
+
self.tanh_after_conv1 = keras.layers.Activation(
|
213
|
+
"tanh", name="tanh_after_conv1", dtype=dtype
|
214
|
+
)
|
215
|
+
self.conv2 = keras.layers.Conv1D(
|
216
|
+
filters=2 * filter_dim,
|
217
|
+
kernel_size=7,
|
218
|
+
strides=3,
|
219
|
+
use_bias=True,
|
220
|
+
padding="valid",
|
221
|
+
kernel_initializer=moonshine_kernel_initializer(
|
222
|
+
initializer_range=initializer_range
|
223
|
+
),
|
224
|
+
name="conv2",
|
225
|
+
dtype=dtype,
|
226
|
+
)
|
227
|
+
self.gelu_after_conv2 = keras.layers.Activation(
|
228
|
+
"gelu", name="gelu_after_conv2", dtype=dtype
|
229
|
+
)
|
230
|
+
self.conv3 = keras.layers.Conv1D(
|
231
|
+
filters=filter_dim,
|
232
|
+
kernel_size=3,
|
233
|
+
strides=2,
|
234
|
+
use_bias=True,
|
235
|
+
padding="valid",
|
236
|
+
kernel_initializer=moonshine_kernel_initializer(
|
237
|
+
initializer_range=initializer_range
|
238
|
+
),
|
239
|
+
name="conv3",
|
240
|
+
dtype=dtype,
|
241
|
+
)
|
242
|
+
self.gelu_after_conv3 = keras.layers.Activation(
|
243
|
+
"gelu", name="gelu_after_conv3", dtype=dtype
|
244
|
+
)
|
245
|
+
|
246
|
+
# Transformer layers.
|
247
|
+
encoder_head_dim = hidden_dim // encoder_num_heads
|
248
|
+
if pad_head_dim_to_multiple_of:
|
249
|
+
encoder_head_dim = (
|
250
|
+
(encoder_head_dim + pad_head_dim_to_multiple_of - 1)
|
251
|
+
// pad_head_dim_to_multiple_of
|
252
|
+
) * pad_head_dim_to_multiple_of
|
253
|
+
|
254
|
+
decoder_head_dim = hidden_dim // decoder_num_heads
|
255
|
+
if pad_head_dim_to_multiple_of:
|
256
|
+
decoder_head_dim = (
|
257
|
+
(decoder_head_dim + pad_head_dim_to_multiple_of - 1)
|
258
|
+
// pad_head_dim_to_multiple_of
|
259
|
+
) * pad_head_dim_to_multiple_of
|
260
|
+
|
261
|
+
# Embedding layer for decoder.
|
262
|
+
self.token_embedding = ReversibleEmbedding(
|
263
|
+
input_dim=vocabulary_size,
|
264
|
+
output_dim=hidden_dim,
|
265
|
+
embeddings_initializer=moonshine_kernel_initializer(
|
266
|
+
initializer_range=initializer_range
|
267
|
+
),
|
268
|
+
name="token_embedding",
|
269
|
+
dtype=dtype,
|
270
|
+
)
|
271
|
+
|
272
|
+
# Rotary embeddings for encoder and decoder.
|
273
|
+
self.encoder_rotary_embedding = MoonshineRotaryEmbedding(
|
274
|
+
head_dim=encoder_head_dim,
|
275
|
+
max_position_embeddings=max_position_embeddings,
|
276
|
+
partial_rotary_factor=partial_rotary_factor,
|
277
|
+
base_value=rope_theta,
|
278
|
+
name="encoder_rotary_embedding",
|
279
|
+
dtype=dtype,
|
280
|
+
)
|
281
|
+
|
282
|
+
self.decoder_rotary_embedding = MoonshineRotaryEmbedding(
|
283
|
+
head_dim=decoder_head_dim,
|
284
|
+
max_position_embeddings=max_position_embeddings,
|
285
|
+
partial_rotary_factor=partial_rotary_factor,
|
286
|
+
base_value=rope_theta,
|
287
|
+
name="decoder_rotary_embedding",
|
288
|
+
dtype=dtype,
|
289
|
+
)
|
290
|
+
|
291
|
+
# Dropout for encoder.
|
292
|
+
self.encoder_dropout = keras.layers.Dropout(
|
293
|
+
dropout, name="encoder_dropout", dtype=dtype
|
294
|
+
)
|
295
|
+
# Dropout for decoder.
|
296
|
+
self.decoder_dropout = keras.layers.Dropout(
|
297
|
+
dropout, name="decoder_dropout", dtype=dtype
|
298
|
+
)
|
299
|
+
|
300
|
+
# Encoder blocks.
|
301
|
+
self.encoder_blocks = []
|
302
|
+
for i in range(encoder_num_layers):
|
303
|
+
encoder_block = MoonshineEncoderBlock(
|
304
|
+
hidden_dim=hidden_dim,
|
305
|
+
intermediate_dim=intermediate_dim,
|
306
|
+
num_heads=encoder_num_heads,
|
307
|
+
feedforward_expansion_factor=feedforward_expansion_factor,
|
308
|
+
use_swiglu_activation=encoder_use_swiglu_activation,
|
309
|
+
pad_head_dim_to_multiple_of=pad_head_dim_to_multiple_of,
|
310
|
+
initializer_range=initializer_range,
|
311
|
+
attention_bias=attention_bias,
|
312
|
+
attention_dropout=attention_dropout,
|
313
|
+
name=f"encoder_block_{i}",
|
314
|
+
dtype=dtype,
|
315
|
+
)
|
316
|
+
self.encoder_blocks.append(encoder_block)
|
317
|
+
|
318
|
+
# Layer normalization for encoder.
|
319
|
+
self.encoder_final_layer_norm = keras.layers.LayerNormalization(
|
320
|
+
epsilon=1e-5,
|
321
|
+
center=False,
|
322
|
+
scale=True,
|
323
|
+
name="encoder_final_layer_norm",
|
324
|
+
dtype=dtype,
|
325
|
+
)
|
326
|
+
|
327
|
+
# Decoder blocks.
|
328
|
+
self.decoder_blocks = []
|
329
|
+
for i in range(decoder_num_layers):
|
330
|
+
decoder_block = MoonshineDecoderBlock(
|
331
|
+
hidden_dim=hidden_dim,
|
332
|
+
intermediate_dim=intermediate_dim,
|
333
|
+
num_heads=decoder_num_heads,
|
334
|
+
feedforward_expansion_factor=feedforward_expansion_factor,
|
335
|
+
use_swiglu_activation=decoder_use_swiglu_activation,
|
336
|
+
pad_head_dim_to_multiple_of=pad_head_dim_to_multiple_of,
|
337
|
+
initializer_range=initializer_range,
|
338
|
+
attention_bias=attention_bias,
|
339
|
+
attention_dropout=attention_dropout,
|
340
|
+
name=f"decoder_block_{i}",
|
341
|
+
dtype=dtype,
|
342
|
+
)
|
343
|
+
self.decoder_blocks.append(decoder_block)
|
344
|
+
|
345
|
+
# Layer normalization for decoder.
|
346
|
+
self.decoder_post_norm = keras.layers.LayerNormalization(
|
347
|
+
epsilon=1e-5,
|
348
|
+
center=False,
|
349
|
+
scale=True,
|
350
|
+
name="decoder_post_norm",
|
351
|
+
dtype=dtype,
|
352
|
+
)
|
353
|
+
|
354
|
+
# === Functional Model ===
|
355
|
+
encoder_raw_input_values = keras.Input(
|
356
|
+
shape=(None, 1), name="encoder_input_values", dtype=dtype
|
357
|
+
)
|
358
|
+
encoder_input_padding_mask = keras.Input(
|
359
|
+
shape=(None,), name="encoder_padding_mask", dtype="bool"
|
360
|
+
)
|
361
|
+
decoder_input = keras.Input(
|
362
|
+
shape=(None,), name="decoder_token_ids", dtype="int32"
|
363
|
+
)
|
364
|
+
decoder_padding_mask = keras.Input(
|
365
|
+
shape=(None,), name="decoder_padding_mask", dtype="bool"
|
366
|
+
)
|
367
|
+
|
368
|
+
# Feature extraction.
|
369
|
+
encoder_hidden_states = self.conv1(encoder_raw_input_values)
|
370
|
+
encoder_hidden_states = self.tanh_after_conv1(encoder_hidden_states)
|
371
|
+
encoder_hidden_states = self.group_norm(encoder_hidden_states)
|
372
|
+
encoder_hidden_states = self.conv2(encoder_hidden_states)
|
373
|
+
encoder_hidden_states = self.gelu_after_conv2(encoder_hidden_states)
|
374
|
+
encoder_hidden_states = self.conv3(encoder_hidden_states)
|
375
|
+
encoder_hidden_states = self.gelu_after_conv3(encoder_hidden_states)
|
376
|
+
|
377
|
+
# Compute mask for encoder features.
|
378
|
+
original_lengths = keras.ops.sum(
|
379
|
+
keras.ops.cast(encoder_input_padding_mask, "int32"), axis=1
|
380
|
+
)
|
381
|
+
output_lengths = compute_output_lengths(original_lengths)
|
382
|
+
encoder_attention_mask = self._compute_mask_layer(
|
383
|
+
encoder_hidden_states, output_lengths
|
384
|
+
)
|
385
|
+
|
386
|
+
# Encoder.
|
387
|
+
encoder_positions = Arange(name="encoder_positions")(
|
388
|
+
encoder_hidden_states
|
389
|
+
)
|
390
|
+
encoder_rotary_emb = self.encoder_rotary_embedding(encoder_positions)
|
391
|
+
encoder_hidden_states = self.encoder_dropout(encoder_hidden_states)
|
392
|
+
for encoder_block in self.encoder_blocks:
|
393
|
+
encoder_hidden_states = encoder_block(
|
394
|
+
encoder_hidden_states,
|
395
|
+
encoder_rotary_emb,
|
396
|
+
attention_mask=encoder_attention_mask,
|
397
|
+
)
|
398
|
+
encoder_output = self.encoder_final_layer_norm(encoder_hidden_states)
|
399
|
+
|
400
|
+
# Decoder.
|
401
|
+
decoder_positions = Arange(name="decoder_positions")(decoder_input)
|
402
|
+
decoder_rotary_emb = self.decoder_rotary_embedding(decoder_positions)
|
403
|
+
decoder_hidden_states = self.token_embedding(decoder_input)
|
404
|
+
decoder_hidden_states = self.decoder_dropout(decoder_hidden_states)
|
405
|
+
for decoder_block in self.decoder_blocks:
|
406
|
+
decoder_hidden_states = decoder_block(
|
407
|
+
decoder_sequence=decoder_hidden_states,
|
408
|
+
encoder_sequence=encoder_output,
|
409
|
+
rotary_embedding=decoder_rotary_emb,
|
410
|
+
decoder_padding_mask=decoder_padding_mask,
|
411
|
+
encoder_padding_mask=encoder_attention_mask,
|
412
|
+
)
|
413
|
+
decoder_output = self.decoder_post_norm(decoder_hidden_states)
|
414
|
+
|
415
|
+
super().__init__(
|
416
|
+
inputs={
|
417
|
+
"encoder_input_values": encoder_raw_input_values,
|
418
|
+
"encoder_padding_mask": encoder_input_padding_mask,
|
419
|
+
"decoder_token_ids": decoder_input,
|
420
|
+
"decoder_padding_mask": decoder_padding_mask,
|
421
|
+
},
|
422
|
+
outputs={
|
423
|
+
"encoder_sequence_output": encoder_output,
|
424
|
+
"decoder_sequence_output": decoder_output,
|
425
|
+
"encoder_attention_mask": encoder_attention_mask,
|
426
|
+
},
|
427
|
+
dtype=dtype,
|
428
|
+
**kwargs,
|
429
|
+
)
|
430
|
+
|
431
|
+
# ==== Config ====
|
432
|
+
self.vocabulary_size = vocabulary_size
|
433
|
+
self.filter_dim = filter_dim
|
434
|
+
self.encoder_num_layers = encoder_num_layers
|
435
|
+
self.decoder_num_layers = decoder_num_layers
|
436
|
+
self.hidden_dim = hidden_dim
|
437
|
+
self.intermediate_dim = intermediate_dim
|
438
|
+
self.encoder_num_heads = encoder_num_heads
|
439
|
+
self.decoder_num_heads = decoder_num_heads
|
440
|
+
self.feedforward_expansion_factor = feedforward_expansion_factor
|
441
|
+
self.encoder_use_swiglu_activation = encoder_use_swiglu_activation
|
442
|
+
self.decoder_use_swiglu_activation = decoder_use_swiglu_activation
|
443
|
+
self.max_position_embeddings = max_position_embeddings
|
444
|
+
self.pad_head_dim_to_multiple_of = pad_head_dim_to_multiple_of
|
445
|
+
self.partial_rotary_factor = partial_rotary_factor
|
446
|
+
self.dropout = dropout
|
447
|
+
self.initializer_range = initializer_range
|
448
|
+
self.rope_theta = rope_theta
|
449
|
+
self.attention_bias = attention_bias
|
450
|
+
self.attention_dropout = attention_dropout
|
451
|
+
|
452
|
+
def get_config(self):
|
453
|
+
config = super().get_config()
|
454
|
+
config.update(
|
455
|
+
{
|
456
|
+
"vocabulary_size": self.vocabulary_size,
|
457
|
+
"filter_dim": self.filter_dim,
|
458
|
+
"encoder_num_layers": self.encoder_num_layers,
|
459
|
+
"decoder_num_layers": self.decoder_num_layers,
|
460
|
+
"hidden_dim": self.hidden_dim,
|
461
|
+
"intermediate_dim": self.intermediate_dim,
|
462
|
+
"encoder_num_heads": self.encoder_num_heads,
|
463
|
+
"decoder_num_heads": self.decoder_num_heads,
|
464
|
+
"feedforward_expansion_factor": self.feedforward_expansion_factor, # noqa: E501
|
465
|
+
"encoder_use_swiglu_activation": self.encoder_use_swiglu_activation, # noqa: E501
|
466
|
+
"decoder_use_swiglu_activation": self.decoder_use_swiglu_activation, # noqa: E501
|
467
|
+
"max_position_embeddings": self.max_position_embeddings,
|
468
|
+
"pad_head_dim_to_multiple_of": self.pad_head_dim_to_multiple_of,
|
469
|
+
"partial_rotary_factor": self.partial_rotary_factor,
|
470
|
+
"dropout": self.dropout,
|
471
|
+
"initializer_range": self.initializer_range,
|
472
|
+
"rope_theta": self.rope_theta,
|
473
|
+
"attention_bias": self.attention_bias,
|
474
|
+
"attention_dropout": self.attention_dropout,
|
475
|
+
"dtype": self.dtype,
|
476
|
+
}
|
477
|
+
)
|
478
|
+
return config
|