keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +52 -0
- keras_hub/api/__init__.py +27 -0
- keras_hub/api/layers/__init__.py +47 -0
- keras_hub/api/metrics/__init__.py +24 -0
- keras_hub/api/models/__init__.py +249 -0
- keras_hub/api/samplers/__init__.py +29 -0
- keras_hub/api/tokenizers/__init__.py +35 -0
- keras_hub/src/__init__.py +13 -0
- keras_hub/src/api_export.py +53 -0
- keras_hub/src/layers/__init__.py +13 -0
- keras_hub/src/layers/modeling/__init__.py +13 -0
- keras_hub/src/layers/modeling/alibi_bias.py +143 -0
- keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
- keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
- keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
- keras_hub/src/layers/modeling/position_embedding.py +123 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
- keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
- keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
- keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
- keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
- keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
- keras_hub/src/layers/preprocessing/__init__.py +13 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
- keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
- keras_hub/src/layers/preprocessing/random_swap.py +267 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
- keras_hub/src/metrics/__init__.py +13 -0
- keras_hub/src/metrics/bleu.py +394 -0
- keras_hub/src/metrics/edit_distance.py +197 -0
- keras_hub/src/metrics/perplexity.py +181 -0
- keras_hub/src/metrics/rouge_base.py +204 -0
- keras_hub/src/metrics/rouge_l.py +97 -0
- keras_hub/src/metrics/rouge_n.py +125 -0
- keras_hub/src/models/__init__.py +13 -0
- keras_hub/src/models/albert/__init__.py +20 -0
- keras_hub/src/models/albert/albert_backbone.py +267 -0
- keras_hub/src/models/albert/albert_classifier.py +202 -0
- keras_hub/src/models/albert/albert_masked_lm.py +129 -0
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/albert/albert_preprocessor.py +206 -0
- keras_hub/src/models/albert/albert_presets.py +70 -0
- keras_hub/src/models/albert/albert_tokenizer.py +119 -0
- keras_hub/src/models/backbone.py +311 -0
- keras_hub/src/models/bart/__init__.py +20 -0
- keras_hub/src/models/bart/bart_backbone.py +261 -0
- keras_hub/src/models/bart/bart_preprocessor.py +276 -0
- keras_hub/src/models/bart/bart_presets.py +74 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
- keras_hub/src/models/bart/bart_tokenizer.py +124 -0
- keras_hub/src/models/bert/__init__.py +23 -0
- keras_hub/src/models/bert/bert_backbone.py +227 -0
- keras_hub/src/models/bert/bert_classifier.py +183 -0
- keras_hub/src/models/bert/bert_masked_lm.py +131 -0
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/bert/bert_preprocessor.py +184 -0
- keras_hub/src/models/bert/bert_presets.py +147 -0
- keras_hub/src/models/bert/bert_tokenizer.py +112 -0
- keras_hub/src/models/bloom/__init__.py +20 -0
- keras_hub/src/models/bloom/bloom_attention.py +186 -0
- keras_hub/src/models/bloom/bloom_backbone.py +173 -0
- keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
- keras_hub/src/models/bloom/bloom_decoder.py +206 -0
- keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
- keras_hub/src/models/bloom/bloom_presets.py +121 -0
- keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
- keras_hub/src/models/causal_lm.py +383 -0
- keras_hub/src/models/classifier.py +109 -0
- keras_hub/src/models/csp_darknet/__init__.py +13 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
- keras_hub/src/models/deberta_v3/__init__.py +24 -0
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
- keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
- keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
- keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
- keras_hub/src/models/densenet/__init__.py +13 -0
- keras_hub/src/models/densenet/densenet_backbone.py +210 -0
- keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
- keras_hub/src/models/distil_bert/__init__.py +26 -0
- keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
- keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
- keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
- keras_hub/src/models/electra/__init__.py +20 -0
- keras_hub/src/models/electra/electra_backbone.py +247 -0
- keras_hub/src/models/electra/electra_preprocessor.py +154 -0
- keras_hub/src/models/electra/electra_presets.py +95 -0
- keras_hub/src/models/electra/electra_tokenizer.py +104 -0
- keras_hub/src/models/f_net/__init__.py +20 -0
- keras_hub/src/models/f_net/f_net_backbone.py +236 -0
- keras_hub/src/models/f_net/f_net_classifier.py +154 -0
- keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
- keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
- keras_hub/src/models/f_net/f_net_presets.py +43 -0
- keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
- keras_hub/src/models/falcon/__init__.py +20 -0
- keras_hub/src/models/falcon/falcon_attention.py +156 -0
- keras_hub/src/models/falcon/falcon_backbone.py +164 -0
- keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
- keras_hub/src/models/falcon/falcon_presets.py +30 -0
- keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
- keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
- keras_hub/src/models/feature_pyramid_backbone.py +73 -0
- keras_hub/src/models/gemma/__init__.py +20 -0
- keras_hub/src/models/gemma/gemma_attention.py +250 -0
- keras_hub/src/models/gemma/gemma_backbone.py +316 -0
- keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
- keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
- keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
- keras_hub/src/models/gemma/gemma_presets.py +248 -0
- keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
- keras_hub/src/models/gemma/rms_normalization.py +40 -0
- keras_hub/src/models/gpt2/__init__.py +20 -0
- keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
- keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
- keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
- keras_hub/src/models/image_classifier.py +90 -0
- keras_hub/src/models/llama/__init__.py +20 -0
- keras_hub/src/models/llama/llama_attention.py +225 -0
- keras_hub/src/models/llama/llama_backbone.py +188 -0
- keras_hub/src/models/llama/llama_causal_lm.py +327 -0
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
- keras_hub/src/models/llama/llama_decoder.py +246 -0
- keras_hub/src/models/llama/llama_layernorm.py +48 -0
- keras_hub/src/models/llama/llama_preprocessor.py +189 -0
- keras_hub/src/models/llama/llama_presets.py +80 -0
- keras_hub/src/models/llama/llama_tokenizer.py +84 -0
- keras_hub/src/models/llama3/__init__.py +20 -0
- keras_hub/src/models/llama3/llama3_backbone.py +84 -0
- keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
- keras_hub/src/models/llama3/llama3_presets.py +69 -0
- keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
- keras_hub/src/models/masked_lm.py +101 -0
- keras_hub/src/models/mistral/__init__.py +20 -0
- keras_hub/src/models/mistral/mistral_attention.py +238 -0
- keras_hub/src/models/mistral/mistral_backbone.py +203 -0
- keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
- keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
- keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
- keras_hub/src/models/mistral/mistral_presets.py +48 -0
- keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
- keras_hub/src/models/mix_transformer/__init__.py +13 -0
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
- keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
- keras_hub/src/models/opt/__init__.py +20 -0
- keras_hub/src/models/opt/opt_backbone.py +173 -0
- keras_hub/src/models/opt/opt_causal_lm.py +301 -0
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
- keras_hub/src/models/opt/opt_preprocessor.py +188 -0
- keras_hub/src/models/opt/opt_presets.py +72 -0
- keras_hub/src/models/opt/opt_tokenizer.py +116 -0
- keras_hub/src/models/pali_gemma/__init__.py +23 -0
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
- keras_hub/src/models/phi3/__init__.py +20 -0
- keras_hub/src/models/phi3/phi3_attention.py +260 -0
- keras_hub/src/models/phi3/phi3_backbone.py +224 -0
- keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/phi3/phi3_decoder.py +260 -0
- keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
- keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
- keras_hub/src/models/phi3/phi3_presets.py +50 -0
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
- keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
- keras_hub/src/models/preprocessor.py +207 -0
- keras_hub/src/models/resnet/__init__.py +13 -0
- keras_hub/src/models/resnet/resnet_backbone.py +612 -0
- keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
- keras_hub/src/models/roberta/__init__.py +20 -0
- keras_hub/src/models/roberta/roberta_backbone.py +184 -0
- keras_hub/src/models/roberta/roberta_classifier.py +209 -0
- keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
- keras_hub/src/models/roberta/roberta_presets.py +43 -0
- keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
- keras_hub/src/models/seq_2_seq_lm.py +54 -0
- keras_hub/src/models/t5/__init__.py +20 -0
- keras_hub/src/models/t5/t5_backbone.py +261 -0
- keras_hub/src/models/t5/t5_layer_norm.py +35 -0
- keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
- keras_hub/src/models/t5/t5_presets.py +95 -0
- keras_hub/src/models/t5/t5_tokenizer.py +100 -0
- keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
- keras_hub/src/models/task.py +419 -0
- keras_hub/src/models/vgg/__init__.py +13 -0
- keras_hub/src/models/vgg/vgg_backbone.py +158 -0
- keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
- keras_hub/src/models/vit_det/__init__.py +13 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
- keras_hub/src/models/vit_det/vit_layers.py +565 -0
- keras_hub/src/models/whisper/__init__.py +20 -0
- keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
- keras_hub/src/models/whisper/whisper_backbone.py +305 -0
- keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
- keras_hub/src/models/whisper/whisper_decoder.py +141 -0
- keras_hub/src/models/whisper/whisper_encoder.py +106 -0
- keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
- keras_hub/src/models/whisper/whisper_presets.py +148 -0
- keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
- keras_hub/src/models/xlm_roberta/__init__.py +26 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
- keras_hub/src/models/xlnet/__init__.py +13 -0
- keras_hub/src/models/xlnet/relative_attention.py +459 -0
- keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
- keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
- keras_hub/src/samplers/__init__.py +13 -0
- keras_hub/src/samplers/beam_sampler.py +207 -0
- keras_hub/src/samplers/contrastive_sampler.py +231 -0
- keras_hub/src/samplers/greedy_sampler.py +50 -0
- keras_hub/src/samplers/random_sampler.py +77 -0
- keras_hub/src/samplers/sampler.py +237 -0
- keras_hub/src/samplers/serialization.py +97 -0
- keras_hub/src/samplers/top_k_sampler.py +92 -0
- keras_hub/src/samplers/top_p_sampler.py +113 -0
- keras_hub/src/tests/__init__.py +13 -0
- keras_hub/src/tests/test_case.py +608 -0
- keras_hub/src/tokenizers/__init__.py +13 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
- keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
- keras_hub/src/tokenizers/tokenizer.py +235 -0
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
- keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
- keras_hub/src/utils/__init__.py +13 -0
- keras_hub/src/utils/keras_utils.py +130 -0
- keras_hub/src/utils/pipeline_model.py +293 -0
- keras_hub/src/utils/preset_utils.py +621 -0
- keras_hub/src/utils/python_utils.py +21 -0
- keras_hub/src/utils/tensor_utils.py +206 -0
- keras_hub/src/utils/timm/__init__.py +13 -0
- keras_hub/src/utils/timm/convert.py +37 -0
- keras_hub/src/utils/timm/convert_resnet.py +171 -0
- keras_hub/src/utils/transformers/__init__.py +13 -0
- keras_hub/src/utils/transformers/convert.py +101 -0
- keras_hub/src/utils/transformers/convert_bert.py +173 -0
- keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
- keras_hub/src/utils/transformers/convert_gemma.py +187 -0
- keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
- keras_hub/src/utils/transformers/convert_llama3.py +136 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
- keras_hub/src/version_utils.py +23 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,181 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import keras
|
15
|
+
import numpy as np
|
16
|
+
from keras import ops
|
17
|
+
|
18
|
+
from keras_hub.src.api_export import keras_hub_export
|
19
|
+
from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone
|
20
|
+
from keras_hub.src.models.mix_transformer.mix_transformer_layers import (
|
21
|
+
HierarchicalTransformerEncoder,
|
22
|
+
)
|
23
|
+
from keras_hub.src.models.mix_transformer.mix_transformer_layers import (
|
24
|
+
OverlappingPatchingAndEmbedding,
|
25
|
+
)
|
26
|
+
|
27
|
+
|
28
|
+
@keras_hub_export("keras_hub.models.MiTBackbone")
|
29
|
+
class MiTBackbone(FeaturePyramidBackbone):
|
30
|
+
def __init__(
|
31
|
+
self,
|
32
|
+
depths,
|
33
|
+
num_layers,
|
34
|
+
blockwise_num_heads,
|
35
|
+
blockwise_sr_ratios,
|
36
|
+
end_value,
|
37
|
+
patch_sizes,
|
38
|
+
strides,
|
39
|
+
include_rescaling=True,
|
40
|
+
image_shape=(224, 224, 3),
|
41
|
+
hidden_dims=None,
|
42
|
+
**kwargs,
|
43
|
+
):
|
44
|
+
"""A Backbone implementing the MixTransformer.
|
45
|
+
|
46
|
+
This architecture to be used as a backbone for the SegFormer
|
47
|
+
architecture [SegFormer: Simple and Efficient Design for Semantic
|
48
|
+
Segmentation with Transformers](https://arxiv.org/abs/2105.15203)
|
49
|
+
[Based on the TensorFlow implementation from DeepVision](
|
50
|
+
https://github.com/DavidLandup0/deepvision/tree/main/deepvision/models/classification/mix_transformer)
|
51
|
+
|
52
|
+
Args:
|
53
|
+
depths: The number of transformer encoders to be used per layer in the
|
54
|
+
network.
|
55
|
+
num_layers: int. The number of Transformer layers.
|
56
|
+
blockwise_num_heads: list of integers, the number of heads to use
|
57
|
+
in the attention computation for each layer.
|
58
|
+
blockwise_sr_ratios: list of integers, the sequence reduction
|
59
|
+
ratio to perform for each layer on the sequence before key and
|
60
|
+
value projections. If set to > 1, a `Conv2D` layer is used to
|
61
|
+
reduce the length of the sequence.
|
62
|
+
end_value: The end value of the sequence.
|
63
|
+
include_rescaling: bool, whether to rescale the inputs. If set
|
64
|
+
to `True`, inputs will be passed through a `Rescaling(1/255.0)`
|
65
|
+
layer. Defaults to `True`.
|
66
|
+
image_shape: optional shape tuple, defaults to (224, 224, 3).
|
67
|
+
hidden_dims: the embedding dims per hierarchical layer, used as
|
68
|
+
the levels of the feature pyramid.
|
69
|
+
patch_sizes: list of integers, the patch_size to apply for each layer.
|
70
|
+
strides: list of integers, stride to apply for each layer.
|
71
|
+
|
72
|
+
Examples:
|
73
|
+
|
74
|
+
Using the class with a `backbone`:
|
75
|
+
|
76
|
+
```python
|
77
|
+
images = np.ones(shape=(1, 96, 96, 3))
|
78
|
+
labels = np.zeros(shape=(1, 96, 96, 1))
|
79
|
+
backbone = keras_hub.models.MiTBackbone.from_preset("mit_b0_imagenet")
|
80
|
+
|
81
|
+
# Evaluate model
|
82
|
+
model(images)
|
83
|
+
|
84
|
+
# Train model
|
85
|
+
model.compile(
|
86
|
+
optimizer="adam",
|
87
|
+
loss=keras.losses.BinaryCrossentropy(from_logits=False),
|
88
|
+
metrics=["accuracy"],
|
89
|
+
)
|
90
|
+
model.fit(images, labels, epochs=3)
|
91
|
+
```
|
92
|
+
"""
|
93
|
+
dpr = [x for x in np.linspace(0.0, end_value, sum(depths))]
|
94
|
+
|
95
|
+
# === Layers ===
|
96
|
+
cur = 0
|
97
|
+
patch_embedding_layers = []
|
98
|
+
transformer_blocks = []
|
99
|
+
layer_norms = []
|
100
|
+
|
101
|
+
for i in range(num_layers):
|
102
|
+
patch_embed_layer = OverlappingPatchingAndEmbedding(
|
103
|
+
project_dim=hidden_dims[i],
|
104
|
+
patch_size=patch_sizes[i],
|
105
|
+
stride=strides[i],
|
106
|
+
name=f"patch_and_embed_{i}",
|
107
|
+
)
|
108
|
+
patch_embedding_layers.append(patch_embed_layer)
|
109
|
+
|
110
|
+
transformer_block = [
|
111
|
+
HierarchicalTransformerEncoder(
|
112
|
+
project_dim=hidden_dims[i],
|
113
|
+
num_heads=blockwise_num_heads[i],
|
114
|
+
sr_ratio=blockwise_sr_ratios[i],
|
115
|
+
drop_prob=dpr[cur + k],
|
116
|
+
name=f"hierarchical_encoder_{i}_{k}",
|
117
|
+
)
|
118
|
+
for k in range(depths[i])
|
119
|
+
]
|
120
|
+
transformer_blocks.append(transformer_block)
|
121
|
+
cur += depths[i]
|
122
|
+
layer_norms.append(keras.layers.LayerNormalization())
|
123
|
+
|
124
|
+
# === Functional Model ===
|
125
|
+
image_input = keras.layers.Input(shape=image_shape)
|
126
|
+
x = image_input
|
127
|
+
|
128
|
+
if include_rescaling:
|
129
|
+
x = keras.layers.Rescaling(scale=1 / 255)(x)
|
130
|
+
|
131
|
+
pyramid_outputs = {}
|
132
|
+
for i in range(num_layers):
|
133
|
+
# Compute new height/width after the `proj`
|
134
|
+
# call in `OverlappingPatchingAndEmbedding`
|
135
|
+
stride = strides[i]
|
136
|
+
new_height, new_width = (
|
137
|
+
int(ops.shape(x)[1] / stride),
|
138
|
+
int(ops.shape(x)[2] / stride),
|
139
|
+
)
|
140
|
+
|
141
|
+
x = patch_embedding_layers[i](x)
|
142
|
+
for blk in transformer_blocks[i]:
|
143
|
+
x = blk(x)
|
144
|
+
x = layer_norms[i](x)
|
145
|
+
x = keras.layers.Reshape(
|
146
|
+
(new_height, new_width, -1), name=f"output_level_{i}"
|
147
|
+
)(x)
|
148
|
+
pyramid_outputs[f"P{i + 1}"] = x
|
149
|
+
|
150
|
+
super().__init__(inputs=image_input, outputs=x, **kwargs)
|
151
|
+
|
152
|
+
# === Config ===
|
153
|
+
self.depths = depths
|
154
|
+
self.include_rescaling = include_rescaling
|
155
|
+
self.image_shape = image_shape
|
156
|
+
self.hidden_dims = hidden_dims
|
157
|
+
self.pyramid_outputs = pyramid_outputs
|
158
|
+
self.num_layers = num_layers
|
159
|
+
self.blockwise_num_heads = blockwise_num_heads
|
160
|
+
self.blockwise_sr_ratios = blockwise_sr_ratios
|
161
|
+
self.end_value = end_value
|
162
|
+
self.patch_sizes = patch_sizes
|
163
|
+
self.strides = strides
|
164
|
+
|
165
|
+
def get_config(self):
|
166
|
+
config = super().get_config()
|
167
|
+
config.update(
|
168
|
+
{
|
169
|
+
"depths": self.depths,
|
170
|
+
"include_rescaling": self.include_rescaling,
|
171
|
+
"hidden_dims": self.hidden_dims,
|
172
|
+
"image_shape": self.image_shape,
|
173
|
+
"num_layers": self.num_layers,
|
174
|
+
"blockwise_num_heads": self.blockwise_num_heads,
|
175
|
+
"blockwise_sr_ratios": self.blockwise_sr_ratios,
|
176
|
+
"end_value": self.end_value,
|
177
|
+
"patch_sizes": self.patch_sizes,
|
178
|
+
"strides": self.strides,
|
179
|
+
}
|
180
|
+
)
|
181
|
+
return config
|
@@ -0,0 +1,133 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import keras
|
15
|
+
|
16
|
+
from keras_hub.src.api_export import keras_hub_export
|
17
|
+
from keras_hub.src.models.image_classifier import ImageClassifier
|
18
|
+
from keras_hub.src.models.mix_transformer.mix_transformer_backbone import (
|
19
|
+
MiTBackbone,
|
20
|
+
)
|
21
|
+
|
22
|
+
|
23
|
+
@keras_hub_export("keras_hub.models.MiTImageClassifier")
|
24
|
+
class MiTImageClassifier(ImageClassifier):
|
25
|
+
"""MiTImageClassifier image classifier model.
|
26
|
+
|
27
|
+
Args:
|
28
|
+
backbone: A `keras_hub.models.MiTBackbone` instance.
|
29
|
+
num_classes: int. The number of classes to predict.
|
30
|
+
activation: `None`, str or callable. The activation function to use on
|
31
|
+
the `Dense` layer. Set `activation=None` to return the output
|
32
|
+
logits. Defaults to `"softmax"`.
|
33
|
+
|
34
|
+
To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)`
|
35
|
+
where `x` is a tensor and `y` is a integer from `[0, num_classes)`.
|
36
|
+
All `ImageClassifier` tasks include a `from_preset()` constructor which can
|
37
|
+
be used to load a pre-trained config and weights.
|
38
|
+
|
39
|
+
Examples:
|
40
|
+
|
41
|
+
Call `predict()` to run inference.
|
42
|
+
```python
|
43
|
+
# Load preset and train
|
44
|
+
images = np.ones((2, 224, 224, 3), dtype="float32")
|
45
|
+
classifier = keras_hub.models.MiTImageClassifier.from_preset(
|
46
|
+
"mit_b0_imagenet")
|
47
|
+
classifier.predict(images)
|
48
|
+
```
|
49
|
+
|
50
|
+
Call `fit()` on a single batch.
|
51
|
+
```python
|
52
|
+
# Load preset and train
|
53
|
+
images = np.ones((2, 224, 224, 3), dtype="float32")
|
54
|
+
labels = [0, 3]
|
55
|
+
classifier = keras_hub.models.MixTransformerImageClassifier.from_preset(
|
56
|
+
"mit_b0_imagenet")
|
57
|
+
classifier.fit(x=images, y=labels, batch_size=2)
|
58
|
+
```
|
59
|
+
|
60
|
+
Call `fit()` with custom loss, optimizer and backbone.
|
61
|
+
```python
|
62
|
+
classifier = keras_hub.models.MiTImageClassifier.from_preset(
|
63
|
+
"mit_b0_imagenet")
|
64
|
+
classifier.compile(
|
65
|
+
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
|
66
|
+
optimizer=keras.optimizers.Adam(5e-5),
|
67
|
+
)
|
68
|
+
classifier.backbone.trainable = False
|
69
|
+
classifier.fit(x=images, y=labels, batch_size=2)
|
70
|
+
```
|
71
|
+
|
72
|
+
Custom backbone.
|
73
|
+
```python
|
74
|
+
images = np.ones((2, 224, 224, 3), dtype="float32")
|
75
|
+
labels = [0, 3]
|
76
|
+
backbone = keras_hub.models.MiTBackbone(
|
77
|
+
stackwise_num_filters=[128, 256, 512, 1024],
|
78
|
+
stackwise_depth=[3, 9, 9, 3],
|
79
|
+
include_rescaling=False,
|
80
|
+
block_type="basic_block",
|
81
|
+
image_shape = (224, 224, 3),
|
82
|
+
)
|
83
|
+
classifier = keras_hub.models.MiTImageClassifier(
|
84
|
+
backbone=backbone,
|
85
|
+
num_classes=4,
|
86
|
+
)
|
87
|
+
classifier.fit(x=images, y=labels, batch_size=2)
|
88
|
+
```
|
89
|
+
"""
|
90
|
+
|
91
|
+
backbone_cls = MiTBackbone
|
92
|
+
|
93
|
+
def __init__(
|
94
|
+
self,
|
95
|
+
backbone,
|
96
|
+
num_classes,
|
97
|
+
activation="softmax",
|
98
|
+
preprocessor=None, # adding this dummy arg for saved model test
|
99
|
+
# TODO: once preprocessor flow is figured out, this needs to be updated
|
100
|
+
**kwargs,
|
101
|
+
):
|
102
|
+
# === Layers ===
|
103
|
+
self.backbone = backbone
|
104
|
+
self.output_dense = keras.layers.Dense(
|
105
|
+
num_classes,
|
106
|
+
activation=activation,
|
107
|
+
name="predictions",
|
108
|
+
)
|
109
|
+
|
110
|
+
# === Functional Model ===
|
111
|
+
inputs = self.backbone.input
|
112
|
+
x = self.backbone(inputs)
|
113
|
+
outputs = self.output_dense(x)
|
114
|
+
super().__init__(
|
115
|
+
inputs=inputs,
|
116
|
+
outputs=outputs,
|
117
|
+
**kwargs,
|
118
|
+
)
|
119
|
+
|
120
|
+
# === Config ===
|
121
|
+
self.num_classes = num_classes
|
122
|
+
self.activation = activation
|
123
|
+
|
124
|
+
def get_config(self):
|
125
|
+
# Backbone serialized in `super`
|
126
|
+
config = super().get_config()
|
127
|
+
config.update(
|
128
|
+
{
|
129
|
+
"num_classes": self.num_classes,
|
130
|
+
"activation": self.activation,
|
131
|
+
}
|
132
|
+
)
|
133
|
+
return config
|
@@ -0,0 +1,300 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import math
|
15
|
+
|
16
|
+
import keras
|
17
|
+
from keras import ops
|
18
|
+
from keras import random
|
19
|
+
|
20
|
+
|
21
|
+
class OverlappingPatchingAndEmbedding(keras.layers.Layer):
|
22
|
+
def __init__(self, project_dim=32, patch_size=7, stride=4, **kwargs):
|
23
|
+
"""Overlapping Patching and Embedding layer.
|
24
|
+
|
25
|
+
Differs from `PatchingAndEmbedding` in that the patch size does not
|
26
|
+
affect the sequence length. It's fully derived from the `stride`
|
27
|
+
parameter. Additionally, no positional embedding is done
|
28
|
+
as part of the layer - only a projection using a `Conv2D` layer.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
project_dim: integer, the dimensionality of the projection.
|
32
|
+
Defaults to `32`.
|
33
|
+
patch_size: integer, the size of the patches to encode.
|
34
|
+
Defaults to `7`.
|
35
|
+
stride: integer, the stride to use for the patching before
|
36
|
+
projection. Defaults to `5`.
|
37
|
+
"""
|
38
|
+
super().__init__(**kwargs)
|
39
|
+
|
40
|
+
self.project_dim = project_dim
|
41
|
+
self.patch_size = patch_size
|
42
|
+
self.stride = stride
|
43
|
+
|
44
|
+
self.proj = keras.layers.Conv2D(
|
45
|
+
filters=project_dim,
|
46
|
+
kernel_size=patch_size,
|
47
|
+
strides=stride,
|
48
|
+
padding="same",
|
49
|
+
)
|
50
|
+
self.norm = keras.layers.LayerNormalization()
|
51
|
+
|
52
|
+
def call(self, x):
|
53
|
+
x = self.proj(x)
|
54
|
+
# B, H, W, C
|
55
|
+
shape = x.shape
|
56
|
+
x = ops.reshape(x, (-1, shape[1] * shape[2], shape[3]))
|
57
|
+
x = self.norm(x)
|
58
|
+
return x
|
59
|
+
|
60
|
+
def get_config(self):
|
61
|
+
config = super().get_config()
|
62
|
+
config.update(
|
63
|
+
{
|
64
|
+
"project_dim": self.project_dim,
|
65
|
+
"patch_size": self.patch_size,
|
66
|
+
"stride": self.stride,
|
67
|
+
}
|
68
|
+
)
|
69
|
+
return config
|
70
|
+
|
71
|
+
|
72
|
+
class HierarchicalTransformerEncoder(keras.layers.Layer):
|
73
|
+
"""Hierarchical transformer encoder block implementation as a Keras Layer.
|
74
|
+
|
75
|
+
The layer uses `SegFormerMultiheadAttention` as a `MultiHeadAttention`
|
76
|
+
alternative for computational efficiency, and is meant to be used
|
77
|
+
within the SegFormer architecture.
|
78
|
+
|
79
|
+
Args:
|
80
|
+
project_dim: integer, the dimensionality of the projection of the
|
81
|
+
encoder, and output of the `SegFormerMultiheadAttention` layer.
|
82
|
+
Due to the residual addition the input dimensionality has to be
|
83
|
+
equal to the output dimensionality.
|
84
|
+
num_heads: integer, the number of heads for the
|
85
|
+
`SegFormerMultiheadAttention` layer.
|
86
|
+
drop_prob: float, the probability of dropping a random
|
87
|
+
sample using the `DropPath` layer. Defaults to `0.0`.
|
88
|
+
layer_norm_epsilon: float, the epsilon for
|
89
|
+
`LayerNormalization` layers. Defaults to `1e-06`
|
90
|
+
sr_ratio: integer, the ratio to use within
|
91
|
+
`SegFormerMultiheadAttention`. If set to > 1, a `Conv2D`
|
92
|
+
layer is used to reduce the length of the sequence. Defaults to `1`.
|
93
|
+
"""
|
94
|
+
|
95
|
+
def __init__(
|
96
|
+
self,
|
97
|
+
project_dim,
|
98
|
+
num_heads,
|
99
|
+
sr_ratio=1,
|
100
|
+
drop_prob=0.0,
|
101
|
+
layer_norm_epsilon=1e-6,
|
102
|
+
**kwargs,
|
103
|
+
):
|
104
|
+
super().__init__(**kwargs)
|
105
|
+
self.project_dim = project_dim
|
106
|
+
self.num_heads = num_heads
|
107
|
+
self.drop_prop = drop_prob
|
108
|
+
|
109
|
+
self.norm1 = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon)
|
110
|
+
self.attn = SegFormerMultiheadAttention(
|
111
|
+
project_dim, num_heads, sr_ratio
|
112
|
+
)
|
113
|
+
self.drop_path = DropPath(drop_prob)
|
114
|
+
self.norm2 = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon)
|
115
|
+
self.mlp = MixFFN(
|
116
|
+
channels=project_dim,
|
117
|
+
mid_channels=int(project_dim * 4),
|
118
|
+
)
|
119
|
+
|
120
|
+
def build(self, input_shape):
|
121
|
+
super().build(input_shape)
|
122
|
+
self.H = ops.sqrt(ops.cast(input_shape[1], "float32"))
|
123
|
+
self.W = ops.sqrt(ops.cast(input_shape[2], "float32"))
|
124
|
+
|
125
|
+
def call(self, x):
|
126
|
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
127
|
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
128
|
+
return x
|
129
|
+
|
130
|
+
def get_config(self):
|
131
|
+
config = super().get_config()
|
132
|
+
config.update(
|
133
|
+
{
|
134
|
+
"mlp": keras.saving.serialize_keras_object(self.mlp),
|
135
|
+
"project_dim": self.project_dim,
|
136
|
+
"num_heads": self.num_heads,
|
137
|
+
"drop_prop": self.drop_prop,
|
138
|
+
}
|
139
|
+
)
|
140
|
+
return config
|
141
|
+
|
142
|
+
|
143
|
+
class MixFFN(keras.layers.Layer):
|
144
|
+
def __init__(self, channels, mid_channels):
|
145
|
+
super().__init__()
|
146
|
+
self.fc1 = keras.layers.Dense(mid_channels)
|
147
|
+
self.dwconv = keras.layers.DepthwiseConv2D(
|
148
|
+
kernel_size=3,
|
149
|
+
strides=1,
|
150
|
+
padding="same",
|
151
|
+
)
|
152
|
+
self.fc2 = keras.layers.Dense(channels)
|
153
|
+
|
154
|
+
def call(self, x):
|
155
|
+
x = self.fc1(x)
|
156
|
+
shape = ops.shape(x)
|
157
|
+
H, W = int(math.sqrt(shape[1])), int(math.sqrt(shape[1]))
|
158
|
+
B, C = shape[0], shape[2]
|
159
|
+
x = ops.reshape(x, (B, H, W, C))
|
160
|
+
x = self.dwconv(x)
|
161
|
+
x = ops.reshape(x, (B, -1, C))
|
162
|
+
x = ops.nn.gelu(x)
|
163
|
+
x = self.fc2(x)
|
164
|
+
return x
|
165
|
+
|
166
|
+
|
167
|
+
class SegFormerMultiheadAttention(keras.layers.Layer):
|
168
|
+
def __init__(self, project_dim, num_heads, sr_ratio):
|
169
|
+
"""Efficient MultiHeadAttention implementation as a Keras layer.
|
170
|
+
|
171
|
+
A huge bottleneck in scaling transformers is the self-attention layer
|
172
|
+
with an O(n^2) complexity.
|
173
|
+
|
174
|
+
SegFormerMultiheadAttention performs a sequence reduction (SR) operation
|
175
|
+
with a given ratio, to reduce the sequence length before performing key
|
176
|
+
and value projections, reducing the O(n^2) complexity to O(n^2/R) where
|
177
|
+
R is the sequence reduction ratio.
|
178
|
+
|
179
|
+
Args:
|
180
|
+
project_dim: integer, the dimensionality of the projection
|
181
|
+
of the `SegFormerMultiheadAttention` layer.
|
182
|
+
num_heads: integer, the number of heads to use in the
|
183
|
+
attention computation.
|
184
|
+
sr_ratio: integer, the sequence reduction ratio to perform
|
185
|
+
on the sequence before key and value projections.
|
186
|
+
"""
|
187
|
+
super().__init__()
|
188
|
+
self.num_heads = num_heads
|
189
|
+
self.sr_ratio = sr_ratio
|
190
|
+
self.scale = (project_dim // num_heads) ** -0.5
|
191
|
+
self.q = keras.layers.Dense(project_dim)
|
192
|
+
self.k = keras.layers.Dense(project_dim)
|
193
|
+
self.v = keras.layers.Dense(project_dim)
|
194
|
+
self.proj = keras.layers.Dense(project_dim)
|
195
|
+
|
196
|
+
if sr_ratio > 1:
|
197
|
+
self.sr = keras.layers.Conv2D(
|
198
|
+
filters=project_dim,
|
199
|
+
kernel_size=sr_ratio,
|
200
|
+
strides=sr_ratio,
|
201
|
+
padding="same",
|
202
|
+
)
|
203
|
+
self.norm = keras.layers.LayerNormalization()
|
204
|
+
|
205
|
+
def call(self, x):
|
206
|
+
input_shape = ops.shape(x)
|
207
|
+
H, W = int(math.sqrt(input_shape[1])), int(math.sqrt(input_shape[1]))
|
208
|
+
B, C = input_shape[0], input_shape[2]
|
209
|
+
|
210
|
+
q = self.q(x)
|
211
|
+
q = ops.reshape(
|
212
|
+
q,
|
213
|
+
(
|
214
|
+
input_shape[0],
|
215
|
+
input_shape[1],
|
216
|
+
self.num_heads,
|
217
|
+
input_shape[2] // self.num_heads,
|
218
|
+
),
|
219
|
+
)
|
220
|
+
q = ops.transpose(q, [0, 2, 1, 3])
|
221
|
+
|
222
|
+
if self.sr_ratio > 1:
|
223
|
+
x = ops.reshape(
|
224
|
+
ops.transpose(x, [0, 2, 1]),
|
225
|
+
(B, H, W, C),
|
226
|
+
)
|
227
|
+
x = self.sr(x)
|
228
|
+
x = ops.reshape(x, [input_shape[0], input_shape[2], -1])
|
229
|
+
x = ops.transpose(x, [0, 2, 1])
|
230
|
+
x = self.norm(x)
|
231
|
+
|
232
|
+
k = self.k(x)
|
233
|
+
v = self.v(x)
|
234
|
+
|
235
|
+
k = ops.transpose(
|
236
|
+
ops.reshape(
|
237
|
+
k,
|
238
|
+
[B, -1, self.num_heads, C // self.num_heads],
|
239
|
+
),
|
240
|
+
[0, 2, 1, 3],
|
241
|
+
)
|
242
|
+
|
243
|
+
v = ops.transpose(
|
244
|
+
ops.reshape(
|
245
|
+
v,
|
246
|
+
[B, -1, self.num_heads, C // self.num_heads],
|
247
|
+
),
|
248
|
+
[0, 2, 1, 3],
|
249
|
+
)
|
250
|
+
|
251
|
+
attn = (q @ ops.transpose(k, [0, 1, 3, 2])) * self.scale
|
252
|
+
attn = ops.nn.softmax(attn, axis=-1)
|
253
|
+
|
254
|
+
attn = attn @ v
|
255
|
+
attn = ops.reshape(
|
256
|
+
ops.transpose(attn, [0, 2, 1, 3]),
|
257
|
+
[input_shape[0], input_shape[1], input_shape[2]],
|
258
|
+
)
|
259
|
+
|
260
|
+
x = self.proj(attn)
|
261
|
+
return x
|
262
|
+
|
263
|
+
|
264
|
+
class DropPath(keras.layers.Layer):
|
265
|
+
"""Implements the DropPath layer.
|
266
|
+
|
267
|
+
DropPath randomly drops samples during
|
268
|
+
training with a probability of `rate`. Note that this layer drops individual
|
269
|
+
samples within a batch and not the entire batch, whereas StochasticDepth
|
270
|
+
randomly drops the entire batch.
|
271
|
+
|
272
|
+
Args:
|
273
|
+
rate: float, the probability of the residual branch being dropped.
|
274
|
+
seed: (Optional) integer. Used to create a random seed.
|
275
|
+
"""
|
276
|
+
|
277
|
+
def __init__(self, rate=0.5, seed=None, **kwargs):
|
278
|
+
super().__init__(**kwargs)
|
279
|
+
self.rate = rate
|
280
|
+
self._seed_val = seed
|
281
|
+
self.seed = random.SeedGenerator(seed=seed)
|
282
|
+
|
283
|
+
def call(self, x, training=None):
|
284
|
+
if self.rate == 0.0 or not training:
|
285
|
+
return x
|
286
|
+
else:
|
287
|
+
batch_size = x.shape[0] or ops.shape(x)[0]
|
288
|
+
drop_map_shape = (batch_size,) + (1,) * (len(x.shape) - 1)
|
289
|
+
drop_map = ops.cast(
|
290
|
+
random.uniform(drop_map_shape, seed=self.seed) > self.rate,
|
291
|
+
x.dtype,
|
292
|
+
)
|
293
|
+
x = x / (1.0 - self.rate)
|
294
|
+
x = x * drop_map
|
295
|
+
return x
|
296
|
+
|
297
|
+
def get_config(self):
|
298
|
+
config = super().get_config()
|
299
|
+
config.update({"rate": self.rate, "seed": self._seed_val})
|
300
|
+
return config
|
@@ -0,0 +1,20 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from keras_hub.src.models.opt.opt_backbone import OPTBackbone
|
16
|
+
from keras_hub.src.models.opt.opt_presets import backbone_presets
|
17
|
+
from keras_hub.src.models.opt.opt_tokenizer import OPTTokenizer
|
18
|
+
from keras_hub.src.utils.preset_utils import register_presets
|
19
|
+
|
20
|
+
register_presets(backbone_presets, (OPTBackbone, OPTTokenizer))
|