keras-hub 0.20.0.dev1__py3-none-any.whl → 0.21.0.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. keras_hub/__init__.py +15 -33
  2. keras_hub/layers/__init__.py +134 -0
  3. keras_hub/metrics/__init__.py +11 -0
  4. keras_hub/models/__init__.py +642 -0
  5. keras_hub/samplers/__init__.py +18 -0
  6. keras_hub/src/layers/modeling/reversible_embedding.py +25 -35
  7. keras_hub/src/layers/preprocessing/image_converter.py +1 -0
  8. keras_hub/src/layers/preprocessing/random_deletion.py +1 -1
  9. keras_hub/src/layers/preprocessing/random_swap.py +1 -1
  10. keras_hub/src/models/audio_to_text.py +66 -0
  11. keras_hub/src/models/audio_to_text_preprocessor.py +80 -0
  12. keras_hub/src/models/backbone.py +5 -2
  13. keras_hub/src/models/cspnet/cspnet_backbone.py +51 -26
  14. keras_hub/src/models/cspnet/cspnet_presets.py +38 -3
  15. keras_hub/src/models/falcon/falcon_backbone.py +1 -1
  16. keras_hub/src/models/gemma/gemma_presets.py +10 -10
  17. keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +3 -2
  18. keras_hub/src/models/gemma3/gemma3_presets.py +8 -8
  19. keras_hub/src/models/gemma3/gemma3_vision_encoder.py +1 -1
  20. keras_hub/src/models/llama/llama_attention.py +24 -6
  21. keras_hub/src/models/llama/llama_backbone.py +50 -16
  22. keras_hub/src/models/llama/llama_decoder.py +20 -3
  23. keras_hub/src/models/llama/llama_presets.py +3 -3
  24. keras_hub/src/models/llama/llama_rotary_embedding.py +180 -0
  25. keras_hub/src/models/llama3/llama3_backbone.py +10 -2
  26. keras_hub/src/models/llama3/llama3_presets.py +84 -2
  27. keras_hub/src/models/mistral/mistral_presets.py +3 -3
  28. keras_hub/src/models/mixtral/__init__.py +5 -0
  29. keras_hub/src/models/mixtral/mixtral_attention.py +252 -0
  30. keras_hub/src/models/mixtral/mixtral_backbone.py +207 -0
  31. keras_hub/src/models/mixtral/mixtral_causal_lm.py +281 -0
  32. keras_hub/src/models/mixtral/mixtral_causal_lm_preprocessor.py +76 -0
  33. keras_hub/src/models/mixtral/mixtral_decoder.py +494 -0
  34. keras_hub/src/models/mixtral/mixtral_layer_norm.py +34 -0
  35. keras_hub/src/models/mixtral/mixtral_presets.py +26 -0
  36. keras_hub/src/models/mixtral/mixtral_tokenizer.py +21 -0
  37. keras_hub/src/models/moonshine/__init__.py +5 -0
  38. keras_hub/src/models/moonshine/moonshine_audio_converter.py +301 -0
  39. keras_hub/src/models/moonshine/moonshine_audio_to_text.py +383 -0
  40. keras_hub/src/models/moonshine/moonshine_audio_to_text_preprocessor.py +272 -0
  41. keras_hub/src/models/moonshine/moonshine_backbone.py +478 -0
  42. keras_hub/src/models/moonshine/moonshine_decoder.py +313 -0
  43. keras_hub/src/models/moonshine/moonshine_encoder.py +212 -0
  44. keras_hub/src/models/moonshine/moonshine_layers.py +239 -0
  45. keras_hub/src/models/moonshine/moonshine_multi_head_attention.py +355 -0
  46. keras_hub/src/models/moonshine/moonshine_presets.py +25 -0
  47. keras_hub/src/models/moonshine/moonshine_tokenizer.py +62 -0
  48. keras_hub/src/models/pali_gemma/pali_gemma_presets.py +11 -11
  49. keras_hub/src/models/pali_gemma/pali_gemma_vit.py +1 -1
  50. keras_hub/src/models/qwen/__init__.py +4 -0
  51. keras_hub/src/models/qwen/qwen_attention.py +3 -1
  52. keras_hub/src/models/qwen/qwen_backbone.py +8 -1
  53. keras_hub/src/models/qwen/qwen_causal_lm.py +7 -0
  54. keras_hub/src/models/qwen/qwen_causal_lm_preprocessor.py +7 -0
  55. keras_hub/src/models/qwen/qwen_presets.py +61 -0
  56. keras_hub/src/models/qwen/qwen_tokenizer.py +9 -0
  57. keras_hub/src/models/qwen_moe/__init__.py +5 -0
  58. keras_hub/src/models/qwen_moe/qwen_moe_attention.py +375 -0
  59. keras_hub/src/models/qwen_moe/qwen_moe_backbone.py +373 -0
  60. keras_hub/src/models/qwen_moe/qwen_moe_causal_lm.py +350 -0
  61. keras_hub/src/models/qwen_moe/qwen_moe_causal_lm_preprocessor.py +17 -0
  62. keras_hub/src/models/qwen_moe/qwen_moe_decoder.py +625 -0
  63. keras_hub/src/models/qwen_moe/qwen_moe_layernorm.py +32 -0
  64. keras_hub/src/models/qwen_moe/qwen_moe_presets.py +15 -0
  65. keras_hub/src/models/qwen_moe/qwen_moe_tokenizer.py +46 -0
  66. keras_hub/src/models/retinanet/retinanet_image_converter.py +0 -13
  67. keras_hub/src/models/retinanet/retinanet_presets.py +2 -2
  68. keras_hub/src/models/segformer/segformer_image_segmenter_preprocessor.py +0 -18
  69. keras_hub/src/models/segformer/segformer_presets.py +12 -12
  70. keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +6 -0
  71. keras_hub/src/models/task.py +5 -2
  72. keras_hub/src/models/xception/__init__.py +5 -0
  73. keras_hub/src/models/xception/xception_backbone.py +188 -0
  74. keras_hub/src/models/xception/xception_image_classifier.py +12 -0
  75. keras_hub/src/models/xception/xception_image_classifier_preprocessor.py +14 -0
  76. keras_hub/src/models/xception/xception_image_converter.py +8 -0
  77. keras_hub/src/models/xception/xception_presets.py +14 -0
  78. keras_hub/src/tests/mocks/mock_gemma3_tokenizer.py +155 -0
  79. keras_hub/src/utils/coco/__init__.py +0 -0
  80. keras_hub/src/utils/coco/coco_utils.py +133 -0
  81. keras_hub/src/utils/imagenet/imagenet_utils.py +36 -0
  82. keras_hub/src/utils/keras_utils.py +11 -0
  83. keras_hub/src/utils/preset_utils.py +70 -10
  84. keras_hub/src/utils/tensor_utils.py +27 -1
  85. keras_hub/src/utils/timm/convert_cspnet.py +94 -23
  86. keras_hub/src/utils/timm/preset_loader.py +6 -6
  87. keras_hub/src/utils/transformers/convert_llama3.py +21 -1
  88. keras_hub/src/utils/transformers/convert_mixtral.py +139 -0
  89. keras_hub/src/utils/transformers/convert_qwen.py +1 -0
  90. keras_hub/src/utils/transformers/convert_qwen_moe.py +253 -0
  91. keras_hub/src/utils/transformers/preset_loader.py +6 -0
  92. keras_hub/src/{version_utils.py → version.py} +1 -1
  93. keras_hub/tokenizers/__init__.py +117 -0
  94. keras_hub/utils/__init__.py +21 -0
  95. {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/METADATA +6 -20
  96. {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/RECORD +98 -55
  97. {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/WHEEL +1 -1
  98. keras_hub/api/__init__.py +0 -15
  99. keras_hub/api/layers/__init__.py +0 -86
  100. keras_hub/api/metrics/__init__.py +0 -11
  101. keras_hub/api/models/__init__.py +0 -416
  102. keras_hub/api/samplers/__init__.py +0 -16
  103. keras_hub/api/tokenizers/__init__.py +0 -58
  104. keras_hub/api/utils/__init__.py +0 -9
  105. {keras_hub-0.20.0.dev1.dist-info → keras_hub-0.21.0.dev1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,133 @@
1
+ from keras_hub.src.api_export import keras_hub_export
2
+
3
+
4
+ @keras_hub_export("keras_hub.utils.coco_id_to_name")
5
+ def coco_id_to_name(id):
6
+ """Convert a single COCO class name to a class ID.
7
+
8
+ Args:
9
+ id: An integer class id from 0 to 91.
10
+
11
+ Returns:
12
+ The human readable image class name, e.g. "bicycle".
13
+
14
+ Example:
15
+ >>> keras_hub.utils.coco_id_to_name(2)
16
+ 'bicycle'
17
+ """
18
+ return COCO_NAMES[id]
19
+
20
+
21
+ @keras_hub_export("keras_hub.utils.coco_name_to_id")
22
+ def coco_name_to_id(name):
23
+ """Convert a single COCO class name to a class ID.
24
+
25
+ Args:
26
+ name: A human readable image class name, e.g. "bicycle".
27
+
28
+ Returns:
29
+ The integer class id from 0 to 999.
30
+
31
+ Example:
32
+ >>> keras_hub.utils.coco_name_to_id("bicycle")
33
+ 2
34
+ """
35
+ return COCO_IDS[name]
36
+
37
+
38
+ COCO_NAMES = {
39
+ 0: "unlabeled",
40
+ 1: "person",
41
+ 2: "bicycle",
42
+ 3: "car",
43
+ 4: "motorcycle",
44
+ 5: "airplane",
45
+ 6: "bus",
46
+ 7: "train",
47
+ 8: "truck",
48
+ 9: "boat",
49
+ 10: "traffic_light",
50
+ 11: "fire_hydrant",
51
+ 12: "street_sign",
52
+ 13: "stop_sign",
53
+ 14: "parking_meter",
54
+ 15: "bench",
55
+ 16: "bird",
56
+ 17: "cat",
57
+ 18: "dog",
58
+ 19: "horse",
59
+ 20: "sheep",
60
+ 21: "cow",
61
+ 22: "elephant",
62
+ 23: "bear",
63
+ 24: "zebra",
64
+ 25: "giraffe",
65
+ 26: "hat",
66
+ 27: "backpack",
67
+ 28: "umbrella",
68
+ 29: "shoe",
69
+ 30: "eye_glasses",
70
+ 31: "handbag",
71
+ 32: "tie",
72
+ 33: "suitcase",
73
+ 34: "frisbee",
74
+ 35: "skis",
75
+ 36: "snowboard",
76
+ 37: "sports_ball",
77
+ 38: "kite",
78
+ 39: "baseball_bat",
79
+ 40: "baseball_glove",
80
+ 41: "skateboard",
81
+ 42: "surfboard",
82
+ 43: "tennis_racket",
83
+ 44: "bottle",
84
+ 45: "plate",
85
+ 46: "wine_glass",
86
+ 47: "cup",
87
+ 48: "fork",
88
+ 49: "knife",
89
+ 50: "spoon",
90
+ 51: "bowl",
91
+ 52: "banana",
92
+ 53: "apple",
93
+ 54: "sandwich",
94
+ 55: "orange",
95
+ 56: "broccoli",
96
+ 57: "carrot",
97
+ 58: "hot_dog",
98
+ 59: "pizza",
99
+ 60: "donut",
100
+ 61: "cake",
101
+ 62: "chair",
102
+ 63: "couch",
103
+ 64: "potted_plant",
104
+ 65: "bed",
105
+ 66: "mirror",
106
+ 67: "dining_table",
107
+ 68: "window",
108
+ 69: "desk",
109
+ 70: "toilet",
110
+ 71: "door",
111
+ 72: "tv",
112
+ 73: "laptop",
113
+ 74: "mouse",
114
+ 75: "remote",
115
+ 76: "keyboard",
116
+ 77: "cell_phone",
117
+ 78: "microwave",
118
+ 79: "oven",
119
+ 80: "toaster",
120
+ 81: "sink",
121
+ 82: "refrigerator",
122
+ 83: "blender",
123
+ 84: "book",
124
+ 85: "clock",
125
+ 86: "vase",
126
+ 87: "scissors",
127
+ 88: "teddy_bear",
128
+ 89: "hair_drier",
129
+ 90: "toothbrush",
130
+ 91: "hair_brush",
131
+ }
132
+
133
+ COCO_IDS = {v: k for k, v in COCO_NAMES.items()}
@@ -3,6 +3,40 @@ from keras import ops
3
3
  from keras_hub.src.api_export import keras_hub_export
4
4
 
5
5
 
6
+ @keras_hub_export("keras_hub.utils.imagenet_id_to_name")
7
+ def imagenet_id_to_name(id):
8
+ """Convert a single ImageNet class ID to a class name.
9
+
10
+ Args:
11
+ id: An integer class id from 0 to 999.
12
+
13
+ Returns:
14
+ The human readable image class name, e.g. "goldfish".
15
+
16
+ Example:
17
+ >>> keras_hub.utils.imagenet_id_to_name(1)
18
+ 'goldfish'
19
+ """
20
+ return IMAGENET_NAMES[id][1]
21
+
22
+
23
+ @keras_hub_export("keras_hub.utils.imagenet_name_to_id")
24
+ def imagenet_name_to_id(name):
25
+ """Convert a single ImageNet class name to a class ID.
26
+
27
+ Args:
28
+ name: A human readable image class name, e.g. "goldfish".
29
+
30
+ Returns:
31
+ The integer class id from 0 to 999.
32
+
33
+ Example:
34
+ >>> keras_hub.utils.imagenet_name_to_id("goldfish")
35
+ 1
36
+ """
37
+ return IMAGENET_IDS[name]
38
+
39
+
6
40
  @keras_hub_export("keras_hub.utils.decode_imagenet_predictions")
7
41
  def decode_imagenet_predictions(preds, top=5, include_synset_ids=False):
8
42
  """Decodes the predictions for an ImageNet-1k prediction.
@@ -1052,3 +1086,5 @@ IMAGENET_NAMES = {
1052
1086
  998: ("n13133613", "ear"),
1053
1087
  999: ("n15075141", "toilet_tissue"),
1054
1088
  }
1089
+
1090
+ IMAGENET_IDS = {v[1]: k for k, v in IMAGENET_NAMES.items()}
@@ -1,3 +1,4 @@
1
+ import inspect
1
2
  import sys
2
3
 
3
4
  import keras
@@ -147,3 +148,13 @@ def get_gpu_names():
147
148
  ]
148
149
  else:
149
150
  return [""]
151
+
152
+
153
+ def sharded_weights_available():
154
+ """Whether sharded weights serialization is available.
155
+
156
+ Returns:
157
+ `True` if sharded weights are available, `False` otherwise.
158
+ """
159
+ save_weights_signature = inspect.signature(keras.saving.save_weights)
160
+ return "max_shard_size" in save_weights_signature.parameters
@@ -10,6 +10,8 @@ from absl import logging
10
10
 
11
11
  from keras_hub.src.api_export import keras_hub_export
12
12
  from keras_hub.src.utils.keras_utils import print_msg
13
+ from keras_hub.src.utils.keras_utils import sharded_weights_available
14
+ from keras_hub.src.utils.tensor_utils import get_tensor_size_in_bits
13
15
 
14
16
  try:
15
17
  import kagglehub
@@ -48,6 +50,7 @@ METADATA_FILE = "metadata.json"
48
50
  # Weight file names.
49
51
  MODEL_WEIGHTS_FILE = "model.weights.h5"
50
52
  TASK_WEIGHTS_FILE = "task.weights.h5"
53
+ SHARDED_MODEL_WEIGHTS_CONFIG_FILE = "model.weights.json"
51
54
 
52
55
  # HuggingFace filenames.
53
56
  README_FILE = "README.md"
@@ -647,7 +650,7 @@ class KerasPresetLoader(PresetLoader):
647
650
  backbone = self._load_serialized_object(self.config, **kwargs)
648
651
  if load_weights:
649
652
  jax_memory_cleanup(backbone)
650
- backbone.load_weights(get_file(self.preset, MODEL_WEIGHTS_FILE))
653
+ self._load_backbone_weights(backbone)
651
654
  return backbone
652
655
 
653
656
  def load_tokenizer(self, cls, config_file=TOKENIZER_CONFIG_FILE, **kwargs):
@@ -697,8 +700,7 @@ class KerasPresetLoader(PresetLoader):
697
700
  task.load_task_weights(task_weights)
698
701
  else:
699
702
  jax_memory_cleanup(task.backbone)
700
- backbone_weights = get_file(self.preset, MODEL_WEIGHTS_FILE)
701
- task.backbone.load_weights(backbone_weights)
703
+ self._load_backbone_weights(task.backbone)
702
704
  return task
703
705
 
704
706
  def load_preprocessor(
@@ -726,18 +728,64 @@ class KerasPresetLoader(PresetLoader):
726
728
  config["config"] = {**config["config"], **kwargs}
727
729
  return keras.saving.deserialize_keras_object(config)
728
730
 
731
+ def _get_sharded_filenames(self, config_path):
732
+ with open(config_path, encoding="utf-8") as config_file:
733
+ config = json.load(config_file)
734
+ weight_map = config["weight_map"]
735
+ return sorted(set(weight_map.values()))
736
+
737
+ def _load_backbone_weights(self, backbone):
738
+ # Detect if the backbone is sharded or not.
739
+ has_single_file_weights = check_file_exists(
740
+ self.preset, MODEL_WEIGHTS_FILE
741
+ )
742
+ if has_single_file_weights:
743
+ filepath = get_file(self.preset, MODEL_WEIGHTS_FILE)
744
+ else:
745
+ if not sharded_weights_available():
746
+ raise RuntimeError(
747
+ "Sharded weights loading is not supported in the current "
748
+ f"Keras version {keras.__version__}. "
749
+ "Please update to a newer version."
750
+ )
751
+ filepath = get_file(self.preset, SHARDED_MODEL_WEIGHTS_CONFIG_FILE)
752
+ sharded_filenames = self._get_sharded_filenames(filepath)
753
+ for sharded_filename in sharded_filenames:
754
+ # Download the sharded weights.
755
+ _ = get_file(self.preset, sharded_filename)
756
+ backbone.load_weights(filepath)
757
+
729
758
 
730
759
  class KerasPresetSaver:
731
760
  def __init__(self, preset_dir):
732
761
  os.makedirs(preset_dir, exist_ok=True)
733
762
  self.preset_dir = preset_dir
734
763
 
735
- def save_backbone(self, backbone):
764
+ def save_backbone(self, backbone, max_shard_size=10):
736
765
  self._save_serialized_object(backbone, config_file=CONFIG_FILE)
737
- backbone_weight_path = os.path.join(self.preset_dir, MODEL_WEIGHTS_FILE)
738
- backbone.save_weights(backbone_weight_path)
739
766
  self._save_metadata(backbone)
740
767
 
768
+ # Save the weights.
769
+ backbone_size_in_bytes = self._get_variables_size_in_bytes(
770
+ backbone.variables
771
+ )
772
+ backbone_size_in_gb = backbone_size_in_bytes / (1024**3)
773
+ # If the size of the backbone is larger than `max_shard_size`, save
774
+ # sharded weights.
775
+ if sharded_weights_available() and backbone_size_in_gb > max_shard_size:
776
+ backbone_sharded_weights_config_path = os.path.join(
777
+ self.preset_dir, SHARDED_MODEL_WEIGHTS_CONFIG_FILE
778
+ )
779
+ backbone.save_weights(
780
+ backbone_sharded_weights_config_path,
781
+ max_shard_size=max_shard_size,
782
+ )
783
+ else:
784
+ backbone_weight_path = os.path.join(
785
+ self.preset_dir, MODEL_WEIGHTS_FILE
786
+ )
787
+ backbone.save_weights(backbone_weight_path)
788
+
741
789
  def save_tokenizer(self, tokenizer):
742
790
  config_file = TOKENIZER_CONFIG_FILE
743
791
  if hasattr(tokenizer, "config_file"):
@@ -755,7 +803,7 @@ class KerasPresetSaver:
755
803
  def save_image_converter(self, converter):
756
804
  self._save_serialized_object(converter, IMAGE_CONVERTER_CONFIG_FILE)
757
805
 
758
- def save_task(self, task):
806
+ def save_task(self, task, max_shard_size=10):
759
807
  # Save task specific config and weights.
760
808
  self._save_serialized_object(task, TASK_CONFIG_FILE)
761
809
  if task.has_task_weights():
@@ -763,10 +811,12 @@ class KerasPresetSaver:
763
811
  task.save_task_weights(task_weight_path)
764
812
  # Save backbone.
765
813
  if hasattr(task.backbone, "save_to_preset"):
766
- task.backbone.save_to_preset(self.preset_dir)
814
+ task.backbone.save_to_preset(
815
+ self.preset_dir, max_shard_size=max_shard_size
816
+ )
767
817
  else:
768
818
  # Allow saving a `keras.Model` that is not a backbone subclass.
769
- self.save_backbone(task.backbone)
819
+ self.save_backbone(task.backbone, max_shard_size=max_shard_size)
770
820
  # Save preprocessor.
771
821
  if task.preprocessor and hasattr(task.preprocessor, "save_to_preset"):
772
822
  task.preprocessor.save_to_preset(self.preset_dir)
@@ -801,7 +851,7 @@ class KerasPresetSaver:
801
851
 
802
852
  def _save_metadata(self, layer):
803
853
  from keras_hub.src.models.task import Task
804
- from keras_hub.src.version_utils import __version__ as keras_hub_version
854
+ from keras_hub.src.version import __version__ as keras_hub_version
805
855
 
806
856
  # Find all tasks that are compatible with the backbone.
807
857
  # E.g. for `BertBackbone` we would have `TextClassifier` and `MaskedLM`.
@@ -823,3 +873,13 @@ class KerasPresetSaver:
823
873
  metadata_path = os.path.join(self.preset_dir, METADATA_FILE)
824
874
  with open(metadata_path, "w") as metadata_file:
825
875
  metadata_file.write(json.dumps(metadata, indent=4))
876
+
877
+ def _get_variables_size_in_bytes(self, variables):
878
+ unique_variables = {}
879
+ for v in variables:
880
+ if id(v) not in unique_variables:
881
+ unique_variables[id(v)] = (v.shape, v.dtype)
882
+ total_memory_size = 0
883
+ for shape, dtype in unique_variables.values():
884
+ total_memory_size += get_tensor_size_in_bits(shape, dtype)
885
+ return total_memory_size / 8
@@ -1,6 +1,8 @@
1
1
  import contextlib
2
2
  import functools
3
3
  import inspect
4
+ import math
5
+ import re
4
6
  import threading
5
7
 
6
8
  import keras
@@ -305,6 +307,29 @@ def is_string_dtype(dtype):
305
307
  return "string" in keras.backend.standardize_dtype(dtype)
306
308
 
307
309
 
310
+ def get_dtype_size_in_bits(dtype):
311
+ """Get the size of a given dtype in bits."""
312
+ dtype = keras.backend.standardize_dtype(dtype)
313
+ # If dtype is bool, return 1 immediately.
314
+ if dtype == "bool":
315
+ return 1
316
+ # Else, we extract the bit size from the string.
317
+ return int(re.sub(r"bfloat|float|uint|int", "", dtype))
318
+
319
+
320
+ def get_tensor_size_in_bits(shape, dtype):
321
+ """Calculate the size given dtype and shape in bits.
322
+
323
+ Args:
324
+ dtype: The dtype of the tensor.
325
+ shape: List of iterables representing the shape of the tensor.
326
+
327
+ Returns:
328
+ The size of the tensor in bytes.
329
+ """
330
+ return math.prod(shape) * get_dtype_size_in_bits(dtype)
331
+
332
+
308
333
  def any_equal(inputs, values, padding_mask):
309
334
  """Return a mask that is True anywhere `inputs` has a value in `values`.
310
335
 
@@ -320,7 +345,8 @@ def any_equal(inputs, values, padding_mask):
320
345
  Returns:
321
346
  A tensor with `inputs` shape where each position is True if it contains
322
347
  a value from any `values`. Padding mask will be applied before
323
- returning."""
348
+ returning.
349
+ """
324
350
  output = ops.equal(inputs, values[0])
325
351
  for value in values[1:]:
326
352
  value_equality = ops.equal(inputs, value)
@@ -17,10 +17,69 @@ def convert_backbone_config(timm_config):
17
17
  bottle_ratio = (0.5,) + (1.0,)
18
18
  block_ratio = (1.0,) + (0.5,)
19
19
  expand_ratio = (2.0,) + (1.0,)
20
+ stem_padding = "same"
21
+ stem_pooling = None
20
22
  stage_type = "csp"
23
+ groups = 1
21
24
  block_type = "dark_block"
22
25
  down_growth = True
23
- stackwise_strides = 2
26
+ stackwise_strides = [2, 2, 2, 2, 2]
27
+ avg_down = False
28
+ cross_linear = False
29
+ elif timm_architecture == "cspresnet50":
30
+ stem_filters = 64
31
+ stem_kernel_size = 7
32
+ stem_strides = 4
33
+ stackwise_depth = [3, 3, 5, 2]
34
+ stackwise_strides = [1, 2, 2, 2]
35
+ stackwise_num_filters = [128, 256, 512, 1024]
36
+ block_type = "bottleneck_block"
37
+ stage_type = "csp"
38
+ bottle_ratio = [0.5]
39
+ block_ratio = [1.0]
40
+ expand_ratio = [2.0]
41
+ stem_padding = "valid"
42
+ stem_pooling = "max"
43
+ avg_down = False
44
+ groups = 1
45
+ down_growth = False
46
+ cross_linear = True
47
+ elif timm_architecture == "cspresnext50":
48
+ stem_filters = 64
49
+ stem_kernel_size = 7
50
+ stem_strides = 4
51
+ stackwise_depth = [3, 3, 5, 2]
52
+ stackwise_num_filters = [256, 512, 1024, 2048]
53
+ bottle_ratio = [1.0]
54
+ block_ratio = [0.5]
55
+ expand_ratio = [1.0]
56
+ stage_type = "csp"
57
+ block_type = "bottleneck_block"
58
+ stem_pooling = "max"
59
+ stackwise_strides = [1, 2, 2, 2]
60
+ groups = 32
61
+ stem_padding = "valid"
62
+ avg_down = False
63
+ down_growth = False
64
+ cross_linear = True
65
+ elif timm_architecture == "darknet53":
66
+ stem_filters = 32
67
+ stem_kernel_size = 3
68
+ stem_strides = 1
69
+ stackwise_depth = [1, 2, 8, 8, 4]
70
+ stackwise_num_filters = [64, 128, 256, 512, 1024]
71
+ bottle_ratio = [0.5]
72
+ block_ratio = [1.0]
73
+ groups = 1
74
+ expand_ratio = [1.0]
75
+ stage_type = "dark"
76
+ block_type = "dark_block"
77
+ stem_pooling = None
78
+ stackwise_strides = [2, 2, 2, 2, 2]
79
+ stem_padding = "same"
80
+ avg_down = False
81
+ down_growth = False
82
+ cross_linear = False
24
83
  else:
25
84
  raise ValueError(
26
85
  f"Currently, the architecture {timm_architecture} is not supported."
@@ -38,6 +97,11 @@ def convert_backbone_config(timm_config):
38
97
  block_type=block_type,
39
98
  stackwise_strides=stackwise_strides,
40
99
  down_growth=down_growth,
100
+ stem_pooling=stem_pooling,
101
+ stem_padding=stem_padding,
102
+ avg_down=avg_down,
103
+ cross_linear=cross_linear,
104
+ groups=groups,
41
105
  )
42
106
 
43
107
 
@@ -81,21 +145,36 @@ def convert_weights(backbone, loader, timm_config):
81
145
  stackwise_depth = backbone.stackwise_depth
82
146
  stage_type = backbone.stage_type
83
147
  block_type = backbone.block_type
148
+ strides = backbone.stackwise_strides
84
149
 
85
150
  for idx, block in enumerate(stackwise_depth):
86
- port_conv2d(
87
- f"stages.{idx}.conv_down.conv",
88
- f"stage_{idx}_{stage_type}_conv_down_1",
89
- )
90
- port_batch_normalization(
91
- f"stages.{idx}.conv_down.bn", f"stage_{idx}_{stage_type}_bn_1"
92
- )
93
- port_conv2d(
94
- f"stages.{idx}.conv_exp.conv", f"stage_{idx}_{stage_type}_conv_exp"
95
- )
96
- port_batch_normalization(
97
- f"stages.{idx}.conv_exp.bn", f"stage_{idx}_{stage_type}_bn_2"
98
- )
151
+ if strides[idx] != 1 or stage_type == "dark":
152
+ if strides[idx] == 2 and backbone.avg_down:
153
+ port_conv2d(
154
+ f"stages.{idx}.conv_down.1.conv",
155
+ f"stage_{idx}_{stage_type}_conv_down_1",
156
+ )
157
+ port_batch_normalization(
158
+ f"stages.{idx}.conv_down.1.bn",
159
+ f"stage_{idx}_{stage_type}_bn_1",
160
+ )
161
+ else:
162
+ port_conv2d(
163
+ f"stages.{idx}.conv_down.conv",
164
+ f"stage_{idx}_{stage_type}_conv_down_1",
165
+ )
166
+ port_batch_normalization(
167
+ f"stages.{idx}.conv_down.bn",
168
+ f"stage_{idx}_{stage_type}_bn_1",
169
+ )
170
+ if stage_type != "dark":
171
+ port_conv2d(
172
+ f"stages.{idx}.conv_exp.conv",
173
+ f"stage_{idx}_{stage_type}_conv_exp",
174
+ )
175
+ port_batch_normalization(
176
+ f"stages.{idx}.conv_exp.bn", f"stage_{idx}_{stage_type}_bn_2"
177
+ )
99
178
 
100
179
  for i in range(block):
101
180
  port_conv2d(
@@ -133,16 +212,8 @@ def convert_weights(backbone, loader, timm_config):
133
212
  f"stages.{idx}.conv_transition_b.bn",
134
213
  f"stage_{idx}_{stage_type}_transition_b_bn",
135
214
  )
136
- port_conv2d(
137
- f"stages.{idx}.conv_transition.conv",
138
- f"stage_{idx}_{stage_type}_conv_transition",
139
- )
140
- port_batch_normalization(
141
- f"stages.{idx}.conv_transition.bn",
142
- f"stage_{idx}_{stage_type}_transition_bn",
143
- )
144
215
 
145
- else:
216
+ if stage_type != "dark":
146
217
  port_conv2d(
147
218
  f"stages.{idx}.conv_transition.conv",
148
219
  f"stage_{idx}_{stage_type}_conv_transition",
@@ -16,17 +16,17 @@ class TimmPresetLoader(PresetLoader):
16
16
  def __init__(self, preset, config):
17
17
  super().__init__(preset, config)
18
18
  architecture = self.config["architecture"]
19
- if "resnet" in architecture:
19
+ if architecture.startswith("resnet"):
20
20
  self.converter = convert_resnet
21
- elif "csp" in architecture:
21
+ elif architecture.startswith(("csp", "dark")):
22
22
  self.converter = convert_cspnet
23
- elif "densenet" in architecture:
23
+ elif architecture.startswith("densenet"):
24
24
  self.converter = convert_densenet
25
- elif "mobilenet" in architecture:
25
+ elif architecture.startswith("mobilenet"):
26
26
  self.converter = convert_mobilenet
27
- elif "vgg" in architecture:
27
+ elif architecture.startswith("vgg"):
28
28
  self.converter = convert_vgg
29
- elif "efficientnet" in architecture:
29
+ elif architecture.startswith("efficientnet"):
30
30
  self.converter = convert_efficientnet
31
31
  else:
32
32
  raise ValueError(
@@ -7,7 +7,7 @@ backbone_cls = Llama3Backbone
7
7
 
8
8
 
9
9
  def convert_backbone_config(transformers_config):
10
- return {
10
+ backbone_config = {
11
11
  "vocabulary_size": transformers_config["vocab_size"],
12
12
  "num_layers": transformers_config["num_hidden_layers"],
13
13
  "num_query_heads": transformers_config["num_attention_heads"],
@@ -15,8 +15,28 @@ def convert_backbone_config(transformers_config):
15
15
  "intermediate_dim": transformers_config["intermediate_size"],
16
16
  "num_key_value_heads": transformers_config["num_key_value_heads"],
17
17
  "tie_word_embeddings": transformers_config["tie_word_embeddings"],
18
+ "rope_max_wavelength": transformers_config["rope_theta"],
18
19
  }
19
20
 
21
+ if transformers_config.get("rope_scaling", None) is not None:
22
+ if transformers_config["rope_scaling"]["rope_type"] != "llama3":
23
+ raise ValueError("The config should be a valid llama3 config.")
24
+ backbone_config["rope_frequency_adjustment_factor"] = (
25
+ transformers_config["rope_scaling"]["factor"]
26
+ )
27
+ backbone_config["rope_low_freq_factor"] = transformers_config[
28
+ "rope_scaling"
29
+ ]["low_freq_factor"]
30
+ backbone_config["rope_high_freq_factor"] = transformers_config[
31
+ "rope_scaling"
32
+ ]["high_freq_factor"]
33
+ backbone_config["rope_pretraining_sequence_length"] = (
34
+ transformers_config["rope_scaling"][
35
+ "original_max_position_embeddings"
36
+ ]
37
+ )
38
+ return backbone_config
39
+
20
40
 
21
41
  def convert_weights(backbone, loader, transformers_config):
22
42
  loader.port_weight(