keras-hub-nightly 0.21.0.dev202505050407__py3-none-any.whl → 0.21.0.dev202505070407__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. keras_hub/models/__init__.py +21 -0
  2. keras_hub/src/models/backbone.py +5 -2
  3. keras_hub/src/models/cspnet/cspnet_backbone.py +51 -26
  4. keras_hub/src/models/cspnet/cspnet_presets.py +38 -3
  5. keras_hub/src/models/mixtral/mixtral_attention.py +263 -0
  6. keras_hub/src/models/mixtral/mixtral_backbone.py +207 -0
  7. keras_hub/src/models/mixtral/mixtral_causal_lm.py +281 -0
  8. keras_hub/src/models/mixtral/mixtral_causal_lm_preprocessor.py +76 -0
  9. keras_hub/src/models/mixtral/mixtral_decoder.py +494 -0
  10. keras_hub/src/models/mixtral/mixtral_layer_norm.py +34 -0
  11. keras_hub/src/models/mixtral/mixtral_tokenizer.py +21 -0
  12. keras_hub/src/models/qwen/qwen_attention.py +3 -1
  13. keras_hub/src/models/qwen/qwen_presets.py +61 -0
  14. keras_hub/src/models/qwen_moe/__init__.py +0 -0
  15. keras_hub/src/models/qwen_moe/qwen_moe_attention.py +377 -0
  16. keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +373 -0
  17. keras_hub/src/models/qwen_moe/qwen_moe_causal_lm.py +350 -0
  18. keras_hub/src/models/qwen_moe/qwen_moe_causal_lm_preprocessor.py +17 -0
  19. keras_hub/src/models/qwen_moe/qwen_moe_decoder.py +625 -0
  20. keras_hub/src/models/qwen_moe/qwen_moe_layernorm.py +32 -0
  21. keras_hub/src/models/qwen_moe/qwen_moe_tokenizer.py +46 -0
  22. keras_hub/src/models/retinanet/retinanet_image_converter.py +0 -13
  23. keras_hub/src/models/retinanet/retinanet_presets.py +2 -2
  24. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +0 -18
  25. keras_hub/src/models/segformer/segformer_presets.py +12 -12
  26. keras_hub/src/models/task.py +5 -2
  27. keras_hub/src/utils/keras_utils.py +11 -0
  28. keras_hub/src/utils/preset_utils.py +69 -9
  29. keras_hub/src/utils/tensor_utils.py +27 -1
  30. keras_hub/src/utils/timm/convert_cspnet.py +94 -23
  31. keras_hub/src/utils/timm/preset_loader.py +6 -6
  32. keras_hub/src/utils/transformers/convert_mixtral.py +139 -0
  33. keras_hub/src/utils/transformers/convert_qwen_moe.py +253 -0
  34. keras_hub/src/utils/transformers/preset_loader.py +6 -0
  35. keras_hub/src/version.py +1 -1
  36. keras_hub/tokenizers/__init__.py +6 -0
  37. {keras_hub_nightly-0.21.0.dev202505050407.dist-info → keras_hub_nightly-0.21.0.dev202505070407.dist-info}/METADATA +1 -1
  38. {keras_hub_nightly-0.21.0.dev202505050407.dist-info → keras_hub_nightly-0.21.0.dev202505070407.dist-info}/RECORD +40 -22
  39. {keras_hub_nightly-0.21.0.dev202505050407.dist-info → keras_hub_nightly-0.21.0.dev202505070407.dist-info}/WHEEL +0 -0
  40. {keras_hub_nightly-0.21.0.dev202505050407.dist-info → keras_hub_nightly-0.21.0.dev202505070407.dist-info}/top_level.txt +0 -0
@@ -6,16 +6,3 @@ from keras_hub.src.models.retinanet.retinanet_backbone import RetinaNetBackbone
6
6
  @keras_hub_export("keras_hub.layers.RetinaNetImageConverter")
7
7
  class RetinaNetImageConverter(ImageConverter):
8
8
  backbone_cls = RetinaNetBackbone
9
-
10
- def __init__(
11
- self,
12
- *args,
13
- **kwargs,
14
- ):
15
- # TODO: update presets and remove these old config options. They were
16
- # never needed.
17
- if "norm_mean" in kwargs:
18
- kwargs["offset"] = [-x for x in kwargs.pop("norm_mean")]
19
- if "norm_std" in kwargs:
20
- kwargs["scale"] = [1.0 / x for x in kwargs.pop("norm_std")]
21
- super().__init__(*args, **kwargs)
@@ -11,7 +11,7 @@ backbone_presets = {
11
11
  "params": 34121239,
12
12
  "path": "retinanet",
13
13
  },
14
- "kaggle_handle": "kaggle://keras/retinanet/keras/retinanet_resnet50_fpn_coco/3",
14
+ "kaggle_handle": "kaggle://keras/retinanet/keras/retinanet_resnet50_fpn_coco/4",
15
15
  },
16
16
  "retinanet_resnet50_fpn_v2_coco": {
17
17
  "metadata": {
@@ -22,6 +22,6 @@ backbone_presets = {
22
22
  "params": 31558592,
23
23
  "path": "retinanet",
24
24
  },
25
- "kaggle_handle": "kaggle://keras/retinanet/keras/retinanet_resnet50_fpn_v2_coco/2",
25
+ "kaggle_handle": "kaggle://keras/retinanet/keras/retinanet_resnet50_fpn_v2_coco/3",
26
26
  },
27
27
  }
@@ -1,5 +1,3 @@
1
- import keras
2
-
3
1
  from keras_hub.src.api_export import keras_hub_export
4
2
  from keras_hub.src.models.image_segmenter_preprocessor import (
5
3
  ImageSegmenterPreprocessor,
@@ -8,25 +6,9 @@ from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone
8
6
  from keras_hub.src.models.segformer.segformer_image_converter import (
9
7
  SegFormerImageConverter,
10
8
  )
11
- from keras_hub.src.utils.tensor_utils import preprocessing_function
12
-
13
- IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406]
14
- IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225]
15
9
 
16
10
 
17
11
  @keras_hub_export("keras_hub.models.SegFormerImageSegmenterPreprocessor")
18
12
  class SegFormerImageSegmenterPreprocessor(ImageSegmenterPreprocessor):
19
13
  backbone_cls = SegFormerBackbone
20
14
  image_converter_cls = SegFormerImageConverter
21
-
22
- @preprocessing_function
23
- def call(self, x, y=None, sample_weight=None):
24
- if self.image_converter:
25
- x = self.image_converter(x)
26
- if y is not None:
27
- y = self.image_converter(y)
28
-
29
- x = x / 255
30
- x = (x - IMAGENET_DEFAULT_MEAN) / IMAGENET_DEFAULT_STD
31
-
32
- return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
@@ -10,7 +10,7 @@ presets = {
10
10
  "params": 3719027,
11
11
  "path": "segformer_b0",
12
12
  },
13
- "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b0_ade20k_512/2",
13
+ "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b0_ade20k_512/3",
14
14
  },
15
15
  "segformer_b1_ade20k_512": {
16
16
  "metadata": {
@@ -21,7 +21,7 @@ presets = {
21
21
  "params": 13682643,
22
22
  "path": "segformer_b1",
23
23
  },
24
- "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b1_ade20k_512/2",
24
+ "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b1_ade20k_512/5",
25
25
  },
26
26
  "segformer_b2_ade20k_512": {
27
27
  "metadata": {
@@ -32,7 +32,7 @@ presets = {
32
32
  "params": 24727507,
33
33
  "path": "segformer_b2",
34
34
  },
35
- "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b2_ade20k_512/2",
35
+ "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b2_ade20k_512/3",
36
36
  },
37
37
  "segformer_b3_ade20k_512": {
38
38
  "metadata": {
@@ -43,7 +43,7 @@ presets = {
43
43
  "params": 44603347,
44
44
  "path": "segformer_b3",
45
45
  },
46
- "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b3_ade20k_512/2",
46
+ "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b3_ade20k_512/3",
47
47
  },
48
48
  "segformer_b4_ade20k_512": {
49
49
  "metadata": {
@@ -54,7 +54,7 @@ presets = {
54
54
  "params": 61373907,
55
55
  "path": "segformer_b4",
56
56
  },
57
- "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b4_ade20k_512/2",
57
+ "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b4_ade20k_512/3",
58
58
  },
59
59
  "segformer_b5_ade20k_640": {
60
60
  "metadata": {
@@ -65,7 +65,7 @@ presets = {
65
65
  "params": 81974227,
66
66
  "path": "segformer_b5",
67
67
  },
68
- "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b5_ade20k_640/2",
68
+ "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b5_ade20k_640/3",
69
69
  },
70
70
  "segformer_b0_cityscapes_1024": {
71
71
  "metadata": {
@@ -76,7 +76,7 @@ presets = {
76
76
  "params": 3719027,
77
77
  "path": "segformer_b0",
78
78
  },
79
- "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b0_cityscapes_1024/2",
79
+ "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b0_cityscapes_1024/3",
80
80
  },
81
81
  "segformer_b1_cityscapes_1024": {
82
82
  "metadata": {
@@ -87,7 +87,7 @@ presets = {
87
87
  "params": 13682643,
88
88
  "path": "segformer_b1",
89
89
  },
90
- "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b1_ade20k_512/2",
90
+ "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b1_ade20k_512/1",
91
91
  },
92
92
  "segformer_b2_cityscapes_1024": {
93
93
  "metadata": {
@@ -98,7 +98,7 @@ presets = {
98
98
  "params": 24727507,
99
99
  "path": "segformer_b2",
100
100
  },
101
- "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b2_cityscapes_1024/2",
101
+ "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b2_cityscapes_1024/3",
102
102
  },
103
103
  "segformer_b3_cityscapes_1024": {
104
104
  "metadata": {
@@ -109,7 +109,7 @@ presets = {
109
109
  "params": 44603347,
110
110
  "path": "segformer_b3",
111
111
  },
112
- "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b3_cityscapes_1024/2",
112
+ "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b3_cityscapes_1024/3",
113
113
  },
114
114
  "segformer_b4_cityscapes_1024": {
115
115
  "metadata": {
@@ -120,7 +120,7 @@ presets = {
120
120
  "params": 61373907,
121
121
  "path": "segformer_b4",
122
122
  },
123
- "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b4_cityscapes_1024/2",
123
+ "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b4_cityscapes_1024/3",
124
124
  },
125
125
  "segformer_b5_cityscapes_1024": {
126
126
  "metadata": {
@@ -131,6 +131,6 @@ presets = {
131
131
  "params": 81974227,
132
132
  "path": "segformer_b5",
133
133
  },
134
- "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b5_cityscapes_1024/2",
134
+ "kaggle_handle": "kaggle://keras/segformer/keras/segformer_b5_cityscapes_1024/3",
135
135
  },
136
136
  }
@@ -236,14 +236,17 @@ class Task(PipelineModel):
236
236
  objects_to_skip=backbone_layer_ids,
237
237
  )
238
238
 
239
- def save_to_preset(self, preset_dir):
239
+ def save_to_preset(self, preset_dir, max_shard_size=10):
240
240
  """Save task to a preset directory.
241
241
 
242
242
  Args:
243
243
  preset_dir: The path to the local model preset directory.
244
+ max_shard_size: `int` or `float`. Maximum size in GB for each
245
+ sharded file. If `None`, no sharding will be done. Defaults to
246
+ `10`.
244
247
  """
245
248
  saver = get_preset_saver(preset_dir)
246
- saver.save_task(self)
249
+ saver.save_task(self, max_shard_size=max_shard_size)
247
250
 
248
251
  @property
249
252
  def layers(self):
@@ -1,3 +1,4 @@
1
+ import inspect
1
2
  import sys
2
3
 
3
4
  import keras
@@ -147,3 +148,13 @@ def get_gpu_names():
147
148
  ]
148
149
  else:
149
150
  return [""]
151
+
152
+
153
+ def sharded_weights_available():
154
+ """Whether sharded weights serialization is available.
155
+
156
+ Returns:
157
+ `True` if sharded weights are available, `False` otherwise.
158
+ """
159
+ save_weights_signature = inspect.signature(keras.saving.save_weights)
160
+ return "max_shard_size" in save_weights_signature.parameters
@@ -10,6 +10,8 @@ from absl import logging
10
10
 
11
11
  from keras_hub.src.api_export import keras_hub_export
12
12
  from keras_hub.src.utils.keras_utils import print_msg
13
+ from keras_hub.src.utils.keras_utils import sharded_weights_available
14
+ from keras_hub.src.utils.tensor_utils import get_tensor_size_in_bits
13
15
 
14
16
  try:
15
17
  import kagglehub
@@ -48,6 +50,7 @@ METADATA_FILE = "metadata.json"
48
50
  # Weight file names.
49
51
  MODEL_WEIGHTS_FILE = "model.weights.h5"
50
52
  TASK_WEIGHTS_FILE = "task.weights.h5"
53
+ SHARDED_MODEL_WEIGHTS_CONFIG_FILE = "model.weights.json"
51
54
 
52
55
  # HuggingFace filenames.
53
56
  README_FILE = "README.md"
@@ -647,7 +650,7 @@ class KerasPresetLoader(PresetLoader):
647
650
  backbone = self._load_serialized_object(self.config, **kwargs)
648
651
  if load_weights:
649
652
  jax_memory_cleanup(backbone)
650
- backbone.load_weights(get_file(self.preset, MODEL_WEIGHTS_FILE))
653
+ self._load_backbone_weights(backbone)
651
654
  return backbone
652
655
 
653
656
  def load_tokenizer(self, cls, config_file=TOKENIZER_CONFIG_FILE, **kwargs):
@@ -697,8 +700,7 @@ class KerasPresetLoader(PresetLoader):
697
700
  task.load_task_weights(task_weights)
698
701
  else:
699
702
  jax_memory_cleanup(task.backbone)
700
- backbone_weights = get_file(self.preset, MODEL_WEIGHTS_FILE)
701
- task.backbone.load_weights(backbone_weights)
703
+ self._load_backbone_weights(task.backbone)
702
704
  return task
703
705
 
704
706
  def load_preprocessor(
@@ -726,18 +728,64 @@ class KerasPresetLoader(PresetLoader):
726
728
  config["config"] = {**config["config"], **kwargs}
727
729
  return keras.saving.deserialize_keras_object(config)
728
730
 
731
+ def _get_sharded_filenames(self, config_path):
732
+ with open(config_path, encoding="utf-8") as config_file:
733
+ config = json.load(config_file)
734
+ weight_map = config["weight_map"]
735
+ return sorted(set(weight_map.values()))
736
+
737
+ def _load_backbone_weights(self, backbone):
738
+ # Detect if the backbone is sharded or not.
739
+ has_single_file_weights = check_file_exists(
740
+ self.preset, MODEL_WEIGHTS_FILE
741
+ )
742
+ if has_single_file_weights:
743
+ filepath = get_file(self.preset, MODEL_WEIGHTS_FILE)
744
+ else:
745
+ if not sharded_weights_available():
746
+ raise RuntimeError(
747
+ "Sharded weights loading is not supported in the current "
748
+ f"Keras version {keras.__version__}. "
749
+ "Please update to a newer version."
750
+ )
751
+ filepath = get_file(self.preset, SHARDED_MODEL_WEIGHTS_CONFIG_FILE)
752
+ sharded_filenames = self._get_sharded_filenames(filepath)
753
+ for sharded_filename in sharded_filenames:
754
+ # Download the sharded weights.
755
+ _ = get_file(self.preset, sharded_filename)
756
+ backbone.load_weights(filepath)
757
+
729
758
 
730
759
  class KerasPresetSaver:
731
760
  def __init__(self, preset_dir):
732
761
  os.makedirs(preset_dir, exist_ok=True)
733
762
  self.preset_dir = preset_dir
734
763
 
735
- def save_backbone(self, backbone):
764
+ def save_backbone(self, backbone, max_shard_size=10):
736
765
  self._save_serialized_object(backbone, config_file=CONFIG_FILE)
737
- backbone_weight_path = os.path.join(self.preset_dir, MODEL_WEIGHTS_FILE)
738
- backbone.save_weights(backbone_weight_path)
739
766
  self._save_metadata(backbone)
740
767
 
768
+ # Save the weights.
769
+ backbone_size_in_bytes = self._get_variables_size_in_bytes(
770
+ backbone.variables
771
+ )
772
+ backbone_size_in_gb = backbone_size_in_bytes / (1024**3)
773
+ # If the size of the backbone is larger than `max_shard_size`, save
774
+ # sharded weights.
775
+ if sharded_weights_available() and backbone_size_in_gb > max_shard_size:
776
+ backbone_sharded_weights_config_path = os.path.join(
777
+ self.preset_dir, SHARDED_MODEL_WEIGHTS_CONFIG_FILE
778
+ )
779
+ backbone.save_weights(
780
+ backbone_sharded_weights_config_path,
781
+ max_shard_size=max_shard_size,
782
+ )
783
+ else:
784
+ backbone_weight_path = os.path.join(
785
+ self.preset_dir, MODEL_WEIGHTS_FILE
786
+ )
787
+ backbone.save_weights(backbone_weight_path)
788
+
741
789
  def save_tokenizer(self, tokenizer):
742
790
  config_file = TOKENIZER_CONFIG_FILE
743
791
  if hasattr(tokenizer, "config_file"):
@@ -755,7 +803,7 @@ class KerasPresetSaver:
755
803
  def save_image_converter(self, converter):
756
804
  self._save_serialized_object(converter, IMAGE_CONVERTER_CONFIG_FILE)
757
805
 
758
- def save_task(self, task):
806
+ def save_task(self, task, max_shard_size=10):
759
807
  # Save task specific config and weights.
760
808
  self._save_serialized_object(task, TASK_CONFIG_FILE)
761
809
  if task.has_task_weights():
@@ -763,10 +811,12 @@ class KerasPresetSaver:
763
811
  task.save_task_weights(task_weight_path)
764
812
  # Save backbone.
765
813
  if hasattr(task.backbone, "save_to_preset"):
766
- task.backbone.save_to_preset(self.preset_dir)
814
+ task.backbone.save_to_preset(
815
+ self.preset_dir, max_shard_size=max_shard_size
816
+ )
767
817
  else:
768
818
  # Allow saving a `keras.Model` that is not a backbone subclass.
769
- self.save_backbone(task.backbone)
819
+ self.save_backbone(task.backbone, max_shard_size=max_shard_size)
770
820
  # Save preprocessor.
771
821
  if task.preprocessor and hasattr(task.preprocessor, "save_to_preset"):
772
822
  task.preprocessor.save_to_preset(self.preset_dir)
@@ -823,3 +873,13 @@ class KerasPresetSaver:
823
873
  metadata_path = os.path.join(self.preset_dir, METADATA_FILE)
824
874
  with open(metadata_path, "w") as metadata_file:
825
875
  metadata_file.write(json.dumps(metadata, indent=4))
876
+
877
+ def _get_variables_size_in_bytes(self, variables):
878
+ unique_variables = {}
879
+ for v in variables:
880
+ if id(v) not in unique_variables:
881
+ unique_variables[id(v)] = (v.shape, v.dtype)
882
+ total_memory_size = 0
883
+ for shape, dtype in unique_variables.values():
884
+ total_memory_size += get_tensor_size_in_bits(shape, dtype)
885
+ return total_memory_size / 8
@@ -1,6 +1,8 @@
1
1
  import contextlib
2
2
  import functools
3
3
  import inspect
4
+ import math
5
+ import re
4
6
  import threading
5
7
 
6
8
  import keras
@@ -305,6 +307,29 @@ def is_string_dtype(dtype):
305
307
  return "string" in keras.backend.standardize_dtype(dtype)
306
308
 
307
309
 
310
+ def get_dtype_size_in_bits(dtype):
311
+ """Get the size of a given dtype in bits."""
312
+ dtype = keras.backend.standardize_dtype(dtype)
313
+ # If dtype is bool, return 1 immediately.
314
+ if dtype == "bool":
315
+ return 1
316
+ # Else, we extract the bit size from the string.
317
+ return int(re.sub(r"bfloat|float|uint|int", "", dtype))
318
+
319
+
320
+ def get_tensor_size_in_bits(shape, dtype):
321
+ """Calculate the size given dtype and shape in bits.
322
+
323
+ Args:
324
+ dtype: The dtype of the tensor.
325
+ shape: List of iterables representing the shape of the tensor.
326
+
327
+ Returns:
328
+ The size of the tensor in bytes.
329
+ """
330
+ return math.prod(shape) * get_dtype_size_in_bits(dtype)
331
+
332
+
308
333
  def any_equal(inputs, values, padding_mask):
309
334
  """Return a mask that is True anywhere `inputs` has a value in `values`.
310
335
 
@@ -320,7 +345,8 @@ def any_equal(inputs, values, padding_mask):
320
345
  Returns:
321
346
  A tensor with `inputs` shape where each position is True if it contains
322
347
  a value from any `values`. Padding mask will be applied before
323
- returning."""
348
+ returning.
349
+ """
324
350
  output = ops.equal(inputs, values[0])
325
351
  for value in values[1:]:
326
352
  value_equality = ops.equal(inputs, value)
@@ -17,10 +17,69 @@ def convert_backbone_config(timm_config):
17
17
  bottle_ratio = (0.5,) + (1.0,)
18
18
  block_ratio = (1.0,) + (0.5,)
19
19
  expand_ratio = (2.0,) + (1.0,)
20
+ stem_padding = "same"
21
+ stem_pooling = None
20
22
  stage_type = "csp"
23
+ groups = 1
21
24
  block_type = "dark_block"
22
25
  down_growth = True
23
- stackwise_strides = 2
26
+ stackwise_strides = [2, 2, 2, 2, 2]
27
+ avg_down = False
28
+ cross_linear = False
29
+ elif timm_architecture == "cspresnet50":
30
+ stem_filters = 64
31
+ stem_kernel_size = 7
32
+ stem_strides = 4
33
+ stackwise_depth = [3, 3, 5, 2]
34
+ stackwise_strides = [1, 2, 2, 2]
35
+ stackwise_num_filters = [128, 256, 512, 1024]
36
+ block_type = "bottleneck_block"
37
+ stage_type = "csp"
38
+ bottle_ratio = [0.5]
39
+ block_ratio = [1.0]
40
+ expand_ratio = [2.0]
41
+ stem_padding = "valid"
42
+ stem_pooling = "max"
43
+ avg_down = False
44
+ groups = 1
45
+ down_growth = False
46
+ cross_linear = True
47
+ elif timm_architecture == "cspresnext50":
48
+ stem_filters = 64
49
+ stem_kernel_size = 7
50
+ stem_strides = 4
51
+ stackwise_depth = [3, 3, 5, 2]
52
+ stackwise_num_filters = [256, 512, 1024, 2048]
53
+ bottle_ratio = [1.0]
54
+ block_ratio = [0.5]
55
+ expand_ratio = [1.0]
56
+ stage_type = "csp"
57
+ block_type = "bottleneck_block"
58
+ stem_pooling = "max"
59
+ stackwise_strides = [1, 2, 2, 2]
60
+ groups = 32
61
+ stem_padding = "valid"
62
+ avg_down = False
63
+ down_growth = False
64
+ cross_linear = True
65
+ elif timm_architecture == "darknet53":
66
+ stem_filters = 32
67
+ stem_kernel_size = 3
68
+ stem_strides = 1
69
+ stackwise_depth = [1, 2, 8, 8, 4]
70
+ stackwise_num_filters = [64, 128, 256, 512, 1024]
71
+ bottle_ratio = [0.5]
72
+ block_ratio = [1.0]
73
+ groups = 1
74
+ expand_ratio = [1.0]
75
+ stage_type = "dark"
76
+ block_type = "dark_block"
77
+ stem_pooling = None
78
+ stackwise_strides = [2, 2, 2, 2, 2]
79
+ stem_padding = "same"
80
+ avg_down = False
81
+ down_growth = False
82
+ cross_linear = False
24
83
  else:
25
84
  raise ValueError(
26
85
  f"Currently, the architecture {timm_architecture} is not supported."
@@ -38,6 +97,11 @@ def convert_backbone_config(timm_config):
38
97
  block_type=block_type,
39
98
  stackwise_strides=stackwise_strides,
40
99
  down_growth=down_growth,
100
+ stem_pooling=stem_pooling,
101
+ stem_padding=stem_padding,
102
+ avg_down=avg_down,
103
+ cross_linear=cross_linear,
104
+ groups=groups,
41
105
  )
42
106
 
43
107
 
@@ -81,21 +145,36 @@ def convert_weights(backbone, loader, timm_config):
81
145
  stackwise_depth = backbone.stackwise_depth
82
146
  stage_type = backbone.stage_type
83
147
  block_type = backbone.block_type
148
+ strides = backbone.stackwise_strides
84
149
 
85
150
  for idx, block in enumerate(stackwise_depth):
86
- port_conv2d(
87
- f"stages.{idx}.conv_down.conv",
88
- f"stage_{idx}_{stage_type}_conv_down_1",
89
- )
90
- port_batch_normalization(
91
- f"stages.{idx}.conv_down.bn", f"stage_{idx}_{stage_type}_bn_1"
92
- )
93
- port_conv2d(
94
- f"stages.{idx}.conv_exp.conv", f"stage_{idx}_{stage_type}_conv_exp"
95
- )
96
- port_batch_normalization(
97
- f"stages.{idx}.conv_exp.bn", f"stage_{idx}_{stage_type}_bn_2"
98
- )
151
+ if strides[idx] != 1 or stage_type == "dark":
152
+ if strides[idx] == 2 and backbone.avg_down:
153
+ port_conv2d(
154
+ f"stages.{idx}.conv_down.1.conv",
155
+ f"stage_{idx}_{stage_type}_conv_down_1",
156
+ )
157
+ port_batch_normalization(
158
+ f"stages.{idx}.conv_down.1.bn",
159
+ f"stage_{idx}_{stage_type}_bn_1",
160
+ )
161
+ else:
162
+ port_conv2d(
163
+ f"stages.{idx}.conv_down.conv",
164
+ f"stage_{idx}_{stage_type}_conv_down_1",
165
+ )
166
+ port_batch_normalization(
167
+ f"stages.{idx}.conv_down.bn",
168
+ f"stage_{idx}_{stage_type}_bn_1",
169
+ )
170
+ if stage_type != "dark":
171
+ port_conv2d(
172
+ f"stages.{idx}.conv_exp.conv",
173
+ f"stage_{idx}_{stage_type}_conv_exp",
174
+ )
175
+ port_batch_normalization(
176
+ f"stages.{idx}.conv_exp.bn", f"stage_{idx}_{stage_type}_bn_2"
177
+ )
99
178
 
100
179
  for i in range(block):
101
180
  port_conv2d(
@@ -133,16 +212,8 @@ def convert_weights(backbone, loader, timm_config):
133
212
  f"stages.{idx}.conv_transition_b.bn",
134
213
  f"stage_{idx}_{stage_type}_transition_b_bn",
135
214
  )
136
- port_conv2d(
137
- f"stages.{idx}.conv_transition.conv",
138
- f"stage_{idx}_{stage_type}_conv_transition",
139
- )
140
- port_batch_normalization(
141
- f"stages.{idx}.conv_transition.bn",
142
- f"stage_{idx}_{stage_type}_transition_bn",
143
- )
144
215
 
145
- else:
216
+ if stage_type != "dark":
146
217
  port_conv2d(
147
218
  f"stages.{idx}.conv_transition.conv",
148
219
  f"stage_{idx}_{stage_type}_conv_transition",
@@ -16,17 +16,17 @@ class TimmPresetLoader(PresetLoader):
16
16
  def __init__(self, preset, config):
17
17
  super().__init__(preset, config)
18
18
  architecture = self.config["architecture"]
19
- if "resnet" in architecture:
19
+ if architecture.startswith("resnet"):
20
20
  self.converter = convert_resnet
21
- elif "csp" in architecture:
21
+ elif architecture.startswith(("csp", "dark")):
22
22
  self.converter = convert_cspnet
23
- elif "densenet" in architecture:
23
+ elif architecture.startswith("densenet"):
24
24
  self.converter = convert_densenet
25
- elif "mobilenet" in architecture:
25
+ elif architecture.startswith("mobilenet"):
26
26
  self.converter = convert_mobilenet
27
- elif "vgg" in architecture:
27
+ elif architecture.startswith("vgg"):
28
28
  self.converter = convert_vgg
29
- elif "efficientnet" in architecture:
29
+ elif architecture.startswith("efficientnet"):
30
30
  self.converter = convert_efficientnet
31
31
  else:
32
32
  raise ValueError(