keras-hub-nightly 0.16.1.dev202410200345__py3-none-any.whl → 0.19.0.dev202412070351__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. keras_hub/api/layers/__init__.py +12 -0
  2. keras_hub/api/models/__init__.py +32 -0
  3. keras_hub/src/bounding_box/__init__.py +2 -0
  4. keras_hub/src/bounding_box/converters.py +102 -12
  5. keras_hub/src/layers/modeling/rms_normalization.py +34 -0
  6. keras_hub/src/layers/modeling/transformer_encoder.py +27 -7
  7. keras_hub/src/layers/preprocessing/image_converter.py +5 -0
  8. keras_hub/src/models/albert/albert_presets.py +0 -8
  9. keras_hub/src/models/bart/bart_presets.py +0 -6
  10. keras_hub/src/models/bert/bert_presets.py +0 -20
  11. keras_hub/src/models/bloom/bloom_presets.py +0 -16
  12. keras_hub/src/models/clip/__init__.py +5 -0
  13. keras_hub/src/models/clip/clip_backbone.py +286 -0
  14. keras_hub/src/models/clip/clip_encoder_block.py +19 -4
  15. keras_hub/src/models/clip/clip_image_converter.py +8 -0
  16. keras_hub/src/models/clip/clip_presets.py +93 -0
  17. keras_hub/src/models/clip/clip_text_encoder.py +4 -1
  18. keras_hub/src/models/clip/clip_tokenizer.py +18 -3
  19. keras_hub/src/models/clip/clip_vision_embedding.py +101 -0
  20. keras_hub/src/models/clip/clip_vision_encoder.py +159 -0
  21. keras_hub/src/models/deberta_v3/deberta_v3_presets.py +0 -10
  22. keras_hub/src/models/deeplab_v3/deeplab_v3_presets.py +0 -2
  23. keras_hub/src/models/deeplab_v3/deeplab_v3_segmenter.py +5 -3
  24. keras_hub/src/models/densenet/densenet_backbone.py +1 -1
  25. keras_hub/src/models/densenet/densenet_presets.py +0 -6
  26. keras_hub/src/models/distil_bert/distil_bert_presets.py +0 -6
  27. keras_hub/src/models/efficientnet/__init__.py +9 -0
  28. keras_hub/src/models/efficientnet/cba.py +141 -0
  29. keras_hub/src/models/efficientnet/efficientnet_backbone.py +139 -56
  30. keras_hub/src/models/efficientnet/efficientnet_image_classifier.py +14 -0
  31. keras_hub/src/models/efficientnet/efficientnet_image_classifier_preprocessor.py +16 -0
  32. keras_hub/src/models/efficientnet/efficientnet_image_converter.py +10 -0
  33. keras_hub/src/models/efficientnet/efficientnet_presets.py +192 -0
  34. keras_hub/src/models/efficientnet/fusedmbconv.py +81 -36
  35. keras_hub/src/models/efficientnet/mbconv.py +52 -21
  36. keras_hub/src/models/electra/electra_presets.py +0 -12
  37. keras_hub/src/models/f_net/f_net_presets.py +0 -4
  38. keras_hub/src/models/falcon/falcon_presets.py +0 -2
  39. keras_hub/src/models/flux/__init__.py +5 -0
  40. keras_hub/src/models/flux/flux_layers.py +494 -0
  41. keras_hub/src/models/flux/flux_maths.py +218 -0
  42. keras_hub/src/models/flux/flux_model.py +231 -0
  43. keras_hub/src/models/flux/flux_presets.py +14 -0
  44. keras_hub/src/models/flux/flux_text_to_image.py +142 -0
  45. keras_hub/src/models/flux/flux_text_to_image_preprocessor.py +73 -0
  46. keras_hub/src/models/gemma/gemma_presets.py +0 -40
  47. keras_hub/src/models/gpt2/gpt2_presets.py +0 -9
  48. keras_hub/src/models/image_object_detector.py +87 -0
  49. keras_hub/src/models/image_object_detector_preprocessor.py +57 -0
  50. keras_hub/src/models/image_to_image.py +16 -10
  51. keras_hub/src/models/inpaint.py +20 -13
  52. keras_hub/src/models/llama/llama_backbone.py +1 -1
  53. keras_hub/src/models/llama/llama_presets.py +5 -15
  54. keras_hub/src/models/llama3/llama3_presets.py +0 -8
  55. keras_hub/src/models/mistral/mistral_presets.py +0 -6
  56. keras_hub/src/models/mit/mit_backbone.py +41 -27
  57. keras_hub/src/models/mit/mit_layers.py +9 -7
  58. keras_hub/src/models/mit/mit_presets.py +12 -24
  59. keras_hub/src/models/opt/opt_presets.py +0 -8
  60. keras_hub/src/models/pali_gemma/pali_gemma_backbone.py +61 -11
  61. keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py +21 -23
  62. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +166 -10
  63. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +12 -11
  64. keras_hub/src/models/phi3/phi3_presets.py +0 -4
  65. keras_hub/src/models/resnet/resnet_presets.py +10 -42
  66. keras_hub/src/models/retinanet/__init__.py +5 -0
  67. keras_hub/src/models/retinanet/anchor_generator.py +52 -53
  68. keras_hub/src/models/retinanet/feature_pyramid.py +99 -36
  69. keras_hub/src/models/retinanet/non_max_supression.py +1 -0
  70. keras_hub/src/models/retinanet/prediction_head.py +192 -0
  71. keras_hub/src/models/retinanet/retinanet_backbone.py +146 -0
  72. keras_hub/src/models/retinanet/retinanet_image_converter.py +53 -0
  73. keras_hub/src/models/retinanet/retinanet_label_encoder.py +49 -51
  74. keras_hub/src/models/retinanet/retinanet_object_detector.py +382 -0
  75. keras_hub/src/models/retinanet/retinanet_object_detector_preprocessor.py +14 -0
  76. keras_hub/src/models/retinanet/retinanet_presets.py +15 -0
  77. keras_hub/src/models/roberta/roberta_presets.py +0 -4
  78. keras_hub/src/models/sam/sam_backbone.py +0 -1
  79. keras_hub/src/models/sam/sam_image_segmenter.py +9 -10
  80. keras_hub/src/models/sam/sam_presets.py +0 -6
  81. keras_hub/src/models/segformer/__init__.py +8 -0
  82. keras_hub/src/models/segformer/segformer_backbone.py +163 -0
  83. keras_hub/src/models/segformer/segformer_image_converter.py +8 -0
  84. keras_hub/src/models/segformer/segformer_image_segmenter.py +171 -0
  85. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +31 -0
  86. keras_hub/src/models/segformer/segformer_presets.py +124 -0
  87. keras_hub/src/models/stable_diffusion_3/mmdit.py +41 -0
  88. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +38 -21
  89. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +3 -3
  90. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +3 -3
  91. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +28 -4
  92. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image.py +1 -1
  93. keras_hub/src/models/t5/t5_backbone.py +5 -4
  94. keras_hub/src/models/t5/t5_presets.py +41 -13
  95. keras_hub/src/models/text_to_image.py +13 -5
  96. keras_hub/src/models/vgg/vgg_backbone.py +1 -1
  97. keras_hub/src/models/vgg/vgg_presets.py +0 -8
  98. keras_hub/src/models/whisper/whisper_audio_converter.py +1 -1
  99. keras_hub/src/models/whisper/whisper_presets.py +0 -20
  100. keras_hub/src/models/xlm_roberta/xlm_roberta_presets.py +0 -4
  101. keras_hub/src/tests/test_case.py +25 -0
  102. keras_hub/src/utils/preset_utils.py +17 -4
  103. keras_hub/src/utils/timm/convert_efficientnet.py +449 -0
  104. keras_hub/src/utils/timm/preset_loader.py +3 -0
  105. keras_hub/src/version_utils.py +1 -1
  106. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/METADATA +15 -26
  107. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/RECORD +109 -76
  108. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/WHEEL +1 -1
  109. {keras_hub_nightly-0.16.1.dev202410200345.dist-info → keras_hub_nightly-0.19.0.dev202412070351.dist-info}/top_level.txt +0 -0
@@ -7,9 +7,7 @@ backbone_presets = {
7
7
  "English speech data."
8
8
  ),
9
9
  "params": 37184256,
10
- "official_name": "Whisper",
11
10
  "path": "whisper",
12
- "model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
13
11
  },
14
12
  "kaggle_handle": "kaggle://keras/whisper/keras/whisper_tiny_en/3",
15
13
  },
@@ -20,9 +18,7 @@ backbone_presets = {
20
18
  "English speech data."
21
19
  ),
22
20
  "params": 124439808,
23
- "official_name": "Whisper",
24
21
  "path": "whisper",
25
- "model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
26
22
  },
27
23
  "kaggle_handle": "kaggle://keras/whisper/keras/whisper_base_en/3",
28
24
  },
@@ -33,9 +29,7 @@ backbone_presets = {
33
29
  "English speech data."
34
30
  ),
35
31
  "params": 241734144,
36
- "official_name": "Whisper",
37
32
  "path": "whisper",
38
- "model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
39
33
  },
40
34
  "kaggle_handle": "kaggle://keras/whisper/keras/whisper_small_en/3",
41
35
  },
@@ -46,9 +40,7 @@ backbone_presets = {
46
40
  "English speech data."
47
41
  ),
48
42
  "params": 763856896,
49
- "official_name": "Whisper",
50
43
  "path": "whisper",
51
- "model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
52
44
  },
53
45
  "kaggle_handle": "kaggle://keras/whisper/keras/whisper_medium_en/3",
54
46
  },
@@ -59,9 +51,7 @@ backbone_presets = {
59
51
  "multilingual speech data."
60
52
  ),
61
53
  "params": 37760640,
62
- "official_name": "Whisper",
63
54
  "path": "whisper",
64
- "model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
65
55
  },
66
56
  "kaggle_handle": "kaggle://keras/whisper/keras/whisper_tiny_multi/3",
67
57
  },
@@ -72,9 +62,7 @@ backbone_presets = {
72
62
  "multilingual speech data."
73
63
  ),
74
64
  "params": 72593920,
75
- "official_name": "Whisper",
76
65
  "path": "whisper",
77
- "model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
78
66
  },
79
67
  "kaggle_handle": "kaggle://keras/whisper/keras/whisper_base_multi/3",
80
68
  },
@@ -85,9 +73,7 @@ backbone_presets = {
85
73
  "multilingual speech data."
86
74
  ),
87
75
  "params": 241734912,
88
- "official_name": "Whisper",
89
76
  "path": "whisper",
90
- "model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
91
77
  },
92
78
  "kaggle_handle": "kaggle://keras/whisper/keras/whisper_small_multi/3",
93
79
  },
@@ -98,9 +84,7 @@ backbone_presets = {
98
84
  "multilingual speech data."
99
85
  ),
100
86
  "params": 763857920,
101
- "official_name": "Whisper",
102
87
  "path": "whisper",
103
- "model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
104
88
  },
105
89
  "kaggle_handle": "kaggle://keras/whisper/keras/whisper_medium_multi/3",
106
90
  },
@@ -111,9 +95,7 @@ backbone_presets = {
111
95
  "multilingual speech data."
112
96
  ),
113
97
  "params": 1543304960,
114
- "official_name": "Whisper",
115
98
  "path": "whisper",
116
- "model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
117
99
  },
118
100
  "kaggle_handle": "kaggle://keras/whisper/keras/whisper_large_multi/3",
119
101
  },
@@ -125,9 +107,7 @@ backbone_presets = {
125
107
  "of `whisper_large_multi`."
126
108
  ),
127
109
  "params": 1543304960,
128
- "official_name": "Whisper",
129
110
  "path": "whisper",
130
- "model_card": "https://github.com/openai/whisper/blob/main/model-card.md",
131
111
  },
132
112
  "kaggle_handle": "kaggle://keras/whisper/keras/whisper_large_multi_v2/3",
133
113
  },
@@ -8,9 +8,7 @@ backbone_presets = {
8
8
  "Trained on CommonCrawl in 100 languages."
9
9
  ),
10
10
  "params": 277450752,
11
- "official_name": "XLM-RoBERTa",
12
11
  "path": "xlm_roberta",
13
- "model_card": "https://github.com/facebookresearch/fairseq/blob/main/examples/xlmr/README.md",
14
12
  },
15
13
  "kaggle_handle": "kaggle://keras/xlm_roberta/keras/xlm_roberta_base_multi/2",
16
14
  },
@@ -21,9 +19,7 @@ backbone_presets = {
21
19
  "Trained on CommonCrawl in 100 languages."
22
20
  ),
23
21
  "params": 558837760,
24
- "official_name": "XLM-RoBERTa",
25
22
  "path": "xlm_roberta",
26
- "model_card": "https://github.com/facebookresearch/fairseq/blob/main/examples/xlmr/README.md",
27
23
  },
28
24
  "kaggle_handle": "kaggle://keras/xlm_roberta/keras/xlm_roberta_large_multi/2",
29
25
  },
@@ -313,6 +313,14 @@ class TestCase(tf.test.TestCase, parameterized.TestCase):
313
313
 
314
314
  for policy in ["mixed_float16", "mixed_bfloat16", "bfloat16"]:
315
315
  policy = keras.mixed_precision.Policy(policy)
316
+ # Ensure the correct `dtype` is set for sublayers or submodels in
317
+ # `init_kwargs`.
318
+ original_init_kwargs = init_kwargs.copy()
319
+ for k, v in init_kwargs.items():
320
+ if isinstance(v, keras.Layer):
321
+ config = v.get_config()
322
+ config["dtype"] = policy
323
+ init_kwargs[k] = v.__class__.from_config(config)
316
324
  layer = cls(**{**init_kwargs, "dtype": policy})
317
325
  if isinstance(layer, keras.Model):
318
326
  output_data = layer(input_data)
@@ -343,8 +351,15 @@ class TestCase(tf.test.TestCase, parameterized.TestCase):
343
351
  continue
344
352
  self.assertEqual(policy.compute_dtype, sublayer.compute_dtype)
345
353
  self.assertEqual(policy.variable_dtype, sublayer.variable_dtype)
354
+ # Restore `init_kwargs`.
355
+ init_kwargs = original_init_kwargs
346
356
 
347
357
  def run_quantization_test(self, instance, cls, init_kwargs, input_data):
358
+ # TODO: revert the following if. This works around a torch
359
+ # quantization failure in `MultiHeadAttention` with Keras 3.7.
360
+ if keras.config.backend() == "torch":
361
+ return
362
+
348
363
  def _get_supported_layers(mode):
349
364
  supported_layers = [keras.layers.Dense, keras.layers.EinsumDense]
350
365
  if mode == "int8":
@@ -361,6 +376,14 @@ class TestCase(tf.test.TestCase, parameterized.TestCase):
361
376
  policy_map[layer.path] = keras.dtype_policies.get(
362
377
  f"{mode}_from_float32"
363
378
  )
379
+ # Ensure the correct `dtype` is set for sublayers or submodels in
380
+ # `init_kwargs`.
381
+ original_init_kwargs = init_kwargs.copy()
382
+ for k, v in init_kwargs.items():
383
+ if isinstance(v, keras.Layer):
384
+ config = v.get_config()
385
+ config["dtype"] = policy_map
386
+ init_kwargs[k] = v.__class__.from_config(config)
364
387
  # Instantiate the layer.
365
388
  model = cls(**{**init_kwargs, "dtype": policy_map})
366
389
  # Call layer eagerly.
@@ -382,6 +405,8 @@ class TestCase(tf.test.TestCase, parameterized.TestCase):
382
405
  # Check weights loading.
383
406
  weights = model.get_weights()
384
407
  revived_model.set_weights(weights)
408
+ # Restore `init_kwargs`.
409
+ init_kwargs = original_init_kwargs
385
410
 
386
411
  def run_model_saving_test(
387
412
  self,
@@ -563,10 +563,8 @@ class PresetLoader:
563
563
  backbone_kwargs["dtype"] = kwargs.pop("dtype", None)
564
564
 
565
565
  # Forward `height` and `width` to backbone when using `TextToImage`.
566
- if "height" in kwargs:
567
- backbone_kwargs["height"] = kwargs.pop("height", None)
568
- if "width" in kwargs:
569
- backbone_kwargs["width"] = kwargs.pop("width", None)
566
+ if "image_shape" in kwargs:
567
+ backbone_kwargs["image_shape"] = kwargs.pop("image_shape", None)
570
568
 
571
569
  return backbone_kwargs, kwargs
572
570
 
@@ -660,6 +658,12 @@ class KerasPresetLoader(PresetLoader):
660
658
  cls, load_weights, load_task_weights, **kwargs
661
659
  )
662
660
  # We found a `task.json` with a complete config for our class.
661
+ # Forward backbone args.
662
+ backbone_kwargs, kwargs = self.get_backbone_kwargs(**kwargs)
663
+ if "backbone" in task_config["config"]:
664
+ backbone_config = task_config["config"]["backbone"]["config"]
665
+ backbone_config = {**backbone_config, **backbone_kwargs}
666
+ task_config["config"]["backbone"]["config"] = backbone_config
663
667
  task = load_serialized_object(task_config, **kwargs)
664
668
  if task.preprocessor and hasattr(
665
669
  task.preprocessor, "load_preset_assets"
@@ -767,14 +771,23 @@ class KerasPresetSaver:
767
771
  config_file.write(json.dumps(config, indent=4))
768
772
 
769
773
  def _save_metadata(self, layer):
774
+ from keras_hub.src.models.task import Task
770
775
  from keras_hub.src.version_utils import __version__ as keras_hub_version
771
776
 
777
+ # Find all tasks that are compatible with the backbone.
778
+ # E.g. for `BertBackbone` we would have `TextClassifier` and `MaskedLM`.
779
+ # For `ResNetBackbone` we would have `ImageClassifier`.
780
+ tasks = list_subclasses(Task)
781
+ tasks = filter(lambda x: x.backbone_cls == type(layer), tasks)
782
+ tasks = [task.__base__.__name__ for task in tasks]
783
+
772
784
  keras_version = keras.version() if hasattr(keras, "version") else None
773
785
  metadata = {
774
786
  "keras_version": keras_version,
775
787
  "keras_hub_version": keras_hub_version,
776
788
  "parameter_count": layer.count_params(),
777
789
  "date_saved": datetime.datetime.now().strftime("%Y-%m-%d@%H:%M:%S"),
790
+ "tasks": tasks,
778
791
  }
779
792
  metadata_path = os.path.join(self.preset_dir, METADATA_FILE)
780
793
  with open(metadata_path, "w") as metadata_file:
@@ -0,0 +1,449 @@
1
+ import math
2
+
3
+ import numpy as np
4
+
5
+ from keras_hub.src.models.efficientnet.efficientnet_backbone import (
6
+ EfficientNetBackbone,
7
+ )
8
+
9
+ backbone_cls = EfficientNetBackbone
10
+
11
+
12
+ VARIANT_MAP = {
13
+ "b0": {
14
+ "stackwise_width_coefficients": [1.0] * 7,
15
+ "stackwise_depth_coefficients": [1.0] * 7,
16
+ "stackwise_squeeze_and_excite_ratios": [0.25] * 7,
17
+ },
18
+ "b1": {
19
+ "stackwise_width_coefficients": [1.0] * 7,
20
+ "stackwise_depth_coefficients": [1.1] * 7,
21
+ "stackwise_squeeze_and_excite_ratios": [0.25] * 7,
22
+ },
23
+ "b2": {
24
+ "stackwise_width_coefficients": [1.1] * 7,
25
+ "stackwise_depth_coefficients": [1.2] * 7,
26
+ "stackwise_squeeze_and_excite_ratios": [0.25] * 7,
27
+ },
28
+ "b3": {
29
+ "stackwise_width_coefficients": [1.2] * 7,
30
+ "stackwise_depth_coefficients": [1.4] * 7,
31
+ "stackwise_squeeze_and_excite_ratios": [0.25] * 7,
32
+ },
33
+ "b4": {
34
+ "stackwise_width_coefficients": [1.4] * 7,
35
+ "stackwise_depth_coefficients": [1.8] * 7,
36
+ "stackwise_squeeze_and_excite_ratios": [0.25] * 7,
37
+ },
38
+ "b5": {
39
+ "stackwise_width_coefficients": [1.6] * 7,
40
+ "stackwise_depth_coefficients": [2.2] * 7,
41
+ "stackwise_squeeze_and_excite_ratios": [0.25] * 7,
42
+ },
43
+ "lite0": {
44
+ "stackwise_width_coefficients": [1.0] * 7,
45
+ "stackwise_depth_coefficients": [1.0] * 7,
46
+ "stackwise_squeeze_and_excite_ratios": [0] * 7,
47
+ "activation": "relu6",
48
+ },
49
+ "el": {
50
+ "stackwise_width_coefficients": [1.2] * 6,
51
+ "stackwise_depth_coefficients": [1.4] * 6,
52
+ "stackwise_kernel_sizes": [3, 3, 3, 5, 5, 5],
53
+ "stackwise_num_repeats": [1, 2, 4, 5, 4, 2],
54
+ "stackwise_input_filters": [32, 24, 32, 48, 96, 144],
55
+ "stackwise_output_filters": [24, 32, 48, 96, 144, 192],
56
+ "stackwise_expansion_ratios": [4, 8, 8, 8, 8, 8],
57
+ "stackwise_strides": [1, 2, 2, 2, 1, 2],
58
+ "stackwise_squeeze_and_excite_ratios": [0] * 6,
59
+ "stackwise_block_types": ["fused"] * 3 + ["unfused"] * 3,
60
+ "stackwise_force_input_filters": [24, 0, 0, 0, 0, 0],
61
+ "stackwise_nores_option": [True] + [False] * 5,
62
+ "activation": "relu",
63
+ },
64
+ "em": {
65
+ "stackwise_width_coefficients": [1.0] * 6,
66
+ "stackwise_depth_coefficients": [1.1] * 6,
67
+ "stackwise_kernel_sizes": [3, 3, 3, 5, 5, 5],
68
+ "stackwise_num_repeats": [1, 2, 4, 5, 4, 2],
69
+ "stackwise_input_filters": [32, 24, 32, 48, 96, 144],
70
+ "stackwise_output_filters": [24, 32, 48, 96, 144, 192],
71
+ "stackwise_expansion_ratios": [4, 8, 8, 8, 8, 8],
72
+ "stackwise_strides": [1, 2, 2, 2, 1, 2],
73
+ "stackwise_squeeze_and_excite_ratios": [0] * 6,
74
+ "stackwise_block_types": ["fused"] * 3 + ["unfused"] * 3,
75
+ "stackwise_force_input_filters": [24, 0, 0, 0, 0, 0],
76
+ "stackwise_nores_option": [True] + [False] * 5,
77
+ "activation": "relu",
78
+ },
79
+ "es": {
80
+ "stackwise_width_coefficients": [1.0] * 6,
81
+ "stackwise_depth_coefficients": [1.0] * 6,
82
+ "stackwise_kernel_sizes": [3, 3, 3, 5, 5, 5],
83
+ "stackwise_num_repeats": [1, 2, 4, 5, 4, 2],
84
+ "stackwise_input_filters": [32, 24, 32, 48, 96, 144],
85
+ "stackwise_output_filters": [24, 32, 48, 96, 144, 192],
86
+ "stackwise_expansion_ratios": [4, 8, 8, 8, 8, 8],
87
+ "stackwise_strides": [1, 2, 2, 2, 1, 2],
88
+ "stackwise_squeeze_and_excite_ratios": [0] * 6,
89
+ "stackwise_block_types": ["fused"] * 3 + ["unfused"] * 3,
90
+ "stackwise_force_input_filters": [24, 0, 0, 0, 0, 0],
91
+ "stackwise_nores_option": [True] + [False] * 5,
92
+ "activation": "relu",
93
+ },
94
+ "rw_m": {
95
+ "stackwise_width_coefficients": [1.2] * 6,
96
+ "stackwise_depth_coefficients": [1.2] * 4 + [1.6] * 2,
97
+ "stackwise_kernel_sizes": [3, 3, 3, 3, 3, 3],
98
+ "stackwise_num_repeats": [2, 4, 4, 6, 9, 15],
99
+ "stackwise_input_filters": [24, 24, 48, 64, 128, 160],
100
+ "stackwise_output_filters": [24, 48, 64, 128, 160, 272],
101
+ "stackwise_expansion_ratios": [1, 4, 4, 4, 6, 6],
102
+ "stackwise_strides": [1, 2, 2, 2, 1, 2],
103
+ "stackwise_squeeze_and_excite_ratios": [0, 0, 0, 0.25, 0.25, 0.25],
104
+ "stackwise_block_types": ["fused"] * 3 + ["unfused"] * 3,
105
+ "stackwise_force_input_filters": [0, 0, 0, 0, 0, 0],
106
+ "stackwise_nores_option": [False] * 6,
107
+ "activation": "silu",
108
+ "num_features": 1792,
109
+ },
110
+ "rw_s": {
111
+ "stackwise_width_coefficients": [1.0] * 6,
112
+ "stackwise_depth_coefficients": [1.0] * 6,
113
+ "stackwise_kernel_sizes": [3, 3, 3, 3, 3, 3],
114
+ "stackwise_num_repeats": [2, 4, 4, 6, 9, 15],
115
+ "stackwise_input_filters": [24, 24, 48, 64, 128, 160],
116
+ "stackwise_output_filters": [24, 48, 64, 128, 160, 272],
117
+ "stackwise_expansion_ratios": [1, 4, 4, 4, 6, 6],
118
+ "stackwise_strides": [1, 2, 2, 2, 1, 2],
119
+ "stackwise_squeeze_and_excite_ratios": [0, 0, 0, 0.25, 0.25, 0.25],
120
+ "stackwise_block_types": ["fused"] * 3 + ["unfused"] * 3,
121
+ "stackwise_force_input_filters": [0, 0, 0, 0, 0, 0],
122
+ "stackwise_nores_option": [False] * 6,
123
+ "activation": "silu",
124
+ "num_features": 1792,
125
+ },
126
+ "rw_t": {
127
+ "stackwise_width_coefficients": [0.8] * 6,
128
+ "stackwise_depth_coefficients": [0.9] * 6,
129
+ "stackwise_kernel_sizes": [3, 3, 3, 3, 3, 3],
130
+ "stackwise_num_repeats": [2, 4, 4, 6, 9, 15],
131
+ "stackwise_input_filters": [24, 24, 48, 64, 128, 160],
132
+ "stackwise_output_filters": [24, 48, 64, 128, 160, 256],
133
+ "stackwise_expansion_ratios": [1, 4, 4, 4, 6, 6],
134
+ "stackwise_strides": [1, 2, 2, 2, 1, 2],
135
+ "stackwise_squeeze_and_excite_ratios": [0, 0, 0, 0.25, 0.25, 0.25],
136
+ "stackwise_block_types": ["cba"] + ["fused"] * 2 + ["unfused"] * 3,
137
+ "stackwise_force_input_filters": [0, 0, 0, 0, 0, 0],
138
+ "stackwise_nores_option": [False] * 6,
139
+ "activation": "silu",
140
+ },
141
+ }
142
+
143
+
144
+ def convert_backbone_config(timm_config):
145
+ timm_architecture = timm_config["architecture"]
146
+
147
+ base_kwargs = {
148
+ "stackwise_kernel_sizes": [3, 3, 5, 3, 5, 5, 3],
149
+ "stackwise_num_repeats": [1, 2, 2, 3, 3, 4, 1],
150
+ "stackwise_input_filters": [32, 16, 24, 40, 80, 112, 192],
151
+ "stackwise_output_filters": [16, 24, 40, 80, 112, 192, 320],
152
+ "stackwise_expansion_ratios": [1, 6, 6, 6, 6, 6, 6],
153
+ "stackwise_strides": [1, 2, 2, 2, 1, 2, 1],
154
+ "stackwise_block_types": ["v1"] * 7,
155
+ "min_depth": None,
156
+ "include_stem_padding": True,
157
+ "use_depth_divisor_as_min_depth": True,
158
+ "cap_round_filter_decrease": True,
159
+ "stem_conv_padding": "valid",
160
+ "batch_norm_momentum": 0.9,
161
+ "batch_norm_epsilon": 1e-5,
162
+ "dropout": 0,
163
+ "projection_activation": None,
164
+ }
165
+
166
+ variant = "_".join(timm_architecture.split("_")[1:])
167
+
168
+ if variant not in VARIANT_MAP:
169
+ raise ValueError(
170
+ f"Currently, the architecture {timm_architecture} is not supported."
171
+ )
172
+
173
+ base_kwargs.update(VARIANT_MAP[variant])
174
+
175
+ return base_kwargs
176
+
177
+
178
+ def convert_weights(backbone, loader, timm_config):
179
+ timm_architecture = timm_config["architecture"]
180
+ variant = "_".join(timm_architecture.split("_")[1:])
181
+
182
+ def port_conv2d(keras_layer, hf_weight_prefix, port_bias=True):
183
+ loader.port_weight(
184
+ keras_layer.kernel,
185
+ hf_weight_key=f"{hf_weight_prefix}.weight",
186
+ hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)),
187
+ )
188
+
189
+ if port_bias:
190
+ loader.port_weight(
191
+ keras_layer.bias,
192
+ hf_weight_key=f"{hf_weight_prefix}.bias",
193
+ )
194
+
195
+ def port_depthwise_conv2d(
196
+ keras_layer,
197
+ hf_weight_prefix,
198
+ port_bias=True,
199
+ depth_multiplier=1,
200
+ ):
201
+
202
+ def convert_pt_conv2d_kernel(pt_kernel):
203
+ out_channels, in_channels_per_group, height, width = pt_kernel.shape
204
+ # PT Convs are depthwise convs if and only if in_channels_per_group == 1
205
+ assert in_channels_per_group == 1
206
+ pt_kernel = np.transpose(pt_kernel, (2, 3, 0, 1))
207
+ in_channels = out_channels // depth_multiplier
208
+ return np.reshape(
209
+ pt_kernel, (height, width, in_channels, depth_multiplier)
210
+ )
211
+
212
+ loader.port_weight(
213
+ keras_layer.kernel,
214
+ hf_weight_key=f"{hf_weight_prefix}.weight",
215
+ hook_fn=lambda x, _: convert_pt_conv2d_kernel(x),
216
+ )
217
+
218
+ if port_bias:
219
+ loader.port_weight(
220
+ keras_layer.bias,
221
+ hf_weight_key=f"{hf_weight_prefix}.bias",
222
+ )
223
+
224
+ def port_batch_normalization(keras_layer, hf_weight_prefix):
225
+ loader.port_weight(
226
+ keras_layer.gamma,
227
+ hf_weight_key=f"{hf_weight_prefix}.weight",
228
+ )
229
+ loader.port_weight(
230
+ keras_layer.beta,
231
+ hf_weight_key=f"{hf_weight_prefix}.bias",
232
+ )
233
+ loader.port_weight(
234
+ keras_layer.moving_mean,
235
+ hf_weight_key=f"{hf_weight_prefix}.running_mean",
236
+ )
237
+ loader.port_weight(
238
+ keras_layer.moving_variance,
239
+ hf_weight_key=f"{hf_weight_prefix}.running_var",
240
+ )
241
+ # do we need num batches tracked?
242
+
243
+ # Stem
244
+ port_conv2d(backbone.get_layer("stem_conv"), "conv_stem", port_bias=False)
245
+ port_batch_normalization(backbone.get_layer("stem_bn"), "bn1")
246
+
247
+ # Stages
248
+ num_stacks = len(backbone.stackwise_kernel_sizes)
249
+
250
+ for stack_index in range(num_stacks):
251
+
252
+ block_type = backbone.stackwise_block_types[stack_index]
253
+ expansion_ratio = backbone.stackwise_expansion_ratios[stack_index]
254
+ repeats = backbone.stackwise_num_repeats[stack_index]
255
+ stack_depth_coefficient = backbone.stackwise_depth_coefficients[
256
+ stack_index
257
+ ]
258
+
259
+ repeats = int(math.ceil(stack_depth_coefficient * repeats))
260
+
261
+ se_ratio = VARIANT_MAP[variant]["stackwise_squeeze_and_excite_ratios"][
262
+ stack_index
263
+ ]
264
+
265
+ for block_idx in range(repeats):
266
+
267
+ conv_pw_count = 0
268
+ bn_count = 1
269
+
270
+ # 97 is the start of the lowercase alphabet.
271
+ letter_identifier = chr(block_idx + 97)
272
+
273
+ keras_block_prefix = f"block{stack_index+1}{letter_identifier}_"
274
+ hf_block_prefix = f"blocks.{stack_index}.{block_idx}."
275
+
276
+ if block_type == "v1":
277
+ conv_pw_name_map = ["conv_pw", "conv_pwl"]
278
+ # Initial Expansion Conv
279
+ if expansion_ratio != 1:
280
+ port_conv2d(
281
+ backbone.get_layer(keras_block_prefix + "expand_conv"),
282
+ hf_block_prefix + conv_pw_name_map[conv_pw_count],
283
+ port_bias=False,
284
+ )
285
+ conv_pw_count += 1
286
+ port_batch_normalization(
287
+ backbone.get_layer(keras_block_prefix + "expand_bn"),
288
+ hf_block_prefix + f"bn{bn_count}",
289
+ )
290
+ bn_count += 1
291
+
292
+ # Depthwise Conv
293
+ port_depthwise_conv2d(
294
+ backbone.get_layer(keras_block_prefix + "dwconv"),
295
+ hf_block_prefix + "conv_dw",
296
+ port_bias=False,
297
+ )
298
+ port_batch_normalization(
299
+ backbone.get_layer(keras_block_prefix + "dwconv_bn"),
300
+ hf_block_prefix + f"bn{bn_count}",
301
+ )
302
+ bn_count += 1
303
+
304
+ if 0 < se_ratio <= 1:
305
+ # Squeeze and Excite
306
+ port_conv2d(
307
+ backbone.get_layer(keras_block_prefix + "se_reduce"),
308
+ hf_block_prefix + "se.conv_reduce",
309
+ )
310
+ port_conv2d(
311
+ backbone.get_layer(keras_block_prefix + "se_expand"),
312
+ hf_block_prefix + "se.conv_expand",
313
+ )
314
+
315
+ # Output/Projection
316
+ port_conv2d(
317
+ backbone.get_layer(keras_block_prefix + "project"),
318
+ hf_block_prefix + conv_pw_name_map[conv_pw_count],
319
+ port_bias=False,
320
+ )
321
+ conv_pw_count += 1
322
+ port_batch_normalization(
323
+ backbone.get_layer(keras_block_prefix + "project_bn"),
324
+ hf_block_prefix + f"bn{bn_count}",
325
+ )
326
+ bn_count += 1
327
+ elif block_type == "fused":
328
+ fused_block_layer = backbone.get_layer(keras_block_prefix)
329
+
330
+ # Initial Expansion Conv
331
+ port_conv2d(
332
+ fused_block_layer.conv1,
333
+ hf_block_prefix + "conv_exp",
334
+ port_bias=False,
335
+ )
336
+ conv_pw_count += 1
337
+ port_batch_normalization(
338
+ fused_block_layer.bn1,
339
+ hf_block_prefix + f"bn{bn_count}",
340
+ )
341
+ bn_count += 1
342
+
343
+ if 0 < se_ratio <= 1:
344
+ # Squeeze and Excite
345
+ port_conv2d(
346
+ fused_block_layer.se_conv1,
347
+ hf_block_prefix + "se.conv_reduce",
348
+ )
349
+ port_conv2d(
350
+ fused_block_layer.se_conv2,
351
+ hf_block_prefix + "se.conv_expand",
352
+ )
353
+
354
+ # Output/Projection
355
+ port_conv2d(
356
+ fused_block_layer.output_conv,
357
+ hf_block_prefix + "conv_pwl",
358
+ port_bias=False,
359
+ )
360
+ conv_pw_count += 1
361
+ port_batch_normalization(
362
+ fused_block_layer.bn2,
363
+ hf_block_prefix + f"bn{bn_count}",
364
+ )
365
+ bn_count += 1
366
+
367
+ elif block_type == "unfused":
368
+ unfused_block_layer = backbone.get_layer(keras_block_prefix)
369
+ # Initial Expansion Conv
370
+ if expansion_ratio != 1:
371
+ port_conv2d(
372
+ unfused_block_layer.conv1,
373
+ hf_block_prefix + "conv_pw",
374
+ port_bias=False,
375
+ )
376
+ conv_pw_count += 1
377
+ port_batch_normalization(
378
+ unfused_block_layer.bn1,
379
+ hf_block_prefix + f"bn{bn_count}",
380
+ )
381
+ bn_count += 1
382
+
383
+ # Depthwise Conv
384
+ port_depthwise_conv2d(
385
+ unfused_block_layer.depthwise,
386
+ hf_block_prefix + "conv_dw",
387
+ port_bias=False,
388
+ )
389
+ port_batch_normalization(
390
+ unfused_block_layer.bn2,
391
+ hf_block_prefix + f"bn{bn_count}",
392
+ )
393
+ bn_count += 1
394
+
395
+ if 0 < se_ratio <= 1:
396
+ # Squeeze and Excite
397
+ port_conv2d(
398
+ unfused_block_layer.se_conv1,
399
+ hf_block_prefix + "se.conv_reduce",
400
+ )
401
+ port_conv2d(
402
+ unfused_block_layer.se_conv2,
403
+ hf_block_prefix + "se.conv_expand",
404
+ )
405
+
406
+ # Output/Projection
407
+ port_conv2d(
408
+ unfused_block_layer.output_conv,
409
+ hf_block_prefix + "conv_pwl",
410
+ port_bias=False,
411
+ )
412
+ conv_pw_count += 1
413
+ port_batch_normalization(
414
+ unfused_block_layer.bn3,
415
+ hf_block_prefix + f"bn{bn_count}",
416
+ )
417
+ bn_count += 1
418
+ elif block_type == "cba":
419
+ cba_block_layer = backbone.get_layer(keras_block_prefix)
420
+ # Initial Expansion Conv
421
+ port_conv2d(
422
+ cba_block_layer.conv1,
423
+ hf_block_prefix + "conv",
424
+ port_bias=False,
425
+ )
426
+ conv_pw_count += 1
427
+ port_batch_normalization(
428
+ cba_block_layer.bn1,
429
+ hf_block_prefix + f"bn{bn_count}",
430
+ )
431
+ bn_count += 1
432
+
433
+ # Head/Top
434
+ port_conv2d(backbone.get_layer("top_conv"), "conv_head", port_bias=False)
435
+ port_batch_normalization(backbone.get_layer("top_bn"), "bn2")
436
+
437
+
438
+ def convert_head(task, loader, timm_config):
439
+ classifier_prefix = timm_config["pretrained_cfg"]["classifier"]
440
+ prefix = f"{classifier_prefix}."
441
+ loader.port_weight(
442
+ task.output_dense.kernel,
443
+ hf_weight_key=prefix + "weight",
444
+ hook_fn=lambda x, _: np.transpose(np.squeeze(x)),
445
+ )
446
+ loader.port_weight(
447
+ task.output_dense.bias,
448
+ hf_weight_key=prefix + "bias",
449
+ )