keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +52 -0
- keras_hub/api/__init__.py +27 -0
- keras_hub/api/layers/__init__.py +47 -0
- keras_hub/api/metrics/__init__.py +24 -0
- keras_hub/api/models/__init__.py +249 -0
- keras_hub/api/samplers/__init__.py +29 -0
- keras_hub/api/tokenizers/__init__.py +35 -0
- keras_hub/src/__init__.py +13 -0
- keras_hub/src/api_export.py +53 -0
- keras_hub/src/layers/__init__.py +13 -0
- keras_hub/src/layers/modeling/__init__.py +13 -0
- keras_hub/src/layers/modeling/alibi_bias.py +143 -0
- keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
- keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
- keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
- keras_hub/src/layers/modeling/position_embedding.py +123 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
- keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
- keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
- keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
- keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
- keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
- keras_hub/src/layers/preprocessing/__init__.py +13 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
- keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
- keras_hub/src/layers/preprocessing/random_swap.py +267 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
- keras_hub/src/metrics/__init__.py +13 -0
- keras_hub/src/metrics/bleu.py +394 -0
- keras_hub/src/metrics/edit_distance.py +197 -0
- keras_hub/src/metrics/perplexity.py +181 -0
- keras_hub/src/metrics/rouge_base.py +204 -0
- keras_hub/src/metrics/rouge_l.py +97 -0
- keras_hub/src/metrics/rouge_n.py +125 -0
- keras_hub/src/models/__init__.py +13 -0
- keras_hub/src/models/albert/__init__.py +20 -0
- keras_hub/src/models/albert/albert_backbone.py +267 -0
- keras_hub/src/models/albert/albert_classifier.py +202 -0
- keras_hub/src/models/albert/albert_masked_lm.py +129 -0
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/albert/albert_preprocessor.py +206 -0
- keras_hub/src/models/albert/albert_presets.py +70 -0
- keras_hub/src/models/albert/albert_tokenizer.py +119 -0
- keras_hub/src/models/backbone.py +311 -0
- keras_hub/src/models/bart/__init__.py +20 -0
- keras_hub/src/models/bart/bart_backbone.py +261 -0
- keras_hub/src/models/bart/bart_preprocessor.py +276 -0
- keras_hub/src/models/bart/bart_presets.py +74 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
- keras_hub/src/models/bart/bart_tokenizer.py +124 -0
- keras_hub/src/models/bert/__init__.py +23 -0
- keras_hub/src/models/bert/bert_backbone.py +227 -0
- keras_hub/src/models/bert/bert_classifier.py +183 -0
- keras_hub/src/models/bert/bert_masked_lm.py +131 -0
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/bert/bert_preprocessor.py +184 -0
- keras_hub/src/models/bert/bert_presets.py +147 -0
- keras_hub/src/models/bert/bert_tokenizer.py +112 -0
- keras_hub/src/models/bloom/__init__.py +20 -0
- keras_hub/src/models/bloom/bloom_attention.py +186 -0
- keras_hub/src/models/bloom/bloom_backbone.py +173 -0
- keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
- keras_hub/src/models/bloom/bloom_decoder.py +206 -0
- keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
- keras_hub/src/models/bloom/bloom_presets.py +121 -0
- keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
- keras_hub/src/models/causal_lm.py +383 -0
- keras_hub/src/models/classifier.py +109 -0
- keras_hub/src/models/csp_darknet/__init__.py +13 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
- keras_hub/src/models/deberta_v3/__init__.py +24 -0
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
- keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
- keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
- keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
- keras_hub/src/models/densenet/__init__.py +13 -0
- keras_hub/src/models/densenet/densenet_backbone.py +210 -0
- keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
- keras_hub/src/models/distil_bert/__init__.py +26 -0
- keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
- keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
- keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
- keras_hub/src/models/electra/__init__.py +20 -0
- keras_hub/src/models/electra/electra_backbone.py +247 -0
- keras_hub/src/models/electra/electra_preprocessor.py +154 -0
- keras_hub/src/models/electra/electra_presets.py +95 -0
- keras_hub/src/models/electra/electra_tokenizer.py +104 -0
- keras_hub/src/models/f_net/__init__.py +20 -0
- keras_hub/src/models/f_net/f_net_backbone.py +236 -0
- keras_hub/src/models/f_net/f_net_classifier.py +154 -0
- keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
- keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
- keras_hub/src/models/f_net/f_net_presets.py +43 -0
- keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
- keras_hub/src/models/falcon/__init__.py +20 -0
- keras_hub/src/models/falcon/falcon_attention.py +156 -0
- keras_hub/src/models/falcon/falcon_backbone.py +164 -0
- keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
- keras_hub/src/models/falcon/falcon_presets.py +30 -0
- keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
- keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
- keras_hub/src/models/feature_pyramid_backbone.py +73 -0
- keras_hub/src/models/gemma/__init__.py +20 -0
- keras_hub/src/models/gemma/gemma_attention.py +250 -0
- keras_hub/src/models/gemma/gemma_backbone.py +316 -0
- keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
- keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
- keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
- keras_hub/src/models/gemma/gemma_presets.py +248 -0
- keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
- keras_hub/src/models/gemma/rms_normalization.py +40 -0
- keras_hub/src/models/gpt2/__init__.py +20 -0
- keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
- keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
- keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
- keras_hub/src/models/image_classifier.py +90 -0
- keras_hub/src/models/llama/__init__.py +20 -0
- keras_hub/src/models/llama/llama_attention.py +225 -0
- keras_hub/src/models/llama/llama_backbone.py +188 -0
- keras_hub/src/models/llama/llama_causal_lm.py +327 -0
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
- keras_hub/src/models/llama/llama_decoder.py +246 -0
- keras_hub/src/models/llama/llama_layernorm.py +48 -0
- keras_hub/src/models/llama/llama_preprocessor.py +189 -0
- keras_hub/src/models/llama/llama_presets.py +80 -0
- keras_hub/src/models/llama/llama_tokenizer.py +84 -0
- keras_hub/src/models/llama3/__init__.py +20 -0
- keras_hub/src/models/llama3/llama3_backbone.py +84 -0
- keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
- keras_hub/src/models/llama3/llama3_presets.py +69 -0
- keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
- keras_hub/src/models/masked_lm.py +101 -0
- keras_hub/src/models/mistral/__init__.py +20 -0
- keras_hub/src/models/mistral/mistral_attention.py +238 -0
- keras_hub/src/models/mistral/mistral_backbone.py +203 -0
- keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
- keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
- keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
- keras_hub/src/models/mistral/mistral_presets.py +48 -0
- keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
- keras_hub/src/models/mix_transformer/__init__.py +13 -0
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
- keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
- keras_hub/src/models/opt/__init__.py +20 -0
- keras_hub/src/models/opt/opt_backbone.py +173 -0
- keras_hub/src/models/opt/opt_causal_lm.py +301 -0
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
- keras_hub/src/models/opt/opt_preprocessor.py +188 -0
- keras_hub/src/models/opt/opt_presets.py +72 -0
- keras_hub/src/models/opt/opt_tokenizer.py +116 -0
- keras_hub/src/models/pali_gemma/__init__.py +23 -0
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
- keras_hub/src/models/phi3/__init__.py +20 -0
- keras_hub/src/models/phi3/phi3_attention.py +260 -0
- keras_hub/src/models/phi3/phi3_backbone.py +224 -0
- keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/phi3/phi3_decoder.py +260 -0
- keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
- keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
- keras_hub/src/models/phi3/phi3_presets.py +50 -0
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
- keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
- keras_hub/src/models/preprocessor.py +207 -0
- keras_hub/src/models/resnet/__init__.py +13 -0
- keras_hub/src/models/resnet/resnet_backbone.py +612 -0
- keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
- keras_hub/src/models/roberta/__init__.py +20 -0
- keras_hub/src/models/roberta/roberta_backbone.py +184 -0
- keras_hub/src/models/roberta/roberta_classifier.py +209 -0
- keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
- keras_hub/src/models/roberta/roberta_presets.py +43 -0
- keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
- keras_hub/src/models/seq_2_seq_lm.py +54 -0
- keras_hub/src/models/t5/__init__.py +20 -0
- keras_hub/src/models/t5/t5_backbone.py +261 -0
- keras_hub/src/models/t5/t5_layer_norm.py +35 -0
- keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
- keras_hub/src/models/t5/t5_presets.py +95 -0
- keras_hub/src/models/t5/t5_tokenizer.py +100 -0
- keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
- keras_hub/src/models/task.py +419 -0
- keras_hub/src/models/vgg/__init__.py +13 -0
- keras_hub/src/models/vgg/vgg_backbone.py +158 -0
- keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
- keras_hub/src/models/vit_det/__init__.py +13 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
- keras_hub/src/models/vit_det/vit_layers.py +565 -0
- keras_hub/src/models/whisper/__init__.py +20 -0
- keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
- keras_hub/src/models/whisper/whisper_backbone.py +305 -0
- keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
- keras_hub/src/models/whisper/whisper_decoder.py +141 -0
- keras_hub/src/models/whisper/whisper_encoder.py +106 -0
- keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
- keras_hub/src/models/whisper/whisper_presets.py +148 -0
- keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
- keras_hub/src/models/xlm_roberta/__init__.py +26 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
- keras_hub/src/models/xlnet/__init__.py +13 -0
- keras_hub/src/models/xlnet/relative_attention.py +459 -0
- keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
- keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
- keras_hub/src/samplers/__init__.py +13 -0
- keras_hub/src/samplers/beam_sampler.py +207 -0
- keras_hub/src/samplers/contrastive_sampler.py +231 -0
- keras_hub/src/samplers/greedy_sampler.py +50 -0
- keras_hub/src/samplers/random_sampler.py +77 -0
- keras_hub/src/samplers/sampler.py +237 -0
- keras_hub/src/samplers/serialization.py +97 -0
- keras_hub/src/samplers/top_k_sampler.py +92 -0
- keras_hub/src/samplers/top_p_sampler.py +113 -0
- keras_hub/src/tests/__init__.py +13 -0
- keras_hub/src/tests/test_case.py +608 -0
- keras_hub/src/tokenizers/__init__.py +13 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
- keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
- keras_hub/src/tokenizers/tokenizer.py +235 -0
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
- keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
- keras_hub/src/utils/__init__.py +13 -0
- keras_hub/src/utils/keras_utils.py +130 -0
- keras_hub/src/utils/pipeline_model.py +293 -0
- keras_hub/src/utils/preset_utils.py +621 -0
- keras_hub/src/utils/python_utils.py +21 -0
- keras_hub/src/utils/tensor_utils.py +206 -0
- keras_hub/src/utils/timm/__init__.py +13 -0
- keras_hub/src/utils/timm/convert.py +37 -0
- keras_hub/src/utils/timm/convert_resnet.py +171 -0
- keras_hub/src/utils/transformers/__init__.py +13 -0
- keras_hub/src/utils/transformers/convert.py +101 -0
- keras_hub/src/utils/transformers/convert_bert.py +173 -0
- keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
- keras_hub/src/utils/transformers/convert_gemma.py +187 -0
- keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
- keras_hub/src/utils/transformers/convert_llama3.py +136 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
- keras_hub/src/version_utils.py +23 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,260 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
|
16
|
+
import numpy as np
|
17
|
+
|
18
|
+
try:
|
19
|
+
import tensorflow as tf
|
20
|
+
except ImportError:
|
21
|
+
raise ImportError(
|
22
|
+
"To use `keras_hub`, please install Tensorflow: `pip install tensorflow`. "
|
23
|
+
"The TensorFlow package is required for data preprocessing with any backend."
|
24
|
+
)
|
25
|
+
|
26
|
+
from keras_hub.src.api_export import keras_hub_export
|
27
|
+
from keras_hub.src.layers.preprocessing.preprocessing_layer import (
|
28
|
+
PreprocessingLayer,
|
29
|
+
)
|
30
|
+
|
31
|
+
|
32
|
+
@keras_hub_export("keras_hub.models.WhisperAudioFeatureExtractor")
|
33
|
+
class WhisperAudioFeatureExtractor(PreprocessingLayer):
|
34
|
+
"""
|
35
|
+
Whisper audio feature extractor layer.
|
36
|
+
|
37
|
+
This layer takes in a batch of audio tensors, and computes the log-mel
|
38
|
+
spectrogram features for each audio tensor.
|
39
|
+
|
40
|
+
The input audio tensor can either be of shape `(length_of_audio,)` or
|
41
|
+
`(batch_size, length_of_audio)`. The output is a tensor of shape
|
42
|
+
`(batch_size, num_frames, num_mels)`, where `num_frames` is
|
43
|
+
`(max_audio_length * sampling_rate) / stride`.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
num_mels: int. The number of mel-frequency filters. Defaults to `80`.
|
47
|
+
num_fft_bins: int. The size of the Fourier Transform in STFT.
|
48
|
+
Defaults to `400`.
|
49
|
+
stride: int. The distance between neighboring
|
50
|
+
sliding window frames while computing STFT.
|
51
|
+
Defaults to `160`.
|
52
|
+
sampling_rate: int. The sample rate of the audio. Defaults to `16000`.
|
53
|
+
max_audio_length: int. The length of each audio chunk in
|
54
|
+
seconds. The input audio tensor will be padded/trimmed to
|
55
|
+
`max_audio_length * sampling_rate`. Defaults to `30`.
|
56
|
+
|
57
|
+
Examples:
|
58
|
+
|
59
|
+
```python
|
60
|
+
audio_tensor = tf.ones((8000,), dtype="float32")
|
61
|
+
|
62
|
+
# Compute the log-mel spectrogram.
|
63
|
+
whisper_audio_feature_extractor = keras_hub.models.WhisperAudioFeatureExtractor()
|
64
|
+
whisper_audio_feature_extractor(audio_tensor)
|
65
|
+
|
66
|
+
# Compute the log-mel spectrogram for a batch of audio tensors.
|
67
|
+
audio_tensor_1 = tf.ones((8000,), dtype="float32")
|
68
|
+
audio_tensor_2 = tf.ones((10000,), dtype="float32"
|
69
|
+
audio_tensor = tf.ragged.stack([audio_tensor_1, audio_tensor_2], axis=0)
|
70
|
+
whisper_audio_feature_extractor(audio_tensor)
|
71
|
+
```
|
72
|
+
"""
|
73
|
+
|
74
|
+
def __init__(
|
75
|
+
self,
|
76
|
+
num_mels=80,
|
77
|
+
num_fft_bins=400,
|
78
|
+
stride=160,
|
79
|
+
sampling_rate=16000,
|
80
|
+
max_audio_length=30,
|
81
|
+
**kwargs,
|
82
|
+
):
|
83
|
+
super().__init__(**kwargs)
|
84
|
+
|
85
|
+
self._convert_input_args = False
|
86
|
+
self._allow_non_tensor_positional_args = True
|
87
|
+
self.built = True
|
88
|
+
|
89
|
+
self.num_mels = num_mels
|
90
|
+
self.num_fft_bins = num_fft_bins
|
91
|
+
self.stride = stride
|
92
|
+
self.sampling_rate = sampling_rate
|
93
|
+
self.max_audio_length = max_audio_length
|
94
|
+
self.num_samples = self.sampling_rate * self.max_audio_length
|
95
|
+
|
96
|
+
# After transposition, `self.mel_filters`'s shape is
|
97
|
+
# `(num_fft_bins // 2 + 1, num_mels).`
|
98
|
+
self.mel_filters = self._get_mel_filters()
|
99
|
+
|
100
|
+
def _get_mel_filters(self):
|
101
|
+
"""
|
102
|
+
Adapted from Hugging Face
|
103
|
+
(https://github.com/huggingface/transformers/blob/v4.27.1/src/transformers/models/whisper/feature_extraction_whisper.py#L86)
|
104
|
+
"""
|
105
|
+
|
106
|
+
# TODO: Convert to TensorFlow ops (if possible).
|
107
|
+
|
108
|
+
dtype = np.float32
|
109
|
+
# Initialize the weights
|
110
|
+
weights = np.zeros(
|
111
|
+
(self.num_mels, int(1 + self.num_fft_bins // 2)), dtype=dtype
|
112
|
+
)
|
113
|
+
|
114
|
+
# Center freqs of each FFT bin
|
115
|
+
fftfreqs = np.fft.rfftfreq(
|
116
|
+
n=self.num_fft_bins, d=1.0 / self.sampling_rate
|
117
|
+
)
|
118
|
+
|
119
|
+
# 'Center freqs' of mel bands - uniformly spaced between limits
|
120
|
+
min_mel = 0.0
|
121
|
+
max_mel = 45.245640471924965
|
122
|
+
|
123
|
+
mels = np.linspace(min_mel, max_mel, self.num_mels + 2)
|
124
|
+
|
125
|
+
mels = np.asanyarray(mels)
|
126
|
+
|
127
|
+
# Fill in the linear scale
|
128
|
+
f_min = 0.0
|
129
|
+
f_sp = 200.0 / 3
|
130
|
+
freqs = f_min + f_sp * mels
|
131
|
+
|
132
|
+
# And now the nonlinear scale
|
133
|
+
min_log_hz = 1000.0 # beginning of log region (Hz)
|
134
|
+
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
|
135
|
+
logstep = np.log(6.4) / 27.0 # step size for log region
|
136
|
+
|
137
|
+
# If we have vector data, vectorize
|
138
|
+
log_t = mels >= min_log_mel
|
139
|
+
freqs[log_t] = min_log_hz * np.exp(
|
140
|
+
logstep * (mels[log_t] - min_log_mel)
|
141
|
+
)
|
142
|
+
|
143
|
+
mel_f = freqs
|
144
|
+
|
145
|
+
fdiff = np.diff(mel_f)
|
146
|
+
ramps = np.subtract.outer(mel_f, fftfreqs)
|
147
|
+
|
148
|
+
for i in range(self.num_mels):
|
149
|
+
# lower and upper slopes for all bins
|
150
|
+
lower = -ramps[i] / fdiff[i]
|
151
|
+
upper = ramps[i + 2] / fdiff[i + 1]
|
152
|
+
|
153
|
+
# .. then intersect them with each other and zero
|
154
|
+
weights[i] = np.maximum(0, np.minimum(lower, upper))
|
155
|
+
|
156
|
+
# Slaney-style mel is scaled to be approx constant energy per channel
|
157
|
+
enorm = 2.0 / (mel_f[2 : self.num_mels + 2] - mel_f[: self.num_mels])
|
158
|
+
weights *= enorm[:, np.newaxis]
|
159
|
+
|
160
|
+
weights = np.transpose(weights)
|
161
|
+
return tf.constant(weights, dtype=self.compute_dtype)
|
162
|
+
|
163
|
+
def _extract_audio_features(self, audio):
|
164
|
+
audio = tf.cast(audio, self.compute_dtype)
|
165
|
+
# Use "reflection" padding - `tf.signal.stft` uses symmetric padding
|
166
|
+
# internally.
|
167
|
+
audio = tf.pad(
|
168
|
+
audio,
|
169
|
+
paddings=[[0, 0], [self.num_fft_bins // 2, self.num_fft_bins // 2]],
|
170
|
+
mode="REFLECT",
|
171
|
+
)
|
172
|
+
|
173
|
+
# Compute the mel spectrogram.
|
174
|
+
stft = tf.signal.stft(
|
175
|
+
audio,
|
176
|
+
frame_length=self.num_fft_bins,
|
177
|
+
frame_step=self.stride,
|
178
|
+
fft_length=self.num_fft_bins,
|
179
|
+
)
|
180
|
+
magnitudes = tf.square(tf.abs(stft[:, :-1, :]))
|
181
|
+
|
182
|
+
mel_spec = tf.matmul(
|
183
|
+
magnitudes,
|
184
|
+
self.mel_filters,
|
185
|
+
)
|
186
|
+
|
187
|
+
def tf_log10(x):
|
188
|
+
"""
|
189
|
+
Computes log base 10 of input tensor using TensorFlow's natural log operator.
|
190
|
+
"""
|
191
|
+
numerator = tf.math.log(x)
|
192
|
+
denominator = tf.math.log(tf.constant(10, dtype=numerator.dtype))
|
193
|
+
return numerator / denominator
|
194
|
+
|
195
|
+
# Clamp the values to a minimum value of 1e-10. This is done to avoid
|
196
|
+
# taking the log of 0, i.e., for numerical stability.
|
197
|
+
mel_spec = tf.maximum(mel_spec, 1e-10)
|
198
|
+
|
199
|
+
# Calculate the log mel spectrogram.
|
200
|
+
log_spec = tf_log10(mel_spec)
|
201
|
+
# Dynamic range compression.
|
202
|
+
log_spec_shape = tf.shape(log_spec)
|
203
|
+
max_value_minus_eight = tf.math.subtract(
|
204
|
+
tf.math.reduce_max(log_spec, axis=[1, 2]),
|
205
|
+
tf.cast(8, dtype=log_spec.dtype),
|
206
|
+
)
|
207
|
+
max_value_minus_eight = tf.expand_dims(max_value_minus_eight, axis=1)
|
208
|
+
max_value_minus_eight = tf.repeat(
|
209
|
+
max_value_minus_eight,
|
210
|
+
repeats=log_spec_shape[1] * log_spec_shape[2],
|
211
|
+
axis=1,
|
212
|
+
)
|
213
|
+
max_value_minus_eight = tf.reshape(
|
214
|
+
max_value_minus_eight, shape=log_spec_shape
|
215
|
+
)
|
216
|
+
log_spec = tf.maximum(log_spec, max_value_minus_eight)
|
217
|
+
# Normalization.
|
218
|
+
type_cast_four = tf.cast(4, dtype=log_spec.dtype)
|
219
|
+
log_spec = tf.math.divide(
|
220
|
+
tf.math.add(log_spec, type_cast_four),
|
221
|
+
type_cast_four,
|
222
|
+
)
|
223
|
+
|
224
|
+
return log_spec
|
225
|
+
|
226
|
+
def call(self, audio):
|
227
|
+
if not isinstance(audio, (tf.Tensor, tf.RaggedTensor)):
|
228
|
+
audio = tf.convert_to_tensor(audio)
|
229
|
+
|
230
|
+
rank_1_input = audio.shape.rank == 1
|
231
|
+
if rank_1_input:
|
232
|
+
audio = tf.expand_dims(audio, 0)
|
233
|
+
|
234
|
+
# Convert the tensor to a Ragged Tensor.
|
235
|
+
if isinstance(audio, tf.Tensor):
|
236
|
+
audio = tf.RaggedTensor.from_tensor(audio)
|
237
|
+
|
238
|
+
# Pad audio.
|
239
|
+
audio_shape = audio.shape.as_list()
|
240
|
+
audio_shape[-1] = self.num_samples
|
241
|
+
audio = audio.to_tensor(shape=audio_shape)
|
242
|
+
|
243
|
+
# Find the log mel spectrogram.
|
244
|
+
log_spec = self._extract_audio_features(audio)
|
245
|
+
if rank_1_input:
|
246
|
+
log_spec = tf.squeeze(log_spec, 0)
|
247
|
+
return log_spec
|
248
|
+
|
249
|
+
def get_config(self):
|
250
|
+
config = super().get_config()
|
251
|
+
config.update(
|
252
|
+
{
|
253
|
+
"num_mels": self.num_mels,
|
254
|
+
"num_fft_bins": self.num_fft_bins,
|
255
|
+
"stride": self.stride,
|
256
|
+
"sampling_rate": self.sampling_rate,
|
257
|
+
"max_audio_length": self.max_audio_length,
|
258
|
+
}
|
259
|
+
)
|
260
|
+
return config
|
@@ -0,0 +1,305 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
|
16
|
+
import keras
|
17
|
+
from keras import ops
|
18
|
+
|
19
|
+
from keras_hub.src.api_export import keras_hub_export
|
20
|
+
from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
|
21
|
+
from keras_hub.src.layers.modeling.token_and_position_embedding import (
|
22
|
+
TokenAndPositionEmbedding,
|
23
|
+
)
|
24
|
+
from keras_hub.src.models.backbone import Backbone
|
25
|
+
from keras_hub.src.models.whisper.whisper_decoder import WhisperDecoder
|
26
|
+
from keras_hub.src.models.whisper.whisper_encoder import WhisperEncoder
|
27
|
+
from keras_hub.src.utils.tensor_utils import assert_tf_backend
|
28
|
+
|
29
|
+
|
30
|
+
def whisper_kernel_initializer(stddev=0.02):
|
31
|
+
return keras.initializers.TruncatedNormal(stddev=stddev)
|
32
|
+
|
33
|
+
|
34
|
+
class Padder(keras.layers.Layer):
|
35
|
+
def call(self, x):
|
36
|
+
return ops.pad(x, [[0, 0], [1, 1], [0, 0]])
|
37
|
+
|
38
|
+
|
39
|
+
@keras_hub_export("keras_hub.models.WhisperBackbone")
|
40
|
+
class WhisperBackbone(Backbone):
|
41
|
+
"""A Whisper encoder-decoder network for speech.
|
42
|
+
|
43
|
+
This class implements a Transformer-based encoder-decoder model as
|
44
|
+
described in
|
45
|
+
["Robust Speech Recognition via Large-Scale Weak Supervision"](https://arxiv.org/abs/2212.04356).
|
46
|
+
It includes the embedding lookups and transformer layers, but not the head
|
47
|
+
for predicting the next token.
|
48
|
+
|
49
|
+
The default constructor gives a fully customizable, randomly initialized Whisper
|
50
|
+
model with any number of layers, heads, and embedding dimensions. To load
|
51
|
+
preset architectures and weights, use the `from_preset()` constructor.
|
52
|
+
|
53
|
+
Disclaimer: Pre-trained models are provided on an "as is" basis, without
|
54
|
+
warranties or conditions of any kind. The underlying model is provided by a
|
55
|
+
third party and subject to a separate license, available
|
56
|
+
[here](https://github.com/openai/whisper).
|
57
|
+
|
58
|
+
Args:
|
59
|
+
vocabulary_size: int. The size of the token vocabulary.
|
60
|
+
num_layers: int. The number of transformer encoder layers and
|
61
|
+
transformer decoder layers.
|
62
|
+
num_heads: int. The number of attention heads for each transformer.
|
63
|
+
The hidden size must be divisible by the number of attention heads.
|
64
|
+
hidden_dim: int. The size of the transformer encoding and pooler layers.
|
65
|
+
intermediate_dim: int. The output dimension of the first Dense layer in
|
66
|
+
a two-layer feedforward network for each transformer.
|
67
|
+
num_mels: int. The number of mel-frequency filters. Defaults to `80`.
|
68
|
+
dropout: float. Dropout probability for the Transformer encoder.
|
69
|
+
max_encoder_sequence_length: int. The maximum sequence length that the
|
70
|
+
audio encoder can consume. Since the second convolutional layer in
|
71
|
+
the encoder reduces the sequence length by half (stride of 2), we
|
72
|
+
use `max_encoder_sequence_length // 2` as the sequence length for the
|
73
|
+
positional embedding layer.
|
74
|
+
max_decoder_sequence_length: int. The maximum sequence length that the
|
75
|
+
text decoder can consume.
|
76
|
+
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
|
77
|
+
for model computations and weights. Note that some computations,
|
78
|
+
such as softmax and layer normalization, will always be done at
|
79
|
+
float32 precision regardless of dtype.
|
80
|
+
|
81
|
+
Examples:
|
82
|
+
|
83
|
+
```python
|
84
|
+
input_data = {
|
85
|
+
"encoder_features": np.ones(shape=(1, 12, 80), dtype="int32"),
|
86
|
+
"decoder_token_ids": np.ones(shape=(1, 12), dtype="int32"),
|
87
|
+
"decoder_padding_mask": np.array(
|
88
|
+
[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]
|
89
|
+
),
|
90
|
+
}
|
91
|
+
|
92
|
+
# Randomly initialized Whisper encoder-decoder model with a custom config.
|
93
|
+
model = keras_hub.models.WhisperBackbone(
|
94
|
+
vocabulary_size=51864,
|
95
|
+
num_layers=4,
|
96
|
+
num_heads=4,
|
97
|
+
hidden_dim=256,
|
98
|
+
intermediate_dim=512,
|
99
|
+
max_encoder_sequence_length=128,
|
100
|
+
max_decoder_sequence_length=128,
|
101
|
+
)
|
102
|
+
model(input_data)
|
103
|
+
```
|
104
|
+
"""
|
105
|
+
|
106
|
+
def __init__(
|
107
|
+
self,
|
108
|
+
vocabulary_size,
|
109
|
+
num_layers,
|
110
|
+
num_heads,
|
111
|
+
hidden_dim,
|
112
|
+
intermediate_dim,
|
113
|
+
num_mels=80,
|
114
|
+
dropout=0.0,
|
115
|
+
max_encoder_sequence_length=3000,
|
116
|
+
max_decoder_sequence_length=448,
|
117
|
+
dtype=None,
|
118
|
+
**kwargs,
|
119
|
+
):
|
120
|
+
assert_tf_backend(self.__class__.__name__)
|
121
|
+
|
122
|
+
# === Layers ===
|
123
|
+
self.encoder_conv_layer_1 = keras.layers.Conv1D(
|
124
|
+
filters=hidden_dim,
|
125
|
+
kernel_size=3,
|
126
|
+
strides=1,
|
127
|
+
padding="same",
|
128
|
+
dtype=dtype,
|
129
|
+
name="encoder_token_embedding_conv_layer_1",
|
130
|
+
)
|
131
|
+
self.encoder_conv_layer_2 = keras.layers.Conv1D(
|
132
|
+
filters=hidden_dim,
|
133
|
+
kernel_size=3,
|
134
|
+
strides=2,
|
135
|
+
padding="valid",
|
136
|
+
dtype=dtype,
|
137
|
+
name="encoder_token_embedding_conv_layer_2",
|
138
|
+
)
|
139
|
+
self.encoder_padder = Padder(
|
140
|
+
dtype=dtype,
|
141
|
+
name="encoder_padder",
|
142
|
+
)
|
143
|
+
self.encoder_position_embedding = PositionEmbedding(
|
144
|
+
initializer=whisper_kernel_initializer(),
|
145
|
+
sequence_length=max_encoder_sequence_length // 2,
|
146
|
+
dtype=dtype,
|
147
|
+
name="encoder_position_embedding",
|
148
|
+
trainable=False,
|
149
|
+
)
|
150
|
+
self.encoder_embeddings_add = keras.layers.Add(
|
151
|
+
dtype=dtype,
|
152
|
+
name="encoder_embeddings_add",
|
153
|
+
)
|
154
|
+
self.encoder_embeddings_dropout = keras.layers.Dropout(
|
155
|
+
dropout,
|
156
|
+
dtype=dtype,
|
157
|
+
name="encoder_embeddings_dropout",
|
158
|
+
)
|
159
|
+
self.encoder_transformer_layers = []
|
160
|
+
for i in range(num_layers):
|
161
|
+
layer = WhisperEncoder(
|
162
|
+
num_heads=num_heads,
|
163
|
+
intermediate_dim=intermediate_dim,
|
164
|
+
activation=keras.activations.gelu,
|
165
|
+
layer_norm_epsilon=1e-5,
|
166
|
+
dropout=dropout,
|
167
|
+
kernel_initializer=whisper_kernel_initializer(),
|
168
|
+
normalize_first=True,
|
169
|
+
dtype=dtype,
|
170
|
+
name=f"transformer_encoder_layer_{i}",
|
171
|
+
)
|
172
|
+
self.encoder_transformer_layers.append(layer)
|
173
|
+
self.encoder_layer_norm = keras.layers.LayerNormalization(
|
174
|
+
axis=-1,
|
175
|
+
epsilon=1e-5,
|
176
|
+
dtype=dtype,
|
177
|
+
name="encoder_layer_norm",
|
178
|
+
)
|
179
|
+
self.decoder_embeddings = TokenAndPositionEmbedding(
|
180
|
+
vocabulary_size=vocabulary_size,
|
181
|
+
sequence_length=max_decoder_sequence_length,
|
182
|
+
embedding_dim=hidden_dim,
|
183
|
+
embeddings_initializer=whisper_kernel_initializer(),
|
184
|
+
dtype=dtype,
|
185
|
+
name="decoder_token_and_position_embedding",
|
186
|
+
)
|
187
|
+
self.token_embedding = self.decoder_embeddings.token_embedding
|
188
|
+
self.decoder_embeddings_dropout = keras.layers.Dropout(
|
189
|
+
dropout,
|
190
|
+
dtype=dtype,
|
191
|
+
name="decoder_embeddings_dropout",
|
192
|
+
)
|
193
|
+
self.decoder_transformer_layers = []
|
194
|
+
for i in range(num_layers):
|
195
|
+
layer = WhisperDecoder(
|
196
|
+
intermediate_dim=intermediate_dim,
|
197
|
+
num_heads=num_heads,
|
198
|
+
dropout=dropout,
|
199
|
+
activation=keras.activations.gelu,
|
200
|
+
layer_norm_epsilon=1e-5,
|
201
|
+
kernel_initializer=whisper_kernel_initializer(),
|
202
|
+
normalize_first=True,
|
203
|
+
dtype=dtype,
|
204
|
+
name=f"transformer_decoder_layer_{i}",
|
205
|
+
)
|
206
|
+
self.decoder_transformer_layers.append(layer)
|
207
|
+
self.decoder_layer_norm = keras.layers.LayerNormalization(
|
208
|
+
axis=-1,
|
209
|
+
epsilon=1e-5,
|
210
|
+
dtype=dtype,
|
211
|
+
name="decoder_layer_norm",
|
212
|
+
)
|
213
|
+
|
214
|
+
# === Functional Model ===
|
215
|
+
# Note that the encoder does not have a padding mask:
|
216
|
+
# https://github.com/openai/whisper/blob/v20230124/whisper/model.py#L132.
|
217
|
+
encoder_feature_input = keras.Input(
|
218
|
+
shape=(None, num_mels), dtype="float32", name="encoder_features"
|
219
|
+
)
|
220
|
+
decoder_token_id_input = keras.Input(
|
221
|
+
shape=(None,), dtype="int32", name="decoder_token_ids"
|
222
|
+
)
|
223
|
+
decoder_padding_mask_input = keras.Input(
|
224
|
+
shape=(None,), dtype="int32", name="decoder_padding_mask"
|
225
|
+
)
|
226
|
+
# Encoder.
|
227
|
+
# Embed the input features. This consists of two 1D convolutional
|
228
|
+
# layers.
|
229
|
+
# For the first layer, we use `padding="same"` since that corresponds to
|
230
|
+
# a padding size of 1.
|
231
|
+
embedded_features = keras.activations.gelu(
|
232
|
+
self.encoder_conv_layer_1(encoder_feature_input),
|
233
|
+
approximate=False,
|
234
|
+
)
|
235
|
+
# For the second conv. layer, we cannot use `padding="same"` since
|
236
|
+
# that corresponds to a padding size of 1.5 (since stride is 2). Hence,
|
237
|
+
# we will manually pad the input.
|
238
|
+
embedded_features = self.encoder_padder(embedded_features)
|
239
|
+
embedded_features = keras.activations.gelu(
|
240
|
+
self.encoder_conv_layer_2(embedded_features),
|
241
|
+
approximate=False,
|
242
|
+
)
|
243
|
+
# The position embedding layer for the encoder is a sinusoidal embedding
|
244
|
+
# layer: https://github.com/openai/whisper/blob/v20230124/whisper/model.py#L137.
|
245
|
+
# Hence, we set it to be non-trainable.
|
246
|
+
# TODO: We can use `keras_hub.layers.SinePositionEncoding` layer.
|
247
|
+
positions = self.encoder_position_embedding(embedded_features)
|
248
|
+
x = self.encoder_embeddings_add((embedded_features, positions))
|
249
|
+
x = self.encoder_embeddings_dropout(x)
|
250
|
+
for transformer_layer in self.encoder_transformer_layers:
|
251
|
+
x = transformer_layer(x)
|
252
|
+
x = self.encoder_layer_norm(x)
|
253
|
+
encoder_output = x
|
254
|
+
# Decoder.
|
255
|
+
x = self.decoder_embeddings(decoder_token_id_input)
|
256
|
+
x = self.decoder_embeddings_dropout(x)
|
257
|
+
for transformer_layer in self.decoder_transformer_layers:
|
258
|
+
x = transformer_layer(
|
259
|
+
decoder_sequence=x,
|
260
|
+
encoder_sequence=encoder_output,
|
261
|
+
decoder_padding_mask=decoder_padding_mask_input,
|
262
|
+
)
|
263
|
+
x = self.decoder_layer_norm(x)
|
264
|
+
decoder_output = x
|
265
|
+
super().__init__(
|
266
|
+
inputs={
|
267
|
+
"encoder_features": encoder_feature_input,
|
268
|
+
"decoder_token_ids": decoder_token_id_input,
|
269
|
+
"decoder_padding_mask": decoder_padding_mask_input,
|
270
|
+
},
|
271
|
+
outputs={
|
272
|
+
"encoder_sequence_output": encoder_output,
|
273
|
+
"decoder_sequence_output": decoder_output,
|
274
|
+
},
|
275
|
+
dtype=dtype,
|
276
|
+
**kwargs,
|
277
|
+
)
|
278
|
+
|
279
|
+
# === Config ===
|
280
|
+
self.vocabulary_size = vocabulary_size
|
281
|
+
self.num_layers = num_layers
|
282
|
+
self.num_heads = num_heads
|
283
|
+
self.hidden_dim = hidden_dim
|
284
|
+
self.intermediate_dim = intermediate_dim
|
285
|
+
self.num_mels = num_mels
|
286
|
+
self.dropout = dropout
|
287
|
+
self.max_encoder_sequence_length = max_encoder_sequence_length
|
288
|
+
self.max_decoder_sequence_length = max_decoder_sequence_length
|
289
|
+
|
290
|
+
def get_config(self):
|
291
|
+
config = super().get_config()
|
292
|
+
config.update(
|
293
|
+
{
|
294
|
+
"vocabulary_size": self.vocabulary_size,
|
295
|
+
"num_layers": self.num_layers,
|
296
|
+
"num_heads": self.num_heads,
|
297
|
+
"hidden_dim": self.hidden_dim,
|
298
|
+
"intermediate_dim": self.intermediate_dim,
|
299
|
+
"num_mels": self.num_mels,
|
300
|
+
"dropout": self.dropout,
|
301
|
+
"max_encoder_sequence_length": self.max_encoder_sequence_length,
|
302
|
+
"max_decoder_sequence_length": self.max_decoder_sequence_length,
|
303
|
+
}
|
304
|
+
)
|
305
|
+
return config
|