keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +52 -0
- keras_hub/api/__init__.py +27 -0
- keras_hub/api/layers/__init__.py +47 -0
- keras_hub/api/metrics/__init__.py +24 -0
- keras_hub/api/models/__init__.py +249 -0
- keras_hub/api/samplers/__init__.py +29 -0
- keras_hub/api/tokenizers/__init__.py +35 -0
- keras_hub/src/__init__.py +13 -0
- keras_hub/src/api_export.py +53 -0
- keras_hub/src/layers/__init__.py +13 -0
- keras_hub/src/layers/modeling/__init__.py +13 -0
- keras_hub/src/layers/modeling/alibi_bias.py +143 -0
- keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
- keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
- keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
- keras_hub/src/layers/modeling/position_embedding.py +123 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
- keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
- keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
- keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
- keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
- keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
- keras_hub/src/layers/preprocessing/__init__.py +13 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
- keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
- keras_hub/src/layers/preprocessing/random_swap.py +267 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
- keras_hub/src/metrics/__init__.py +13 -0
- keras_hub/src/metrics/bleu.py +394 -0
- keras_hub/src/metrics/edit_distance.py +197 -0
- keras_hub/src/metrics/perplexity.py +181 -0
- keras_hub/src/metrics/rouge_base.py +204 -0
- keras_hub/src/metrics/rouge_l.py +97 -0
- keras_hub/src/metrics/rouge_n.py +125 -0
- keras_hub/src/models/__init__.py +13 -0
- keras_hub/src/models/albert/__init__.py +20 -0
- keras_hub/src/models/albert/albert_backbone.py +267 -0
- keras_hub/src/models/albert/albert_classifier.py +202 -0
- keras_hub/src/models/albert/albert_masked_lm.py +129 -0
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/albert/albert_preprocessor.py +206 -0
- keras_hub/src/models/albert/albert_presets.py +70 -0
- keras_hub/src/models/albert/albert_tokenizer.py +119 -0
- keras_hub/src/models/backbone.py +311 -0
- keras_hub/src/models/bart/__init__.py +20 -0
- keras_hub/src/models/bart/bart_backbone.py +261 -0
- keras_hub/src/models/bart/bart_preprocessor.py +276 -0
- keras_hub/src/models/bart/bart_presets.py +74 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
- keras_hub/src/models/bart/bart_tokenizer.py +124 -0
- keras_hub/src/models/bert/__init__.py +23 -0
- keras_hub/src/models/bert/bert_backbone.py +227 -0
- keras_hub/src/models/bert/bert_classifier.py +183 -0
- keras_hub/src/models/bert/bert_masked_lm.py +131 -0
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/bert/bert_preprocessor.py +184 -0
- keras_hub/src/models/bert/bert_presets.py +147 -0
- keras_hub/src/models/bert/bert_tokenizer.py +112 -0
- keras_hub/src/models/bloom/__init__.py +20 -0
- keras_hub/src/models/bloom/bloom_attention.py +186 -0
- keras_hub/src/models/bloom/bloom_backbone.py +173 -0
- keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
- keras_hub/src/models/bloom/bloom_decoder.py +206 -0
- keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
- keras_hub/src/models/bloom/bloom_presets.py +121 -0
- keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
- keras_hub/src/models/causal_lm.py +383 -0
- keras_hub/src/models/classifier.py +109 -0
- keras_hub/src/models/csp_darknet/__init__.py +13 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
- keras_hub/src/models/deberta_v3/__init__.py +24 -0
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
- keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
- keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
- keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
- keras_hub/src/models/densenet/__init__.py +13 -0
- keras_hub/src/models/densenet/densenet_backbone.py +210 -0
- keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
- keras_hub/src/models/distil_bert/__init__.py +26 -0
- keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
- keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
- keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
- keras_hub/src/models/electra/__init__.py +20 -0
- keras_hub/src/models/electra/electra_backbone.py +247 -0
- keras_hub/src/models/electra/electra_preprocessor.py +154 -0
- keras_hub/src/models/electra/electra_presets.py +95 -0
- keras_hub/src/models/electra/electra_tokenizer.py +104 -0
- keras_hub/src/models/f_net/__init__.py +20 -0
- keras_hub/src/models/f_net/f_net_backbone.py +236 -0
- keras_hub/src/models/f_net/f_net_classifier.py +154 -0
- keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
- keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
- keras_hub/src/models/f_net/f_net_presets.py +43 -0
- keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
- keras_hub/src/models/falcon/__init__.py +20 -0
- keras_hub/src/models/falcon/falcon_attention.py +156 -0
- keras_hub/src/models/falcon/falcon_backbone.py +164 -0
- keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
- keras_hub/src/models/falcon/falcon_presets.py +30 -0
- keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
- keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
- keras_hub/src/models/feature_pyramid_backbone.py +73 -0
- keras_hub/src/models/gemma/__init__.py +20 -0
- keras_hub/src/models/gemma/gemma_attention.py +250 -0
- keras_hub/src/models/gemma/gemma_backbone.py +316 -0
- keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
- keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
- keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
- keras_hub/src/models/gemma/gemma_presets.py +248 -0
- keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
- keras_hub/src/models/gemma/rms_normalization.py +40 -0
- keras_hub/src/models/gpt2/__init__.py +20 -0
- keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
- keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
- keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
- keras_hub/src/models/image_classifier.py +90 -0
- keras_hub/src/models/llama/__init__.py +20 -0
- keras_hub/src/models/llama/llama_attention.py +225 -0
- keras_hub/src/models/llama/llama_backbone.py +188 -0
- keras_hub/src/models/llama/llama_causal_lm.py +327 -0
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
- keras_hub/src/models/llama/llama_decoder.py +246 -0
- keras_hub/src/models/llama/llama_layernorm.py +48 -0
- keras_hub/src/models/llama/llama_preprocessor.py +189 -0
- keras_hub/src/models/llama/llama_presets.py +80 -0
- keras_hub/src/models/llama/llama_tokenizer.py +84 -0
- keras_hub/src/models/llama3/__init__.py +20 -0
- keras_hub/src/models/llama3/llama3_backbone.py +84 -0
- keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
- keras_hub/src/models/llama3/llama3_presets.py +69 -0
- keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
- keras_hub/src/models/masked_lm.py +101 -0
- keras_hub/src/models/mistral/__init__.py +20 -0
- keras_hub/src/models/mistral/mistral_attention.py +238 -0
- keras_hub/src/models/mistral/mistral_backbone.py +203 -0
- keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
- keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
- keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
- keras_hub/src/models/mistral/mistral_presets.py +48 -0
- keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
- keras_hub/src/models/mix_transformer/__init__.py +13 -0
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
- keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
- keras_hub/src/models/opt/__init__.py +20 -0
- keras_hub/src/models/opt/opt_backbone.py +173 -0
- keras_hub/src/models/opt/opt_causal_lm.py +301 -0
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
- keras_hub/src/models/opt/opt_preprocessor.py +188 -0
- keras_hub/src/models/opt/opt_presets.py +72 -0
- keras_hub/src/models/opt/opt_tokenizer.py +116 -0
- keras_hub/src/models/pali_gemma/__init__.py +23 -0
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
- keras_hub/src/models/phi3/__init__.py +20 -0
- keras_hub/src/models/phi3/phi3_attention.py +260 -0
- keras_hub/src/models/phi3/phi3_backbone.py +224 -0
- keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/phi3/phi3_decoder.py +260 -0
- keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
- keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
- keras_hub/src/models/phi3/phi3_presets.py +50 -0
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
- keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
- keras_hub/src/models/preprocessor.py +207 -0
- keras_hub/src/models/resnet/__init__.py +13 -0
- keras_hub/src/models/resnet/resnet_backbone.py +612 -0
- keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
- keras_hub/src/models/roberta/__init__.py +20 -0
- keras_hub/src/models/roberta/roberta_backbone.py +184 -0
- keras_hub/src/models/roberta/roberta_classifier.py +209 -0
- keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
- keras_hub/src/models/roberta/roberta_presets.py +43 -0
- keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
- keras_hub/src/models/seq_2_seq_lm.py +54 -0
- keras_hub/src/models/t5/__init__.py +20 -0
- keras_hub/src/models/t5/t5_backbone.py +261 -0
- keras_hub/src/models/t5/t5_layer_norm.py +35 -0
- keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
- keras_hub/src/models/t5/t5_presets.py +95 -0
- keras_hub/src/models/t5/t5_tokenizer.py +100 -0
- keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
- keras_hub/src/models/task.py +419 -0
- keras_hub/src/models/vgg/__init__.py +13 -0
- keras_hub/src/models/vgg/vgg_backbone.py +158 -0
- keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
- keras_hub/src/models/vit_det/__init__.py +13 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
- keras_hub/src/models/vit_det/vit_layers.py +565 -0
- keras_hub/src/models/whisper/__init__.py +20 -0
- keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
- keras_hub/src/models/whisper/whisper_backbone.py +305 -0
- keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
- keras_hub/src/models/whisper/whisper_decoder.py +141 -0
- keras_hub/src/models/whisper/whisper_encoder.py +106 -0
- keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
- keras_hub/src/models/whisper/whisper_presets.py +148 -0
- keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
- keras_hub/src/models/xlm_roberta/__init__.py +26 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
- keras_hub/src/models/xlnet/__init__.py +13 -0
- keras_hub/src/models/xlnet/relative_attention.py +459 -0
- keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
- keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
- keras_hub/src/samplers/__init__.py +13 -0
- keras_hub/src/samplers/beam_sampler.py +207 -0
- keras_hub/src/samplers/contrastive_sampler.py +231 -0
- keras_hub/src/samplers/greedy_sampler.py +50 -0
- keras_hub/src/samplers/random_sampler.py +77 -0
- keras_hub/src/samplers/sampler.py +237 -0
- keras_hub/src/samplers/serialization.py +97 -0
- keras_hub/src/samplers/top_k_sampler.py +92 -0
- keras_hub/src/samplers/top_p_sampler.py +113 -0
- keras_hub/src/tests/__init__.py +13 -0
- keras_hub/src/tests/test_case.py +608 -0
- keras_hub/src/tokenizers/__init__.py +13 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
- keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
- keras_hub/src/tokenizers/tokenizer.py +235 -0
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
- keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
- keras_hub/src/utils/__init__.py +13 -0
- keras_hub/src/utils/keras_utils.py +130 -0
- keras_hub/src/utils/pipeline_model.py +293 -0
- keras_hub/src/utils/preset_utils.py +621 -0
- keras_hub/src/utils/python_utils.py +21 -0
- keras_hub/src/utils/tensor_utils.py +206 -0
- keras_hub/src/utils/timm/__init__.py +13 -0
- keras_hub/src/utils/timm/convert.py +37 -0
- keras_hub/src/utils/timm/convert_resnet.py +171 -0
- keras_hub/src/utils/transformers/__init__.py +13 -0
- keras_hub/src/utils/transformers/convert.py +101 -0
- keras_hub/src/utils/transformers/convert_bert.py +173 -0
- keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
- keras_hub/src/utils/transformers/convert_gemma.py +187 -0
- keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
- keras_hub/src/utils/transformers/convert_llama3.py +136 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
- keras_hub/src/version_utils.py +23 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,565 @@
|
|
1
|
+
# Copyright 2024 The KerasCV Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
|
16
|
+
import keras
|
17
|
+
from keras import ops
|
18
|
+
|
19
|
+
|
20
|
+
class MLP(keras.layers.Layer):
|
21
|
+
"""A MLP block with architecture.
|
22
|
+
|
23
|
+
The MLP block implements `input_dim -> [intermediate_dim] ->
|
24
|
+
hidden_dim`. The code has been adapted from [Segment Anything paper](
|
25
|
+
https://arxiv.org/abs/2304.02643), [Segment Anything GitHub](
|
26
|
+
https://github.com/facebookresearch/segment-anything) and [Detectron2](
|
27
|
+
https://github.com/facebookresearch/detectron2).
|
28
|
+
|
29
|
+
Args:
|
30
|
+
intermediate_dim (int): The number of units in the hidden layers.
|
31
|
+
hidden_dim (int): The number of units in the output layer.
|
32
|
+
activation (str): Activation to use in the hidden layers.
|
33
|
+
Default is `"relu"`.
|
34
|
+
"""
|
35
|
+
|
36
|
+
def __init__(
|
37
|
+
self, intermediate_dim, hidden_dim, activation="relu", **kwargs
|
38
|
+
):
|
39
|
+
super().__init__(**kwargs)
|
40
|
+
self.intermediate_dim = intermediate_dim
|
41
|
+
self.hidden_dim = hidden_dim
|
42
|
+
self.activation = activation
|
43
|
+
h = [intermediate_dim]
|
44
|
+
self.dense_net = []
|
45
|
+
for intermediate_dim in h:
|
46
|
+
self.dense_net.append(keras.layers.Dense(intermediate_dim))
|
47
|
+
self.dense_net.append(keras.layers.Activation(activation))
|
48
|
+
self.dense_net.append(keras.layers.Dense(hidden_dim))
|
49
|
+
self.dense_net = keras.models.Sequential(self.dense_net)
|
50
|
+
|
51
|
+
def build(self, input_shape):
|
52
|
+
self.dense_net.build(input_shape)
|
53
|
+
self.built = True
|
54
|
+
|
55
|
+
def call(self, x):
|
56
|
+
return self.dense_net(x)
|
57
|
+
|
58
|
+
def get_config(self):
|
59
|
+
config = super().get_config()
|
60
|
+
config.update(
|
61
|
+
{
|
62
|
+
"intermediate_dim": self.intermediate_dim,
|
63
|
+
"hidden_dim": self.hidden_dim,
|
64
|
+
"activation": self.activation,
|
65
|
+
}
|
66
|
+
)
|
67
|
+
return config
|
68
|
+
|
69
|
+
|
70
|
+
class AddRelativePositionalEmbedding(keras.layers.Layer):
|
71
|
+
def __init__(self, input_size, key_dim, **kwargs):
|
72
|
+
super().__init__(**kwargs)
|
73
|
+
self.input_size = input_size
|
74
|
+
self.key_dim = key_dim
|
75
|
+
self.rel_pos_h = self.add_weight(
|
76
|
+
name="rel_pos_h",
|
77
|
+
shape=(2 * self.input_size[0] - 1, self.key_dim),
|
78
|
+
initializer="zeros",
|
79
|
+
)
|
80
|
+
self.rel_pos_w = self.add_weight(
|
81
|
+
name="rel_pos_w",
|
82
|
+
shape=(2 * self.input_size[1] - 1, self.key_dim),
|
83
|
+
initializer="zeros",
|
84
|
+
)
|
85
|
+
self.built = True
|
86
|
+
|
87
|
+
def _get_rel_pos(self, query_size, key_size, rel_pos):
|
88
|
+
"""Get relative positional embeddings.
|
89
|
+
|
90
|
+
Get relative positional embeddings according to the relative positions
|
91
|
+
of query and key sizes.
|
92
|
+
|
93
|
+
Args:
|
94
|
+
query_size (int): The number of features of the queries.
|
95
|
+
key_size (int): The number of features of the keys.
|
96
|
+
rel_pos (tensor): Relative positional embedding tensor.
|
97
|
+
|
98
|
+
Returns:
|
99
|
+
tensor: Extracted positional embeddings according to relative
|
100
|
+
positions.
|
101
|
+
"""
|
102
|
+
max_rel_dist = 2 * max(query_size, key_size) - 1
|
103
|
+
if ops.shape(rel_pos)[0] != max_rel_dist:
|
104
|
+
rel_pos_resized = ops.image.resize(
|
105
|
+
image=ops.reshape(
|
106
|
+
rel_pos,
|
107
|
+
(1, ops.shape(rel_pos)[0], ops.shape(rel_pos)[1], 1),
|
108
|
+
),
|
109
|
+
size=(max_rel_dist, ops.shape(rel_pos)[1]),
|
110
|
+
interpolation="bilinear",
|
111
|
+
)
|
112
|
+
rel_pos_resized = ops.squeeze(rel_pos_resized, axis=(0, -1))
|
113
|
+
return rel_pos_resized
|
114
|
+
else:
|
115
|
+
rel_pos_resized = rel_pos
|
116
|
+
# Query coordinates
|
117
|
+
query_coordinates = ops.cast(
|
118
|
+
ops.arange(query_size), dtype=self.compute_dtype
|
119
|
+
)[:, None] * (max(key_size / query_size, 1.0))
|
120
|
+
# Key coordinates
|
121
|
+
key_coordinates = ops.cast(
|
122
|
+
ops.arange(key_size), dtype=self.compute_dtype
|
123
|
+
)[None, :] * (max(query_size / key_size, 1.0))
|
124
|
+
# Relative coordinates
|
125
|
+
relative_coordinates = (query_coordinates - key_coordinates) + (
|
126
|
+
key_size - 1
|
127
|
+
) * max(query_size / key_size, 1.0)
|
128
|
+
relative_coordinates = ops.cast(relative_coordinates, dtype="int32")
|
129
|
+
return ops.take(rel_pos_resized, relative_coordinates, 0)
|
130
|
+
|
131
|
+
def call(self, attention_map, queries, query_size, key_size):
|
132
|
+
"""Calculate decomposed Relative Positional Embeddings
|
133
|
+
|
134
|
+
The code has been adapted based on
|
135
|
+
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa: E501
|
136
|
+
|
137
|
+
Args:
|
138
|
+
attention_map (tensor): Attention map.
|
139
|
+
queries (tensor): Queries in the attention layer with shape
|
140
|
+
`(batch, query_height * query_width, channels)`.
|
141
|
+
query_size (tuple[int, int]): Spatial sequence size of queries with
|
142
|
+
`(query_height, query_width)`.
|
143
|
+
key_size (tuple[int, int]): Spatial sequence size of keys with
|
144
|
+
`(key_height, key_width)`.
|
145
|
+
|
146
|
+
Returns:
|
147
|
+
tensor: attention map with added relative positional embeddings.
|
148
|
+
"""
|
149
|
+
query_height, query_width = query_size[0], query_size[1]
|
150
|
+
key_height, key_width = key_size[0], key_size[1]
|
151
|
+
rel_heights = self._get_rel_pos(
|
152
|
+
query_height, key_height, self.rel_pos_h
|
153
|
+
)
|
154
|
+
rel_widths = self._get_rel_pos(query_width, key_width, self.rel_pos_w)
|
155
|
+
shape = ops.shape(queries)
|
156
|
+
batch, channels = shape[0], shape[2]
|
157
|
+
rel_queries = ops.reshape(
|
158
|
+
queries, (batch, query_height, query_width, channels)
|
159
|
+
)
|
160
|
+
rel_heights = ops.einsum("bhwc,hkc->bhwk", rel_queries, rel_heights)
|
161
|
+
rel_widths = ops.einsum("bhwc,wkc->bhwk", rel_queries, rel_widths)
|
162
|
+
attention_map = ops.reshape(
|
163
|
+
attention_map,
|
164
|
+
(batch, query_height, query_width, key_height, key_width),
|
165
|
+
)
|
166
|
+
attention_map = attention_map + rel_heights[..., :, None]
|
167
|
+
attention_map = attention_map + rel_widths[..., None, :]
|
168
|
+
attention_map = ops.reshape(
|
169
|
+
attention_map,
|
170
|
+
(batch, query_height * query_width, key_height * key_width),
|
171
|
+
)
|
172
|
+
return attention_map
|
173
|
+
|
174
|
+
def get_config(self):
|
175
|
+
config = super().get_config()
|
176
|
+
config.update({"input_size": self.input_size, "key_dim": self.key_dim})
|
177
|
+
return config
|
178
|
+
|
179
|
+
|
180
|
+
class MultiHeadAttentionWithRelativePE(keras.layers.Layer):
|
181
|
+
"""Multi-head Attention block with relative position embeddings.
|
182
|
+
|
183
|
+
The code has been adapted from [Segment Anything paper](
|
184
|
+
https://arxiv.org/abs/2304.02643), [Segment Anything GitHub](
|
185
|
+
https://github.com/facebookresearch/segment-anything) and [Detectron2](
|
186
|
+
https://github.com/facebookresearch/detectron2).
|
187
|
+
|
188
|
+
Args:
|
189
|
+
num_heads (int): Number of attention heads.
|
190
|
+
key_dim (int): Size of each attention head for query, key, and
|
191
|
+
value.
|
192
|
+
use_bias (bool, optional): Whether to use bias when projecting
|
193
|
+
the queries, keys, and values. Defaults to `True`.
|
194
|
+
use_rel_pos (bool, optional): Whether to use relative positional
|
195
|
+
embeddings or not. Defaults to `False`.
|
196
|
+
input_size (tuple[int, int], optional): Size of the input image.
|
197
|
+
Must be provided when using relative positional embeddings.
|
198
|
+
Defaults to `None`.
|
199
|
+
|
200
|
+
Raises:
|
201
|
+
ValueError: When `input_size = None` with `use_rel_pos = True`.
|
202
|
+
"""
|
203
|
+
|
204
|
+
def __init__(
|
205
|
+
self,
|
206
|
+
num_heads,
|
207
|
+
key_dim,
|
208
|
+
use_bias=True,
|
209
|
+
use_rel_pos=False,
|
210
|
+
input_size=None,
|
211
|
+
**kwargs
|
212
|
+
):
|
213
|
+
super().__init__(**kwargs)
|
214
|
+
self.num_heads = num_heads
|
215
|
+
self.key_dim = key_dim
|
216
|
+
self.scale = self.key_dim**-0.5
|
217
|
+
self.use_bias = use_bias
|
218
|
+
self.input_size = input_size
|
219
|
+
self.use_rel_pos = use_rel_pos
|
220
|
+
self.qkv = keras.layers.Dense(
|
221
|
+
key_dim * self.num_heads * 3, use_bias=self.use_bias
|
222
|
+
)
|
223
|
+
self.projection = keras.layers.Dense(key_dim * self.num_heads)
|
224
|
+
if self.use_rel_pos:
|
225
|
+
if input_size is None:
|
226
|
+
raise ValueError(
|
227
|
+
"Input size must be provided if using relative "
|
228
|
+
"positional encoding."
|
229
|
+
)
|
230
|
+
self.add_decomposed_reative_pe = AddRelativePositionalEmbedding(
|
231
|
+
self.input_size, self.key_dim
|
232
|
+
)
|
233
|
+
|
234
|
+
def build(self, input_shape=None):
|
235
|
+
self.qkv.build([self.key_dim * self.num_heads])
|
236
|
+
self.projection.build([self.key_dim * self.num_heads])
|
237
|
+
self.built = True
|
238
|
+
|
239
|
+
def compute_output_shape(self, input_shape):
|
240
|
+
return input_shape
|
241
|
+
|
242
|
+
def call(self, x):
|
243
|
+
batch, height, width, channels = ops.shape(x)
|
244
|
+
qkv = ops.transpose(
|
245
|
+
ops.reshape(
|
246
|
+
self.qkv(x),
|
247
|
+
(batch, height * width, 3, self.num_heads, self.key_dim),
|
248
|
+
),
|
249
|
+
axes=(2, 0, 3, 1, 4),
|
250
|
+
)
|
251
|
+
qkv = ops.reshape(
|
252
|
+
qkv, (3, batch * self.num_heads, height * width, self.key_dim)
|
253
|
+
)
|
254
|
+
queries, keys, values = ops.unstack(qkv, axis=0)
|
255
|
+
attention_map = (queries * self.scale) @ ops.transpose(
|
256
|
+
keys, axes=(0, 2, 1)
|
257
|
+
)
|
258
|
+
if self.use_rel_pos:
|
259
|
+
attention_map = self.add_decomposed_reative_pe(
|
260
|
+
attention_map,
|
261
|
+
queries=queries,
|
262
|
+
query_size=(height, width),
|
263
|
+
key_size=(height, width),
|
264
|
+
)
|
265
|
+
attention_map = ops.softmax(attention_map, axis=-1)
|
266
|
+
x = ops.reshape(
|
267
|
+
attention_map @ values,
|
268
|
+
(batch, self.num_heads, height, width, self.key_dim),
|
269
|
+
)
|
270
|
+
x = ops.transpose(x, axes=(0, 2, 3, 1, 4))
|
271
|
+
x = ops.reshape(x, (batch, height, width, channels))
|
272
|
+
x = self.projection(x)
|
273
|
+
|
274
|
+
return x
|
275
|
+
|
276
|
+
def get_config(self):
|
277
|
+
config = super().get_config()
|
278
|
+
config.update(
|
279
|
+
{
|
280
|
+
"num_heads": self.num_heads,
|
281
|
+
"key_dim": self.key_dim,
|
282
|
+
"use_bias": self.use_bias,
|
283
|
+
"use_rel_pos": self.use_rel_pos,
|
284
|
+
"input_size": self.input_size,
|
285
|
+
}
|
286
|
+
)
|
287
|
+
return config
|
288
|
+
|
289
|
+
|
290
|
+
class WindowPartitioning(keras.layers.Layer):
|
291
|
+
def __init__(self, window_size, **kwargs):
|
292
|
+
super().__init__(**kwargs)
|
293
|
+
self.window_size = window_size
|
294
|
+
self.built = True
|
295
|
+
|
296
|
+
def partition(self, x):
|
297
|
+
batch, height, width, channels = ops.shape(x)
|
298
|
+
pad_height = (
|
299
|
+
self.window_size - height % self.window_size
|
300
|
+
) % self.window_size
|
301
|
+
pad_width = (
|
302
|
+
self.window_size - width % self.window_size
|
303
|
+
) % self.window_size
|
304
|
+
if pad_height > 0 or pad_width > 0:
|
305
|
+
x = ops.pad(x, ((0, 0), (0, pad_height), (0, pad_width), (0, 0)))
|
306
|
+
height_padded, width_padded = height + pad_height, width + pad_width
|
307
|
+
x = ops.reshape(
|
308
|
+
x,
|
309
|
+
(
|
310
|
+
batch,
|
311
|
+
height_padded // self.window_size,
|
312
|
+
self.window_size,
|
313
|
+
width_padded // self.window_size,
|
314
|
+
self.window_size,
|
315
|
+
channels,
|
316
|
+
),
|
317
|
+
)
|
318
|
+
windows = ops.reshape(
|
319
|
+
ops.transpose(x, axes=(0, 1, 3, 2, 4, 5)),
|
320
|
+
(-1, self.window_size, self.window_size, channels),
|
321
|
+
)
|
322
|
+
return windows, (height_padded, width_padded)
|
323
|
+
|
324
|
+
def unpartition(self, windows, height_width_padded, height_width):
|
325
|
+
height_padded, width_padded = height_width_padded
|
326
|
+
height, width = height_width
|
327
|
+
batch = ops.shape(windows)[0] // (
|
328
|
+
(height_padded // self.window_size)
|
329
|
+
* (width_padded // self.window_size)
|
330
|
+
)
|
331
|
+
x = ops.reshape(
|
332
|
+
windows,
|
333
|
+
(
|
334
|
+
batch,
|
335
|
+
height_padded // self.window_size,
|
336
|
+
width_padded // self.window_size,
|
337
|
+
self.window_size,
|
338
|
+
self.window_size,
|
339
|
+
-1,
|
340
|
+
),
|
341
|
+
)
|
342
|
+
x = ops.reshape(
|
343
|
+
ops.transpose(x, axes=(0, 1, 3, 2, 4, 5)),
|
344
|
+
(batch, height_padded, width_padded, -1),
|
345
|
+
)
|
346
|
+
return x[:, :height, :width, :]
|
347
|
+
|
348
|
+
def get_config(self):
|
349
|
+
config = super().get_config()
|
350
|
+
config.update({"window_size": self.window_size})
|
351
|
+
return config
|
352
|
+
|
353
|
+
|
354
|
+
class WindowedTransformerEncoder(keras.layers.Layer):
|
355
|
+
"""Implements windowed transformer encoder.
|
356
|
+
|
357
|
+
Transformer blocks with support of window attention and residual
|
358
|
+
propagation blocks. The code has been adapted from [Segment Anything paper](
|
359
|
+
https://arxiv.org/abs/2304.02643), [Segment Anything GitHub](
|
360
|
+
https://github.com/facebookresearch/segment-anything) and [Detectron2](
|
361
|
+
https://github.com/facebookresearch/detectron2).
|
362
|
+
|
363
|
+
Args:
|
364
|
+
project_dim (int): the dimensionality of the projection of the
|
365
|
+
encoder, and output of the `MultiHeadAttention`.
|
366
|
+
intermediate_dim (int): the intermediate dimensionality of the MLP head
|
367
|
+
before projecting to `project_dim`.
|
368
|
+
num_heads (int): the number of heads for the `MultiHeadAttention`
|
369
|
+
layer.
|
370
|
+
use_bias (bool, optional): Whether to use bias to project the keys,
|
371
|
+
queries, and values in the attention layer. Defaults to `True`.
|
372
|
+
use_rel_pos (bool, optional): Whether to use relative positional
|
373
|
+
emcodings in the attention layer. Defaults to `False`.
|
374
|
+
window_size (int, optional): Window size for windowed attention.
|
375
|
+
Defaults to `0`.
|
376
|
+
input_size (tuple[int, int], optional): Height and width of the input
|
377
|
+
image as a tuple of integers. Must be provided when using relative
|
378
|
+
positional embeddings. Defaults to `None`.
|
379
|
+
activation (str, optional): the activation function to apply in the
|
380
|
+
MLP head - should be a function. Defaults to `"gelu"`.
|
381
|
+
layer_norm_epsilon (float, optional): The epsilon to use in the layer
|
382
|
+
normalization layers. Defaults to `1e-6`.
|
383
|
+
"""
|
384
|
+
|
385
|
+
def __init__(
|
386
|
+
self,
|
387
|
+
project_dim,
|
388
|
+
intermediate_dim,
|
389
|
+
num_heads,
|
390
|
+
use_bias=True,
|
391
|
+
use_rel_pos=False,
|
392
|
+
window_size=0,
|
393
|
+
input_size=None,
|
394
|
+
activation="gelu",
|
395
|
+
layer_norm_epsilon=1e-6,
|
396
|
+
**kwargs
|
397
|
+
):
|
398
|
+
super().__init__(**kwargs)
|
399
|
+
self.project_dim = project_dim
|
400
|
+
self.intermediate_dim = intermediate_dim
|
401
|
+
self.num_heads = num_heads
|
402
|
+
self.use_bias = use_bias
|
403
|
+
self.input_size = input_size
|
404
|
+
self.activation = activation
|
405
|
+
self.layer_norm_epsilon = layer_norm_epsilon
|
406
|
+
self.window_size = window_size
|
407
|
+
self.use_rel_pos = use_rel_pos
|
408
|
+
|
409
|
+
self.layer_norm1 = keras.layers.LayerNormalization(
|
410
|
+
epsilon=self.layer_norm_epsilon
|
411
|
+
)
|
412
|
+
self.layer_norm2 = keras.layers.LayerNormalization(
|
413
|
+
epsilon=self.layer_norm_epsilon
|
414
|
+
)
|
415
|
+
self.attention = MultiHeadAttentionWithRelativePE(
|
416
|
+
num_heads=self.num_heads,
|
417
|
+
key_dim=self.project_dim // self.num_heads,
|
418
|
+
use_bias=use_bias,
|
419
|
+
use_rel_pos=use_rel_pos,
|
420
|
+
input_size=(
|
421
|
+
input_size if window_size == 0 else (window_size, window_size)
|
422
|
+
),
|
423
|
+
)
|
424
|
+
self.mlp_block = MLP(
|
425
|
+
intermediate_dim,
|
426
|
+
project_dim,
|
427
|
+
activation="gelu",
|
428
|
+
)
|
429
|
+
self.window_partitioning = WindowPartitioning(window_size)
|
430
|
+
|
431
|
+
def build(self, input_shape=None):
|
432
|
+
self.layer_norm1.build([None, None, None, self.project_dim])
|
433
|
+
self.layer_norm2.build([None, None, None, self.project_dim])
|
434
|
+
self.attention.build()
|
435
|
+
self.mlp_block.build([None, None, None, self.project_dim])
|
436
|
+
self.built = True
|
437
|
+
|
438
|
+
def compute_output_shape(self, input_shape):
|
439
|
+
return input_shape
|
440
|
+
|
441
|
+
def call(self, x):
|
442
|
+
shortcut = x
|
443
|
+
x = self.layer_norm1(x)
|
444
|
+
# Window Partition
|
445
|
+
if self.window_size > 0:
|
446
|
+
height, width = ops.shape(x)[1], ops.shape(x)[2]
|
447
|
+
x, height_width_padded = self.window_partitioning.partition(x)
|
448
|
+
|
449
|
+
x = self.attention(x)
|
450
|
+
# Reverse Window Partition
|
451
|
+
if self.window_size > 0:
|
452
|
+
x = self.window_partitioning.unpartition(
|
453
|
+
x,
|
454
|
+
height_width_padded=height_width_padded,
|
455
|
+
height_width=(height, width),
|
456
|
+
)
|
457
|
+
x = shortcut + x
|
458
|
+
x = x + self.mlp_block(self.layer_norm2(x))
|
459
|
+
return x
|
460
|
+
|
461
|
+
def get_config(self):
|
462
|
+
config = super().get_config()
|
463
|
+
config.update(
|
464
|
+
{
|
465
|
+
"project_dim": self.project_dim,
|
466
|
+
"intermediate_dim": self.intermediate_dim,
|
467
|
+
"num_heads": self.num_heads,
|
468
|
+
"use_bias": self.use_bias,
|
469
|
+
"use_rel_pos": self.use_rel_pos,
|
470
|
+
"window_size": self.window_size,
|
471
|
+
"input_size": self.input_size,
|
472
|
+
"activation": self.activation,
|
473
|
+
"layer_norm_epsilon": self.layer_norm_epsilon,
|
474
|
+
}
|
475
|
+
)
|
476
|
+
return config
|
477
|
+
|
478
|
+
|
479
|
+
class ViTDetPatchingAndEmbedding(keras.layers.Layer):
|
480
|
+
"""
|
481
|
+
Implements a image patch and embedding layer.
|
482
|
+
|
483
|
+
Image to Patch Embedding using only a conv layer (without
|
484
|
+
layer normalization).The code has been adapted from [Segment Anything
|
485
|
+
paper](https://arxiv.org/abs/2304.02643), [Segment Anything GitHub](
|
486
|
+
https://github.com/facebookresearch/segment-anything) and [Detectron2](
|
487
|
+
https://github.com/facebookresearch/detectron2).
|
488
|
+
|
489
|
+
Args:
|
490
|
+
kernel_size (tuple[int, int], optional): Kernel size of the
|
491
|
+
projection layer. Defaults to `(16, 16)`.
|
492
|
+
strides (tuple, optional): Strides of the projection layer.
|
493
|
+
Defaults to `(16, 16)`.
|
494
|
+
embed_dim (int, optional): Number of filters to use in the
|
495
|
+
projection layer i.e. projection size. Defaults to `768`.
|
496
|
+
"""
|
497
|
+
|
498
|
+
def __init__(
|
499
|
+
self, kernel_size=(16, 16), strides=(16, 16), embed_dim=768, **kwargs
|
500
|
+
):
|
501
|
+
super().__init__(**kwargs)
|
502
|
+
|
503
|
+
self.projection = keras.layers.Conv2D(
|
504
|
+
embed_dim, kernel_size=kernel_size, strides=strides
|
505
|
+
)
|
506
|
+
self.kernel_size = kernel_size
|
507
|
+
self.strides = strides
|
508
|
+
self.embed_dim = embed_dim
|
509
|
+
|
510
|
+
def build(self, input_shape):
|
511
|
+
self.projection.build(input_shape)
|
512
|
+
self.built = True
|
513
|
+
|
514
|
+
def compute_output_shape(self, input_shape):
|
515
|
+
return self.projection.compute_output_shape(input_shape)
|
516
|
+
|
517
|
+
def call(self, x):
|
518
|
+
x = self.projection(x)
|
519
|
+
return x
|
520
|
+
|
521
|
+
def get_config(self):
|
522
|
+
config = super().get_config()
|
523
|
+
config.update(
|
524
|
+
{
|
525
|
+
"kernel_size": self.kernel_size,
|
526
|
+
"strides": self.strides,
|
527
|
+
"embed_dim": self.embed_dim,
|
528
|
+
}
|
529
|
+
)
|
530
|
+
return config
|
531
|
+
|
532
|
+
|
533
|
+
class AddPositionalEmbedding(keras.layers.Layer):
|
534
|
+
def __init__(self, img_size, patch_size, embed_dim, **kwargs):
|
535
|
+
super().__init__(**kwargs)
|
536
|
+
self.img_size = img_size
|
537
|
+
self.patch_size = patch_size
|
538
|
+
self.embed_dim = embed_dim
|
539
|
+
self.pos_embed = self.add_weight(
|
540
|
+
name="pos_embed",
|
541
|
+
shape=(
|
542
|
+
1,
|
543
|
+
img_size // patch_size,
|
544
|
+
img_size // patch_size,
|
545
|
+
embed_dim,
|
546
|
+
),
|
547
|
+
initializer="zeros",
|
548
|
+
)
|
549
|
+
|
550
|
+
def compute_output_shape(self, input_shape):
|
551
|
+
return input_shape
|
552
|
+
|
553
|
+
def call(self, x):
|
554
|
+
return x + self.pos_embed
|
555
|
+
|
556
|
+
def get_confg(self):
|
557
|
+
config = super().get_config()
|
558
|
+
config.update(
|
559
|
+
{
|
560
|
+
"img_size": self.img_size,
|
561
|
+
"patch_size": self.patch_size,
|
562
|
+
"embed_dim": self.embed_dim,
|
563
|
+
}
|
564
|
+
)
|
565
|
+
return config
|
@@ -0,0 +1,20 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from keras_hub.src.models.whisper.whisper_backbone import WhisperBackbone
|
16
|
+
from keras_hub.src.models.whisper.whisper_presets import backbone_presets
|
17
|
+
from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer
|
18
|
+
from keras_hub.src.utils.preset_utils import register_presets
|
19
|
+
|
20
|
+
register_presets(backbone_presets, (WhisperBackbone, WhisperTokenizer))
|