keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +52 -0
- keras_hub/api/__init__.py +27 -0
- keras_hub/api/layers/__init__.py +47 -0
- keras_hub/api/metrics/__init__.py +24 -0
- keras_hub/api/models/__init__.py +249 -0
- keras_hub/api/samplers/__init__.py +29 -0
- keras_hub/api/tokenizers/__init__.py +35 -0
- keras_hub/src/__init__.py +13 -0
- keras_hub/src/api_export.py +53 -0
- keras_hub/src/layers/__init__.py +13 -0
- keras_hub/src/layers/modeling/__init__.py +13 -0
- keras_hub/src/layers/modeling/alibi_bias.py +143 -0
- keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
- keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
- keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
- keras_hub/src/layers/modeling/position_embedding.py +123 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
- keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
- keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
- keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
- keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
- keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
- keras_hub/src/layers/preprocessing/__init__.py +13 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
- keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
- keras_hub/src/layers/preprocessing/random_swap.py +267 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
- keras_hub/src/metrics/__init__.py +13 -0
- keras_hub/src/metrics/bleu.py +394 -0
- keras_hub/src/metrics/edit_distance.py +197 -0
- keras_hub/src/metrics/perplexity.py +181 -0
- keras_hub/src/metrics/rouge_base.py +204 -0
- keras_hub/src/metrics/rouge_l.py +97 -0
- keras_hub/src/metrics/rouge_n.py +125 -0
- keras_hub/src/models/__init__.py +13 -0
- keras_hub/src/models/albert/__init__.py +20 -0
- keras_hub/src/models/albert/albert_backbone.py +267 -0
- keras_hub/src/models/albert/albert_classifier.py +202 -0
- keras_hub/src/models/albert/albert_masked_lm.py +129 -0
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/albert/albert_preprocessor.py +206 -0
- keras_hub/src/models/albert/albert_presets.py +70 -0
- keras_hub/src/models/albert/albert_tokenizer.py +119 -0
- keras_hub/src/models/backbone.py +311 -0
- keras_hub/src/models/bart/__init__.py +20 -0
- keras_hub/src/models/bart/bart_backbone.py +261 -0
- keras_hub/src/models/bart/bart_preprocessor.py +276 -0
- keras_hub/src/models/bart/bart_presets.py +74 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
- keras_hub/src/models/bart/bart_tokenizer.py +124 -0
- keras_hub/src/models/bert/__init__.py +23 -0
- keras_hub/src/models/bert/bert_backbone.py +227 -0
- keras_hub/src/models/bert/bert_classifier.py +183 -0
- keras_hub/src/models/bert/bert_masked_lm.py +131 -0
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/bert/bert_preprocessor.py +184 -0
- keras_hub/src/models/bert/bert_presets.py +147 -0
- keras_hub/src/models/bert/bert_tokenizer.py +112 -0
- keras_hub/src/models/bloom/__init__.py +20 -0
- keras_hub/src/models/bloom/bloom_attention.py +186 -0
- keras_hub/src/models/bloom/bloom_backbone.py +173 -0
- keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
- keras_hub/src/models/bloom/bloom_decoder.py +206 -0
- keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
- keras_hub/src/models/bloom/bloom_presets.py +121 -0
- keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
- keras_hub/src/models/causal_lm.py +383 -0
- keras_hub/src/models/classifier.py +109 -0
- keras_hub/src/models/csp_darknet/__init__.py +13 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
- keras_hub/src/models/deberta_v3/__init__.py +24 -0
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
- keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
- keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
- keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
- keras_hub/src/models/densenet/__init__.py +13 -0
- keras_hub/src/models/densenet/densenet_backbone.py +210 -0
- keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
- keras_hub/src/models/distil_bert/__init__.py +26 -0
- keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
- keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
- keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
- keras_hub/src/models/electra/__init__.py +20 -0
- keras_hub/src/models/electra/electra_backbone.py +247 -0
- keras_hub/src/models/electra/electra_preprocessor.py +154 -0
- keras_hub/src/models/electra/electra_presets.py +95 -0
- keras_hub/src/models/electra/electra_tokenizer.py +104 -0
- keras_hub/src/models/f_net/__init__.py +20 -0
- keras_hub/src/models/f_net/f_net_backbone.py +236 -0
- keras_hub/src/models/f_net/f_net_classifier.py +154 -0
- keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
- keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
- keras_hub/src/models/f_net/f_net_presets.py +43 -0
- keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
- keras_hub/src/models/falcon/__init__.py +20 -0
- keras_hub/src/models/falcon/falcon_attention.py +156 -0
- keras_hub/src/models/falcon/falcon_backbone.py +164 -0
- keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
- keras_hub/src/models/falcon/falcon_presets.py +30 -0
- keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
- keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
- keras_hub/src/models/feature_pyramid_backbone.py +73 -0
- keras_hub/src/models/gemma/__init__.py +20 -0
- keras_hub/src/models/gemma/gemma_attention.py +250 -0
- keras_hub/src/models/gemma/gemma_backbone.py +316 -0
- keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
- keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
- keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
- keras_hub/src/models/gemma/gemma_presets.py +248 -0
- keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
- keras_hub/src/models/gemma/rms_normalization.py +40 -0
- keras_hub/src/models/gpt2/__init__.py +20 -0
- keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
- keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
- keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
- keras_hub/src/models/image_classifier.py +90 -0
- keras_hub/src/models/llama/__init__.py +20 -0
- keras_hub/src/models/llama/llama_attention.py +225 -0
- keras_hub/src/models/llama/llama_backbone.py +188 -0
- keras_hub/src/models/llama/llama_causal_lm.py +327 -0
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
- keras_hub/src/models/llama/llama_decoder.py +246 -0
- keras_hub/src/models/llama/llama_layernorm.py +48 -0
- keras_hub/src/models/llama/llama_preprocessor.py +189 -0
- keras_hub/src/models/llama/llama_presets.py +80 -0
- keras_hub/src/models/llama/llama_tokenizer.py +84 -0
- keras_hub/src/models/llama3/__init__.py +20 -0
- keras_hub/src/models/llama3/llama3_backbone.py +84 -0
- keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
- keras_hub/src/models/llama3/llama3_presets.py +69 -0
- keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
- keras_hub/src/models/masked_lm.py +101 -0
- keras_hub/src/models/mistral/__init__.py +20 -0
- keras_hub/src/models/mistral/mistral_attention.py +238 -0
- keras_hub/src/models/mistral/mistral_backbone.py +203 -0
- keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
- keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
- keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
- keras_hub/src/models/mistral/mistral_presets.py +48 -0
- keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
- keras_hub/src/models/mix_transformer/__init__.py +13 -0
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
- keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
- keras_hub/src/models/opt/__init__.py +20 -0
- keras_hub/src/models/opt/opt_backbone.py +173 -0
- keras_hub/src/models/opt/opt_causal_lm.py +301 -0
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
- keras_hub/src/models/opt/opt_preprocessor.py +188 -0
- keras_hub/src/models/opt/opt_presets.py +72 -0
- keras_hub/src/models/opt/opt_tokenizer.py +116 -0
- keras_hub/src/models/pali_gemma/__init__.py +23 -0
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
- keras_hub/src/models/phi3/__init__.py +20 -0
- keras_hub/src/models/phi3/phi3_attention.py +260 -0
- keras_hub/src/models/phi3/phi3_backbone.py +224 -0
- keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/phi3/phi3_decoder.py +260 -0
- keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
- keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
- keras_hub/src/models/phi3/phi3_presets.py +50 -0
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
- keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
- keras_hub/src/models/preprocessor.py +207 -0
- keras_hub/src/models/resnet/__init__.py +13 -0
- keras_hub/src/models/resnet/resnet_backbone.py +612 -0
- keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
- keras_hub/src/models/roberta/__init__.py +20 -0
- keras_hub/src/models/roberta/roberta_backbone.py +184 -0
- keras_hub/src/models/roberta/roberta_classifier.py +209 -0
- keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
- keras_hub/src/models/roberta/roberta_presets.py +43 -0
- keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
- keras_hub/src/models/seq_2_seq_lm.py +54 -0
- keras_hub/src/models/t5/__init__.py +20 -0
- keras_hub/src/models/t5/t5_backbone.py +261 -0
- keras_hub/src/models/t5/t5_layer_norm.py +35 -0
- keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
- keras_hub/src/models/t5/t5_presets.py +95 -0
- keras_hub/src/models/t5/t5_tokenizer.py +100 -0
- keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
- keras_hub/src/models/task.py +419 -0
- keras_hub/src/models/vgg/__init__.py +13 -0
- keras_hub/src/models/vgg/vgg_backbone.py +158 -0
- keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
- keras_hub/src/models/vit_det/__init__.py +13 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
- keras_hub/src/models/vit_det/vit_layers.py +565 -0
- keras_hub/src/models/whisper/__init__.py +20 -0
- keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
- keras_hub/src/models/whisper/whisper_backbone.py +305 -0
- keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
- keras_hub/src/models/whisper/whisper_decoder.py +141 -0
- keras_hub/src/models/whisper/whisper_encoder.py +106 -0
- keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
- keras_hub/src/models/whisper/whisper_presets.py +148 -0
- keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
- keras_hub/src/models/xlm_roberta/__init__.py +26 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
- keras_hub/src/models/xlnet/__init__.py +13 -0
- keras_hub/src/models/xlnet/relative_attention.py +459 -0
- keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
- keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
- keras_hub/src/samplers/__init__.py +13 -0
- keras_hub/src/samplers/beam_sampler.py +207 -0
- keras_hub/src/samplers/contrastive_sampler.py +231 -0
- keras_hub/src/samplers/greedy_sampler.py +50 -0
- keras_hub/src/samplers/random_sampler.py +77 -0
- keras_hub/src/samplers/sampler.py +237 -0
- keras_hub/src/samplers/serialization.py +97 -0
- keras_hub/src/samplers/top_k_sampler.py +92 -0
- keras_hub/src/samplers/top_p_sampler.py +113 -0
- keras_hub/src/tests/__init__.py +13 -0
- keras_hub/src/tests/test_case.py +608 -0
- keras_hub/src/tokenizers/__init__.py +13 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
- keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
- keras_hub/src/tokenizers/tokenizer.py +235 -0
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
- keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
- keras_hub/src/utils/__init__.py +13 -0
- keras_hub/src/utils/keras_utils.py +130 -0
- keras_hub/src/utils/pipeline_model.py +293 -0
- keras_hub/src/utils/preset_utils.py +621 -0
- keras_hub/src/utils/python_utils.py +21 -0
- keras_hub/src/utils/tensor_utils.py +206 -0
- keras_hub/src/utils/timm/__init__.py +13 -0
- keras_hub/src/utils/timm/convert.py +37 -0
- keras_hub/src/utils/timm/convert_resnet.py +171 -0
- keras_hub/src/utils/transformers/__init__.py +13 -0
- keras_hub/src/utils/transformers/convert.py +101 -0
- keras_hub/src/utils/transformers/convert_bert.py +173 -0
- keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
- keras_hub/src/utils/transformers/convert_gemma.py +187 -0
- keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
- keras_hub/src/utils/transformers/convert_llama3.py +136 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
- keras_hub/src/version_utils.py +23 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,169 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import keras
|
16
|
+
from keras import ops
|
17
|
+
|
18
|
+
from keras_hub.src.api_export import keras_hub_export
|
19
|
+
|
20
|
+
|
21
|
+
@keras_hub_export("keras_hub.layers.RotaryEmbedding")
|
22
|
+
class RotaryEmbedding(keras.layers.Layer):
|
23
|
+
"""Rotary positional encoding layer.
|
24
|
+
|
25
|
+
This layer encodes absolute positional information with a rotation
|
26
|
+
matrix. It calculates the rotary encoding with a mix of sine and
|
27
|
+
cosine functions with geometrically increasing wavelengths.
|
28
|
+
Defined and formulated in [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864v4).
|
29
|
+
The input must be a tensor with shape a sequence dimension and a feature
|
30
|
+
dimension. Typically, this will either an input with shape
|
31
|
+
`(batch_size, sequence_length, feature_length)` or
|
32
|
+
`(batch_size, sequence_length, num_heads, feature_length)`.
|
33
|
+
This layer will return a new tensor with the rotary embedding applied to
|
34
|
+
the input tensor.
|
35
|
+
|
36
|
+
Args:
|
37
|
+
max_wavelength: int. The maximum angular wavelength of the sine/cosine
|
38
|
+
curves.
|
39
|
+
scaling_factor: float. The scaling factor used to scale positions of
|
40
|
+
the tokens.
|
41
|
+
sequence_axis: int. Sequence axis in the input tensor.
|
42
|
+
feature_axis: int. Feature axis in the input tensor.
|
43
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
44
|
+
including `name`, `trainable`, `dtype` etc.
|
45
|
+
|
46
|
+
Call arguments:
|
47
|
+
inputs: The tensor inputs to apply the embedding to. This can have
|
48
|
+
any shape, but must contain both a sequence and feature axis. The
|
49
|
+
rotary embedding will be applied to `inputs` and returned.
|
50
|
+
start_index: An integer or integer tensor. The starting position to
|
51
|
+
compute the rotary embedding from. This is useful during cached
|
52
|
+
decoding, where each position is predicted separately in a loop.
|
53
|
+
|
54
|
+
Examples:
|
55
|
+
|
56
|
+
```python
|
57
|
+
batch_size = 16
|
58
|
+
feature_length = 18
|
59
|
+
sequence_length = 256
|
60
|
+
num_heads = 8
|
61
|
+
|
62
|
+
# No multi-head dimension.
|
63
|
+
tensor = np.ones((batch_size, sequence_length, feature_length))
|
64
|
+
rot_emb_layer = RotaryEmbedding()
|
65
|
+
tensor_rot = rot_emb_layer(tensor)
|
66
|
+
|
67
|
+
# With multi-head dimension.
|
68
|
+
tensor = np.ones((batch_size, sequence_length, num_heads, feature_length))
|
69
|
+
tensor_rot = rot_emb_layer(tensor)
|
70
|
+
```
|
71
|
+
|
72
|
+
References:
|
73
|
+
- [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864v4)
|
74
|
+
"""
|
75
|
+
|
76
|
+
def __init__(
|
77
|
+
self,
|
78
|
+
max_wavelength=10000,
|
79
|
+
scaling_factor=1.0,
|
80
|
+
sequence_axis=1,
|
81
|
+
feature_axis=-1,
|
82
|
+
**kwargs
|
83
|
+
):
|
84
|
+
super().__init__(**kwargs)
|
85
|
+
self.max_wavelength = max_wavelength
|
86
|
+
self.sequence_axis = sequence_axis
|
87
|
+
self.feature_axis = feature_axis
|
88
|
+
self.scaling_factor = scaling_factor
|
89
|
+
self.built = True
|
90
|
+
|
91
|
+
def call(self, inputs, start_index=0, positions=None):
|
92
|
+
inputs = ops.moveaxis(
|
93
|
+
inputs, (self.feature_axis, self.sequence_axis), (-1, 1)
|
94
|
+
)
|
95
|
+
cos_emb, sin_emb = self._compute_cos_sin_embedding(
|
96
|
+
inputs, start_index, positions
|
97
|
+
)
|
98
|
+
output = self._apply_rotary_pos_emb(inputs, cos_emb, sin_emb)
|
99
|
+
return ops.moveaxis(
|
100
|
+
output, (-1, 1), (self.feature_axis, self.sequence_axis)
|
101
|
+
)
|
102
|
+
|
103
|
+
def _apply_rotary_pos_emb(self, tensor, cos_emb, sin_emb):
|
104
|
+
x1, x2 = ops.split(tensor, 2, axis=-1)
|
105
|
+
# Avoid `ops.concatenate` for now, to avoid a obscure bug with XLA
|
106
|
+
# compilation on jax. We should be able to remove this once the
|
107
|
+
# following PR is in all jax releases we care about:
|
108
|
+
# https://github.com/openxla/xla/pull/7875
|
109
|
+
half_rot_tensor = ops.stack((-x2, x1), axis=-2)
|
110
|
+
half_rot_tensor = ops.reshape(half_rot_tensor, ops.shape(tensor))
|
111
|
+
return (tensor * cos_emb) + (half_rot_tensor * sin_emb)
|
112
|
+
|
113
|
+
def _compute_positions(self, inputs, start_index=0):
|
114
|
+
seq_len = ops.shape(inputs)[1]
|
115
|
+
positions = ops.arange(seq_len, dtype="float32")
|
116
|
+
return positions + ops.cast(start_index, dtype="float32")
|
117
|
+
|
118
|
+
def _compute_cos_sin_embedding(self, inputs, start_index=0, positions=None):
|
119
|
+
feature_axis = len(inputs.shape) - 1
|
120
|
+
sequence_axis = 1
|
121
|
+
|
122
|
+
rotary_dim = ops.shape(inputs)[feature_axis]
|
123
|
+
inverse_freq = self._get_inverse_freq(rotary_dim)
|
124
|
+
|
125
|
+
if positions is None:
|
126
|
+
positions = self._compute_positions(inputs, start_index)
|
127
|
+
else:
|
128
|
+
positions = ops.cast(positions, "float32")
|
129
|
+
|
130
|
+
positions = positions / ops.cast(self.scaling_factor, "float32")
|
131
|
+
freq = ops.einsum("i,j->ij", positions, inverse_freq)
|
132
|
+
embedding = ops.stack((freq, freq), axis=-2)
|
133
|
+
embedding = ops.reshape(
|
134
|
+
embedding, (*ops.shape(freq)[:-1], ops.shape(freq)[-1] * 2)
|
135
|
+
)
|
136
|
+
|
137
|
+
# Reshape the embedding to be broadcastable with input shape.
|
138
|
+
if feature_axis < sequence_axis:
|
139
|
+
embedding = ops.transpose(embedding)
|
140
|
+
for axis in range(len(inputs.shape)):
|
141
|
+
if axis != sequence_axis and axis != feature_axis:
|
142
|
+
embedding = ops.expand_dims(embedding, axis)
|
143
|
+
|
144
|
+
cos_emb = ops.cast(ops.cos(embedding), self.compute_dtype)
|
145
|
+
sin_emb = ops.cast(ops.sin(embedding), self.compute_dtype)
|
146
|
+
return cos_emb, sin_emb
|
147
|
+
|
148
|
+
def _get_inverse_freq(self, rotary_dim):
|
149
|
+
freq_range = ops.divide(
|
150
|
+
ops.arange(0, rotary_dim, 2, dtype="float32"),
|
151
|
+
ops.cast(rotary_dim, "float32"),
|
152
|
+
)
|
153
|
+
inverse_freq = 1.0 / (self.max_wavelength**freq_range)
|
154
|
+
return inverse_freq
|
155
|
+
|
156
|
+
def get_config(self):
|
157
|
+
config = super().get_config()
|
158
|
+
config.update(
|
159
|
+
{
|
160
|
+
"max_wavelength": self.max_wavelength,
|
161
|
+
"scaling_factor": self.scaling_factor,
|
162
|
+
"sequence_axis": self.sequence_axis,
|
163
|
+
"feature_axis": self.feature_axis,
|
164
|
+
}
|
165
|
+
)
|
166
|
+
return config
|
167
|
+
|
168
|
+
def compute_output_shape(self, input_shape):
|
169
|
+
return input_shape
|
@@ -0,0 +1,108 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import keras
|
16
|
+
from keras import ops
|
17
|
+
|
18
|
+
from keras_hub.src.api_export import keras_hub_export
|
19
|
+
|
20
|
+
|
21
|
+
@keras_hub_export("keras_hub.layers.SinePositionEncoding")
|
22
|
+
class SinePositionEncoding(keras.layers.Layer):
|
23
|
+
"""Sinusoidal positional encoding layer.
|
24
|
+
|
25
|
+
This layer calculates the position encoding as a mix of sine and cosine
|
26
|
+
functions with geometrically increasing wavelengths. Defined and formulized
|
27
|
+
in [Attention is All You Need](https://arxiv.org/abs/1706.03762).
|
28
|
+
|
29
|
+
Takes as input an embedded token tensor. The input must have shape
|
30
|
+
[batch_size, sequence_length, feature_size]. This layer will return a
|
31
|
+
positional encoding the same size as the embedded token tensor, which
|
32
|
+
can be added directly to the embedded token tensor.
|
33
|
+
|
34
|
+
Args:
|
35
|
+
max_wavelength: The maximum angular wavelength of the sine/cosine
|
36
|
+
curves, as described in Attention is All You Need. Defaults to
|
37
|
+
`10000`.
|
38
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
39
|
+
including `name`, `trainable`, `dtype` etc.
|
40
|
+
|
41
|
+
Call arguments:
|
42
|
+
inputs: The tensor inputs to compute an embedding for, with shape
|
43
|
+
`(batch_size, sequence_length, hidden_dim)`.
|
44
|
+
start_index: An integer or integer tensor. The starting position to
|
45
|
+
compute the encoding from. This is useful during cached decoding,
|
46
|
+
where each position is predicted separately in a loop.
|
47
|
+
|
48
|
+
Example:
|
49
|
+
```python
|
50
|
+
# create a simple embedding layer with sinusoidal positional encoding
|
51
|
+
seq_len = 100
|
52
|
+
vocab_size = 1000
|
53
|
+
embedding_dim = 32
|
54
|
+
inputs = keras.Input((seq_len,), dtype="float32")
|
55
|
+
embedding = keras.layers.Embedding(
|
56
|
+
input_dim=vocab_size, output_dim=embedding_dim
|
57
|
+
)(inputs)
|
58
|
+
positional_encoding = keras_hub.layers.SinePositionEncoding()(embedding)
|
59
|
+
outputs = embedding + positional_encoding
|
60
|
+
```
|
61
|
+
|
62
|
+
References:
|
63
|
+
- [Vaswani et al., 2017](https://arxiv.org/abs/1706.03762)
|
64
|
+
"""
|
65
|
+
|
66
|
+
def __init__(
|
67
|
+
self,
|
68
|
+
max_wavelength=10000,
|
69
|
+
**kwargs,
|
70
|
+
):
|
71
|
+
super().__init__(**kwargs)
|
72
|
+
self.max_wavelength = max_wavelength
|
73
|
+
self.built = True
|
74
|
+
|
75
|
+
def call(self, inputs, start_index=0):
|
76
|
+
shape = ops.shape(inputs)
|
77
|
+
seq_length = shape[-2]
|
78
|
+
hidden_size = shape[-1]
|
79
|
+
positions = ops.arange(seq_length)
|
80
|
+
positions = ops.cast(positions + start_index, self.compute_dtype)
|
81
|
+
min_freq = ops.cast(1 / self.max_wavelength, dtype=self.compute_dtype)
|
82
|
+
timescales = ops.power(
|
83
|
+
min_freq,
|
84
|
+
ops.cast(2 * (ops.arange(hidden_size) // 2), self.compute_dtype)
|
85
|
+
/ ops.cast(hidden_size, self.compute_dtype),
|
86
|
+
)
|
87
|
+
angles = ops.expand_dims(positions, 1) * ops.expand_dims(timescales, 0)
|
88
|
+
# even indices are sine, odd are cosine
|
89
|
+
cos_mask = ops.cast(ops.arange(hidden_size) % 2, self.compute_dtype)
|
90
|
+
sin_mask = 1 - cos_mask
|
91
|
+
# embedding shape is [seq_length, hidden_size]
|
92
|
+
positional_encodings = (
|
93
|
+
ops.sin(angles) * sin_mask + ops.cos(angles) * cos_mask
|
94
|
+
)
|
95
|
+
|
96
|
+
return ops.broadcast_to(positional_encodings, shape)
|
97
|
+
|
98
|
+
def get_config(self):
|
99
|
+
config = super().get_config()
|
100
|
+
config.update(
|
101
|
+
{
|
102
|
+
"max_wavelength": self.max_wavelength,
|
103
|
+
}
|
104
|
+
)
|
105
|
+
return config
|
106
|
+
|
107
|
+
def compute_output_shape(self, input_shape):
|
108
|
+
return input_shape
|
@@ -0,0 +1,150 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import keras
|
16
|
+
|
17
|
+
from keras_hub.src.api_export import keras_hub_export
|
18
|
+
from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
|
19
|
+
from keras_hub.src.layers.modeling.reversible_embedding import (
|
20
|
+
ReversibleEmbedding,
|
21
|
+
)
|
22
|
+
from keras_hub.src.utils.keras_utils import clone_initializer
|
23
|
+
|
24
|
+
|
25
|
+
@keras_hub_export("keras_hub.layers.TokenAndPositionEmbedding")
|
26
|
+
class TokenAndPositionEmbedding(keras.layers.Layer):
|
27
|
+
"""A layer which sums a token and position embedding.
|
28
|
+
|
29
|
+
Token and position embeddings are ways of representing words and their order
|
30
|
+
in a sentence. This layer creates a `keras.layers.Embedding` token embedding
|
31
|
+
and a `keras_hub.layers.PositionEmbedding` position embedding and sums their
|
32
|
+
output when called. This layer assumes that the last dimension in the input
|
33
|
+
corresponds to the sequence dimension.
|
34
|
+
|
35
|
+
Args:
|
36
|
+
vocabulary_size: The size of the vocabulary.
|
37
|
+
sequence_length: The maximum length of input sequence
|
38
|
+
embedding_dim: The output dimension of the embedding layer
|
39
|
+
tie_weights: Boolean, whether or not the matrix for embedding and
|
40
|
+
the matrix for the `reverse` projection should share the same
|
41
|
+
weights.
|
42
|
+
embeddings_initializer: The initializer to use for the Embedding
|
43
|
+
Layers
|
44
|
+
mask_zero: Boolean, whether or not the input value 0 is a special
|
45
|
+
"padding" value that should be masked out.
|
46
|
+
This is useful when using recurrent layers which may take variable
|
47
|
+
length input. If this is True, then all subsequent layers in the
|
48
|
+
model need to support masking or an exception will be raised.
|
49
|
+
If mask_zero` is set to True, as a consequence, index 0 cannot be
|
50
|
+
used in the vocabulary
|
51
|
+
(input_dim should equal size of vocabulary + 1).
|
52
|
+
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
|
53
|
+
including `name`, `trainable`, `dtype` etc.
|
54
|
+
|
55
|
+
Example:
|
56
|
+
```python
|
57
|
+
inputs = np.ones(shape=(1, 50), dtype="int32")
|
58
|
+
embedding_layer = keras_hub.layers.TokenAndPositionEmbedding(
|
59
|
+
vocabulary_size=10_000,
|
60
|
+
sequence_length=50,
|
61
|
+
embedding_dim=128,
|
62
|
+
)
|
63
|
+
outputs = embedding_layer(inputs)
|
64
|
+
```
|
65
|
+
"""
|
66
|
+
|
67
|
+
def __init__(
|
68
|
+
self,
|
69
|
+
vocabulary_size,
|
70
|
+
sequence_length,
|
71
|
+
embedding_dim,
|
72
|
+
tie_weights=True,
|
73
|
+
embeddings_initializer="uniform",
|
74
|
+
mask_zero=False,
|
75
|
+
**kwargs
|
76
|
+
):
|
77
|
+
super().__init__(**kwargs)
|
78
|
+
if vocabulary_size is None:
|
79
|
+
raise ValueError(
|
80
|
+
"`vocabulary_size` must be an Integer, received `None`."
|
81
|
+
)
|
82
|
+
if sequence_length is None:
|
83
|
+
raise ValueError(
|
84
|
+
"`sequence_length` must be an Integer, received `None`."
|
85
|
+
)
|
86
|
+
if embedding_dim is None:
|
87
|
+
raise ValueError(
|
88
|
+
"`embedding_dim` must be an Integer, received `None`."
|
89
|
+
)
|
90
|
+
self.vocabulary_size = int(vocabulary_size)
|
91
|
+
self.sequence_length = int(sequence_length)
|
92
|
+
self.embedding_dim = int(embedding_dim)
|
93
|
+
self.embeddings_initializer = keras.initializers.get(
|
94
|
+
embeddings_initializer
|
95
|
+
)
|
96
|
+
self.token_embedding = ReversibleEmbedding(
|
97
|
+
vocabulary_size,
|
98
|
+
embedding_dim,
|
99
|
+
tie_weights=tie_weights,
|
100
|
+
embeddings_initializer=clone_initializer(
|
101
|
+
self.embeddings_initializer
|
102
|
+
),
|
103
|
+
mask_zero=mask_zero,
|
104
|
+
dtype=self.dtype_policy,
|
105
|
+
name="token_embedding",
|
106
|
+
)
|
107
|
+
self.position_embedding = PositionEmbedding(
|
108
|
+
sequence_length=sequence_length,
|
109
|
+
initializer=clone_initializer(self.embeddings_initializer),
|
110
|
+
dtype=self.dtype_policy,
|
111
|
+
name="position_embedding",
|
112
|
+
)
|
113
|
+
self.supports_masking = self.token_embedding.supports_masking
|
114
|
+
|
115
|
+
def build(self, input_shape):
|
116
|
+
input_shape = tuple(input_shape)
|
117
|
+
self.token_embedding.build(input_shape)
|
118
|
+
self.position_embedding.build(input_shape + (self.embedding_dim,))
|
119
|
+
self.built = True
|
120
|
+
|
121
|
+
def get_config(self):
|
122
|
+
config = super().get_config()
|
123
|
+
config.update(
|
124
|
+
{
|
125
|
+
"vocabulary_size": self.vocabulary_size,
|
126
|
+
"sequence_length": self.sequence_length,
|
127
|
+
"embedding_dim": self.embedding_dim,
|
128
|
+
"embeddings_initializer": keras.initializers.serialize(
|
129
|
+
self.embeddings_initializer
|
130
|
+
),
|
131
|
+
"tie_weights": self.token_embedding.tie_weights,
|
132
|
+
"mask_zero": self.token_embedding.mask_zero,
|
133
|
+
}
|
134
|
+
)
|
135
|
+
return config
|
136
|
+
|
137
|
+
def call(self, inputs, start_index=0):
|
138
|
+
embedded_tokens = self.token_embedding(inputs)
|
139
|
+
embedded_positions = self.position_embedding(
|
140
|
+
embedded_tokens,
|
141
|
+
start_index=start_index,
|
142
|
+
)
|
143
|
+
outputs = embedded_tokens + embedded_positions
|
144
|
+
return outputs
|
145
|
+
|
146
|
+
def compute_mask(self, inputs, mask=None):
|
147
|
+
return self.token_embedding.compute_mask(inputs, mask=mask)
|
148
|
+
|
149
|
+
def compute_output_shape(self, input_shape):
|
150
|
+
return tuple(input_shape) + (self.embedding_dim,)
|