keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +52 -0
- keras_hub/api/__init__.py +27 -0
- keras_hub/api/layers/__init__.py +47 -0
- keras_hub/api/metrics/__init__.py +24 -0
- keras_hub/api/models/__init__.py +249 -0
- keras_hub/api/samplers/__init__.py +29 -0
- keras_hub/api/tokenizers/__init__.py +35 -0
- keras_hub/src/__init__.py +13 -0
- keras_hub/src/api_export.py +53 -0
- keras_hub/src/layers/__init__.py +13 -0
- keras_hub/src/layers/modeling/__init__.py +13 -0
- keras_hub/src/layers/modeling/alibi_bias.py +143 -0
- keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
- keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
- keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
- keras_hub/src/layers/modeling/position_embedding.py +123 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
- keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
- keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
- keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
- keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
- keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
- keras_hub/src/layers/preprocessing/__init__.py +13 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
- keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
- keras_hub/src/layers/preprocessing/random_swap.py +267 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
- keras_hub/src/metrics/__init__.py +13 -0
- keras_hub/src/metrics/bleu.py +394 -0
- keras_hub/src/metrics/edit_distance.py +197 -0
- keras_hub/src/metrics/perplexity.py +181 -0
- keras_hub/src/metrics/rouge_base.py +204 -0
- keras_hub/src/metrics/rouge_l.py +97 -0
- keras_hub/src/metrics/rouge_n.py +125 -0
- keras_hub/src/models/__init__.py +13 -0
- keras_hub/src/models/albert/__init__.py +20 -0
- keras_hub/src/models/albert/albert_backbone.py +267 -0
- keras_hub/src/models/albert/albert_classifier.py +202 -0
- keras_hub/src/models/albert/albert_masked_lm.py +129 -0
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/albert/albert_preprocessor.py +206 -0
- keras_hub/src/models/albert/albert_presets.py +70 -0
- keras_hub/src/models/albert/albert_tokenizer.py +119 -0
- keras_hub/src/models/backbone.py +311 -0
- keras_hub/src/models/bart/__init__.py +20 -0
- keras_hub/src/models/bart/bart_backbone.py +261 -0
- keras_hub/src/models/bart/bart_preprocessor.py +276 -0
- keras_hub/src/models/bart/bart_presets.py +74 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
- keras_hub/src/models/bart/bart_tokenizer.py +124 -0
- keras_hub/src/models/bert/__init__.py +23 -0
- keras_hub/src/models/bert/bert_backbone.py +227 -0
- keras_hub/src/models/bert/bert_classifier.py +183 -0
- keras_hub/src/models/bert/bert_masked_lm.py +131 -0
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/bert/bert_preprocessor.py +184 -0
- keras_hub/src/models/bert/bert_presets.py +147 -0
- keras_hub/src/models/bert/bert_tokenizer.py +112 -0
- keras_hub/src/models/bloom/__init__.py +20 -0
- keras_hub/src/models/bloom/bloom_attention.py +186 -0
- keras_hub/src/models/bloom/bloom_backbone.py +173 -0
- keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
- keras_hub/src/models/bloom/bloom_decoder.py +206 -0
- keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
- keras_hub/src/models/bloom/bloom_presets.py +121 -0
- keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
- keras_hub/src/models/causal_lm.py +383 -0
- keras_hub/src/models/classifier.py +109 -0
- keras_hub/src/models/csp_darknet/__init__.py +13 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
- keras_hub/src/models/deberta_v3/__init__.py +24 -0
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
- keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
- keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
- keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
- keras_hub/src/models/densenet/__init__.py +13 -0
- keras_hub/src/models/densenet/densenet_backbone.py +210 -0
- keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
- keras_hub/src/models/distil_bert/__init__.py +26 -0
- keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
- keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
- keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
- keras_hub/src/models/electra/__init__.py +20 -0
- keras_hub/src/models/electra/electra_backbone.py +247 -0
- keras_hub/src/models/electra/electra_preprocessor.py +154 -0
- keras_hub/src/models/electra/electra_presets.py +95 -0
- keras_hub/src/models/electra/electra_tokenizer.py +104 -0
- keras_hub/src/models/f_net/__init__.py +20 -0
- keras_hub/src/models/f_net/f_net_backbone.py +236 -0
- keras_hub/src/models/f_net/f_net_classifier.py +154 -0
- keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
- keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
- keras_hub/src/models/f_net/f_net_presets.py +43 -0
- keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
- keras_hub/src/models/falcon/__init__.py +20 -0
- keras_hub/src/models/falcon/falcon_attention.py +156 -0
- keras_hub/src/models/falcon/falcon_backbone.py +164 -0
- keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
- keras_hub/src/models/falcon/falcon_presets.py +30 -0
- keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
- keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
- keras_hub/src/models/feature_pyramid_backbone.py +73 -0
- keras_hub/src/models/gemma/__init__.py +20 -0
- keras_hub/src/models/gemma/gemma_attention.py +250 -0
- keras_hub/src/models/gemma/gemma_backbone.py +316 -0
- keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
- keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
- keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
- keras_hub/src/models/gemma/gemma_presets.py +248 -0
- keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
- keras_hub/src/models/gemma/rms_normalization.py +40 -0
- keras_hub/src/models/gpt2/__init__.py +20 -0
- keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
- keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
- keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
- keras_hub/src/models/image_classifier.py +90 -0
- keras_hub/src/models/llama/__init__.py +20 -0
- keras_hub/src/models/llama/llama_attention.py +225 -0
- keras_hub/src/models/llama/llama_backbone.py +188 -0
- keras_hub/src/models/llama/llama_causal_lm.py +327 -0
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
- keras_hub/src/models/llama/llama_decoder.py +246 -0
- keras_hub/src/models/llama/llama_layernorm.py +48 -0
- keras_hub/src/models/llama/llama_preprocessor.py +189 -0
- keras_hub/src/models/llama/llama_presets.py +80 -0
- keras_hub/src/models/llama/llama_tokenizer.py +84 -0
- keras_hub/src/models/llama3/__init__.py +20 -0
- keras_hub/src/models/llama3/llama3_backbone.py +84 -0
- keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
- keras_hub/src/models/llama3/llama3_presets.py +69 -0
- keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
- keras_hub/src/models/masked_lm.py +101 -0
- keras_hub/src/models/mistral/__init__.py +20 -0
- keras_hub/src/models/mistral/mistral_attention.py +238 -0
- keras_hub/src/models/mistral/mistral_backbone.py +203 -0
- keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
- keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
- keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
- keras_hub/src/models/mistral/mistral_presets.py +48 -0
- keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
- keras_hub/src/models/mix_transformer/__init__.py +13 -0
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
- keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
- keras_hub/src/models/opt/__init__.py +20 -0
- keras_hub/src/models/opt/opt_backbone.py +173 -0
- keras_hub/src/models/opt/opt_causal_lm.py +301 -0
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
- keras_hub/src/models/opt/opt_preprocessor.py +188 -0
- keras_hub/src/models/opt/opt_presets.py +72 -0
- keras_hub/src/models/opt/opt_tokenizer.py +116 -0
- keras_hub/src/models/pali_gemma/__init__.py +23 -0
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
- keras_hub/src/models/phi3/__init__.py +20 -0
- keras_hub/src/models/phi3/phi3_attention.py +260 -0
- keras_hub/src/models/phi3/phi3_backbone.py +224 -0
- keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/phi3/phi3_decoder.py +260 -0
- keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
- keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
- keras_hub/src/models/phi3/phi3_presets.py +50 -0
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
- keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
- keras_hub/src/models/preprocessor.py +207 -0
- keras_hub/src/models/resnet/__init__.py +13 -0
- keras_hub/src/models/resnet/resnet_backbone.py +612 -0
- keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
- keras_hub/src/models/roberta/__init__.py +20 -0
- keras_hub/src/models/roberta/roberta_backbone.py +184 -0
- keras_hub/src/models/roberta/roberta_classifier.py +209 -0
- keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
- keras_hub/src/models/roberta/roberta_presets.py +43 -0
- keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
- keras_hub/src/models/seq_2_seq_lm.py +54 -0
- keras_hub/src/models/t5/__init__.py +20 -0
- keras_hub/src/models/t5/t5_backbone.py +261 -0
- keras_hub/src/models/t5/t5_layer_norm.py +35 -0
- keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
- keras_hub/src/models/t5/t5_presets.py +95 -0
- keras_hub/src/models/t5/t5_tokenizer.py +100 -0
- keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
- keras_hub/src/models/task.py +419 -0
- keras_hub/src/models/vgg/__init__.py +13 -0
- keras_hub/src/models/vgg/vgg_backbone.py +158 -0
- keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
- keras_hub/src/models/vit_det/__init__.py +13 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
- keras_hub/src/models/vit_det/vit_layers.py +565 -0
- keras_hub/src/models/whisper/__init__.py +20 -0
- keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
- keras_hub/src/models/whisper/whisper_backbone.py +305 -0
- keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
- keras_hub/src/models/whisper/whisper_decoder.py +141 -0
- keras_hub/src/models/whisper/whisper_encoder.py +106 -0
- keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
- keras_hub/src/models/whisper/whisper_presets.py +148 -0
- keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
- keras_hub/src/models/xlm_roberta/__init__.py +26 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
- keras_hub/src/models/xlnet/__init__.py +13 -0
- keras_hub/src/models/xlnet/relative_attention.py +459 -0
- keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
- keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
- keras_hub/src/samplers/__init__.py +13 -0
- keras_hub/src/samplers/beam_sampler.py +207 -0
- keras_hub/src/samplers/contrastive_sampler.py +231 -0
- keras_hub/src/samplers/greedy_sampler.py +50 -0
- keras_hub/src/samplers/random_sampler.py +77 -0
- keras_hub/src/samplers/sampler.py +237 -0
- keras_hub/src/samplers/serialization.py +97 -0
- keras_hub/src/samplers/top_k_sampler.py +92 -0
- keras_hub/src/samplers/top_p_sampler.py +113 -0
- keras_hub/src/tests/__init__.py +13 -0
- keras_hub/src/tests/test_case.py +608 -0
- keras_hub/src/tokenizers/__init__.py +13 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
- keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
- keras_hub/src/tokenizers/tokenizer.py +235 -0
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
- keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
- keras_hub/src/utils/__init__.py +13 -0
- keras_hub/src/utils/keras_utils.py +130 -0
- keras_hub/src/utils/pipeline_model.py +293 -0
- keras_hub/src/utils/preset_utils.py +621 -0
- keras_hub/src/utils/python_utils.py +21 -0
- keras_hub/src/utils/tensor_utils.py +206 -0
- keras_hub/src/utils/timm/__init__.py +13 -0
- keras_hub/src/utils/timm/convert.py +37 -0
- keras_hub/src/utils/timm/convert_resnet.py +171 -0
- keras_hub/src/utils/transformers/__init__.py +13 -0
- keras_hub/src/utils/transformers/convert.py +101 -0
- keras_hub/src/utils/transformers/convert_bert.py +173 -0
- keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
- keras_hub/src/utils/transformers/convert_gemma.py +187 -0
- keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
- keras_hub/src/utils/transformers/convert_llama3.py +136 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
- keras_hub/src/version_utils.py +23 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,143 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import math
|
15
|
+
|
16
|
+
import keras
|
17
|
+
from keras import ops
|
18
|
+
|
19
|
+
from keras_hub.src.api_export import keras_hub_export
|
20
|
+
|
21
|
+
|
22
|
+
@keras_hub_export("keras_hub.layers.AlibiBias")
|
23
|
+
class AlibiBias(keras.layers.Layer):
|
24
|
+
"""A layer that adds the alibi bias to attention scores.
|
25
|
+
|
26
|
+
This layer adds the alibi bias to the attention scores. Alibi bias is a
|
27
|
+
linear, non-learned bias. Defined and formalized in
|
28
|
+
[Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation](https://arxiv.org/abs/2108.12409).
|
29
|
+
|
30
|
+
This layer takes as input the attention scores. and returns the attention
|
31
|
+
scores after adding the alibi bias to it. The output will have the same
|
32
|
+
shape as the input.
|
33
|
+
|
34
|
+
Args:
|
35
|
+
alibi_bias_max: int. This value will be used to compute the slope of
|
36
|
+
each head. The heads' slopes are a geometric sequence that starts at
|
37
|
+
`2**(-alibi_bias_max/num_heads)` and uses that same value as its
|
38
|
+
ratio. Defaults to 8.
|
39
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
40
|
+
including `name`, `trainable`, `dtype` etc.
|
41
|
+
|
42
|
+
Call arguments:
|
43
|
+
attention_scores: The result of multipying the query and the key of the
|
44
|
+
multi-head attention layer of the transformer to add alibi bias to
|
45
|
+
it. With shape `(batch_size, num_heads, query_length, key_length)`.
|
46
|
+
|
47
|
+
Example:
|
48
|
+
```python
|
49
|
+
query_length = 10
|
50
|
+
key_length = 10
|
51
|
+
num_heads = 4
|
52
|
+
batch_size = 2
|
53
|
+
hidden_dim = 8
|
54
|
+
|
55
|
+
# Create new alibi layer.
|
56
|
+
alibi_layer = keras_hub.layers.AlibiBias()
|
57
|
+
|
58
|
+
query = np.zeros((batch_size, num_heads, query_length, hidden_dim))
|
59
|
+
key = np.zeros((batch_size, num_heads, hidden_dim, key_length))
|
60
|
+
|
61
|
+
attention_scores = keras.ops.matmul(query, key)
|
62
|
+
|
63
|
+
# Add alibi bias to attention scores.
|
64
|
+
attention_scores = alibi_layer(attention_scores)
|
65
|
+
```
|
66
|
+
|
67
|
+
References:
|
68
|
+
- [Press et al., 2021](https://arxiv.org/abs/2108.12409)
|
69
|
+
"""
|
70
|
+
|
71
|
+
def __init__(
|
72
|
+
self,
|
73
|
+
alibi_bias_max=8,
|
74
|
+
**kwargs,
|
75
|
+
):
|
76
|
+
super().__init__(**kwargs)
|
77
|
+
self.alibi_bias_max = alibi_bias_max
|
78
|
+
|
79
|
+
def call(self, attention_scores):
|
80
|
+
shape = ops.shape(attention_scores)
|
81
|
+
if len(shape) != 4:
|
82
|
+
raise ValueError(
|
83
|
+
"Expected `attention_scores` shape to be "
|
84
|
+
"`(batch_size, num_heads, query_length, key_Length)`."
|
85
|
+
f" Recived shape={shape}"
|
86
|
+
)
|
87
|
+
|
88
|
+
key_length = shape[-1]
|
89
|
+
num_heads = shape[-3]
|
90
|
+
|
91
|
+
alibi_bias = self._get_alibi_bias(num_heads, key_length)
|
92
|
+
|
93
|
+
return ops.add(attention_scores, alibi_bias)
|
94
|
+
|
95
|
+
def _get_alibi_bias(self, num_heads, key_length):
|
96
|
+
slopes = ops.convert_to_tensor(
|
97
|
+
self._get_slopes(num_heads), dtype=self.compute_dtype
|
98
|
+
)
|
99
|
+
slopes = ops.expand_dims(slopes, 1)
|
100
|
+
|
101
|
+
seq_range = ops.expand_dims(
|
102
|
+
ops.arange(1 - key_length, 1, dtype="int32"), 0
|
103
|
+
)
|
104
|
+
seq_range = ops.cast(seq_range, dtype=self.compute_dtype)
|
105
|
+
|
106
|
+
alibi_bias = ops.multiply(slopes, seq_range)
|
107
|
+
alibi_bias = ops.expand_dims(alibi_bias, 1)
|
108
|
+
|
109
|
+
# return shape is `(1, num_heads, 1, key_length)`
|
110
|
+
return ops.expand_dims(alibi_bias, 0)
|
111
|
+
|
112
|
+
def _get_slopes(self, num_heads):
|
113
|
+
# this function is adopted from Alibi original implementation.
|
114
|
+
# https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
|
115
|
+
def get_slopes_power_of_2(n):
|
116
|
+
start = 2 ** (
|
117
|
+
-(2 ** -(math.log2(n) - math.log2(self.alibi_bias_max)))
|
118
|
+
)
|
119
|
+
ratio = start
|
120
|
+
return [start * ratio**i for i in range(n)]
|
121
|
+
|
122
|
+
if math.log2(num_heads).is_integer():
|
123
|
+
return get_slopes_power_of_2(num_heads)
|
124
|
+
else:
|
125
|
+
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
|
126
|
+
return (
|
127
|
+
get_slopes_power_of_2(closest_power_of_2)
|
128
|
+
+ self._get_slopes(2 * closest_power_of_2)[0::2][
|
129
|
+
: num_heads - closest_power_of_2
|
130
|
+
]
|
131
|
+
)
|
132
|
+
|
133
|
+
def compute_output_shape(self, input_shape):
|
134
|
+
return input_shape
|
135
|
+
|
136
|
+
def get_config(self):
|
137
|
+
config = super().get_config()
|
138
|
+
config.update(
|
139
|
+
{
|
140
|
+
"alibi_bias_max": self.alibi_bias_max,
|
141
|
+
}
|
142
|
+
)
|
143
|
+
return config
|
@@ -0,0 +1,137 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import keras
|
16
|
+
from keras import ops
|
17
|
+
|
18
|
+
from keras_hub.src.api_export import keras_hub_export
|
19
|
+
|
20
|
+
|
21
|
+
@keras_hub_export("keras_hub.layers.CachedMultiHeadAttention")
|
22
|
+
class CachedMultiHeadAttention(keras.layers.MultiHeadAttention):
|
23
|
+
"""MultiHeadAttention layer with cache support.
|
24
|
+
|
25
|
+
This layer is suitable for use in autoregressive decoding. It can be used
|
26
|
+
to cache decoder self-attention and cross-attention. The forward pass
|
27
|
+
can happen in one of three modes:
|
28
|
+
|
29
|
+
- No cache, same as regular multi-head attention.
|
30
|
+
- Static cache (`cache_update_index` is None). In this case, the
|
31
|
+
cached key/value projections will be used and the input values will
|
32
|
+
be ignored.
|
33
|
+
- Updated cache (`cache_update_index` is not None). In this case, new
|
34
|
+
key/value projections are computed using the input, and spliced into
|
35
|
+
the cache at the specified index.
|
36
|
+
|
37
|
+
Note that caching is useful only during inference and should not be used
|
38
|
+
during training.
|
39
|
+
|
40
|
+
We use the notation `B`, `T`, `S` below, where `B` is the batch dimension,
|
41
|
+
`T` is the target sequence length, and `S` in the source sequence length.
|
42
|
+
Note that during generative decoding, `T` is usually 1 (you are
|
43
|
+
generating a target sequence of length one to predict the next token).
|
44
|
+
|
45
|
+
Call arguments:
|
46
|
+
query: Query `Tensor` of shape `(B, T, dim)`.
|
47
|
+
value: Value `Tensor` of shape `(B, S*, dim)`. if `cache` is None`, `S*`
|
48
|
+
must equal `S` and match the shape of `attention_mask`. If cache` is
|
49
|
+
not `None`, `S*` can be any length less than `S`, and the computed
|
50
|
+
value will be spliced into `cache` at `cache_update_index`.
|
51
|
+
key: Optional key `Tensor` of shape `(B, S*, dim)`. If `cache` is
|
52
|
+
`None`, `S*` must equal `S` and match the shape of
|
53
|
+
`attention_mask`. If `cache` is not `None`, `S*` can be any length
|
54
|
+
less than `S`, and the computed value will be spliced into `cache`
|
55
|
+
at `cache_update_index`.
|
56
|
+
attention_mask: a boolean mask of shape `(B, T, S)`. `attention_mask`
|
57
|
+
prevents attention to certain positions. The boolean mask specifies
|
58
|
+
which query elements can attend to which key elements, 1 indicates
|
59
|
+
attention and 0 indicates no attention. Broadcasting can happen for
|
60
|
+
the missing batch dimensions and the head dimension.
|
61
|
+
cache: a dense float Tensor. The key/value cache, of shape
|
62
|
+
`[B, 2, S, num_heads, key_dims]`, where `S` must agree with the
|
63
|
+
`attention_mask` shape. This argument is intended for use during
|
64
|
+
generation to avoid recomputing intermediate state.
|
65
|
+
cache_update_index: a int or int Tensor, the index at which to update
|
66
|
+
`cache` (usually the index of the current token being processed
|
67
|
+
when running generation). If `cache_update_index=None` while `cache`
|
68
|
+
is set, the cache will not be updated.
|
69
|
+
training: a boolean indicating whether the layer should behave in
|
70
|
+
training mode or in inference mode.
|
71
|
+
|
72
|
+
Returns:
|
73
|
+
An `(attention_output, cache)` tuple. `attention_output` is the result
|
74
|
+
of the computation, of shape `(B, T, dim)`, where `T` is for target
|
75
|
+
sequence shapes and `dim` is the query input last dimension if
|
76
|
+
`output_shape` is `None`. Otherwise, the multi-head outputs are
|
77
|
+
projected to the shape specified by `output_shape`. `cache` is the
|
78
|
+
updated cache.
|
79
|
+
"""
|
80
|
+
|
81
|
+
def call(
|
82
|
+
self,
|
83
|
+
query,
|
84
|
+
value,
|
85
|
+
key=None,
|
86
|
+
attention_mask=None,
|
87
|
+
cache=None,
|
88
|
+
cache_update_index=None,
|
89
|
+
training=None,
|
90
|
+
):
|
91
|
+
if key is None:
|
92
|
+
key = value
|
93
|
+
|
94
|
+
query = self._query_dense(query)
|
95
|
+
|
96
|
+
# If cache is not `None`, we will use the cache to compute the final key
|
97
|
+
# and value tensors. If `cache_update_index` is not None, we will first
|
98
|
+
# update the cache before use. To do this, we first call the
|
99
|
+
# `_key_dense` and `_value_dense` layers, and copy the outputs into the
|
100
|
+
# cache at the specified index. `cache = None` handles the training
|
101
|
+
# case, where we don't use the cache at all.
|
102
|
+
if cache is not None:
|
103
|
+
key_cache = cache[:, 0, ...]
|
104
|
+
value_cache = cache[:, 1, ...]
|
105
|
+
if cache_update_index is None:
|
106
|
+
key = key_cache
|
107
|
+
value = value_cache
|
108
|
+
else:
|
109
|
+
key_update = self._key_dense(key)
|
110
|
+
value_update = self._value_dense(value)
|
111
|
+
start = [0, cache_update_index, 0, 0]
|
112
|
+
key = ops.slice_update(key_cache, start, key_update)
|
113
|
+
value = ops.slice_update(value_cache, start, value_update)
|
114
|
+
cache = ops.stack((key, value), axis=1)
|
115
|
+
else:
|
116
|
+
if cache_update_index is not None:
|
117
|
+
raise ValueError(
|
118
|
+
"`cache_update_index` should not be set if `cache` is "
|
119
|
+
f"`None`. Received: cache={cache}, "
|
120
|
+
f"cache_update_index={cache_update_index}"
|
121
|
+
)
|
122
|
+
key = self._key_dense(key)
|
123
|
+
value = self._value_dense(value)
|
124
|
+
|
125
|
+
attention_output, attention_scores = self._compute_attention(
|
126
|
+
query=query,
|
127
|
+
key=key,
|
128
|
+
value=value,
|
129
|
+
attention_mask=attention_mask,
|
130
|
+
training=training,
|
131
|
+
)
|
132
|
+
|
133
|
+
attention_output = self._output_dense(attention_output)
|
134
|
+
|
135
|
+
if cache is not None:
|
136
|
+
return attention_output, cache
|
137
|
+
return attention_output
|
@@ -0,0 +1,200 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import keras
|
16
|
+
from keras import ops
|
17
|
+
|
18
|
+
from keras_hub.src.api_export import keras_hub_export
|
19
|
+
from keras_hub.src.utils.keras_utils import clone_initializer
|
20
|
+
|
21
|
+
|
22
|
+
@keras_hub_export("keras_hub.layers.FNetEncoder")
|
23
|
+
class FNetEncoder(keras.layers.Layer):
|
24
|
+
"""FNet encoder.
|
25
|
+
|
26
|
+
This class follows the architecture of FNet encoder layer in the
|
27
|
+
[FNet paper](https://arxiv.org/abs/2105.03824). Users can instantiate
|
28
|
+
multiple instances of this class to stack up the encoder.
|
29
|
+
|
30
|
+
Note on masking: In the official FNet code, padding tokens are added to the
|
31
|
+
the input. However, the padding masks are deleted, i.e., mixing of
|
32
|
+
all tokens is done. This is because certain frequencies will be zeroed
|
33
|
+
out if we apply padding masks in every encoder layer. Hence, we don't
|
34
|
+
take padding mask as input in the call() function.
|
35
|
+
|
36
|
+
Args:
|
37
|
+
intermediate_dim: int. The hidden size of feedforward network.
|
38
|
+
dropout: float. The dropout value, applied in the
|
39
|
+
feedforward network. Defaults to `0.`.
|
40
|
+
activation: string or `keras.activations`. The
|
41
|
+
activation function of feedforward network.
|
42
|
+
Defaults to `"relu"`.
|
43
|
+
layer_norm_epsilon: float. The epsilon value in layer
|
44
|
+
normalization components. Defaults to `1e-5`.
|
45
|
+
kernel_initializer: `str` or `keras.initializers` initializer.
|
46
|
+
The kernel initializer for the dense layers.
|
47
|
+
Defaults to `"glorot_uniform"`.
|
48
|
+
bias_initializer: "string" or `keras.initializers` initializer.
|
49
|
+
The bias initializer for the dense layers.
|
50
|
+
Defaults to `"zeros"`.
|
51
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
52
|
+
including `name`, `trainable`, `dtype` etc.
|
53
|
+
|
54
|
+
Example:
|
55
|
+
|
56
|
+
```python
|
57
|
+
# Create a single FNet encoder layer.
|
58
|
+
encoder = keras_hub.layers.FNetEncoder(
|
59
|
+
intermediate_dim=64)
|
60
|
+
|
61
|
+
# Create a simple model containing the encoder.
|
62
|
+
input = keras.Input(shape=(10, 64))
|
63
|
+
output = encoder(input)
|
64
|
+
model = keras.Model(inputs=input, outputs=output)
|
65
|
+
|
66
|
+
# Call encoder on the inputs.
|
67
|
+
input_data = np.random.uniform(size=(1, 10, 64))
|
68
|
+
output = model(input_data)
|
69
|
+
```
|
70
|
+
|
71
|
+
References:
|
72
|
+
- [Lee-Thorp et al., 2021](https://arxiv.org/abs/2105.03824)
|
73
|
+
"""
|
74
|
+
|
75
|
+
def __init__(
|
76
|
+
self,
|
77
|
+
intermediate_dim,
|
78
|
+
dropout=0,
|
79
|
+
activation="relu",
|
80
|
+
layer_norm_epsilon=1e-5,
|
81
|
+
kernel_initializer="glorot_uniform",
|
82
|
+
bias_initializer="zeros",
|
83
|
+
**kwargs
|
84
|
+
):
|
85
|
+
super().__init__(**kwargs)
|
86
|
+
self.intermediate_dim = intermediate_dim
|
87
|
+
self.dropout = dropout
|
88
|
+
self.activation = keras.activations.get(activation)
|
89
|
+
self.layer_norm_epsilon = layer_norm_epsilon
|
90
|
+
self.kernel_initializer = keras.initializers.get(kernel_initializer)
|
91
|
+
self.bias_initializer = keras.initializers.get(bias_initializer)
|
92
|
+
|
93
|
+
def build(self, inputs_shape):
|
94
|
+
# Create layers based on input shape.
|
95
|
+
feature_size = inputs_shape[-1]
|
96
|
+
|
97
|
+
# Layer Norm layers.
|
98
|
+
self._mixing_layer_norm = keras.layers.LayerNormalization(
|
99
|
+
epsilon=self.layer_norm_epsilon,
|
100
|
+
dtype=self.dtype_policy,
|
101
|
+
name="mixing_layer_norm",
|
102
|
+
)
|
103
|
+
self._mixing_layer_norm.build(inputs_shape)
|
104
|
+
self._output_layer_norm = keras.layers.LayerNormalization(
|
105
|
+
epsilon=self.layer_norm_epsilon,
|
106
|
+
dtype=self.dtype_policy,
|
107
|
+
name="output_layer_norm",
|
108
|
+
)
|
109
|
+
self._output_layer_norm.build(inputs_shape)
|
110
|
+
|
111
|
+
# Feedforward layers.
|
112
|
+
self._intermediate_dense = keras.layers.Dense(
|
113
|
+
self.intermediate_dim,
|
114
|
+
activation=self.activation,
|
115
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
116
|
+
bias_initializer=clone_initializer(self.bias_initializer),
|
117
|
+
dtype=self.dtype_policy,
|
118
|
+
name="intermediate_dense",
|
119
|
+
)
|
120
|
+
self._intermediate_dense.build(inputs_shape)
|
121
|
+
self._output_dense = keras.layers.Dense(
|
122
|
+
feature_size,
|
123
|
+
kernel_initializer=clone_initializer(self.kernel_initializer),
|
124
|
+
bias_initializer=clone_initializer(self.bias_initializer),
|
125
|
+
dtype=self.dtype_policy,
|
126
|
+
name="output_dense",
|
127
|
+
)
|
128
|
+
self._output_dense.build(
|
129
|
+
self._intermediate_dense.compute_output_shape(inputs_shape)
|
130
|
+
)
|
131
|
+
self._output_dropout = keras.layers.Dropout(
|
132
|
+
rate=self.dropout,
|
133
|
+
dtype=self.dtype_policy,
|
134
|
+
name="output_dropout",
|
135
|
+
)
|
136
|
+
self.built = True
|
137
|
+
|
138
|
+
def call(self, inputs, training=None):
|
139
|
+
"""Forward pass of the FNetEncoder.
|
140
|
+
|
141
|
+
Args:
|
142
|
+
inputs: a Tensor. The input data to TransformerEncoder, should be
|
143
|
+
of shape [batch_size, sequence_length, feature_dim].
|
144
|
+
training: a boolean indicating whether the layer should behave in
|
145
|
+
training mode or in inference mode.
|
146
|
+
|
147
|
+
Returns:
|
148
|
+
A Tensor of the same shape as the `inputs`.
|
149
|
+
"""
|
150
|
+
|
151
|
+
def fourier_transform(input):
|
152
|
+
# Apply FFT on the input and take the real part.
|
153
|
+
input_dtype = input.dtype
|
154
|
+
# FFT transforms do not support float16.
|
155
|
+
input = ops.cast(input, "float32")
|
156
|
+
real_in, imaginary_in = (input, ops.zeros_like(input))
|
157
|
+
real_out, _ = ops.fft2((real_in, imaginary_in))
|
158
|
+
return ops.cast(real_out, input_dtype)
|
159
|
+
|
160
|
+
def add_and_norm(input1, input2, norm_layer):
|
161
|
+
return norm_layer(input1 + input2)
|
162
|
+
|
163
|
+
def feed_forward(input):
|
164
|
+
x = self._intermediate_dense(input)
|
165
|
+
x = self._output_dense(x)
|
166
|
+
return self._output_dropout(x, training=training)
|
167
|
+
|
168
|
+
mixing_output = fourier_transform(inputs)
|
169
|
+
|
170
|
+
mixing_output = add_and_norm(
|
171
|
+
inputs, mixing_output, self._mixing_layer_norm
|
172
|
+
)
|
173
|
+
|
174
|
+
feed_forward_output = feed_forward(mixing_output)
|
175
|
+
|
176
|
+
x = add_and_norm(
|
177
|
+
mixing_output, feed_forward_output, self._output_layer_norm
|
178
|
+
)
|
179
|
+
return x
|
180
|
+
|
181
|
+
def get_config(self):
|
182
|
+
config = super().get_config()
|
183
|
+
config.update(
|
184
|
+
{
|
185
|
+
"intermediate_dim": self.intermediate_dim,
|
186
|
+
"dropout": self.dropout,
|
187
|
+
"activation": keras.activations.serialize(self.activation),
|
188
|
+
"layer_norm_epsilon": self.layer_norm_epsilon,
|
189
|
+
"kernel_initializer": keras.initializers.serialize(
|
190
|
+
self.kernel_initializer
|
191
|
+
),
|
192
|
+
"bias_initializer": keras.initializers.serialize(
|
193
|
+
self.bias_initializer
|
194
|
+
),
|
195
|
+
}
|
196
|
+
)
|
197
|
+
return config
|
198
|
+
|
199
|
+
def compute_output_shape(self, inputs_shape):
|
200
|
+
return inputs_shape
|