keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +21 -3
- keras_hub/api/models/__init__.py +71 -12
- keras_hub/api/tokenizers/__init__.py +1 -1
- keras_hub/src/bounding_box/__init__.py +2 -0
- keras_hub/src/bounding_box/converters.py +102 -12
- keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
- keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
- keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
- keras_hub/src/layers/modeling/rms_normalization.py +36 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
- keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
- keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
- keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
- keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
- keras_hub/src/layers/preprocessing/image_converter.py +170 -34
- keras_hub/src/metrics/bleu.py +4 -3
- keras_hub/src/models/albert/albert_presets.py +4 -12
- keras_hub/src/models/albert/albert_text_classifier.py +7 -7
- keras_hub/src/models/backbone.py +3 -14
- keras_hub/src/models/bart/bart_backbone.py +4 -4
- keras_hub/src/models/bart/bart_presets.py +3 -9
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
- keras_hub/src/models/basnet/__init__.py +5 -0
- keras_hub/src/models/basnet/basnet.py +122 -0
- keras_hub/src/models/basnet/basnet_backbone.py +366 -0
- keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
- keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
- keras_hub/src/models/basnet/basnet_presets.py +17 -0
- keras_hub/src/models/bert/bert_presets.py +14 -32
- keras_hub/src/models/bert/bert_text_classifier.py +3 -3
- keras_hub/src/models/bloom/bloom_presets.py +8 -24
- keras_hub/src/models/causal_lm.py +56 -12
- keras_hub/src/models/clip/__init__.py +5 -0
- keras_hub/src/models/clip/clip_backbone.py +286 -0
- keras_hub/src/models/clip/clip_encoder_block.py +19 -4
- keras_hub/src/models/clip/clip_image_converter.py +8 -0
- keras_hub/src/models/clip/clip_presets.py +93 -0
- keras_hub/src/models/clip/clip_text_encoder.py +4 -1
- keras_hub/src/models/clip/clip_tokenizer.py +18 -3
- keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
- keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
- keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
- keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
- keras_hub/src/models/deeplab_v3/__init__.py +7 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
- keras_hub/src/models/densenet/densenet_backbone.py +6 -4
- keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
- keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
- keras_hub/src/models/densenet/densenet_presets.py +9 -15
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
- keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
- keras_hub/src/models/efficientnet/__init__.py +9 -0
- keras_hub/src/models/efficientnet/cba.py +141 -0
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
- keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
- keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
- keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
- keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
- keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
- keras_hub/src/models/efficientnet/mbconv.py +53 -22
- keras_hub/src/models/electra/electra_backbone.py +2 -2
- keras_hub/src/models/electra/electra_presets.py +6 -18
- keras_hub/src/models/f_net/f_net_presets.py +2 -6
- keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
- keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
- keras_hub/src/models/falcon/falcon_backbone.py +5 -3
- keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
- keras_hub/src/models/falcon/falcon_presets.py +1 -3
- keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
- keras_hub/src/models/feature_pyramid_backbone.py +1 -1
- keras_hub/src/models/flux/__init__.py +5 -0
- keras_hub/src/models/flux/flux_layers.py +496 -0
- keras_hub/src/models/flux/flux_maths.py +225 -0
- keras_hub/src/models/flux/flux_model.py +236 -0
- keras_hub/src/models/flux/flux_presets.py +3 -0
- keras_hub/src/models/flux/flux_text_to_image.py +146 -0
- keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
- keras_hub/src/models/gemma/gemma_backbone.py +35 -20
- keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
- keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
- keras_hub/src/models/gemma/gemma_presets.py +29 -63
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
- keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
- keras_hub/src/models/image_classifier.py +147 -2
- keras_hub/src/models/image_classifier_preprocessor.py +6 -3
- keras_hub/src/models/image_object_detector.py +87 -0
- keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
- keras_hub/src/models/image_segmenter.py +0 -5
- keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
- keras_hub/src/models/image_to_image.py +417 -0
- keras_hub/src/models/inpaint.py +520 -0
- keras_hub/src/models/llama/llama_backbone.py +138 -12
- keras_hub/src/models/llama/llama_causal_lm.py +3 -1
- keras_hub/src/models/llama/llama_presets.py +10 -20
- keras_hub/src/models/llama3/llama3_backbone.py +12 -11
- keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
- keras_hub/src/models/llama3/llama3_presets.py +4 -12
- keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
- keras_hub/src/models/mistral/mistral_backbone.py +16 -15
- keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
- keras_hub/src/models/mistral/mistral_presets.py +3 -9
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
- keras_hub/src/models/mit/__init__.py +6 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
- keras_hub/src/models/mit/mit_image_classifier.py +12 -0
- keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/mit/mit_image_converter.py +8 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
- keras_hub/src/models/mit/mit_presets.py +139 -0
- keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
- keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
- keras_hub/src/models/opt/opt_causal_lm.py +2 -2
- keras_hub/src/models/opt/opt_presets.py +4 -12
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
- keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
- keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
- keras_hub/src/models/phi3/phi3_decoder.py +0 -1
- keras_hub/src/models/phi3/phi3_presets.py +2 -6
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
- keras_hub/src/models/preprocessor.py +25 -11
- keras_hub/src/models/resnet/resnet_backbone.py +3 -14
- keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
- keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
- keras_hub/src/models/resnet/resnet_presets.py +127 -18
- keras_hub/src/models/retinanet/__init__.py +5 -0
- keras_hub/src/models/retinanet/anchor_generator.py +52 -53
- keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
- keras_hub/src/models/retinanet/non_max_supression.py +1 -0
- keras_hub/src/models/retinanet/prediction_head.py +192 -0
- keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
- keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
- keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
- keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
- keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
- keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
- keras_hub/src/models/roberta/roberta_backbone.py +2 -2
- keras_hub/src/models/roberta/roberta_presets.py +6 -8
- keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
- keras_hub/src/models/sam/__init__.py +5 -0
- keras_hub/src/models/sam/sam_backbone.py +2 -3
- keras_hub/src/models/sam/sam_image_converter.py +2 -4
- keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
- keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
- keras_hub/src/models/sam/sam_layers.py +5 -3
- keras_hub/src/models/sam/sam_presets.py +3 -9
- keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
- keras_hub/src/models/sam/sam_transformer.py +5 -4
- keras_hub/src/models/segformer/__init__.py +8 -0
- keras_hub/src/models/segformer/segformer_backbone.py +167 -0
- keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
- keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
- keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
- keras_hub/src/models/segformer/segformer_presets.py +136 -0
- keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
- keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
- keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
- keras_hub/src/models/t5/t5_backbone.py +5 -4
- keras_hub/src/models/t5/t5_presets.py +47 -19
- keras_hub/src/models/task.py +47 -39
- keras_hub/src/models/text_classifier.py +2 -2
- keras_hub/src/models/text_to_image.py +106 -41
- keras_hub/src/models/vae/__init__.py +1 -0
- keras_hub/src/models/vae/vae_backbone.py +184 -0
- keras_hub/src/models/vae/vae_layers.py +739 -0
- keras_hub/src/models/vgg/__init__.py +5 -0
- keras_hub/src/models/vgg/vgg_backbone.py +4 -24
- keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
- keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
- keras_hub/src/models/vgg/vgg_presets.py +48 -0
- keras_hub/src/models/vit/__init__.py +5 -0
- keras_hub/src/models/vit/vit_backbone.py +152 -0
- keras_hub/src/models/vit/vit_image_classifier.py +187 -0
- keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vit/vit_image_converter.py +73 -0
- keras_hub/src/models/vit/vit_layers.py +391 -0
- keras_hub/src/models/vit/vit_presets.py +126 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
- keras_hub/src/models/vit_det/vit_layers.py +3 -3
- keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
- keras_hub/src/models/whisper/whisper_backbone.py +6 -5
- keras_hub/src/models/whisper/whisper_decoder.py +3 -5
- keras_hub/src/models/whisper/whisper_presets.py +10 -30
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
- keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
- keras_hub/src/models/xlnet/relative_attention.py +20 -19
- keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
- keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
- keras_hub/src/samplers/contrastive_sampler.py +2 -3
- keras_hub/src/samplers/sampler.py +2 -1
- keras_hub/src/tests/test_case.py +41 -6
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
- keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
- keras_hub/src/tokenizers/tokenizer.py +10 -13
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
- keras_hub/src/utils/keras_utils.py +2 -13
- keras_hub/src/utils/pipeline_model.py +3 -3
- keras_hub/src/utils/preset_utils.py +196 -144
- keras_hub/src/utils/tensor_utils.py +4 -4
- keras_hub/src/utils/timm/convert_densenet.py +6 -4
- keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
- keras_hub/src/utils/timm/convert_resnet.py +1 -1
- keras_hub/src/utils/timm/convert_vgg.py +85 -0
- keras_hub/src/utils/timm/preset_loader.py +14 -9
- keras_hub/src/utils/transformers/convert_llama3.py +21 -5
- keras_hub/src/utils/transformers/convert_vit.py +150 -0
- keras_hub/src/utils/transformers/preset_loader.py +23 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
- keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
- keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
- keras_hub/src/models/mix_transformer/__init__.py +0 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
- keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
- keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -42,7 +42,9 @@ class LlamaCausalLM(CausalLM):
|
|
42
42
|
self.preprocessor = preprocessor
|
43
43
|
|
44
44
|
# === Functional Model ===
|
45
|
-
|
45
|
+
# This must be "backbone.input" i.e. the full input structure,
|
46
|
+
# rather than "backbone.inputs" which is the flattened list of inputs.
|
47
|
+
inputs = backbone.input
|
46
48
|
hidden_states = backbone(inputs)
|
47
49
|
outputs = backbone.token_embedding(hidden_states, reverse=True)
|
48
50
|
super().__init__(
|
@@ -6,11 +6,9 @@ backbone_presets = {
|
|
6
6
|
"metadata": {
|
7
7
|
"description": "7 billion parameter, 32-layer, base LLaMA 2 model.",
|
8
8
|
"params": 6738415616,
|
9
|
-
"
|
10
|
-
"path": "llama2",
|
11
|
-
"model_card": "https://github.com/meta-llama/llama",
|
9
|
+
"path": "llama",
|
12
10
|
},
|
13
|
-
"kaggle_handle": "kaggle://keras/llama2/keras/llama2_7b_en/
|
11
|
+
"kaggle_handle": "kaggle://keras/llama2/keras/llama2_7b_en/2",
|
14
12
|
},
|
15
13
|
"llama2_7b_en_int8": {
|
16
14
|
"metadata": {
|
@@ -19,11 +17,9 @@ backbone_presets = {
|
|
19
17
|
"activation and weights quantized to int8."
|
20
18
|
),
|
21
19
|
"params": 6739839488,
|
22
|
-
"
|
23
|
-
"path": "llama2",
|
24
|
-
"model_card": "https://github.com/meta-llama/llama",
|
20
|
+
"path": "llama",
|
25
21
|
},
|
26
|
-
"kaggle_handle": "kaggle://keras/llama2/keras/llama2_7b_en_int8/
|
22
|
+
"kaggle_handle": "kaggle://keras/llama2/keras/llama2_7b_en_int8/2",
|
27
23
|
},
|
28
24
|
"llama2_instruct_7b_en": {
|
29
25
|
"metadata": {
|
@@ -32,11 +28,9 @@ backbone_presets = {
|
|
32
28
|
"model."
|
33
29
|
),
|
34
30
|
"params": 6738415616,
|
35
|
-
"
|
36
|
-
"path": "llama2",
|
37
|
-
"model_card": "https://github.com/meta-llama/llama",
|
31
|
+
"path": "llama",
|
38
32
|
},
|
39
|
-
"kaggle_handle": "kaggle://keras/llama2/keras/llama2_instruct_7b_en/
|
33
|
+
"kaggle_handle": "kaggle://keras/llama2/keras/llama2_instruct_7b_en/2",
|
40
34
|
},
|
41
35
|
"llama2_instruct_7b_en_int8": {
|
42
36
|
"metadata": {
|
@@ -45,11 +39,9 @@ backbone_presets = {
|
|
45
39
|
"model with activation and weights quantized to int8."
|
46
40
|
),
|
47
41
|
"params": 6739839488,
|
48
|
-
"
|
49
|
-
"path": "llama2",
|
50
|
-
"model_card": "https://github.com/meta-llama/llama",
|
42
|
+
"path": "llama",
|
51
43
|
},
|
52
|
-
"kaggle_handle": "kaggle://keras/llama2/keras/llama2_instruct_7b_en_int8/
|
44
|
+
"kaggle_handle": "kaggle://keras/llama2/keras/llama2_instruct_7b_en_int8/2",
|
53
45
|
},
|
54
46
|
"vicuna_1.5_7b_en": {
|
55
47
|
"metadata": {
|
@@ -58,10 +50,8 @@ backbone_presets = {
|
|
58
50
|
"model."
|
59
51
|
),
|
60
52
|
"params": 6738415616,
|
61
|
-
"
|
62
|
-
"path": "vicuna",
|
63
|
-
"model_card": "https://github.com/lm-sys/FastChat",
|
53
|
+
"path": "llama",
|
64
54
|
},
|
65
|
-
"kaggle_handle": "kaggle://keras/vicuna/keras/vicuna_1.5_7b_en/
|
55
|
+
"kaggle_handle": "kaggle://keras/vicuna/keras/vicuna_1.5_7b_en/2",
|
66
56
|
},
|
67
57
|
}
|
@@ -24,17 +24,18 @@ class Llama3Backbone(LlamaBackbone):
|
|
24
24
|
num_layers (int): The number of transformer layers.
|
25
25
|
num_query_heads (int): The number of query attention heads for
|
26
26
|
each transformer.
|
27
|
-
hidden_dim (int): The size of the transformer encoding and pooling
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
27
|
+
hidden_dim (int): The size of the transformer encoding and pooling
|
28
|
+
layers.
|
29
|
+
intermediate_dim (int): The output dimension of the first Dense layer in
|
30
|
+
a three-layer feedforward network for each transformer.
|
31
|
+
num_key_value_heads (int): The number of key and value attention heads
|
32
|
+
fo each transformer.
|
33
|
+
rope_max_wavelength (int, optional): The maximum angular wavelength of
|
34
|
+
the sine/cosine curves, for rotary embeddings. Defaults to `10000`.
|
35
|
+
rope_scaling_factor (float, optional): The scaling factor for
|
36
|
+
calculation of roatary embedding. Defaults to `1.0`.
|
37
|
+
layer_norm_epsilon (float, optional): Epsilon for the layer
|
38
|
+
normalization layers in the transformer decoder. Defaults to `1e-6`.
|
38
39
|
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
|
39
40
|
for model computations and weights. Note that some computations,
|
40
41
|
such as softmax and layer normalization, will always be done at
|
@@ -1,9 +1,9 @@
|
|
1
1
|
from keras_hub.src.api_export import keras_hub_export
|
2
|
+
from keras_hub.src.models.llama.llama_causal_lm import LlamaCausalLM
|
2
3
|
from keras_hub.src.models.llama3.llama3_backbone import Llama3Backbone
|
3
4
|
from keras_hub.src.models.llama3.llama3_causal_lm_preprocessor import (
|
4
5
|
Llama3CausalLMPreprocessor,
|
5
6
|
)
|
6
|
-
from keras_hub.src.models.llama.llama_causal_lm import LlamaCausalLM
|
7
7
|
|
8
8
|
|
9
9
|
@keras_hub_export("keras_hub.models.Llama3CausalLM")
|
@@ -6,11 +6,9 @@ backbone_presets = {
|
|
6
6
|
"metadata": {
|
7
7
|
"description": "8 billion parameter, 32-layer, base LLaMA 3 model.",
|
8
8
|
"params": 8030261248,
|
9
|
-
"official_name": "LLaMA 3",
|
10
9
|
"path": "llama3",
|
11
|
-
"model_card": "https://github.com/meta-llama/llama3",
|
12
10
|
},
|
13
|
-
"kaggle_handle": "kaggle://keras/llama3/keras/llama3_8b_en/
|
11
|
+
"kaggle_handle": "kaggle://keras/llama3/keras/llama3_8b_en/4",
|
14
12
|
},
|
15
13
|
"llama3_8b_en_int8": {
|
16
14
|
"metadata": {
|
@@ -19,11 +17,9 @@ backbone_presets = {
|
|
19
17
|
"activation and weights quantized to int8."
|
20
18
|
),
|
21
19
|
"params": 8031894016,
|
22
|
-
"official_name": "LLaMA 3",
|
23
20
|
"path": "llama3",
|
24
|
-
"model_card": "https://github.com/meta-llama/llama3",
|
25
21
|
},
|
26
|
-
"kaggle_handle": "kaggle://keras/llama3/keras/llama3_8b_en_int8/
|
22
|
+
"kaggle_handle": "kaggle://keras/llama3/keras/llama3_8b_en_int8/2",
|
27
23
|
},
|
28
24
|
"llama3_instruct_8b_en": {
|
29
25
|
"metadata": {
|
@@ -32,11 +28,9 @@ backbone_presets = {
|
|
32
28
|
"model."
|
33
29
|
),
|
34
30
|
"params": 8030261248,
|
35
|
-
"official_name": "LLaMA 3",
|
36
31
|
"path": "llama3",
|
37
|
-
"model_card": "https://github.com/meta-llama/llama3",
|
38
32
|
},
|
39
|
-
"kaggle_handle": "kaggle://keras/llama3/keras/llama3_instruct_8b_en/
|
33
|
+
"kaggle_handle": "kaggle://keras/llama3/keras/llama3_instruct_8b_en/4",
|
40
34
|
},
|
41
35
|
"llama3_instruct_8b_en_int8": {
|
42
36
|
"metadata": {
|
@@ -45,12 +39,10 @@ backbone_presets = {
|
|
45
39
|
"model with activation and weights quantized to int8."
|
46
40
|
),
|
47
41
|
"params": 8031894016,
|
48
|
-
"official_name": "LLaMA 3",
|
49
42
|
"path": "llama3",
|
50
|
-
"model_card": "https://github.com/meta-llama/llama3",
|
51
43
|
},
|
52
44
|
"kaggle_handle": (
|
53
|
-
"kaggle://keras/llama3/keras/llama3_instruct_8b_en_int8/
|
45
|
+
"kaggle://keras/llama3/keras/llama3_instruct_8b_en_int8/2"
|
54
46
|
),
|
55
47
|
},
|
56
48
|
}
|
@@ -16,10 +16,33 @@ class Llama3Tokenizer(BytePairTokenizer):
|
|
16
16
|
self,
|
17
17
|
vocabulary=None,
|
18
18
|
merges=None,
|
19
|
+
bos_token="<|begin_of_text|>",
|
20
|
+
eos_token="<|end_of_text|>",
|
21
|
+
misc_special_tokens={"<|start_header_id|>", "<|end_header_id|>"},
|
19
22
|
**kwargs,
|
20
23
|
):
|
21
|
-
|
22
|
-
|
24
|
+
# Note: all special tokens must also appear in "vocabulary"
|
25
|
+
|
26
|
+
self._add_special_token(bos_token, "start_token")
|
27
|
+
misc_special_tokens -= {bos_token}
|
28
|
+
self._add_special_token(eos_token, "end_token")
|
29
|
+
misc_special_tokens -= {eos_token}
|
30
|
+
for i, token in enumerate(misc_special_tokens):
|
31
|
+
self._add_special_token(token, f"special_token_{i:03d}")
|
32
|
+
|
33
|
+
# Hack:
|
34
|
+
# Llama models use the <|end_of_text|> or the <|eot_id|> as the stop
|
35
|
+
# token. This info can be read from config when loading a Hugging Face
|
36
|
+
# checkpoint but no such config exists for Keras checkpoints.
|
37
|
+
# Setting both probable end tokens when no config is availble will
|
38
|
+
# make text generation work in all cases as it will stop
|
39
|
+
# on both end tokens. However, the packer will always use
|
40
|
+
# "<|end_of_text|>" , which will be the wrong eos_token for "instruct"
|
41
|
+
# variants of Llama3.
|
42
|
+
# TODO: load this correctly from a Keras tokenizer config.
|
43
|
+
if eos_token == "<|end_of_text|>":
|
44
|
+
self._add_special_token("<|eot_id|>", "end_token2")
|
45
|
+
|
23
46
|
self.pad_token_id = 0
|
24
47
|
super().__init__(
|
25
48
|
vocabulary=vocabulary,
|
@@ -38,22 +38,23 @@ class MistralBackbone(Backbone):
|
|
38
38
|
num_layers (int): The number of transformer layers.
|
39
39
|
num_query_heads (int): The number of query attention heads for
|
40
40
|
each transformer.
|
41
|
-
hidden_dim (int): The size of the transformer encoding and pooling
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
41
|
+
hidden_dim (int): The size of the transformer encoding and pooling
|
42
|
+
layers.
|
43
|
+
intermediate_dim (int): The output dimension of the first Dense layer
|
44
|
+
in a three-layer feedforward network for each transformer.
|
45
|
+
num_key_value_heads (int): The number of key and value attention heads
|
46
|
+
for each transformer.
|
47
|
+
rope_max_wavelength (int, optional): The maximum angular wavelength of
|
48
|
+
the sine/cosine curves, for rotary embeddings. Defaults to `10000`.
|
49
|
+
rope_scaling_factor (float, optional): The scaling factor for
|
50
|
+
calculation of roatary embedding. Defaults to `1.0`.
|
51
|
+
layer_norm_epsilon (float, optional): Epsilon for the layer
|
52
|
+
normalization layers in the transformer decoder. Defaults to `1e-6`.
|
52
53
|
sliding_window (int, optional): The sliding window for the mistral
|
53
|
-
attention layers. This controls the maximum cache size for the
|
54
|
-
layers in each transformer decoder. Only `sliding_window`
|
55
|
-
are saved in the cache and used to generate the
|
56
|
-
Defaults to `512`.
|
54
|
+
attention layers. This controls the maximum cache size for the
|
55
|
+
attention layers in each transformer decoder. Only `sliding_window`
|
56
|
+
number of tokens are saved in the cache and used to generate the
|
57
|
+
next token. Defaults to `512`.
|
57
58
|
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
|
58
59
|
for model computations and weights. Note that some computations,
|
59
60
|
such as softmax and layer normalization, will always be done at
|
@@ -28,9 +28,9 @@ class MistralCausalLM(CausalLM):
|
|
28
28
|
|
29
29
|
Args:
|
30
30
|
backbone: A `keras_hub.models.MistralBackbone` instance.
|
31
|
-
preprocessor: A `keras_hub.models.MistralCausalLMPreprocessor` or
|
32
|
-
If `None`, this model will not apply preprocessing, and
|
33
|
-
should be preprocessed before calling the model.
|
31
|
+
preprocessor: A `keras_hub.models.MistralCausalLMPreprocessor` or
|
32
|
+
`None`. If `None`, this model will not apply preprocessing, and
|
33
|
+
inputs should be preprocessed before calling the model.
|
34
34
|
"""
|
35
35
|
|
36
36
|
backbone_cls = MistralBackbone
|
@@ -42,7 +42,9 @@ class MistralCausalLM(CausalLM):
|
|
42
42
|
self.preprocessor = preprocessor
|
43
43
|
|
44
44
|
# === Functional Model ===
|
45
|
-
|
45
|
+
# This must be "backbone.input" i.e. the full input structure,
|
46
|
+
# rather than "backbone.inputs" which is the flattened list of inputs.
|
47
|
+
inputs = backbone.input
|
46
48
|
hidden_states = backbone(inputs)
|
47
49
|
outputs = backbone.token_embedding(hidden_states, reverse=True)
|
48
50
|
super().__init__(
|
@@ -6,30 +6,24 @@ backbone_presets = {
|
|
6
6
|
"metadata": {
|
7
7
|
"description": "Mistral 7B base model",
|
8
8
|
"params": 7241732096,
|
9
|
-
"official_name": "Mistral",
|
10
9
|
"path": "mistral",
|
11
|
-
"model_card": "https://github.com/mistralai/mistral-src/blob/main/README.md",
|
12
10
|
},
|
13
|
-
"kaggle_handle": "kaggle://keras/mistral/keras/mistral_7b_en/
|
11
|
+
"kaggle_handle": "kaggle://keras/mistral/keras/mistral_7b_en/7",
|
14
12
|
},
|
15
13
|
"mistral_instruct_7b_en": {
|
16
14
|
"metadata": {
|
17
15
|
"description": "Mistral 7B instruct model",
|
18
16
|
"params": 7241732096,
|
19
|
-
"official_name": "Mistral",
|
20
17
|
"path": "mistral",
|
21
|
-
"model_card": "https://github.com/mistralai/mistral-src/blob/main/README.md",
|
22
18
|
},
|
23
|
-
"kaggle_handle": "kaggle://keras/mistral/keras/mistral_instruct_7b_en/
|
19
|
+
"kaggle_handle": "kaggle://keras/mistral/keras/mistral_instruct_7b_en/7",
|
24
20
|
},
|
25
21
|
"mistral_0.2_instruct_7b_en": {
|
26
22
|
"metadata": {
|
27
23
|
"description": "Mistral 7B instruct Version 0.2 model",
|
28
24
|
"params": 7241732096,
|
29
|
-
"official_name": "Mistral",
|
30
25
|
"path": "mistral",
|
31
|
-
"model_card": "https://github.com/mistralai/mistral-src/blob/main/README.md",
|
32
26
|
},
|
33
|
-
"kaggle_handle": "kaggle://keras/mistral/keras/mistral_0.2_instruct_7b_en/
|
27
|
+
"kaggle_handle": "kaggle://keras/mistral/keras/mistral_0.2_instruct_7b_en/2",
|
34
28
|
},
|
35
29
|
}
|
@@ -215,7 +215,8 @@ class MistralTransformerDecoder(keras.layers.Layer):
|
|
215
215
|
# Mistral uses a banded attention mask if sliding window is not None
|
216
216
|
if self.sliding_window is not None:
|
217
217
|
# Below is a workaround for `ops.triu` for Keras 2.
|
218
|
-
# TODO(tirthasheshpatel): Use `ops.triu` once Keras 2 support is
|
218
|
+
# TODO(tirthasheshpatel): Use `ops.triu` once Keras 2 support is
|
219
|
+
# removed.
|
219
220
|
# causal_mask = ops.triu(causal_mask, k=-self.sliding_window)
|
220
221
|
i = ops.arange(output_length)[:, None] + cache_update_index
|
221
222
|
j = ops.arange(input_length)[None, :]
|
@@ -0,0 +1,6 @@
|
|
1
|
+
from keras_hub.src.models.mit.mit_backbone import MiTBackbone
|
2
|
+
from keras_hub.src.models.mit.mit_image_classifier import MiTImageClassifier
|
3
|
+
from keras_hub.src.models.mit.mit_presets import backbone_presets
|
4
|
+
from keras_hub.src.utils.preset_utils import register_presets
|
5
|
+
|
6
|
+
register_presets(backbone_presets, MiTBackbone)
|
@@ -1,28 +1,35 @@
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
2
|
+
# you may not use this file except in compliance with the License.
|
3
|
+
# You may obtain a copy of the License at
|
4
|
+
#
|
5
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
#
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
10
|
+
# See the License for the specific language governing permissions and
|
11
|
+
# limitations under the License.
|
1
12
|
import keras
|
2
13
|
import numpy as np
|
3
14
|
from keras import ops
|
4
15
|
|
5
16
|
from keras_hub.src.api_export import keras_hub_export
|
6
17
|
from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone
|
7
|
-
from keras_hub.src.models.
|
8
|
-
|
9
|
-
)
|
10
|
-
from keras_hub.src.models.mix_transformer.mix_transformer_layers import (
|
11
|
-
OverlappingPatchingAndEmbedding,
|
12
|
-
)
|
18
|
+
from keras_hub.src.models.mit.mit_layers import HierarchicalTransformerEncoder
|
19
|
+
from keras_hub.src.models.mit.mit_layers import OverlappingPatchingAndEmbedding
|
13
20
|
|
14
21
|
|
15
22
|
@keras_hub_export("keras_hub.models.MiTBackbone")
|
16
23
|
class MiTBackbone(FeaturePyramidBackbone):
|
17
24
|
def __init__(
|
18
25
|
self,
|
19
|
-
|
26
|
+
layerwise_depths,
|
20
27
|
num_layers,
|
21
|
-
|
22
|
-
|
28
|
+
layerwise_num_heads,
|
29
|
+
layerwise_sr_ratios,
|
23
30
|
max_drop_path_rate,
|
24
|
-
|
25
|
-
|
31
|
+
layerwise_patch_sizes,
|
32
|
+
layerwise_strides,
|
26
33
|
image_shape=(None, None, 3),
|
27
34
|
hidden_dims=None,
|
28
35
|
**kwargs,
|
@@ -36,12 +43,12 @@ class MiTBackbone(FeaturePyramidBackbone):
|
|
36
43
|
https://github.com/DavidLandup0/deepvision/tree/main/deepvision/models/classification/mix_transformer)
|
37
44
|
|
38
45
|
Args:
|
39
|
-
|
40
|
-
network.
|
46
|
+
layerwise_depths: The number of transformer encoders to be used per
|
47
|
+
layer in the network.
|
41
48
|
num_layers: int. The number of Transformer layers.
|
42
|
-
|
49
|
+
layerwise_num_heads: list of integers, the number of heads to use
|
43
50
|
in the attention computation for each layer.
|
44
|
-
|
51
|
+
layerwise_sr_ratios: list of integers, the sequence reduction
|
45
52
|
ratio to perform for each layer on the sequence before key and
|
46
53
|
value projections. If set to > 1, a `Conv2D` layer is used to
|
47
54
|
reduce the length of the sequence.
|
@@ -51,7 +58,8 @@ class MiTBackbone(FeaturePyramidBackbone):
|
|
51
58
|
image_shape: optional shape tuple, defaults to (None, None, 3).
|
52
59
|
hidden_dims: the embedding dims per hierarchical layer, used as
|
53
60
|
the levels of the feature pyramid.
|
54
|
-
patch_sizes: list of integers, the patch_size to apply for each
|
61
|
+
patch_sizes: list of integers, the patch_size to apply for each
|
62
|
+
layer.
|
55
63
|
strides: list of integers, stride to apply for each layer.
|
56
64
|
|
57
65
|
Examples:
|
@@ -61,7 +69,7 @@ class MiTBackbone(FeaturePyramidBackbone):
|
|
61
69
|
```python
|
62
70
|
images = np.ones(shape=(1, 96, 96, 3))
|
63
71
|
labels = np.zeros(shape=(1, 96, 96, 1))
|
64
|
-
backbone = keras_hub.models.MiTBackbone.from_preset("
|
72
|
+
backbone = keras_hub.models.MiTBackbone.from_preset("mit_b0_ade20k_512")
|
65
73
|
|
66
74
|
# Evaluate model
|
67
75
|
model(images)
|
@@ -75,7 +83,10 @@ class MiTBackbone(FeaturePyramidBackbone):
|
|
75
83
|
model.fit(images, labels, epochs=3)
|
76
84
|
```
|
77
85
|
"""
|
78
|
-
dpr = [
|
86
|
+
dpr = [
|
87
|
+
x
|
88
|
+
for x in np.linspace(0.0, max_drop_path_rate, sum(layerwise_depths))
|
89
|
+
]
|
79
90
|
|
80
91
|
# === Layers ===
|
81
92
|
cur = 0
|
@@ -86,8 +97,8 @@ class MiTBackbone(FeaturePyramidBackbone):
|
|
86
97
|
for i in range(num_layers):
|
87
98
|
patch_embed_layer = OverlappingPatchingAndEmbedding(
|
88
99
|
project_dim=hidden_dims[i],
|
89
|
-
patch_size=
|
90
|
-
stride=
|
100
|
+
patch_size=layerwise_patch_sizes[i],
|
101
|
+
stride=layerwise_strides[i],
|
91
102
|
name=f"patch_and_embed_{i}",
|
92
103
|
)
|
93
104
|
patch_embedding_layers.append(patch_embed_layer)
|
@@ -95,16 +106,16 @@ class MiTBackbone(FeaturePyramidBackbone):
|
|
95
106
|
transformer_block = [
|
96
107
|
HierarchicalTransformerEncoder(
|
97
108
|
project_dim=hidden_dims[i],
|
98
|
-
num_heads=
|
99
|
-
sr_ratio=
|
109
|
+
num_heads=layerwise_num_heads[i],
|
110
|
+
sr_ratio=layerwise_sr_ratios[i],
|
100
111
|
drop_prob=dpr[cur + k],
|
101
112
|
name=f"hierarchical_encoder_{i}_{k}",
|
102
113
|
)
|
103
|
-
for k in range(
|
114
|
+
for k in range(layerwise_depths[i])
|
104
115
|
]
|
105
116
|
transformer_blocks.append(transformer_block)
|
106
|
-
cur +=
|
107
|
-
layer_norms.append(keras.layers.LayerNormalization())
|
117
|
+
cur += layerwise_depths[i]
|
118
|
+
layer_norms.append(keras.layers.LayerNormalization(epsilon=1e-5))
|
108
119
|
|
109
120
|
# === Functional Model ===
|
110
121
|
image_input = keras.layers.Input(shape=image_shape)
|
@@ -113,7 +124,7 @@ class MiTBackbone(FeaturePyramidBackbone):
|
|
113
124
|
for i in range(num_layers):
|
114
125
|
# Compute new height/width after the `proj`
|
115
126
|
# call in `OverlappingPatchingAndEmbedding`
|
116
|
-
stride =
|
127
|
+
stride = layerwise_strides[i]
|
117
128
|
new_height, new_width = (
|
118
129
|
int(ops.shape(x)[1] / stride),
|
119
130
|
int(ops.shape(x)[2] / stride),
|
@@ -131,30 +142,30 @@ class MiTBackbone(FeaturePyramidBackbone):
|
|
131
142
|
super().__init__(inputs=image_input, outputs=x, **kwargs)
|
132
143
|
|
133
144
|
# === Config ===
|
134
|
-
self.
|
145
|
+
self.layerwise_depths = layerwise_depths
|
135
146
|
self.image_shape = image_shape
|
136
147
|
self.hidden_dims = hidden_dims
|
137
148
|
self.pyramid_outputs = pyramid_outputs
|
138
149
|
self.num_layers = num_layers
|
139
|
-
self.
|
140
|
-
self.
|
150
|
+
self.layerwise_num_heads = layerwise_num_heads
|
151
|
+
self.layerwise_sr_ratios = layerwise_sr_ratios
|
141
152
|
self.max_drop_path_rate = max_drop_path_rate
|
142
|
-
self.
|
143
|
-
self.
|
153
|
+
self.layerwise_patch_sizes = layerwise_patch_sizes
|
154
|
+
self.layerwise_strides = layerwise_strides
|
144
155
|
|
145
156
|
def get_config(self):
|
146
157
|
config = super().get_config()
|
147
158
|
config.update(
|
148
159
|
{
|
149
|
-
"
|
160
|
+
"layerwise_depths": self.layerwise_depths,
|
150
161
|
"hidden_dims": self.hidden_dims,
|
151
162
|
"image_shape": self.image_shape,
|
152
163
|
"num_layers": self.num_layers,
|
153
|
-
"
|
154
|
-
"
|
164
|
+
"layerwise_num_heads": self.layerwise_num_heads,
|
165
|
+
"layerwise_sr_ratios": self.layerwise_sr_ratios,
|
155
166
|
"max_drop_path_rate": self.max_drop_path_rate,
|
156
|
-
"
|
157
|
-
"
|
167
|
+
"layerwise_patch_sizes": self.layerwise_patch_sizes,
|
168
|
+
"layerwise_strides": self.layerwise_strides,
|
158
169
|
}
|
159
170
|
)
|
160
171
|
return config
|
@@ -0,0 +1,12 @@
|
|
1
|
+
from keras_hub.src.api_export import keras_hub_export
|
2
|
+
from keras_hub.src.models.image_classifier import ImageClassifier
|
3
|
+
from keras_hub.src.models.mit.mit_backbone import MiTBackbone
|
4
|
+
from keras_hub.src.models.mit.mit_image_classifier_preprocessor import (
|
5
|
+
MiTImageClassifierPreprocessor,
|
6
|
+
)
|
7
|
+
|
8
|
+
|
9
|
+
@keras_hub_export("keras_hub.models.MiTImageClassifier")
|
10
|
+
class MiTImageClassifier(ImageClassifier):
|
11
|
+
backbone_cls = MiTBackbone
|
12
|
+
preprocessor_cls = MiTImageClassifierPreprocessor
|
@@ -0,0 +1,12 @@
|
|
1
|
+
from keras_hub.src.api_export import keras_hub_export
|
2
|
+
from keras_hub.src.models.image_classifier_preprocessor import (
|
3
|
+
ImageClassifierPreprocessor,
|
4
|
+
)
|
5
|
+
from keras_hub.src.models.mit.mit_backbone import MiTBackbone
|
6
|
+
from keras_hub.src.models.mit.mit_image_converter import MiTImageConverter
|
7
|
+
|
8
|
+
|
9
|
+
@keras_hub_export("keras_hub.models.MiTImageClassifierPreprocessor")
|
10
|
+
class MiTImageClassifierPreprocessor(ImageClassifierPreprocessor):
|
11
|
+
backbone_cls = MiTBackbone
|
12
|
+
image_converter_cls = MiTImageConverter
|
@@ -0,0 +1,8 @@
|
|
1
|
+
from keras_hub.src.api_export import keras_hub_export
|
2
|
+
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
|
3
|
+
from keras_hub.src.models.mit import MiTBackbone
|
4
|
+
|
5
|
+
|
6
|
+
@keras_hub_export("keras_hub.layers.MiTImageConverter")
|
7
|
+
class MiTImageConverter(ImageConverter):
|
8
|
+
backbone_cls = MiTBackbone
|
@@ -28,19 +28,23 @@ class OverlappingPatchingAndEmbedding(keras.layers.Layer):
|
|
28
28
|
self.patch_size = patch_size
|
29
29
|
self.stride = stride
|
30
30
|
|
31
|
+
padding_size = self.patch_size // 2
|
32
|
+
|
33
|
+
self.padding = keras.layers.ZeroPadding2D(
|
34
|
+
padding=(padding_size, padding_size)
|
35
|
+
)
|
31
36
|
self.proj = keras.layers.Conv2D(
|
32
37
|
filters=project_dim,
|
33
38
|
kernel_size=patch_size,
|
34
39
|
strides=stride,
|
35
|
-
padding="
|
40
|
+
padding="valid",
|
36
41
|
)
|
37
|
-
self.norm = keras.layers.LayerNormalization()
|
42
|
+
self.norm = keras.layers.LayerNormalization(epsilon=1e-5)
|
38
43
|
|
39
44
|
def call(self, x):
|
45
|
+
x = self.padding(x)
|
40
46
|
x = self.proj(x)
|
41
|
-
|
42
|
-
shape = x.shape
|
43
|
-
x = ops.reshape(x, (-1, shape[1] * shape[2], shape[3]))
|
47
|
+
x = ops.reshape(x, (-1, x.shape[1] * x.shape[2], x.shape[3]))
|
44
48
|
x = self.norm(x)
|
45
49
|
return x
|
46
50
|
|
@@ -76,7 +80,8 @@ class HierarchicalTransformerEncoder(keras.layers.Layer):
|
|
76
80
|
`LayerNormalization` layers. Defaults to `1e-06`
|
77
81
|
sr_ratio: integer, the ratio to use within
|
78
82
|
`SegFormerMultiheadAttention`. If set to > 1, a `Conv2D`
|
79
|
-
|
83
|
+
layer is used to reduce the length of the sequence.
|
84
|
+
Defaults to `1`.
|
80
85
|
"""
|
81
86
|
|
82
87
|
def __init__(
|
@@ -179,20 +184,21 @@ class SegFormerMultiheadAttention(keras.layers.Layer):
|
|
179
184
|
self.k = keras.layers.Dense(project_dim)
|
180
185
|
self.v = keras.layers.Dense(project_dim)
|
181
186
|
self.proj = keras.layers.Dense(project_dim)
|
187
|
+
self.dropout = keras.layers.Dropout(0.1)
|
188
|
+
self.proj_drop = keras.layers.Dropout(0.1)
|
182
189
|
|
183
190
|
if sr_ratio > 1:
|
184
191
|
self.sr = keras.layers.Conv2D(
|
185
192
|
filters=project_dim,
|
186
193
|
kernel_size=sr_ratio,
|
187
194
|
strides=sr_ratio,
|
188
|
-
padding="same",
|
189
195
|
)
|
190
|
-
self.norm = keras.layers.LayerNormalization()
|
196
|
+
self.norm = keras.layers.LayerNormalization(epsilon=1e-5)
|
191
197
|
|
192
198
|
def call(self, x):
|
193
199
|
input_shape = ops.shape(x)
|
194
200
|
H, W = int(math.sqrt(input_shape[1])), int(math.sqrt(input_shape[1]))
|
195
|
-
B, C = input_shape[0], input_shape[2]
|
201
|
+
B, N, C = input_shape[0], input_shape[1], input_shape[2]
|
196
202
|
|
197
203
|
q = self.q(x)
|
198
204
|
q = ops.reshape(
|
@@ -208,12 +214,11 @@ class SegFormerMultiheadAttention(keras.layers.Layer):
|
|
208
214
|
|
209
215
|
if self.sr_ratio > 1:
|
210
216
|
x = ops.reshape(
|
211
|
-
|
217
|
+
x,
|
212
218
|
(B, H, W, C),
|
213
219
|
)
|
214
220
|
x = self.sr(x)
|
215
|
-
x = ops.reshape(x, [
|
216
|
-
x = ops.transpose(x, [0, 2, 1])
|
221
|
+
x = ops.reshape(x, [B, -1, C])
|
217
222
|
x = self.norm(x)
|
218
223
|
|
219
224
|
k = self.k(x)
|
@@ -237,14 +242,16 @@ class SegFormerMultiheadAttention(keras.layers.Layer):
|
|
237
242
|
|
238
243
|
attn = (q @ ops.transpose(k, [0, 1, 3, 2])) * self.scale
|
239
244
|
attn = ops.nn.softmax(attn, axis=-1)
|
245
|
+
attn = self.dropout(attn)
|
240
246
|
|
241
247
|
attn = attn @ v
|
242
248
|
attn = ops.reshape(
|
243
249
|
ops.transpose(attn, [0, 2, 1, 3]),
|
244
|
-
[
|
250
|
+
[B, N, C],
|
245
251
|
)
|
246
252
|
|
247
253
|
x = self.proj(attn)
|
254
|
+
x = self.proj_drop(x)
|
248
255
|
return x
|
249
256
|
|
250
257
|
|