keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +52 -0
- keras_hub/api/__init__.py +27 -0
- keras_hub/api/layers/__init__.py +47 -0
- keras_hub/api/metrics/__init__.py +24 -0
- keras_hub/api/models/__init__.py +249 -0
- keras_hub/api/samplers/__init__.py +29 -0
- keras_hub/api/tokenizers/__init__.py +35 -0
- keras_hub/src/__init__.py +13 -0
- keras_hub/src/api_export.py +53 -0
- keras_hub/src/layers/__init__.py +13 -0
- keras_hub/src/layers/modeling/__init__.py +13 -0
- keras_hub/src/layers/modeling/alibi_bias.py +143 -0
- keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
- keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
- keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
- keras_hub/src/layers/modeling/position_embedding.py +123 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
- keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
- keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
- keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
- keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
- keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
- keras_hub/src/layers/preprocessing/__init__.py +13 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
- keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
- keras_hub/src/layers/preprocessing/random_swap.py +267 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
- keras_hub/src/metrics/__init__.py +13 -0
- keras_hub/src/metrics/bleu.py +394 -0
- keras_hub/src/metrics/edit_distance.py +197 -0
- keras_hub/src/metrics/perplexity.py +181 -0
- keras_hub/src/metrics/rouge_base.py +204 -0
- keras_hub/src/metrics/rouge_l.py +97 -0
- keras_hub/src/metrics/rouge_n.py +125 -0
- keras_hub/src/models/__init__.py +13 -0
- keras_hub/src/models/albert/__init__.py +20 -0
- keras_hub/src/models/albert/albert_backbone.py +267 -0
- keras_hub/src/models/albert/albert_classifier.py +202 -0
- keras_hub/src/models/albert/albert_masked_lm.py +129 -0
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/albert/albert_preprocessor.py +206 -0
- keras_hub/src/models/albert/albert_presets.py +70 -0
- keras_hub/src/models/albert/albert_tokenizer.py +119 -0
- keras_hub/src/models/backbone.py +311 -0
- keras_hub/src/models/bart/__init__.py +20 -0
- keras_hub/src/models/bart/bart_backbone.py +261 -0
- keras_hub/src/models/bart/bart_preprocessor.py +276 -0
- keras_hub/src/models/bart/bart_presets.py +74 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
- keras_hub/src/models/bart/bart_tokenizer.py +124 -0
- keras_hub/src/models/bert/__init__.py +23 -0
- keras_hub/src/models/bert/bert_backbone.py +227 -0
- keras_hub/src/models/bert/bert_classifier.py +183 -0
- keras_hub/src/models/bert/bert_masked_lm.py +131 -0
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/bert/bert_preprocessor.py +184 -0
- keras_hub/src/models/bert/bert_presets.py +147 -0
- keras_hub/src/models/bert/bert_tokenizer.py +112 -0
- keras_hub/src/models/bloom/__init__.py +20 -0
- keras_hub/src/models/bloom/bloom_attention.py +186 -0
- keras_hub/src/models/bloom/bloom_backbone.py +173 -0
- keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
- keras_hub/src/models/bloom/bloom_decoder.py +206 -0
- keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
- keras_hub/src/models/bloom/bloom_presets.py +121 -0
- keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
- keras_hub/src/models/causal_lm.py +383 -0
- keras_hub/src/models/classifier.py +109 -0
- keras_hub/src/models/csp_darknet/__init__.py +13 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
- keras_hub/src/models/deberta_v3/__init__.py +24 -0
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
- keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
- keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
- keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
- keras_hub/src/models/densenet/__init__.py +13 -0
- keras_hub/src/models/densenet/densenet_backbone.py +210 -0
- keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
- keras_hub/src/models/distil_bert/__init__.py +26 -0
- keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
- keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
- keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
- keras_hub/src/models/electra/__init__.py +20 -0
- keras_hub/src/models/electra/electra_backbone.py +247 -0
- keras_hub/src/models/electra/electra_preprocessor.py +154 -0
- keras_hub/src/models/electra/electra_presets.py +95 -0
- keras_hub/src/models/electra/electra_tokenizer.py +104 -0
- keras_hub/src/models/f_net/__init__.py +20 -0
- keras_hub/src/models/f_net/f_net_backbone.py +236 -0
- keras_hub/src/models/f_net/f_net_classifier.py +154 -0
- keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
- keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
- keras_hub/src/models/f_net/f_net_presets.py +43 -0
- keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
- keras_hub/src/models/falcon/__init__.py +20 -0
- keras_hub/src/models/falcon/falcon_attention.py +156 -0
- keras_hub/src/models/falcon/falcon_backbone.py +164 -0
- keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
- keras_hub/src/models/falcon/falcon_presets.py +30 -0
- keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
- keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
- keras_hub/src/models/feature_pyramid_backbone.py +73 -0
- keras_hub/src/models/gemma/__init__.py +20 -0
- keras_hub/src/models/gemma/gemma_attention.py +250 -0
- keras_hub/src/models/gemma/gemma_backbone.py +316 -0
- keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
- keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
- keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
- keras_hub/src/models/gemma/gemma_presets.py +248 -0
- keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
- keras_hub/src/models/gemma/rms_normalization.py +40 -0
- keras_hub/src/models/gpt2/__init__.py +20 -0
- keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
- keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
- keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
- keras_hub/src/models/image_classifier.py +90 -0
- keras_hub/src/models/llama/__init__.py +20 -0
- keras_hub/src/models/llama/llama_attention.py +225 -0
- keras_hub/src/models/llama/llama_backbone.py +188 -0
- keras_hub/src/models/llama/llama_causal_lm.py +327 -0
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
- keras_hub/src/models/llama/llama_decoder.py +246 -0
- keras_hub/src/models/llama/llama_layernorm.py +48 -0
- keras_hub/src/models/llama/llama_preprocessor.py +189 -0
- keras_hub/src/models/llama/llama_presets.py +80 -0
- keras_hub/src/models/llama/llama_tokenizer.py +84 -0
- keras_hub/src/models/llama3/__init__.py +20 -0
- keras_hub/src/models/llama3/llama3_backbone.py +84 -0
- keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
- keras_hub/src/models/llama3/llama3_presets.py +69 -0
- keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
- keras_hub/src/models/masked_lm.py +101 -0
- keras_hub/src/models/mistral/__init__.py +20 -0
- keras_hub/src/models/mistral/mistral_attention.py +238 -0
- keras_hub/src/models/mistral/mistral_backbone.py +203 -0
- keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
- keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
- keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
- keras_hub/src/models/mistral/mistral_presets.py +48 -0
- keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
- keras_hub/src/models/mix_transformer/__init__.py +13 -0
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
- keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
- keras_hub/src/models/opt/__init__.py +20 -0
- keras_hub/src/models/opt/opt_backbone.py +173 -0
- keras_hub/src/models/opt/opt_causal_lm.py +301 -0
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
- keras_hub/src/models/opt/opt_preprocessor.py +188 -0
- keras_hub/src/models/opt/opt_presets.py +72 -0
- keras_hub/src/models/opt/opt_tokenizer.py +116 -0
- keras_hub/src/models/pali_gemma/__init__.py +23 -0
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
- keras_hub/src/models/phi3/__init__.py +20 -0
- keras_hub/src/models/phi3/phi3_attention.py +260 -0
- keras_hub/src/models/phi3/phi3_backbone.py +224 -0
- keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/phi3/phi3_decoder.py +260 -0
- keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
- keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
- keras_hub/src/models/phi3/phi3_presets.py +50 -0
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
- keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
- keras_hub/src/models/preprocessor.py +207 -0
- keras_hub/src/models/resnet/__init__.py +13 -0
- keras_hub/src/models/resnet/resnet_backbone.py +612 -0
- keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
- keras_hub/src/models/roberta/__init__.py +20 -0
- keras_hub/src/models/roberta/roberta_backbone.py +184 -0
- keras_hub/src/models/roberta/roberta_classifier.py +209 -0
- keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
- keras_hub/src/models/roberta/roberta_presets.py +43 -0
- keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
- keras_hub/src/models/seq_2_seq_lm.py +54 -0
- keras_hub/src/models/t5/__init__.py +20 -0
- keras_hub/src/models/t5/t5_backbone.py +261 -0
- keras_hub/src/models/t5/t5_layer_norm.py +35 -0
- keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
- keras_hub/src/models/t5/t5_presets.py +95 -0
- keras_hub/src/models/t5/t5_tokenizer.py +100 -0
- keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
- keras_hub/src/models/task.py +419 -0
- keras_hub/src/models/vgg/__init__.py +13 -0
- keras_hub/src/models/vgg/vgg_backbone.py +158 -0
- keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
- keras_hub/src/models/vit_det/__init__.py +13 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
- keras_hub/src/models/vit_det/vit_layers.py +565 -0
- keras_hub/src/models/whisper/__init__.py +20 -0
- keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
- keras_hub/src/models/whisper/whisper_backbone.py +305 -0
- keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
- keras_hub/src/models/whisper/whisper_decoder.py +141 -0
- keras_hub/src/models/whisper/whisper_encoder.py +106 -0
- keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
- keras_hub/src/models/whisper/whisper_presets.py +148 -0
- keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
- keras_hub/src/models/xlm_roberta/__init__.py +26 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
- keras_hub/src/models/xlnet/__init__.py +13 -0
- keras_hub/src/models/xlnet/relative_attention.py +459 -0
- keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
- keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
- keras_hub/src/samplers/__init__.py +13 -0
- keras_hub/src/samplers/beam_sampler.py +207 -0
- keras_hub/src/samplers/contrastive_sampler.py +231 -0
- keras_hub/src/samplers/greedy_sampler.py +50 -0
- keras_hub/src/samplers/random_sampler.py +77 -0
- keras_hub/src/samplers/sampler.py +237 -0
- keras_hub/src/samplers/serialization.py +97 -0
- keras_hub/src/samplers/top_k_sampler.py +92 -0
- keras_hub/src/samplers/top_p_sampler.py +113 -0
- keras_hub/src/tests/__init__.py +13 -0
- keras_hub/src/tests/test_case.py +608 -0
- keras_hub/src/tokenizers/__init__.py +13 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
- keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
- keras_hub/src/tokenizers/tokenizer.py +235 -0
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
- keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
- keras_hub/src/utils/__init__.py +13 -0
- keras_hub/src/utils/keras_utils.py +130 -0
- keras_hub/src/utils/pipeline_model.py +293 -0
- keras_hub/src/utils/preset_utils.py +621 -0
- keras_hub/src/utils/python_utils.py +21 -0
- keras_hub/src/utils/tensor_utils.py +206 -0
- keras_hub/src/utils/timm/__init__.py +13 -0
- keras_hub/src/utils/timm/convert.py +37 -0
- keras_hub/src/utils/timm/convert_resnet.py +171 -0
- keras_hub/src/utils/transformers/__init__.py +13 -0
- keras_hub/src/utils/transformers/convert.py +101 -0
- keras_hub/src/utils/transformers/convert_bert.py +173 -0
- keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
- keras_hub/src/utils/transformers/convert_gemma.py +187 -0
- keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
- keras_hub/src/utils/transformers/convert_llama3.py +136 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
- keras_hub/src/version_utils.py +23 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,313 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writingf, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
from keras import ops
|
15
|
+
|
16
|
+
from keras_hub.src.api_export import keras_hub_export
|
17
|
+
from keras_hub.src.models.causal_lm import CausalLM
|
18
|
+
from keras_hub.src.models.pali_gemma.pali_gemma_backbone import (
|
19
|
+
PaliGemmaBackbone,
|
20
|
+
)
|
21
|
+
from keras_hub.src.models.pali_gemma.pali_gemma_causal_lm_preprocessor import (
|
22
|
+
PaliGemmaCausalLMPreprocessor,
|
23
|
+
)
|
24
|
+
from keras_hub.src.utils.tensor_utils import any_equal
|
25
|
+
|
26
|
+
|
27
|
+
@keras_hub_export("keras_hub.models.PaliGemmaCausalLM")
|
28
|
+
class PaliGemmaCausalLM(CausalLM):
|
29
|
+
"""An end-to-end multi modal PaliGemma model for causal language modeling.
|
30
|
+
|
31
|
+
A causal language model (LM) predicts the next token based on previous
|
32
|
+
tokens. This task setup can be used to train the model unsupervised on
|
33
|
+
image and plain text input, or to autoregressively generate plain text
|
34
|
+
similar to the data used for training.
|
35
|
+
|
36
|
+
This model has a `generate()` method, which generates text based on a
|
37
|
+
prompt. The generation strategy used is controlled by an additional
|
38
|
+
`sampler` argument on `compile()`. You can recompile the model with
|
39
|
+
different `keras_hub.samplers` objects to control the generation. By
|
40
|
+
default, `"greedy"` sampling will be used.
|
41
|
+
|
42
|
+
This model can optionally be configured with a `preprocessor` layer, in
|
43
|
+
which case it will automatically apply preprocessing to string inputs during
|
44
|
+
`fit()`, `predict()`, `evaluate()` and `generate()`. This is done by default
|
45
|
+
when creating the model with `from_preset()`.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
backbone: A `keras_hub.models.PaliGemmaBackbone` instance.
|
49
|
+
preprocessor: A `keras_hub.models.PaliGemmaCausalLMPreprocessor` or
|
50
|
+
`None`. If `None`, this model will not apply preprocessing, and
|
51
|
+
inputs should be preprocessed before calling the model.
|
52
|
+
|
53
|
+
Examples:
|
54
|
+
|
55
|
+
Use `generate()` to do text generation.
|
56
|
+
```python
|
57
|
+
image = np.random.rand(224, 224, 3)
|
58
|
+
pali_gemma_lm = keras_hub.models.PaliGemmaCausalLM.from_preset(
|
59
|
+
"pali_gemma_3b_mix_224"
|
60
|
+
)
|
61
|
+
pali_gemma_lm.generate(
|
62
|
+
{
|
63
|
+
"images": image,
|
64
|
+
"text": ["answer en where is the cow standing?\\n"]
|
65
|
+
}
|
66
|
+
)
|
67
|
+
|
68
|
+
# Generate with batched prompts.
|
69
|
+
pali_gemma_lm.generate(
|
70
|
+
{
|
71
|
+
"images": [image, image],
|
72
|
+
"text": ["answer en where is the cow standing?\\n", "caption en\\n"]
|
73
|
+
}
|
74
|
+
)
|
75
|
+
```
|
76
|
+
|
77
|
+
Use `generate()` without preprocessing.
|
78
|
+
```python
|
79
|
+
image = np.random.rand(224, 224, 3)
|
80
|
+
inputs = {
|
81
|
+
"images": [image, image],
|
82
|
+
# Token ids for "<bos> Keras is".
|
83
|
+
"token_ids": np.array([[2, 214064, 603, 0, 0, 0, 0]] * 2),
|
84
|
+
# Use `"padding_mask"` to indicate values that should not be overridden.
|
85
|
+
"padding_mask": np.array([[1, 1, 1, 0, 0, 0, 0]] * 2),
|
86
|
+
}
|
87
|
+
|
88
|
+
pali_gemma_lm = keras_hub.models.PaliGemmaCausalLM.from_preset(
|
89
|
+
"pali_gemma_3b_mix_224",
|
90
|
+
preprocessor=None,
|
91
|
+
)
|
92
|
+
pali_gemma_lm.generate(inputs)
|
93
|
+
```
|
94
|
+
|
95
|
+
Custom backbone and vocabulary.
|
96
|
+
```python
|
97
|
+
tokenizer = keras_hub.models.PaliGemmaTokenizer(
|
98
|
+
proto="proto.spm",
|
99
|
+
)
|
100
|
+
preprocessor = keras_hub.models.PaliGemmaCausalLMPreprocessor(
|
101
|
+
tokenizer=tokenizer,
|
102
|
+
sequence_length=128,
|
103
|
+
)
|
104
|
+
backbone = keras_hub.models.PaliGemmaBackbone()
|
105
|
+
pali_gemma_lm = keras_hub.models.PaliGemmaCausalLM(
|
106
|
+
backbone=backbone,
|
107
|
+
preprocessor=preprocessor,
|
108
|
+
)
|
109
|
+
```
|
110
|
+
"""
|
111
|
+
|
112
|
+
backbone_cls = PaliGemmaBackbone
|
113
|
+
preprocessor_cls = PaliGemmaCausalLMPreprocessor
|
114
|
+
|
115
|
+
def __init__(
|
116
|
+
self,
|
117
|
+
preprocessor,
|
118
|
+
backbone,
|
119
|
+
**kwargs,
|
120
|
+
):
|
121
|
+
# === Layers ===
|
122
|
+
self.preprocessor = preprocessor
|
123
|
+
self.backbone = backbone
|
124
|
+
|
125
|
+
# === Functional Model ===
|
126
|
+
inputs = backbone.inputs
|
127
|
+
hidden_state = backbone(inputs=inputs)
|
128
|
+
outputs = backbone.token_embedding(hidden_state, reverse=True)
|
129
|
+
outputs = outputs[:, backbone.image_sequence_length :, :]
|
130
|
+
super().__init__(
|
131
|
+
inputs=inputs,
|
132
|
+
outputs=outputs,
|
133
|
+
**kwargs,
|
134
|
+
)
|
135
|
+
|
136
|
+
def compile(
|
137
|
+
self,
|
138
|
+
optimizer="auto",
|
139
|
+
loss="auto",
|
140
|
+
*,
|
141
|
+
weighted_metrics="auto",
|
142
|
+
sampler="greedy",
|
143
|
+
**kwargs,
|
144
|
+
):
|
145
|
+
super().compile(
|
146
|
+
optimizer=optimizer,
|
147
|
+
loss=loss,
|
148
|
+
weighted_metrics=weighted_metrics,
|
149
|
+
sampler=sampler,
|
150
|
+
**kwargs,
|
151
|
+
)
|
152
|
+
|
153
|
+
def call_with_cache(
|
154
|
+
self,
|
155
|
+
token_ids,
|
156
|
+
cache,
|
157
|
+
cache_update_index,
|
158
|
+
img_embeddings=None,
|
159
|
+
padding_mask=None,
|
160
|
+
):
|
161
|
+
"""Forward pass of `PaliGemmaCausalLM` with cache.
|
162
|
+
|
163
|
+
`call_with_cache` adds an additional forward pass for the model for
|
164
|
+
autoregressive inference. Unlike calling the model directly, this method
|
165
|
+
allows caching previous key/value Tensors in multi-head attention layer,
|
166
|
+
and avoids recomputing the outputs of seen tokens.
|
167
|
+
|
168
|
+
Args:
|
169
|
+
token_ids: a dense int Tensor with shape `(batch_size, max_length)`.
|
170
|
+
cache: a dense float Tensor, the cache of key and value.
|
171
|
+
cache_update_index: int, or int Tensor. The index of current inputs
|
172
|
+
in the whole sequence.
|
173
|
+
img_embeddings: a dense float Tensor with shape
|
174
|
+
`(batch_size, image_sequence_length, hidden_dim)`.
|
175
|
+
padding_mask: a dense int Tensor with shape
|
176
|
+
`(batch_size, max_length)`.
|
177
|
+
|
178
|
+
Returns:
|
179
|
+
A (logits, hidden_states, cache) tuple. Where `logits` is the
|
180
|
+
language model logits for the input token_ids, `hidden_states` is
|
181
|
+
the final hidden representation of the input tokens, and `cache` is
|
182
|
+
the decoding cache.
|
183
|
+
"""
|
184
|
+
text_embeddings = self.backbone.token_embedding(token_ids)
|
185
|
+
text_embeddings = text_embeddings * ops.cast(
|
186
|
+
ops.sqrt(self.backbone.hidden_dim), text_embeddings.dtype
|
187
|
+
)
|
188
|
+
|
189
|
+
if img_embeddings is not None:
|
190
|
+
x = ops.concatenate((img_embeddings, text_embeddings), axis=1)
|
191
|
+
else:
|
192
|
+
x = text_embeddings
|
193
|
+
|
194
|
+
# Each decoder layer has a cache; we update them separately.
|
195
|
+
caches = []
|
196
|
+
for i, transformer_layer in enumerate(self.backbone.transformer_layers):
|
197
|
+
current_cache = cache[:, i, ...]
|
198
|
+
x, next_cache = transformer_layer(
|
199
|
+
x,
|
200
|
+
cache=current_cache,
|
201
|
+
cache_update_index=cache_update_index,
|
202
|
+
padding_mask=padding_mask,
|
203
|
+
)
|
204
|
+
caches.append(next_cache)
|
205
|
+
cache = ops.stack(caches, axis=1)
|
206
|
+
hidden_states = x = self.backbone.layer_norm(x)
|
207
|
+
logits = self.backbone.token_embedding(x, reverse=True)
|
208
|
+
return logits, hidden_states, cache
|
209
|
+
|
210
|
+
def _build_cache(self, token_ids, img_embeddings, padding_mask):
|
211
|
+
"""Build an empty cache for use with `call_with_cache()`."""
|
212
|
+
batch_size = ops.shape(token_ids)[0]
|
213
|
+
max_length = (
|
214
|
+
ops.shape(token_ids)[1] + self.backbone.image_sequence_length
|
215
|
+
)
|
216
|
+
num_layers = self.backbone.num_layers
|
217
|
+
num_heads = self.backbone.num_key_value_heads
|
218
|
+
head_dim = self.backbone.head_dim
|
219
|
+
shape = [batch_size, num_layers, 2, max_length, num_heads, head_dim]
|
220
|
+
cache = ops.zeros(shape, dtype=self.compute_dtype)
|
221
|
+
# Seed the cache.
|
222
|
+
logits, hidden_states, cache = self.call_with_cache(
|
223
|
+
token_ids=token_ids,
|
224
|
+
img_embeddings=img_embeddings,
|
225
|
+
cache=cache,
|
226
|
+
cache_update_index=0,
|
227
|
+
padding_mask=padding_mask,
|
228
|
+
)
|
229
|
+
return hidden_states, cache
|
230
|
+
|
231
|
+
def generate_step(self, inputs, stop_token_ids=None):
|
232
|
+
"""A compilable generation function for a single batch of inputs.
|
233
|
+
|
234
|
+
This function represents the inner, XLA-compilable, generation function
|
235
|
+
for a single batch of inputs. Inputs should have the same structure as
|
236
|
+
model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`.
|
237
|
+
|
238
|
+
Args:
|
239
|
+
inputs: A dictionary with two keys `"token_ids"` and
|
240
|
+
`"padding_mask"` and batched tensor values.
|
241
|
+
stop_token_ids: Tuple of id's of end token's to stop on. If all
|
242
|
+
sequences have produced a new stop token, generation
|
243
|
+
will stop.
|
244
|
+
"""
|
245
|
+
token_ids, padding_mask, images = (
|
246
|
+
inputs["token_ids"],
|
247
|
+
inputs["padding_mask"],
|
248
|
+
inputs["images"],
|
249
|
+
)
|
250
|
+
if len(ops.shape(images)) == 3:
|
251
|
+
# Handle an unbatched image. Unlike `token_ids` and `padding_mask`
|
252
|
+
# this will not automatically be upranked.
|
253
|
+
images = ops.expand_dims(images, axis=0)
|
254
|
+
img_embeddings = self.backbone.vit_encoder(images)
|
255
|
+
|
256
|
+
# Create and seed cache with a single forward pass.
|
257
|
+
hidden_states, cache = self._build_cache(
|
258
|
+
token_ids, img_embeddings, padding_mask
|
259
|
+
)
|
260
|
+
# Compute the lengths of all user inputted tokens ids.
|
261
|
+
row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1)
|
262
|
+
# Start at the first index that has no user inputted id.
|
263
|
+
index = ops.min(row_lengths)
|
264
|
+
|
265
|
+
def next(prompt, cache, index):
|
266
|
+
# The cache index is the index of our previous token.
|
267
|
+
cache_update_index = index - 1 + self.backbone.image_sequence_length
|
268
|
+
batch_size = ops.shape(prompt)[0]
|
269
|
+
prompt = ops.slice(prompt, [0, index - 1], [batch_size, 1])
|
270
|
+
logits, hidden_states, cache = self.call_with_cache(
|
271
|
+
token_ids=prompt,
|
272
|
+
cache=cache,
|
273
|
+
cache_update_index=cache_update_index,
|
274
|
+
)
|
275
|
+
return (
|
276
|
+
ops.squeeze(logits, axis=1),
|
277
|
+
ops.squeeze(hidden_states, axis=1),
|
278
|
+
cache,
|
279
|
+
)
|
280
|
+
|
281
|
+
token_ids = self.sampler(
|
282
|
+
next=next,
|
283
|
+
prompt=token_ids,
|
284
|
+
cache=cache,
|
285
|
+
index=index,
|
286
|
+
mask=padding_mask,
|
287
|
+
stop_token_ids=stop_token_ids,
|
288
|
+
hidden_states=hidden_states,
|
289
|
+
model=self,
|
290
|
+
)
|
291
|
+
|
292
|
+
# Compute an output padding mask with the token ids we updated.
|
293
|
+
if stop_token_ids is not None:
|
294
|
+
# Build a mask of `stop_token_ids` locations not in the original
|
295
|
+
# prompt (not in locations where `padding_mask` is True).
|
296
|
+
end_locations = any_equal(
|
297
|
+
token_ids, stop_token_ids, ops.logical_not(padding_mask)
|
298
|
+
)
|
299
|
+
|
300
|
+
end_locations = ops.cast(end_locations, "int32")
|
301
|
+
# Use cumsum to get ones in all locations after end_locations.
|
302
|
+
cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32")
|
303
|
+
overflow = cumsum - end_locations
|
304
|
+
# Our padding mask is the inverse of these overflow locations.
|
305
|
+
padding_mask = ops.logical_not(ops.cast(overflow, "bool"))
|
306
|
+
else:
|
307
|
+
# Without early stopping, all locations will have been updated.
|
308
|
+
padding_mask = ops.ones_like(token_ids, dtype="bool")
|
309
|
+
return {
|
310
|
+
"token_ids": token_ids,
|
311
|
+
"padding_mask": padding_mask,
|
312
|
+
"images": images,
|
313
|
+
}
|
@@ -0,0 +1,147 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import keras
|
15
|
+
from absl import logging
|
16
|
+
from keras import ops
|
17
|
+
|
18
|
+
from keras_hub.src.api_export import keras_hub_export
|
19
|
+
from keras_hub.src.layers.preprocessing.multi_segment_packer import (
|
20
|
+
MultiSegmentPacker,
|
21
|
+
)
|
22
|
+
from keras_hub.src.models.gemma.gemma_causal_lm_preprocessor import (
|
23
|
+
GemmaCausalLMPreprocessor,
|
24
|
+
)
|
25
|
+
from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import (
|
26
|
+
PaliGemmaTokenizer,
|
27
|
+
)
|
28
|
+
from keras_hub.src.utils.keras_utils import (
|
29
|
+
convert_inputs_to_list_of_tensor_segments,
|
30
|
+
)
|
31
|
+
|
32
|
+
|
33
|
+
@keras_hub_export("keras_hub.models.PaliGemmaCausalLMPreprocessor")
|
34
|
+
class PaliGemmaCausalLMPreprocessor(GemmaCausalLMPreprocessor):
|
35
|
+
tokenizer_cls = PaliGemmaTokenizer
|
36
|
+
|
37
|
+
def __init__(
|
38
|
+
self,
|
39
|
+
tokenizer,
|
40
|
+
sequence_length=512,
|
41
|
+
add_start_token=True,
|
42
|
+
add_end_token=True,
|
43
|
+
**kwargs,
|
44
|
+
):
|
45
|
+
super().__init__(
|
46
|
+
tokenizer, sequence_length, add_start_token, add_end_token, **kwargs
|
47
|
+
)
|
48
|
+
|
49
|
+
def build(self, input_shape):
|
50
|
+
# Defer packer creation to `build()` so that we can be sure tokenizer
|
51
|
+
# assets have loaded when restoring a saved model.
|
52
|
+
self.packer = MultiSegmentPacker(
|
53
|
+
start_value=self.tokenizer.start_token_id,
|
54
|
+
end_value=self.tokenizer.end_token_id,
|
55
|
+
pad_value=self.tokenizer.pad_token_id,
|
56
|
+
sep_value=[],
|
57
|
+
sequence_length=self.sequence_length,
|
58
|
+
)
|
59
|
+
self.built = True
|
60
|
+
|
61
|
+
def call(
|
62
|
+
self,
|
63
|
+
x,
|
64
|
+
y=None,
|
65
|
+
sample_weight=None,
|
66
|
+
sequence_length=None,
|
67
|
+
):
|
68
|
+
if y is not None or sample_weight is not None:
|
69
|
+
logging.warning(
|
70
|
+
"`PaliGemmaCausalLMPreprocessor` generates `y` and `sample_weight` "
|
71
|
+
"based on your input data, but your data already contains `y` "
|
72
|
+
"or `sample_weight`. Your `y` and `sample_weight` will be "
|
73
|
+
"ignored."
|
74
|
+
)
|
75
|
+
sequence_length = sequence_length or self.sequence_length
|
76
|
+
|
77
|
+
images, prompts, responses = x["images"], x["prompts"], x["responses"]
|
78
|
+
if keras.config.backend() == "tensorflow":
|
79
|
+
# Tensorflow backend needs uniform ouput types.
|
80
|
+
images = ops.convert_to_tensor(images)
|
81
|
+
prompts = convert_inputs_to_list_of_tensor_segments(prompts)[0]
|
82
|
+
prompts = self.tokenizer(prompts)
|
83
|
+
responses = convert_inputs_to_list_of_tensor_segments(responses)[0]
|
84
|
+
responses = self.tokenizer(responses)
|
85
|
+
# Pad with one extra token to account for the truncation below.
|
86
|
+
token_ids, segment_ids = self.packer(
|
87
|
+
(prompts, responses),
|
88
|
+
sequence_length=sequence_length + 1,
|
89
|
+
add_start_value=self.add_start_token,
|
90
|
+
add_end_value=self.add_end_token,
|
91
|
+
)
|
92
|
+
padding_mask = token_ids != self.tokenizer.pad_token_id
|
93
|
+
response_mask = segment_ids == 1
|
94
|
+
# The last token does not have a next token, so we truncate it out.
|
95
|
+
x = {
|
96
|
+
"token_ids": token_ids[..., :-1],
|
97
|
+
"response_mask": response_mask[..., :-1],
|
98
|
+
"padding_mask": padding_mask[..., :-1],
|
99
|
+
"images": images,
|
100
|
+
}
|
101
|
+
# Target `y` will be the next token.
|
102
|
+
y = token_ids[..., 1:]
|
103
|
+
# Only compute the loss for labels in the response.
|
104
|
+
sample_weight = response_mask[..., 1:]
|
105
|
+
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
|
106
|
+
|
107
|
+
def generate_preprocess(
|
108
|
+
self,
|
109
|
+
x,
|
110
|
+
sequence_length=None,
|
111
|
+
):
|
112
|
+
"""Convert strings to integer token input for generation.
|
113
|
+
|
114
|
+
Similar to calling the layer for training, this method takes in strings
|
115
|
+
or tensor strings, tokenizes and packs the input, and computes a padding
|
116
|
+
mask masking all inputs not filled in with a padded value.
|
117
|
+
|
118
|
+
Unlike calling the layer for training, this method does not compute
|
119
|
+
labels and will never append a `tokenizer.end_token_id` to the end of
|
120
|
+
the sequence (as generation is expected to continue at the end of the
|
121
|
+
inputted prompt).
|
122
|
+
"""
|
123
|
+
if not self.built:
|
124
|
+
self.build(None)
|
125
|
+
sequence_length = sequence_length or self.sequence_length
|
126
|
+
|
127
|
+
images, prompts = x["images"], x["prompts"]
|
128
|
+
prompts = convert_inputs_to_list_of_tensor_segments(prompts)[0]
|
129
|
+
prompts = self.tokenizer(prompts)
|
130
|
+
segments = [prompts]
|
131
|
+
if "responses" in x:
|
132
|
+
responses = x["responses"]
|
133
|
+
responses = convert_inputs_to_list_of_tensor_segments(responses)[0]
|
134
|
+
segments.append(self.tokenizer(responses))
|
135
|
+
token_ids, segment_ids = self.packer(
|
136
|
+
segments,
|
137
|
+
sequence_length=sequence_length,
|
138
|
+
add_end_value=False,
|
139
|
+
)
|
140
|
+
padding_mask = token_ids != self.tokenizer.pad_token_id
|
141
|
+
response_mask = segment_ids == 1
|
142
|
+
return {
|
143
|
+
"images": images,
|
144
|
+
"token_ids": token_ids,
|
145
|
+
"response_mask": response_mask,
|
146
|
+
"padding_mask": padding_mask,
|
147
|
+
}
|
@@ -0,0 +1,160 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import keras
|
16
|
+
from keras import ops
|
17
|
+
|
18
|
+
from keras_hub.src.layers.modeling.transformer_layer_utils import (
|
19
|
+
compute_causal_mask,
|
20
|
+
)
|
21
|
+
from keras_hub.src.models.gemma.gemma_decoder_block import GemmaDecoderBlock
|
22
|
+
|
23
|
+
|
24
|
+
class PaliGemmaDecoderBlock(GemmaDecoderBlock):
|
25
|
+
"""PaliGemma mixed decoder block.
|
26
|
+
|
27
|
+
This class implements a decoder block of the PaliGemma Architecture: a
|
28
|
+
mixed transformer decoder block. Intended to be used with an input
|
29
|
+
sequence comprised of both embedded image and text data, this block
|
30
|
+
functions largely identically to the `GemmaDecoderBlock` class, with a
|
31
|
+
notable exception in the computation of attention masks.
|
32
|
+
|
33
|
+
Specifically, this decoder block will use causal self-attention on the
|
34
|
+
text portion of the input, while using full self-attention for image
|
35
|
+
data. It is expected that any image data occurs before text data in the
|
36
|
+
input.
|
37
|
+
|
38
|
+
Args:
|
39
|
+
hidden_dim: int. The size of the transformer hidden state at the end
|
40
|
+
of the block.
|
41
|
+
intermediate_dim: int. The output dimension of the first Dense layer in
|
42
|
+
the two-layer feedforward network.
|
43
|
+
head_dim: int. The size of each attention head.
|
44
|
+
num_query_heads: int. The number of heads for the query projections in
|
45
|
+
the attention layer.
|
46
|
+
num_key_value_heads: int. The number of heads for the key and value
|
47
|
+
projections in the attention layer.
|
48
|
+
layer_norm_epsilon: float. The epsilon hyperparameter used for layer
|
49
|
+
normalization.
|
50
|
+
dropout: float. The dropout rate for the transformer attention layer.
|
51
|
+
"""
|
52
|
+
|
53
|
+
def __init__(
|
54
|
+
self,
|
55
|
+
hidden_dim,
|
56
|
+
intermediate_dim,
|
57
|
+
head_dim,
|
58
|
+
num_query_heads,
|
59
|
+
num_key_value_heads,
|
60
|
+
layer_norm_epsilon=1e-6,
|
61
|
+
dropout=0,
|
62
|
+
**kwargs,
|
63
|
+
):
|
64
|
+
super().__init__(
|
65
|
+
hidden_dim=hidden_dim,
|
66
|
+
intermediate_dim=intermediate_dim,
|
67
|
+
head_dim=head_dim,
|
68
|
+
num_query_heads=num_query_heads,
|
69
|
+
num_key_value_heads=num_key_value_heads,
|
70
|
+
layer_norm_epsilon=layer_norm_epsilon,
|
71
|
+
dropout=dropout,
|
72
|
+
**kwargs,
|
73
|
+
)
|
74
|
+
|
75
|
+
def call(
|
76
|
+
self,
|
77
|
+
x,
|
78
|
+
padding_mask=None,
|
79
|
+
response_mask=None,
|
80
|
+
cache=None,
|
81
|
+
cache_update_index=0,
|
82
|
+
):
|
83
|
+
normalized_x = self.pre_attention_norm(x)
|
84
|
+
attention_mask = self._compute_attention_mask(
|
85
|
+
normalized_x, padding_mask, cache, cache_update_index, response_mask
|
86
|
+
)
|
87
|
+
if cache is not None:
|
88
|
+
attention, new_cache = self.attention(
|
89
|
+
normalized_x,
|
90
|
+
attention_mask=attention_mask,
|
91
|
+
cache=cache,
|
92
|
+
cache_update_index=cache_update_index,
|
93
|
+
)
|
94
|
+
else:
|
95
|
+
attention = self.attention(
|
96
|
+
normalized_x,
|
97
|
+
attention_mask=attention_mask,
|
98
|
+
)
|
99
|
+
|
100
|
+
if self.dropout:
|
101
|
+
attention = self.attention_dropout(attention)
|
102
|
+
|
103
|
+
attention_x = x + attention
|
104
|
+
normalized_x = self.pre_ffw_norm(attention_x)
|
105
|
+
|
106
|
+
x1 = self.gating_ffw(normalized_x)
|
107
|
+
x2 = self.gating_ffw_2(normalized_x)
|
108
|
+
x = keras.activations.gelu(x1, approximate=True) * x2
|
109
|
+
x = self.ffw_linear(x)
|
110
|
+
|
111
|
+
x = x + attention_x
|
112
|
+
|
113
|
+
if cache is not None:
|
114
|
+
return x, new_cache
|
115
|
+
return x
|
116
|
+
|
117
|
+
def _compute_attention_mask(
|
118
|
+
self,
|
119
|
+
x,
|
120
|
+
padding_mask,
|
121
|
+
cache,
|
122
|
+
cache_update_index,
|
123
|
+
response_mask=None,
|
124
|
+
):
|
125
|
+
batch_size = ops.shape(x)[0]
|
126
|
+
input_length = output_length = ops.shape(x)[1]
|
127
|
+
if cache is not None:
|
128
|
+
input_length = ops.shape(cache)[2]
|
129
|
+
|
130
|
+
causal_mask = compute_causal_mask(
|
131
|
+
batch_size=batch_size,
|
132
|
+
input_length=input_length,
|
133
|
+
output_length=output_length,
|
134
|
+
cache_index=cache_update_index,
|
135
|
+
)
|
136
|
+
|
137
|
+
if padding_mask is None:
|
138
|
+
# We should only hit this case during generative decoding.
|
139
|
+
# Just the causal mask is fine in this case.
|
140
|
+
return causal_mask
|
141
|
+
|
142
|
+
def token_to_attention_mask(mask, fill_value):
|
143
|
+
"""Reshape token mask -> attention mask padding for image tokens."""
|
144
|
+
mask = ops.cast(mask, "int32")
|
145
|
+
pad = input_length - ops.shape(mask)[1]
|
146
|
+
mask = ops.pad(mask, ((0, 0), (pad, 0)), constant_values=fill_value)
|
147
|
+
return ops.expand_dims(mask, axis=1)
|
148
|
+
|
149
|
+
padding_mask = token_to_attention_mask(padding_mask, 1)
|
150
|
+
if response_mask is not None:
|
151
|
+
response_mask = token_to_attention_mask(response_mask, 0)
|
152
|
+
not_response_mask = ops.logical_not(response_mask)
|
153
|
+
# Only apply the causal mask to the response tokens.
|
154
|
+
causal_mask = ops.logical_and(causal_mask, response_mask)
|
155
|
+
# Only apply block attention to the non-response tokens.
|
156
|
+
padding_mask = ops.logical_and(padding_mask, not_response_mask)
|
157
|
+
|
158
|
+
# Use block attention for the padding mask,
|
159
|
+
# which marks all image and prompt tokens.
|
160
|
+
return ops.logical_or(padding_mask, causal_mask)
|
@@ -0,0 +1,78 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""PaliGemma model preset configurations."""
|
15
|
+
|
16
|
+
# Metadata for loading pretrained model weights.
|
17
|
+
backbone_presets = {
|
18
|
+
"pali_gemma_3b_mix_224": {
|
19
|
+
"metadata": {
|
20
|
+
"description": (
|
21
|
+
"image size 224, mix fine tuned, text sequence " "length is 256"
|
22
|
+
),
|
23
|
+
"params": 2923335408,
|
24
|
+
"official_name": "PaliGemma",
|
25
|
+
"path": "pali_gemma",
|
26
|
+
"model_card": "https://www.kaggle.com/models/google/paligemma",
|
27
|
+
},
|
28
|
+
"kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_mix_224/1",
|
29
|
+
},
|
30
|
+
"pali_gemma_3b_mix_448": {
|
31
|
+
"metadata": {
|
32
|
+
"description": (
|
33
|
+
"image size 448, mix fine tuned, text sequence length is 512"
|
34
|
+
),
|
35
|
+
"params": 2924220144,
|
36
|
+
"official_name": "PaliGemma",
|
37
|
+
"path": "pali_gemma",
|
38
|
+
"model_card": "https://www.kaggle.com/models/google/paligemma",
|
39
|
+
},
|
40
|
+
"kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_mix_448/1",
|
41
|
+
},
|
42
|
+
"pali_gemma_3b_224": {
|
43
|
+
"metadata": {
|
44
|
+
"description": (
|
45
|
+
"image size 224, pre trained, text sequence length is 128"
|
46
|
+
),
|
47
|
+
"params": 2923335408,
|
48
|
+
"official_name": "PaliGemma",
|
49
|
+
"path": "pali_gemma",
|
50
|
+
"model_card": "https://www.kaggle.com/models/google/paligemma",
|
51
|
+
},
|
52
|
+
"kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_224/1",
|
53
|
+
},
|
54
|
+
"pali_gemma_3b_448": {
|
55
|
+
"metadata": {
|
56
|
+
"description": (
|
57
|
+
"image size 448, pre trained, text sequence length is 512"
|
58
|
+
),
|
59
|
+
"params": 2924220144,
|
60
|
+
"official_name": "PaliGemma",
|
61
|
+
"path": "pali_gemma",
|
62
|
+
"model_card": "https://www.kaggle.com/models/google/paligemma",
|
63
|
+
},
|
64
|
+
"kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_448/1",
|
65
|
+
},
|
66
|
+
"pali_gemma_3b_896": {
|
67
|
+
"metadata": {
|
68
|
+
"description": (
|
69
|
+
"image size 896, pre trained, text sequence length " "is 512"
|
70
|
+
),
|
71
|
+
"params": 2927759088,
|
72
|
+
"official_name": "PaliGemma",
|
73
|
+
"path": "pali_gemma",
|
74
|
+
"model_card": "https://www.kaggle.com/models/google/paligemma",
|
75
|
+
},
|
76
|
+
"kaggle_handle": "kaggle://keras/paligemma/keras/pali_gemma_3b_896/1",
|
77
|
+
},
|
78
|
+
}
|