fusion-bench 0.2.20__py3-none-any.whl → 0.2.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. fusion_bench/__init__.py +22 -2
  2. fusion_bench/_get_started/__init__.py +3 -0
  3. fusion_bench/_get_started/greeting_program.py +49 -0
  4. fusion_bench/compat/method/base_algorithm.py +14 -0
  5. fusion_bench/constants/__init__.py +6 -0
  6. fusion_bench/constants/clip_vision.py +26 -2
  7. fusion_bench/constants/paths.py +4 -0
  8. fusion_bench/constants/runtime.py +57 -0
  9. fusion_bench/dataset/clip_dataset.py +2 -1
  10. fusion_bench/dataset/gpt2_glue.py +9 -9
  11. fusion_bench/dataset/image_corruption/__init__.py +0 -0
  12. fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
  13. fusion_bench/dataset/image_dataset.py +1 -1
  14. fusion_bench/dataset/nyuv2.py +2 -2
  15. fusion_bench/method/__init__.py +24 -5
  16. fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
  17. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
  18. fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
  19. fusion_bench/method/base_algorithm.py +195 -12
  20. fusion_bench/method/bitdelta/__init__.py +5 -0
  21. fusion_bench/method/bitdelta/bitdelta.py +156 -0
  22. fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
  23. fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
  24. fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
  25. fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
  26. fusion_bench/method/classification/clip_finetune.py +1 -1
  27. fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
  28. fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
  29. fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
  30. fusion_bench/method/doge_ta/doge_ta.py +1 -1
  31. fusion_bench/method/ensemble.py +12 -12
  32. fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
  33. fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -6
  34. fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
  35. fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
  36. fusion_bench/method/fw_merging/fw_hard.py +1 -1
  37. fusion_bench/method/fw_merging/fw_soft.py +1 -1
  38. fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
  39. fusion_bench/method/linear/expo.py +2 -1
  40. fusion_bench/method/linear/linear_interpolation.py +6 -4
  41. fusion_bench/method/linear/simple_average_for_llama.py +17 -13
  42. fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
  43. fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
  44. fusion_bench/method/model_recombination.py +2 -5
  45. fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
  46. fusion_bench/method/moe_pruner/utils/data.py +2 -1
  47. fusion_bench/method/moe_pruner/utils/prune.py +6 -1
  48. fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
  49. fusion_bench/method/pruning/wanda_utils/data.py +1 -2
  50. fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
  51. fusion_bench/method/randes/modelsoup.py +1 -3
  52. fusion_bench/method/regmean/clip_regmean.py +2 -2
  53. fusion_bench/method/regmean/gpt2_regmean.py +3 -10
  54. fusion_bench/method/regmean/regmean.py +2 -11
  55. fusion_bench/method/regmean_plusplus/__init__.py +1 -1
  56. fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +24 -17
  57. fusion_bench/method/regmean_plusplus/regmean_plusplus.py +56 -38
  58. fusion_bench/method/simple_average.py +12 -16
  59. fusion_bench/method/slerp/slerp.py +5 -2
  60. fusion_bench/method/smile_upscaling/causal_lm_upscaling.py +371 -0
  61. fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
  62. fusion_bench/method/smile_upscaling/projected_energy.py +144 -0
  63. fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +5 -1
  64. fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +71 -51
  65. fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
  66. fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
  67. fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
  68. fusion_bench/method/ties_merging/ties_merging.py +13 -26
  69. fusion_bench/method/we_moe/__init__.py +1 -0
  70. fusion_bench/method/we_moe/clip_we_moe.py +5 -4
  71. fusion_bench/method/we_moe/entropy_loss.py +25 -0
  72. fusion_bench/method/we_moe/flan_t5_we_moe.py +331 -0
  73. fusion_bench/method/we_moe/utils.py +15 -0
  74. fusion_bench/method/we_moe/we_moe.py +6 -6
  75. fusion_bench/method/weighted_average/llama.py +4 -16
  76. fusion_bench/metrics/continual_learning/__init__.py +1 -0
  77. fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
  78. fusion_bench/metrics/nyuv2/__init__.py +2 -2
  79. fusion_bench/metrics/nyuv2/segmentation.py +1 -1
  80. fusion_bench/mixins/__init__.py +10 -2
  81. fusion_bench/mixins/clip_classification.py +15 -45
  82. fusion_bench/mixins/hydra_config.py +105 -7
  83. fusion_bench/mixins/lightning_fabric.py +2 -0
  84. fusion_bench/mixins/serialization.py +275 -48
  85. fusion_bench/modelpool/__init__.py +2 -2
  86. fusion_bench/modelpool/base_pool.py +29 -9
  87. fusion_bench/modelpool/causal_lm/causal_lm.py +41 -33
  88. fusion_bench/modelpool/clip_vision/modelpool.py +1 -3
  89. fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
  90. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
  91. fusion_bench/models/__init__.py +7 -1
  92. fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
  93. fusion_bench/models/hf_utils.py +160 -0
  94. fusion_bench/models/linearized/linearized_model_utils.py +4 -4
  95. fusion_bench/models/linearized/vision_model.py +1 -1
  96. fusion_bench/models/model_card_templates/default.md +46 -0
  97. fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
  98. fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
  99. fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
  100. fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
  101. fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
  102. fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
  103. fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
  104. fusion_bench/models/modeling_smile_llama/__init__.py +7 -0
  105. fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
  106. fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +698 -0
  107. fusion_bench/models/modeling_smile_llama/register.py +8 -0
  108. fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
  109. fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
  110. fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +7 -12
  111. fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
  112. fusion_bench/models/parameter_dict.py +1 -1
  113. fusion_bench/models/sparse_we_moe.py +1 -53
  114. fusion_bench/models/utils.py +26 -0
  115. fusion_bench/models/we_moe.py +1 -53
  116. fusion_bench/models/wrappers/ensemble.py +6 -4
  117. fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
  118. fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
  119. fusion_bench/programs/base_program.py +81 -2
  120. fusion_bench/programs/fabric_fusion_program.py +46 -61
  121. fusion_bench/scripts/cli.py +38 -5
  122. fusion_bench/taskpool/base_pool.py +4 -3
  123. fusion_bench/taskpool/clip_vision/taskpool.py +43 -22
  124. fusion_bench/taskpool/dummy.py +1 -1
  125. fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
  126. fusion_bench/tasks/clip_classification/__init__.py +6 -4
  127. fusion_bench/utils/__init__.py +7 -1
  128. fusion_bench/utils/cache_utils.py +101 -1
  129. fusion_bench/utils/devices.py +14 -4
  130. fusion_bench/utils/fabric.py +2 -2
  131. fusion_bench/utils/instantiate_utils.py +3 -1
  132. fusion_bench/utils/lazy_imports.py +23 -0
  133. fusion_bench/utils/lazy_state_dict.py +38 -3
  134. fusion_bench/utils/modelscope.py +127 -8
  135. fusion_bench/utils/parameters.py +2 -2
  136. fusion_bench/utils/path.py +56 -0
  137. fusion_bench/utils/pylogger.py +1 -1
  138. fusion_bench/utils/rich_utils.py +3 -0
  139. fusion_bench/utils/state_dict_arithmetic.py +25 -23
  140. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/METADATA +24 -47
  141. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/RECORD +184 -145
  142. fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
  143. fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
  144. fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
  145. fusion_bench_config/_get_started/greeting_program.yaml +4 -0
  146. fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
  147. fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
  148. fusion_bench_config/fabric_model_fusion.yaml +45 -17
  149. fusion_bench_config/hydra/default.yaml +6 -2
  150. fusion_bench_config/llama_full_finetune.yaml +1 -0
  151. fusion_bench_config/method/adamerging/clip.yaml +1 -1
  152. fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
  153. fusion_bench_config/method/depth_upscaling.yaml +4 -1
  154. fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +0 -1
  155. fusion_bench_config/method/linear/simple_average_for_llama.yaml +3 -2
  156. fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +21 -0
  157. fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
  158. fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
  159. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +2 -1
  160. fusion_bench_config/method/wemoe/flan_t5_weight_ensembling_moe.yaml +20 -0
  161. fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
  162. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
  163. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +1 -1
  164. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -6
  165. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +1 -1
  166. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +1 -1
  167. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +3 -3
  168. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
  169. fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
  170. fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
  171. fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
  172. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
  173. fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
  174. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
  175. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
  176. fusion_bench_config/nyuv2_config.yaml +3 -1
  177. fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
  178. fusion_bench_config/path/default.yaml +28 -0
  179. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
  180. fusion_bench_config/method/adamerging.yaml +0 -23
  181. fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
  182. fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
  183. fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
  184. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/WHEEL +0 -0
  185. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/entry_points.txt +0 -0
  186. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/licenses/LICENSE +0 -0
  187. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/top_level.txt +0 -0
  188. /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
@@ -12,6 +12,7 @@ import os
12
12
  import hydra
13
13
  from omegaconf import DictConfig, OmegaConf
14
14
 
15
+ from fusion_bench.constants import PROJECT_ROOT_PATH
15
16
  from fusion_bench.programs import BaseHydraProgram
16
17
  from fusion_bench.utils import instantiate
17
18
 
@@ -20,11 +21,10 @@ log = logging.getLogger(__name__)
20
21
 
21
22
  def _get_default_config_path():
22
23
  for config_dir in ["fusion_bench_config", "config"]:
23
- config_path = os.path.join(
24
- importlib.import_module("fusion_bench").__path__[0], "..", config_dir
25
- )
26
- if os.path.exists(config_path) and os.path.isdir(config_path):
27
- return os.path.abspath(config_path)
24
+ for config_path_root in [os.getcwd(), PROJECT_ROOT_PATH]:
25
+ config_path = os.path.join(config_path_root, config_dir)
26
+ if os.path.exists(config_path) and os.path.isdir(config_path):
27
+ return os.path.abspath(config_path)
28
28
  return None
29
29
 
30
30
 
@@ -34,6 +34,39 @@ def _get_default_config_path():
34
34
  version_base=None,
35
35
  )
36
36
  def main(cfg: DictConfig) -> None:
37
+ """
38
+ Main entry point for the FusionBench command-line interface.
39
+
40
+ This function serves as the primary entry point for the `fusion_bench` CLI command.
41
+ It is decorated with Hydra's main decorator to handle configuration management,
42
+ command-line argument parsing, and configuration file loading.
43
+
44
+ The function performs the following operations:
45
+ 1. Resolves any interpolations in the configuration using OmegaConf
46
+ 2. Instantiates the appropriate program class based on the configuration
47
+ 3. Executes the program's run method to perform the fusion task
48
+
49
+ Args:
50
+ cfg (DictConfig): The Hydra configuration object containing all settings
51
+ for the fusion task. This includes method configuration, model pool
52
+ configuration, task pool configuration, and other runtime parameters.
53
+ The configuration is automatically loaded by Hydra from the specified
54
+ config files and command-line overrides.
55
+
56
+ Returns:
57
+ None: This function doesn't return a value but executes the fusion
58
+ program which may save results, log outputs, or perform other
59
+ side effects as configured.
60
+
61
+ Example:
62
+ This function is typically called automatically when running:
63
+ ```bash
64
+ fusion_bench method=... modelpool=... taskpool=...
65
+ ```
66
+
67
+ The Hydra decorator handles parsing these command-line arguments and
68
+ loading the corresponding configuration files to populate the cfg parameter.
69
+ """
37
70
  OmegaConf.resolve(cfg)
38
71
  program: BaseHydraProgram = instantiate(cfg)
39
72
  program.run()
@@ -1,14 +1,15 @@
1
1
  from abc import abstractmethod
2
+ from typing import Any, Dict
2
3
 
3
- from fusion_bench.mixins import BaseYAMLSerializableModel
4
+ from fusion_bench.mixins import BaseYAMLSerializable
4
5
 
5
6
 
6
- class BaseTaskPool(BaseYAMLSerializableModel):
7
+ class BaseTaskPool(BaseYAMLSerializable):
7
8
  _program = None
8
9
  _config_key = "taskpool"
9
10
 
10
11
  @abstractmethod
11
- def evaluate(self, model, *args, **kwargs):
12
+ def evaluate(self, model: Any, *args: Any, **kwargs: Any) -> Dict[str, Any]:
12
13
  """
13
14
  Evaluate the model on all tasks in the task pool, and return a report.
14
15
 
@@ -27,8 +27,9 @@ from tqdm.autonotebook import tqdm
27
27
  from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
28
28
  from transformers.models.clip.modeling_clip import CLIPVisionTransformer
29
29
 
30
+ from fusion_bench import RuntimeConstants
30
31
  from fusion_bench.dataset import CLIPDataset
31
- from fusion_bench.mixins import LightningFabricMixin
32
+ from fusion_bench.mixins import HydraConfigMixin, LightningFabricMixin
32
33
  from fusion_bench.models.hf_clip import HFCLIPClassifier
33
34
  from fusion_bench.taskpool import BaseTaskPool
34
35
  from fusion_bench.tasks.clip_classification import get_classnames_and_templates
@@ -56,6 +57,8 @@ class LayerWiseFeatureSaver:
56
57
  first_token_only: bool = True,
57
58
  max_num: Optional[int] = None,
58
59
  ):
60
+ if isinstance(save_path, str):
61
+ save_path = Path(save_path)
59
62
  self.save_path = save_path
60
63
  self.first_token_only = first_token_only
61
64
  self.max_num = max_num
@@ -84,8 +87,9 @@ class LayerWiseFeatureSaver:
84
87
 
85
88
 
86
89
  class CLIPVisionModelTaskPool(
87
- BaseTaskPool,
90
+ HydraConfigMixin,
88
91
  LightningFabricMixin,
92
+ BaseTaskPool,
89
93
  ):
90
94
  """
91
95
  This class is used to define the image classification task for CLIP models.
@@ -122,14 +126,14 @@ class CLIPVisionModelTaskPool(
122
126
  self,
123
127
  test_datasets: Union[DictConfig, Dict[str, Dataset]],
124
128
  *,
125
- processor: Union[DictConfig, CLIPProcessor],
126
- data_processor: Union[DictConfig, CLIPProcessor],
127
- clip_model: Union[DictConfig, CLIPModel],
129
+ processor: Union[str, DictConfig, CLIPProcessor],
130
+ clip_model: Union[str, DictConfig, CLIPModel],
131
+ data_processor: Union[DictConfig, CLIPProcessor] = None,
128
132
  dataloader_kwargs: DictConfig = None,
129
133
  layer_wise_feature_save_path: Optional[str] = None,
130
134
  layer_wise_feature_first_token_only: bool = True,
131
135
  layer_wise_feature_max_num: Optional[int] = None,
132
- fast_dev_run: bool = False,
136
+ fast_dev_run: Optional[bool] = None,
133
137
  **kwargs,
134
138
  ):
135
139
  """
@@ -151,7 +155,10 @@ class CLIPVisionModelTaskPool(
151
155
  self.layer_wise_feature_first_token_only = layer_wise_feature_first_token_only
152
156
  self.layer_wise_feature_max_num = layer_wise_feature_max_num
153
157
 
154
- self.fast_dev_run = fast_dev_run
158
+ if self.fast_dev_run is None:
159
+ self.fast_dev_run = RuntimeConstants().debug
160
+ else:
161
+ self.fast_dev_run = fast_dev_run
155
162
  super().__init__(**kwargs)
156
163
 
157
164
  def setup(self):
@@ -159,21 +166,35 @@ class CLIPVisionModelTaskPool(
159
166
  Set up the processor, data processor, CLIP model, test datasets, and data loaders.
160
167
  """
161
168
  # setup processor and clip model
162
- self.processor = (
163
- instantiate(self._processor)
164
- if isinstance(self._processor, DictConfig)
165
- else self._processor
166
- )
167
- self.data_processor = (
168
- instantiate(self._data_processor)
169
- if isinstance(self._data_processor, DictConfig)
170
- else self._data_processor
171
- )
172
- self.clip_model = (
173
- instantiate(self._clip_model)
174
- if isinstance(self._clip_model, DictConfig)
175
- else self._clip_model
176
- )
169
+ if isinstance(self._processor, str):
170
+ self.processor = CLIPProcessor.from_pretrained(self._processor)
171
+ elif (
172
+ isinstance(self._processor, (dict, DictConfig))
173
+ and "_target_" in self._processor
174
+ ):
175
+ self.processor = instantiate(self._processor)
176
+ else:
177
+ self.processor = self._processor
178
+
179
+ if self._data_processor is None:
180
+ self.data_processor = self.processor
181
+ else:
182
+ self.data_processor = (
183
+ instantiate(self._data_processor)
184
+ if isinstance(self._data_processor, DictConfig)
185
+ else self._data_processor
186
+ )
187
+
188
+ if isinstance(self._clip_model, str):
189
+ self.clip_model = CLIPModel.from_pretrained(self._clip_model)
190
+ elif (
191
+ isinstance(self._clip_model, (dict, DictConfig))
192
+ and "_target_" in self._clip_model
193
+ ):
194
+ self.clip_model = instantiate(self._clip_model)
195
+ else:
196
+ self.clip_model = self._clip_model
197
+
177
198
  self.clip_model = self.fabric.to_device(self.clip_model)
178
199
  self.clip_model.requires_grad_(False)
179
200
  self.clip_model.eval()
@@ -4,13 +4,13 @@ This is the dummy task pool that is used for debugging purposes.
4
4
 
5
5
  from typing import Optional
6
6
 
7
+ from lightning.pytorch.utilities import rank_zero_only
7
8
  from torch import nn
8
9
 
9
10
  from fusion_bench.models.separate_io import separate_save
10
11
  from fusion_bench.taskpool.base_pool import BaseTaskPool
11
12
  from fusion_bench.utils import timeit_context
12
13
  from fusion_bench.utils.parameters import count_parameters, print_parameters
13
- from lightning.pytorch.utilities import rank_zero_only
14
14
 
15
15
 
16
16
  def get_model_summary(model: nn.Module) -> dict:
@@ -1,6 +1,6 @@
1
1
  import logging
2
2
  import os
3
- from typing import List, Literal, Optional, Union, TYPE_CHECKING
3
+ from typing import TYPE_CHECKING, List, Literal, Optional, Union
4
4
 
5
5
  import lightning.fabric
6
6
  import lm_eval
@@ -12,7 +12,6 @@ from fusion_bench import BaseTaskPool
12
12
  from fusion_bench.mixins import LightningFabricMixin
13
13
  from fusion_bench.utils.strenum import _version
14
14
 
15
-
16
15
  log = logging.getLogger(__name__)
17
16
 
18
17
 
@@ -1,6 +1,6 @@
1
1
  import importlib
2
2
  import warnings
3
- from typing import Any, Callable, Dict, List
3
+ from typing import Any, Callable, Dict, List, Tuple
4
4
 
5
5
  from datasets import load_dataset
6
6
 
@@ -79,7 +79,9 @@ class CLIPTemplateFactory:
79
79
  }
80
80
 
81
81
  @staticmethod
82
- def get_classnames_and_templates(dataset_name: str):
82
+ def get_classnames_and_templates(
83
+ dataset_name: str,
84
+ ) -> Tuple[List[str], List[Callable]]:
83
85
  """
84
86
  Retrieves class names and templates for the specified dataset.
85
87
 
@@ -169,7 +171,7 @@ class CLIPTemplateFactory:
169
171
  CLIPTemplateFactory._dataset_mapping[dataset_name] = dataset_info
170
172
 
171
173
  @staticmethod
172
- def get_available_datasets():
174
+ def get_available_datasets() -> List[str]:
173
175
  """
174
176
  Get a list of all available dataset names.
175
177
 
@@ -179,5 +181,5 @@ class CLIPTemplateFactory:
179
181
  return list(CLIPTemplateFactory._dataset_mapping.keys())
180
182
 
181
183
 
182
- def get_classnames_and_templates(dataset_name: str):
184
+ def get_classnames_and_templates(dataset_name: str) -> Tuple[List[str], List[Callable]]:
183
185
  return CLIPTemplateFactory.get_classnames_and_templates(dataset_name)
@@ -7,10 +7,16 @@ from .cache_utils import *
7
7
  from .devices import *
8
8
  from .dtype import parse_dtype
9
9
  from .fabric import seed_everything_by_time
10
- from .instantiate_utils import instantiate, is_instantiable
10
+ from .instantiate_utils import (
11
+ instantiate,
12
+ is_instantiable,
13
+ set_print_function_call,
14
+ set_print_function_call_permeanent,
15
+ )
11
16
  from .json import load_from_json, save_to_json
12
17
  from .lazy_state_dict import LazyStateDict
13
18
  from .misc import *
14
19
  from .packages import import_object
15
20
  from .parameters import *
21
+ from .pylogger import get_rankzero_logger
16
22
  from .timer import timeit_context
@@ -1,15 +1,30 @@
1
1
  import logging
2
2
  import os
3
3
  import pickle
4
+ import warnings
4
5
  from functools import wraps
5
6
  from pathlib import Path
6
7
  from typing import Any, Callable, Union
7
8
 
8
- __all__ = ["cache_to_disk"]
9
+ from joblib import Memory
10
+
11
+ __all__ = ["cache_to_disk", "cache_with_joblib", "set_default_cache_dir"]
9
12
 
10
13
 
11
14
  log = logging.getLogger(__name__)
12
15
 
16
+ DEFAULT_CACHE_DIR = Path.cwd() / "outputs" / "cache"
17
+
18
+
19
+ def set_default_cache_dir(path: str | Path):
20
+ global DEFAULT_CACHE_DIR
21
+ if path is None:
22
+ return
23
+
24
+ if isinstance(path, str):
25
+ path = Path(path)
26
+ DEFAULT_CACHE_DIR = path
27
+
13
28
 
14
29
  def cache_to_disk(file_path: Union[str, Path]) -> Callable:
15
30
  """
@@ -17,6 +32,11 @@ def cache_to_disk(file_path: Union[str, Path]) -> Callable:
17
32
  the result is loaded from the file. Otherwise, the function is executed and
18
33
  the result is saved to the file.
19
34
 
35
+ !!! warning "deprecated"
36
+ This function is deprecated. Use `cache_with_joblib` instead for better
37
+ caching capabilities including automatic cache invalidation, better object
38
+ handling, and memory efficiency.
39
+
20
40
  ## Example usage
21
41
 
22
42
  ```python
@@ -32,6 +52,13 @@ def cache_to_disk(file_path: Union[str, Path]) -> Callable:
32
52
  Returns:
33
53
  Callable: The decorated function.
34
54
  """
55
+ warnings.warn(
56
+ "cache_to_disk is deprecated. Use cache_with_joblib instead for better "
57
+ "caching capabilities including automatic cache invalidation, better object "
58
+ "handling, and memory efficiency.",
59
+ DeprecationWarning,
60
+ stacklevel=2,
61
+ )
35
62
  if isinstance(file_path, str):
36
63
  file_path = Path(file_path)
37
64
  assert isinstance(file_path, Path)
@@ -56,3 +83,76 @@ def cache_to_disk(file_path: Union[str, Path]) -> Callable:
56
83
  return wrapper
57
84
 
58
85
  return decorator
86
+
87
+
88
+ def cache_with_joblib(
89
+ cache_dir: Union[str, Path] = None,
90
+ verbose: int = 0,
91
+ ) -> Callable:
92
+ """
93
+ A decorator to cache the result of a function using joblib.Memory. This provides
94
+ more advanced caching capabilities compared to cache_to_disk, including:
95
+ - Automatic cache invalidation when function arguments change
96
+ - Better handling of numpy arrays and other complex objects
97
+ - Memory-efficient storage
98
+ - Optional verbose output for cache hits/misses
99
+
100
+ ## Example usage
101
+
102
+ ```python
103
+ @cache_with_joblib("./cache", verbose=1)
104
+ def expensive_computation(x: int, y: str) -> Any:
105
+ # Function implementation
106
+ return complex_result
107
+
108
+ # Or with default settings:
109
+ @cache_with_joblib()
110
+ def another_function(x: int) -> int:
111
+ return x * 2
112
+ ```
113
+
114
+ Args:
115
+ cache_dir (Union[str, Path]): The directory where cache files should be stored.
116
+ If `None`, a default directory `outputs/cache` will be used.
117
+ verbose (int): Verbosity level for joblib.Memory (0=silent, 1=basic, 2++=verbose).
118
+
119
+ Returns:
120
+ Callable: A decorator function that can be applied to functions.
121
+ """
122
+
123
+ if cache_dir is None:
124
+ cache_dir = DEFAULT_CACHE_DIR
125
+
126
+ if isinstance(cache_dir, str):
127
+ cache_dir = Path(cache_dir)
128
+ assert isinstance(cache_dir, Path)
129
+
130
+ # Create the cache directory if it doesn't exist
131
+ cache_dir.mkdir(parents=True, exist_ok=True)
132
+
133
+ # Create a Memory object for this function
134
+ memory = Memory(location=cache_dir, verbose=verbose)
135
+
136
+ def decorator(func: Callable) -> Callable:
137
+ nonlocal memory
138
+
139
+ # Create the cached version of the function
140
+ cached_func = memory.cache(func)
141
+
142
+ @wraps(func)
143
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
144
+ return cached_func(*args, **kwargs)
145
+
146
+ # Expose useful methods from joblib.Memory
147
+ if not (
148
+ hasattr(cached_func, "clear")
149
+ or hasattr(cached_func, "call")
150
+ or hasattr(cached_func, "check_call_in_cache")
151
+ ):
152
+ wrapper.clear = cached_func.clear
153
+ wrapper.call = cached_func.call
154
+ wrapper.check_call_in_cache = cached_func.check_call_in_cache
155
+
156
+ return wrapper
157
+
158
+ return decorator
@@ -1,4 +1,5 @@
1
1
  import gc
2
+ import logging
2
3
  import os
3
4
  from typing import List, Optional, Union
4
5
 
@@ -12,7 +13,7 @@ from transformers.utils import (
12
13
  )
13
14
 
14
15
  __all__ = [
15
- "cuda_empty_cache",
16
+ "clear_cuda_cache",
16
17
  "to_device",
17
18
  "num_devices",
18
19
  "get_device",
@@ -21,10 +22,19 @@ __all__ = [
21
22
  "get_device_capabilities",
22
23
  ]
23
24
 
25
+ log = logging.getLogger(__name__)
24
26
 
25
- def cuda_empty_cache():
27
+
28
+ def clear_cuda_cache():
29
+ """
30
+ Clears the CUDA memory cache to free up GPU memory.
31
+ Works only if CUDA is available.
32
+ """
26
33
  gc.collect()
27
- torch.cuda.empty_cache()
34
+ if torch.cuda.is_available():
35
+ torch.cuda.empty_cache()
36
+ else:
37
+ log.warning("CUDA is not available. No cache to clear.")
28
38
 
29
39
 
30
40
  def to_device(obj, device: Optional[torch.device], **kwargs):
@@ -75,7 +85,7 @@ def num_devices(devices: Union[int, List[int], str]) -> int:
75
85
  Return the number of devices.
76
86
 
77
87
  Args:
78
- devices: `devices` can be a single int to specify the number of devices, or a list of device ids, e.g. [0, 1, 2, 3] or a str of device ids, e.g. "0,1,2,3" and "[0, 1, 2]".
88
+ devices: `devices` can be a single int to specify the number of devices, or a list of device ids, e.g. [0, 1, 2, 3], or a str of device ids, e.g. "0,1,2,3" and "[0, 1, 2]".
79
89
 
80
90
  Returns:
81
91
  The number of devices.
@@ -3,9 +3,9 @@ from typing import Optional
3
3
 
4
4
  import lightning as L
5
5
 
6
- from fusion_bench.utils.pylogger import getRankZeroLogger
6
+ from fusion_bench.utils.pylogger import get_rankzero_logger
7
7
 
8
- log = getRankZeroLogger(__name__)
8
+ log = get_rankzero_logger(__name__)
9
9
 
10
10
 
11
11
  def seed_everything_by_time(fabric: Optional[L.Fabric] = None):
@@ -28,7 +28,7 @@ PRINT_FUNCTION_CALL_FUNC = print
28
28
  Function to be used for printing function calls.
29
29
  """
30
30
 
31
- CATCH_EXCEPTION = True
31
+ CATCH_EXCEPTION = False
32
32
 
33
33
 
34
34
  @contextmanager
@@ -41,10 +41,12 @@ def set_print_function_call(value: bool):
41
41
  finally:
42
42
  PRINT_FUNCTION_CALL = old_value
43
43
 
44
+
44
45
  def set_print_function_call_permeanent(value: bool):
45
46
  global PRINT_FUNCTION_CALL
46
47
  PRINT_FUNCTION_CALL = value
47
48
 
49
+
48
50
  def is_instantiable(config: Union[DictConfig, Any]) -> bool:
49
51
  if OmegaConf.is_dict(config):
50
52
  return "_target_" in config
@@ -72,3 +72,26 @@ class LazyImporter(ModuleType):
72
72
 
73
73
  def __reduce__(self):
74
74
  return (self.__class__, (self._name, self.__file__, self._import_structure))
75
+
76
+
77
+ class LazyModule(ModuleType):
78
+ """Module wrapper for lazy import.
79
+ Adapted from Optuna: https://github.com/optuna/optuna/blob/1f92d496b0c4656645384e31539e4ee74992ff55/optuna/__init__.py
80
+
81
+ This class wraps specified module and lazily import it when they are actually accessed.
82
+
83
+ Args:
84
+ name: Name of module to apply lazy import.
85
+ """
86
+
87
+ def __init__(self, name: str) -> None:
88
+ super().__init__(name)
89
+ self._name = name
90
+
91
+ def _load(self) -> ModuleType:
92
+ module = importlib.import_module(self._name)
93
+ self.__dict__.update(module.__dict__)
94
+ return module
95
+
96
+ def __getattr__(self, item: str) -> Any:
97
+ return getattr(self._load(), item)
@@ -2,7 +2,7 @@ import json
2
2
  import logging
3
3
  import os
4
4
  from copy import deepcopy
5
- from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple, Type
5
+ from typing import TYPE_CHECKING, Dict, Iterator, List, Mapping, Optional, Tuple, Type
6
6
 
7
7
  import torch
8
8
  from accelerate import init_empty_weights
@@ -49,7 +49,7 @@ def resolve_checkpoint_path(
49
49
  )
50
50
 
51
51
 
52
- class LazyStateDict:
52
+ class LazyStateDict(Mapping[str, torch.Tensor]):
53
53
  """
54
54
  Dictionary-like object that lazily loads a state dict from a checkpoint path.
55
55
  """
@@ -168,12 +168,21 @@ class LazyStateDict:
168
168
  def config(self) -> "PretrainedConfig":
169
169
  return AutoConfig.from_pretrained(self._checkpoint)
170
170
 
171
+ @property
172
+ def dtype(self) -> torch.dtype:
173
+ """
174
+ `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
175
+ """
176
+ first_key = next(iter(self.keys()))
177
+ first_param = self[first_key]
178
+ return first_param.dtype
179
+
171
180
  def state_dict(self, keep_vars: bool = False) -> "LazyStateDict":
172
181
  """
173
182
  Args:
174
183
  keep_vars (bool): Ignored, as LazyStateDict does not support keep_vars. Just for compatibility.
175
184
  """
176
- return self
185
+ return deepcopy(self)
177
186
 
178
187
  def _resolve_checkpoint_files(self, checkpoint: str):
179
188
  # reference: https://huggingface.co/docs/accelerate/v0.17.1/en/usage_guides/big_modeling
@@ -290,6 +299,18 @@ class LazyStateDict:
290
299
  )
291
300
  return tensor
292
301
 
302
+ def pop(self, key: str):
303
+ assert key in list(
304
+ self.keys()
305
+ ), "KeyError: Cannot pop a tensor for a key that does not exist in the LazyStateDict."
306
+ if self._state_dict_cache is not None and key in self._state_dict_cache:
307
+ if key in self._index:
308
+ self._index.pop(key)
309
+ return self._state_dict_cache.pop(key)
310
+ if key in self._index:
311
+ self._index.pop(key)
312
+ return None
313
+
293
314
  def __setitem__(self, key: str, value: torch.Tensor) -> None:
294
315
  """
295
316
  Set a tensor in the LazyStateDict. This will update the state dict cache if it is enabled.
@@ -408,3 +429,17 @@ class LazyStateDict:
408
429
  raise KeyError(f"Key {key} not found in LazyStateDict.")
409
430
  for key, value in state_dict.items():
410
431
  self[key] = value
432
+
433
+ def __getattr__(self, name: str):
434
+ if "meta_module" in self.__dict__:
435
+ meta_module = self.__dict__["meta_module"]
436
+ if meta_module is not None:
437
+ if "_parameters" in meta_module.__dict__:
438
+ if name in meta_module.__dict__["_parameters"]:
439
+ return self.get_parameter(name)
440
+ if "_modules" in meta_module.__dict__:
441
+ if name in meta_module.__dict__["_modules"]:
442
+ return self.get_submodule(name)
443
+ raise AttributeError(
444
+ f"'{type(self).__name__}' object has no attribute '{name}'"
445
+ )