keras-hub-nightly 0.15.0.dev20240823171555__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/__init__.py +52 -0
- keras_hub/api/__init__.py +27 -0
- keras_hub/api/layers/__init__.py +47 -0
- keras_hub/api/metrics/__init__.py +24 -0
- keras_hub/api/models/__init__.py +249 -0
- keras_hub/api/samplers/__init__.py +29 -0
- keras_hub/api/tokenizers/__init__.py +35 -0
- keras_hub/src/__init__.py +13 -0
- keras_hub/src/api_export.py +53 -0
- keras_hub/src/layers/__init__.py +13 -0
- keras_hub/src/layers/modeling/__init__.py +13 -0
- keras_hub/src/layers/modeling/alibi_bias.py +143 -0
- keras_hub/src/layers/modeling/cached_multi_head_attention.py +137 -0
- keras_hub/src/layers/modeling/f_net_encoder.py +200 -0
- keras_hub/src/layers/modeling/masked_lm_head.py +239 -0
- keras_hub/src/layers/modeling/position_embedding.py +123 -0
- keras_hub/src/layers/modeling/reversible_embedding.py +311 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +169 -0
- keras_hub/src/layers/modeling/sine_position_encoding.py +108 -0
- keras_hub/src/layers/modeling/token_and_position_embedding.py +150 -0
- keras_hub/src/layers/modeling/transformer_decoder.py +496 -0
- keras_hub/src/layers/modeling/transformer_encoder.py +262 -0
- keras_hub/src/layers/modeling/transformer_layer_utils.py +106 -0
- keras_hub/src/layers/preprocessing/__init__.py +13 -0
- keras_hub/src/layers/preprocessing/masked_lm_mask_generator.py +220 -0
- keras_hub/src/layers/preprocessing/multi_segment_packer.py +319 -0
- keras_hub/src/layers/preprocessing/preprocessing_layer.py +62 -0
- keras_hub/src/layers/preprocessing/random_deletion.py +271 -0
- keras_hub/src/layers/preprocessing/random_swap.py +267 -0
- keras_hub/src/layers/preprocessing/start_end_packer.py +219 -0
- keras_hub/src/metrics/__init__.py +13 -0
- keras_hub/src/metrics/bleu.py +394 -0
- keras_hub/src/metrics/edit_distance.py +197 -0
- keras_hub/src/metrics/perplexity.py +181 -0
- keras_hub/src/metrics/rouge_base.py +204 -0
- keras_hub/src/metrics/rouge_l.py +97 -0
- keras_hub/src/metrics/rouge_n.py +125 -0
- keras_hub/src/models/__init__.py +13 -0
- keras_hub/src/models/albert/__init__.py +20 -0
- keras_hub/src/models/albert/albert_backbone.py +267 -0
- keras_hub/src/models/albert/albert_classifier.py +202 -0
- keras_hub/src/models/albert/albert_masked_lm.py +129 -0
- keras_hub/src/models/albert/albert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/albert/albert_preprocessor.py +206 -0
- keras_hub/src/models/albert/albert_presets.py +70 -0
- keras_hub/src/models/albert/albert_tokenizer.py +119 -0
- keras_hub/src/models/backbone.py +311 -0
- keras_hub/src/models/bart/__init__.py +20 -0
- keras_hub/src/models/bart/bart_backbone.py +261 -0
- keras_hub/src/models/bart/bart_preprocessor.py +276 -0
- keras_hub/src/models/bart/bart_presets.py +74 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +490 -0
- keras_hub/src/models/bart/bart_seq_2_seq_lm_preprocessor.py +262 -0
- keras_hub/src/models/bart/bart_tokenizer.py +124 -0
- keras_hub/src/models/bert/__init__.py +23 -0
- keras_hub/src/models/bert/bert_backbone.py +227 -0
- keras_hub/src/models/bert/bert_classifier.py +183 -0
- keras_hub/src/models/bert/bert_masked_lm.py +131 -0
- keras_hub/src/models/bert/bert_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/bert/bert_preprocessor.py +184 -0
- keras_hub/src/models/bert/bert_presets.py +147 -0
- keras_hub/src/models/bert/bert_tokenizer.py +112 -0
- keras_hub/src/models/bloom/__init__.py +20 -0
- keras_hub/src/models/bloom/bloom_attention.py +186 -0
- keras_hub/src/models/bloom/bloom_backbone.py +173 -0
- keras_hub/src/models/bloom/bloom_causal_lm.py +298 -0
- keras_hub/src/models/bloom/bloom_causal_lm_preprocessor.py +176 -0
- keras_hub/src/models/bloom/bloom_decoder.py +206 -0
- keras_hub/src/models/bloom/bloom_preprocessor.py +185 -0
- keras_hub/src/models/bloom/bloom_presets.py +121 -0
- keras_hub/src/models/bloom/bloom_tokenizer.py +116 -0
- keras_hub/src/models/causal_lm.py +383 -0
- keras_hub/src/models/classifier.py +109 -0
- keras_hub/src/models/csp_darknet/__init__.py +13 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +410 -0
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +133 -0
- keras_hub/src/models/deberta_v3/__init__.py +24 -0
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +210 -0
- keras_hub/src/models/deberta_v3/deberta_v3_classifier.py +228 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm.py +135 -0
- keras_hub/src/models/deberta_v3/deberta_v3_masked_lm_preprocessor.py +191 -0
- keras_hub/src/models/deberta_v3/deberta_v3_preprocessor.py +206 -0
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +82 -0
- keras_hub/src/models/deberta_v3/deberta_v3_tokenizer.py +155 -0
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +227 -0
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +412 -0
- keras_hub/src/models/deberta_v3/relative_embedding.py +94 -0
- keras_hub/src/models/densenet/__init__.py +13 -0
- keras_hub/src/models/densenet/densenet_backbone.py +210 -0
- keras_hub/src/models/densenet/densenet_image_classifier.py +131 -0
- keras_hub/src/models/distil_bert/__init__.py +26 -0
- keras_hub/src/models/distil_bert/distil_bert_backbone.py +187 -0
- keras_hub/src/models/distil_bert/distil_bert_classifier.py +208 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +137 -0
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +194 -0
- keras_hub/src/models/distil_bert/distil_bert_preprocessor.py +175 -0
- keras_hub/src/models/distil_bert/distil_bert_presets.py +57 -0
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +114 -0
- keras_hub/src/models/electra/__init__.py +20 -0
- keras_hub/src/models/electra/electra_backbone.py +247 -0
- keras_hub/src/models/electra/electra_preprocessor.py +154 -0
- keras_hub/src/models/electra/electra_presets.py +95 -0
- keras_hub/src/models/electra/electra_tokenizer.py +104 -0
- keras_hub/src/models/f_net/__init__.py +20 -0
- keras_hub/src/models/f_net/f_net_backbone.py +236 -0
- keras_hub/src/models/f_net/f_net_classifier.py +154 -0
- keras_hub/src/models/f_net/f_net_masked_lm.py +132 -0
- keras_hub/src/models/f_net/f_net_masked_lm_preprocessor.py +196 -0
- keras_hub/src/models/f_net/f_net_preprocessor.py +177 -0
- keras_hub/src/models/f_net/f_net_presets.py +43 -0
- keras_hub/src/models/f_net/f_net_tokenizer.py +95 -0
- keras_hub/src/models/falcon/__init__.py +20 -0
- keras_hub/src/models/falcon/falcon_attention.py +156 -0
- keras_hub/src/models/falcon/falcon_backbone.py +164 -0
- keras_hub/src/models/falcon/falcon_causal_lm.py +291 -0
- keras_hub/src/models/falcon/falcon_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/falcon/falcon_preprocessor.py +187 -0
- keras_hub/src/models/falcon/falcon_presets.py +30 -0
- keras_hub/src/models/falcon/falcon_tokenizer.py +110 -0
- keras_hub/src/models/falcon/falcon_transformer_decoder.py +255 -0
- keras_hub/src/models/feature_pyramid_backbone.py +73 -0
- keras_hub/src/models/gemma/__init__.py +20 -0
- keras_hub/src/models/gemma/gemma_attention.py +250 -0
- keras_hub/src/models/gemma/gemma_backbone.py +316 -0
- keras_hub/src/models/gemma/gemma_causal_lm.py +448 -0
- keras_hub/src/models/gemma/gemma_causal_lm_preprocessor.py +167 -0
- keras_hub/src/models/gemma/gemma_decoder_block.py +241 -0
- keras_hub/src/models/gemma/gemma_preprocessor.py +191 -0
- keras_hub/src/models/gemma/gemma_presets.py +248 -0
- keras_hub/src/models/gemma/gemma_tokenizer.py +103 -0
- keras_hub/src/models/gemma/rms_normalization.py +40 -0
- keras_hub/src/models/gpt2/__init__.py +20 -0
- keras_hub/src/models/gpt2/gpt2_backbone.py +199 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +437 -0
- keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/gpt2/gpt2_preprocessor.py +187 -0
- keras_hub/src/models/gpt2/gpt2_presets.py +82 -0
- keras_hub/src/models/gpt2/gpt2_tokenizer.py +110 -0
- keras_hub/src/models/gpt_neo_x/__init__.py +13 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +251 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone.py +175 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +201 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor.py +141 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +258 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_preprocessor.py +145 -0
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer.py +88 -0
- keras_hub/src/models/image_classifier.py +90 -0
- keras_hub/src/models/llama/__init__.py +20 -0
- keras_hub/src/models/llama/llama_attention.py +225 -0
- keras_hub/src/models/llama/llama_backbone.py +188 -0
- keras_hub/src/models/llama/llama_causal_lm.py +327 -0
- keras_hub/src/models/llama/llama_causal_lm_preprocessor.py +170 -0
- keras_hub/src/models/llama/llama_decoder.py +246 -0
- keras_hub/src/models/llama/llama_layernorm.py +48 -0
- keras_hub/src/models/llama/llama_preprocessor.py +189 -0
- keras_hub/src/models/llama/llama_presets.py +80 -0
- keras_hub/src/models/llama/llama_tokenizer.py +84 -0
- keras_hub/src/models/llama3/__init__.py +20 -0
- keras_hub/src/models/llama3/llama3_backbone.py +84 -0
- keras_hub/src/models/llama3/llama3_causal_lm.py +46 -0
- keras_hub/src/models/llama3/llama3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/llama3/llama3_preprocessor.py +21 -0
- keras_hub/src/models/llama3/llama3_presets.py +69 -0
- keras_hub/src/models/llama3/llama3_tokenizer.py +63 -0
- keras_hub/src/models/masked_lm.py +101 -0
- keras_hub/src/models/mistral/__init__.py +20 -0
- keras_hub/src/models/mistral/mistral_attention.py +238 -0
- keras_hub/src/models/mistral/mistral_backbone.py +203 -0
- keras_hub/src/models/mistral/mistral_causal_lm.py +328 -0
- keras_hub/src/models/mistral/mistral_causal_lm_preprocessor.py +175 -0
- keras_hub/src/models/mistral/mistral_layer_norm.py +48 -0
- keras_hub/src/models/mistral/mistral_preprocessor.py +190 -0
- keras_hub/src/models/mistral/mistral_presets.py +48 -0
- keras_hub/src/models/mistral/mistral_tokenizer.py +82 -0
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +265 -0
- keras_hub/src/models/mix_transformer/__init__.py +13 -0
- keras_hub/src/models/mix_transformer/mix_transformer_backbone.py +181 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +133 -0
- keras_hub/src/models/mix_transformer/mix_transformer_layers.py +300 -0
- keras_hub/src/models/opt/__init__.py +20 -0
- keras_hub/src/models/opt/opt_backbone.py +173 -0
- keras_hub/src/models/opt/opt_causal_lm.py +301 -0
- keras_hub/src/models/opt/opt_causal_lm_preprocessor.py +177 -0
- keras_hub/src/models/opt/opt_preprocessor.py +188 -0
- keras_hub/src/models/opt/opt_presets.py +72 -0
- keras_hub/src/models/opt/opt_tokenizer.py +116 -0
- keras_hub/src/models/pali_gemma/__init__.py +23 -0
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +277 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +313 -0
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm_preprocessor.py +147 -0
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +160 -0
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +78 -0
- keras_hub/src/models/pali_gemma/pali_gemma_tokenizer.py +79 -0
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +566 -0
- keras_hub/src/models/phi3/__init__.py +20 -0
- keras_hub/src/models/phi3/phi3_attention.py +260 -0
- keras_hub/src/models/phi3/phi3_backbone.py +224 -0
- keras_hub/src/models/phi3/phi3_causal_lm.py +218 -0
- keras_hub/src/models/phi3/phi3_causal_lm_preprocessor.py +173 -0
- keras_hub/src/models/phi3/phi3_decoder.py +260 -0
- keras_hub/src/models/phi3/phi3_layernorm.py +48 -0
- keras_hub/src/models/phi3/phi3_preprocessor.py +190 -0
- keras_hub/src/models/phi3/phi3_presets.py +50 -0
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +137 -0
- keras_hub/src/models/phi3/phi3_tokenizer.py +94 -0
- keras_hub/src/models/preprocessor.py +207 -0
- keras_hub/src/models/resnet/__init__.py +13 -0
- keras_hub/src/models/resnet/resnet_backbone.py +612 -0
- keras_hub/src/models/resnet/resnet_image_classifier.py +136 -0
- keras_hub/src/models/roberta/__init__.py +20 -0
- keras_hub/src/models/roberta/roberta_backbone.py +184 -0
- keras_hub/src/models/roberta/roberta_classifier.py +209 -0
- keras_hub/src/models/roberta/roberta_masked_lm.py +136 -0
- keras_hub/src/models/roberta/roberta_masked_lm_preprocessor.py +198 -0
- keras_hub/src/models/roberta/roberta_preprocessor.py +192 -0
- keras_hub/src/models/roberta/roberta_presets.py +43 -0
- keras_hub/src/models/roberta/roberta_tokenizer.py +132 -0
- keras_hub/src/models/seq_2_seq_lm.py +54 -0
- keras_hub/src/models/t5/__init__.py +20 -0
- keras_hub/src/models/t5/t5_backbone.py +261 -0
- keras_hub/src/models/t5/t5_layer_norm.py +35 -0
- keras_hub/src/models/t5/t5_multi_head_attention.py +324 -0
- keras_hub/src/models/t5/t5_presets.py +95 -0
- keras_hub/src/models/t5/t5_tokenizer.py +100 -0
- keras_hub/src/models/t5/t5_transformer_layer.py +178 -0
- keras_hub/src/models/task.py +419 -0
- keras_hub/src/models/vgg/__init__.py +13 -0
- keras_hub/src/models/vgg/vgg_backbone.py +158 -0
- keras_hub/src/models/vgg/vgg_image_classifier.py +124 -0
- keras_hub/src/models/vit_det/__init__.py +13 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +204 -0
- keras_hub/src/models/vit_det/vit_layers.py +565 -0
- keras_hub/src/models/whisper/__init__.py +20 -0
- keras_hub/src/models/whisper/whisper_audio_feature_extractor.py +260 -0
- keras_hub/src/models/whisper/whisper_backbone.py +305 -0
- keras_hub/src/models/whisper/whisper_cached_multi_head_attention.py +153 -0
- keras_hub/src/models/whisper/whisper_decoder.py +141 -0
- keras_hub/src/models/whisper/whisper_encoder.py +106 -0
- keras_hub/src/models/whisper/whisper_preprocessor.py +326 -0
- keras_hub/src/models/whisper/whisper_presets.py +148 -0
- keras_hub/src/models/whisper/whisper_tokenizer.py +163 -0
- keras_hub/src/models/xlm_roberta/__init__.py +26 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_backbone.py +81 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_classifier.py +225 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +141 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +195 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_preprocessor.py +205 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +43 -0
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +191 -0
- keras_hub/src/models/xlnet/__init__.py +13 -0
- keras_hub/src/models/xlnet/relative_attention.py +459 -0
- keras_hub/src/models/xlnet/xlnet_backbone.py +222 -0
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +133 -0
- keras_hub/src/models/xlnet/xlnet_encoder.py +378 -0
- keras_hub/src/samplers/__init__.py +13 -0
- keras_hub/src/samplers/beam_sampler.py +207 -0
- keras_hub/src/samplers/contrastive_sampler.py +231 -0
- keras_hub/src/samplers/greedy_sampler.py +50 -0
- keras_hub/src/samplers/random_sampler.py +77 -0
- keras_hub/src/samplers/sampler.py +237 -0
- keras_hub/src/samplers/serialization.py +97 -0
- keras_hub/src/samplers/top_k_sampler.py +92 -0
- keras_hub/src/samplers/top_p_sampler.py +113 -0
- keras_hub/src/tests/__init__.py +13 -0
- keras_hub/src/tests/test_case.py +608 -0
- keras_hub/src/tokenizers/__init__.py +13 -0
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +638 -0
- keras_hub/src/tokenizers/byte_tokenizer.py +299 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +267 -0
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +150 -0
- keras_hub/src/tokenizers/tokenizer.py +235 -0
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +355 -0
- keras_hub/src/tokenizers/word_piece_tokenizer.py +544 -0
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +176 -0
- keras_hub/src/utils/__init__.py +13 -0
- keras_hub/src/utils/keras_utils.py +130 -0
- keras_hub/src/utils/pipeline_model.py +293 -0
- keras_hub/src/utils/preset_utils.py +621 -0
- keras_hub/src/utils/python_utils.py +21 -0
- keras_hub/src/utils/tensor_utils.py +206 -0
- keras_hub/src/utils/timm/__init__.py +13 -0
- keras_hub/src/utils/timm/convert.py +37 -0
- keras_hub/src/utils/timm/convert_resnet.py +171 -0
- keras_hub/src/utils/transformers/__init__.py +13 -0
- keras_hub/src/utils/transformers/convert.py +101 -0
- keras_hub/src/utils/transformers/convert_bert.py +173 -0
- keras_hub/src/utils/transformers/convert_distilbert.py +184 -0
- keras_hub/src/utils/transformers/convert_gemma.py +187 -0
- keras_hub/src/utils/transformers/convert_gpt2.py +186 -0
- keras_hub/src/utils/transformers/convert_llama3.py +136 -0
- keras_hub/src/utils/transformers/convert_pali_gemma.py +303 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +97 -0
- keras_hub/src/version_utils.py +23 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/METADATA +34 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/RECORD +297 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/WHEEL +5 -0
- keras_hub_nightly-0.15.0.dev20240823171555.dist-info/top_level.txt +1 -0
@@ -0,0 +1,419 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import os
|
16
|
+
|
17
|
+
import keras
|
18
|
+
from rich import console as rich_console
|
19
|
+
from rich import markup
|
20
|
+
from rich import table as rich_table
|
21
|
+
|
22
|
+
from keras_hub.src.api_export import keras_hub_export
|
23
|
+
from keras_hub.src.utils.keras_utils import print_msg
|
24
|
+
from keras_hub.src.utils.pipeline_model import PipelineModel
|
25
|
+
from keras_hub.src.utils.preset_utils import CONFIG_FILE
|
26
|
+
from keras_hub.src.utils.preset_utils import MODEL_WEIGHTS_FILE
|
27
|
+
from keras_hub.src.utils.preset_utils import TASK_CONFIG_FILE
|
28
|
+
from keras_hub.src.utils.preset_utils import TASK_WEIGHTS_FILE
|
29
|
+
from keras_hub.src.utils.preset_utils import check_config_class
|
30
|
+
from keras_hub.src.utils.preset_utils import check_file_exists
|
31
|
+
from keras_hub.src.utils.preset_utils import check_format
|
32
|
+
from keras_hub.src.utils.preset_utils import get_file
|
33
|
+
from keras_hub.src.utils.preset_utils import jax_memory_cleanup
|
34
|
+
from keras_hub.src.utils.preset_utils import list_presets
|
35
|
+
from keras_hub.src.utils.preset_utils import list_subclasses
|
36
|
+
from keras_hub.src.utils.preset_utils import load_serialized_object
|
37
|
+
from keras_hub.src.utils.preset_utils import save_serialized_object
|
38
|
+
from keras_hub.src.utils.python_utils import classproperty
|
39
|
+
|
40
|
+
|
41
|
+
@keras_hub_export("keras_hub.models.Task")
|
42
|
+
class Task(PipelineModel):
|
43
|
+
"""Base class for all Task models.
|
44
|
+
|
45
|
+
A `Task` wraps a `keras_hub.models.Backbone` and
|
46
|
+
a `keras_hub.models.Preprocessor` to create a model that can be directly
|
47
|
+
used for training, fine-tuning, and prediction for a given text problem.
|
48
|
+
|
49
|
+
All `Task` models have `backbone` and `preprocessor` properties. By
|
50
|
+
default `fit()`, `predict()` and `evaluate()` will preprocess all inputs
|
51
|
+
automatically. To preprocess inputs separately or with a custom function,
|
52
|
+
you can set `task.preprocessor = None`, which disable any automatic
|
53
|
+
preprocessing on inputs.
|
54
|
+
|
55
|
+
All `Task` classes include a `from_preset()` constructor which can be used
|
56
|
+
to load a pre-trained config and weights. Calling `from_preset()` on a task
|
57
|
+
will automatically instantiate a `keras_hub.models.Backbone` and
|
58
|
+
`keras_hub.models.Preprocessor`.
|
59
|
+
"""
|
60
|
+
|
61
|
+
backbone_cls = None
|
62
|
+
preprocessor_cls = None
|
63
|
+
|
64
|
+
def __init__(self, *args, **kwargs):
|
65
|
+
super().__init__(*args, **kwargs)
|
66
|
+
self._functional_layer_ids = set(
|
67
|
+
id(layer) for layer in self._flatten_layers()
|
68
|
+
)
|
69
|
+
self._initialized = True
|
70
|
+
if self.backbone is not None:
|
71
|
+
self.dtype_policy = self._backbone.dtype_policy
|
72
|
+
|
73
|
+
def preprocess_samples(self, x, y=None, sample_weight=None):
|
74
|
+
if self.preprocessor is not None:
|
75
|
+
return self.preprocessor(x, y=y, sample_weight=sample_weight)
|
76
|
+
else:
|
77
|
+
return super().preprocess_samples(x, y, sample_weight)
|
78
|
+
|
79
|
+
def __setattr__(self, name, value):
|
80
|
+
# Work around setattr issues for Keras 2 and Keras 3 torch backend.
|
81
|
+
# Since all our state is covered by functional model we can route
|
82
|
+
# around custom setattr calls.
|
83
|
+
is_property = isinstance(getattr(type(self), name, None), property)
|
84
|
+
is_unitialized = not hasattr(self, "_initialized")
|
85
|
+
is_torch = keras.config.backend() == "torch"
|
86
|
+
if is_torch and (is_property or is_unitialized):
|
87
|
+
return object.__setattr__(self, name, value)
|
88
|
+
return super().__setattr__(name, value)
|
89
|
+
|
90
|
+
@property
|
91
|
+
def backbone(self):
|
92
|
+
"""A `keras_hub.models.Backbone` model with the core architecture."""
|
93
|
+
return getattr(self, "_backbone", None)
|
94
|
+
|
95
|
+
@backbone.setter
|
96
|
+
def backbone(self, value):
|
97
|
+
self._backbone = value
|
98
|
+
|
99
|
+
@property
|
100
|
+
def preprocessor(self):
|
101
|
+
"""A `keras_hub.models.Preprocessor` layer used to preprocess input."""
|
102
|
+
return getattr(self, "_preprocessor", None)
|
103
|
+
|
104
|
+
@preprocessor.setter
|
105
|
+
def preprocessor(self, value):
|
106
|
+
self._preprocessor = value
|
107
|
+
|
108
|
+
def get_config(self):
|
109
|
+
# Don't chain to super here. The default `get_config()` for functional
|
110
|
+
# models is nested and cannot be passed to our Task constructors.
|
111
|
+
return {
|
112
|
+
"backbone": keras.layers.serialize(self.backbone),
|
113
|
+
"preprocessor": keras.layers.serialize(self.preprocessor),
|
114
|
+
"name": self.name,
|
115
|
+
}
|
116
|
+
|
117
|
+
@classmethod
|
118
|
+
def from_config(cls, config):
|
119
|
+
# The default `from_config()` for functional models will return a
|
120
|
+
# vanilla `keras.Model`. We override it to get a subclass instance back.
|
121
|
+
if "backbone" in config and isinstance(config["backbone"], dict):
|
122
|
+
config["backbone"] = keras.layers.deserialize(config["backbone"])
|
123
|
+
if "preprocessor" in config and isinstance(
|
124
|
+
config["preprocessor"], dict
|
125
|
+
):
|
126
|
+
config["preprocessor"] = keras.layers.deserialize(
|
127
|
+
config["preprocessor"]
|
128
|
+
)
|
129
|
+
return cls(**config)
|
130
|
+
|
131
|
+
@classproperty
|
132
|
+
def presets(cls):
|
133
|
+
"""List built-in presets for a `Task` subclass."""
|
134
|
+
presets = list_presets(cls)
|
135
|
+
# We can also load backbone presets.
|
136
|
+
if cls.backbone_cls is not None:
|
137
|
+
presets.update(cls.backbone_cls.presets)
|
138
|
+
for subclass in list_subclasses(cls):
|
139
|
+
presets.update(subclass.presets)
|
140
|
+
return presets
|
141
|
+
|
142
|
+
@classmethod
|
143
|
+
def from_preset(
|
144
|
+
cls,
|
145
|
+
preset,
|
146
|
+
load_weights=True,
|
147
|
+
**kwargs,
|
148
|
+
):
|
149
|
+
"""Instantiate a `keras_hub.models.Task` from a model preset.
|
150
|
+
|
151
|
+
A preset is a directory of configs, weights and other file assets used
|
152
|
+
to save and load a pre-trained model. The `preset` can be passed as a
|
153
|
+
one of:
|
154
|
+
|
155
|
+
1. a built in preset identifier like `'bert_base_en'`
|
156
|
+
2. a Kaggle Models handle like `'kaggle://user/bert/keras/bert_base_en'`
|
157
|
+
3. a Hugging Face handle like `'hf://user/bert_base_en'`
|
158
|
+
4. a path to a local preset directory like `'./bert_base_en'`
|
159
|
+
|
160
|
+
For any `Task` subclass, you can run `cls.presets.keys()` to list all
|
161
|
+
built-in presets available on the class.
|
162
|
+
|
163
|
+
This constructor can be called in one of two ways. Either from a task
|
164
|
+
specific base class like `keras_hub.models.CausalLM.from_preset()`, or
|
165
|
+
from a model class like `keras_hub.models.BertClassifier.from_preset()`.
|
166
|
+
If calling from the a base class, the subclass of the returning object
|
167
|
+
will be inferred from the config in the preset directory.
|
168
|
+
|
169
|
+
Args:
|
170
|
+
preset: string. A built in preset identifier, a Kaggle Models
|
171
|
+
handle, a Hugging Face handle, or a path to a local directory.
|
172
|
+
load_weights: bool. If `True`, the weights will be loaded into the
|
173
|
+
model architecture. If `False`, the weights will be randomly
|
174
|
+
initialized.
|
175
|
+
|
176
|
+
Examples:
|
177
|
+
```python
|
178
|
+
# Load a Gemma generative task.
|
179
|
+
causal_lm = keras_hub.models.CausalLM.from_preset(
|
180
|
+
"gemma_2b_en",
|
181
|
+
)
|
182
|
+
|
183
|
+
# Load a Bert classification task.
|
184
|
+
model = keras_hub.models.Classifier.from_preset(
|
185
|
+
"bert_base_en",
|
186
|
+
num_classes=2,
|
187
|
+
)
|
188
|
+
```
|
189
|
+
"""
|
190
|
+
format = check_format(preset)
|
191
|
+
|
192
|
+
if format == "transformers":
|
193
|
+
if cls.backbone_cls is None:
|
194
|
+
raise ValueError("Backbone class is None")
|
195
|
+
if cls.preprocessor_cls is None:
|
196
|
+
raise ValueError("Preprocessor class is None")
|
197
|
+
|
198
|
+
backbone = cls.backbone_cls.from_preset(preset)
|
199
|
+
preprocessor = cls.preprocessor_cls.from_preset(preset)
|
200
|
+
return cls(backbone=backbone, preprocessor=preprocessor, **kwargs)
|
201
|
+
|
202
|
+
if cls == Task:
|
203
|
+
raise ValueError(
|
204
|
+
"Do not call `Task.from_preset()` directly. Instead call a "
|
205
|
+
"particular task class, e.g. "
|
206
|
+
"`keras_hub.models.Classifier.from_preset()` or "
|
207
|
+
"`keras_hub.models.BertClassifier.from_preset()`."
|
208
|
+
)
|
209
|
+
if "backbone" in kwargs:
|
210
|
+
raise ValueError(
|
211
|
+
"You cannot pass a `backbone` argument to the `from_preset` "
|
212
|
+
f"method. Instead, call the {cls.__name__} default "
|
213
|
+
"constructor with a `backbone` argument. "
|
214
|
+
f"Received: backbone={kwargs['backbone']}."
|
215
|
+
)
|
216
|
+
|
217
|
+
# Check if we should load a `task.json` directly.
|
218
|
+
load_task_config = False
|
219
|
+
if check_file_exists(preset, TASK_CONFIG_FILE):
|
220
|
+
task_preset_cls = check_config_class(preset, TASK_CONFIG_FILE)
|
221
|
+
if issubclass(task_preset_cls, cls):
|
222
|
+
load_task_config = True
|
223
|
+
if load_task_config:
|
224
|
+
# Task case.
|
225
|
+
task_preset_cls = check_config_class(preset, TASK_CONFIG_FILE)
|
226
|
+
task = load_serialized_object(preset, TASK_CONFIG_FILE)
|
227
|
+
if load_weights:
|
228
|
+
jax_memory_cleanup(task)
|
229
|
+
if check_file_exists(preset, TASK_WEIGHTS_FILE):
|
230
|
+
task.load_task_weights(get_file(preset, TASK_WEIGHTS_FILE))
|
231
|
+
task.backbone.load_weights(get_file(preset, MODEL_WEIGHTS_FILE))
|
232
|
+
task.preprocessor.tokenizer.load_preset_assets(preset)
|
233
|
+
return task
|
234
|
+
|
235
|
+
# Backbone case.
|
236
|
+
# If `task.json` doesn't exist or the task preset class is different
|
237
|
+
# from the calling class, create the task based on `config.json`.
|
238
|
+
backbone_preset_cls = check_config_class(preset, CONFIG_FILE)
|
239
|
+
if backbone_preset_cls is not cls.backbone_cls:
|
240
|
+
subclasses = list_subclasses(cls)
|
241
|
+
subclasses = tuple(
|
242
|
+
filter(
|
243
|
+
lambda x: x.backbone_cls == backbone_preset_cls,
|
244
|
+
subclasses,
|
245
|
+
)
|
246
|
+
)
|
247
|
+
if len(subclasses) == 0:
|
248
|
+
raise ValueError(
|
249
|
+
f"No registered subclass of `{cls.__name__}` can load "
|
250
|
+
f"a `{backbone_preset_cls.__name__}`."
|
251
|
+
)
|
252
|
+
if len(subclasses) > 1:
|
253
|
+
names = ", ".join(f"`{x.__name__}`" for x in subclasses)
|
254
|
+
raise ValueError(
|
255
|
+
f"Ambiguous call to `{cls.__name__}.from_preset()`. "
|
256
|
+
f"Found multiple possible subclasses {names}. "
|
257
|
+
"Please call `from_preset` on a subclass directly."
|
258
|
+
)
|
259
|
+
cls = subclasses[0]
|
260
|
+
# Forward dtype to the backbone.
|
261
|
+
backbone_kwargs = {}
|
262
|
+
if "dtype" in kwargs:
|
263
|
+
backbone_kwargs = {"dtype": kwargs.pop("dtype")}
|
264
|
+
backbone = backbone_preset_cls.from_preset(
|
265
|
+
preset, load_weights=load_weights, **backbone_kwargs
|
266
|
+
)
|
267
|
+
if "preprocessor" in kwargs:
|
268
|
+
preprocessor = kwargs.pop("preprocessor")
|
269
|
+
else:
|
270
|
+
preprocessor = cls.preprocessor_cls.from_preset(preset)
|
271
|
+
return cls(backbone=backbone, preprocessor=preprocessor, **kwargs)
|
272
|
+
|
273
|
+
def load_task_weights(self, filepath):
|
274
|
+
"""Load only the tasks specific weights not in the backbone."""
|
275
|
+
if not str(filepath).endswith(".weights.h5"):
|
276
|
+
raise ValueError(
|
277
|
+
"The filename must end in `.weights.h5`. Received: filepath={filepath}"
|
278
|
+
)
|
279
|
+
backbone_layer_ids = set(id(w) for w in self.backbone._flatten_layers())
|
280
|
+
keras.saving.load_weights(
|
281
|
+
self,
|
282
|
+
filepath,
|
283
|
+
objects_to_skip=backbone_layer_ids,
|
284
|
+
)
|
285
|
+
|
286
|
+
def has_task_weights(self):
|
287
|
+
task_weight_ids = set(id(w) for w in self.weights)
|
288
|
+
backbone_weight_ids = set(id(w) for w in self.backbone.weights)
|
289
|
+
return not task_weight_ids.issubset(backbone_weight_ids)
|
290
|
+
|
291
|
+
def save_task_weights(self, filepath):
|
292
|
+
"""Save only the tasks specific weights not in the backbone."""
|
293
|
+
if not str(filepath).endswith(".weights.h5"):
|
294
|
+
raise ValueError(
|
295
|
+
"The filename must end in `.weights.h5`. "
|
296
|
+
f"Received: filepath={filepath}"
|
297
|
+
)
|
298
|
+
|
299
|
+
backbone_layer_ids = set(id(w) for w in self.backbone._flatten_layers())
|
300
|
+
if not self.has_task_weights():
|
301
|
+
raise ValueError(
|
302
|
+
f"Task {self} has no weights not in the `backbone`. "
|
303
|
+
"`save_task_weights()` has nothing to save."
|
304
|
+
)
|
305
|
+
keras.saving.save_weights(
|
306
|
+
self,
|
307
|
+
filepath=filepath,
|
308
|
+
objects_to_skip=backbone_layer_ids,
|
309
|
+
)
|
310
|
+
|
311
|
+
def save_to_preset(self, preset_dir):
|
312
|
+
"""Save task to a preset directory.
|
313
|
+
|
314
|
+
Args:
|
315
|
+
preset_dir: The path to the local model preset directory.
|
316
|
+
"""
|
317
|
+
if self.preprocessor is None:
|
318
|
+
raise ValueError(
|
319
|
+
"Cannot save `task` to preset: `Preprocessor` is not initialized."
|
320
|
+
)
|
321
|
+
|
322
|
+
save_serialized_object(self, preset_dir, config_file=TASK_CONFIG_FILE)
|
323
|
+
if self.has_task_weights():
|
324
|
+
self.save_task_weights(os.path.join(preset_dir, TASK_WEIGHTS_FILE))
|
325
|
+
|
326
|
+
self.preprocessor.save_to_preset(preset_dir)
|
327
|
+
self.backbone.save_to_preset(preset_dir)
|
328
|
+
|
329
|
+
@property
|
330
|
+
def layers(self):
|
331
|
+
# Remove preprocessor from layers so it does not show up in the summary.
|
332
|
+
layers = super().layers
|
333
|
+
if self.preprocessor and self.preprocessor in layers:
|
334
|
+
layers.remove(self.preprocessor)
|
335
|
+
return layers
|
336
|
+
|
337
|
+
def summary(
|
338
|
+
self,
|
339
|
+
line_length=None,
|
340
|
+
positions=None,
|
341
|
+
print_fn=None,
|
342
|
+
**kwargs,
|
343
|
+
):
|
344
|
+
"""Override `model.summary()` to show a preprocessor if set."""
|
345
|
+
|
346
|
+
# Compat fixes for tf.keras.
|
347
|
+
if not hasattr(self, "compiled"):
|
348
|
+
self.compiled = getattr(self.optimizer, "_is_compiled", False)
|
349
|
+
if (
|
350
|
+
self.compiled
|
351
|
+
and self.optimizer
|
352
|
+
and not hasattr(self.optimizer, "built")
|
353
|
+
):
|
354
|
+
self.optimizer.built = getattr(self.optimizer, "_built", False)
|
355
|
+
|
356
|
+
# Below is copied from keras-core for now.
|
357
|
+
# We should consider an API contract.
|
358
|
+
line_length = line_length or 108
|
359
|
+
|
360
|
+
if not print_fn and not keras.utils.is_interactive_logging_enabled():
|
361
|
+
print_fn = print_msg
|
362
|
+
|
363
|
+
def highlight_number(x):
|
364
|
+
return f"[color(45)]{x}[/]" if x is None else f"[color(34)]{x}[/]"
|
365
|
+
|
366
|
+
def highlight_symbol(x):
|
367
|
+
return f"[color(33)]{x}[/]"
|
368
|
+
|
369
|
+
def bold_text(x):
|
370
|
+
return f"[bold]{x}[/]"
|
371
|
+
|
372
|
+
if self.preprocessor:
|
373
|
+
# Create a rich console for printing. Capture for non-interactive logging.
|
374
|
+
if print_fn:
|
375
|
+
console = rich_console.Console(
|
376
|
+
highlight=False, force_terminal=False, color_system=None
|
377
|
+
)
|
378
|
+
console.begin_capture()
|
379
|
+
else:
|
380
|
+
console = rich_console.Console(highlight=False)
|
381
|
+
|
382
|
+
column_1 = rich_table.Column(
|
383
|
+
"Tokenizer (type)",
|
384
|
+
justify="left",
|
385
|
+
width=int(0.5 * line_length),
|
386
|
+
)
|
387
|
+
column_2 = rich_table.Column(
|
388
|
+
"Vocab #",
|
389
|
+
justify="right",
|
390
|
+
width=int(0.5 * line_length),
|
391
|
+
)
|
392
|
+
table = rich_table.Table(
|
393
|
+
column_1, column_2, width=line_length, show_lines=True
|
394
|
+
)
|
395
|
+
tokenizer = self.preprocessor.tokenizer
|
396
|
+
tokenizer_name = markup.escape(tokenizer.name)
|
397
|
+
tokenizer_class = highlight_symbol(
|
398
|
+
markup.escape(tokenizer.__class__.__name__)
|
399
|
+
)
|
400
|
+
table.add_row(
|
401
|
+
f"{tokenizer_name} ({tokenizer_class})",
|
402
|
+
highlight_number(f"{tokenizer.vocabulary_size():,}"),
|
403
|
+
)
|
404
|
+
|
405
|
+
# Print the to the console.
|
406
|
+
preprocessor_name = markup.escape(self.preprocessor.name)
|
407
|
+
console.print(bold_text(f'Preprocessor: "{preprocessor_name}"'))
|
408
|
+
console.print(table)
|
409
|
+
|
410
|
+
# Output captured summary for non-interactive logging.
|
411
|
+
if print_fn:
|
412
|
+
print_fn(console.end_capture(), line_break=False)
|
413
|
+
|
414
|
+
super().summary(
|
415
|
+
line_length=line_length,
|
416
|
+
positions=positions,
|
417
|
+
print_fn=print_fn,
|
418
|
+
**kwargs,
|
419
|
+
)
|
@@ -0,0 +1,13 @@
|
|
1
|
+
# Copyright 2024 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
@@ -0,0 +1,158 @@
|
|
1
|
+
# Copyright 2023 The KerasHub Authors
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import keras
|
15
|
+
from keras import layers
|
16
|
+
|
17
|
+
from keras_hub.src.api_export import keras_hub_export
|
18
|
+
from keras_hub.src.models.backbone import Backbone
|
19
|
+
|
20
|
+
|
21
|
+
@keras_hub_export("keras_hub.models.VGGBackbone")
|
22
|
+
class VGGBackbone(Backbone):
|
23
|
+
"""This class represents Keras Backbone of VGG model.
|
24
|
+
|
25
|
+
This class implements a VGG backbone as described in [Very Deep
|
26
|
+
Convolutional Networks for Large-Scale Image Recognition](
|
27
|
+
https://arxiv.org/abs/1409.1556)(ICLR 2015).
|
28
|
+
|
29
|
+
Args:
|
30
|
+
stackwise_num_repeats: list of ints, number of repeated convolutional
|
31
|
+
blocks per VGG block. For VGG16 this is [2, 2, 3, 3, 3] and for
|
32
|
+
VGG19 this is [2, 2, 4, 4, 4].
|
33
|
+
stackwise_num_filters: list of ints, filter size for convolutional
|
34
|
+
blocks per VGG block. For both VGG16 and VGG19 this is [
|
35
|
+
64, 128, 256, 512, 512].
|
36
|
+
include_rescaling: bool, whether to rescale the inputs. If set to
|
37
|
+
True, inputs will be passed through a `Rescaling(1/255.0)` layer.
|
38
|
+
image_shape: tuple, optional shape tuple, defaults to (224, 224, 3).
|
39
|
+
pooling: bool, Optional pooling mode for feature extraction
|
40
|
+
when `include_top` is `False`.
|
41
|
+
- `None` means that the output of the model will be
|
42
|
+
the 4D tensor output of the
|
43
|
+
last convolutional block.
|
44
|
+
- `avg` means that global average pooling
|
45
|
+
will be applied to the output of the
|
46
|
+
last convolutional block, and thus
|
47
|
+
the output of the model will be a 2D tensor.
|
48
|
+
- `max` means that global max pooling will
|
49
|
+
be applied.
|
50
|
+
|
51
|
+
Examples:
|
52
|
+
```python
|
53
|
+
input_data = np.ones((2, 224, 224, 3), dtype="float32")
|
54
|
+
|
55
|
+
# Pretrained VGG backbone.
|
56
|
+
model = keras_hub.models.VGGBackbone.from_preset("vgg16")
|
57
|
+
model(input_data)
|
58
|
+
|
59
|
+
# Randomly initialized VGG backbone with a custom config.
|
60
|
+
model = keras_hub.models.VGGBackbone(
|
61
|
+
stackwise_num_repeats = [2, 2, 3, 3, 3],
|
62
|
+
stackwise_num_filters = [64, 128, 256, 512, 512],
|
63
|
+
image_shape = (224, 224, 3),
|
64
|
+
include_rescaling = False,
|
65
|
+
pooling = "avg",
|
66
|
+
)
|
67
|
+
model(input_data)
|
68
|
+
```
|
69
|
+
"""
|
70
|
+
|
71
|
+
def __init__(
|
72
|
+
self,
|
73
|
+
stackwise_num_repeats,
|
74
|
+
stackwise_num_filters,
|
75
|
+
include_rescaling,
|
76
|
+
image_shape=(224, 224, 3),
|
77
|
+
pooling="avg",
|
78
|
+
**kwargs,
|
79
|
+
):
|
80
|
+
|
81
|
+
# === Functional Model ===
|
82
|
+
img_input = keras.layers.Input(shape=image_shape)
|
83
|
+
x = img_input
|
84
|
+
|
85
|
+
if include_rescaling:
|
86
|
+
x = layers.Rescaling(scale=1 / 255.0)(x)
|
87
|
+
for stack_index in range(len(stackwise_num_repeats) - 1):
|
88
|
+
x = apply_vgg_block(
|
89
|
+
x=x,
|
90
|
+
num_layers=stackwise_num_repeats[stack_index],
|
91
|
+
filters=stackwise_num_filters[stack_index],
|
92
|
+
kernel_size=(3, 3),
|
93
|
+
activation="relu",
|
94
|
+
padding="same",
|
95
|
+
max_pool=True,
|
96
|
+
name=f"block{stack_index + 1}",
|
97
|
+
)
|
98
|
+
if pooling == "avg":
|
99
|
+
x = layers.GlobalAveragePooling2D()(x)
|
100
|
+
elif pooling == "max":
|
101
|
+
x = layers.GlobalMaxPooling2D()(x)
|
102
|
+
|
103
|
+
super().__init__(inputs=img_input, outputs=x, **kwargs)
|
104
|
+
|
105
|
+
# === Config ===
|
106
|
+
self.stackwise_num_repeats = stackwise_num_repeats
|
107
|
+
self.stackwise_num_filters = stackwise_num_filters
|
108
|
+
self.include_rescaling = include_rescaling
|
109
|
+
self.image_shape = image_shape
|
110
|
+
self.pooling = pooling
|
111
|
+
|
112
|
+
def get_config(self):
|
113
|
+
return {
|
114
|
+
"stackwise_num_repeats": self.stackwise_num_repeats,
|
115
|
+
"stackwise_num_filters": self.stackwise_num_filters,
|
116
|
+
"include_rescaling": self.include_rescaling,
|
117
|
+
"image_shape": self.image_shape,
|
118
|
+
"pooling": self.pooling,
|
119
|
+
}
|
120
|
+
|
121
|
+
|
122
|
+
def apply_vgg_block(
|
123
|
+
x,
|
124
|
+
num_layers,
|
125
|
+
filters,
|
126
|
+
kernel_size,
|
127
|
+
activation,
|
128
|
+
padding,
|
129
|
+
max_pool,
|
130
|
+
name,
|
131
|
+
):
|
132
|
+
"""
|
133
|
+
Applies VGG block
|
134
|
+
Args:
|
135
|
+
x: Tensor, input tensor to pass through network
|
136
|
+
num_layers: int, number of CNN layers in the block
|
137
|
+
filters: int, filter size of each CNN layer in block
|
138
|
+
kernel_size: int (or) tuple, kernel size for CNN layer in block
|
139
|
+
activation: str (or) callable, activation function for each CNN layer in
|
140
|
+
block
|
141
|
+
padding: str (or) callable, padding function for each CNN layer in block
|
142
|
+
max_pool: bool, whether to add MaxPooling2D layer at end of block
|
143
|
+
name: str, name of the block
|
144
|
+
|
145
|
+
Returns:
|
146
|
+
keras.KerasTensor
|
147
|
+
"""
|
148
|
+
for num in range(1, num_layers + 1):
|
149
|
+
x = layers.Conv2D(
|
150
|
+
filters,
|
151
|
+
kernel_size,
|
152
|
+
activation=activation,
|
153
|
+
padding=padding,
|
154
|
+
name=f"{name}_conv{num}",
|
155
|
+
)(x)
|
156
|
+
if max_pool:
|
157
|
+
x = layers.MaxPooling2D((2, 2), (2, 2), name=f"{name}_pool")(x)
|
158
|
+
return x
|