fusion-bench 0.2.19__py3-none-any.whl → 0.2.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (193) hide show
  1. fusion_bench/__init__.py +1 -0
  2. fusion_bench/_get_started/__init__.py +3 -0
  3. fusion_bench/_get_started/greeting_program.py +49 -0
  4. fusion_bench/compat/method/base_algorithm.py +14 -0
  5. fusion_bench/constants/__init__.py +5 -0
  6. fusion_bench/constants/clip_vision.py +26 -2
  7. fusion_bench/constants/paths.py +4 -0
  8. fusion_bench/dataset/clip_dataset.py +2 -1
  9. fusion_bench/dataset/gpt2_glue.py +9 -9
  10. fusion_bench/dataset/image_corruption/__init__.py +0 -0
  11. fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
  12. fusion_bench/dataset/image_dataset.py +1 -1
  13. fusion_bench/dataset/nyuv2.py +2 -2
  14. fusion_bench/method/__init__.py +16 -1
  15. fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
  16. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
  17. fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
  18. fusion_bench/method/base_algorithm.py +195 -12
  19. fusion_bench/method/bitdelta/__init__.py +4 -0
  20. fusion_bench/method/bitdelta/bitdelta.py +156 -0
  21. fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
  22. fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
  23. fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
  24. fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
  25. fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
  26. fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
  27. fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
  28. fusion_bench/method/doge_ta/doge_ta.py +1 -1
  29. fusion_bench/method/ensemble.py +12 -12
  30. fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
  31. fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -2
  32. fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
  33. fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
  34. fusion_bench/method/fw_merging/fw_hard.py +1 -1
  35. fusion_bench/method/fw_merging/fw_soft.py +1 -1
  36. fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
  37. fusion_bench/method/linear/expo.py +2 -1
  38. fusion_bench/method/linear/linear_interpolation.py +6 -4
  39. fusion_bench/method/linear/simple_average_for_llama.py +16 -6
  40. fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
  41. fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
  42. fusion_bench/method/model_recombination.py +2 -5
  43. fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
  44. fusion_bench/method/moe_pruner/utils/data.py +2 -1
  45. fusion_bench/method/moe_pruner/utils/prune.py +6 -1
  46. fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
  47. fusion_bench/method/pruning/wanda_utils/data.py +1 -2
  48. fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
  49. fusion_bench/method/randes/modelsoup.py +1 -3
  50. fusion_bench/method/regmean/clip_regmean.py +2 -2
  51. fusion_bench/method/regmean/gpt2_regmean.py +3 -10
  52. fusion_bench/method/regmean/regmean.py +2 -11
  53. fusion_bench/method/regmean_plusplus/__init__.py +3 -0
  54. fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +199 -0
  55. fusion_bench/method/regmean_plusplus/regmean_plusplus.py +383 -0
  56. fusion_bench/method/simple_average.py +16 -4
  57. fusion_bench/method/slerp/slerp.py +5 -2
  58. fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
  59. fusion_bench/method/smile_upscaling/projected_energy.py +145 -0
  60. fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +39 -28
  61. fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
  62. fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
  63. fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
  64. fusion_bench/method/ties_merging/ties_merging.py +13 -26
  65. fusion_bench/method/we_moe/clip_we_moe.py +5 -4
  66. fusion_bench/method/we_moe/we_moe.py +6 -6
  67. fusion_bench/method/weighted_average/llama.py +4 -16
  68. fusion_bench/metrics/continual_learning/__init__.py +1 -0
  69. fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
  70. fusion_bench/metrics/nyuv2/__init__.py +2 -2
  71. fusion_bench/metrics/nyuv2/segmentation.py +1 -1
  72. fusion_bench/mixins/__init__.py +10 -2
  73. fusion_bench/mixins/clip_classification.py +4 -3
  74. fusion_bench/mixins/hydra_config.py +105 -7
  75. fusion_bench/mixins/lightning_fabric.py +2 -0
  76. fusion_bench/mixins/serialization.py +265 -48
  77. fusion_bench/modelpool/__init__.py +2 -2
  78. fusion_bench/modelpool/base_pool.py +29 -9
  79. fusion_bench/modelpool/causal_lm/causal_lm.py +9 -0
  80. fusion_bench/modelpool/clip_vision/modelpool.py +43 -12
  81. fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
  82. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
  83. fusion_bench/models/__init__.py +2 -1
  84. fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
  85. fusion_bench/models/hf_utils.py +182 -0
  86. fusion_bench/models/linearized/linearized_model_utils.py +4 -4
  87. fusion_bench/models/linearized/vision_model.py +1 -1
  88. fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
  89. fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
  90. fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
  91. fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
  92. fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
  93. fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
  94. fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
  95. fusion_bench/models/modeling_smile_llama/__init__.py +0 -0
  96. fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
  97. fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +705 -0
  98. fusion_bench/models/modeling_smile_llama/register.py +8 -0
  99. fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
  100. fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
  101. fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +6 -7
  102. fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
  103. fusion_bench/models/parameter_dict.py +1 -1
  104. fusion_bench/models/sparse_we_moe.py +1 -53
  105. fusion_bench/models/utils.py +26 -0
  106. fusion_bench/models/we_moe.py +1 -53
  107. fusion_bench/models/wrappers/ensemble.py +6 -4
  108. fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
  109. fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
  110. fusion_bench/programs/base_program.py +81 -2
  111. fusion_bench/programs/fabric_fusion_program.py +24 -8
  112. fusion_bench/scripts/cli.py +6 -6
  113. fusion_bench/taskpool/base_pool.py +4 -3
  114. fusion_bench/taskpool/clip_vision/taskpool.py +34 -18
  115. fusion_bench/taskpool/dummy.py +1 -1
  116. fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
  117. fusion_bench/tasks/clip_classification/__init__.py +6 -4
  118. fusion_bench/utils/__init__.py +6 -1
  119. fusion_bench/utils/devices.py +14 -4
  120. fusion_bench/utils/instantiate_utils.py +3 -1
  121. fusion_bench/utils/misc.py +48 -2
  122. fusion_bench/utils/modelscope.py +265 -0
  123. fusion_bench/utils/parameters.py +2 -2
  124. fusion_bench/utils/rich_utils.py +3 -0
  125. fusion_bench/utils/state_dict_arithmetic.py +34 -27
  126. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/METADATA +31 -24
  127. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/RECORD +189 -153
  128. fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
  129. fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
  130. fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
  131. fusion_bench_config/_get_started/greeting_program.yaml +4 -0
  132. fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
  133. fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
  134. fusion_bench_config/fabric_model_fusion.yaml +45 -17
  135. fusion_bench_config/hydra/default.yaml +6 -2
  136. fusion_bench_config/llama_full_finetune.yaml +1 -0
  137. fusion_bench_config/method/adamerging/clip.yaml +1 -1
  138. fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
  139. fusion_bench_config/method/depth_upscaling.yaml +4 -1
  140. fusion_bench_config/method/regmean/clip_regmean.yaml +1 -1
  141. fusion_bench_config/method/regmean_plusplus/clip_regmean_plusplus.yaml +11 -0
  142. fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
  143. fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
  144. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -0
  145. fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
  146. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +73 -8
  147. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +27 -7
  148. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml +34 -4
  149. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +14 -17
  150. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml +14 -3
  151. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +39 -5
  152. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +49 -5
  153. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +55 -5
  154. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +21 -4
  155. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +61 -5
  156. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +67 -5
  157. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +73 -5
  158. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +26 -3
  159. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
  160. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +7 -5
  161. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +6 -10
  162. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_cars.yaml +6 -7
  163. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_dtd.yaml +6 -7
  164. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_cars_and_dtd.yaml +7 -8
  165. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +8 -6
  166. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +4 -6
  167. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +32 -7
  168. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +14 -6
  169. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +73 -8
  170. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +27 -7
  171. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +6 -10
  172. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +2 -2
  173. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
  174. fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
  175. fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
  176. fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
  177. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
  178. fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
  179. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
  180. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
  181. fusion_bench_config/nyuv2_config.yaml +3 -1
  182. fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
  183. fusion_bench_config/path/default.yaml +28 -0
  184. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
  185. fusion_bench_config/method/adamerging.yaml +0 -23
  186. fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
  187. fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
  188. fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
  189. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/WHEEL +0 -0
  190. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/entry_points.txt +0 -0
  191. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/licenses/LICENSE +0 -0
  192. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/top_level.txt +0 -0
  193. /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
@@ -56,6 +56,8 @@ class LayerWiseFeatureSaver:
56
56
  first_token_only: bool = True,
57
57
  max_num: Optional[int] = None,
58
58
  ):
59
+ if isinstance(save_path, str):
60
+ save_path = Path(save_path)
59
61
  self.save_path = save_path
60
62
  self.first_token_only = first_token_only
61
63
  self.max_num = max_num
@@ -122,9 +124,9 @@ class CLIPVisionModelTaskPool(
122
124
  self,
123
125
  test_datasets: Union[DictConfig, Dict[str, Dataset]],
124
126
  *,
125
- processor: Union[DictConfig, CLIPProcessor],
126
- data_processor: Union[DictConfig, CLIPProcessor],
127
- clip_model: Union[DictConfig, CLIPModel],
127
+ processor: Union[str, DictConfig, CLIPProcessor],
128
+ clip_model: Union[str, DictConfig, CLIPModel],
129
+ data_processor: Union[DictConfig, CLIPProcessor] = None,
128
130
  dataloader_kwargs: DictConfig = None,
129
131
  layer_wise_feature_save_path: Optional[str] = None,
130
132
  layer_wise_feature_first_token_only: bool = True,
@@ -159,21 +161,35 @@ class CLIPVisionModelTaskPool(
159
161
  Set up the processor, data processor, CLIP model, test datasets, and data loaders.
160
162
  """
161
163
  # setup processor and clip model
162
- self.processor = (
163
- instantiate(self._processor)
164
- if isinstance(self._processor, DictConfig)
165
- else self._processor
166
- )
167
- self.data_processor = (
168
- instantiate(self._data_processor)
169
- if isinstance(self._data_processor, DictConfig)
170
- else self._data_processor
171
- )
172
- self.clip_model = (
173
- instantiate(self._clip_model)
174
- if isinstance(self._clip_model, DictConfig)
175
- else self._clip_model
176
- )
164
+ if isinstance(self._processor, str):
165
+ self.processor = CLIPProcessor.from_pretrained(self._processor)
166
+ elif (
167
+ isinstance(self._processor, (dict, DictConfig))
168
+ and "_target_" in self._processor
169
+ ):
170
+ self.processor = instantiate(self._processor)
171
+ else:
172
+ self.processor = self._processor
173
+
174
+ if self._data_processor is None:
175
+ self.data_processor = self.processor
176
+ else:
177
+ self.data_processor = (
178
+ instantiate(self._data_processor)
179
+ if isinstance(self._data_processor, DictConfig)
180
+ else self._data_processor
181
+ )
182
+
183
+ if isinstance(self._clip_model, str):
184
+ self.clip_model = CLIPModel.from_pretrained(self._clip_model)
185
+ elif (
186
+ isinstance(self._clip_model, (dict, DictConfig))
187
+ and "_target_" in self._clip_model
188
+ ):
189
+ self.clip_model = instantiate(self._clip_model)
190
+ else:
191
+ self.clip_model = self._clip_model
192
+
177
193
  self.clip_model = self.fabric.to_device(self.clip_model)
178
194
  self.clip_model.requires_grad_(False)
179
195
  self.clip_model.eval()
@@ -4,13 +4,13 @@ This is the dummy task pool that is used for debugging purposes.
4
4
 
5
5
  from typing import Optional
6
6
 
7
+ from lightning.pytorch.utilities import rank_zero_only
7
8
  from torch import nn
8
9
 
9
10
  from fusion_bench.models.separate_io import separate_save
10
11
  from fusion_bench.taskpool.base_pool import BaseTaskPool
11
12
  from fusion_bench.utils import timeit_context
12
13
  from fusion_bench.utils.parameters import count_parameters, print_parameters
13
- from lightning.pytorch.utilities import rank_zero_only
14
14
 
15
15
 
16
16
  def get_model_summary(model: nn.Module) -> dict:
@@ -1,6 +1,6 @@
1
1
  import logging
2
2
  import os
3
- from typing import List, Literal, Optional, Union, TYPE_CHECKING
3
+ from typing import TYPE_CHECKING, List, Literal, Optional, Union
4
4
 
5
5
  import lightning.fabric
6
6
  import lm_eval
@@ -12,7 +12,6 @@ from fusion_bench import BaseTaskPool
12
12
  from fusion_bench.mixins import LightningFabricMixin
13
13
  from fusion_bench.utils.strenum import _version
14
14
 
15
-
16
15
  log = logging.getLogger(__name__)
17
16
 
18
17
 
@@ -1,6 +1,6 @@
1
1
  import importlib
2
2
  import warnings
3
- from typing import Any, Callable, Dict, List
3
+ from typing import Any, Callable, Dict, List, Tuple
4
4
 
5
5
  from datasets import load_dataset
6
6
 
@@ -79,7 +79,9 @@ class CLIPTemplateFactory:
79
79
  }
80
80
 
81
81
  @staticmethod
82
- def get_classnames_and_templates(dataset_name: str):
82
+ def get_classnames_and_templates(
83
+ dataset_name: str,
84
+ ) -> Tuple[List[str], List[Callable]]:
83
85
  """
84
86
  Retrieves class names and templates for the specified dataset.
85
87
 
@@ -169,7 +171,7 @@ class CLIPTemplateFactory:
169
171
  CLIPTemplateFactory._dataset_mapping[dataset_name] = dataset_info
170
172
 
171
173
  @staticmethod
172
- def get_available_datasets():
174
+ def get_available_datasets() -> List[str]:
173
175
  """
174
176
  Get a list of all available dataset names.
175
177
 
@@ -179,5 +181,5 @@ class CLIPTemplateFactory:
179
181
  return list(CLIPTemplateFactory._dataset_mapping.keys())
180
182
 
181
183
 
182
- def get_classnames_and_templates(dataset_name: str):
184
+ def get_classnames_and_templates(dataset_name: str) -> Tuple[List[str], List[Callable]]:
183
185
  return CLIPTemplateFactory.get_classnames_and_templates(dataset_name)
@@ -7,7 +7,12 @@ from .cache_utils import *
7
7
  from .devices import *
8
8
  from .dtype import parse_dtype
9
9
  from .fabric import seed_everything_by_time
10
- from .instantiate_utils import instantiate, is_instantiable
10
+ from .instantiate_utils import (
11
+ instantiate,
12
+ is_instantiable,
13
+ set_print_function_call,
14
+ set_print_function_call_permeanent,
15
+ )
11
16
  from .json import load_from_json, save_to_json
12
17
  from .lazy_state_dict import LazyStateDict
13
18
  from .misc import *
@@ -1,4 +1,5 @@
1
1
  import gc
2
+ import logging
2
3
  import os
3
4
  from typing import List, Optional, Union
4
5
 
@@ -12,7 +13,7 @@ from transformers.utils import (
12
13
  )
13
14
 
14
15
  __all__ = [
15
- "cuda_empty_cache",
16
+ "clear_cuda_cache",
16
17
  "to_device",
17
18
  "num_devices",
18
19
  "get_device",
@@ -21,10 +22,19 @@ __all__ = [
21
22
  "get_device_capabilities",
22
23
  ]
23
24
 
25
+ log = logging.getLogger(__name__)
24
26
 
25
- def cuda_empty_cache():
27
+
28
+ def clear_cuda_cache():
29
+ """
30
+ Clears the CUDA memory cache to free up GPU memory.
31
+ Works only if CUDA is available.
32
+ """
26
33
  gc.collect()
27
- torch.cuda.empty_cache()
34
+ if torch.cuda.is_available():
35
+ torch.cuda.empty_cache()
36
+ else:
37
+ log.warning("CUDA is not available. No cache to clear.")
28
38
 
29
39
 
30
40
  def to_device(obj, device: Optional[torch.device], **kwargs):
@@ -75,7 +85,7 @@ def num_devices(devices: Union[int, List[int], str]) -> int:
75
85
  Return the number of devices.
76
86
 
77
87
  Args:
78
- devices: `devices` can be a single int to specify the number of devices, or a list of device ids, e.g. [0, 1, 2, 3] or a str of device ids, e.g. "0,1,2,3" and "[0, 1, 2]".
88
+ devices: `devices` can be a single int to specify the number of devices, or a list of device ids, e.g. [0, 1, 2, 3], or a str of device ids, e.g. "0,1,2,3" and "[0, 1, 2]".
79
89
 
80
90
  Returns:
81
91
  The number of devices.
@@ -28,7 +28,7 @@ PRINT_FUNCTION_CALL_FUNC = print
28
28
  Function to be used for printing function calls.
29
29
  """
30
30
 
31
- CATCH_EXCEPTION = True
31
+ CATCH_EXCEPTION = False
32
32
 
33
33
 
34
34
  @contextmanager
@@ -41,10 +41,12 @@ def set_print_function_call(value: bool):
41
41
  finally:
42
42
  PRINT_FUNCTION_CALL = old_value
43
43
 
44
+
44
45
  def set_print_function_call_permeanent(value: bool):
45
46
  global PRINT_FUNCTION_CALL
46
47
  PRINT_FUNCTION_CALL = value
47
48
 
49
+
48
50
  def is_instantiable(config: Union[DictConfig, Any]) -> bool:
49
51
  if OmegaConf.is_dict(config):
50
52
  return "_target_" in config
@@ -1,6 +1,13 @@
1
- from typing import Iterable, List
1
+ from difflib import get_close_matches
2
+ from typing import Any, Iterable, List, Optional
2
3
 
3
- __all__ = ["first", "has_length", "join_list", "attr_equal"]
4
+ __all__ = [
5
+ "first",
6
+ "has_length",
7
+ "join_list",
8
+ "attr_equal",
9
+ "validate_and_suggest_corrections",
10
+ ]
4
11
 
5
12
 
6
13
  def first(iterable: Iterable):
@@ -41,3 +48,42 @@ def attr_equal(obj, attr: str, value):
41
48
  if not hasattr(obj, attr):
42
49
  return False
43
50
  return getattr(obj, attr) == value
51
+
52
+
53
+ def validate_and_suggest_corrections(
54
+ obj: Any, values: Iterable[Any], *, max_suggestions: int = 3, cutoff: float = 0.6
55
+ ) -> Any:
56
+ """
57
+ Return *obj* if it is contained in *values*.
58
+ Otherwise raise a helpful ``ValueError`` that lists the closest matches.
59
+
60
+ Args:
61
+ obj : Any
62
+ The value to validate.
63
+ values : Iterable[Any]
64
+ The set of allowed values.
65
+ max_suggestions : int, optional
66
+ How many typo-hints to include at most (default 3).
67
+ cutoff : float, optional
68
+ Similarity threshold for suggestions (0.0-1.0, default 0.6).
69
+
70
+ Returns:
71
+ The original *obj* if it is valid.
72
+
73
+ Raises:
74
+ ValueError: With a friendly message that points out possible typos.
75
+ """
76
+ # Normalise to a list so we can reuse it
77
+ value_list = list(values)
78
+
79
+ if obj in value_list:
80
+ return obj
81
+
82
+ # Build suggestions
83
+ str_values = list(map(str, value_list))
84
+ matches = get_close_matches(str(obj), str_values, n=max_suggestions, cutoff=cutoff)
85
+
86
+ msg = f"Invalid value {obj!r}. Allowed values: {value_list}"
87
+ if matches:
88
+ msg += f". Did you mean {', '.join(repr(m) for m in matches)}?"
89
+ raise ValueError(msg)
@@ -0,0 +1,265 @@
1
+ import os
2
+ from typing import Literal, Optional
3
+
4
+ from datasets import load_dataset as datasets_load_dataset
5
+
6
+ from fusion_bench.utils import validate_and_suggest_corrections
7
+
8
+ try:
9
+ from modelscope import dataset_file_download as modelscope_dataset_file_download
10
+ from modelscope import model_file_download as modelscope_model_file_download
11
+ from modelscope import snapshot_download as modelscope_snapshot_download
12
+
13
+ except ImportError:
14
+
15
+ def _raise_modelscope_not_installed_error(*args, **kwargs):
16
+ raise ImportError(
17
+ "ModelScope is not installed. Please install it using `pip install modelscope` to use ModelScope models."
18
+ )
19
+
20
+ modelscope_snapshot_download = _raise_modelscope_not_installed_error
21
+ modelscope_model_file_download = _raise_modelscope_not_installed_error
22
+ modelscope_dataset_file_download = _raise_modelscope_not_installed_error
23
+
24
+ try:
25
+ from huggingface_hub import hf_hub_download
26
+ from huggingface_hub import snapshot_download as huggingface_snapshot_download
27
+ except ImportError:
28
+
29
+ def _raise_hugggingface_not_installed_error(*args, **kwargs):
30
+ raise ImportError(
31
+ "Hugging Face Hub is not installed. Please install it using `pip install huggingface_hub` to use Hugging Face models."
32
+ )
33
+
34
+ huggingface_snapshot_download = _raise_hugggingface_not_installed_error
35
+ hf_hub_download = _raise_hugggingface_not_installed_error
36
+
37
+ __all__ = [
38
+ "load_dataset",
39
+ "resolve_repo_path",
40
+ ]
41
+
42
+ AVAILABLE_PLATFORMS = ["hf", "huggingface", "modelscope"]
43
+
44
+
45
+ def _raise_unknown_platform_error(platform: str):
46
+ raise ValueError(
47
+ f"Unsupported platform: {platform}. Supported platforms are 'hf', 'huggingface', and 'modelscope'."
48
+ )
49
+
50
+
51
+ def load_dataset(
52
+ name: str,
53
+ split: str = "train",
54
+ platform: Literal["hf", "huggingface", "modelscope"] = "hf",
55
+ ):
56
+ """
57
+ Load a dataset from Hugging Face or ModelScope.
58
+
59
+ Args:
60
+ platform (Literal['hf', 'modelscope']): The platform to load the dataset from.
61
+ name (str): The name of the dataset.
62
+ split (str): The split of the dataset to load (default is "train").
63
+
64
+ Returns:
65
+ Dataset: The loaded dataset.
66
+ """
67
+ validate_and_suggest_corrections(platform, AVAILABLE_PLATFORMS)
68
+ if platform == "hf" or platform == "huggingface":
69
+ return datasets_load_dataset(name, split=split)
70
+ elif platform == "modelscope":
71
+ dataset_dir = modelscope_snapshot_download(name, repo_type="dataset")
72
+ return datasets_load_dataset(dataset_dir, split=split)
73
+ else:
74
+ _raise_unknown_platform_error(platform)
75
+
76
+
77
+ def resolve_repo_path(
78
+ repo_id: str,
79
+ repo_type: Optional[str] = "model",
80
+ platform: Literal["hf", "huggingface", "modelscope"] = "hf",
81
+ **kwargs,
82
+ ):
83
+ """
84
+ Resolve and download a repository from various platforms to a local path.
85
+
86
+ This function handles multiple repository sources including local paths, Hugging Face,
87
+ and ModelScope. It automatically downloads remote repositories to local cache and
88
+ returns the local path for further use.
89
+
90
+ Args:
91
+ repo_id (str): Repository identifier. Can be:
92
+ - Local file/directory path (returned as-is if exists)
93
+ - Hugging Face model/dataset ID (e.g., "bert-base-uncased")
94
+ - ModelScope model/dataset ID
95
+ - URL-prefixed ID (e.g., "hf://model-name", "modelscope://model-name").
96
+ The prefix will override the platform argument.
97
+ repo_type (str, optional): Type of repository to download. Defaults to "model".
98
+ Common values include "model" and "dataset".
99
+ platform (Literal["hf", "huggingface", "modelscope"], optional):
100
+ Platform to download from. Defaults to "hf". Options:
101
+ - "hf" or "huggingface": Hugging Face Hub
102
+ - "modelscope": ModelScope platform
103
+ **kwargs: Additional arguments passed to the underlying download functions.
104
+
105
+ Returns:
106
+ str: Local path to the repository (either existing local path or downloaded cache path).
107
+
108
+ Raises:
109
+ FileNotFoundError: If the repository cannot be found or downloaded from any platform.
110
+ ValueError: If an unsupported platform is specified.
111
+ ImportError: If required dependencies for the specified platform are not installed.
112
+
113
+ Examples:
114
+ >>> # Local path (returned as-is)
115
+ >>> resolve_repo_path("/path/to/local/model")
116
+ "/path/to/local/model"
117
+
118
+ >>> # Hugging Face model
119
+ >>> resolve_repo_path("bert-base-uncased")
120
+ "/home/user/.cache/huggingface/hub/models--bert-base-uncased/..."
121
+
122
+ >>> # ModelScope model with explicit platform
123
+ >>> resolve_repo_path("damo/nlp_bert_backbone_base_std", platform="modelscope")
124
+ "/home/user/.cache/modelscope/hub/damo/nlp_bert_backbone_base_std/..."
125
+
126
+ >>> # URL-prefixed repository ID
127
+ >>> resolve_repo_path("hf://microsoft/DialoGPT-medium")
128
+ "/home/user/.cache/huggingface/hub/models--microsoft--DialoGPT-medium/..."
129
+ """
130
+ # If it's a HuggingFace Hub model id, download snapshot
131
+ if repo_id.startswith("hf://") or repo_id.startswith("huggingface://"):
132
+ repo_id = repo_id.replace("hf://", "").replace("huggingface://", "")
133
+ platform = "hf"
134
+ # If it's a ModelScope model id, download snapshot
135
+ elif repo_id.startswith("modelscope://"):
136
+ repo_id = repo_id.replace("modelscope://", "")
137
+ platform = "modelscope"
138
+
139
+ # If it's a local file or directory, return as is
140
+ if os.path.exists(repo_id):
141
+ return repo_id
142
+
143
+ try:
144
+ validate_and_suggest_corrections(platform, AVAILABLE_PLATFORMS)
145
+ # This will download the model to the cache and return the local path
146
+ if platform in ["hf", "huggingface"]:
147
+ local_path = huggingface_snapshot_download(
148
+ repo_id=repo_id, repo_type=repo_type, **kwargs
149
+ )
150
+ elif platform == "modelscope":
151
+ local_path = modelscope_snapshot_download(
152
+ repo_id=repo_id, repo_type=repo_type, **kwargs
153
+ )
154
+ else:
155
+ _raise_unknown_platform_error(platform)
156
+ return local_path
157
+ except Exception as e:
158
+ raise FileNotFoundError(f"Could not resolve checkpoint: {repo_id}. Error: {e}")
159
+
160
+
161
+ def resolve_file_path(
162
+ repo_id: str,
163
+ filename: str,
164
+ repo_type: Literal["model", "dataset"] = "model",
165
+ platform: Literal["hf", "huggingface", "modelscope"] = "hf",
166
+ **kwargs,
167
+ ) -> str:
168
+ """
169
+ Resolve and download a specific file from a repository across multiple platforms.
170
+
171
+ This function downloads a specific file from repositories hosted on various platforms
172
+ including local paths, Hugging Face Hub, and ModelScope. It handles platform-specific
173
+ URL prefixes and automatically determines the appropriate download method.
174
+
175
+ Args:
176
+ repo_id (str): Repository identifier. Can be:
177
+ - Local directory path (file will be joined with this path if it exists)
178
+ - Hugging Face model/dataset ID (e.g., "bert-base-uncased")
179
+ - ModelScope model/dataset ID
180
+ - URL-prefixed ID (e.g., "hf://model-name", "modelscope://model-name").
181
+ The prefix will override the platform argument.
182
+ filename (str): The specific file to download from the repository.
183
+ repo_type (Literal["model", "dataset"], optional): Type of repository.
184
+ Defaults to "model". Used for ModelScope platform to determine the
185
+ correct download function.
186
+ platform (Literal["hf", "huggingface", "modelscope"], optional):
187
+ Platform to download from. Defaults to "hf". Options:
188
+ - "hf" or "huggingface": Hugging Face Hub
189
+ - "modelscope": ModelScope platform
190
+ **kwargs: Additional arguments passed to the underlying download functions
191
+ (e.g., cache_dir, force_download, use_auth_token).
192
+
193
+ Returns:
194
+ str: Local path to the downloaded file.
195
+
196
+ Raises:
197
+ ValueError: If an unsupported repo_type is specified for ModelScope platform.
198
+ ImportError: If required dependencies for the specified platform are not installed.
199
+ FileNotFoundError: If the file cannot be found or downloaded.
200
+
201
+ Examples:
202
+ >>> # Download config.json from a Hugging Face model
203
+ >>> resolve_file_path("bert-base-uncased", "config.json")
204
+ "/home/user/.cache/huggingface/hub/models--bert-base-uncased/.../config.json"
205
+
206
+ >>> # Download from ModelScope
207
+ >>> resolve_file_path(
208
+ ... "damo/nlp_bert_backbone_base_std",
209
+ ... "pytorch_model.bin",
210
+ ... platform="modelscope"
211
+ ... )
212
+ "/home/user/.cache/modelscope/hub/.../pytorch_model.bin"
213
+
214
+ >>> # Local file path
215
+ >>> resolve_file_path("/path/to/local/model", "config.json")
216
+ "/path/to/local/model/config.json"
217
+
218
+ >>> # URL-prefixed repository
219
+ >>> resolve_file_path("hf://microsoft/DialoGPT-medium", "config.json")
220
+ "/home/user/.cache/huggingface/hub/.../config.json"
221
+
222
+ >>> # Download dataset file from ModelScope
223
+ >>> resolve_file_path(
224
+ ... "DAMO_NLP/jd",
225
+ ... "train.json",
226
+ ... repo_type="dataset",
227
+ ... platform="modelscope"
228
+ ... )
229
+ "/home/user/.cache/modelscope/datasets/.../train.json"
230
+ """
231
+ # If it's a HuggingFace Hub model id, download snapshot
232
+ if repo_id.startswith("hf://") or repo_id.startswith("huggingface://"):
233
+ repo_id = repo_id.replace("hf://", "").replace("huggingface://", "")
234
+ platform = "hf"
235
+ # If it's a ModelScope model id, download snapshot
236
+ elif repo_id.startswith("modelscope://"):
237
+ repo_id = repo_id.replace("modelscope://", "")
238
+ platform = "modelscope"
239
+
240
+ # If it's a local file or directory, return as is
241
+ if os.path.exists(repo_id):
242
+ return os.path.join(repo_id, filename)
243
+
244
+ if platform in ["hf", "huggingface"]:
245
+ return hf_hub_download(
246
+ repo_id=repo_id,
247
+ filename=filename,
248
+ repo_type=repo_type,
249
+ **kwargs,
250
+ )
251
+ elif platform == "modelscope":
252
+ if repo_type == "model":
253
+ return modelscope_model_file_download(
254
+ model_id=repo_id, file_path=filename, **kwargs
255
+ )
256
+ elif repo_type == "dataset":
257
+ return modelscope_dataset_file_download(
258
+ dataset_id=repo_id, file_path=filename, **kwargs
259
+ )
260
+ else:
261
+ raise ValueError(
262
+ f"Unsupported repo_type: {repo_type}. Supported types are 'model' and 'dataset'."
263
+ )
264
+ else:
265
+ _raise_unknown_platform_error(platform)
@@ -1,6 +1,6 @@
1
1
  import copy
2
2
  from collections import OrderedDict
3
- from typing import List, Mapping, Optional, Union
3
+ from typing import Dict, List, Mapping, Optional, Union
4
4
 
5
5
  import torch
6
6
  from torch import nn
@@ -83,7 +83,7 @@ def vector_to_state_dict(
83
83
  vector: torch.Tensor,
84
84
  state_dict: Union[StateDictType, nn.Module],
85
85
  remove_keys: Optional[List[str]] = None,
86
- ):
86
+ ) -> Dict[str, torch.Tensor]:
87
87
  """
88
88
  Convert a vector to a state dictionary.
89
89
 
@@ -189,6 +189,9 @@ if __name__ == "__main__":
189
189
 
190
190
 
191
191
  def setup_colorlogging(force=False, **config_kwargs):
192
+ """
193
+ Sets up color logging for the application.
194
+ """
192
195
  FORMAT = "%(message)s"
193
196
 
194
197
  logging.basicConfig(