keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +52 -0
- keras_hub/api/__init__.py +27 -0
- keras_hub/api/layers/__init__.py +47 -0
- keras_hub/api/metrics/__init__.py +24 -0
- keras_hub/api/models/__init__.py +249 -0
- keras_hub/api/samplers/__init__.py +29 -0
- keras_hub/api/tokenizers/__init__.py +35 -0
- keras_hub/src/__init__.py +13 -0
- keras_hub/src/api_export.py +53 -0
- keras_hub/src/layers/__init__.py +13 -0
- keras_hub/src/layers/modeling/__init__.py +13 -0
- keras_hub/src/layers/modeling/alibi_bias.py +143 -0
- keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
- keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
- keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
- keras_hub/src/layers/modeling/position_embedding.py +123 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
- keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
- keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
- keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
- keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
- keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
- keras_hub/src/layers/preprocessing/__init__.py +13 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
- keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
- keras_hub/src/layers/preprocessing/random_swap.py +267 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
- keras_hub/src/metrics/__init__.py +13 -0
- keras_hub/src/metrics/bleu.py +394 -0
- keras_hub/src/metrics/edit_distance.py +197 -0
- keras_hub/src/metrics/perplexity.py +181 -0
- keras_hub/src/metrics/rouge_base.py +204 -0
- keras_hub/src/metrics/rouge_l.py +97 -0
- keras_hub/src/metrics/rouge_n.py +125 -0
- keras_hub/src/models/__init__.py +13 -0
- keras_hub/src/models/albert/__init__.py +20 -0
- keras_hub/src/models/albert/albert_backbone.py +267 -0
- keras_hub/src/models/albert/albert_classifier.py +202 -0
- keras_hub/src/models/albert/albert_masked_lm.py +129 -0
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/albert/albert_preprocessor.py +206 -0
- keras_hub/src/models/albert/albert_presets.py +70 -0
- keras_hub/src/models/albert/albert_tokenizer.py +119 -0
- keras_hub/src/models/backbone.py +311 -0
- keras_hub/src/models/bart/__init__.py +20 -0
- keras_hub/src/models/bart/bart_backbone.py +261 -0
- keras_hub/src/models/bart/bart_preprocessor.py +276 -0
- keras_hub/src/models/bart/bart_presets.py +74 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
- keras_hub/src/models/bart/bart_tokenizer.py +124 -0
- keras_hub/src/models/bert/__init__.py +23 -0
- keras_hub/src/models/bert/bert_backbone.py +227 -0
- keras_hub/src/models/bert/bert_classifier.py +183 -0
- keras_hub/src/models/bert/bert_masked_lm.py +131 -0
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/bert/bert_preprocessor.py +184 -0
- keras_hub/src/models/bert/bert_presets.py +147 -0
- keras_hub/src/models/bert/bert_tokenizer.py +112 -0
- keras_hub/src/models/bloom/__init__.py +20 -0
- keras_hub/src/models/bloom/bloom_attention.py +186 -0
- keras_hub/src/models/bloom/bloom_backbone.py +173 -0
- keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
- keras_hub/src/models/bloom/bloom_decoder.py +206 -0
- keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
- keras_hub/src/models/bloom/bloom_presets.py +121 -0
- keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
- keras_hub/src/models/causal_lm.py +383 -0
- keras_hub/src/models/classifier.py +109 -0
- keras_hub/src/models/csp_darknet/__init__.py +13 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
- keras_hub/src/models/deberta_v3/__init__.py +24 -0
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
- keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
- keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
- keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
- keras_hub/src/models/densenet/__init__.py +13 -0
- keras_hub/src/models/densenet/densenet_backbone.py +210 -0
- keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
- keras_hub/src/models/distil_bert/__init__.py +26 -0
- keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
- keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
- keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
- keras_hub/src/models/electra/__init__.py +20 -0
- keras_hub/src/models/electra/electra_backbone.py +247 -0
- keras_hub/src/models/electra/electra_preprocessor.py +154 -0
- keras_hub/src/models/electra/electra_presets.py +95 -0
- keras_hub/src/models/electra/electra_tokenizer.py +104 -0
- keras_hub/src/models/f_net/__init__.py +20 -0
- keras_hub/src/models/f_net/f_net_backbone.py +236 -0
- keras_hub/src/models/f_net/f_net_classifier.py +154 -0
- keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
- keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
- keras_hub/src/models/f_net/f_net_presets.py +43 -0
- keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
- keras_hub/src/models/falcon/__init__.py +20 -0
- keras_hub/src/models/falcon/falcon_attention.py +156 -0
- keras_hub/src/models/falcon/falcon_backbone.py +164 -0
- keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
- keras_hub/src/models/falcon/falcon_presets.py +30 -0
- keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
- keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
- keras_hub/src/models/feature_pyramid_backbone.py +73 -0
- keras_hub/src/models/gemma/__init__.py +20 -0
- keras_hub/src/models/gemma/gemma_attention.py +250 -0
- keras_hub/src/models/gemma/gemma_backbone.py +316 -0
- keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
- keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
- keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
- keras_hub/src/models/gemma/gemma_presets.py +248 -0
- keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
- keras_hub/src/models/gemma/rms_normalization.py +40 -0
- keras_hub/src/models/gpt2/__init__.py +20 -0
- keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
- keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
- keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
- keras_hub/src/models/image_classifier.py +90 -0
- keras_hub/src/models/llama/__init__.py +20 -0
- keras_hub/src/models/llama/llama_attention.py +225 -0
- keras_hub/src/models/llama/llama_backbone.py +188 -0
- keras_hub/src/models/llama/llama_causal_lm.py +327 -0
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
- keras_hub/src/models/llama/llama_decoder.py +246 -0
- keras_hub/src/models/llama/llama_layernorm.py +48 -0
- keras_hub/src/models/llama/llama_preprocessor.py +189 -0
- keras_hub/src/models/llama/llama_presets.py +80 -0
- keras_hub/src/models/llama/llama_tokenizer.py +84 -0
- keras_hub/src/models/llama3/__init__.py +20 -0
- keras_hub/src/models/llama3/llama3_backbone.py +84 -0
- keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
- keras_hub/src/models/llama3/llama3_presets.py +69 -0
- keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
- keras_hub/src/models/masked_lm.py +101 -0
- keras_hub/src/models/mistral/__init__.py +20 -0
- keras_hub/src/models/mistral/mistral_attention.py +238 -0
- keras_hub/src/models/mistral/mistral_backbone.py +203 -0
- keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
- keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
- keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
- keras_hub/src/models/mistral/mistral_presets.py +48 -0
- keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
- keras_hub/src/models/mix_transformer/__init__.py +13 -0
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
- keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
- keras_hub/src/models/opt/__init__.py +20 -0
- keras_hub/src/models/opt/opt_backbone.py +173 -0
- keras_hub/src/models/opt/opt_causal_lm.py +301 -0
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
- keras_hub/src/models/opt/opt_preprocessor.py +188 -0
- keras_hub/src/models/opt/opt_presets.py +72 -0
- keras_hub/src/models/opt/opt_tokenizer.py +116 -0
- keras_hub/src/models/pali_gemma/__init__.py +23 -0
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
- keras_hub/src/models/phi3/__init__.py +20 -0
- keras_hub/src/models/phi3/phi3_attention.py +260 -0
- keras_hub/src/models/phi3/phi3_backbone.py +224 -0
- keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/phi3/phi3_decoder.py +260 -0
- keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
- keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
- keras_hub/src/models/phi3/phi3_presets.py +50 -0
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
- keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
- keras_hub/src/models/preprocessor.py +207 -0
- keras_hub/src/models/resnet/__init__.py +13 -0
- keras_hub/src/models/resnet/resnet_backbone.py +612 -0
- keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
- keras_hub/src/models/roberta/__init__.py +20 -0
- keras_hub/src/models/roberta/roberta_backbone.py +184 -0
- keras_hub/src/models/roberta/roberta_classifier.py +209 -0
- keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
- keras_hub/src/models/roberta/roberta_presets.py +43 -0
- keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
- keras_hub/src/models/seq_2_seq_lm.py +54 -0
- keras_hub/src/models/t5/__init__.py +20 -0
- keras_hub/src/models/t5/t5_backbone.py +261 -0
- keras_hub/src/models/t5/t5_layer_norm.py +35 -0
- keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
- keras_hub/src/models/t5/t5_presets.py +95 -0
- keras_hub/src/models/t5/t5_tokenizer.py +100 -0
- keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
- keras_hub/src/models/task.py +419 -0
- keras_hub/src/models/vgg/__init__.py +13 -0
- keras_hub/src/models/vgg/vgg_backbone.py +158 -0
- keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
- keras_hub/src/models/vit_det/__init__.py +13 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
- keras_hub/src/models/vit_det/vit_layers.py +565 -0
- keras_hub/src/models/whisper/__init__.py +20 -0
- keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
- keras_hub/src/models/whisper/whisper_backbone.py +305 -0
- keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
- keras_hub/src/models/whisper/whisper_decoder.py +141 -0
- keras_hub/src/models/whisper/whisper_encoder.py +106 -0
- keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
- keras_hub/src/models/whisper/whisper_presets.py +148 -0
- keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
- keras_hub/src/models/xlm_roberta/__init__.py +26 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
- keras_hub/src/models/xlnet/__init__.py +13 -0
- keras_hub/src/models/xlnet/relative_attention.py +459 -0
- keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
- keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
- keras_hub/src/samplers/__init__.py +13 -0
- keras_hub/src/samplers/beam_sampler.py +207 -0
- keras_hub/src/samplers/contrastive_sampler.py +231 -0
- keras_hub/src/samplers/greedy_sampler.py +50 -0
- keras_hub/src/samplers/random_sampler.py +77 -0
- keras_hub/src/samplers/sampler.py +237 -0
- keras_hub/src/samplers/serialization.py +97 -0
- keras_hub/src/samplers/top_k_sampler.py +92 -0
- keras_hub/src/samplers/top_p_sampler.py +113 -0
- keras_hub/src/tests/__init__.py +13 -0
- keras_hub/src/tests/test_case.py +608 -0
- keras_hub/src/tokenizers/__init__.py +13 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
- keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
- keras_hub/src/tokenizers/tokenizer.py +235 -0
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
- keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
- keras_hub/src/utils/__init__.py +13 -0
- keras_hub/src/utils/keras_utils.py +130 -0
- keras_hub/src/utils/pipeline_model.py +293 -0
- keras_hub/src/utils/preset_utils.py +621 -0
- keras_hub/src/utils/python_utils.py +21 -0
- keras_hub/src/utils/tensor_utils.py +206 -0
- keras_hub/src/utils/timm/__init__.py +13 -0
- keras_hub/src/utils/timm/convert.py +37 -0
- keras_hub/src/utils/timm/convert_resnet.py +171 -0
- keras_hub/src/utils/transformers/__init__.py +13 -0
- keras_hub/src/utils/transformers/convert.py +101 -0
- keras_hub/src/utils/transformers/convert_bert.py +173 -0
- keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
- keras_hub/src/utils/transformers/convert_gemma.py +187 -0
- keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
- keras_hub/src/utils/transformers/convert_llama3.py +136 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
- keras_hub/src/version_utils.py +23 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,250 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import keras
|
15
|
+
import numpy as np
|
16
|
+
from keras import ops
|
17
|
+
|
18
|
+
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
|
19
|
+
from keras_hub.src.utils.keras_utils import clone_initializer
|
20
|
+
|
21
|
+
|
22
|
+
class CachedGemmaAttention(keras.layers.Layer):
|
23
|
+
"""A cached grouped query attention layer."""
|
24
|
+
|
25
|
+
def __init__(
|
26
|
+
self,
|
27
|
+
head_dim,
|
28
|
+
num_query_heads,
|
29
|
+
num_key_value_heads,
|
30
|
+
kernel_initializer="glorot_uniform",
|
31
|
+
logit_soft_cap=None,
|
32
|
+
use_sliding_window_attention=False,
|
33
|
+
sliding_window_size=4096,
|
34
|
+
query_head_dim_normalize=True,
|
35
|
+
dropout=0,
|
36
|
+
**kwargs,
|
37
|
+
):
|
38
|
+
super().__init__(**kwargs)
|
39
|
+
self.num_query_heads = num_query_heads
|
40
|
+
self.num_key_value_heads = num_key_value_heads
|
41
|
+
self.head_dim = head_dim
|
42
|
+
self.logit_soft_cap = logit_soft_cap
|
43
|
+
self.use_sliding_window_attention = use_sliding_window_attention
|
44
|
+
self.sliding_window_size = sliding_window_size
|
45
|
+
self.query_head_dim_normalize = query_head_dim_normalize
|
46
|
+
self.dropout = dropout
|
47
|
+
|
48
|
+
self._kernel_initializer = keras.initializers.get(
|
49
|
+
clone_initializer(kernel_initializer)
|
50
|
+
)
|
51
|
+
self.num_key_value_groups = num_query_heads // num_key_value_heads
|
52
|
+
self.query_head_dim_normalize = query_head_dim_normalize
|
53
|
+
|
54
|
+
def build(self, inputs_shape):
|
55
|
+
self.hidden_dim = inputs_shape[-1]
|
56
|
+
|
57
|
+
self.query_dense = keras.layers.EinsumDense(
|
58
|
+
"btd,ndh->btnh",
|
59
|
+
output_shape=(None, self.num_query_heads, self.head_dim),
|
60
|
+
kernel_initializer=self._kernel_initializer,
|
61
|
+
dtype=self.dtype_policy,
|
62
|
+
name="query",
|
63
|
+
)
|
64
|
+
self.query_dense.build(inputs_shape)
|
65
|
+
|
66
|
+
self.key_dense = keras.layers.EinsumDense(
|
67
|
+
"bsd,kdh->bskh",
|
68
|
+
output_shape=(None, self.num_key_value_heads, self.head_dim),
|
69
|
+
kernel_initializer=self._kernel_initializer,
|
70
|
+
dtype=self.dtype_policy,
|
71
|
+
name="key",
|
72
|
+
)
|
73
|
+
self.key_dense.build(inputs_shape)
|
74
|
+
|
75
|
+
self.value_dense = keras.layers.EinsumDense(
|
76
|
+
"bsd,kdh->bskh",
|
77
|
+
output_shape=(None, self.num_key_value_heads, self.head_dim),
|
78
|
+
kernel_initializer=self._kernel_initializer,
|
79
|
+
dtype=self.dtype_policy,
|
80
|
+
name="value",
|
81
|
+
)
|
82
|
+
self.value_dense.build(inputs_shape)
|
83
|
+
|
84
|
+
self.dropout_layer = keras.layers.Dropout(
|
85
|
+
rate=self.dropout,
|
86
|
+
dtype=self.dtype_policy,
|
87
|
+
)
|
88
|
+
|
89
|
+
self.output_dense = keras.layers.EinsumDense(
|
90
|
+
equation="btnh,nhd->btd",
|
91
|
+
output_shape=(None, self.hidden_dim),
|
92
|
+
kernel_initializer=self._kernel_initializer,
|
93
|
+
dtype=self.dtype_policy,
|
94
|
+
name="attention_output",
|
95
|
+
)
|
96
|
+
self.output_dense.build(
|
97
|
+
(None, None, self.num_query_heads, self.head_dim)
|
98
|
+
)
|
99
|
+
self.softmax = keras.layers.Softmax(dtype="float32")
|
100
|
+
|
101
|
+
self.rope_layer = RotaryEmbedding(
|
102
|
+
max_wavelength=10_000.0, dtype=self.dtype_policy
|
103
|
+
)
|
104
|
+
|
105
|
+
self.built = True
|
106
|
+
|
107
|
+
def _apply_rope(self, x, start_index):
|
108
|
+
"""Rope rotate q or k."""
|
109
|
+
x = self.rope_layer(x, start_index=start_index)
|
110
|
+
# Gemma uses a different layout for positional embeddings.
|
111
|
+
# The transformation below ensures the embeddings are numerically
|
112
|
+
# equivalent to the original gemma implementation.
|
113
|
+
x = ops.reshape(
|
114
|
+
ops.stack(ops.split(x, 2, axis=-1), axis=-1), ops.shape(x)
|
115
|
+
)
|
116
|
+
return x
|
117
|
+
|
118
|
+
def _compute_attention(
|
119
|
+
self,
|
120
|
+
q,
|
121
|
+
k,
|
122
|
+
v,
|
123
|
+
attention_mask,
|
124
|
+
training=False,
|
125
|
+
cache_update_index=0,
|
126
|
+
):
|
127
|
+
if self.query_head_dim_normalize:
|
128
|
+
query_normalization = 1 / np.sqrt(self.head_dim)
|
129
|
+
else:
|
130
|
+
query_normalization = 1 / np.sqrt(
|
131
|
+
self.hidden_dim // self.num_query_heads
|
132
|
+
)
|
133
|
+
|
134
|
+
q *= ops.cast(query_normalization, dtype=q.dtype)
|
135
|
+
q_shape = ops.shape(q)
|
136
|
+
q = ops.reshape(
|
137
|
+
q,
|
138
|
+
(
|
139
|
+
*q_shape[:-2],
|
140
|
+
self.num_key_value_heads,
|
141
|
+
self.num_query_heads // self.num_key_value_heads,
|
142
|
+
q_shape[-1],
|
143
|
+
),
|
144
|
+
)
|
145
|
+
b, q_len, _, _, h = ops.shape(q)
|
146
|
+
|
147
|
+
attention_logits = ops.einsum("btkgh,bskh->bkgts", q, k)
|
148
|
+
|
149
|
+
if self.logit_soft_cap is not None:
|
150
|
+
attention_logits = ops.divide(attention_logits, self.logit_soft_cap)
|
151
|
+
attention_logits = ops.multiply(
|
152
|
+
ops.tanh(attention_logits), self.logit_soft_cap
|
153
|
+
)
|
154
|
+
|
155
|
+
if self.use_sliding_window_attention:
|
156
|
+
attention_mask = self._mask_sliding_window(
|
157
|
+
attention_mask,
|
158
|
+
cache_update_index=cache_update_index,
|
159
|
+
)
|
160
|
+
|
161
|
+
attention_mask = attention_mask[:, None, None, :, :]
|
162
|
+
orig_dtype = attention_logits.dtype
|
163
|
+
attention_softmax = self.softmax(attention_logits, mask=attention_mask)
|
164
|
+
attention_softmax = ops.cast(attention_softmax, orig_dtype)
|
165
|
+
|
166
|
+
if self.dropout:
|
167
|
+
attention_softmax = self.dropout_layer(
|
168
|
+
attention_softmax, training=training
|
169
|
+
)
|
170
|
+
|
171
|
+
results = ops.einsum("bkgts,bskh->btkgh", attention_softmax, v)
|
172
|
+
return ops.reshape(results, (b, q_len, self.num_query_heads, h))
|
173
|
+
|
174
|
+
def _mask_sliding_window(
|
175
|
+
self,
|
176
|
+
attention_mask,
|
177
|
+
cache_update_index=0,
|
178
|
+
):
|
179
|
+
batch_size, query_len, key_len = ops.shape(attention_mask)
|
180
|
+
# Compute the sliding window for square attention.
|
181
|
+
all_ones = ops.ones((key_len, key_len), "bool")
|
182
|
+
if keras.config.backend() == "tensorflow":
|
183
|
+
# TODO: trui/tril has issues with dynamic shape on the tensorflow
|
184
|
+
# backend. We should fix, but use `band_part` for now.
|
185
|
+
import tensorflow as tf
|
186
|
+
|
187
|
+
band_size = ops.minimum(key_len, self.sliding_window_size - 1)
|
188
|
+
band_size = ops.cast(band_size, "int32")
|
189
|
+
sliding_mask = tf.linalg.band_part(all_ones, band_size, band_size)
|
190
|
+
else:
|
191
|
+
sliding_mask = ops.triu(
|
192
|
+
all_ones, -1 * self.sliding_window_size + 1
|
193
|
+
) * ops.tril(all_ones, self.sliding_window_size - 1)
|
194
|
+
# Slice the window for short queries during generation.
|
195
|
+
start = (cache_update_index, 0)
|
196
|
+
sliding_mask = ops.slice(sliding_mask, start, (query_len, key_len))
|
197
|
+
sliding_mask = ops.expand_dims(sliding_mask, 0)
|
198
|
+
return ops.logical_and(attention_mask, ops.cast(sliding_mask, "bool"))
|
199
|
+
|
200
|
+
def call(
|
201
|
+
self,
|
202
|
+
x,
|
203
|
+
attention_mask=None,
|
204
|
+
cache=None,
|
205
|
+
cache_update_index=0,
|
206
|
+
training=False,
|
207
|
+
):
|
208
|
+
query = self.query_dense(x)
|
209
|
+
query = self._apply_rope(query, cache_update_index)
|
210
|
+
|
211
|
+
if cache is not None:
|
212
|
+
key_cache = cache[:, 0, ...]
|
213
|
+
value_cache = cache[:, 1, ...]
|
214
|
+
key_update = self.key_dense(x)
|
215
|
+
key_update = self._apply_rope(key_update, cache_update_index)
|
216
|
+
value_update = self.value_dense(x)
|
217
|
+
start = [0, cache_update_index, 0, 0]
|
218
|
+
key = ops.slice_update(key_cache, start, key_update)
|
219
|
+
value = ops.slice_update(value_cache, start, value_update)
|
220
|
+
cache = ops.stack((key, value), axis=1)
|
221
|
+
else:
|
222
|
+
key = self.key_dense(x)
|
223
|
+
key = self._apply_rope(key, cache_update_index)
|
224
|
+
value = self.value_dense(x)
|
225
|
+
|
226
|
+
attention_vec = self._compute_attention(
|
227
|
+
query,
|
228
|
+
key,
|
229
|
+
value,
|
230
|
+
attention_mask,
|
231
|
+
training=training,
|
232
|
+
cache_update_index=cache_update_index,
|
233
|
+
)
|
234
|
+
|
235
|
+
# Wipe attn vec if there are no attended tokens.
|
236
|
+
no_attended_tokens = ops.all(
|
237
|
+
ops.equal(attention_mask, 0), axis=-1, keepdims=True
|
238
|
+
)[..., None]
|
239
|
+
attention_vec = ops.where(
|
240
|
+
no_attended_tokens, ops.zeros_like(attention_vec), attention_vec
|
241
|
+
)
|
242
|
+
|
243
|
+
attention_output = self.output_dense(attention_vec)
|
244
|
+
|
245
|
+
if cache is not None:
|
246
|
+
return attention_output, cache
|
247
|
+
return attention_output
|
248
|
+
|
249
|
+
def compute_output_shape(self, input_shape):
|
250
|
+
return input_shape
|
@@ -0,0 +1,316 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
|
16
|
+
import keras
|
17
|
+
from keras import ops
|
18
|
+
|
19
|
+
from keras_hub.src.api_export import keras_hub_export
|
20
|
+
from keras_hub.src.layers.modeling.reversible_embedding import (
|
21
|
+
ReversibleEmbedding,
|
22
|
+
)
|
23
|
+
from keras_hub.src.models.backbone import Backbone
|
24
|
+
from keras_hub.src.models.gemma.gemma_decoder_block import GemmaDecoderBlock
|
25
|
+
from keras_hub.src.models.gemma.rms_normalization import RMSNormalization
|
26
|
+
|
27
|
+
|
28
|
+
@keras_hub_export("keras_hub.models.GemmaBackbone")
|
29
|
+
class GemmaBackbone(Backbone):
|
30
|
+
"""Gemma core network with hyperparameters.
|
31
|
+
|
32
|
+
This backbone implements the base Transformer network for the Gemma model.
|
33
|
+
It includes the embedding lookups and transformer layers. This backbone
|
34
|
+
will output the final hidden states for each token, not generative
|
35
|
+
predictions over the vocabulary space. For a higher-level object for text
|
36
|
+
generation, see `keras_hub.models.GemmaCausalLM`.
|
37
|
+
|
38
|
+
The default constructor gives a fully customizable, randomly initialized
|
39
|
+
Gemma model with any number of layers, heads, and embedding dimensions. To
|
40
|
+
load preset architectures and weights, use the `from_preset` constructor.
|
41
|
+
|
42
|
+
Args:
|
43
|
+
vocabulary_size: int. The size of the token vocabulary.
|
44
|
+
num_layers: int. The number of transformer layers.
|
45
|
+
num_query_heads: int. The number of heads for the query projections in
|
46
|
+
the attention layer.
|
47
|
+
num_key_value_heads: int. The number of heads for the key and value
|
48
|
+
projections in the attention layer.
|
49
|
+
hidden_dim: int. The size of the transformer hidden state at the end
|
50
|
+
of each transformer layer.
|
51
|
+
intermediate_dim: int. The output dimension of the first Dense layer in
|
52
|
+
a two-layer feedforward network for each transformer.
|
53
|
+
head_dim: int. The size of each attention head.
|
54
|
+
layer_norm_epsilon: float. The epsilon value user for every layer norm
|
55
|
+
in the transformer model.
|
56
|
+
dropout: float. Dropout probability for the Transformer encoder.
|
57
|
+
query_head_dim_normalize: boolean. If `True` normalize the query before
|
58
|
+
attention with `head_dim`. If `False`, normalize the query with
|
59
|
+
`hidden_dim / num_query_heads`. Defaults to True.
|
60
|
+
use_post_ffw_norm: boolean. Whether to normalize after the feedforward
|
61
|
+
block. Defaults to False.
|
62
|
+
use_post_attention_norm: boolean. Whether to normalize after the attention
|
63
|
+
block. Defaults to False.
|
64
|
+
attention_logit_soft_cap: None or int. Soft cap for the attention logits.
|
65
|
+
Defaults to None.
|
66
|
+
final_logit_soft_cap: None or int. Soft cap for the final logits.
|
67
|
+
Defaults to None.
|
68
|
+
use_sliding_window_attention boolean. Whether to use sliding local
|
69
|
+
window attention. Defaults to False.
|
70
|
+
sliding_window_size: int. Size of the sliding local window. Defaults to
|
71
|
+
4096.
|
72
|
+
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
|
73
|
+
for the models computations and weights. Note that some
|
74
|
+
computations, such as softmax and layer normalization will always
|
75
|
+
be done a float32 precision regardless of dtype.
|
76
|
+
|
77
|
+
Example:
|
78
|
+
```python
|
79
|
+
input_data = {
|
80
|
+
"token_ids": np.ones(shape=(1, 12), dtype="int32"),
|
81
|
+
"padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]),
|
82
|
+
}
|
83
|
+
|
84
|
+
# Pretrained Gemma decoder.
|
85
|
+
model = keras_hub.models.GemmaBackbone.from_preset("gemma_2b_en")
|
86
|
+
model(input_data)
|
87
|
+
|
88
|
+
# Randomly initialized Gemma decoder with custom config.
|
89
|
+
model = keras_hub.models.GemmaBackbone(
|
90
|
+
vocabulary_size=50257,
|
91
|
+
num_layers=12,
|
92
|
+
num_query_heads=12,
|
93
|
+
num_key_value_heads=1,
|
94
|
+
hidden_dim=768,
|
95
|
+
intermediate_dim=3072,
|
96
|
+
head_dim=64,
|
97
|
+
)
|
98
|
+
model(input_data)
|
99
|
+
```
|
100
|
+
"""
|
101
|
+
|
102
|
+
def __init__(
|
103
|
+
self,
|
104
|
+
vocabulary_size,
|
105
|
+
num_layers,
|
106
|
+
num_query_heads,
|
107
|
+
num_key_value_heads,
|
108
|
+
hidden_dim,
|
109
|
+
intermediate_dim,
|
110
|
+
head_dim,
|
111
|
+
query_head_dim_normalize=True,
|
112
|
+
use_post_ffw_norm=False,
|
113
|
+
use_post_attention_norm=False,
|
114
|
+
attention_logit_soft_cap=None,
|
115
|
+
final_logit_soft_cap=None,
|
116
|
+
use_sliding_window_attention=False,
|
117
|
+
sliding_window_size=4096,
|
118
|
+
layer_norm_epsilon=1e-6,
|
119
|
+
dropout=0,
|
120
|
+
dtype=None,
|
121
|
+
**kwargs,
|
122
|
+
):
|
123
|
+
# === Layers ===
|
124
|
+
self.token_embedding = ReversibleEmbedding(
|
125
|
+
input_dim=vocabulary_size,
|
126
|
+
output_dim=hidden_dim,
|
127
|
+
tie_weights=True,
|
128
|
+
embeddings_initializer=keras.initializers.VarianceScaling(
|
129
|
+
scale=1.0,
|
130
|
+
mode="fan_in",
|
131
|
+
distribution="untruncated_normal",
|
132
|
+
seed=None,
|
133
|
+
),
|
134
|
+
dtype=dtype,
|
135
|
+
logit_soft_cap=final_logit_soft_cap,
|
136
|
+
name="token_embedding",
|
137
|
+
)
|
138
|
+
self.transformer_layers = []
|
139
|
+
for i in range(num_layers):
|
140
|
+
sliding_window = use_sliding_window_attention and (i % 2 == 0)
|
141
|
+
layer = GemmaDecoderBlock(
|
142
|
+
intermediate_dim=intermediate_dim,
|
143
|
+
hidden_dim=hidden_dim,
|
144
|
+
num_query_heads=num_query_heads,
|
145
|
+
head_dim=head_dim,
|
146
|
+
num_key_value_heads=num_key_value_heads,
|
147
|
+
query_head_dim_normalize=query_head_dim_normalize,
|
148
|
+
use_post_ffw_norm=use_post_ffw_norm,
|
149
|
+
use_post_attention_norm=use_post_attention_norm,
|
150
|
+
logit_soft_cap=attention_logit_soft_cap,
|
151
|
+
use_sliding_window_attention=sliding_window,
|
152
|
+
sliding_window_size=sliding_window_size,
|
153
|
+
dropout=dropout,
|
154
|
+
dtype=dtype,
|
155
|
+
name=f"decoder_block_{i}",
|
156
|
+
)
|
157
|
+
self.transformer_layers.append(layer)
|
158
|
+
self.layer_norm = RMSNormalization(
|
159
|
+
epsilon=layer_norm_epsilon,
|
160
|
+
dtype=dtype,
|
161
|
+
name="final_normalization",
|
162
|
+
)
|
163
|
+
|
164
|
+
# === Functional Model ===
|
165
|
+
token_id_input = keras.Input(
|
166
|
+
shape=(None,), dtype="float32", name="token_ids"
|
167
|
+
)
|
168
|
+
padding_mask_input = keras.Input(
|
169
|
+
shape=(None,), dtype="float32", name="padding_mask"
|
170
|
+
)
|
171
|
+
x = self.token_embedding(token_id_input)
|
172
|
+
x = x * ops.cast(ops.sqrt(hidden_dim), x.dtype)
|
173
|
+
for transformer_layer in self.transformer_layers:
|
174
|
+
x = transformer_layer(x, padding_mask=padding_mask_input)
|
175
|
+
sequence_output = self.layer_norm(x)
|
176
|
+
super().__init__(
|
177
|
+
inputs={
|
178
|
+
"token_ids": token_id_input,
|
179
|
+
"padding_mask": padding_mask_input,
|
180
|
+
},
|
181
|
+
outputs=sequence_output,
|
182
|
+
dtype=dtype,
|
183
|
+
**kwargs,
|
184
|
+
)
|
185
|
+
|
186
|
+
# === Config ===
|
187
|
+
self.vocabulary_size = vocabulary_size
|
188
|
+
self.num_layers = num_layers
|
189
|
+
self.num_query_heads = num_query_heads
|
190
|
+
self.num_key_value_heads = num_key_value_heads
|
191
|
+
self.hidden_dim = hidden_dim
|
192
|
+
self.intermediate_dim = intermediate_dim
|
193
|
+
self.head_dim = head_dim
|
194
|
+
self.layer_norm_epsilon = layer_norm_epsilon
|
195
|
+
self.dropout = dropout
|
196
|
+
self.query_head_dim_normalize = query_head_dim_normalize
|
197
|
+
self.use_post_ffw_norm = use_post_ffw_norm
|
198
|
+
self.use_post_attention_norm = use_post_attention_norm
|
199
|
+
self.attention_logit_soft_cap = attention_logit_soft_cap
|
200
|
+
self.final_logit_soft_cap = final_logit_soft_cap
|
201
|
+
self.sliding_window_size = sliding_window_size
|
202
|
+
self.use_sliding_window_attention = use_sliding_window_attention
|
203
|
+
|
204
|
+
def get_config(self):
|
205
|
+
config = super().get_config()
|
206
|
+
config.update(
|
207
|
+
{
|
208
|
+
"vocabulary_size": self.vocabulary_size,
|
209
|
+
"num_layers": self.num_layers,
|
210
|
+
"num_query_heads": self.num_query_heads,
|
211
|
+
"num_key_value_heads": self.num_key_value_heads,
|
212
|
+
"hidden_dim": self.hidden_dim,
|
213
|
+
"intermediate_dim": self.intermediate_dim,
|
214
|
+
"head_dim": self.head_dim,
|
215
|
+
"layer_norm_epsilon": self.layer_norm_epsilon,
|
216
|
+
"dropout": self.dropout,
|
217
|
+
"query_head_dim_normalize": self.query_head_dim_normalize,
|
218
|
+
"use_post_ffw_norm": self.use_post_ffw_norm,
|
219
|
+
"use_post_attention_norm": self.use_post_attention_norm,
|
220
|
+
"final_logit_soft_cap": self.final_logit_soft_cap,
|
221
|
+
"attention_logit_soft_cap": self.attention_logit_soft_cap,
|
222
|
+
"sliding_window_size": self.sliding_window_size,
|
223
|
+
"use_sliding_window_attention": self.use_sliding_window_attention,
|
224
|
+
}
|
225
|
+
)
|
226
|
+
return config
|
227
|
+
|
228
|
+
@staticmethod
|
229
|
+
def get_layout_map(
|
230
|
+
device_mesh,
|
231
|
+
model_parallel_dim_name="model",
|
232
|
+
data_parallel_dim_name="batch",
|
233
|
+
):
|
234
|
+
"""Get a `keras.distribution.LayoutMap` for model parallel distribution.
|
235
|
+
|
236
|
+
The returned `LayoutMap` contains the sharding spec for the gemma
|
237
|
+
backbone weights, so that you can use it to distribute weights across
|
238
|
+
the accelerators.
|
239
|
+
|
240
|
+
Example:
|
241
|
+
```
|
242
|
+
# Feel free to change the mesh shape to balance data and model parallel
|
243
|
+
mesh = keras.distribution.DeviceMesh(
|
244
|
+
shape=(1, 8), axis_names=('batch', 'model'),
|
245
|
+
devices=keras.distribution.list_devices())
|
246
|
+
layout_map = GemmaBackbone.get_layout_map(
|
247
|
+
mesh, model_parallel_dim_name="model")
|
248
|
+
|
249
|
+
distribution = keras.distribution.ModelParallel(
|
250
|
+
mesh, layout_map, batch_dim_name='batch')
|
251
|
+
with distribution.scope():
|
252
|
+
gemma_model = keras_hub.models.GemmaCausalLM.from_preset()
|
253
|
+
```
|
254
|
+
|
255
|
+
Args:
|
256
|
+
device_mesh: The `keras.distribution.DeviceMesh` instance for
|
257
|
+
distribution.
|
258
|
+
model_parallel_dim_name: The axis name of the device mesh, where
|
259
|
+
the weights should be partition on.
|
260
|
+
data_parallel_dim_name: The axis name of the device mesh, where
|
261
|
+
the data should be partition on.
|
262
|
+
Return:
|
263
|
+
`keras.distribution.LayoutMap` that contains the sharding spec
|
264
|
+
of all the model weights.
|
265
|
+
"""
|
266
|
+
# The weight path and shape of the Gemma backbone is like below (for 2G)
|
267
|
+
# token_embedding/embeddings, (256128, 2048), 524550144
|
268
|
+
# repeat block for decoder
|
269
|
+
# ...
|
270
|
+
# decoder_block_17/pre_attention_norm/scale, (2048,), 2048
|
271
|
+
# decoder_block_17/attention/query/kernel, (8, 2048, 256), 4194304
|
272
|
+
# decoder_block_17/attention/key/kernel, (8, 2048, 256), 4194304
|
273
|
+
# decoder_block_17/attention/value/kernel, (8, 2048, 256), 4194304
|
274
|
+
# decoder_block_17/attention/attention_output/kernel, (8, 256, 2048), 4194304
|
275
|
+
# decoder_block_17/pre_ffw_norm/scale, (2048,), 2048
|
276
|
+
# decoder_block_17/ffw_gating/kernel, (2048, 16384), 33554432
|
277
|
+
# decoder_block_17/ffw_gating_2/kernel, (2048, 16384), 33554432
|
278
|
+
# decoder_block_17/ffw_linear/kernel, (16384, 2048), 33554432
|
279
|
+
if not isinstance(device_mesh, keras.distribution.DeviceMesh):
|
280
|
+
raise ValueError(
|
281
|
+
"Invalid device_mesh type. Expected `keras.distribution.Device`,"
|
282
|
+
f" got {type(device_mesh)}"
|
283
|
+
)
|
284
|
+
if model_parallel_dim_name not in device_mesh.axis_names:
|
285
|
+
raise ValueError(
|
286
|
+
f"{model_parallel_dim_name} is not found in the "
|
287
|
+
f"device_mesh.axis_names. {device_mesh.axis_name=}"
|
288
|
+
)
|
289
|
+
if data_parallel_dim_name not in device_mesh.axis_names:
|
290
|
+
raise ValueError(
|
291
|
+
f"{data_parallel_dim_name} is not found in the "
|
292
|
+
f"device_mesh.axis_names. {device_mesh.axis_name=}"
|
293
|
+
)
|
294
|
+
# Note that it is possible to further config the mesh to be 3D, eg
|
295
|
+
# (data, seq, model). We leave it as 2D for now for simplicity.
|
296
|
+
data_dim = data_parallel_dim_name
|
297
|
+
model_dim = model_parallel_dim_name
|
298
|
+
# The sharding config is based on the Gemma team training config.
|
299
|
+
# See https://arxiv.org/abs/2403.08295
|
300
|
+
layout_map = keras.distribution.LayoutMap(device_mesh)
|
301
|
+
layout_map["token_embedding/embeddings"] = (model_dim, data_dim)
|
302
|
+
layout_map["decoder_block.*attention.*(query|key|value).kernel"] = (
|
303
|
+
model_dim,
|
304
|
+
data_dim,
|
305
|
+
None,
|
306
|
+
)
|
307
|
+
layout_map["decoder_block.*attention_output.kernel"] = (
|
308
|
+
model_dim,
|
309
|
+
None,
|
310
|
+
data_dim,
|
311
|
+
)
|
312
|
+
layout_map["decoder_block.*ffw_gating.kernel"] = (data_dim, model_dim)
|
313
|
+
layout_map["decoder_block.*ffw_gating_2.kernel"] = (data_dim, model_dim)
|
314
|
+
layout_map["decoder_block.*ffw_linear.kernel"] = (model_dim, data_dim)
|
315
|
+
|
316
|
+
return layout_map
|