keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +52 -0
- keras_hub/api/__init__.py +27 -0
- keras_hub/api/layers/__init__.py +47 -0
- keras_hub/api/metrics/__init__.py +24 -0
- keras_hub/api/models/__init__.py +249 -0
- keras_hub/api/samplers/__init__.py +29 -0
- keras_hub/api/tokenizers/__init__.py +35 -0
- keras_hub/src/__init__.py +13 -0
- keras_hub/src/api_export.py +53 -0
- keras_hub/src/layers/__init__.py +13 -0
- keras_hub/src/layers/modeling/__init__.py +13 -0
- keras_hub/src/layers/modeling/alibi_bias.py +143 -0
- keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
- keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
- keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
- keras_hub/src/layers/modeling/position_embedding.py +123 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
- keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
- keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
- keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
- keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
- keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
- keras_hub/src/layers/preprocessing/__init__.py +13 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
- keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
- keras_hub/src/layers/preprocessing/random_swap.py +267 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
- keras_hub/src/metrics/__init__.py +13 -0
- keras_hub/src/metrics/bleu.py +394 -0
- keras_hub/src/metrics/edit_distance.py +197 -0
- keras_hub/src/metrics/perplexity.py +181 -0
- keras_hub/src/metrics/rouge_base.py +204 -0
- keras_hub/src/metrics/rouge_l.py +97 -0
- keras_hub/src/metrics/rouge_n.py +125 -0
- keras_hub/src/models/__init__.py +13 -0
- keras_hub/src/models/albert/__init__.py +20 -0
- keras_hub/src/models/albert/albert_backbone.py +267 -0
- keras_hub/src/models/albert/albert_classifier.py +202 -0
- keras_hub/src/models/albert/albert_masked_lm.py +129 -0
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/albert/albert_preprocessor.py +206 -0
- keras_hub/src/models/albert/albert_presets.py +70 -0
- keras_hub/src/models/albert/albert_tokenizer.py +119 -0
- keras_hub/src/models/backbone.py +311 -0
- keras_hub/src/models/bart/__init__.py +20 -0
- keras_hub/src/models/bart/bart_backbone.py +261 -0
- keras_hub/src/models/bart/bart_preprocessor.py +276 -0
- keras_hub/src/models/bart/bart_presets.py +74 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
- keras_hub/src/models/bart/bart_tokenizer.py +124 -0
- keras_hub/src/models/bert/__init__.py +23 -0
- keras_hub/src/models/bert/bert_backbone.py +227 -0
- keras_hub/src/models/bert/bert_classifier.py +183 -0
- keras_hub/src/models/bert/bert_masked_lm.py +131 -0
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/bert/bert_preprocessor.py +184 -0
- keras_hub/src/models/bert/bert_presets.py +147 -0
- keras_hub/src/models/bert/bert_tokenizer.py +112 -0
- keras_hub/src/models/bloom/__init__.py +20 -0
- keras_hub/src/models/bloom/bloom_attention.py +186 -0
- keras_hub/src/models/bloom/bloom_backbone.py +173 -0
- keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
- keras_hub/src/models/bloom/bloom_decoder.py +206 -0
- keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
- keras_hub/src/models/bloom/bloom_presets.py +121 -0
- keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
- keras_hub/src/models/causal_lm.py +383 -0
- keras_hub/src/models/classifier.py +109 -0
- keras_hub/src/models/csp_darknet/__init__.py +13 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
- keras_hub/src/models/deberta_v3/__init__.py +24 -0
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
- keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
- keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
- keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
- keras_hub/src/models/densenet/__init__.py +13 -0
- keras_hub/src/models/densenet/densenet_backbone.py +210 -0
- keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
- keras_hub/src/models/distil_bert/__init__.py +26 -0
- keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
- keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
- keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
- keras_hub/src/models/electra/__init__.py +20 -0
- keras_hub/src/models/electra/electra_backbone.py +247 -0
- keras_hub/src/models/electra/electra_preprocessor.py +154 -0
- keras_hub/src/models/electra/electra_presets.py +95 -0
- keras_hub/src/models/electra/electra_tokenizer.py +104 -0
- keras_hub/src/models/f_net/__init__.py +20 -0
- keras_hub/src/models/f_net/f_net_backbone.py +236 -0
- keras_hub/src/models/f_net/f_net_classifier.py +154 -0
- keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
- keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
- keras_hub/src/models/f_net/f_net_presets.py +43 -0
- keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
- keras_hub/src/models/falcon/__init__.py +20 -0
- keras_hub/src/models/falcon/falcon_attention.py +156 -0
- keras_hub/src/models/falcon/falcon_backbone.py +164 -0
- keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
- keras_hub/src/models/falcon/falcon_presets.py +30 -0
- keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
- keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
- keras_hub/src/models/feature_pyramid_backbone.py +73 -0
- keras_hub/src/models/gemma/__init__.py +20 -0
- keras_hub/src/models/gemma/gemma_attention.py +250 -0
- keras_hub/src/models/gemma/gemma_backbone.py +316 -0
- keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
- keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
- keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
- keras_hub/src/models/gemma/gemma_presets.py +248 -0
- keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
- keras_hub/src/models/gemma/rms_normalization.py +40 -0
- keras_hub/src/models/gpt2/__init__.py +20 -0
- keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
- keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
- keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
- keras_hub/src/models/image_classifier.py +90 -0
- keras_hub/src/models/llama/__init__.py +20 -0
- keras_hub/src/models/llama/llama_attention.py +225 -0
- keras_hub/src/models/llama/llama_backbone.py +188 -0
- keras_hub/src/models/llama/llama_causal_lm.py +327 -0
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
- keras_hub/src/models/llama/llama_decoder.py +246 -0
- keras_hub/src/models/llama/llama_layernorm.py +48 -0
- keras_hub/src/models/llama/llama_preprocessor.py +189 -0
- keras_hub/src/models/llama/llama_presets.py +80 -0
- keras_hub/src/models/llama/llama_tokenizer.py +84 -0
- keras_hub/src/models/llama3/__init__.py +20 -0
- keras_hub/src/models/llama3/llama3_backbone.py +84 -0
- keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
- keras_hub/src/models/llama3/llama3_presets.py +69 -0
- keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
- keras_hub/src/models/masked_lm.py +101 -0
- keras_hub/src/models/mistral/__init__.py +20 -0
- keras_hub/src/models/mistral/mistral_attention.py +238 -0
- keras_hub/src/models/mistral/mistral_backbone.py +203 -0
- keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
- keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
- keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
- keras_hub/src/models/mistral/mistral_presets.py +48 -0
- keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
- keras_hub/src/models/mix_transformer/__init__.py +13 -0
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
- keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
- keras_hub/src/models/opt/__init__.py +20 -0
- keras_hub/src/models/opt/opt_backbone.py +173 -0
- keras_hub/src/models/opt/opt_causal_lm.py +301 -0
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
- keras_hub/src/models/opt/opt_preprocessor.py +188 -0
- keras_hub/src/models/opt/opt_presets.py +72 -0
- keras_hub/src/models/opt/opt_tokenizer.py +116 -0
- keras_hub/src/models/pali_gemma/__init__.py +23 -0
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
- keras_hub/src/models/phi3/__init__.py +20 -0
- keras_hub/src/models/phi3/phi3_attention.py +260 -0
- keras_hub/src/models/phi3/phi3_backbone.py +224 -0
- keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/phi3/phi3_decoder.py +260 -0
- keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
- keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
- keras_hub/src/models/phi3/phi3_presets.py +50 -0
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
- keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
- keras_hub/src/models/preprocessor.py +207 -0
- keras_hub/src/models/resnet/__init__.py +13 -0
- keras_hub/src/models/resnet/resnet_backbone.py +612 -0
- keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
- keras_hub/src/models/roberta/__init__.py +20 -0
- keras_hub/src/models/roberta/roberta_backbone.py +184 -0
- keras_hub/src/models/roberta/roberta_classifier.py +209 -0
- keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
- keras_hub/src/models/roberta/roberta_presets.py +43 -0
- keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
- keras_hub/src/models/seq_2_seq_lm.py +54 -0
- keras_hub/src/models/t5/__init__.py +20 -0
- keras_hub/src/models/t5/t5_backbone.py +261 -0
- keras_hub/src/models/t5/t5_layer_norm.py +35 -0
- keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
- keras_hub/src/models/t5/t5_presets.py +95 -0
- keras_hub/src/models/t5/t5_tokenizer.py +100 -0
- keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
- keras_hub/src/models/task.py +419 -0
- keras_hub/src/models/vgg/__init__.py +13 -0
- keras_hub/src/models/vgg/vgg_backbone.py +158 -0
- keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
- keras_hub/src/models/vit_det/__init__.py +13 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
- keras_hub/src/models/vit_det/vit_layers.py +565 -0
- keras_hub/src/models/whisper/__init__.py +20 -0
- keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
- keras_hub/src/models/whisper/whisper_backbone.py +305 -0
- keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
- keras_hub/src/models/whisper/whisper_decoder.py +141 -0
- keras_hub/src/models/whisper/whisper_encoder.py +106 -0
- keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
- keras_hub/src/models/whisper/whisper_presets.py +148 -0
- keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
- keras_hub/src/models/xlm_roberta/__init__.py +26 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
- keras_hub/src/models/xlnet/__init__.py +13 -0
- keras_hub/src/models/xlnet/relative_attention.py +459 -0
- keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
- keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
- keras_hub/src/samplers/__init__.py +13 -0
- keras_hub/src/samplers/beam_sampler.py +207 -0
- keras_hub/src/samplers/contrastive_sampler.py +231 -0
- keras_hub/src/samplers/greedy_sampler.py +50 -0
- keras_hub/src/samplers/random_sampler.py +77 -0
- keras_hub/src/samplers/sampler.py +237 -0
- keras_hub/src/samplers/serialization.py +97 -0
- keras_hub/src/samplers/top_k_sampler.py +92 -0
- keras_hub/src/samplers/top_p_sampler.py +113 -0
- keras_hub/src/tests/__init__.py +13 -0
- keras_hub/src/tests/test_case.py +608 -0
- keras_hub/src/tokenizers/__init__.py +13 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
- keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
- keras_hub/src/tokenizers/tokenizer.py +235 -0
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
- keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
- keras_hub/src/utils/__init__.py +13 -0
- keras_hub/src/utils/keras_utils.py +130 -0
- keras_hub/src/utils/pipeline_model.py +293 -0
- keras_hub/src/utils/preset_utils.py +621 -0
- keras_hub/src/utils/python_utils.py +21 -0
- keras_hub/src/utils/tensor_utils.py +206 -0
- keras_hub/src/utils/timm/__init__.py +13 -0
- keras_hub/src/utils/timm/convert.py +37 -0
- keras_hub/src/utils/timm/convert_resnet.py +171 -0
- keras_hub/src/utils/transformers/__init__.py +13 -0
- keras_hub/src/utils/transformers/convert.py +101 -0
- keras_hub/src/utils/transformers/convert_bert.py +173 -0
- keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
- keras_hub/src/utils/transformers/convert_gemma.py +187 -0
- keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
- keras_hub/src/utils/transformers/convert_llama3.py +136 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
- keras_hub/src/version_utils.py +23 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,326 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
|
16
|
+
import keras
|
17
|
+
from absl import logging
|
18
|
+
|
19
|
+
from keras_hub.src.api_export import keras_hub_export
|
20
|
+
from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker
|
21
|
+
from keras_hub.src.models.preprocessor import Preprocessor
|
22
|
+
from keras_hub.src.models.whisper.whisper_audio_feature_extractor import (
|
23
|
+
WhisperAudioFeatureExtractor,
|
24
|
+
)
|
25
|
+
from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer
|
26
|
+
from keras_hub.src.utils.keras_utils import (
|
27
|
+
convert_inputs_to_list_of_tensor_segments,
|
28
|
+
)
|
29
|
+
|
30
|
+
|
31
|
+
@keras_hub_export("keras_hub.models.WhisperPreprocessor")
|
32
|
+
class WhisperPreprocessor(Preprocessor):
|
33
|
+
"""A Whisper preprocessing layer which handles audio and text input.
|
34
|
+
|
35
|
+
This preprocessing layer will do three things:
|
36
|
+
|
37
|
+
1. Compute the log-mel spectrogram of the audio tensor inputs using
|
38
|
+
`audio_feature_extractor`.
|
39
|
+
2. Tokenize decoder inputs using the `tokenizer`.
|
40
|
+
2. Add the appropriate special tokens - `"<|startoftranscript|>", task
|
41
|
+
token, language token, `"<|endoftext|>"`, etc.
|
42
|
+
3. Construct a dictionary with keys `"encoder_features"`,
|
43
|
+
`"decoder_token_ids"`, `"decoder_padding_mask"` that can be passed
|
44
|
+
directly to a Whisper model.
|
45
|
+
|
46
|
+
Args:
|
47
|
+
tokenizer: A `keras_hub.models.WhisperTokenizer` instance.
|
48
|
+
audio_feature_extractor: A
|
49
|
+
`keras_hub.models.WhisperAudioFeatureExtractor` instance or `None`.
|
50
|
+
If `None` a feature extractor with default parameters will be
|
51
|
+
created.
|
52
|
+
decoder_sequence_length: The length of the packed decoder inputs.
|
53
|
+
language: string, language token. Should only be passed if your
|
54
|
+
tokenizer is multilingual.
|
55
|
+
task: string, task name. One of `"transcribe"`, `"translate"`. Should
|
56
|
+
only be passed if your tokenizer is multilingual.
|
57
|
+
no_timestamps: bool. If True, `"<|no_timestamps|>"` will be added as a
|
58
|
+
special token to your input.
|
59
|
+
|
60
|
+
Call arguments:
|
61
|
+
x: A dictionary with `"encoder_audio"` and `"decoder_text"` as its keys.
|
62
|
+
`"encoder_audio"` should correspond to the input audio tensor.
|
63
|
+
`"decoder_text"` should be a tensor of single string sequences.
|
64
|
+
Inputs may be batched or unbatched. Raw python inputs will be
|
65
|
+
converted to tensors.
|
66
|
+
y: Any label data. Will be passed through unaltered.
|
67
|
+
sample_weight: Any label weight data. Will be passed through unaltered.
|
68
|
+
|
69
|
+
Examples:
|
70
|
+
|
71
|
+
Directly calling the layer on data.
|
72
|
+
```python
|
73
|
+
preprocessor = keras_hub.models.WhisperPreprocessor.from_preset(
|
74
|
+
"whisper_tiny_en",
|
75
|
+
)
|
76
|
+
|
77
|
+
# Preprocess unbatched inputs.
|
78
|
+
input_data = {
|
79
|
+
"encoder_audio": tf.ones((200,)),
|
80
|
+
"decoder_text": "The quick brown fox jumped.",
|
81
|
+
}
|
82
|
+
preprocessor(input_data)
|
83
|
+
|
84
|
+
# Preprocess batched inputs.
|
85
|
+
input_data = {
|
86
|
+
"encoder_audio": tf.ones((2, 200)),
|
87
|
+
"decoder_text": ["The quick brown fox jumped.", "Call me Ishmael."],
|
88
|
+
}
|
89
|
+
preprocessor(input_data)
|
90
|
+
|
91
|
+
# Custom audio feature extractor and vocabulary.
|
92
|
+
audio_feature_extractor = keras_hub.models.WhisperAudioFeatureExtractor(
|
93
|
+
num_mels=80,
|
94
|
+
num_fft_bins=400,
|
95
|
+
stride=100,
|
96
|
+
sampling_rate=100,
|
97
|
+
max_audio_length=5,
|
98
|
+
)
|
99
|
+
|
100
|
+
features = ["a quick fox.", "a fox quick."]
|
101
|
+
vocab = {"<|endoftext|>": 0, "a": 4, "Ġquick": 5, "Ġfox": 6}
|
102
|
+
merges = ["Ġ q", "u i", "c k", "ui ck", "Ġq uick"]
|
103
|
+
merges += ["Ġ f", "o x", "Ġf ox"]
|
104
|
+
special_tokens = {
|
105
|
+
"<|startoftranscript|>": 9,
|
106
|
+
"<|endoftext|>": 10,
|
107
|
+
"<|notimestamps|>": 11,
|
108
|
+
"<|transcribe|>": 12,
|
109
|
+
"<|translate|>": 13,
|
110
|
+
}
|
111
|
+
|
112
|
+
tokenizer = keras_hub.models.WhisperTokenizer(
|
113
|
+
vocabulary=vocab,
|
114
|
+
merges=merges,
|
115
|
+
special_tokens=special_tokens,
|
116
|
+
)
|
117
|
+
preprocessor = keras_hub.models.WhisperPreprocessor(
|
118
|
+
audio_feature_extractor=audio_feature_extractor,
|
119
|
+
tokenizer=tokenizer,
|
120
|
+
)
|
121
|
+
|
122
|
+
input_data = {
|
123
|
+
"encoder_audio": tf.ones((200,)),
|
124
|
+
"decoder_text": "The quick brown fox jumped.",
|
125
|
+
}
|
126
|
+
preprocessor(input_data)
|
127
|
+
```
|
128
|
+
|
129
|
+
Mapping with `tf.data.Dataset`.
|
130
|
+
```python
|
131
|
+
preprocessor = keras_hub.models.WhisperPreprocessor.from_preset(
|
132
|
+
"whisper_tiny_en")
|
133
|
+
|
134
|
+
# Map labeled single sentences.
|
135
|
+
features = {
|
136
|
+
"encoder_audio": tf.ones((2, 200)),
|
137
|
+
"decoder_text": ["The quick brown fox jumped.", "Call me Ishmael."],
|
138
|
+
}
|
139
|
+
labels = tf.constant(["True", "False"])
|
140
|
+
ds = tf.data.Dataset.from_tensor_slices((features, labels))
|
141
|
+
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
|
142
|
+
|
143
|
+
# Map unlabeled single sentences.
|
144
|
+
features = {
|
145
|
+
"encoder_audio": tf.ones((2, 200)),
|
146
|
+
"decoder_text": ["The quick brown fox jumped.", "Call me Ishmael."],
|
147
|
+
}
|
148
|
+
ds = tf.data.Dataset.from_tensor_slices(features)
|
149
|
+
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
|
150
|
+
```
|
151
|
+
"""
|
152
|
+
|
153
|
+
tokenizer_cls = WhisperTokenizer
|
154
|
+
|
155
|
+
def __init__(
|
156
|
+
self,
|
157
|
+
tokenizer,
|
158
|
+
audio_feature_extractor=None,
|
159
|
+
decoder_sequence_length=448,
|
160
|
+
language=None,
|
161
|
+
task=None,
|
162
|
+
no_timestamps=True,
|
163
|
+
**kwargs,
|
164
|
+
):
|
165
|
+
super().__init__(**kwargs)
|
166
|
+
if audio_feature_extractor is None:
|
167
|
+
audio_feature_extractor = WhisperAudioFeatureExtractor()
|
168
|
+
self.audio_feature_extractor = audio_feature_extractor
|
169
|
+
self.tokenizer = tokenizer
|
170
|
+
self.decoder_packer = None
|
171
|
+
self.decoder_sequence_length = decoder_sequence_length
|
172
|
+
self.language = language
|
173
|
+
self.task = task
|
174
|
+
self.no_timestamps = no_timestamps
|
175
|
+
|
176
|
+
def build(self, input_shape):
|
177
|
+
# Defer packer creation to `build()` so that we can be sure tokenizer
|
178
|
+
# assets have loaded when restoring a saved model.
|
179
|
+
|
180
|
+
# Create list of tokens to be prepended to decoder inputs.
|
181
|
+
bos_tokens = [self.tokenizer.bos_token_id]
|
182
|
+
if self.tokenizer.language_tokens is not None:
|
183
|
+
if (
|
184
|
+
self.language is None
|
185
|
+
or self.language not in self.tokenizer.language_tokens
|
186
|
+
):
|
187
|
+
raise ValueError(
|
188
|
+
"You must pass a non-None value for `language` when using "
|
189
|
+
"a multilingual tokenizer. The value must be one of "
|
190
|
+
f'{",".join(self.tokenizer.language_tokens.keys())}. '
|
191
|
+
f"Received: language={self.language}."
|
192
|
+
)
|
193
|
+
if self.task is None or self.task not in [
|
194
|
+
"transcribe",
|
195
|
+
"translate",
|
196
|
+
]:
|
197
|
+
raise ValueError(
|
198
|
+
"You must pass a non-None value for `task` when using "
|
199
|
+
"a multilingual tokenizer. The value must be one of "
|
200
|
+
'`"transcribe"`, `"translate"`. '
|
201
|
+
f"Received: task={self.task}."
|
202
|
+
)
|
203
|
+
|
204
|
+
bos_tokens += [self.tokenizer.language_tokens[self.language]]
|
205
|
+
|
206
|
+
if self.task == "transcribe":
|
207
|
+
bos_tokens += [self.tokenizer.special_tokens["<|transcribe|>"]]
|
208
|
+
elif self.task == "translate":
|
209
|
+
bos_tokens += [self.tokenizer.special_tokens["<|translate|>"]]
|
210
|
+
else:
|
211
|
+
if self.language is not None:
|
212
|
+
logging.info(
|
213
|
+
"`tokenizer` is monolingual, and `language` has a "
|
214
|
+
"non-`None` value. Setting `language` to `None`."
|
215
|
+
)
|
216
|
+
self.language = None
|
217
|
+
if self.task is not None:
|
218
|
+
logging.info(
|
219
|
+
"`tokenizer` is monolingual, and `task` has a "
|
220
|
+
"non-`None` value. Setting `task` to `None`."
|
221
|
+
)
|
222
|
+
self.task = None
|
223
|
+
|
224
|
+
if self.no_timestamps:
|
225
|
+
bos_tokens += [self.tokenizer.no_timestamps_token_id]
|
226
|
+
|
227
|
+
# TODO: Use `MultiSegmentPacker` instead of `StartEndPacker` once we
|
228
|
+
# want to move to multi-segment packing and have improved
|
229
|
+
# `MultiSegmentPacker`'s performance.
|
230
|
+
self.decoder_packer = StartEndPacker(
|
231
|
+
start_value=bos_tokens,
|
232
|
+
end_value=self.tokenizer.eos_token_id,
|
233
|
+
pad_value=self.tokenizer.pad_token_id,
|
234
|
+
sequence_length=self.decoder_sequence_length,
|
235
|
+
return_padding_mask=True,
|
236
|
+
)
|
237
|
+
|
238
|
+
def call(self, x, y=None, sample_weight=None, decoder_sequence_length=None):
|
239
|
+
if not (
|
240
|
+
isinstance(x, dict)
|
241
|
+
and ["encoder_audio", "decoder_text"] == list(x.keys())
|
242
|
+
):
|
243
|
+
raise ValueError(
|
244
|
+
'`x` must be a dictionary, containing the keys `"encoder_audio"`'
|
245
|
+
f' and `"decoder_text"`. Received x={x}.'
|
246
|
+
)
|
247
|
+
|
248
|
+
encoder_audio = x["encoder_audio"]
|
249
|
+
decoder_text = x["decoder_text"]
|
250
|
+
|
251
|
+
encoder_audio = convert_inputs_to_list_of_tensor_segments(encoder_audio)
|
252
|
+
decoder_text = convert_inputs_to_list_of_tensor_segments(decoder_text)
|
253
|
+
|
254
|
+
if len(encoder_audio) > 1 or len(decoder_text) > 1:
|
255
|
+
raise ValueError(
|
256
|
+
'`WhisperPreprocessor` requires both `"encoder_audio"` and '
|
257
|
+
f'`"decoder_text"` to contain only one segment, but received '
|
258
|
+
f"{len(encoder_audio)} and {len(decoder_text)}, respectively."
|
259
|
+
)
|
260
|
+
|
261
|
+
encoder_features = self.audio_feature_extractor(encoder_audio[0])
|
262
|
+
decoder_sequence_length = (
|
263
|
+
decoder_sequence_length or self.decoder_sequence_length
|
264
|
+
)
|
265
|
+
decoder_inputs = self.tokenizer(decoder_text[0])
|
266
|
+
decoder_token_ids, decoder_padding_mask = self.decoder_packer(
|
267
|
+
decoder_inputs,
|
268
|
+
sequence_length=decoder_sequence_length,
|
269
|
+
)
|
270
|
+
|
271
|
+
x = {
|
272
|
+
"encoder_features": encoder_features,
|
273
|
+
"decoder_token_ids": decoder_token_ids,
|
274
|
+
"decoder_padding_mask": decoder_padding_mask,
|
275
|
+
}
|
276
|
+
|
277
|
+
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
|
278
|
+
|
279
|
+
def get_config(self):
|
280
|
+
config = super().get_config()
|
281
|
+
config.update(
|
282
|
+
{
|
283
|
+
"audio_feature_extractor": keras.layers.serialize(
|
284
|
+
self.audio_feature_extractor
|
285
|
+
),
|
286
|
+
"decoder_sequence_length": self.decoder_sequence_length,
|
287
|
+
"language": self.language,
|
288
|
+
"task": self.task,
|
289
|
+
"no_timestamps": self.no_timestamps,
|
290
|
+
}
|
291
|
+
)
|
292
|
+
return config
|
293
|
+
|
294
|
+
@classmethod
|
295
|
+
def from_config(cls, config):
|
296
|
+
if "tokenizer" in config and isinstance(config["tokenizer"], dict):
|
297
|
+
config["tokenizer"] = keras.layers.deserialize(config["tokenizer"])
|
298
|
+
|
299
|
+
if "audio_feature_extractor" in config and isinstance(
|
300
|
+
config["audio_feature_extractor"], dict
|
301
|
+
):
|
302
|
+
config["audio_feature_extractor"] = keras.layers.deserialize(
|
303
|
+
config["audio_feature_extractor"]
|
304
|
+
)
|
305
|
+
|
306
|
+
return cls(**config)
|
307
|
+
|
308
|
+
@property
|
309
|
+
def decoder_sequence_length(self):
|
310
|
+
"""The padded length of decoder input sequences."""
|
311
|
+
return self._decoder_sequence_length
|
312
|
+
|
313
|
+
@decoder_sequence_length.setter
|
314
|
+
def decoder_sequence_length(self, value):
|
315
|
+
self._decoder_sequence_length = value
|
316
|
+
if self.decoder_packer is not None:
|
317
|
+
self.decoder_packer.sequence_length = value
|
318
|
+
|
319
|
+
@property
|
320
|
+
def sequence_length(self):
|
321
|
+
"""Alias for `decoder_sequence_length`."""
|
322
|
+
return self.decoder_sequence_length
|
323
|
+
|
324
|
+
@sequence_length.setter
|
325
|
+
def sequence_length(self, value):
|
326
|
+
self.decoder_sequence_length = value
|
@@ -0,0 +1,148 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Metadata for loading pretrained model weights.
|
16
|
+
backbone_presets = {
|
17
|
+
"whisper_tiny_en": {
|
18
|
+
"metadata": {
|
19
|
+
"description": (
|
20
|
+
"4-layer Whisper model. Trained on 438,000 hours of labelled "
|
21
|
+
"English speech data."
|
22
|
+
),
|
23
|
+
"params": 37184256,
|
24
|
+
"official_name": "Whisper",
|
25
|
+
"path": "whisper",
|
26
|
+
"model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
|
27
|
+
},
|
28
|
+
"kaggle_handle": "kaggle://keras/whisper/keras/whisper_tiny_en/2",
|
29
|
+
},
|
30
|
+
"whisper_base_en": {
|
31
|
+
"metadata": {
|
32
|
+
"description": (
|
33
|
+
"6-layer Whisper model. Trained on 438,000 hours of labelled "
|
34
|
+
"English speech data."
|
35
|
+
),
|
36
|
+
"params": 124439808,
|
37
|
+
"official_name": "Whisper",
|
38
|
+
"path": "whisper",
|
39
|
+
"model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
|
40
|
+
},
|
41
|
+
"kaggle_handle": "kaggle://keras/whisper/keras/whisper_base_en/2",
|
42
|
+
},
|
43
|
+
"whisper_small_en": {
|
44
|
+
"metadata": {
|
45
|
+
"description": (
|
46
|
+
"12-layer Whisper model. Trained on 438,000 hours of labelled "
|
47
|
+
"English speech data."
|
48
|
+
),
|
49
|
+
"params": 241734144,
|
50
|
+
"official_name": "Whisper",
|
51
|
+
"path": "whisper",
|
52
|
+
"model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
|
53
|
+
},
|
54
|
+
"kaggle_handle": "kaggle://keras/whisper/keras/whisper_small_en/2",
|
55
|
+
},
|
56
|
+
"whisper_medium_en": {
|
57
|
+
"metadata": {
|
58
|
+
"description": (
|
59
|
+
"24-layer Whisper model. Trained on 438,000 hours of labelled "
|
60
|
+
"English speech data."
|
61
|
+
),
|
62
|
+
"params": 763856896,
|
63
|
+
"official_name": "Whisper",
|
64
|
+
"path": "whisper",
|
65
|
+
"model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
|
66
|
+
},
|
67
|
+
"kaggle_handle": "kaggle://keras/whisper/keras/whisper_medium_en/2",
|
68
|
+
},
|
69
|
+
"whisper_tiny_multi": {
|
70
|
+
"metadata": {
|
71
|
+
"description": (
|
72
|
+
"4-layer Whisper model. Trained on 680,000 hours of labelled "
|
73
|
+
"multilingual speech data."
|
74
|
+
),
|
75
|
+
"params": 37760640,
|
76
|
+
"official_name": "Whisper",
|
77
|
+
"path": "whisper",
|
78
|
+
"model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
|
79
|
+
},
|
80
|
+
"kaggle_handle": "kaggle://keras/whisper/keras/whisper_tiny_multi/2",
|
81
|
+
},
|
82
|
+
"whisper_base_multi": {
|
83
|
+
"metadata": {
|
84
|
+
"description": (
|
85
|
+
"6-layer Whisper model. Trained on 680,000 hours of labelled "
|
86
|
+
"multilingual speech data."
|
87
|
+
),
|
88
|
+
"params": 72593920,
|
89
|
+
"official_name": "Whisper",
|
90
|
+
"path": "whisper",
|
91
|
+
"model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
|
92
|
+
},
|
93
|
+
"kaggle_handle": "kaggle://keras/whisper/keras/whisper_base_multi/2",
|
94
|
+
},
|
95
|
+
"whisper_small_multi": {
|
96
|
+
"metadata": {
|
97
|
+
"description": (
|
98
|
+
"12-layer Whisper model. Trained on 680,000 hours of labelled "
|
99
|
+
"multilingual speech data."
|
100
|
+
),
|
101
|
+
"params": 241734912,
|
102
|
+
"official_name": "Whisper",
|
103
|
+
"path": "whisper",
|
104
|
+
"model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
|
105
|
+
},
|
106
|
+
"kaggle_handle": "kaggle://keras/whisper/keras/whisper_small_multi/2",
|
107
|
+
},
|
108
|
+
"whisper_medium_multi": {
|
109
|
+
"metadata": {
|
110
|
+
"description": (
|
111
|
+
"24-layer Whisper model. Trained on 680,000 hours of labelled "
|
112
|
+
"multilingual speech data."
|
113
|
+
),
|
114
|
+
"params": 763857920,
|
115
|
+
"official_name": "Whisper",
|
116
|
+
"path": "whisper",
|
117
|
+
"model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
|
118
|
+
},
|
119
|
+
"kaggle_handle": "kaggle://keras/whisper/keras/whisper_medium_multi/2",
|
120
|
+
},
|
121
|
+
"whisper_large_multi": {
|
122
|
+
"metadata": {
|
123
|
+
"description": (
|
124
|
+
"32-layer Whisper model. Trained on 680,000 hours of labelled "
|
125
|
+
"multilingual speech data."
|
126
|
+
),
|
127
|
+
"params": 1543304960,
|
128
|
+
"official_name": "Whisper",
|
129
|
+
"path": "whisper",
|
130
|
+
"model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
|
131
|
+
},
|
132
|
+
"kaggle_handle": "kaggle://keras/whisper/keras/whisper_large_multi/2",
|
133
|
+
},
|
134
|
+
"whisper_large_multi_v2": {
|
135
|
+
"metadata": {
|
136
|
+
"description": (
|
137
|
+
"32-layer Whisper model. Trained for 2.5 epochs on 680,000 "
|
138
|
+
"hours of labelled multilingual speech data. An improved "
|
139
|
+
"of `whisper_large_multi`."
|
140
|
+
),
|
141
|
+
"params": 1543304960,
|
142
|
+
"official_name": "Whisper",
|
143
|
+
"path": "whisper",
|
144
|
+
"model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
|
145
|
+
},
|
146
|
+
"kaggle_handle": "kaggle://keras/whisper/keras/whisper_large_multi_v2/2",
|
147
|
+
},
|
148
|
+
}
|
@@ -0,0 +1,163 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import json
|
16
|
+
|
17
|
+
from keras_hub.src.api_export import keras_hub_export
|
18
|
+
from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer
|
19
|
+
|
20
|
+
|
21
|
+
def _load_dict(dict_or_path):
|
22
|
+
if isinstance(dict_or_path, str):
|
23
|
+
with open(dict_or_path, "r", encoding="utf-8") as f:
|
24
|
+
dict_or_path = json.load(f)
|
25
|
+
return dict_or_path
|
26
|
+
|
27
|
+
|
28
|
+
@keras_hub_export("keras_hub.models.WhisperTokenizer")
|
29
|
+
class WhisperTokenizer(BytePairTokenizer):
|
30
|
+
"""Whisper text tokenizer using Byte-Pair Encoding subword segmentation.
|
31
|
+
|
32
|
+
This tokenizer class will tokenize raw strings into integer sequences and
|
33
|
+
is based on `keras_hub.tokenizers.BytePairTokenizer`.
|
34
|
+
This tokenizer does not provide truncation or padding of inputs.
|
35
|
+
|
36
|
+
Args:
|
37
|
+
vocabulary: string or dict, maps token to integer ids. If it is a
|
38
|
+
string, it should be the file path to a json file.
|
39
|
+
merges: string or list, contains the merge rule. If it is a string,
|
40
|
+
it should be the file path to merge rules. The merge rule file
|
41
|
+
should have one merge rule per line. Every merge rule contains
|
42
|
+
merge entities separated by a space.
|
43
|
+
special_tokens: string or dict, maps special tokens to integer IDs. If
|
44
|
+
it is a string, it should be the path to a JSON file.
|
45
|
+
language_tokens: string or dict, maps language tokens to integer IDs. If
|
46
|
+
not None, the tokenizer will be assumed to be a multilingual
|
47
|
+
tokenizer.
|
48
|
+
"""
|
49
|
+
|
50
|
+
def __init__(
|
51
|
+
self,
|
52
|
+
vocabulary=None,
|
53
|
+
merges=None,
|
54
|
+
special_tokens=None,
|
55
|
+
language_tokens=None,
|
56
|
+
**kwargs,
|
57
|
+
):
|
58
|
+
special_tokens = _load_dict(special_tokens)
|
59
|
+
if language_tokens is not None:
|
60
|
+
language_tokens = _load_dict(language_tokens)
|
61
|
+
|
62
|
+
# Necessary special tokens.
|
63
|
+
self.bos_token = "<|startoftranscript|>"
|
64
|
+
self.eos_token = "<|endoftext|>"
|
65
|
+
# TODO: The pad token for the multilingual tokenizer is actually
|
66
|
+
# "", but it errors out (OOM). After BPE is fixed, we can update
|
67
|
+
# this to "". For now, we will use `"<|endoftext|>"`.
|
68
|
+
self.pad_token = "<|endoftext|>"
|
69
|
+
|
70
|
+
self.no_timestamps_token = "<|notimestamps|>"
|
71
|
+
# Task special tokens.
|
72
|
+
self.translate_token = "<|translate|>"
|
73
|
+
self.transcribe_token = "<|transcribe|>"
|
74
|
+
|
75
|
+
for token in [
|
76
|
+
self.bos_token,
|
77
|
+
self.eos_token,
|
78
|
+
self.pad_token,
|
79
|
+
self.no_timestamps_token,
|
80
|
+
self.translate_token,
|
81
|
+
self.transcribe_token,
|
82
|
+
]:
|
83
|
+
if token not in special_tokens:
|
84
|
+
raise ValueError(
|
85
|
+
f"Cannot find token `'{token}'` in the provided "
|
86
|
+
f"`special_tokens`. Please provide `'{token}'` in your "
|
87
|
+
"`special_tokens`."
|
88
|
+
)
|
89
|
+
|
90
|
+
self.bos_token_id = special_tokens[self.bos_token]
|
91
|
+
self.eos_token_id = special_tokens[self.eos_token]
|
92
|
+
self.pad_token_id = special_tokens[self.pad_token]
|
93
|
+
self.no_timestamps_token_id = special_tokens[self.no_timestamps_token]
|
94
|
+
self.translate_token_id = special_tokens[self.translate_token]
|
95
|
+
self.transcribe_token_id = special_tokens[self.transcribe_token]
|
96
|
+
|
97
|
+
self.special_tokens = special_tokens
|
98
|
+
self.language_tokens = language_tokens
|
99
|
+
|
100
|
+
# TODO: Add language tokens to `unsplittable_tokens` once we figure
|
101
|
+
# out the performance issue with a large list.
|
102
|
+
unsplittable_tokens = list(special_tokens.keys())
|
103
|
+
|
104
|
+
super().__init__(
|
105
|
+
vocabulary=vocabulary,
|
106
|
+
merges=merges,
|
107
|
+
unsplittable_tokens=unsplittable_tokens,
|
108
|
+
**kwargs,
|
109
|
+
)
|
110
|
+
|
111
|
+
def save_assets(self, dir_path):
|
112
|
+
# TODO: whisper is currently mutating it's vocabulary before passing
|
113
|
+
# it to the super class, so we need to restore the unmutated vocabulary
|
114
|
+
# before saving our assets. We should find a more robust (and memory
|
115
|
+
# efficient) way to do this.
|
116
|
+
vocabulary = self.vocabulary
|
117
|
+
self.vocabulary = self._initial_vocabulary
|
118
|
+
super().save_assets(dir_path)
|
119
|
+
self.vocabulary = vocabulary
|
120
|
+
|
121
|
+
def set_vocabulary_and_merges(self, vocabulary, merges):
|
122
|
+
if vocabulary is not None:
|
123
|
+
vocabulary = _load_dict(vocabulary)
|
124
|
+
self._initial_vocabulary = dict(vocabulary)
|
125
|
+
|
126
|
+
if self.language_tokens is not None:
|
127
|
+
# Multilingual tokenizer.
|
128
|
+
# Add language tokens to the vocabulary. This makes
|
129
|
+
# detokenization easier for us.
|
130
|
+
vocabulary = {
|
131
|
+
**vocabulary,
|
132
|
+
**self.language_tokens,
|
133
|
+
}
|
134
|
+
|
135
|
+
for token in [
|
136
|
+
self.bos_token,
|
137
|
+
self.eos_token,
|
138
|
+
self.pad_token,
|
139
|
+
self.no_timestamps_token,
|
140
|
+
self.translate_token,
|
141
|
+
self.transcribe_token,
|
142
|
+
]:
|
143
|
+
vocabulary[token] = self.special_tokens[token]
|
144
|
+
else:
|
145
|
+
self._initial_vocabulary = None
|
146
|
+
|
147
|
+
super().set_vocabulary_and_merges(vocabulary, merges)
|
148
|
+
|
149
|
+
def get_config(self):
|
150
|
+
config = super().get_config()
|
151
|
+
|
152
|
+
# In the constructor, we pass the list of special tokens to the
|
153
|
+
# `unsplittable_tokens` arg of the superclass' constructor. Hence, we
|
154
|
+
# delete it from the config here.
|
155
|
+
del config["unsplittable_tokens"]
|
156
|
+
|
157
|
+
config.update(
|
158
|
+
{
|
159
|
+
"special_tokens": self.special_tokens,
|
160
|
+
"language_tokens": self.language_tokens,
|
161
|
+
}
|
162
|
+
)
|
163
|
+
return config
|
@@ -0,0 +1,26 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from keras_hub.src.models.xlm_roberta.xlm_roberta_backbone import (
|
16
|
+
XLMRobertaBackbone,
|
17
|
+
)
|
18
|
+
from keras_hub.src.models.xlm_roberta.xlm_roberta_presets import (
|
19
|
+
backbone_presets,
|
20
|
+
)
|
21
|
+
from keras_hub.src.models.xlm_roberta.xlm_roberta_tokenizer import (
|
22
|
+
XLMRobertaTokenizer,
|
23
|
+
)
|
24
|
+
from keras_hub.src.utils.preset_utils import register_presets
|
25
|
+
|
26
|
+
register_presets(backbone_presets, (XLMRobertaBackbone, XLMRobertaTokenizer))
|