keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +52 -0
- keras_hub/api/__init__.py +27 -0
- keras_hub/api/layers/__init__.py +47 -0
- keras_hub/api/metrics/__init__.py +24 -0
- keras_hub/api/models/__init__.py +249 -0
- keras_hub/api/samplers/__init__.py +29 -0
- keras_hub/api/tokenizers/__init__.py +35 -0
- keras_hub/src/__init__.py +13 -0
- keras_hub/src/api_export.py +53 -0
- keras_hub/src/layers/__init__.py +13 -0
- keras_hub/src/layers/modeling/__init__.py +13 -0
- keras_hub/src/layers/modeling/alibi_bias.py +143 -0
- keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
- keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
- keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
- keras_hub/src/layers/modeling/position_embedding.py +123 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
- keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
- keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
- keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
- keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
- keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
- keras_hub/src/layers/preprocessing/__init__.py +13 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
- keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
- keras_hub/src/layers/preprocessing/random_swap.py +267 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
- keras_hub/src/metrics/__init__.py +13 -0
- keras_hub/src/metrics/bleu.py +394 -0
- keras_hub/src/metrics/edit_distance.py +197 -0
- keras_hub/src/metrics/perplexity.py +181 -0
- keras_hub/src/metrics/rouge_base.py +204 -0
- keras_hub/src/metrics/rouge_l.py +97 -0
- keras_hub/src/metrics/rouge_n.py +125 -0
- keras_hub/src/models/__init__.py +13 -0
- keras_hub/src/models/albert/__init__.py +20 -0
- keras_hub/src/models/albert/albert_backbone.py +267 -0
- keras_hub/src/models/albert/albert_classifier.py +202 -0
- keras_hub/src/models/albert/albert_masked_lm.py +129 -0
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/albert/albert_preprocessor.py +206 -0
- keras_hub/src/models/albert/albert_presets.py +70 -0
- keras_hub/src/models/albert/albert_tokenizer.py +119 -0
- keras_hub/src/models/backbone.py +311 -0
- keras_hub/src/models/bart/__init__.py +20 -0
- keras_hub/src/models/bart/bart_backbone.py +261 -0
- keras_hub/src/models/bart/bart_preprocessor.py +276 -0
- keras_hub/src/models/bart/bart_presets.py +74 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
- keras_hub/src/models/bart/bart_tokenizer.py +124 -0
- keras_hub/src/models/bert/__init__.py +23 -0
- keras_hub/src/models/bert/bert_backbone.py +227 -0
- keras_hub/src/models/bert/bert_classifier.py +183 -0
- keras_hub/src/models/bert/bert_masked_lm.py +131 -0
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/bert/bert_preprocessor.py +184 -0
- keras_hub/src/models/bert/bert_presets.py +147 -0
- keras_hub/src/models/bert/bert_tokenizer.py +112 -0
- keras_hub/src/models/bloom/__init__.py +20 -0
- keras_hub/src/models/bloom/bloom_attention.py +186 -0
- keras_hub/src/models/bloom/bloom_backbone.py +173 -0
- keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
- keras_hub/src/models/bloom/bloom_decoder.py +206 -0
- keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
- keras_hub/src/models/bloom/bloom_presets.py +121 -0
- keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
- keras_hub/src/models/causal_lm.py +383 -0
- keras_hub/src/models/classifier.py +109 -0
- keras_hub/src/models/csp_darknet/__init__.py +13 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
- keras_hub/src/models/deberta_v3/__init__.py +24 -0
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
- keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
- keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
- keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
- keras_hub/src/models/densenet/__init__.py +13 -0
- keras_hub/src/models/densenet/densenet_backbone.py +210 -0
- keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
- keras_hub/src/models/distil_bert/__init__.py +26 -0
- keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
- keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
- keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
- keras_hub/src/models/electra/__init__.py +20 -0
- keras_hub/src/models/electra/electra_backbone.py +247 -0
- keras_hub/src/models/electra/electra_preprocessor.py +154 -0
- keras_hub/src/models/electra/electra_presets.py +95 -0
- keras_hub/src/models/electra/electra_tokenizer.py +104 -0
- keras_hub/src/models/f_net/__init__.py +20 -0
- keras_hub/src/models/f_net/f_net_backbone.py +236 -0
- keras_hub/src/models/f_net/f_net_classifier.py +154 -0
- keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
- keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
- keras_hub/src/models/f_net/f_net_presets.py +43 -0
- keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
- keras_hub/src/models/falcon/__init__.py +20 -0
- keras_hub/src/models/falcon/falcon_attention.py +156 -0
- keras_hub/src/models/falcon/falcon_backbone.py +164 -0
- keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
- keras_hub/src/models/falcon/falcon_presets.py +30 -0
- keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
- keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
- keras_hub/src/models/feature_pyramid_backbone.py +73 -0
- keras_hub/src/models/gemma/__init__.py +20 -0
- keras_hub/src/models/gemma/gemma_attention.py +250 -0
- keras_hub/src/models/gemma/gemma_backbone.py +316 -0
- keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
- keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
- keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
- keras_hub/src/models/gemma/gemma_presets.py +248 -0
- keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
- keras_hub/src/models/gemma/rms_normalization.py +40 -0
- keras_hub/src/models/gpt2/__init__.py +20 -0
- keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
- keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
- keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
- keras_hub/src/models/image_classifier.py +90 -0
- keras_hub/src/models/llama/__init__.py +20 -0
- keras_hub/src/models/llama/llama_attention.py +225 -0
- keras_hub/src/models/llama/llama_backbone.py +188 -0
- keras_hub/src/models/llama/llama_causal_lm.py +327 -0
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
- keras_hub/src/models/llama/llama_decoder.py +246 -0
- keras_hub/src/models/llama/llama_layernorm.py +48 -0
- keras_hub/src/models/llama/llama_preprocessor.py +189 -0
- keras_hub/src/models/llama/llama_presets.py +80 -0
- keras_hub/src/models/llama/llama_tokenizer.py +84 -0
- keras_hub/src/models/llama3/__init__.py +20 -0
- keras_hub/src/models/llama3/llama3_backbone.py +84 -0
- keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
- keras_hub/src/models/llama3/llama3_presets.py +69 -0
- keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
- keras_hub/src/models/masked_lm.py +101 -0
- keras_hub/src/models/mistral/__init__.py +20 -0
- keras_hub/src/models/mistral/mistral_attention.py +238 -0
- keras_hub/src/models/mistral/mistral_backbone.py +203 -0
- keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
- keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
- keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
- keras_hub/src/models/mistral/mistral_presets.py +48 -0
- keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
- keras_hub/src/models/mix_transformer/__init__.py +13 -0
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
- keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
- keras_hub/src/models/opt/__init__.py +20 -0
- keras_hub/src/models/opt/opt_backbone.py +173 -0
- keras_hub/src/models/opt/opt_causal_lm.py +301 -0
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
- keras_hub/src/models/opt/opt_preprocessor.py +188 -0
- keras_hub/src/models/opt/opt_presets.py +72 -0
- keras_hub/src/models/opt/opt_tokenizer.py +116 -0
- keras_hub/src/models/pali_gemma/__init__.py +23 -0
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
- keras_hub/src/models/phi3/__init__.py +20 -0
- keras_hub/src/models/phi3/phi3_attention.py +260 -0
- keras_hub/src/models/phi3/phi3_backbone.py +224 -0
- keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/phi3/phi3_decoder.py +260 -0
- keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
- keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
- keras_hub/src/models/phi3/phi3_presets.py +50 -0
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
- keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
- keras_hub/src/models/preprocessor.py +207 -0
- keras_hub/src/models/resnet/__init__.py +13 -0
- keras_hub/src/models/resnet/resnet_backbone.py +612 -0
- keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
- keras_hub/src/models/roberta/__init__.py +20 -0
- keras_hub/src/models/roberta/roberta_backbone.py +184 -0
- keras_hub/src/models/roberta/roberta_classifier.py +209 -0
- keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
- keras_hub/src/models/roberta/roberta_presets.py +43 -0
- keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
- keras_hub/src/models/seq_2_seq_lm.py +54 -0
- keras_hub/src/models/t5/__init__.py +20 -0
- keras_hub/src/models/t5/t5_backbone.py +261 -0
- keras_hub/src/models/t5/t5_layer_norm.py +35 -0
- keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
- keras_hub/src/models/t5/t5_presets.py +95 -0
- keras_hub/src/models/t5/t5_tokenizer.py +100 -0
- keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
- keras_hub/src/models/task.py +419 -0
- keras_hub/src/models/vgg/__init__.py +13 -0
- keras_hub/src/models/vgg/vgg_backbone.py +158 -0
- keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
- keras_hub/src/models/vit_det/__init__.py +13 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
- keras_hub/src/models/vit_det/vit_layers.py +565 -0
- keras_hub/src/models/whisper/__init__.py +20 -0
- keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
- keras_hub/src/models/whisper/whisper_backbone.py +305 -0
- keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
- keras_hub/src/models/whisper/whisper_decoder.py +141 -0
- keras_hub/src/models/whisper/whisper_encoder.py +106 -0
- keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
- keras_hub/src/models/whisper/whisper_presets.py +148 -0
- keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
- keras_hub/src/models/xlm_roberta/__init__.py +26 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
- keras_hub/src/models/xlnet/__init__.py +13 -0
- keras_hub/src/models/xlnet/relative_attention.py +459 -0
- keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
- keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
- keras_hub/src/samplers/__init__.py +13 -0
- keras_hub/src/samplers/beam_sampler.py +207 -0
- keras_hub/src/samplers/contrastive_sampler.py +231 -0
- keras_hub/src/samplers/greedy_sampler.py +50 -0
- keras_hub/src/samplers/random_sampler.py +77 -0
- keras_hub/src/samplers/sampler.py +237 -0
- keras_hub/src/samplers/serialization.py +97 -0
- keras_hub/src/samplers/top_k_sampler.py +92 -0
- keras_hub/src/samplers/top_p_sampler.py +113 -0
- keras_hub/src/tests/__init__.py +13 -0
- keras_hub/src/tests/test_case.py +608 -0
- keras_hub/src/tokenizers/__init__.py +13 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
- keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
- keras_hub/src/tokenizers/tokenizer.py +235 -0
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
- keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
- keras_hub/src/utils/__init__.py +13 -0
- keras_hub/src/utils/keras_utils.py +130 -0
- keras_hub/src/utils/pipeline_model.py +293 -0
- keras_hub/src/utils/preset_utils.py +621 -0
- keras_hub/src/utils/python_utils.py +21 -0
- keras_hub/src/utils/tensor_utils.py +206 -0
- keras_hub/src/utils/timm/__init__.py +13 -0
- keras_hub/src/utils/timm/convert.py +37 -0
- keras_hub/src/utils/timm/convert_resnet.py +171 -0
- keras_hub/src/utils/transformers/__init__.py +13 -0
- keras_hub/src/utils/transformers/convert.py +101 -0
- keras_hub/src/utils/transformers/convert_bert.py +173 -0
- keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
- keras_hub/src/utils/transformers/convert_gemma.py +187 -0
- keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
- keras_hub/src/utils/transformers/convert_llama3.py +136 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
- keras_hub/src/version_utils.py +23 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,147 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""BERT model preset configurations."""
|
15
|
+
|
16
|
+
backbone_presets = {
|
17
|
+
"bert_tiny_en_uncased": {
|
18
|
+
"metadata": {
|
19
|
+
"description": (
|
20
|
+
"2-layer BERT model where all input is lowercased. "
|
21
|
+
"Trained on English Wikipedia + BooksCorpus."
|
22
|
+
),
|
23
|
+
"params": 4385920,
|
24
|
+
"official_name": "BERT",
|
25
|
+
"path": "bert",
|
26
|
+
"model_card": "https://github.com/google-research/bert/blob/master/README.md",
|
27
|
+
},
|
28
|
+
"kaggle_handle": "kaggle://keras/bert/keras/bert_tiny_en_uncased/2",
|
29
|
+
},
|
30
|
+
"bert_small_en_uncased": {
|
31
|
+
"metadata": {
|
32
|
+
"description": (
|
33
|
+
"4-layer BERT model where all input is lowercased. "
|
34
|
+
"Trained on English Wikipedia + BooksCorpus."
|
35
|
+
),
|
36
|
+
"params": 28763648,
|
37
|
+
"official_name": "BERT",
|
38
|
+
"path": "bert",
|
39
|
+
"model_card": "https://github.com/google-research/bert/blob/master/README.md",
|
40
|
+
},
|
41
|
+
"kaggle_handle": "kaggle://keras/bert/keras/bert_small_en_uncased/2",
|
42
|
+
},
|
43
|
+
"bert_medium_en_uncased": {
|
44
|
+
"metadata": {
|
45
|
+
"description": (
|
46
|
+
"8-layer BERT model where all input is lowercased. "
|
47
|
+
"Trained on English Wikipedia + BooksCorpus."
|
48
|
+
),
|
49
|
+
"params": 41373184,
|
50
|
+
"official_name": "BERT",
|
51
|
+
"path": "bert",
|
52
|
+
"model_card": "https://github.com/google-research/bert/blob/master/README.md",
|
53
|
+
},
|
54
|
+
"kaggle_handle": "kaggle://keras/bert/keras/bert_medium_en_uncased/2",
|
55
|
+
},
|
56
|
+
"bert_base_en_uncased": {
|
57
|
+
"metadata": {
|
58
|
+
"description": (
|
59
|
+
"12-layer BERT model where all input is lowercased. "
|
60
|
+
"Trained on English Wikipedia + BooksCorpus."
|
61
|
+
),
|
62
|
+
"params": 109482240,
|
63
|
+
"official_name": "BERT",
|
64
|
+
"path": "bert",
|
65
|
+
"model_card": "https://github.com/google-research/bert/blob/master/README.md",
|
66
|
+
},
|
67
|
+
"kaggle_handle": "kaggle://keras/bert/keras/bert_base_en_uncased/2",
|
68
|
+
},
|
69
|
+
"bert_base_en": {
|
70
|
+
"metadata": {
|
71
|
+
"description": (
|
72
|
+
"12-layer BERT model where case is maintained. "
|
73
|
+
"Trained on English Wikipedia + BooksCorpus."
|
74
|
+
),
|
75
|
+
"params": 108310272,
|
76
|
+
"official_name": "BERT",
|
77
|
+
"path": "bert",
|
78
|
+
"model_card": "https://github.com/google-research/bert/blob/master/README.md",
|
79
|
+
},
|
80
|
+
"kaggle_handle": "kaggle://keras/bert/keras/bert_base_en/2",
|
81
|
+
},
|
82
|
+
"bert_base_zh": {
|
83
|
+
"metadata": {
|
84
|
+
"description": (
|
85
|
+
"12-layer BERT model. Trained on Chinese Wikipedia."
|
86
|
+
),
|
87
|
+
"params": 102267648,
|
88
|
+
"official_name": "BERT",
|
89
|
+
"path": "bert",
|
90
|
+
"model_card": "https://github.com/google-research/bert/blob/master/README.md",
|
91
|
+
},
|
92
|
+
"kaggle_handle": "kaggle://keras/bert/keras/bert_base_zh/2",
|
93
|
+
},
|
94
|
+
"bert_base_multi": {
|
95
|
+
"metadata": {
|
96
|
+
"description": (
|
97
|
+
"12-layer BERT model where case is maintained. Trained on trained on Wikipedias of 104 languages"
|
98
|
+
),
|
99
|
+
"params": 177853440,
|
100
|
+
"official_name": "BERT",
|
101
|
+
"path": "bert",
|
102
|
+
"model_card": "https://github.com/google-research/bert/blob/master/README.md",
|
103
|
+
},
|
104
|
+
"kaggle_handle": "kaggle://keras/bert/keras/bert_base_multi/2",
|
105
|
+
},
|
106
|
+
"bert_large_en_uncased": {
|
107
|
+
"metadata": {
|
108
|
+
"description": (
|
109
|
+
"24-layer BERT model where all input is lowercased. "
|
110
|
+
"Trained on English Wikipedia + BooksCorpus."
|
111
|
+
),
|
112
|
+
"params": 335141888,
|
113
|
+
"official_name": "BERT",
|
114
|
+
"path": "bert",
|
115
|
+
"model_card": "https://github.com/google-research/bert/blob/master/README.md",
|
116
|
+
},
|
117
|
+
"kaggle_handle": "kaggle://keras/bert/keras/bert_large_en_uncased/2",
|
118
|
+
},
|
119
|
+
"bert_large_en": {
|
120
|
+
"metadata": {
|
121
|
+
"description": (
|
122
|
+
"24-layer BERT model where case is maintained. "
|
123
|
+
"Trained on English Wikipedia + BooksCorpus."
|
124
|
+
),
|
125
|
+
"params": 333579264,
|
126
|
+
"official_name": "BERT",
|
127
|
+
"path": "bert",
|
128
|
+
"model_card": "https://github.com/google-research/bert/blob/master/README.md",
|
129
|
+
},
|
130
|
+
"kaggle_handle": "kaggle://keras/bert/keras/bert_large_en/2",
|
131
|
+
},
|
132
|
+
}
|
133
|
+
|
134
|
+
classifier_presets = {
|
135
|
+
"bert_tiny_en_uncased_sst2": {
|
136
|
+
"metadata": {
|
137
|
+
"description": (
|
138
|
+
"The bert_tiny_en_uncased backbone model fine-tuned on the SST-2 sentiment analysis dataset."
|
139
|
+
),
|
140
|
+
"params": 4385920,
|
141
|
+
"official_name": "BERT",
|
142
|
+
"path": "bert",
|
143
|
+
"model_card": "https://github.com/google-research/bert/blob/master/README.md",
|
144
|
+
},
|
145
|
+
"kaggle_handle": "kaggle://keras/bert/keras/bert_tiny_en_uncased_sst2/4",
|
146
|
+
}
|
147
|
+
}
|
@@ -0,0 +1,112 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from keras_hub.src.api_export import keras_hub_export
|
16
|
+
from keras_hub.src.tokenizers.word_piece_tokenizer import WordPieceTokenizer
|
17
|
+
|
18
|
+
|
19
|
+
@keras_hub_export("keras_hub.models.BertTokenizer")
|
20
|
+
class BertTokenizer(WordPieceTokenizer):
|
21
|
+
"""A BERT tokenizer using WordPiece subword segmentation.
|
22
|
+
|
23
|
+
This tokenizer class will tokenize raw strings into integer sequences and
|
24
|
+
is based on `keras_hub.tokenizers.WordPieceTokenizer`. Unlike the
|
25
|
+
underlying tokenizer, it will check for all special tokens needed by BERT
|
26
|
+
models and provides a `from_preset()` method to automatically download
|
27
|
+
a matching vocabulary for a BERT preset.
|
28
|
+
|
29
|
+
This tokenizer does not provide truncation or padding of inputs. It can be
|
30
|
+
combined with a `keras_hub.models.BertPreprocessor` layer for input packing.
|
31
|
+
|
32
|
+
If input is a batch of strings (rank > 0), the layer will output a
|
33
|
+
`tf.RaggedTensor` where the last dimension of the output is ragged.
|
34
|
+
|
35
|
+
If input is a scalar string (rank == 0), the layer will output a dense
|
36
|
+
`tf.Tensor` with static shape `[None]`.
|
37
|
+
|
38
|
+
Args:
|
39
|
+
vocabulary: A list of strings or a string filename path. If
|
40
|
+
passing a list, each element of the list should be a single word
|
41
|
+
piece token string. If passing a filename, the file should be a
|
42
|
+
plain text file containing a single word piece token per line.
|
43
|
+
lowercase: If `True`, the input text will be first lowered before
|
44
|
+
tokenization.
|
45
|
+
special_tokens_in_strings: bool. A bool to indicate if the tokenizer
|
46
|
+
should expect special tokens in input strings that should be
|
47
|
+
tokenized and mapped correctly to their ids. Defaults to False.
|
48
|
+
|
49
|
+
Examples:
|
50
|
+
```python
|
51
|
+
# Unbatched input.
|
52
|
+
tokenizer = keras_hub.models.BertTokenizer.from_preset(
|
53
|
+
"bert_base_en_uncased",
|
54
|
+
)
|
55
|
+
tokenizer("The quick brown fox jumped.")
|
56
|
+
|
57
|
+
# Batched input.
|
58
|
+
tokenizer(["The quick brown fox jumped.", "The fox slept."])
|
59
|
+
|
60
|
+
# Detokenization.
|
61
|
+
tokenizer.detokenize(tokenizer("The quick brown fox jumped."))
|
62
|
+
|
63
|
+
# Custom vocabulary.
|
64
|
+
vocab = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
|
65
|
+
vocab += ["The", "quick", "brown", "fox", "jumped", "."]
|
66
|
+
tokenizer = keras_hub.models.BertTokenizer(vocabulary=vocab)
|
67
|
+
tokenizer("The quick brown fox jumped.")
|
68
|
+
```
|
69
|
+
"""
|
70
|
+
|
71
|
+
def __init__(
|
72
|
+
self,
|
73
|
+
vocabulary=None,
|
74
|
+
lowercase=False,
|
75
|
+
special_tokens_in_strings=False,
|
76
|
+
**kwargs,
|
77
|
+
):
|
78
|
+
self.cls_token = "[CLS]"
|
79
|
+
self.sep_token = "[SEP]"
|
80
|
+
self.pad_token = "[PAD]"
|
81
|
+
self.mask_token = "[MASK]"
|
82
|
+
super().__init__(
|
83
|
+
vocabulary=vocabulary,
|
84
|
+
lowercase=lowercase,
|
85
|
+
special_tokens=[
|
86
|
+
self.cls_token,
|
87
|
+
self.sep_token,
|
88
|
+
self.pad_token,
|
89
|
+
self.mask_token,
|
90
|
+
],
|
91
|
+
special_tokens_in_strings=special_tokens_in_strings,
|
92
|
+
**kwargs,
|
93
|
+
)
|
94
|
+
|
95
|
+
def set_vocabulary(self, vocabulary):
|
96
|
+
super().set_vocabulary(vocabulary)
|
97
|
+
|
98
|
+
if vocabulary is not None:
|
99
|
+
self.cls_token_id = self.token_to_id(self.cls_token)
|
100
|
+
self.sep_token_id = self.token_to_id(self.sep_token)
|
101
|
+
self.pad_token_id = self.token_to_id(self.pad_token)
|
102
|
+
self.mask_token_id = self.token_to_id(self.mask_token)
|
103
|
+
else:
|
104
|
+
self.cls_token_id = None
|
105
|
+
self.sep_token_id = None
|
106
|
+
self.pad_token_id = None
|
107
|
+
self.mask_token_id = None
|
108
|
+
|
109
|
+
def get_config(self):
|
110
|
+
config = super().get_config()
|
111
|
+
del config["special_tokens"] # Not configurable; set in __init__.
|
112
|
+
return config
|
@@ -0,0 +1,20 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from keras_hub.src.models.bloom.bloom_backbone import BloomBackbone
|
16
|
+
from keras_hub.src.models.bloom.bloom_presets import backbone_presets
|
17
|
+
from keras_hub.src.models.bloom.bloom_tokenizer import BloomTokenizer
|
18
|
+
from keras_hub.src.utils.preset_utils import register_presets
|
19
|
+
|
20
|
+
register_presets(backbone_presets, (BloomBackbone, BloomTokenizer))
|
@@ -0,0 +1,186 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import math
|
15
|
+
|
16
|
+
import keras
|
17
|
+
from keras import ops
|
18
|
+
|
19
|
+
from keras_hub.src.layers.modeling.alibi_bias import AlibiBias
|
20
|
+
from keras_hub.src.utils.keras_utils import clone_initializer
|
21
|
+
|
22
|
+
|
23
|
+
class BloomAttention(keras.layers.Layer):
|
24
|
+
def __init__(
|
25
|
+
self,
|
26
|
+
num_heads,
|
27
|
+
dropout=0.0,
|
28
|
+
kernel_initializer="glorot_uniform",
|
29
|
+
bias_initializer="zeros",
|
30
|
+
**kwargs,
|
31
|
+
):
|
32
|
+
super().__init__(**kwargs)
|
33
|
+
self.num_heads = num_heads
|
34
|
+
self.dropout = dropout
|
35
|
+
self.kernel_initializer = keras.initializers.get(kernel_initializer)
|
36
|
+
self.bias_initializer = keras.initializers.get(bias_initializer)
|
37
|
+
|
38
|
+
def build(self, inputs_shape):
|
39
|
+
batch_size, seq_length, hidden_dim = inputs_shape
|
40
|
+
|
41
|
+
self.head_dim = hidden_dim // self.num_heads
|
42
|
+
|
43
|
+
# Layer-wise attention scaling
|
44
|
+
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
|
45
|
+
|
46
|
+
self._query_dense = keras.layers.EinsumDense(
|
47
|
+
equation="btm,mnh->btnh",
|
48
|
+
output_shape=(None, self.num_heads, self.head_dim),
|
49
|
+
bias_axes="nh",
|
50
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
51
|
+
bias_initializer=clone_initializer(self.bias_initializer),
|
52
|
+
dtype=self.dtype_policy,
|
53
|
+
name="query_dense",
|
54
|
+
)
|
55
|
+
self._query_dense.build(inputs_shape)
|
56
|
+
|
57
|
+
self._key_dense = keras.layers.EinsumDense(
|
58
|
+
equation="bsm,mnh->bsnh",
|
59
|
+
output_shape=(None, self.num_heads, self.head_dim),
|
60
|
+
bias_axes="nh",
|
61
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
62
|
+
bias_initializer=clone_initializer(self.bias_initializer),
|
63
|
+
dtype=self.dtype_policy,
|
64
|
+
name="key_dense",
|
65
|
+
)
|
66
|
+
self._key_dense.build(inputs_shape)
|
67
|
+
|
68
|
+
self._value_dense = keras.layers.EinsumDense(
|
69
|
+
equation="bsm,mnh->bsnh",
|
70
|
+
output_shape=(None, self.num_heads, self.head_dim),
|
71
|
+
bias_axes="nh",
|
72
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
73
|
+
bias_initializer=clone_initializer(self.bias_initializer),
|
74
|
+
dtype=self.dtype_policy,
|
75
|
+
name="value_dense",
|
76
|
+
)
|
77
|
+
self._value_dense.build(inputs_shape)
|
78
|
+
|
79
|
+
self._alibi_layer = AlibiBias(
|
80
|
+
dtype=self.dtype_policy,
|
81
|
+
)
|
82
|
+
|
83
|
+
self._output_dense = keras.layers.Dense(
|
84
|
+
hidden_dim,
|
85
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
86
|
+
bias_initializer=clone_initializer(self.bias_initializer),
|
87
|
+
dtype=self.dtype_policy,
|
88
|
+
name="output_dense",
|
89
|
+
)
|
90
|
+
self._output_dense.build(inputs_shape)
|
91
|
+
|
92
|
+
self._dropout_layer = keras.layers.Dropout(
|
93
|
+
rate=self.dropout,
|
94
|
+
dtype=self.dtype_policy,
|
95
|
+
name="dropout",
|
96
|
+
)
|
97
|
+
self._softmax = keras.layers.Softmax(
|
98
|
+
dtype="float32",
|
99
|
+
name="softmax",
|
100
|
+
)
|
101
|
+
|
102
|
+
self.built = True
|
103
|
+
|
104
|
+
def call(
|
105
|
+
self,
|
106
|
+
hidden_states,
|
107
|
+
attention_mask=None,
|
108
|
+
cache=None,
|
109
|
+
cache_update_index=None,
|
110
|
+
):
|
111
|
+
batch_size, seq_length, hidden_dim = ops.shape(hidden_states)
|
112
|
+
|
113
|
+
query = self._query_dense(hidden_states)
|
114
|
+
key = self._key_dense(hidden_states)
|
115
|
+
value = self._value_dense(hidden_states)
|
116
|
+
|
117
|
+
if cache is not None:
|
118
|
+
key_cache = cache[:, 0, ...]
|
119
|
+
value_cache = cache[:, 1, ...]
|
120
|
+
if cache_update_index is None:
|
121
|
+
key = key_cache
|
122
|
+
value = value_cache
|
123
|
+
else:
|
124
|
+
start = [0, cache_update_index, 0, 0]
|
125
|
+
key = ops.slice_update(key_cache, start, key)
|
126
|
+
value = ops.slice_update(value_cache, start, value)
|
127
|
+
cache = ops.stack((key, value), axis=1)
|
128
|
+
else:
|
129
|
+
if cache_update_index is not None:
|
130
|
+
raise ValueError(
|
131
|
+
"`cache_update_index` should not be set if `cache` is "
|
132
|
+
f"`None`. Received: cache={cache}, "
|
133
|
+
f"cache_update_index={cache_update_index}"
|
134
|
+
)
|
135
|
+
|
136
|
+
# query (batch_size, num_heads, query_length, head_dim)
|
137
|
+
query = ops.transpose(query, [0, 2, 1, 3])
|
138
|
+
# value (batch_size, num_heads, kv_length, head_dim)
|
139
|
+
value = ops.transpose(value, [0, 2, 1, 3])
|
140
|
+
# key (batch_size, num_heads, head_dim, kv_length)
|
141
|
+
key = ops.transpose(key, [0, 2, 3, 1])
|
142
|
+
|
143
|
+
attention_scores = (
|
144
|
+
ops.matmul(query, key) * self.inv_norm_factor
|
145
|
+
) # [batch_size, num_heads, query_length, kv_length]
|
146
|
+
attention_scores = self._alibi_layer(attention_scores)
|
147
|
+
attention_scores = self._softmax(
|
148
|
+
attention_scores, ops.expand_dims(attention_mask, 1)
|
149
|
+
)
|
150
|
+
attention_scores = self._dropout_layer(attention_scores)
|
151
|
+
|
152
|
+
attention_output = ops.matmul(
|
153
|
+
attention_scores, value
|
154
|
+
) # [batch_size, num_heads, query_length, head_dim]
|
155
|
+
|
156
|
+
attention_output = ops.transpose(
|
157
|
+
attention_output, [0, 2, 1, 3]
|
158
|
+
) # [batch_size, query_length, num_heads, head_dim]
|
159
|
+
attention_output = ops.reshape(
|
160
|
+
attention_output,
|
161
|
+
[batch_size, seq_length, self.num_heads * self.head_dim],
|
162
|
+
) # [batch_size, query_length, hidden_dim]
|
163
|
+
|
164
|
+
attention_output = self._output_dense(attention_output)
|
165
|
+
attention_output = self._dropout_layer(attention_output)
|
166
|
+
|
167
|
+
if cache is not None:
|
168
|
+
return attention_output, cache
|
169
|
+
|
170
|
+
return attention_output
|
171
|
+
|
172
|
+
def get_config(self):
|
173
|
+
config = super().get_config()
|
174
|
+
config.update(
|
175
|
+
{
|
176
|
+
"num_heads": self.num_heads,
|
177
|
+
"dropout": self.dropout,
|
178
|
+
"kernel_initializer": keras.initializers.serialize(
|
179
|
+
self.kernel_initializer
|
180
|
+
),
|
181
|
+
"bias_initializer": keras.initializers.serialize(
|
182
|
+
self.bias_initializer
|
183
|
+
),
|
184
|
+
}
|
185
|
+
)
|
186
|
+
return config
|
@@ -0,0 +1,173 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import keras
|
16
|
+
|
17
|
+
from keras_hub.src.api_export import keras_hub_export
|
18
|
+
from keras_hub.src.layers.modeling.reversible_embedding import (
|
19
|
+
ReversibleEmbedding,
|
20
|
+
)
|
21
|
+
from keras_hub.src.models.backbone import Backbone
|
22
|
+
from keras_hub.src.models.bloom.bloom_decoder import BloomDecoder
|
23
|
+
|
24
|
+
|
25
|
+
def _bloom_kernel_initializer(stddev=0.02):
|
26
|
+
return keras.initializers.RandomNormal(stddev=stddev)
|
27
|
+
|
28
|
+
|
29
|
+
@keras_hub_export("keras_hub.models.BloomBackbone")
|
30
|
+
class BloomBackbone(Backbone):
|
31
|
+
"""A BLOOM decoder network.
|
32
|
+
|
33
|
+
This network implements a Transformer-based decoder network, BigScience
|
34
|
+
Language Open-science Open-access Multilingual (BLOOM), as descriped in
|
35
|
+
["BLOOM: A 176B-Parameter Open-Access Multilingual Language Model"](https://arxiv.org/pdf/2211.05100.pdf).
|
36
|
+
|
37
|
+
The default constructor gives a fully customizable, randomly initialized
|
38
|
+
Bloom model with any number of layers, heads, and embedding dimensions. To
|
39
|
+
load preset architectures and weights, use the `from_preset()` constructor.
|
40
|
+
|
41
|
+
Disclaimer: Pre-trained models are provided on an "as is" basis, without
|
42
|
+
warranties or conditions of any kind. The underlying model is provided by a
|
43
|
+
third party and subject to a separate license, available [here](https://huggingface.co/spaces/bigscience/license).
|
44
|
+
|
45
|
+
Args:
|
46
|
+
vocabulary_size: int. The size of the token vocabulary.
|
47
|
+
num_layers: int. The number of transformer layers.
|
48
|
+
num_heads: int. The number of attention heads for each transformer.
|
49
|
+
The hidden size must be divisible by the number of attention heads.
|
50
|
+
hidden_dim: int. The dimensionality of the embeddings and hidden states.
|
51
|
+
intermediate_dim: int. The output dimension of the first Dense layer in
|
52
|
+
the MLP network of each transformer.
|
53
|
+
dropout: float. Dropout probability for the Transformer decoder.
|
54
|
+
layer_norm_epsilon: float. Epsilon for the layer normalization layers in
|
55
|
+
the transformer decoder.
|
56
|
+
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
|
57
|
+
for model computations and weights. Note that some computations,
|
58
|
+
such as softmax and layer normalization, will always be done at
|
59
|
+
float32 precision regardless of dtype.
|
60
|
+
|
61
|
+
Example:
|
62
|
+
```python
|
63
|
+
input_data = {
|
64
|
+
"token_ids": np.ones(shape=(1, 12), dtype="int32"),
|
65
|
+
"padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]),
|
66
|
+
}
|
67
|
+
|
68
|
+
# Pretrained BLOOM decoder.
|
69
|
+
model = keras_hub.models.BloomBackbone.from_preset("bloom_560m_multi")
|
70
|
+
model(input_data)
|
71
|
+
|
72
|
+
# Randomly initialized BLOOM decoder with a custom config.
|
73
|
+
model = keras_hub.models.BloomBackbone(
|
74
|
+
vocabulary_size=10,
|
75
|
+
num_layers=2,
|
76
|
+
num_heads=2,
|
77
|
+
hidden_dim=32,
|
78
|
+
intermediate_dim=32*4,
|
79
|
+
dropout=0.0,
|
80
|
+
layer_norm_epsilon=1e-5,
|
81
|
+
)
|
82
|
+
model(input_data)
|
83
|
+
```
|
84
|
+
|
85
|
+
"""
|
86
|
+
|
87
|
+
def __init__(
|
88
|
+
self,
|
89
|
+
vocabulary_size,
|
90
|
+
num_layers,
|
91
|
+
num_heads,
|
92
|
+
hidden_dim,
|
93
|
+
intermediate_dim,
|
94
|
+
dropout=0.0,
|
95
|
+
layer_norm_epsilon=1e-5,
|
96
|
+
dtype=None,
|
97
|
+
**kwargs,
|
98
|
+
):
|
99
|
+
# === Layers ===
|
100
|
+
self.token_embedding = ReversibleEmbedding(
|
101
|
+
input_dim=vocabulary_size,
|
102
|
+
output_dim=hidden_dim,
|
103
|
+
embeddings_initializer=_bloom_kernel_initializer(stddev=0.02),
|
104
|
+
dtype=dtype,
|
105
|
+
name="token_embedding",
|
106
|
+
)
|
107
|
+
self.embeddings_layer_norm = keras.layers.LayerNormalization(
|
108
|
+
epsilon=layer_norm_epsilon,
|
109
|
+
dtype=dtype,
|
110
|
+
name="embedding_layernorm",
|
111
|
+
)
|
112
|
+
self.transformer_layers = []
|
113
|
+
for i in range(num_layers):
|
114
|
+
layer = BloomDecoder(
|
115
|
+
num_heads=num_heads,
|
116
|
+
intermediate_dim=intermediate_dim,
|
117
|
+
dropout=dropout,
|
118
|
+
layer_norm_epsilon=layer_norm_epsilon,
|
119
|
+
dtype=dtype,
|
120
|
+
name=f"transformer_layer_{i}",
|
121
|
+
)
|
122
|
+
self.transformer_layers.append(layer)
|
123
|
+
self.layer_norm = keras.layers.LayerNormalization(
|
124
|
+
epsilon=layer_norm_epsilon,
|
125
|
+
dtype=dtype,
|
126
|
+
name="final_layernorm",
|
127
|
+
)
|
128
|
+
|
129
|
+
# === Functional Model ===
|
130
|
+
token_id_input = keras.Input(
|
131
|
+
shape=(None,), dtype="int32", name="token_ids"
|
132
|
+
)
|
133
|
+
padding_mask_input = keras.Input(
|
134
|
+
shape=(None,), dtype="int32", name="padding_mask"
|
135
|
+
)
|
136
|
+
x = self.token_embedding(token_id_input)
|
137
|
+
x = self.embeddings_layer_norm(x)
|
138
|
+
for transformer_layer in self.transformer_layers:
|
139
|
+
x = transformer_layer(x, decoder_padding_mask=padding_mask_input)
|
140
|
+
sequence_output = self.layer_norm(x)
|
141
|
+
super().__init__(
|
142
|
+
inputs={
|
143
|
+
"token_ids": token_id_input,
|
144
|
+
"padding_mask": padding_mask_input,
|
145
|
+
},
|
146
|
+
outputs=sequence_output,
|
147
|
+
dtype=dtype,
|
148
|
+
**kwargs,
|
149
|
+
)
|
150
|
+
|
151
|
+
# === Config ===
|
152
|
+
self.vocabulary_size = vocabulary_size
|
153
|
+
self.num_layers = num_layers
|
154
|
+
self.num_heads = num_heads
|
155
|
+
self.hidden_dim = hidden_dim
|
156
|
+
self.intermediate_dim = intermediate_dim
|
157
|
+
self.dropout = dropout
|
158
|
+
self.layer_norm_epsilon = layer_norm_epsilon
|
159
|
+
|
160
|
+
def get_config(self):
|
161
|
+
config = super().get_config()
|
162
|
+
config.update(
|
163
|
+
{
|
164
|
+
"vocabulary_size": self.vocabulary_size,
|
165
|
+
"num_layers": self.num_layers,
|
166
|
+
"num_heads": self.num_heads,
|
167
|
+
"hidden_dim": self.hidden_dim,
|
168
|
+
"intermediate_dim": self.intermediate_dim,
|
169
|
+
"dropout": self.dropout,
|
170
|
+
"layer_norm_epsilon": self.layer_norm_epsilon,
|
171
|
+
}
|
172
|
+
)
|
173
|
+
return config
|