keras-hub-nightly 0.16.1.dev202410020340__py3-none-any.whl → 0.19.0.dev202501260345__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +21 -3
- keras_hub/api/models/__init__.py +71 -12
- keras_hub/api/tokenizers/__init__.py +1 -1
- keras_hub/src/bounding_box/__init__.py +2 -0
- keras_hub/src/bounding_box/converters.py +102 -12
- keras_hub/src/layers/modeling/f_net_encoder.py +1 -1
- keras_hub/src/layers/modeling/masked_lm_head.py +2 -1
- keras_hub/src/layers/modeling/reversible_embedding.py +3 -16
- keras_hub/src/layers/modeling/rms_normalization.py +36 -0
- keras_hub/src/layers/modeling/rotary_embedding.py +3 -2
- keras_hub/src/layers/modeling/token_and_position_embedding.py +1 -1
- keras_hub/src/layers/modeling/transformer_decoder.py +8 -6
- keras_hub/src/layers/modeling/transformer_encoder.py +29 -7
- keras_hub/src/layers/preprocessing/audio_converter.py +3 -7
- keras_hub/src/layers/preprocessing/image_converter.py +170 -34
- keras_hub/src/metrics/bleu.py +4 -3
- keras_hub/src/models/albert/albert_presets.py +4 -12
- keras_hub/src/models/albert/albert_text_classifier.py +7 -7
- keras_hub/src/models/backbone.py +3 -14
- keras_hub/src/models/bart/bart_backbone.py +4 -4
- keras_hub/src/models/bart/bart_presets.py +3 -9
- keras_hub/src/models/bart/bart_seq_2_seq_lm.py +9 -8
- keras_hub/src/models/basnet/__init__.py +5 -0
- keras_hub/src/models/basnet/basnet.py +122 -0
- keras_hub/src/models/basnet/basnet_backbone.py +366 -0
- keras_hub/src/models/basnet/basnet_image_converter.py +8 -0
- keras_hub/src/models/basnet/basnet_preprocessor.py +14 -0
- keras_hub/src/models/basnet/basnet_presets.py +17 -0
- keras_hub/src/models/bert/bert_presets.py +14 -32
- keras_hub/src/models/bert/bert_text_classifier.py +3 -3
- keras_hub/src/models/bloom/bloom_presets.py +8 -24
- keras_hub/src/models/causal_lm.py +56 -12
- keras_hub/src/models/clip/__init__.py +5 -0
- keras_hub/src/models/clip/clip_backbone.py +286 -0
- keras_hub/src/models/clip/clip_encoder_block.py +19 -4
- keras_hub/src/models/clip/clip_image_converter.py +8 -0
- keras_hub/src/models/clip/clip_presets.py +93 -0
- keras_hub/src/models/clip/clip_text_encoder.py +4 -1
- keras_hub/src/models/clip/clip_tokenizer.py +18 -3
- keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
- keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
- keras_hub/src/models/csp_darknet/csp_darknet_backbone.py +2 -1
- keras_hub/src/models/csp_darknet/csp_darknet_image_classifier.py +0 -109
- keras_hub/src/models/deberta_v3/deberta_v3_backbone.py +1 -1
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +5 -15
- keras_hub/src/models/deberta_v3/deberta_v3_text_classifier.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_attention_encoder.py +4 -4
- keras_hub/src/models/deberta_v3/disentangled_self_attention.py +3 -2
- keras_hub/src/models/deberta_v3/relative_embedding.py +1 -1
- keras_hub/src/models/deeplab_v3/__init__.py +7 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_backbone.py +200 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py +10 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py +16 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_layers.py +215 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +17 -0
- keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +111 -0
- keras_hub/src/models/densenet/densenet_backbone.py +6 -4
- keras_hub/src/models/densenet/densenet_image_classifier.py +1 -129
- keras_hub/src/models/densenet/densenet_image_converter.py +2 -4
- keras_hub/src/models/densenet/densenet_presets.py +9 -15
- keras_hub/src/models/distil_bert/distil_bert_masked_lm.py +1 -1
- keras_hub/src/models/distil_bert/distil_bert_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/distil_bert/distil_bert_presets.py +5 -10
- keras_hub/src/models/distil_bert/distil_bert_text_classifier.py +5 -5
- keras_hub/src/models/distil_bert/distil_bert_tokenizer.py +3 -3
- keras_hub/src/models/efficientnet/__init__.py +9 -0
- keras_hub/src/models/efficientnet/cba.py +141 -0
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +160 -61
- keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
- keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
- keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
- keras_hub/src/models/efficientnet/efficientnet_presets.py +193 -0
- keras_hub/src/models/efficientnet/fusedmbconv.py +84 -41
- keras_hub/src/models/efficientnet/mbconv.py +53 -22
- keras_hub/src/models/electra/electra_backbone.py +2 -2
- keras_hub/src/models/electra/electra_presets.py +6 -18
- keras_hub/src/models/f_net/f_net_presets.py +2 -6
- keras_hub/src/models/f_net/f_net_text_classifier.py +3 -3
- keras_hub/src/models/f_net/f_net_text_classifier_preprocessor.py +3 -3
- keras_hub/src/models/falcon/falcon_backbone.py +5 -3
- keras_hub/src/models/falcon/falcon_causal_lm.py +18 -8
- keras_hub/src/models/falcon/falcon_presets.py +1 -3
- keras_hub/src/models/falcon/falcon_tokenizer.py +7 -2
- keras_hub/src/models/feature_pyramid_backbone.py +1 -1
- keras_hub/src/models/flux/__init__.py +5 -0
- keras_hub/src/models/flux/flux_layers.py +496 -0
- keras_hub/src/models/flux/flux_maths.py +225 -0
- keras_hub/src/models/flux/flux_model.py +236 -0
- keras_hub/src/models/flux/flux_presets.py +3 -0
- keras_hub/src/models/flux/flux_text_to_image.py +146 -0
- keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
- keras_hub/src/models/gemma/gemma_backbone.py +35 -20
- keras_hub/src/models/gemma/gemma_causal_lm.py +2 -2
- keras_hub/src/models/gemma/gemma_decoder_block.py +3 -1
- keras_hub/src/models/gemma/gemma_presets.py +29 -63
- keras_hub/src/models/gpt2/gpt2_causal_lm.py +2 -2
- keras_hub/src/models/gpt2/gpt2_presets.py +5 -14
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_attention.py +2 -1
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm.py +3 -3
- keras_hub/src/models/gpt_neo_x/gpt_neo_x_decoder.py +2 -1
- keras_hub/src/models/image_classifier.py +147 -2
- keras_hub/src/models/image_classifier_preprocessor.py +6 -3
- keras_hub/src/models/image_object_detector.py +87 -0
- keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
- keras_hub/src/models/image_segmenter.py +0 -5
- keras_hub/src/models/image_segmenter_preprocessor.py +29 -4
- keras_hub/src/models/image_to_image.py +417 -0
- keras_hub/src/models/inpaint.py +520 -0
- keras_hub/src/models/llama/llama_backbone.py +138 -12
- keras_hub/src/models/llama/llama_causal_lm.py +3 -1
- keras_hub/src/models/llama/llama_presets.py +10 -20
- keras_hub/src/models/llama3/llama3_backbone.py +12 -11
- keras_hub/src/models/llama3/llama3_causal_lm.py +1 -1
- keras_hub/src/models/llama3/llama3_presets.py +4 -12
- keras_hub/src/models/llama3/llama3_tokenizer.py +25 -2
- keras_hub/src/models/mistral/mistral_backbone.py +16 -15
- keras_hub/src/models/mistral/mistral_causal_lm.py +6 -4
- keras_hub/src/models/mistral/mistral_presets.py +3 -9
- keras_hub/src/models/mistral/mistral_transformer_decoder.py +2 -1
- keras_hub/src/models/mit/__init__.py +6 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_backbone.py → mit/mit_backbone.py} +47 -36
- keras_hub/src/models/mit/mit_image_classifier.py +12 -0
- keras_hub/src/models/mit/mit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/mit/mit_image_converter.py +8 -0
- keras_hub/src/models/{mix_transformer/mix_transformer_layers.py → mit/mit_layers.py} +20 -13
- keras_hub/src/models/mit/mit_presets.py +139 -0
- keras_hub/src/models/mobilenet/mobilenet_backbone.py +8 -8
- keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +0 -92
- keras_hub/src/models/opt/opt_causal_lm.py +2 -2
- keras_hub/src/models/opt/opt_presets.py +4 -12
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +63 -17
- keras_hub/src/models/pali_gemma/pali_gemma_causal_lm.py +3 -1
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
- keras_hub/src/models/pali_gemma/pali_gemma_image_converter.py +2 -4
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +173 -17
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +14 -26
- keras_hub/src/models/phi3/phi3_causal_lm.py +3 -1
- keras_hub/src/models/phi3/phi3_decoder.py +0 -1
- keras_hub/src/models/phi3/phi3_presets.py +2 -6
- keras_hub/src/models/phi3/phi3_rotary_embedding.py +1 -1
- keras_hub/src/models/preprocessor.py +25 -11
- keras_hub/src/models/resnet/resnet_backbone.py +3 -14
- keras_hub/src/models/resnet/resnet_image_classifier.py +0 -137
- keras_hub/src/models/resnet/resnet_image_converter.py +2 -4
- keras_hub/src/models/resnet/resnet_presets.py +127 -18
- keras_hub/src/models/retinanet/__init__.py +5 -0
- keras_hub/src/models/retinanet/anchor_generator.py +52 -53
- keras_hub/src/models/retinanet/feature_pyramid.py +103 -39
- keras_hub/src/models/retinanet/non_max_supression.py +1 -0
- keras_hub/src/models/retinanet/prediction_head.py +192 -0
- keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
- keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
- keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
- keras_hub/src/models/retinanet/retinanet_object_detector.py +381 -0
- keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
- keras_hub/src/models/retinanet/retinanet_presets.py +16 -0
- keras_hub/src/models/roberta/roberta_backbone.py +2 -2
- keras_hub/src/models/roberta/roberta_presets.py +6 -8
- keras_hub/src/models/roberta/roberta_text_classifier.py +3 -3
- keras_hub/src/models/sam/__init__.py +5 -0
- keras_hub/src/models/sam/sam_backbone.py +2 -3
- keras_hub/src/models/sam/sam_image_converter.py +2 -4
- keras_hub/src/models/sam/sam_image_segmenter.py +16 -16
- keras_hub/src/models/sam/sam_image_segmenter_preprocessor.py +11 -1
- keras_hub/src/models/sam/sam_layers.py +5 -3
- keras_hub/src/models/sam/sam_presets.py +3 -9
- keras_hub/src/models/sam/sam_prompt_encoder.py +4 -2
- keras_hub/src/models/sam/sam_transformer.py +5 -4
- keras_hub/src/models/segformer/__init__.py +8 -0
- keras_hub/src/models/segformer/segformer_backbone.py +167 -0
- keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
- keras_hub/src/models/segformer/segformer_image_segmenter.py +184 -0
- keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
- keras_hub/src/models/segformer/segformer_presets.py +136 -0
- keras_hub/src/models/seq_2_seq_lm_preprocessor.py +1 -1
- keras_hub/src/models/stable_diffusion_3/flow_match_euler_discrete_scheduler.py +8 -1
- keras_hub/src/models/stable_diffusion_3/mmdit.py +577 -190
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +189 -163
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +178 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +193 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +43 -7
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +25 -14
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor.py +1 -1
- keras_hub/src/models/t5/t5_backbone.py +5 -4
- keras_hub/src/models/t5/t5_presets.py +47 -19
- keras_hub/src/models/task.py +47 -39
- keras_hub/src/models/text_classifier.py +2 -2
- keras_hub/src/models/text_to_image.py +106 -41
- keras_hub/src/models/vae/__init__.py +1 -0
- keras_hub/src/models/vae/vae_backbone.py +184 -0
- keras_hub/src/models/vae/vae_layers.py +739 -0
- keras_hub/src/models/vgg/__init__.py +5 -0
- keras_hub/src/models/vgg/vgg_backbone.py +4 -24
- keras_hub/src/models/vgg/vgg_image_classifier.py +139 -33
- keras_hub/src/models/vgg/vgg_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vgg/vgg_image_converter.py +8 -0
- keras_hub/src/models/vgg/vgg_presets.py +48 -0
- keras_hub/src/models/vit/__init__.py +5 -0
- keras_hub/src/models/vit/vit_backbone.py +152 -0
- keras_hub/src/models/vit/vit_image_classifier.py +187 -0
- keras_hub/src/models/vit/vit_image_classifier_preprocessor.py +12 -0
- keras_hub/src/models/vit/vit_image_converter.py +73 -0
- keras_hub/src/models/vit/vit_layers.py +391 -0
- keras_hub/src/models/vit/vit_presets.py +126 -0
- keras_hub/src/models/vit_det/vit_det_backbone.py +6 -4
- keras_hub/src/models/vit_det/vit_layers.py +3 -3
- keras_hub/src/models/whisper/whisper_audio_converter.py +2 -4
- keras_hub/src/models/whisper/whisper_backbone.py +6 -5
- keras_hub/src/models/whisper/whisper_decoder.py +3 -5
- keras_hub/src/models/whisper/whisper_presets.py +10 -30
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm.py +1 -1
- keras_hub/src/models/xlm_roberta/xlm_roberta_masked_lm_preprocessor.py +2 -2
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +2 -6
- keras_hub/src/models/xlm_roberta/xlm_roberta_text_classifier.py +4 -4
- keras_hub/src/models/xlm_roberta/xlm_roberta_tokenizer.py +2 -1
- keras_hub/src/models/xlnet/relative_attention.py +20 -19
- keras_hub/src/models/xlnet/xlnet_backbone.py +2 -2
- keras_hub/src/models/xlnet/xlnet_content_and_query_embedding.py +3 -5
- keras_hub/src/models/xlnet/xlnet_encoder.py +7 -9
- keras_hub/src/samplers/contrastive_sampler.py +2 -3
- keras_hub/src/samplers/sampler.py +2 -1
- keras_hub/src/tests/test_case.py +41 -6
- keras_hub/src/tokenizers/byte_pair_tokenizer.py +7 -3
- keras_hub/src/tokenizers/byte_tokenizer.py +3 -10
- keras_hub/src/tokenizers/sentence_piece_tokenizer.py +2 -9
- keras_hub/src/tokenizers/sentence_piece_tokenizer_trainer.py +9 -11
- keras_hub/src/tokenizers/tokenizer.py +10 -13
- keras_hub/src/tokenizers/unicode_codepoint_tokenizer.py +9 -7
- keras_hub/src/tokenizers/word_piece_tokenizer_trainer.py +10 -3
- keras_hub/src/utils/keras_utils.py +2 -13
- keras_hub/src/utils/pipeline_model.py +3 -3
- keras_hub/src/utils/preset_utils.py +196 -144
- keras_hub/src/utils/tensor_utils.py +4 -4
- keras_hub/src/utils/timm/convert_densenet.py +6 -4
- keras_hub/src/utils/timm/convert_efficientnet.py +447 -0
- keras_hub/src/utils/timm/convert_resnet.py +1 -1
- keras_hub/src/utils/timm/convert_vgg.py +85 -0
- keras_hub/src/utils/timm/preset_loader.py +14 -9
- keras_hub/src/utils/transformers/convert_llama3.py +21 -5
- keras_hub/src/utils/transformers/convert_vit.py +150 -0
- keras_hub/src/utils/transformers/preset_loader.py +23 -0
- keras_hub/src/utils/transformers/safetensor_utils.py +4 -3
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/METADATA +86 -68
- keras_hub_nightly-0.19.0.dev202501260345.dist-info/RECORD +423 -0
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/WHEEL +1 -1
- keras_hub/src/layers/preprocessing/resizing_image_converter.py +0 -138
- keras_hub/src/models/mix_transformer/__init__.py +0 -0
- keras_hub/src/models/mix_transformer/mix_transformer_classifier.py +0 -119
- keras_hub/src/models/stable_diffusion_3/vae_image_decoder.py +0 -320
- keras_hub_nightly-0.16.1.dev202410020340.dist-info/RECORD +0 -357
- {keras_hub_nightly-0.16.1.dev202410020340.dist-info → keras_hub_nightly-0.19.0.dev202501260345.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,447 @@
|
|
1
|
+
import math
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
|
5
|
+
from keras_hub.src.models.efficientnet.efficientnet_backbone import (
|
6
|
+
EfficientNetBackbone,
|
7
|
+
)
|
8
|
+
|
9
|
+
backbone_cls = EfficientNetBackbone
|
10
|
+
|
11
|
+
|
12
|
+
VARIANT_MAP = {
|
13
|
+
"b0": {
|
14
|
+
"stackwise_width_coefficients": [1.0] * 7,
|
15
|
+
"stackwise_depth_coefficients": [1.0] * 7,
|
16
|
+
"stackwise_squeeze_and_excite_ratios": [0.25] * 7,
|
17
|
+
},
|
18
|
+
"b1": {
|
19
|
+
"stackwise_width_coefficients": [1.0] * 7,
|
20
|
+
"stackwise_depth_coefficients": [1.1] * 7,
|
21
|
+
"stackwise_squeeze_and_excite_ratios": [0.25] * 7,
|
22
|
+
},
|
23
|
+
"b2": {
|
24
|
+
"stackwise_width_coefficients": [1.1] * 7,
|
25
|
+
"stackwise_depth_coefficients": [1.2] * 7,
|
26
|
+
"stackwise_squeeze_and_excite_ratios": [0.25] * 7,
|
27
|
+
},
|
28
|
+
"b3": {
|
29
|
+
"stackwise_width_coefficients": [1.2] * 7,
|
30
|
+
"stackwise_depth_coefficients": [1.4] * 7,
|
31
|
+
"stackwise_squeeze_and_excite_ratios": [0.25] * 7,
|
32
|
+
},
|
33
|
+
"b4": {
|
34
|
+
"stackwise_width_coefficients": [1.4] * 7,
|
35
|
+
"stackwise_depth_coefficients": [1.8] * 7,
|
36
|
+
"stackwise_squeeze_and_excite_ratios": [0.25] * 7,
|
37
|
+
},
|
38
|
+
"b5": {
|
39
|
+
"stackwise_width_coefficients": [1.6] * 7,
|
40
|
+
"stackwise_depth_coefficients": [2.2] * 7,
|
41
|
+
"stackwise_squeeze_and_excite_ratios": [0.25] * 7,
|
42
|
+
},
|
43
|
+
"lite0": {
|
44
|
+
"stackwise_width_coefficients": [1.0] * 7,
|
45
|
+
"stackwise_depth_coefficients": [1.0] * 7,
|
46
|
+
"stackwise_squeeze_and_excite_ratios": [0] * 7,
|
47
|
+
"activation": "relu6",
|
48
|
+
},
|
49
|
+
"el": {
|
50
|
+
"stackwise_width_coefficients": [1.2] * 6,
|
51
|
+
"stackwise_depth_coefficients": [1.4] * 6,
|
52
|
+
"stackwise_kernel_sizes": [3, 3, 3, 5, 5, 5],
|
53
|
+
"stackwise_num_repeats": [1, 2, 4, 5, 4, 2],
|
54
|
+
"stackwise_input_filters": [32, 24, 32, 48, 96, 144],
|
55
|
+
"stackwise_output_filters": [24, 32, 48, 96, 144, 192],
|
56
|
+
"stackwise_expansion_ratios": [4, 8, 8, 8, 8, 8],
|
57
|
+
"stackwise_strides": [1, 2, 2, 2, 1, 2],
|
58
|
+
"stackwise_squeeze_and_excite_ratios": [0] * 6,
|
59
|
+
"stackwise_block_types": ["fused"] * 3 + ["unfused"] * 3,
|
60
|
+
"stackwise_force_input_filters": [24, 0, 0, 0, 0, 0],
|
61
|
+
"stackwise_nores_option": [True] + [False] * 5,
|
62
|
+
"activation": "relu",
|
63
|
+
},
|
64
|
+
"em": {
|
65
|
+
"stackwise_width_coefficients": [1.0] * 6,
|
66
|
+
"stackwise_depth_coefficients": [1.1] * 6,
|
67
|
+
"stackwise_kernel_sizes": [3, 3, 3, 5, 5, 5],
|
68
|
+
"stackwise_num_repeats": [1, 2, 4, 5, 4, 2],
|
69
|
+
"stackwise_input_filters": [32, 24, 32, 48, 96, 144],
|
70
|
+
"stackwise_output_filters": [24, 32, 48, 96, 144, 192],
|
71
|
+
"stackwise_expansion_ratios": [4, 8, 8, 8, 8, 8],
|
72
|
+
"stackwise_strides": [1, 2, 2, 2, 1, 2],
|
73
|
+
"stackwise_squeeze_and_excite_ratios": [0] * 6,
|
74
|
+
"stackwise_block_types": ["fused"] * 3 + ["unfused"] * 3,
|
75
|
+
"stackwise_force_input_filters": [24, 0, 0, 0, 0, 0],
|
76
|
+
"stackwise_nores_option": [True] + [False] * 5,
|
77
|
+
"activation": "relu",
|
78
|
+
},
|
79
|
+
"es": {
|
80
|
+
"stackwise_width_coefficients": [1.0] * 6,
|
81
|
+
"stackwise_depth_coefficients": [1.0] * 6,
|
82
|
+
"stackwise_kernel_sizes": [3, 3, 3, 5, 5, 5],
|
83
|
+
"stackwise_num_repeats": [1, 2, 4, 5, 4, 2],
|
84
|
+
"stackwise_input_filters": [32, 24, 32, 48, 96, 144],
|
85
|
+
"stackwise_output_filters": [24, 32, 48, 96, 144, 192],
|
86
|
+
"stackwise_expansion_ratios": [4, 8, 8, 8, 8, 8],
|
87
|
+
"stackwise_strides": [1, 2, 2, 2, 1, 2],
|
88
|
+
"stackwise_squeeze_and_excite_ratios": [0] * 6,
|
89
|
+
"stackwise_block_types": ["fused"] * 3 + ["unfused"] * 3,
|
90
|
+
"stackwise_force_input_filters": [24, 0, 0, 0, 0, 0],
|
91
|
+
"stackwise_nores_option": [True] + [False] * 5,
|
92
|
+
"activation": "relu",
|
93
|
+
},
|
94
|
+
"rw_m": {
|
95
|
+
"stackwise_width_coefficients": [1.2] * 6,
|
96
|
+
"stackwise_depth_coefficients": [1.2] * 4 + [1.6] * 2,
|
97
|
+
"stackwise_kernel_sizes": [3, 3, 3, 3, 3, 3],
|
98
|
+
"stackwise_num_repeats": [2, 4, 4, 6, 9, 15],
|
99
|
+
"stackwise_input_filters": [24, 24, 48, 64, 128, 160],
|
100
|
+
"stackwise_output_filters": [24, 48, 64, 128, 160, 272],
|
101
|
+
"stackwise_expansion_ratios": [1, 4, 4, 4, 6, 6],
|
102
|
+
"stackwise_strides": [1, 2, 2, 2, 1, 2],
|
103
|
+
"stackwise_squeeze_and_excite_ratios": [0, 0, 0, 0.25, 0.25, 0.25],
|
104
|
+
"stackwise_block_types": ["fused"] * 3 + ["unfused"] * 3,
|
105
|
+
"stackwise_force_input_filters": [0, 0, 0, 0, 0, 0],
|
106
|
+
"stackwise_nores_option": [False] * 6,
|
107
|
+
"activation": "silu",
|
108
|
+
"num_features": 1792,
|
109
|
+
},
|
110
|
+
"rw_s": {
|
111
|
+
"stackwise_width_coefficients": [1.0] * 6,
|
112
|
+
"stackwise_depth_coefficients": [1.0] * 6,
|
113
|
+
"stackwise_kernel_sizes": [3, 3, 3, 3, 3, 3],
|
114
|
+
"stackwise_num_repeats": [2, 4, 4, 6, 9, 15],
|
115
|
+
"stackwise_input_filters": [24, 24, 48, 64, 128, 160],
|
116
|
+
"stackwise_output_filters": [24, 48, 64, 128, 160, 272],
|
117
|
+
"stackwise_expansion_ratios": [1, 4, 4, 4, 6, 6],
|
118
|
+
"stackwise_strides": [1, 2, 2, 2, 1, 2],
|
119
|
+
"stackwise_squeeze_and_excite_ratios": [0, 0, 0, 0.25, 0.25, 0.25],
|
120
|
+
"stackwise_block_types": ["fused"] * 3 + ["unfused"] * 3,
|
121
|
+
"stackwise_force_input_filters": [0, 0, 0, 0, 0, 0],
|
122
|
+
"stackwise_nores_option": [False] * 6,
|
123
|
+
"activation": "silu",
|
124
|
+
"num_features": 1792,
|
125
|
+
},
|
126
|
+
"rw_t": {
|
127
|
+
"stackwise_width_coefficients": [0.8] * 6,
|
128
|
+
"stackwise_depth_coefficients": [0.9] * 6,
|
129
|
+
"stackwise_kernel_sizes": [3, 3, 3, 3, 3, 3],
|
130
|
+
"stackwise_num_repeats": [2, 4, 4, 6, 9, 15],
|
131
|
+
"stackwise_input_filters": [24, 24, 48, 64, 128, 160],
|
132
|
+
"stackwise_output_filters": [24, 48, 64, 128, 160, 256],
|
133
|
+
"stackwise_expansion_ratios": [1, 4, 4, 4, 6, 6],
|
134
|
+
"stackwise_strides": [1, 2, 2, 2, 1, 2],
|
135
|
+
"stackwise_squeeze_and_excite_ratios": [0, 0, 0, 0.25, 0.25, 0.25],
|
136
|
+
"stackwise_block_types": ["cba"] + ["fused"] * 2 + ["unfused"] * 3,
|
137
|
+
"stackwise_force_input_filters": [0, 0, 0, 0, 0, 0],
|
138
|
+
"stackwise_nores_option": [False] * 6,
|
139
|
+
"activation": "silu",
|
140
|
+
},
|
141
|
+
}
|
142
|
+
|
143
|
+
|
144
|
+
def convert_backbone_config(timm_config):
|
145
|
+
timm_architecture = timm_config["architecture"]
|
146
|
+
|
147
|
+
base_kwargs = {
|
148
|
+
"stackwise_kernel_sizes": [3, 3, 5, 3, 5, 5, 3],
|
149
|
+
"stackwise_num_repeats": [1, 2, 2, 3, 3, 4, 1],
|
150
|
+
"stackwise_input_filters": [32, 16, 24, 40, 80, 112, 192],
|
151
|
+
"stackwise_output_filters": [16, 24, 40, 80, 112, 192, 320],
|
152
|
+
"stackwise_expansion_ratios": [1, 6, 6, 6, 6, 6, 6],
|
153
|
+
"stackwise_strides": [1, 2, 2, 2, 1, 2, 1],
|
154
|
+
"stackwise_block_types": ["v1"] * 7,
|
155
|
+
"min_depth": None,
|
156
|
+
"include_stem_padding": True,
|
157
|
+
"use_depth_divisor_as_min_depth": True,
|
158
|
+
"cap_round_filter_decrease": True,
|
159
|
+
"stem_conv_padding": "valid",
|
160
|
+
"batch_norm_momentum": 0.9,
|
161
|
+
"batch_norm_epsilon": 1e-5,
|
162
|
+
"dropout": 0,
|
163
|
+
"projection_activation": None,
|
164
|
+
}
|
165
|
+
|
166
|
+
variant = "_".join(timm_architecture.split("_")[1:])
|
167
|
+
|
168
|
+
if variant not in VARIANT_MAP:
|
169
|
+
raise ValueError(
|
170
|
+
f"Currently, the architecture {timm_architecture} is not supported."
|
171
|
+
)
|
172
|
+
|
173
|
+
base_kwargs.update(VARIANT_MAP[variant])
|
174
|
+
|
175
|
+
return base_kwargs
|
176
|
+
|
177
|
+
|
178
|
+
def convert_weights(backbone, loader, timm_config):
|
179
|
+
timm_architecture = timm_config["architecture"]
|
180
|
+
variant = "_".join(timm_architecture.split("_")[1:])
|
181
|
+
|
182
|
+
def port_conv2d(keras_layer, hf_weight_prefix, port_bias=True):
|
183
|
+
loader.port_weight(
|
184
|
+
keras_layer.kernel,
|
185
|
+
hf_weight_key=f"{hf_weight_prefix}.weight",
|
186
|
+
hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)),
|
187
|
+
)
|
188
|
+
|
189
|
+
if port_bias:
|
190
|
+
loader.port_weight(
|
191
|
+
keras_layer.bias,
|
192
|
+
hf_weight_key=f"{hf_weight_prefix}.bias",
|
193
|
+
)
|
194
|
+
|
195
|
+
def port_depthwise_conv2d(
|
196
|
+
keras_layer,
|
197
|
+
hf_weight_prefix,
|
198
|
+
port_bias=True,
|
199
|
+
depth_multiplier=1,
|
200
|
+
):
|
201
|
+
def convert_pt_conv2d_kernel(pt_kernel):
|
202
|
+
out_channels, in_channels_per_group, height, width = pt_kernel.shape
|
203
|
+
# PT Convs are depthwise convs if and only if
|
204
|
+
# `in_channels_per_group == 1`
|
205
|
+
assert in_channels_per_group == 1
|
206
|
+
pt_kernel = np.transpose(pt_kernel, (2, 3, 0, 1))
|
207
|
+
in_channels = out_channels // depth_multiplier
|
208
|
+
return np.reshape(
|
209
|
+
pt_kernel, (height, width, in_channels, depth_multiplier)
|
210
|
+
)
|
211
|
+
|
212
|
+
loader.port_weight(
|
213
|
+
keras_layer.kernel,
|
214
|
+
hf_weight_key=f"{hf_weight_prefix}.weight",
|
215
|
+
hook_fn=lambda x, _: convert_pt_conv2d_kernel(x),
|
216
|
+
)
|
217
|
+
|
218
|
+
if port_bias:
|
219
|
+
loader.port_weight(
|
220
|
+
keras_layer.bias,
|
221
|
+
hf_weight_key=f"{hf_weight_prefix}.bias",
|
222
|
+
)
|
223
|
+
|
224
|
+
def port_batch_normalization(keras_layer, hf_weight_prefix):
|
225
|
+
loader.port_weight(
|
226
|
+
keras_layer.gamma,
|
227
|
+
hf_weight_key=f"{hf_weight_prefix}.weight",
|
228
|
+
)
|
229
|
+
loader.port_weight(
|
230
|
+
keras_layer.beta,
|
231
|
+
hf_weight_key=f"{hf_weight_prefix}.bias",
|
232
|
+
)
|
233
|
+
loader.port_weight(
|
234
|
+
keras_layer.moving_mean,
|
235
|
+
hf_weight_key=f"{hf_weight_prefix}.running_mean",
|
236
|
+
)
|
237
|
+
loader.port_weight(
|
238
|
+
keras_layer.moving_variance,
|
239
|
+
hf_weight_key=f"{hf_weight_prefix}.running_var",
|
240
|
+
)
|
241
|
+
# do we need num batches tracked?
|
242
|
+
|
243
|
+
# Stem
|
244
|
+
port_conv2d(backbone.get_layer("stem_conv"), "conv_stem", port_bias=False)
|
245
|
+
port_batch_normalization(backbone.get_layer("stem_bn"), "bn1")
|
246
|
+
|
247
|
+
# Stages
|
248
|
+
num_stacks = len(backbone.stackwise_kernel_sizes)
|
249
|
+
|
250
|
+
for stack_index in range(num_stacks):
|
251
|
+
block_type = backbone.stackwise_block_types[stack_index]
|
252
|
+
expansion_ratio = backbone.stackwise_expansion_ratios[stack_index]
|
253
|
+
repeats = backbone.stackwise_num_repeats[stack_index]
|
254
|
+
stack_depth_coefficient = backbone.stackwise_depth_coefficients[
|
255
|
+
stack_index
|
256
|
+
]
|
257
|
+
|
258
|
+
repeats = int(math.ceil(stack_depth_coefficient * repeats))
|
259
|
+
|
260
|
+
se_ratio = VARIANT_MAP[variant]["stackwise_squeeze_and_excite_ratios"][
|
261
|
+
stack_index
|
262
|
+
]
|
263
|
+
|
264
|
+
for block_idx in range(repeats):
|
265
|
+
conv_pw_count = 0
|
266
|
+
bn_count = 1
|
267
|
+
|
268
|
+
# 97 is the start of the lowercase alphabet.
|
269
|
+
letter_identifier = chr(block_idx + 97)
|
270
|
+
|
271
|
+
keras_block_prefix = f"block{stack_index + 1}{letter_identifier}_"
|
272
|
+
hf_block_prefix = f"blocks.{stack_index}.{block_idx}."
|
273
|
+
|
274
|
+
if block_type == "v1":
|
275
|
+
conv_pw_name_map = ["conv_pw", "conv_pwl"]
|
276
|
+
# Initial Expansion Conv
|
277
|
+
if expansion_ratio != 1:
|
278
|
+
port_conv2d(
|
279
|
+
backbone.get_layer(keras_block_prefix + "expand_conv"),
|
280
|
+
hf_block_prefix + conv_pw_name_map[conv_pw_count],
|
281
|
+
port_bias=False,
|
282
|
+
)
|
283
|
+
conv_pw_count += 1
|
284
|
+
port_batch_normalization(
|
285
|
+
backbone.get_layer(keras_block_prefix + "expand_bn"),
|
286
|
+
hf_block_prefix + f"bn{bn_count}",
|
287
|
+
)
|
288
|
+
bn_count += 1
|
289
|
+
|
290
|
+
# Depthwise Conv
|
291
|
+
port_depthwise_conv2d(
|
292
|
+
backbone.get_layer(keras_block_prefix + "dwconv"),
|
293
|
+
hf_block_prefix + "conv_dw",
|
294
|
+
port_bias=False,
|
295
|
+
)
|
296
|
+
port_batch_normalization(
|
297
|
+
backbone.get_layer(keras_block_prefix + "dwconv_bn"),
|
298
|
+
hf_block_prefix + f"bn{bn_count}",
|
299
|
+
)
|
300
|
+
bn_count += 1
|
301
|
+
|
302
|
+
if 0 < se_ratio <= 1:
|
303
|
+
# Squeeze and Excite
|
304
|
+
port_conv2d(
|
305
|
+
backbone.get_layer(keras_block_prefix + "se_reduce"),
|
306
|
+
hf_block_prefix + "se.conv_reduce",
|
307
|
+
)
|
308
|
+
port_conv2d(
|
309
|
+
backbone.get_layer(keras_block_prefix + "se_expand"),
|
310
|
+
hf_block_prefix + "se.conv_expand",
|
311
|
+
)
|
312
|
+
|
313
|
+
# Output/Projection
|
314
|
+
port_conv2d(
|
315
|
+
backbone.get_layer(keras_block_prefix + "project"),
|
316
|
+
hf_block_prefix + conv_pw_name_map[conv_pw_count],
|
317
|
+
port_bias=False,
|
318
|
+
)
|
319
|
+
conv_pw_count += 1
|
320
|
+
port_batch_normalization(
|
321
|
+
backbone.get_layer(keras_block_prefix + "project_bn"),
|
322
|
+
hf_block_prefix + f"bn{bn_count}",
|
323
|
+
)
|
324
|
+
bn_count += 1
|
325
|
+
elif block_type == "fused":
|
326
|
+
fused_block_layer = backbone.get_layer(keras_block_prefix)
|
327
|
+
|
328
|
+
# Initial Expansion Conv
|
329
|
+
port_conv2d(
|
330
|
+
fused_block_layer.conv1,
|
331
|
+
hf_block_prefix + "conv_exp",
|
332
|
+
port_bias=False,
|
333
|
+
)
|
334
|
+
conv_pw_count += 1
|
335
|
+
port_batch_normalization(
|
336
|
+
fused_block_layer.bn1,
|
337
|
+
hf_block_prefix + f"bn{bn_count}",
|
338
|
+
)
|
339
|
+
bn_count += 1
|
340
|
+
|
341
|
+
if 0 < se_ratio <= 1:
|
342
|
+
# Squeeze and Excite
|
343
|
+
port_conv2d(
|
344
|
+
fused_block_layer.se_conv1,
|
345
|
+
hf_block_prefix + "se.conv_reduce",
|
346
|
+
)
|
347
|
+
port_conv2d(
|
348
|
+
fused_block_layer.se_conv2,
|
349
|
+
hf_block_prefix + "se.conv_expand",
|
350
|
+
)
|
351
|
+
|
352
|
+
# Output/Projection
|
353
|
+
port_conv2d(
|
354
|
+
fused_block_layer.output_conv,
|
355
|
+
hf_block_prefix + "conv_pwl",
|
356
|
+
port_bias=False,
|
357
|
+
)
|
358
|
+
conv_pw_count += 1
|
359
|
+
port_batch_normalization(
|
360
|
+
fused_block_layer.bn2,
|
361
|
+
hf_block_prefix + f"bn{bn_count}",
|
362
|
+
)
|
363
|
+
bn_count += 1
|
364
|
+
|
365
|
+
elif block_type == "unfused":
|
366
|
+
unfused_block_layer = backbone.get_layer(keras_block_prefix)
|
367
|
+
# Initial Expansion Conv
|
368
|
+
if expansion_ratio != 1:
|
369
|
+
port_conv2d(
|
370
|
+
unfused_block_layer.conv1,
|
371
|
+
hf_block_prefix + "conv_pw",
|
372
|
+
port_bias=False,
|
373
|
+
)
|
374
|
+
conv_pw_count += 1
|
375
|
+
port_batch_normalization(
|
376
|
+
unfused_block_layer.bn1,
|
377
|
+
hf_block_prefix + f"bn{bn_count}",
|
378
|
+
)
|
379
|
+
bn_count += 1
|
380
|
+
|
381
|
+
# Depthwise Conv
|
382
|
+
port_depthwise_conv2d(
|
383
|
+
unfused_block_layer.depthwise,
|
384
|
+
hf_block_prefix + "conv_dw",
|
385
|
+
port_bias=False,
|
386
|
+
)
|
387
|
+
port_batch_normalization(
|
388
|
+
unfused_block_layer.bn2,
|
389
|
+
hf_block_prefix + f"bn{bn_count}",
|
390
|
+
)
|
391
|
+
bn_count += 1
|
392
|
+
|
393
|
+
if 0 < se_ratio <= 1:
|
394
|
+
# Squeeze and Excite
|
395
|
+
port_conv2d(
|
396
|
+
unfused_block_layer.se_conv1,
|
397
|
+
hf_block_prefix + "se.conv_reduce",
|
398
|
+
)
|
399
|
+
port_conv2d(
|
400
|
+
unfused_block_layer.se_conv2,
|
401
|
+
hf_block_prefix + "se.conv_expand",
|
402
|
+
)
|
403
|
+
|
404
|
+
# Output/Projection
|
405
|
+
port_conv2d(
|
406
|
+
unfused_block_layer.output_conv,
|
407
|
+
hf_block_prefix + "conv_pwl",
|
408
|
+
port_bias=False,
|
409
|
+
)
|
410
|
+
conv_pw_count += 1
|
411
|
+
port_batch_normalization(
|
412
|
+
unfused_block_layer.bn3,
|
413
|
+
hf_block_prefix + f"bn{bn_count}",
|
414
|
+
)
|
415
|
+
bn_count += 1
|
416
|
+
elif block_type == "cba":
|
417
|
+
cba_block_layer = backbone.get_layer(keras_block_prefix)
|
418
|
+
# Initial Expansion Conv
|
419
|
+
port_conv2d(
|
420
|
+
cba_block_layer.conv1,
|
421
|
+
hf_block_prefix + "conv",
|
422
|
+
port_bias=False,
|
423
|
+
)
|
424
|
+
conv_pw_count += 1
|
425
|
+
port_batch_normalization(
|
426
|
+
cba_block_layer.bn1,
|
427
|
+
hf_block_prefix + f"bn{bn_count}",
|
428
|
+
)
|
429
|
+
bn_count += 1
|
430
|
+
|
431
|
+
# Head/Top
|
432
|
+
port_conv2d(backbone.get_layer("top_conv"), "conv_head", port_bias=False)
|
433
|
+
port_batch_normalization(backbone.get_layer("top_bn"), "bn2")
|
434
|
+
|
435
|
+
|
436
|
+
def convert_head(task, loader, timm_config):
|
437
|
+
classifier_prefix = timm_config["pretrained_cfg"]["classifier"]
|
438
|
+
prefix = f"{classifier_prefix}."
|
439
|
+
loader.port_weight(
|
440
|
+
task.output_dense.kernel,
|
441
|
+
hf_weight_key=prefix + "weight",
|
442
|
+
hook_fn=lambda x, _: np.transpose(np.squeeze(x)),
|
443
|
+
)
|
444
|
+
loader.port_weight(
|
445
|
+
task.output_dense.bias,
|
446
|
+
hf_weight_key=prefix + "bias",
|
447
|
+
)
|
@@ -89,7 +89,7 @@ def convert_weights(backbone, loader, timm_config):
|
|
89
89
|
for block_idx in range(backbone.stackwise_num_blocks[stack_index]):
|
90
90
|
if version == "v1":
|
91
91
|
keras_name = f"stack{stack_index}_block{block_idx}"
|
92
|
-
hf_name = f"layer{stack_index+1}.{block_idx}"
|
92
|
+
hf_name = f"layer{stack_index + 1}.{block_idx}"
|
93
93
|
else:
|
94
94
|
keras_name = f"stack{stack_index}_block{block_idx}"
|
95
95
|
hf_name = f"stages.{stack_index}.blocks.{block_idx}"
|
@@ -0,0 +1,85 @@
|
|
1
|
+
from typing import Any
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
|
5
|
+
from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone
|
6
|
+
from keras_hub.src.models.vgg.vgg_image_classifier import VGGImageClassifier
|
7
|
+
|
8
|
+
backbone_cls = VGGBackbone
|
9
|
+
|
10
|
+
|
11
|
+
REPEATS_BY_SIZE = {
|
12
|
+
"vgg11": [1, 1, 2, 2, 2],
|
13
|
+
"vgg13": [2, 2, 2, 2, 2],
|
14
|
+
"vgg16": [2, 2, 3, 3, 3],
|
15
|
+
"vgg19": [2, 2, 4, 4, 4],
|
16
|
+
}
|
17
|
+
|
18
|
+
|
19
|
+
def convert_backbone_config(timm_config):
|
20
|
+
architecture = timm_config["architecture"]
|
21
|
+
stackwise_num_repeats = REPEATS_BY_SIZE[architecture]
|
22
|
+
return dict(
|
23
|
+
stackwise_num_repeats=stackwise_num_repeats,
|
24
|
+
stackwise_num_filters=[64, 128, 256, 512, 512],
|
25
|
+
)
|
26
|
+
|
27
|
+
|
28
|
+
def convert_conv2d(
|
29
|
+
model,
|
30
|
+
loader,
|
31
|
+
keras_layer_name: str,
|
32
|
+
hf_layer_name: str,
|
33
|
+
):
|
34
|
+
loader.port_weight(
|
35
|
+
model.get_layer(keras_layer_name).kernel,
|
36
|
+
hf_weight_key=f"{hf_layer_name}.weight",
|
37
|
+
hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)),
|
38
|
+
)
|
39
|
+
loader.port_weight(
|
40
|
+
model.get_layer(keras_layer_name).bias,
|
41
|
+
hf_weight_key=f"{hf_layer_name}.bias",
|
42
|
+
)
|
43
|
+
|
44
|
+
|
45
|
+
def convert_weights(
|
46
|
+
backbone: VGGBackbone,
|
47
|
+
loader,
|
48
|
+
timm_config: dict[Any],
|
49
|
+
):
|
50
|
+
architecture = timm_config["architecture"]
|
51
|
+
stackwise_num_repeats = REPEATS_BY_SIZE[architecture]
|
52
|
+
|
53
|
+
hf_index_to_keras_layer_name = {}
|
54
|
+
layer_index = 0
|
55
|
+
for block_index, repeats_in_block in enumerate(stackwise_num_repeats):
|
56
|
+
for repeat_index in range(repeats_in_block):
|
57
|
+
hf_index = layer_index
|
58
|
+
layer_index += 2 # Conv + activation layers.
|
59
|
+
layer_name = f"block{block_index + 1}_conv{repeat_index + 1}"
|
60
|
+
hf_index_to_keras_layer_name[hf_index] = layer_name
|
61
|
+
layer_index += 1 # Pooling layer after blocks.
|
62
|
+
|
63
|
+
for hf_index, keras_layer_name in hf_index_to_keras_layer_name.items():
|
64
|
+
convert_conv2d(
|
65
|
+
backbone, loader, keras_layer_name, f"features.{hf_index}"
|
66
|
+
)
|
67
|
+
|
68
|
+
|
69
|
+
def convert_head(
|
70
|
+
task: VGGImageClassifier,
|
71
|
+
loader,
|
72
|
+
timm_config: dict[Any],
|
73
|
+
):
|
74
|
+
convert_conv2d(task.head, loader, "fc1", "pre_logits.fc1")
|
75
|
+
convert_conv2d(task.head, loader, "fc2", "pre_logits.fc2")
|
76
|
+
|
77
|
+
loader.port_weight(
|
78
|
+
task.head.get_layer("predictions").kernel,
|
79
|
+
hf_weight_key="head.fc.weight",
|
80
|
+
hook_fn=lambda x, _: np.transpose(np.squeeze(x)),
|
81
|
+
)
|
82
|
+
loader.port_weight(
|
83
|
+
task.head.get_layer("predictions").bias,
|
84
|
+
hf_weight_key="head.fc.bias",
|
85
|
+
)
|
@@ -4,7 +4,9 @@ from keras_hub.src.models.image_classifier import ImageClassifier
|
|
4
4
|
from keras_hub.src.utils.preset_utils import PresetLoader
|
5
5
|
from keras_hub.src.utils.preset_utils import jax_memory_cleanup
|
6
6
|
from keras_hub.src.utils.timm import convert_densenet
|
7
|
+
from keras_hub.src.utils.timm import convert_efficientnet
|
7
8
|
from keras_hub.src.utils.timm import convert_resnet
|
9
|
+
from keras_hub.src.utils.timm import convert_vgg
|
8
10
|
from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader
|
9
11
|
|
10
12
|
|
@@ -14,8 +16,12 @@ class TimmPresetLoader(PresetLoader):
|
|
14
16
|
architecture = self.config["architecture"]
|
15
17
|
if "resnet" in architecture:
|
16
18
|
self.converter = convert_resnet
|
17
|
-
|
19
|
+
elif "densenet" in architecture:
|
18
20
|
self.converter = convert_densenet
|
21
|
+
elif "vgg" in architecture:
|
22
|
+
self.converter = convert_vgg
|
23
|
+
elif "efficientnet" in architecture:
|
24
|
+
self.converter = convert_efficientnet
|
19
25
|
else:
|
20
26
|
raise ValueError(
|
21
27
|
"KerasHub has no converter for timm models "
|
@@ -52,20 +58,19 @@ class TimmPresetLoader(PresetLoader):
|
|
52
58
|
pretrained_cfg = self.config.get("pretrained_cfg", None)
|
53
59
|
if not pretrained_cfg or "input_size" not in pretrained_cfg:
|
54
60
|
return None
|
55
|
-
# This assumes the same basic setup for all timm preprocessing,
|
56
|
-
# all our image conversion will be via a `ResizingImageConverter. We may
|
61
|
+
# This assumes the same basic setup for all timm preprocessing, We may
|
57
62
|
# need to extend this as we cover more model types.
|
58
63
|
input_size = pretrained_cfg["input_size"]
|
59
64
|
mean = pretrained_cfg["mean"]
|
60
|
-
|
65
|
+
std = pretrained_cfg["std"]
|
66
|
+
scale = [1.0 / 255.0 / s for s in std]
|
67
|
+
offset = [-m / s for m, s in zip(mean, std)]
|
61
68
|
interpolation = pretrained_cfg["interpolation"]
|
62
69
|
if interpolation not in ("bilinear", "nearest", "bicubic"):
|
63
70
|
interpolation = "bilinear" # Unsupported interpolation type.
|
64
71
|
return cls(
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
mean=mean,
|
69
|
-
variance=variance,
|
72
|
+
image_size=input_size[1:],
|
73
|
+
scale=scale,
|
74
|
+
offset=offset,
|
70
75
|
interpolation=interpolation,
|
71
76
|
)
|
@@ -107,10 +107,26 @@ def convert_tokenizer(cls, preset, **kwargs):
|
|
107
107
|
vocab = tokenizer_config["model"]["vocab"]
|
108
108
|
merges = tokenizer_config["model"]["merges"]
|
109
109
|
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
110
|
+
# Load all special tokens with the exception of "reserved" ones.
|
111
|
+
special_tokens = set()
|
112
|
+
for token in tokenizer_config["added_tokens"]:
|
113
|
+
if not token["content"].startswith("<|reserved_special_token_"):
|
114
|
+
vocab[token["content"]] = token["id"]
|
115
|
+
special_tokens.add(token["content"])
|
116
|
+
|
117
|
+
# Load text start and stop tokens from the config.
|
118
|
+
# Llama3 uses the <|end_of_text|> end token for regular models
|
119
|
+
# but uses <|eot_id|> for instruction-tuned variants.
|
120
|
+
tokenizer_config2 = load_json(preset, "tokenizer_config.json")
|
121
|
+
bos_token = tokenizer_config2["bos_token"]
|
122
|
+
eos_token = tokenizer_config2["eos_token"]
|
123
|
+
|
124
|
+
kwargs.update(
|
125
|
+
{
|
126
|
+
"bos_token": bos_token,
|
127
|
+
"eos_token": eos_token,
|
128
|
+
"misc_special_tokens": special_tokens,
|
129
|
+
}
|
130
|
+
)
|
115
131
|
|
116
132
|
return cls(vocabulary=vocab, merges=merges, **kwargs)
|