keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +21 -3
- keras_hub/api/models/__init__.py +71 -12
- keras_hub/api/tokenizers/__init__.py +1 -1
- keras_hub/src/bounding_box/__init__.py +2 -0
- keras_hub/src/bounding_box/converters.py +102 -12
- keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
- keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
- keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
- keras_hub/src/layers/modeling/rms_normalization.py +36 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
- keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
- keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
- keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
- keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
- keras_hub/src/layers/preprocessing/image_converter.py +170 -34
- keras_hub/src/metrics/bleu.py +4 -3
- keras_hub/src/models/albert/albert_presets.py +4 -12
- keras_hub/src/models/albert/albert_text_classifier.py +7 -7
- keras_hub/src/models/backbone.py +3 -14
- keras_hub/src/models/bart/bart_backbone.py +4 -4
- keras_hub/src/models/bart/bart_presets.py +3 -9
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
- keras_hub/src/models/basnet/__init__.py +5 -0
- keras_hub/src/models/basnet/basnet.py +122 -0
- keras_hub/src/models/basnet/basnet_backbone.py +366 -0
- keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
- keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
- keras_hub/src/models/basnet/basnet_presets.py +17 -0
- keras_hub/src/models/bert/bert_presets.py +14 -32
- keras_hub/src/models/bert/bert_text_classifier.py +3 -3
- keras_hub/src/models/bloom/bloom_presets.py +8 -24
- keras_hub/src/models/causal_lm.py +56 -12
- keras_hub/src/models/clip/__init__.py +5 -0
- keras_hub/src/models/clip/clip_backbone.py +286 -0
- keras_hub/src/models/clip/clip_encoder_block.py +19 -4
- keras_hub/src/models/clip/clip_image_converter.py +8 -0
- keras_hub/src/models/clip/clip_presets.py +93 -0
- keras_hub/src/models/clip/clip_text_encoder.py +4 -1
- keras_hub/src/models/clip/clip_tokenizer.py +18 -3
- keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
- keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
- keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
- keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
- keras_hub/src/models/deeplab_v3/__init__.py +7 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
- keras_hub/src/models/densenet/densenet_backbone.py +6 -4
- keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
- keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
- keras_hub/src/models/densenet/densenet_presets.py +9 -15
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
- keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
- keras_hub/src/models/efficientnet/__init__.py +9 -0
- keras_hub/src/models/efficientnet/cba.py +141 -0
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
- keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
- keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
- keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
- keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
- keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
- keras_hub/src/models/efficientnet/mbconv.py +53 -22
- keras_hub/src/models/electra/electra_backbone.py +2 -2
- keras_hub/src/models/electra/electra_presets.py +6 -18
- keras_hub/src/models/f_net/f_net_presets.py +2 -6
- keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
- keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
- keras_hub/src/models/falcon/falcon_backbone.py +5 -3
- keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
- keras_hub/src/models/falcon/falcon_presets.py +1 -3
- keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
- keras_hub/src/models/feature_pyramid_backbone.py +1 -1
- keras_hub/src/models/flux/__init__.py +5 -0
- keras_hub/src/models/flux/flux_layers.py +496 -0
- keras_hub/src/models/flux/flux_maths.py +225 -0
- keras_hub/src/models/flux/flux_model.py +236 -0
- keras_hub/src/models/flux/flux_presets.py +3 -0
- keras_hub/src/models/flux/flux_text_to_image.py +146 -0
- keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
- keras_hub/src/models/gemma/gemma_backbone.py +35 -20
- keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
- keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
- keras_hub/src/models/gemma/gemma_presets.py +29 -63
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
- keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
- keras_hub/src/models/image_classifier.py +147 -2
- keras_hub/src/models/image_classifier_preprocessor.py +6 -3
- keras_hub/src/models/image_object_detector.py +87 -0
- keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
- keras_hub/src/models/image_segmenter.py +0 -5
- keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
- keras_hub/src/models/image_to_image.py +417 -0
- keras_hub/src/models/inpaint.py +520 -0
- keras_hub/src/models/llama/llama_backbone.py +138 -12
- keras_hub/src/models/llama/llama_causal_lm.py +3 -1
- keras_hub/src/models/llama/llama_presets.py +10 -20
- keras_hub/src/models/llama3/llama3_backbone.py +12 -11
- keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
- keras_hub/src/models/llama3/llama3_presets.py +4 -12
- keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
- keras_hub/src/models/mistral/mistral_backbone.py +16 -15
- keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
- keras_hub/src/models/mistral/mistral_presets.py +3 -9
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
- keras_hub/src/models/mit/__init__.py +6 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
- keras_hub/src/models/mit/mit_image_classifier.py +12 -0
- keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/mit/mit_image_converter.py +8 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
- keras_hub/src/models/mit/mit_presets.py +139 -0
- keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
- keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
- keras_hub/src/models/opt/opt_causal_lm.py +2 -2
- keras_hub/src/models/opt/opt_presets.py +4 -12
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
- keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
- keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
- keras_hub/src/models/phi3/phi3_decoder.py +0 -1
- keras_hub/src/models/phi3/phi3_presets.py +2 -6
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
- keras_hub/src/models/preprocessor.py +25 -11
- keras_hub/src/models/resnet/resnet_backbone.py +3 -14
- keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
- keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
- keras_hub/src/models/resnet/resnet_presets.py +127 -18
- keras_hub/src/models/retinanet/__init__.py +5 -0
- keras_hub/src/models/retinanet/anchor_generator.py +52 -53
- keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
- keras_hub/src/models/retinanet/non_max_supression.py +1 -0
- keras_hub/src/models/retinanet/prediction_head.py +192 -0
- keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
- keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
- keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
- keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
- keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
- keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
- keras_hub/src/models/roberta/roberta_backbone.py +2 -2
- keras_hub/src/models/roberta/roberta_presets.py +6 -8
- keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
- keras_hub/src/models/sam/__init__.py +5 -0
- keras_hub/src/models/sam/sam_backbone.py +2 -3
- keras_hub/src/models/sam/sam_image_converter.py +2 -4
- keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
- keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
- keras_hub/src/models/sam/sam_layers.py +5 -3
- keras_hub/src/models/sam/sam_presets.py +3 -9
- keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
- keras_hub/src/models/sam/sam_transformer.py +5 -4
- keras_hub/src/models/segformer/__init__.py +8 -0
- keras_hub/src/models/segformer/segformer_backbone.py +167 -0
- keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
- keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
- keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
- keras_hub/src/models/segformer/segformer_presets.py +136 -0
- keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
- keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
- keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
- keras_hub/src/models/t5/t5_backbone.py +5 -4
- keras_hub/src/models/t5/t5_presets.py +47 -19
- keras_hub/src/models/task.py +47 -39
- keras_hub/src/models/text_classifier.py +2 -2
- keras_hub/src/models/text_to_image.py +106 -41
- keras_hub/src/models/vae/__init__.py +1 -0
- keras_hub/src/models/vae/vae_backbone.py +184 -0
- keras_hub/src/models/vae/vae_layers.py +739 -0
- keras_hub/src/models/vgg/__init__.py +5 -0
- keras_hub/src/models/vgg/vgg_backbone.py +4 -24
- keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
- keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
- keras_hub/src/models/vgg/vgg_presets.py +48 -0
- keras_hub/src/models/vit/__init__.py +5 -0
- keras_hub/src/models/vit/vit_backbone.py +152 -0
- keras_hub/src/models/vit/vit_image_classifier.py +187 -0
- keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vit/vit_image_converter.py +73 -0
- keras_hub/src/models/vit/vit_layers.py +391 -0
- keras_hub/src/models/vit/vit_presets.py +126 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
- keras_hub/src/models/vit_det/vit_layers.py +3 -3
- keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
- keras_hub/src/models/whisper/whisper_backbone.py +6 -5
- keras_hub/src/models/whisper/whisper_decoder.py +3 -5
- keras_hub/src/models/whisper/whisper_presets.py +10 -30
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
- keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
- keras_hub/src/models/xlnet/relative_attention.py +20 -19
- keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
- keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
- keras_hub/src/samplers/contrastive_sampler.py +2 -3
- keras_hub/src/samplers/sampler.py +2 -1
- keras_hub/src/tests/test_case.py +41 -6
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
- keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
- keras_hub/src/tokenizers/tokenizer.py +10 -13
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
- keras_hub/src/utils/keras_utils.py +2 -13
- keras_hub/src/utils/pipeline_model.py +3 -3
- keras_hub/src/utils/preset_utils.py +196 -144
- keras_hub/src/utils/tensor_utils.py +4 -4
- keras_hub/src/utils/timm/convert_densenet.py +6 -4
- keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
- keras_hub/src/utils/timm/convert_resnet.py +1 -1
- keras_hub/src/utils/timm/convert_vgg.py +85 -0
- keras_hub/src/utils/timm/preset_loader.py +14 -9
- keras_hub/src/utils/transformers/convert_llama3.py +21 -5
- keras_hub/src/utils/transformers/convert_vit.py +150 -0
- keras_hub/src/utils/transformers/preset_loader.py +23 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
- keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
- keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
- keras_hub/src/models/mix_transformer/__init__.py +0 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
- keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
- keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,139 @@
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
2
|
+
# you may not use this file except in compliance with the License.
|
3
|
+
# You may obtain a copy of the License at
|
4
|
+
#
|
5
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
#
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
10
|
+
# See the License for the specific language governing permissions and
|
11
|
+
# limitations under the License.
|
12
|
+
"""MiT model preset configurations."""
|
13
|
+
|
14
|
+
backbone_presets_with_weights = {
|
15
|
+
"mit_b0_ade20k_512": {
|
16
|
+
"metadata": {
|
17
|
+
"description": (
|
18
|
+
"MiT (MixTransformer) model with 8 transformer blocks."
|
19
|
+
),
|
20
|
+
"params": 3321962,
|
21
|
+
"path": "mit",
|
22
|
+
},
|
23
|
+
"kaggle_handle": "kaggle://keras/mit/keras/mit_b0_ade20k_512/4",
|
24
|
+
},
|
25
|
+
"mit_b1_ade20k_512": {
|
26
|
+
"metadata": {
|
27
|
+
"description": (
|
28
|
+
"MiT (MixTransformer) model with 8 transformer blocks."
|
29
|
+
),
|
30
|
+
"params": 13156554,
|
31
|
+
"path": "mit",
|
32
|
+
},
|
33
|
+
"kaggle_handle": "kaggle://keras/mit/keras/mit_b1_ade20k_512/4",
|
34
|
+
},
|
35
|
+
"mit_b2_ade20k_512": {
|
36
|
+
"metadata": {
|
37
|
+
"description": (
|
38
|
+
"MiT (MixTransformer) model with 16 transformer blocks."
|
39
|
+
),
|
40
|
+
"params": 24201418,
|
41
|
+
"path": "mit",
|
42
|
+
},
|
43
|
+
"kaggle_handle": "kaggle://keras/mit/keras/mit_b2_ade20k_512/4",
|
44
|
+
},
|
45
|
+
"mit_b3_ade20k_512": {
|
46
|
+
"metadata": {
|
47
|
+
"description": (
|
48
|
+
"MiT (MixTransformer) model with 28 transformer blocks."
|
49
|
+
),
|
50
|
+
"params": 44077258,
|
51
|
+
"path": "mit",
|
52
|
+
},
|
53
|
+
"kaggle_handle": "kaggle://keras/mit/keras/mit_b3_ade20k_512/3",
|
54
|
+
},
|
55
|
+
"mit_b4_ade20k_512": {
|
56
|
+
"metadata": {
|
57
|
+
"description": (
|
58
|
+
"MiT (MixTransformer) model with 41 transformer blocks."
|
59
|
+
),
|
60
|
+
"params": 60847818,
|
61
|
+
"path": "mit",
|
62
|
+
},
|
63
|
+
"kaggle_handle": "kaggle://keras/mit/keras/mit_b4_ade20k_512/3",
|
64
|
+
},
|
65
|
+
"mit_b5_ade20k_640": {
|
66
|
+
"metadata": {
|
67
|
+
"description": (
|
68
|
+
"MiT (MixTransformer) model with 52 transformer blocks."
|
69
|
+
),
|
70
|
+
"params": 81448138,
|
71
|
+
"path": "mit",
|
72
|
+
},
|
73
|
+
"kaggle_handle": "kaggle://keras/mit/keras/mit_b5_ade20k_640/3",
|
74
|
+
},
|
75
|
+
"mit_b0_cityscapes_1024": {
|
76
|
+
"metadata": {
|
77
|
+
"description": (
|
78
|
+
"MiT (MixTransformer) model with 8 transformer blocks."
|
79
|
+
),
|
80
|
+
"params": 3321962,
|
81
|
+
"path": "mit",
|
82
|
+
},
|
83
|
+
"kaggle_handle": "kaggle://keras/mit/keras/mit_b0_cityscapes_1024/3",
|
84
|
+
},
|
85
|
+
"mit_b1_cityscapes_1024": {
|
86
|
+
"metadata": {
|
87
|
+
"description": (
|
88
|
+
"MiT (MixTransformer) model with 8 transformer blocks."
|
89
|
+
),
|
90
|
+
"params": 13156554,
|
91
|
+
"path": "mit",
|
92
|
+
},
|
93
|
+
"kaggle_handle": "kaggle://keras/mit/keras/mit_b1_cityscapes_1024/3",
|
94
|
+
},
|
95
|
+
"mit_b2_cityscapes_1024": {
|
96
|
+
"metadata": {
|
97
|
+
"description": (
|
98
|
+
"MiT (MixTransformer) model with 16 transformer blocks."
|
99
|
+
),
|
100
|
+
"params": 24201418,
|
101
|
+
"path": "mit",
|
102
|
+
},
|
103
|
+
"kaggle_handle": "kaggle://keras/mit/keras/mit_b2_cityscapes_1024/3",
|
104
|
+
},
|
105
|
+
"mit_b3_cityscapes_1024": {
|
106
|
+
"metadata": {
|
107
|
+
"description": (
|
108
|
+
"MiT (MixTransformer) model with 28 transformer blocks."
|
109
|
+
),
|
110
|
+
"params": 44077258,
|
111
|
+
"path": "mit",
|
112
|
+
},
|
113
|
+
"kaggle_handle": "kaggle://keras/mit/keras/mit_b3_cityscapes_1024/3",
|
114
|
+
},
|
115
|
+
"mit_b4_cityscapes_1024": {
|
116
|
+
"metadata": {
|
117
|
+
"description": (
|
118
|
+
"MiT (MixTransformer) model with 41 transformer blocks."
|
119
|
+
),
|
120
|
+
"params": 60847818,
|
121
|
+
"path": "mit",
|
122
|
+
},
|
123
|
+
"kaggle_handle": "kaggle://keras/mit/keras/mit_b4_cityscapes_1024/3",
|
124
|
+
},
|
125
|
+
"mit_b5_cityscapes_1024": {
|
126
|
+
"metadata": {
|
127
|
+
"description": (
|
128
|
+
"MiT (MixTransformer) model with 52 transformer blocks."
|
129
|
+
),
|
130
|
+
"params": 81448138,
|
131
|
+
"path": "mit",
|
132
|
+
},
|
133
|
+
"kaggle_handle": "kaggle://keras/mit/keras/mit_b5_cityscapes_1024/3",
|
134
|
+
},
|
135
|
+
}
|
136
|
+
|
137
|
+
backbone_presets = {
|
138
|
+
**backbone_presets_with_weights,
|
139
|
+
}
|
@@ -47,11 +47,11 @@ class MobileNetBackbone(Backbone):
|
|
47
47
|
of filters in each layer.
|
48
48
|
- If `depth_multiplier` > 1.0, proportionally increases the number
|
49
49
|
of filters in each layer.
|
50
|
-
- If `depth_multiplier` = 1, default number of filters from the
|
51
|
-
are used at each layer.
|
50
|
+
- If `depth_multiplier` = 1, default number of filters from the
|
51
|
+
paper are used at each layer.
|
52
52
|
input_num_filters: number of filters in first convolution layer
|
53
|
-
output_num_filters: specifies whether to add conv and batch_norm in the
|
54
|
-
if set to None, it will not add these layers in the end.
|
53
|
+
output_num_filters: specifies whether to add conv and batch_norm in the
|
54
|
+
end, if set to None, it will not add these layers in the end.
|
55
55
|
'None' for MobileNetV1
|
56
56
|
input_activation: activation function to be used in the input layer
|
57
57
|
'hard_swish' for MobileNetV3,
|
@@ -96,7 +96,7 @@ class MobileNetBackbone(Backbone):
|
|
96
96
|
stackwise_activation,
|
97
97
|
output_num_filters,
|
98
98
|
inverted_res_block,
|
99
|
-
image_shape=(
|
99
|
+
image_shape=(None, None, 3),
|
100
100
|
input_activation="hard_swish",
|
101
101
|
output_activation="hard_swish",
|
102
102
|
depth_multiplier=1.0,
|
@@ -365,7 +365,7 @@ def apply_depthwise_conv_block(
|
|
365
365
|
batch normalization and relu6 activation.
|
366
366
|
|
367
367
|
Args:
|
368
|
-
x: Input tensor of shape `(rows, cols, channels)
|
368
|
+
x: Input tensor of shape `(rows, cols, channels)`
|
369
369
|
filters: Integer, the dimensionality of the output space
|
370
370
|
(i.e. the number of output filters in the pointwise convolution).
|
371
371
|
depth_multiplier: controls the width of the network.
|
@@ -383,8 +383,8 @@ def apply_depthwise_conv_block(
|
|
383
383
|
block_id: Integer, a unique identification designating the block number.
|
384
384
|
|
385
385
|
Input shape:
|
386
|
-
4D tensor with shape
|
387
|
-
4D tensor with shape
|
386
|
+
4D tensor with shape `(batch, rows, cols, channels)` in "channels_last"
|
387
|
+
4D tensor with shape `(batch, channels, rows, cols)` in "channels_first"
|
388
388
|
Returns:
|
389
389
|
Output tensor of block.
|
390
390
|
"""
|
@@ -1,5 +1,3 @@
|
|
1
|
-
import keras
|
2
|
-
|
3
1
|
from keras_hub.src.api_export import keras_hub_export
|
4
2
|
from keras_hub.src.models.image_classifier import ImageClassifier
|
5
3
|
from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone
|
@@ -7,94 +5,4 @@ from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone
|
|
7
5
|
|
8
6
|
@keras_hub_export("keras_hub.models.MobileNetImageClassifier")
|
9
7
|
class MobileNetImageClassifier(ImageClassifier):
|
10
|
-
"""MobileNetV3 image classifier task model.
|
11
|
-
|
12
|
-
To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)`
|
13
|
-
where `x` is a tensor and `y` is a integer from `[0, num_classes)`.
|
14
|
-
All `ImageClassifier` tasks include a `from_preset()` constructor which can
|
15
|
-
be used to load a pre-trained config and weights.
|
16
|
-
|
17
|
-
Args:
|
18
|
-
backbone: A `keras_hub.models.MobileNetBackbone` instance.
|
19
|
-
num_classes: int. The number of classes to predict.
|
20
|
-
activation: `None`, str or callable. The activation function to use on
|
21
|
-
the `Dense` layer. Set `activation=None` to return the output
|
22
|
-
logits. Defaults to `"softmax"`.
|
23
|
-
|
24
|
-
Examples:
|
25
|
-
|
26
|
-
Call `predict()` to run inference.
|
27
|
-
```python
|
28
|
-
# Load preset and train
|
29
|
-
images = np.ones((2, 224, 224, 3), dtype="float32")
|
30
|
-
classifier = keras_hub.models.MobileNetImageClassifier.from_preset(
|
31
|
-
"mobilenet_v3_small_imagenet")
|
32
|
-
classifier.predict(images)
|
33
|
-
```
|
34
|
-
|
35
|
-
Custom backbone.
|
36
|
-
```python
|
37
|
-
images = np.ones((2, 224, 224, 3), dtype="float32")
|
38
|
-
labels = [0, 3]
|
39
|
-
model = MobileNetBackbone(
|
40
|
-
stackwise_expansion = [1, 4, 6],
|
41
|
-
stackwise_filters = [4, 8, 16],
|
42
|
-
stackwise_kernel_size = [3, 3, 5],
|
43
|
-
stackwise_stride = [2, 2, 1],
|
44
|
-
stackwise_se_ratio = [ 0.25, None, 0.25],
|
45
|
-
stackwise_activation = ["relu", "relu", "hard_swish"],
|
46
|
-
output_filter=1280,
|
47
|
-
activation="hard_swish",
|
48
|
-
inverted_res_block=True,
|
49
|
-
)
|
50
|
-
classifier = keras_hub.models.MobileNetImageClassifier(
|
51
|
-
backbone=backbone,
|
52
|
-
num_classes=4,
|
53
|
-
)
|
54
|
-
classifier.fit(x=images, y=labels, batch_size=2)
|
55
|
-
```
|
56
|
-
"""
|
57
|
-
|
58
8
|
backbone_cls = MobileNetBackbone
|
59
|
-
|
60
|
-
def __init__(
|
61
|
-
self,
|
62
|
-
backbone,
|
63
|
-
num_classes,
|
64
|
-
activation="softmax",
|
65
|
-
preprocessor=None, # adding this dummy arg for saved model test
|
66
|
-
# TODO: once preprocessor flow is figured out, this needs to be updated
|
67
|
-
**kwargs,
|
68
|
-
):
|
69
|
-
# === Layers ===
|
70
|
-
self.backbone = backbone
|
71
|
-
self.output_dense = keras.layers.Dense(
|
72
|
-
num_classes,
|
73
|
-
activation=activation,
|
74
|
-
name="predictions",
|
75
|
-
)
|
76
|
-
|
77
|
-
# === Functional Model ===
|
78
|
-
inputs = self.backbone.input
|
79
|
-
x = self.backbone(inputs)
|
80
|
-
outputs = self.output_dense(x)
|
81
|
-
super().__init__(
|
82
|
-
inputs=inputs,
|
83
|
-
outputs=outputs,
|
84
|
-
**kwargs,
|
85
|
-
)
|
86
|
-
|
87
|
-
# === Config ===
|
88
|
-
self.num_classes = num_classes
|
89
|
-
self.activation = activation
|
90
|
-
|
91
|
-
def get_config(self):
|
92
|
-
# Backbone serialized in `super`
|
93
|
-
config = super().get_config()
|
94
|
-
config.update(
|
95
|
-
{
|
96
|
-
"num_classes": self.num_classes,
|
97
|
-
"activation": self.activation,
|
98
|
-
}
|
99
|
-
)
|
100
|
-
return config
|
@@ -171,8 +171,8 @@ class OPTCausalLM(CausalLM):
|
|
171
171
|
Args:
|
172
172
|
token_ids: a dense int Tensor with shape `(batch_size, max_length)`.
|
173
173
|
cache: a dense float Tensor, the cache of key and value.
|
174
|
-
cache_update_index: int, or int Tensor. The index of current inputs
|
175
|
-
whole sequence.
|
174
|
+
cache_update_index: int, or int Tensor. The index of current inputs
|
175
|
+
in the whole sequence.
|
176
176
|
|
177
177
|
Returns:
|
178
178
|
A (logits, hidden_states, cache) tuple. Where `logits` is the
|
@@ -9,11 +9,9 @@ backbone_presets = {
|
|
9
9
|
"BookCorpus, CommonCrawl, Pile, and PushShift.io corpora."
|
10
10
|
),
|
11
11
|
"params": 125237760,
|
12
|
-
"official_name": "OPT",
|
13
12
|
"path": "opt",
|
14
|
-
"model_card": "https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/model_card.md",
|
15
13
|
},
|
16
|
-
"kaggle_handle": "kaggle://keras/opt/keras/opt_125m_en/
|
14
|
+
"kaggle_handle": "kaggle://keras/opt/keras/opt_125m_en/3",
|
17
15
|
},
|
18
16
|
# We skip the 350m checkpoint because it does not match the structure of
|
19
17
|
# other checkpoints.
|
@@ -24,11 +22,9 @@ backbone_presets = {
|
|
24
22
|
"BookCorpus, CommonCrawl, Pile, and PushShift.io corpora."
|
25
23
|
),
|
26
24
|
"params": 1315753984,
|
27
|
-
"official_name": "OPT",
|
28
25
|
"path": "opt",
|
29
|
-
"model_card": "https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/model_card.md",
|
30
26
|
},
|
31
|
-
"kaggle_handle": "kaggle://keras/opt/keras/opt_1.3b_en/
|
27
|
+
"kaggle_handle": "kaggle://keras/opt/keras/opt_1.3b_en/3",
|
32
28
|
},
|
33
29
|
"opt_2.7b_en": {
|
34
30
|
"metadata": {
|
@@ -37,11 +33,9 @@ backbone_presets = {
|
|
37
33
|
"BookCorpus, CommonCrawl, Pile, and PushShift.io corpora."
|
38
34
|
),
|
39
35
|
"params": 2700000000,
|
40
|
-
"official_name": "OPT",
|
41
36
|
"path": "opt",
|
42
|
-
"model_card": "https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/model_card.md",
|
43
37
|
},
|
44
|
-
"kaggle_handle": "kaggle://keras/opt/keras/opt_2.7b_en/
|
38
|
+
"kaggle_handle": "kaggle://keras/opt/keras/opt_2.7b_en/3",
|
45
39
|
},
|
46
40
|
"opt_6.7b_en": {
|
47
41
|
"metadata": {
|
@@ -50,10 +44,8 @@ backbone_presets = {
|
|
50
44
|
"BookCorpus, CommonCrawl, Pile, and PushShift.io corpora."
|
51
45
|
),
|
52
46
|
"params": 6700000000,
|
53
|
-
"official_name": "OPT",
|
54
47
|
"path": "opt",
|
55
|
-
"model_card": "https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/model_card.md",
|
56
48
|
},
|
57
|
-
"kaggle_handle": "kaggle://keras/opt/keras/opt_6.7b_en/
|
49
|
+
"kaggle_handle": "kaggle://keras/opt/keras/opt_6.7b_en/3",
|
58
50
|
},
|
59
51
|
}
|
@@ -48,24 +48,40 @@ class PaliGemmaBackbone(Backbone):
|
|
48
48
|
a two-layer feedforward network for each transformer decoder block.
|
49
49
|
head_dim: int. The size of each attention head in the mixed decoder.
|
50
50
|
vit_patch_size: int. The size of each square patch in the input image.
|
51
|
-
vit_num_heads: int. The number of attention heads for the vision(image)
|
51
|
+
vit_num_heads: int. The number of attention heads for the vision (image)
|
52
52
|
transformer encoder.
|
53
53
|
vit_hidden_dim: int. The size of the transformer hidden state at the end
|
54
54
|
of each vision transformer layer.
|
55
55
|
vit_num_layers: int. The number of vision transformer layers.
|
56
56
|
vit_intermediate_dim: int. The output dimension of the first Dense layer
|
57
|
-
in a two-layer feedforward network for vision transformer.
|
58
|
-
|
59
|
-
|
60
|
-
|
57
|
+
in a two-layer feedforward network for vision transformer. Defaults
|
58
|
+
to `4304`.
|
59
|
+
vit_pooling: `None` or string. The encoded vision embeddings are pooled
|
60
|
+
using the specified polling setting. The accepted values are
|
61
|
+
`"map"`, `"gap"`, `"0"` or `None`. Defaults to `None`.
|
61
62
|
vit_classifier_activation: activation function. The activation that
|
62
63
|
is used for final output classification in the vision transformer.
|
64
|
+
Defaults to `None`.
|
63
65
|
vit_name: string. The name used for vision transformer layers.
|
64
|
-
|
65
|
-
|
66
|
+
query_head_dim_normalize: boolean. If `True` normalize the query before
|
67
|
+
attention with `head_dim`. If `False`, normalize the query with
|
68
|
+
`hidden_dim / num_query_heads`. Defaults to `True`.
|
69
|
+
use_post_ffw_norm: boolean. Whether to normalize after the feedforward
|
70
|
+
block. Defaults to `False`.
|
71
|
+
use_post_attention_norm: boolean. Whether to normalize after the
|
72
|
+
attention block. Defaults to `False`.
|
73
|
+
attention_logit_soft_cap: `None` or int. Soft cap for the attention
|
74
|
+
logits. Defaults to `None`.
|
75
|
+
final_logit_soft_cap: `None` or int. Soft cap for the final logits.
|
76
|
+
Defaults to `None`.
|
77
|
+
use_sliding_window_attention: boolean. Whether to use sliding local
|
78
|
+
window attention. Defaults to `False`.
|
79
|
+
sliding_window_size: int. Size of the sliding local window. Defaults to
|
80
|
+
`4096`.
|
66
81
|
layer_norm_epsilon: float. The epsilon value user for every layer norm
|
67
|
-
in all transformer blocks.
|
82
|
+
in all transformer blocks. Defaults to `1e-6`.
|
68
83
|
dropout: float. Dropout probability for the Transformer decoder blocks.
|
84
|
+
Defaults to `0`.
|
69
85
|
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
|
70
86
|
for the models computations and weights. Note that some
|
71
87
|
computations, such as softmax and layer normalization will always
|
@@ -121,7 +137,13 @@ class PaliGemmaBackbone(Backbone):
|
|
121
137
|
vit_pooling=None,
|
122
138
|
vit_classifier_activation=None,
|
123
139
|
vit_name=None,
|
124
|
-
|
140
|
+
query_head_dim_normalize=True,
|
141
|
+
use_post_ffw_norm=False,
|
142
|
+
use_post_attention_norm=False,
|
143
|
+
attention_logit_soft_cap=None,
|
144
|
+
final_logit_soft_cap=None,
|
145
|
+
use_sliding_window_attention=False,
|
146
|
+
sliding_window_size=4096,
|
125
147
|
layer_norm_epsilon=1e-6,
|
126
148
|
dropout=0,
|
127
149
|
dtype=None,
|
@@ -139,13 +161,13 @@ class PaliGemmaBackbone(Backbone):
|
|
139
161
|
seed=None,
|
140
162
|
),
|
141
163
|
dtype=dtype,
|
164
|
+
logit_soft_cap=final_logit_soft_cap,
|
142
165
|
name="token_embedding",
|
143
166
|
)
|
144
167
|
# TODO Remove this. Work around for previous serialization bug.
|
145
168
|
vit_intermediate_dim = vit_intermediate_dim or 4304
|
146
169
|
self.vit_encoder = PaliGemmaVit(
|
147
170
|
image_size=image_size,
|
148
|
-
include_rescaling=include_rescaling,
|
149
171
|
patch_size=vit_patch_size,
|
150
172
|
num_heads=vit_num_heads,
|
151
173
|
hidden_dim=vit_hidden_dim,
|
@@ -159,12 +181,19 @@ class PaliGemmaBackbone(Backbone):
|
|
159
181
|
)
|
160
182
|
self.transformer_layers = []
|
161
183
|
for i in range(num_layers):
|
184
|
+
sliding_window = use_sliding_window_attention and (i % 2 == 0)
|
162
185
|
layer = PaliGemmaDecoderBlock(
|
163
186
|
hidden_dim=hidden_dim,
|
164
187
|
intermediate_dim=intermediate_dim,
|
165
|
-
num_query_heads=num_query_heads,
|
166
188
|
head_dim=head_dim,
|
189
|
+
num_query_heads=num_query_heads,
|
167
190
|
num_key_value_heads=num_key_value_heads,
|
191
|
+
query_head_dim_normalize=query_head_dim_normalize,
|
192
|
+
use_post_ffw_norm=use_post_ffw_norm,
|
193
|
+
use_post_attention_norm=use_post_attention_norm,
|
194
|
+
logit_soft_cap=attention_logit_soft_cap,
|
195
|
+
use_sliding_window_attention=sliding_window,
|
196
|
+
sliding_window_size=sliding_window_size,
|
168
197
|
dropout=dropout,
|
169
198
|
dtype=dtype,
|
170
199
|
name=f"decoder_block_{i}",
|
@@ -177,7 +206,9 @@ class PaliGemmaBackbone(Backbone):
|
|
177
206
|
)
|
178
207
|
|
179
208
|
# === Functional Model ===
|
180
|
-
image_input =
|
209
|
+
image_input = keras.Input(
|
210
|
+
shape=(image_size, image_size, 3), name="images"
|
211
|
+
)
|
181
212
|
token_id_input = keras.Input(
|
182
213
|
shape=(None,), dtype="int32", name="token_ids"
|
183
214
|
)
|
@@ -215,7 +246,6 @@ class PaliGemmaBackbone(Backbone):
|
|
215
246
|
# === Config ===
|
216
247
|
self.vocabulary_size = vocabulary_size
|
217
248
|
self.image_size = image_size
|
218
|
-
self.include_rescaling = include_rescaling
|
219
249
|
self.num_layers = num_layers
|
220
250
|
self.num_query_heads = num_query_heads
|
221
251
|
self.num_key_value_heads = num_key_value_heads
|
@@ -224,7 +254,15 @@ class PaliGemmaBackbone(Backbone):
|
|
224
254
|
self.head_dim = head_dim
|
225
255
|
self.layer_norm_epsilon = layer_norm_epsilon
|
226
256
|
self.dropout = dropout
|
227
|
-
#
|
257
|
+
# Gemma2 params
|
258
|
+
self.query_head_dim_normalize = query_head_dim_normalize
|
259
|
+
self.use_post_ffw_norm = use_post_ffw_norm
|
260
|
+
self.use_post_attention_norm = use_post_attention_norm
|
261
|
+
self.attention_logit_soft_cap = attention_logit_soft_cap
|
262
|
+
self.final_logit_soft_cap = final_logit_soft_cap
|
263
|
+
self.sliding_window_size = sliding_window_size
|
264
|
+
self.use_sliding_window_attention = use_sliding_window_attention
|
265
|
+
# ViT params
|
228
266
|
self.vit_patch_size = vit_patch_size
|
229
267
|
self.vit_num_heads = vit_num_heads
|
230
268
|
self.vit_hidden_dim = vit_hidden_dim
|
@@ -242,15 +280,12 @@ class PaliGemmaBackbone(Backbone):
|
|
242
280
|
{
|
243
281
|
"vocabulary_size": self.vocabulary_size,
|
244
282
|
"image_size": self.image_size,
|
245
|
-
"include_rescaling": self.include_rescaling,
|
246
283
|
"num_layers": self.num_layers,
|
247
284
|
"num_query_heads": self.num_query_heads,
|
248
285
|
"num_key_value_heads": self.num_key_value_heads,
|
249
286
|
"hidden_dim": self.hidden_dim,
|
250
287
|
"intermediate_dim": self.intermediate_dim,
|
251
288
|
"head_dim": self.head_dim,
|
252
|
-
"layer_norm_epsilon": self.layer_norm_epsilon,
|
253
|
-
"dropout": self.dropout,
|
254
289
|
"vit_patch_size": self.vit_patch_size,
|
255
290
|
"vit_num_heads": self.vit_num_heads,
|
256
291
|
"vit_hidden_dim": self.vit_hidden_dim,
|
@@ -259,6 +294,17 @@ class PaliGemmaBackbone(Backbone):
|
|
259
294
|
"vit_pooling": self.vit_pooling,
|
260
295
|
"vit_classifier_activation": self.vit_classifier_activation,
|
261
296
|
"vit_name": self.vit_name,
|
297
|
+
"query_head_dim_normalize": self.query_head_dim_normalize,
|
298
|
+
"use_post_ffw_norm": self.use_post_ffw_norm,
|
299
|
+
"use_post_attention_norm": self.use_post_attention_norm,
|
300
|
+
"final_logit_soft_cap": self.final_logit_soft_cap,
|
301
|
+
"attention_logit_soft_cap": self.attention_logit_soft_cap,
|
302
|
+
"sliding_window_size": self.sliding_window_size,
|
303
|
+
"use_sliding_window_attention": (
|
304
|
+
self.use_sliding_window_attention
|
305
|
+
),
|
306
|
+
"layer_norm_epsilon": self.layer_norm_epsilon,
|
307
|
+
"dropout": self.dropout,
|
262
308
|
}
|
263
309
|
)
|
264
310
|
return config
|
@@ -110,7 +110,9 @@ class PaliGemmaCausalLM(CausalLM):
|
|
110
110
|
self.backbone = backbone
|
111
111
|
|
112
112
|
# === Functional Model ===
|
113
|
-
|
113
|
+
# This must be "backbone.input" i.e. the full input structure,
|
114
|
+
# rather than "backbone.inputs" which is the flattened list of inputs.
|
115
|
+
inputs = backbone.input
|
114
116
|
hidden_state = backbone(inputs=inputs)
|
115
117
|
outputs = backbone.token_embedding(hidden_state, reverse=True)
|
116
118
|
outputs = outputs[:, backbone.image_sequence_length :, :]
|
@@ -31,33 +31,25 @@ class PaliGemmaDecoderBlock(GemmaDecoderBlock):
|
|
31
31
|
the attention layer.
|
32
32
|
num_key_value_heads: int. The number of heads for the key and value
|
33
33
|
projections in the attention layer.
|
34
|
+
query_head_dim_normalize: boolean. If `True` normalize the query before
|
35
|
+
attention with `head_dim`. If `False`, normalize the query with
|
36
|
+
`hidden_dim / num_query_heads`. Defaults to `True`.
|
37
|
+
use_post_ffw_norm: boolean. Whether to normalize after the feedforward
|
38
|
+
block. Defaults to `False`.
|
39
|
+
use_post_attention_norm: boolean. Whether to normalize after the
|
40
|
+
attention block. Defaults to `False`.
|
41
|
+
logit_soft_cap: `None` or int. Soft cap for the attention logits.
|
42
|
+
Defaults to `None`.
|
43
|
+
use_sliding_window_attention: boolean. Whether to use sliding local
|
44
|
+
window attention. Defaults to `False`.
|
45
|
+
sliding_window_size: int. Size of the sliding local window. Defaults to
|
46
|
+
`4096`.
|
34
47
|
layer_norm_epsilon: float. The epsilon hyperparameter used for layer
|
35
|
-
normalization.
|
48
|
+
normalization. Defaults to `1e-6`.
|
36
49
|
dropout: float. The dropout rate for the transformer attention layer.
|
50
|
+
Defaults to `0`.
|
37
51
|
"""
|
38
52
|
|
39
|
-
def __init__(
|
40
|
-
self,
|
41
|
-
hidden_dim,
|
42
|
-
intermediate_dim,
|
43
|
-
head_dim,
|
44
|
-
num_query_heads,
|
45
|
-
num_key_value_heads,
|
46
|
-
layer_norm_epsilon=1e-6,
|
47
|
-
dropout=0,
|
48
|
-
**kwargs,
|
49
|
-
):
|
50
|
-
super().__init__(
|
51
|
-
hidden_dim=hidden_dim,
|
52
|
-
intermediate_dim=intermediate_dim,
|
53
|
-
head_dim=head_dim,
|
54
|
-
num_query_heads=num_query_heads,
|
55
|
-
num_key_value_heads=num_key_value_heads,
|
56
|
-
layer_norm_epsilon=layer_norm_epsilon,
|
57
|
-
dropout=dropout,
|
58
|
-
**kwargs,
|
59
|
-
)
|
60
|
-
|
61
53
|
def call(
|
62
54
|
self,
|
63
55
|
x,
|
@@ -83,6 +75,9 @@ class PaliGemmaDecoderBlock(GemmaDecoderBlock):
|
|
83
75
|
attention_mask=attention_mask,
|
84
76
|
)
|
85
77
|
|
78
|
+
if self.use_post_attention_norm:
|
79
|
+
attention = self.post_attention_norm(attention)
|
80
|
+
|
86
81
|
if self.dropout:
|
87
82
|
attention = self.attention_dropout(attention)
|
88
83
|
|
@@ -94,6 +89,9 @@ class PaliGemmaDecoderBlock(GemmaDecoderBlock):
|
|
94
89
|
x = keras.activations.gelu(x1, approximate=True) * x2
|
95
90
|
x = self.ffw_linear(x)
|
96
91
|
|
92
|
+
if self.use_post_ffw_norm:
|
93
|
+
x = self.post_ffw_norm(x)
|
94
|
+
|
97
95
|
x = x + attention_x
|
98
96
|
|
99
97
|
if cache is not None:
|
@@ -1,12 +1,10 @@
|
|
1
1
|
from keras_hub.src.api_export import keras_hub_export
|
2
|
-
from keras_hub.src.layers.preprocessing.
|
3
|
-
ResizingImageConverter,
|
4
|
-
)
|
2
|
+
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
|
5
3
|
from keras_hub.src.models.pali_gemma.pali_gemma_backbone import (
|
6
4
|
PaliGemmaBackbone,
|
7
5
|
)
|
8
6
|
|
9
7
|
|
10
8
|
@keras_hub_export("keras_hub.layers.PaliGemmaImageConverter")
|
11
|
-
class PaliGemmaImageConverter(
|
9
|
+
class PaliGemmaImageConverter(ImageConverter):
|
12
10
|
backbone_cls = PaliGemmaBackbone
|