keras-hub-nightly 0.16.1.dev202410200345__py3-none-any.whl → 0.19.0.dev202412070351__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_hub/api/layers/__init__.py +12 -0
- keras_hub/api/models/__init__.py +32 -0
- keras_hub/src/bounding_box/__init__.py +2 -0
- keras_hub/src/bounding_box/converters.py +102 -12
- keras_hub/src/layers/modeling/rms_normalization.py +34 -0
- keras_hub/src/layers/modeling/transformer_encoder.py +27 -7
- keras_hub/src/layers/preprocessing/image_converter.py +5 -0
- keras_hub/src/models/albert/albert_presets.py +0 -8
- keras_hub/src/models/bart/bart_presets.py +0 -6
- keras_hub/src/models/bert/bert_presets.py +0 -20
- keras_hub/src/models/bloom/bloom_presets.py +0 -16
- keras_hub/src/models/clip/__init__.py +5 -0
- keras_hub/src/models/clip/clip_backbone.py +286 -0
- keras_hub/src/models/clip/clip_encoder_block.py +19 -4
- keras_hub/src/models/clip/clip_image_converter.py +8 -0
- keras_hub/src/models/clip/clip_presets.py +93 -0
- keras_hub/src/models/clip/clip_text_encoder.py +4 -1
- keras_hub/src/models/clip/clip_tokenizer.py +18 -3
- keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
- keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
- keras_hub/src/models/deberta_v3/deberta_v3_presets.py +0 -10
- keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +0 -2
- keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +5 -3
- keras_hub/src/models/densenet/densenet_backbone.py +1 -1
- keras_hub/src/models/densenet/densenet_presets.py +0 -6
- keras_hub/src/models/distil_bert/distil_bert_presets.py +0 -6
- keras_hub/src/models/efficientnet/__init__.py +9 -0
- keras_hub/src/models/efficientnet/cba.py +141 -0
- keras_hub/src/models/efficientnet/efficientnet_backbone.py +139 -56
- keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
- keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
- keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
- keras_hub/src/models/efficientnet/efficientnet_presets.py +192 -0
- keras_hub/src/models/efficientnet/fusedmbconv.py +81 -36
- keras_hub/src/models/efficientnet/mbconv.py +52 -21
- keras_hub/src/models/electra/electra_presets.py +0 -12
- keras_hub/src/models/f_net/f_net_presets.py +0 -4
- keras_hub/src/models/falcon/falcon_presets.py +0 -2
- keras_hub/src/models/flux/__init__.py +5 -0
- keras_hub/src/models/flux/flux_layers.py +494 -0
- keras_hub/src/models/flux/flux_maths.py +218 -0
- keras_hub/src/models/flux/flux_model.py +231 -0
- keras_hub/src/models/flux/flux_presets.py +14 -0
- keras_hub/src/models/flux/flux_text_to_image.py +142 -0
- keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
- keras_hub/src/models/gemma/gemma_presets.py +0 -40
- keras_hub/src/models/gpt2/gpt2_presets.py +0 -9
- keras_hub/src/models/image_object_detector.py +87 -0
- keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
- keras_hub/src/models/image_to_image.py +16 -10
- keras_hub/src/models/inpaint.py +20 -13
- keras_hub/src/models/llama/llama_backbone.py +1 -1
- keras_hub/src/models/llama/llama_presets.py +5 -15
- keras_hub/src/models/llama3/llama3_presets.py +0 -8
- keras_hub/src/models/mistral/mistral_presets.py +0 -6
- keras_hub/src/models/mit/mit_backbone.py +41 -27
- keras_hub/src/models/mit/mit_layers.py +9 -7
- keras_hub/src/models/mit/mit_presets.py +12 -24
- keras_hub/src/models/opt/opt_presets.py +0 -8
- keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +61 -11
- keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
- keras_hub/src/models/pali_gemma/pali_gemma_presets.py +166 -10
- keras_hub/src/models/pali_gemma/pali_gemma_vit.py +12 -11
- keras_hub/src/models/phi3/phi3_presets.py +0 -4
- keras_hub/src/models/resnet/resnet_presets.py +10 -42
- keras_hub/src/models/retinanet/__init__.py +5 -0
- keras_hub/src/models/retinanet/anchor_generator.py +52 -53
- keras_hub/src/models/retinanet/feature_pyramid.py +99 -36
- keras_hub/src/models/retinanet/non_max_supression.py +1 -0
- keras_hub/src/models/retinanet/prediction_head.py +192 -0
- keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
- keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
- keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
- keras_hub/src/models/retinanet/retinanet_object_detector.py +382 -0
- keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
- keras_hub/src/models/retinanet/retinanet_presets.py +15 -0
- keras_hub/src/models/roberta/roberta_presets.py +0 -4
- keras_hub/src/models/sam/sam_backbone.py +0 -1
- keras_hub/src/models/sam/sam_image_segmenter.py +9 -10
- keras_hub/src/models/sam/sam_presets.py +0 -6
- keras_hub/src/models/segformer/__init__.py +8 -0
- keras_hub/src/models/segformer/segformer_backbone.py +163 -0
- keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
- keras_hub/src/models/segformer/segformer_image_segmenter.py +171 -0
- keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
- keras_hub/src/models/segformer/segformer_presets.py +124 -0
- keras_hub/src/models/stable_diffusion_3/mmdit.py +41 -0
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +38 -21
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +3 -3
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +3 -3
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +28 -4
- keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +1 -1
- keras_hub/src/models/t5/t5_backbone.py +5 -4
- keras_hub/src/models/t5/t5_presets.py +41 -13
- keras_hub/src/models/text_to_image.py +13 -5
- keras_hub/src/models/vgg/vgg_backbone.py +1 -1
- keras_hub/src/models/vgg/vgg_presets.py +0 -8
- keras_hub/src/models/whisper/whisper_audio_converter.py +1 -1
- keras_hub/src/models/whisper/whisper_presets.py +0 -20
- keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +0 -4
- keras_hub/src/tests/test_case.py +25 -0
- keras_hub/src/utils/preset_utils.py +17 -4
- keras_hub/src/utils/timm/convert_efficientnet.py +449 -0
- keras_hub/src/utils/timm/preset_loader.py +3 -0
- keras_hub/src/version_utils.py +1 -1
- {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/METADATA +15 -26
- {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/RECORD +109 -76
- {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/WHEEL +1 -1
- {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/top_level.txt +0 -0
@@ -7,9 +7,7 @@ backbone_presets = {
|
|
7
7
|
"English speech data."
|
8
8
|
),
|
9
9
|
"params": 37184256,
|
10
|
-
"official_name": "Whisper",
|
11
10
|
"path": "whisper",
|
12
|
-
"model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
|
13
11
|
},
|
14
12
|
"kaggle_handle": "kaggle://keras/whisper/keras/whisper_tiny_en/3",
|
15
13
|
},
|
@@ -20,9 +18,7 @@ backbone_presets = {
|
|
20
18
|
"English speech data."
|
21
19
|
),
|
22
20
|
"params": 124439808,
|
23
|
-
"official_name": "Whisper",
|
24
21
|
"path": "whisper",
|
25
|
-
"model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
|
26
22
|
},
|
27
23
|
"kaggle_handle": "kaggle://keras/whisper/keras/whisper_base_en/3",
|
28
24
|
},
|
@@ -33,9 +29,7 @@ backbone_presets = {
|
|
33
29
|
"English speech data."
|
34
30
|
),
|
35
31
|
"params": 241734144,
|
36
|
-
"official_name": "Whisper",
|
37
32
|
"path": "whisper",
|
38
|
-
"model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
|
39
33
|
},
|
40
34
|
"kaggle_handle": "kaggle://keras/whisper/keras/whisper_small_en/3",
|
41
35
|
},
|
@@ -46,9 +40,7 @@ backbone_presets = {
|
|
46
40
|
"English speech data."
|
47
41
|
),
|
48
42
|
"params": 763856896,
|
49
|
-
"official_name": "Whisper",
|
50
43
|
"path": "whisper",
|
51
|
-
"model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
|
52
44
|
},
|
53
45
|
"kaggle_handle": "kaggle://keras/whisper/keras/whisper_medium_en/3",
|
54
46
|
},
|
@@ -59,9 +51,7 @@ backbone_presets = {
|
|
59
51
|
"multilingual speech data."
|
60
52
|
),
|
61
53
|
"params": 37760640,
|
62
|
-
"official_name": "Whisper",
|
63
54
|
"path": "whisper",
|
64
|
-
"model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
|
65
55
|
},
|
66
56
|
"kaggle_handle": "kaggle://keras/whisper/keras/whisper_tiny_multi/3",
|
67
57
|
},
|
@@ -72,9 +62,7 @@ backbone_presets = {
|
|
72
62
|
"multilingual speech data."
|
73
63
|
),
|
74
64
|
"params": 72593920,
|
75
|
-
"official_name": "Whisper",
|
76
65
|
"path": "whisper",
|
77
|
-
"model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
|
78
66
|
},
|
79
67
|
"kaggle_handle": "kaggle://keras/whisper/keras/whisper_base_multi/3",
|
80
68
|
},
|
@@ -85,9 +73,7 @@ backbone_presets = {
|
|
85
73
|
"multilingual speech data."
|
86
74
|
),
|
87
75
|
"params": 241734912,
|
88
|
-
"official_name": "Whisper",
|
89
76
|
"path": "whisper",
|
90
|
-
"model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
|
91
77
|
},
|
92
78
|
"kaggle_handle": "kaggle://keras/whisper/keras/whisper_small_multi/3",
|
93
79
|
},
|
@@ -98,9 +84,7 @@ backbone_presets = {
|
|
98
84
|
"multilingual speech data."
|
99
85
|
),
|
100
86
|
"params": 763857920,
|
101
|
-
"official_name": "Whisper",
|
102
87
|
"path": "whisper",
|
103
|
-
"model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
|
104
88
|
},
|
105
89
|
"kaggle_handle": "kaggle://keras/whisper/keras/whisper_medium_multi/3",
|
106
90
|
},
|
@@ -111,9 +95,7 @@ backbone_presets = {
|
|
111
95
|
"multilingual speech data."
|
112
96
|
),
|
113
97
|
"params": 1543304960,
|
114
|
-
"official_name": "Whisper",
|
115
98
|
"path": "whisper",
|
116
|
-
"model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
|
117
99
|
},
|
118
100
|
"kaggle_handle": "kaggle://keras/whisper/keras/whisper_large_multi/3",
|
119
101
|
},
|
@@ -125,9 +107,7 @@ backbone_presets = {
|
|
125
107
|
"of `whisper_large_multi`."
|
126
108
|
),
|
127
109
|
"params": 1543304960,
|
128
|
-
"official_name": "Whisper",
|
129
110
|
"path": "whisper",
|
130
|
-
"model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
|
131
111
|
},
|
132
112
|
"kaggle_handle": "kaggle://keras/whisper/keras/whisper_large_multi_v2/3",
|
133
113
|
},
|
@@ -8,9 +8,7 @@ backbone_presets = {
|
|
8
8
|
"Trained on CommonCrawl in 100 languages."
|
9
9
|
),
|
10
10
|
"params": 277450752,
|
11
|
-
"official_name": "XLM-RoBERTa",
|
12
11
|
"path": "xlm_roberta",
|
13
|
-
"model_card": "https://github.com/facebookresearch/fairseq/blob/main/examples/xlmr/README.md",
|
14
12
|
},
|
15
13
|
"kaggle_handle": "kaggle://keras/xlm_roberta/keras/xlm_roberta_base_multi/2",
|
16
14
|
},
|
@@ -21,9 +19,7 @@ backbone_presets = {
|
|
21
19
|
"Trained on CommonCrawl in 100 languages."
|
22
20
|
),
|
23
21
|
"params": 558837760,
|
24
|
-
"official_name": "XLM-RoBERTa",
|
25
22
|
"path": "xlm_roberta",
|
26
|
-
"model_card": "https://github.com/facebookresearch/fairseq/blob/main/examples/xlmr/README.md",
|
27
23
|
},
|
28
24
|
"kaggle_handle": "kaggle://keras/xlm_roberta/keras/xlm_roberta_large_multi/2",
|
29
25
|
},
|
keras_hub/src/tests/test_case.py
CHANGED
@@ -313,6 +313,14 @@ class TestCase(tf.test.TestCase, parameterized.TestCase):
|
|
313
313
|
|
314
314
|
for policy in ["mixed_float16", "mixed_bfloat16", "bfloat16"]:
|
315
315
|
policy = keras.mixed_precision.Policy(policy)
|
316
|
+
# Ensure the correct `dtype` is set for sublayers or submodels in
|
317
|
+
# `init_kwargs`.
|
318
|
+
original_init_kwargs = init_kwargs.copy()
|
319
|
+
for k, v in init_kwargs.items():
|
320
|
+
if isinstance(v, keras.Layer):
|
321
|
+
config = v.get_config()
|
322
|
+
config["dtype"] = policy
|
323
|
+
init_kwargs[k] = v.__class__.from_config(config)
|
316
324
|
layer = cls(**{**init_kwargs, "dtype": policy})
|
317
325
|
if isinstance(layer, keras.Model):
|
318
326
|
output_data = layer(input_data)
|
@@ -343,8 +351,15 @@ class TestCase(tf.test.TestCase, parameterized.TestCase):
|
|
343
351
|
continue
|
344
352
|
self.assertEqual(policy.compute_dtype, sublayer.compute_dtype)
|
345
353
|
self.assertEqual(policy.variable_dtype, sublayer.variable_dtype)
|
354
|
+
# Restore `init_kwargs`.
|
355
|
+
init_kwargs = original_init_kwargs
|
346
356
|
|
347
357
|
def run_quantization_test(self, instance, cls, init_kwargs, input_data):
|
358
|
+
# TODO: revert the following if. This works around a torch
|
359
|
+
# quantization failure in `MultiHeadAttention` with Keras 3.7.
|
360
|
+
if keras.config.backend() == "torch":
|
361
|
+
return
|
362
|
+
|
348
363
|
def _get_supported_layers(mode):
|
349
364
|
supported_layers = [keras.layers.Dense, keras.layers.EinsumDense]
|
350
365
|
if mode == "int8":
|
@@ -361,6 +376,14 @@ class TestCase(tf.test.TestCase, parameterized.TestCase):
|
|
361
376
|
policy_map[layer.path] = keras.dtype_policies.get(
|
362
377
|
f"{mode}_from_float32"
|
363
378
|
)
|
379
|
+
# Ensure the correct `dtype` is set for sublayers or submodels in
|
380
|
+
# `init_kwargs`.
|
381
|
+
original_init_kwargs = init_kwargs.copy()
|
382
|
+
for k, v in init_kwargs.items():
|
383
|
+
if isinstance(v, keras.Layer):
|
384
|
+
config = v.get_config()
|
385
|
+
config["dtype"] = policy_map
|
386
|
+
init_kwargs[k] = v.__class__.from_config(config)
|
364
387
|
# Instantiate the layer.
|
365
388
|
model = cls(**{**init_kwargs, "dtype": policy_map})
|
366
389
|
# Call layer eagerly.
|
@@ -382,6 +405,8 @@ class TestCase(tf.test.TestCase, parameterized.TestCase):
|
|
382
405
|
# Check weights loading.
|
383
406
|
weights = model.get_weights()
|
384
407
|
revived_model.set_weights(weights)
|
408
|
+
# Restore `init_kwargs`.
|
409
|
+
init_kwargs = original_init_kwargs
|
385
410
|
|
386
411
|
def run_model_saving_test(
|
387
412
|
self,
|
@@ -563,10 +563,8 @@ class PresetLoader:
|
|
563
563
|
backbone_kwargs["dtype"] = kwargs.pop("dtype", None)
|
564
564
|
|
565
565
|
# Forward `height` and `width` to backbone when using `TextToImage`.
|
566
|
-
if "
|
567
|
-
backbone_kwargs["
|
568
|
-
if "width" in kwargs:
|
569
|
-
backbone_kwargs["width"] = kwargs.pop("width", None)
|
566
|
+
if "image_shape" in kwargs:
|
567
|
+
backbone_kwargs["image_shape"] = kwargs.pop("image_shape", None)
|
570
568
|
|
571
569
|
return backbone_kwargs, kwargs
|
572
570
|
|
@@ -660,6 +658,12 @@ class KerasPresetLoader(PresetLoader):
|
|
660
658
|
cls, load_weights, load_task_weights, **kwargs
|
661
659
|
)
|
662
660
|
# We found a `task.json` with a complete config for our class.
|
661
|
+
# Forward backbone args.
|
662
|
+
backbone_kwargs, kwargs = self.get_backbone_kwargs(**kwargs)
|
663
|
+
if "backbone" in task_config["config"]:
|
664
|
+
backbone_config = task_config["config"]["backbone"]["config"]
|
665
|
+
backbone_config = {**backbone_config, **backbone_kwargs}
|
666
|
+
task_config["config"]["backbone"]["config"] = backbone_config
|
663
667
|
task = load_serialized_object(task_config, **kwargs)
|
664
668
|
if task.preprocessor and hasattr(
|
665
669
|
task.preprocessor, "load_preset_assets"
|
@@ -767,14 +771,23 @@ class KerasPresetSaver:
|
|
767
771
|
config_file.write(json.dumps(config, indent=4))
|
768
772
|
|
769
773
|
def _save_metadata(self, layer):
|
774
|
+
from keras_hub.src.models.task import Task
|
770
775
|
from keras_hub.src.version_utils import __version__ as keras_hub_version
|
771
776
|
|
777
|
+
# Find all tasks that are compatible with the backbone.
|
778
|
+
# E.g. for `BertBackbone` we would have `TextClassifier` and `MaskedLM`.
|
779
|
+
# For `ResNetBackbone` we would have `ImageClassifier`.
|
780
|
+
tasks = list_subclasses(Task)
|
781
|
+
tasks = filter(lambda x: x.backbone_cls == type(layer), tasks)
|
782
|
+
tasks = [task.__base__.__name__ for task in tasks]
|
783
|
+
|
772
784
|
keras_version = keras.version() if hasattr(keras, "version") else None
|
773
785
|
metadata = {
|
774
786
|
"keras_version": keras_version,
|
775
787
|
"keras_hub_version": keras_hub_version,
|
776
788
|
"parameter_count": layer.count_params(),
|
777
789
|
"date_saved": datetime.datetime.now().strftime("%Y-%m-%d@%H:%M:%S"),
|
790
|
+
"tasks": tasks,
|
778
791
|
}
|
779
792
|
metadata_path = os.path.join(self.preset_dir, METADATA_FILE)
|
780
793
|
with open(metadata_path, "w") as metadata_file:
|
@@ -0,0 +1,449 @@
|
|
1
|
+
import math
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
|
5
|
+
from keras_hub.src.models.efficientnet.efficientnet_backbone import (
|
6
|
+
EfficientNetBackbone,
|
7
|
+
)
|
8
|
+
|
9
|
+
backbone_cls = EfficientNetBackbone
|
10
|
+
|
11
|
+
|
12
|
+
VARIANT_MAP = {
|
13
|
+
"b0": {
|
14
|
+
"stackwise_width_coefficients": [1.0] * 7,
|
15
|
+
"stackwise_depth_coefficients": [1.0] * 7,
|
16
|
+
"stackwise_squeeze_and_excite_ratios": [0.25] * 7,
|
17
|
+
},
|
18
|
+
"b1": {
|
19
|
+
"stackwise_width_coefficients": [1.0] * 7,
|
20
|
+
"stackwise_depth_coefficients": [1.1] * 7,
|
21
|
+
"stackwise_squeeze_and_excite_ratios": [0.25] * 7,
|
22
|
+
},
|
23
|
+
"b2": {
|
24
|
+
"stackwise_width_coefficients": [1.1] * 7,
|
25
|
+
"stackwise_depth_coefficients": [1.2] * 7,
|
26
|
+
"stackwise_squeeze_and_excite_ratios": [0.25] * 7,
|
27
|
+
},
|
28
|
+
"b3": {
|
29
|
+
"stackwise_width_coefficients": [1.2] * 7,
|
30
|
+
"stackwise_depth_coefficients": [1.4] * 7,
|
31
|
+
"stackwise_squeeze_and_excite_ratios": [0.25] * 7,
|
32
|
+
},
|
33
|
+
"b4": {
|
34
|
+
"stackwise_width_coefficients": [1.4] * 7,
|
35
|
+
"stackwise_depth_coefficients": [1.8] * 7,
|
36
|
+
"stackwise_squeeze_and_excite_ratios": [0.25] * 7,
|
37
|
+
},
|
38
|
+
"b5": {
|
39
|
+
"stackwise_width_coefficients": [1.6] * 7,
|
40
|
+
"stackwise_depth_coefficients": [2.2] * 7,
|
41
|
+
"stackwise_squeeze_and_excite_ratios": [0.25] * 7,
|
42
|
+
},
|
43
|
+
"lite0": {
|
44
|
+
"stackwise_width_coefficients": [1.0] * 7,
|
45
|
+
"stackwise_depth_coefficients": [1.0] * 7,
|
46
|
+
"stackwise_squeeze_and_excite_ratios": [0] * 7,
|
47
|
+
"activation": "relu6",
|
48
|
+
},
|
49
|
+
"el": {
|
50
|
+
"stackwise_width_coefficients": [1.2] * 6,
|
51
|
+
"stackwise_depth_coefficients": [1.4] * 6,
|
52
|
+
"stackwise_kernel_sizes": [3, 3, 3, 5, 5, 5],
|
53
|
+
"stackwise_num_repeats": [1, 2, 4, 5, 4, 2],
|
54
|
+
"stackwise_input_filters": [32, 24, 32, 48, 96, 144],
|
55
|
+
"stackwise_output_filters": [24, 32, 48, 96, 144, 192],
|
56
|
+
"stackwise_expansion_ratios": [4, 8, 8, 8, 8, 8],
|
57
|
+
"stackwise_strides": [1, 2, 2, 2, 1, 2],
|
58
|
+
"stackwise_squeeze_and_excite_ratios": [0] * 6,
|
59
|
+
"stackwise_block_types": ["fused"] * 3 + ["unfused"] * 3,
|
60
|
+
"stackwise_force_input_filters": [24, 0, 0, 0, 0, 0],
|
61
|
+
"stackwise_nores_option": [True] + [False] * 5,
|
62
|
+
"activation": "relu",
|
63
|
+
},
|
64
|
+
"em": {
|
65
|
+
"stackwise_width_coefficients": [1.0] * 6,
|
66
|
+
"stackwise_depth_coefficients": [1.1] * 6,
|
67
|
+
"stackwise_kernel_sizes": [3, 3, 3, 5, 5, 5],
|
68
|
+
"stackwise_num_repeats": [1, 2, 4, 5, 4, 2],
|
69
|
+
"stackwise_input_filters": [32, 24, 32, 48, 96, 144],
|
70
|
+
"stackwise_output_filters": [24, 32, 48, 96, 144, 192],
|
71
|
+
"stackwise_expansion_ratios": [4, 8, 8, 8, 8, 8],
|
72
|
+
"stackwise_strides": [1, 2, 2, 2, 1, 2],
|
73
|
+
"stackwise_squeeze_and_excite_ratios": [0] * 6,
|
74
|
+
"stackwise_block_types": ["fused"] * 3 + ["unfused"] * 3,
|
75
|
+
"stackwise_force_input_filters": [24, 0, 0, 0, 0, 0],
|
76
|
+
"stackwise_nores_option": [True] + [False] * 5,
|
77
|
+
"activation": "relu",
|
78
|
+
},
|
79
|
+
"es": {
|
80
|
+
"stackwise_width_coefficients": [1.0] * 6,
|
81
|
+
"stackwise_depth_coefficients": [1.0] * 6,
|
82
|
+
"stackwise_kernel_sizes": [3, 3, 3, 5, 5, 5],
|
83
|
+
"stackwise_num_repeats": [1, 2, 4, 5, 4, 2],
|
84
|
+
"stackwise_input_filters": [32, 24, 32, 48, 96, 144],
|
85
|
+
"stackwise_output_filters": [24, 32, 48, 96, 144, 192],
|
86
|
+
"stackwise_expansion_ratios": [4, 8, 8, 8, 8, 8],
|
87
|
+
"stackwise_strides": [1, 2, 2, 2, 1, 2],
|
88
|
+
"stackwise_squeeze_and_excite_ratios": [0] * 6,
|
89
|
+
"stackwise_block_types": ["fused"] * 3 + ["unfused"] * 3,
|
90
|
+
"stackwise_force_input_filters": [24, 0, 0, 0, 0, 0],
|
91
|
+
"stackwise_nores_option": [True] + [False] * 5,
|
92
|
+
"activation": "relu",
|
93
|
+
},
|
94
|
+
"rw_m": {
|
95
|
+
"stackwise_width_coefficients": [1.2] * 6,
|
96
|
+
"stackwise_depth_coefficients": [1.2] * 4 + [1.6] * 2,
|
97
|
+
"stackwise_kernel_sizes": [3, 3, 3, 3, 3, 3],
|
98
|
+
"stackwise_num_repeats": [2, 4, 4, 6, 9, 15],
|
99
|
+
"stackwise_input_filters": [24, 24, 48, 64, 128, 160],
|
100
|
+
"stackwise_output_filters": [24, 48, 64, 128, 160, 272],
|
101
|
+
"stackwise_expansion_ratios": [1, 4, 4, 4, 6, 6],
|
102
|
+
"stackwise_strides": [1, 2, 2, 2, 1, 2],
|
103
|
+
"stackwise_squeeze_and_excite_ratios": [0, 0, 0, 0.25, 0.25, 0.25],
|
104
|
+
"stackwise_block_types": ["fused"] * 3 + ["unfused"] * 3,
|
105
|
+
"stackwise_force_input_filters": [0, 0, 0, 0, 0, 0],
|
106
|
+
"stackwise_nores_option": [False] * 6,
|
107
|
+
"activation": "silu",
|
108
|
+
"num_features": 1792,
|
109
|
+
},
|
110
|
+
"rw_s": {
|
111
|
+
"stackwise_width_coefficients": [1.0] * 6,
|
112
|
+
"stackwise_depth_coefficients": [1.0] * 6,
|
113
|
+
"stackwise_kernel_sizes": [3, 3, 3, 3, 3, 3],
|
114
|
+
"stackwise_num_repeats": [2, 4, 4, 6, 9, 15],
|
115
|
+
"stackwise_input_filters": [24, 24, 48, 64, 128, 160],
|
116
|
+
"stackwise_output_filters": [24, 48, 64, 128, 160, 272],
|
117
|
+
"stackwise_expansion_ratios": [1, 4, 4, 4, 6, 6],
|
118
|
+
"stackwise_strides": [1, 2, 2, 2, 1, 2],
|
119
|
+
"stackwise_squeeze_and_excite_ratios": [0, 0, 0, 0.25, 0.25, 0.25],
|
120
|
+
"stackwise_block_types": ["fused"] * 3 + ["unfused"] * 3,
|
121
|
+
"stackwise_force_input_filters": [0, 0, 0, 0, 0, 0],
|
122
|
+
"stackwise_nores_option": [False] * 6,
|
123
|
+
"activation": "silu",
|
124
|
+
"num_features": 1792,
|
125
|
+
},
|
126
|
+
"rw_t": {
|
127
|
+
"stackwise_width_coefficients": [0.8] * 6,
|
128
|
+
"stackwise_depth_coefficients": [0.9] * 6,
|
129
|
+
"stackwise_kernel_sizes": [3, 3, 3, 3, 3, 3],
|
130
|
+
"stackwise_num_repeats": [2, 4, 4, 6, 9, 15],
|
131
|
+
"stackwise_input_filters": [24, 24, 48, 64, 128, 160],
|
132
|
+
"stackwise_output_filters": [24, 48, 64, 128, 160, 256],
|
133
|
+
"stackwise_expansion_ratios": [1, 4, 4, 4, 6, 6],
|
134
|
+
"stackwise_strides": [1, 2, 2, 2, 1, 2],
|
135
|
+
"stackwise_squeeze_and_excite_ratios": [0, 0, 0, 0.25, 0.25, 0.25],
|
136
|
+
"stackwise_block_types": ["cba"] + ["fused"] * 2 + ["unfused"] * 3,
|
137
|
+
"stackwise_force_input_filters": [0, 0, 0, 0, 0, 0],
|
138
|
+
"stackwise_nores_option": [False] * 6,
|
139
|
+
"activation": "silu",
|
140
|
+
},
|
141
|
+
}
|
142
|
+
|
143
|
+
|
144
|
+
def convert_backbone_config(timm_config):
|
145
|
+
timm_architecture = timm_config["architecture"]
|
146
|
+
|
147
|
+
base_kwargs = {
|
148
|
+
"stackwise_kernel_sizes": [3, 3, 5, 3, 5, 5, 3],
|
149
|
+
"stackwise_num_repeats": [1, 2, 2, 3, 3, 4, 1],
|
150
|
+
"stackwise_input_filters": [32, 16, 24, 40, 80, 112, 192],
|
151
|
+
"stackwise_output_filters": [16, 24, 40, 80, 112, 192, 320],
|
152
|
+
"stackwise_expansion_ratios": [1, 6, 6, 6, 6, 6, 6],
|
153
|
+
"stackwise_strides": [1, 2, 2, 2, 1, 2, 1],
|
154
|
+
"stackwise_block_types": ["v1"] * 7,
|
155
|
+
"min_depth": None,
|
156
|
+
"include_stem_padding": True,
|
157
|
+
"use_depth_divisor_as_min_depth": True,
|
158
|
+
"cap_round_filter_decrease": True,
|
159
|
+
"stem_conv_padding": "valid",
|
160
|
+
"batch_norm_momentum": 0.9,
|
161
|
+
"batch_norm_epsilon": 1e-5,
|
162
|
+
"dropout": 0,
|
163
|
+
"projection_activation": None,
|
164
|
+
}
|
165
|
+
|
166
|
+
variant = "_".join(timm_architecture.split("_")[1:])
|
167
|
+
|
168
|
+
if variant not in VARIANT_MAP:
|
169
|
+
raise ValueError(
|
170
|
+
f"Currently, the architecture {timm_architecture} is not supported."
|
171
|
+
)
|
172
|
+
|
173
|
+
base_kwargs.update(VARIANT_MAP[variant])
|
174
|
+
|
175
|
+
return base_kwargs
|
176
|
+
|
177
|
+
|
178
|
+
def convert_weights(backbone, loader, timm_config):
|
179
|
+
timm_architecture = timm_config["architecture"]
|
180
|
+
variant = "_".join(timm_architecture.split("_")[1:])
|
181
|
+
|
182
|
+
def port_conv2d(keras_layer, hf_weight_prefix, port_bias=True):
|
183
|
+
loader.port_weight(
|
184
|
+
keras_layer.kernel,
|
185
|
+
hf_weight_key=f"{hf_weight_prefix}.weight",
|
186
|
+
hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)),
|
187
|
+
)
|
188
|
+
|
189
|
+
if port_bias:
|
190
|
+
loader.port_weight(
|
191
|
+
keras_layer.bias,
|
192
|
+
hf_weight_key=f"{hf_weight_prefix}.bias",
|
193
|
+
)
|
194
|
+
|
195
|
+
def port_depthwise_conv2d(
|
196
|
+
keras_layer,
|
197
|
+
hf_weight_prefix,
|
198
|
+
port_bias=True,
|
199
|
+
depth_multiplier=1,
|
200
|
+
):
|
201
|
+
|
202
|
+
def convert_pt_conv2d_kernel(pt_kernel):
|
203
|
+
out_channels, in_channels_per_group, height, width = pt_kernel.shape
|
204
|
+
# PT Convs are depthwise convs if and only if in_channels_per_group == 1
|
205
|
+
assert in_channels_per_group == 1
|
206
|
+
pt_kernel = np.transpose(pt_kernel, (2, 3, 0, 1))
|
207
|
+
in_channels = out_channels // depth_multiplier
|
208
|
+
return np.reshape(
|
209
|
+
pt_kernel, (height, width, in_channels, depth_multiplier)
|
210
|
+
)
|
211
|
+
|
212
|
+
loader.port_weight(
|
213
|
+
keras_layer.kernel,
|
214
|
+
hf_weight_key=f"{hf_weight_prefix}.weight",
|
215
|
+
hook_fn=lambda x, _: convert_pt_conv2d_kernel(x),
|
216
|
+
)
|
217
|
+
|
218
|
+
if port_bias:
|
219
|
+
loader.port_weight(
|
220
|
+
keras_layer.bias,
|
221
|
+
hf_weight_key=f"{hf_weight_prefix}.bias",
|
222
|
+
)
|
223
|
+
|
224
|
+
def port_batch_normalization(keras_layer, hf_weight_prefix):
|
225
|
+
loader.port_weight(
|
226
|
+
keras_layer.gamma,
|
227
|
+
hf_weight_key=f"{hf_weight_prefix}.weight",
|
228
|
+
)
|
229
|
+
loader.port_weight(
|
230
|
+
keras_layer.beta,
|
231
|
+
hf_weight_key=f"{hf_weight_prefix}.bias",
|
232
|
+
)
|
233
|
+
loader.port_weight(
|
234
|
+
keras_layer.moving_mean,
|
235
|
+
hf_weight_key=f"{hf_weight_prefix}.running_mean",
|
236
|
+
)
|
237
|
+
loader.port_weight(
|
238
|
+
keras_layer.moving_variance,
|
239
|
+
hf_weight_key=f"{hf_weight_prefix}.running_var",
|
240
|
+
)
|
241
|
+
# do we need num batches tracked?
|
242
|
+
|
243
|
+
# Stem
|
244
|
+
port_conv2d(backbone.get_layer("stem_conv"), "conv_stem", port_bias=False)
|
245
|
+
port_batch_normalization(backbone.get_layer("stem_bn"), "bn1")
|
246
|
+
|
247
|
+
# Stages
|
248
|
+
num_stacks = len(backbone.stackwise_kernel_sizes)
|
249
|
+
|
250
|
+
for stack_index in range(num_stacks):
|
251
|
+
|
252
|
+
block_type = backbone.stackwise_block_types[stack_index]
|
253
|
+
expansion_ratio = backbone.stackwise_expansion_ratios[stack_index]
|
254
|
+
repeats = backbone.stackwise_num_repeats[stack_index]
|
255
|
+
stack_depth_coefficient = backbone.stackwise_depth_coefficients[
|
256
|
+
stack_index
|
257
|
+
]
|
258
|
+
|
259
|
+
repeats = int(math.ceil(stack_depth_coefficient * repeats))
|
260
|
+
|
261
|
+
se_ratio = VARIANT_MAP[variant]["stackwise_squeeze_and_excite_ratios"][
|
262
|
+
stack_index
|
263
|
+
]
|
264
|
+
|
265
|
+
for block_idx in range(repeats):
|
266
|
+
|
267
|
+
conv_pw_count = 0
|
268
|
+
bn_count = 1
|
269
|
+
|
270
|
+
# 97 is the start of the lowercase alphabet.
|
271
|
+
letter_identifier = chr(block_idx + 97)
|
272
|
+
|
273
|
+
keras_block_prefix = f"block{stack_index+1}{letter_identifier}_"
|
274
|
+
hf_block_prefix = f"blocks.{stack_index}.{block_idx}."
|
275
|
+
|
276
|
+
if block_type == "v1":
|
277
|
+
conv_pw_name_map = ["conv_pw", "conv_pwl"]
|
278
|
+
# Initial Expansion Conv
|
279
|
+
if expansion_ratio != 1:
|
280
|
+
port_conv2d(
|
281
|
+
backbone.get_layer(keras_block_prefix + "expand_conv"),
|
282
|
+
hf_block_prefix + conv_pw_name_map[conv_pw_count],
|
283
|
+
port_bias=False,
|
284
|
+
)
|
285
|
+
conv_pw_count += 1
|
286
|
+
port_batch_normalization(
|
287
|
+
backbone.get_layer(keras_block_prefix + "expand_bn"),
|
288
|
+
hf_block_prefix + f"bn{bn_count}",
|
289
|
+
)
|
290
|
+
bn_count += 1
|
291
|
+
|
292
|
+
# Depthwise Conv
|
293
|
+
port_depthwise_conv2d(
|
294
|
+
backbone.get_layer(keras_block_prefix + "dwconv"),
|
295
|
+
hf_block_prefix + "conv_dw",
|
296
|
+
port_bias=False,
|
297
|
+
)
|
298
|
+
port_batch_normalization(
|
299
|
+
backbone.get_layer(keras_block_prefix + "dwconv_bn"),
|
300
|
+
hf_block_prefix + f"bn{bn_count}",
|
301
|
+
)
|
302
|
+
bn_count += 1
|
303
|
+
|
304
|
+
if 0 < se_ratio <= 1:
|
305
|
+
# Squeeze and Excite
|
306
|
+
port_conv2d(
|
307
|
+
backbone.get_layer(keras_block_prefix + "se_reduce"),
|
308
|
+
hf_block_prefix + "se.conv_reduce",
|
309
|
+
)
|
310
|
+
port_conv2d(
|
311
|
+
backbone.get_layer(keras_block_prefix + "se_expand"),
|
312
|
+
hf_block_prefix + "se.conv_expand",
|
313
|
+
)
|
314
|
+
|
315
|
+
# Output/Projection
|
316
|
+
port_conv2d(
|
317
|
+
backbone.get_layer(keras_block_prefix + "project"),
|
318
|
+
hf_block_prefix + conv_pw_name_map[conv_pw_count],
|
319
|
+
port_bias=False,
|
320
|
+
)
|
321
|
+
conv_pw_count += 1
|
322
|
+
port_batch_normalization(
|
323
|
+
backbone.get_layer(keras_block_prefix + "project_bn"),
|
324
|
+
hf_block_prefix + f"bn{bn_count}",
|
325
|
+
)
|
326
|
+
bn_count += 1
|
327
|
+
elif block_type == "fused":
|
328
|
+
fused_block_layer = backbone.get_layer(keras_block_prefix)
|
329
|
+
|
330
|
+
# Initial Expansion Conv
|
331
|
+
port_conv2d(
|
332
|
+
fused_block_layer.conv1,
|
333
|
+
hf_block_prefix + "conv_exp",
|
334
|
+
port_bias=False,
|
335
|
+
)
|
336
|
+
conv_pw_count += 1
|
337
|
+
port_batch_normalization(
|
338
|
+
fused_block_layer.bn1,
|
339
|
+
hf_block_prefix + f"bn{bn_count}",
|
340
|
+
)
|
341
|
+
bn_count += 1
|
342
|
+
|
343
|
+
if 0 < se_ratio <= 1:
|
344
|
+
# Squeeze and Excite
|
345
|
+
port_conv2d(
|
346
|
+
fused_block_layer.se_conv1,
|
347
|
+
hf_block_prefix + "se.conv_reduce",
|
348
|
+
)
|
349
|
+
port_conv2d(
|
350
|
+
fused_block_layer.se_conv2,
|
351
|
+
hf_block_prefix + "se.conv_expand",
|
352
|
+
)
|
353
|
+
|
354
|
+
# Output/Projection
|
355
|
+
port_conv2d(
|
356
|
+
fused_block_layer.output_conv,
|
357
|
+
hf_block_prefix + "conv_pwl",
|
358
|
+
port_bias=False,
|
359
|
+
)
|
360
|
+
conv_pw_count += 1
|
361
|
+
port_batch_normalization(
|
362
|
+
fused_block_layer.bn2,
|
363
|
+
hf_block_prefix + f"bn{bn_count}",
|
364
|
+
)
|
365
|
+
bn_count += 1
|
366
|
+
|
367
|
+
elif block_type == "unfused":
|
368
|
+
unfused_block_layer = backbone.get_layer(keras_block_prefix)
|
369
|
+
# Initial Expansion Conv
|
370
|
+
if expansion_ratio != 1:
|
371
|
+
port_conv2d(
|
372
|
+
unfused_block_layer.conv1,
|
373
|
+
hf_block_prefix + "conv_pw",
|
374
|
+
port_bias=False,
|
375
|
+
)
|
376
|
+
conv_pw_count += 1
|
377
|
+
port_batch_normalization(
|
378
|
+
unfused_block_layer.bn1,
|
379
|
+
hf_block_prefix + f"bn{bn_count}",
|
380
|
+
)
|
381
|
+
bn_count += 1
|
382
|
+
|
383
|
+
# Depthwise Conv
|
384
|
+
port_depthwise_conv2d(
|
385
|
+
unfused_block_layer.depthwise,
|
386
|
+
hf_block_prefix + "conv_dw",
|
387
|
+
port_bias=False,
|
388
|
+
)
|
389
|
+
port_batch_normalization(
|
390
|
+
unfused_block_layer.bn2,
|
391
|
+
hf_block_prefix + f"bn{bn_count}",
|
392
|
+
)
|
393
|
+
bn_count += 1
|
394
|
+
|
395
|
+
if 0 < se_ratio <= 1:
|
396
|
+
# Squeeze and Excite
|
397
|
+
port_conv2d(
|
398
|
+
unfused_block_layer.se_conv1,
|
399
|
+
hf_block_prefix + "se.conv_reduce",
|
400
|
+
)
|
401
|
+
port_conv2d(
|
402
|
+
unfused_block_layer.se_conv2,
|
403
|
+
hf_block_prefix + "se.conv_expand",
|
404
|
+
)
|
405
|
+
|
406
|
+
# Output/Projection
|
407
|
+
port_conv2d(
|
408
|
+
unfused_block_layer.output_conv,
|
409
|
+
hf_block_prefix + "conv_pwl",
|
410
|
+
port_bias=False,
|
411
|
+
)
|
412
|
+
conv_pw_count += 1
|
413
|
+
port_batch_normalization(
|
414
|
+
unfused_block_layer.bn3,
|
415
|
+
hf_block_prefix + f"bn{bn_count}",
|
416
|
+
)
|
417
|
+
bn_count += 1
|
418
|
+
elif block_type == "cba":
|
419
|
+
cba_block_layer = backbone.get_layer(keras_block_prefix)
|
420
|
+
# Initial Expansion Conv
|
421
|
+
port_conv2d(
|
422
|
+
cba_block_layer.conv1,
|
423
|
+
hf_block_prefix + "conv",
|
424
|
+
port_bias=False,
|
425
|
+
)
|
426
|
+
conv_pw_count += 1
|
427
|
+
port_batch_normalization(
|
428
|
+
cba_block_layer.bn1,
|
429
|
+
hf_block_prefix + f"bn{bn_count}",
|
430
|
+
)
|
431
|
+
bn_count += 1
|
432
|
+
|
433
|
+
# Head/Top
|
434
|
+
port_conv2d(backbone.get_layer("top_conv"), "conv_head", port_bias=False)
|
435
|
+
port_batch_normalization(backbone.get_layer("top_bn"), "bn2")
|
436
|
+
|
437
|
+
|
438
|
+
def convert_head(task, loader, timm_config):
|
439
|
+
classifier_prefix = timm_config["pretrained_cfg"]["classifier"]
|
440
|
+
prefix = f"{classifier_prefix}."
|
441
|
+
loader.port_weight(
|
442
|
+
task.output_dense.kernel,
|
443
|
+
hf_weight_key=prefix + "weight",
|
444
|
+
hook_fn=lambda x, _: np.transpose(np.squeeze(x)),
|
445
|
+
)
|
446
|
+
loader.port_weight(
|
447
|
+
task.output_dense.bias,
|
448
|
+
hf_weight_key=prefix + "bias",
|
449
|
+
)
|