fusion-bench 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (209) hide show
  1. fusion_bench/compat/method/__init__.py +2 -0
  2. fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +4 -1
  3. fusion_bench/constants/clip_vision.py +22 -0
  4. fusion_bench/dataset/clip_dataset.py +10 -2
  5. fusion_bench/dataset/fer2013.py +1 -0
  6. fusion_bench/dataset/gsm8k.py +2 -2
  7. fusion_bench/method/__init__.py +10 -0
  8. fusion_bench/method/ada_svd/clip_vision.py +4 -1
  9. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +1 -29
  10. fusion_bench/method/fisher_merging/fisher_merging.py +29 -17
  11. fusion_bench/method/gossip/__init__.py +3 -0
  12. fusion_bench/method/gossip/clip_layer_wise_gossip.py +43 -0
  13. fusion_bench/method/gossip/clip_task_wise_gossip.py +190 -0
  14. fusion_bench/method/gossip/entropy_loss.py +25 -0
  15. fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +388 -0
  16. fusion_bench/method/gossip/layer_wise_gossip.py +434 -0
  17. fusion_bench/method/gossip/min_norm_solvers.py +227 -0
  18. fusion_bench/method/gossip/task_wise_gossip.py +265 -0
  19. fusion_bench/method/gossip/utils.py +74 -0
  20. fusion_bench/method/isotropic_merging/__init__.py +1 -1
  21. fusion_bench/method/opcm/opcm.py +16 -7
  22. fusion_bench/method/pwe_moe/module.py +1 -1
  23. fusion_bench/method/pwe_moe/openclip_pwe_moe.py +476 -0
  24. fusion_bench/method/regmean/regmean.py +25 -17
  25. fusion_bench/method/smile_upscaling/__init__.py +1 -1
  26. fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +46 -145
  27. fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +229 -0
  28. fusion_bench/method/smile_upscaling/smile_upscaling.py +19 -346
  29. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +7 -0
  30. fusion_bench/method/task_arithmetic/task_arithmetic.py +8 -6
  31. fusion_bench/method/ties_merging/ties_merging.py +36 -31
  32. fusion_bench/method/we_moe/we_moe.py +14 -15
  33. fusion_bench/mixins/__init__.py +6 -3
  34. fusion_bench/mixins/hydra_config.py +49 -0
  35. fusion_bench/mixins/openclip_classification.py +11 -0
  36. fusion_bench/mixins/simple_profiler.py +4 -2
  37. fusion_bench/modelpool/__init__.py +3 -1
  38. fusion_bench/modelpool/base_pool.py +2 -2
  39. fusion_bench/modelpool/openclip_vision/__init__.py +1 -0
  40. fusion_bench/modelpool/openclip_vision/modelpool.py +255 -0
  41. fusion_bench/models/modeling_smile_mistral/modeling_smile_mistral.py +2 -203
  42. fusion_bench/models/modeling_smile_qwen2/__init__.py +8 -0
  43. fusion_bench/models/modeling_smile_qwen2/configuration_smile_qwen2.py +21 -0
  44. fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +922 -0
  45. fusion_bench/models/modeling_smile_qwen2/register.py +11 -0
  46. fusion_bench/models/open_clip/__init__.py +6 -0
  47. fusion_bench/models/open_clip/modeling.py +176 -0
  48. fusion_bench/models/open_clip/utils.py +311 -0
  49. fusion_bench/models/open_clip/variables_and_paths.py +56 -0
  50. fusion_bench/models/parameter_dict.py +54 -13
  51. fusion_bench/models/rankone_moe.py +2 -88
  52. fusion_bench/models/smile_moe/linear_from_hf_config.py +373 -0
  53. fusion_bench/models/smile_moe/{linear.py → linear_from_module.py} +103 -33
  54. fusion_bench/models/smile_moe/utils/__init__.py +24 -0
  55. fusion_bench/models/smile_moe/utils/svd_utils.py +46 -0
  56. fusion_bench/scripts/nyuv2_mtl_train.py +1 -1
  57. fusion_bench/taskpool/__init__.py +7 -3
  58. fusion_bench/taskpool/clip_vision/__init__.py +1 -0
  59. fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +2 -30
  60. fusion_bench/taskpool/clip_vision/clip_smile_taskpool.py +102 -0
  61. fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +2 -30
  62. fusion_bench/taskpool/clip_vision/taskpool.py +1 -2
  63. fusion_bench/taskpool/clip_vision/utils/__init__.py +0 -0
  64. fusion_bench/taskpool/clip_vision/utils/routing_analysis_utils.py +65 -0
  65. fusion_bench/taskpool/gpt2_text_classification.py +30 -1
  66. fusion_bench/taskpool/lm_eval_harness/__init__.py +3 -0
  67. fusion_bench/taskpool/lm_eval_harness/taskpool.py +87 -0
  68. fusion_bench/taskpool/openclip_vision/__init__.py +1 -0
  69. fusion_bench/taskpool/openclip_vision/openclip_taskpool.py +196 -0
  70. fusion_bench/utils/data.py +12 -0
  71. fusion_bench/utils/devices.py +14 -0
  72. fusion_bench/utils/instantiate.py +12 -0
  73. fusion_bench/utils/misc.py +9 -2
  74. fusion_bench/utils/packages.py +14 -0
  75. fusion_bench/utils/parameters.py +1 -1
  76. fusion_bench/utils/tensorboard.py +1 -1
  77. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/METADATA +22 -2
  78. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/RECORD +209 -157
  79. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/WHEEL +1 -1
  80. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -2
  81. fusion_bench_config/dataset/image_classification/test/TALL20.yaml +0 -1
  82. fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +0 -1
  83. fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +1 -1
  84. fusion_bench_config/dataset/image_classification/train/TALL20.yaml +0 -1
  85. fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +1 -1
  86. fusion_bench_config/fabric/auto.yaml +0 -1
  87. fusion_bench_config/fabric/llama_ddp.yaml +0 -1
  88. fusion_bench_config/fabric/llama_fsdp.yaml +0 -1
  89. fusion_bench_config/fabric/llama_peft_fsdp.yaml +0 -1
  90. fusion_bench_config/fabric/strategy/deepspeed.yaml +0 -1
  91. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +0 -1
  92. fusion_bench_config/fabric_model_fusion.yaml +0 -1
  93. fusion_bench_config/llama_full_finetune.yaml +0 -2
  94. fusion_bench_config/llama_model_fusion.yaml +0 -2
  95. fusion_bench_config/method/ada_svd/clip_vision.yaml +0 -1
  96. fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +0 -5
  97. fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +0 -5
  98. fusion_bench_config/method/adamerging/llama_sft.yaml +0 -2
  99. fusion_bench_config/method/adamerging.yaml +2 -2
  100. fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +0 -1
  101. fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +0 -1
  102. fusion_bench_config/method/classification/clip_continual_finetune.yaml +0 -1
  103. fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +0 -1
  104. fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +0 -1
  105. fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +1 -12
  106. fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +1 -12
  107. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +1 -10
  108. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +1 -14
  109. fusion_bench_config/method/dare/simple_average.yaml +0 -1
  110. fusion_bench_config/method/dare/task_arithmetic.yaml +0 -1
  111. fusion_bench_config/method/dare/ties_merging.yaml +0 -2
  112. fusion_bench_config/method/dawe/dawe_for_clip.yaml +0 -3
  113. fusion_bench_config/method/doge_ta/doge_ta.yaml +1 -1
  114. fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -1
  115. fusion_bench_config/method/ensemble/simple_ensemble.yaml +0 -1
  116. fusion_bench_config/method/ensemble/weighted_ensemble.yaml +0 -1
  117. fusion_bench_config/method/gossip/layer_wise_clip.yaml +30 -0
  118. fusion_bench_config/method/gossip/layer_wise_flan_t5.yaml +25 -0
  119. fusion_bench_config/method/isotropic_merging/iso_c.yaml +0 -1
  120. fusion_bench_config/method/isotropic_merging/iso_cts.yaml +0 -1
  121. fusion_bench_config/method/linear/linear_interpolation.yaml +0 -1
  122. fusion_bench_config/method/linear/llama_expo.yaml +0 -3
  123. fusion_bench_config/method/linear/llama_expo_with_dare.yaml +0 -5
  124. fusion_bench_config/method/linear/weighted_average.yaml +0 -1
  125. fusion_bench_config/method/linear/weighted_average_for_llama.yaml +0 -1
  126. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +0 -4
  127. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +0 -4
  128. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +0 -6
  129. fusion_bench_config/method/mixtral_moe_upscaling.yaml +1 -2
  130. fusion_bench_config/method/model_recombination.yaml +0 -1
  131. fusion_bench_config/method/opcm/opcm.yaml +0 -1
  132. fusion_bench_config/method/opcm/task_arithmetic.yaml +0 -2
  133. fusion_bench_config/method/opcm/ties_merging.yaml +0 -2
  134. fusion_bench_config/method/opcm/weight_average.yaml +0 -1
  135. fusion_bench_config/method/pwe_moe/epo_for_openclip.yaml +30 -0
  136. fusion_bench_config/method/pwe_moe/ls_for_openclip.yaml +30 -0
  137. fusion_bench_config/method/{pwe_moe_ls_for_clip.yaml → pwe_moe/pwe_moe_ls_for_clip.yaml} +7 -6
  138. fusion_bench_config/method/rankone_moe/rankone_moe.yaml +1 -3
  139. fusion_bench_config/method/regmean/gpt2_regmean.yaml +0 -1
  140. fusion_bench_config/method/slerp/slerp.yaml +0 -2
  141. fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +5 -2
  142. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +13 -0
  143. fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +1 -1
  144. fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
  145. fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
  146. fusion_bench_config/method/surgery/adamerging_surgery.yaml +1 -2
  147. fusion_bench_config/method/task_arithmetic.yaml +1 -1
  148. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +0 -1
  149. fusion_bench_config/method/ties_merging.yaml +1 -1
  150. fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +0 -1
  151. fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +0 -8
  152. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -1
  153. fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -1
  154. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -1
  155. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -1
  156. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -1
  157. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -1
  158. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -1
  159. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -1
  160. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -1
  161. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -1
  162. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -1
  163. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +0 -3
  164. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +0 -3
  165. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +0 -3
  166. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +0 -3
  167. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +0 -3
  168. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +0 -3
  169. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +0 -4
  170. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +0 -3
  171. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +0 -4
  172. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +0 -4
  173. fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +0 -1
  174. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +0 -4
  175. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +0 -4
  176. fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +17 -0
  177. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +0 -1
  178. fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +0 -3
  179. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/README.md +90 -0
  180. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-16_TA8.yaml +27 -0
  181. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA8.yaml +45 -0
  182. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_cars_dtd.yaml +23 -0
  183. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_cars.yaml +23 -0
  184. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_dtd.yaml +23 -0
  185. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_individual.yaml +7 -0
  186. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-L-14_TA8.yaml +26 -0
  187. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +0 -1
  188. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +0 -2
  189. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +0 -2
  190. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +1 -3
  191. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +0 -1
  192. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +0 -3
  193. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +0 -4
  194. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +0 -3
  195. fusion_bench_config/modelpool/gpt-2_glue.yaml +0 -3
  196. fusion_bench_config/nyuv2_config.yaml +0 -2
  197. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +0 -3
  198. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +0 -2
  199. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +0 -2
  200. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +0 -2
  201. fusion_bench_config/taskpool/LMEvalHarnessTaskPool/lm_eval.yaml +12 -0
  202. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-16_TA8.yaml +24 -0
  203. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-32_TA8.yaml +24 -0
  204. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-L-14_TA8.yaml +24 -0
  205. fusion_bench_config/taskpool/gpt-2_glue.yaml +0 -1
  206. fusion_bench_config/taskpool/reward_model_evaluation.yaml +0 -4
  207. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/entry_points.txt +0 -0
  208. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/licenses/LICENSE +0 -0
  209. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,196 @@
1
+ import itertools
2
+ import json
3
+ import logging
4
+ import os
5
+ from typing import TYPE_CHECKING, Callable, Dict, Optional, Union
6
+
7
+ import lightning.fabric
8
+ import open_clip
9
+ import torch
10
+ from omegaconf import DictConfig
11
+ from torch.nn import functional as F
12
+ from torch.utils.data import DataLoader, Dataset
13
+ from torchmetrics import Accuracy, MeanMetric
14
+ from torchmetrics.classification.accuracy import MulticlassAccuracy
15
+ from tqdm.auto import tqdm
16
+
17
+ from fusion_bench import BaseTaskPool
18
+ from fusion_bench.dataset.clip_dataset import CLIPDataset
19
+ from fusion_bench.mixins import LightningFabricMixin
20
+ from fusion_bench.modelpool.openclip_vision.modelpool import load_classifier_head
21
+ from fusion_bench.models.open_clip import (
22
+ ClassificationHead,
23
+ ImageClassifier,
24
+ ImageEncoder,
25
+ )
26
+ from fusion_bench.models.open_clip.variables_and_paths import OPENCLIP_CACHEDIR
27
+ from fusion_bench.utils import count_parameters, instantiate
28
+
29
+ if TYPE_CHECKING:
30
+ from fusion_bench.modelpool import OpenCLIPVisionModelPool
31
+ from fusion_bench.programs import FabricModelFusionProgram
32
+
33
+ log = logging.getLogger(__name__)
34
+
35
+
36
+ class OpenCLIPVisionModelTaskPool(
37
+ BaseTaskPool,
38
+ LightningFabricMixin,
39
+ ):
40
+ _is_setup = False
41
+
42
+ _program: "FabricModelFusionProgram"
43
+
44
+ processor: Optional[Callable] = None
45
+ test_datasets: Dict[str, CLIPDataset]
46
+
47
+ def __init__(
48
+ self,
49
+ test_datasets: Union[DictConfig, Dict[str, Dataset]],
50
+ classification_heads: Union[DictConfig, Dict[str, ClassificationHead]],
51
+ dataloader_kwargs: DictConfig,
52
+ model_name: Optional[str] = None,
53
+ fast_dev_run: bool = False,
54
+ **kwargs,
55
+ ):
56
+ self._test_datasets = test_datasets
57
+ self._classifier_heads = classification_heads
58
+ self._dataloader_kwargs = dataloader_kwargs
59
+ self._model_name = model_name
60
+ self.fast_dev_run = fast_dev_run
61
+ super().__init__(**kwargs)
62
+
63
+ def setup(self):
64
+ # setup the processor
65
+ if self._program is not None and self._program.modelpool is not None:
66
+ modelpool: "OpenCLIPVisionModelPool" = self._program.modelpool
67
+ self.processor = modelpool.test_processor
68
+ elif self._model_name is not None:
69
+ _, _, self.processor = open_clip.create_model_and_transforms(
70
+ self._model_name,
71
+ pretrained="openai",
72
+ cache_dir=OPENCLIP_CACHEDIR,
73
+ )
74
+ else:
75
+ raise ValueError("Modelpool or model_name is not set")
76
+
77
+ # setup the test datasets
78
+ self.test_datasets = {
79
+ name: instantiate(dataset) if isinstance(dataset, DictConfig) else dataset
80
+ for name, dataset in self._test_datasets.items()
81
+ }
82
+ self.test_datasets = {
83
+ name: CLIPDataset(dataset, self.processor)
84
+ for name, dataset in self.test_datasets.items()
85
+ }
86
+ self.test_dataloaders = {
87
+ name: self.fabric.setup_dataloaders(
88
+ DataLoader(dataset, **self._dataloader_kwargs)
89
+ )
90
+ for name, dataset in self.test_datasets.items()
91
+ }
92
+
93
+ # setup classifier heads
94
+ self.classifier_heads = {
95
+ name: load_classifier_head(head).to(self.fabric.device)
96
+ for name, head in self._classifier_heads.items()
97
+ }
98
+ self._is_setup = True
99
+
100
+ @torch.no_grad()
101
+ def _evaluate(
102
+ self,
103
+ classifier: ImageClassifier,
104
+ test_loader: DataLoader,
105
+ num_classes: int,
106
+ task_name: str,
107
+ ):
108
+ accuracy: MulticlassAccuracy = Accuracy(
109
+ task="multiclass", num_classes=num_classes
110
+ )
111
+ classifier.eval()
112
+ loss_metric = MeanMetric()
113
+ # if fast_dev_run is set, we only evaluate on a batch of the data
114
+ if self.fast_dev_run:
115
+ log.info("Running under fast_dev_run mode, evaluating on a single batch.")
116
+ test_loader = itertools.islice(test_loader, 1)
117
+ else:
118
+ test_loader = test_loader
119
+
120
+ pbar = tqdm(
121
+ test_loader,
122
+ desc=f"Evaluating {task_name}",
123
+ leave=False,
124
+ dynamic_ncols=True,
125
+ )
126
+ for batch in pbar:
127
+ inputs, targets = batch
128
+ logits = classifier(inputs)
129
+ loss = F.cross_entropy(logits, targets)
130
+ loss_metric.update(loss.detach().cpu())
131
+ acc = accuracy(logits.detach().cpu(), targets.detach().cpu())
132
+ pbar.set_postfix(
133
+ {
134
+ "accuracy": accuracy.compute().item(),
135
+ "loss": loss_metric.compute().item(),
136
+ }
137
+ )
138
+
139
+ acc = accuracy.compute().item()
140
+ loss = loss_metric.compute().item()
141
+ results = {"accuracy": acc, "loss": loss}
142
+ return results
143
+
144
+ def evaluate(self, model: ImageEncoder, **kwargs):
145
+ if not self._is_setup:
146
+ self.setup()
147
+
148
+ report = {}
149
+ # collect basic model information
150
+ training_params, all_params = count_parameters(model)
151
+ report["model_info"] = {
152
+ "trainable_params": training_params,
153
+ "all_params": all_params,
154
+ "trainable_percentage": training_params / all_params,
155
+ }
156
+
157
+ if not lightning.fabric.is_wrapped(model):
158
+ model = self.fabric.setup_module(model)
159
+
160
+ pbar = tqdm(
161
+ self.test_dataloaders.items(),
162
+ desc="Evaluating tasks",
163
+ total=len(self.test_dataloaders),
164
+ )
165
+ for task_name, test_dataloader in pbar:
166
+ classifier = ImageClassifier(model, self.classifier_heads[task_name])
167
+ num_classes = self.classifier_heads[task_name].weight.size(0)
168
+ result = self._evaluate(
169
+ classifier,
170
+ test_dataloader,
171
+ num_classes=num_classes,
172
+ task_name=task_name,
173
+ )
174
+ report[task_name] = result
175
+
176
+ # calculate the average accuracy and loss
177
+ if "average" not in report:
178
+ report["average"] = {}
179
+ accuracies = [
180
+ value["accuracy"]
181
+ for key, value in report.items()
182
+ if "accuracy" in value
183
+ ]
184
+ if len(accuracies) > 0:
185
+ average_accuracy = sum(accuracies) / len(accuracies)
186
+ report["average"]["accuracy"] = average_accuracy
187
+ losses = [value["loss"] for key, value in report.items() if "loss" in value]
188
+ if len(losses) > 0:
189
+ average_loss = sum(losses) / len(losses)
190
+ report["average"]["loss"] = average_loss
191
+
192
+ log.info(f"Evaluation Result: {report}")
193
+ if self.fabric.is_global_zero and len(self.fabric._loggers) > 0:
194
+ with open(os.path.join(self.log_dir, "report.json"), "w") as fp:
195
+ json.dump(report, fp)
196
+ return report
@@ -9,6 +9,18 @@ from torch.utils.data import DataLoader, Dataset
9
9
 
10
10
 
11
11
  class InfiniteDataLoader:
12
+ """
13
+ A wrapper class for DataLoader to create an infinite data loader.
14
+ This is useful in case we are only interested in the number of steps and not the number of epochs.
15
+
16
+ This class wraps a DataLoader and provides an iterator that resets
17
+ when the end of the dataset is reached, creating an infinite loop.
18
+
19
+ Attributes:
20
+ data_loader (DataLoader): The DataLoader to wrap.
21
+ data_iter (iterator): An iterator over the DataLoader.
22
+ """
23
+
12
24
  def __init__(self, data_loader: DataLoader):
13
25
  self.data_loader = data_loader
14
26
  self.data_iter = iter(data_loader)
@@ -229,3 +229,17 @@ def cleanup_cuda():
229
229
  gc.collect()
230
230
  torch.cuda.empty_cache()
231
231
  torch.cuda.reset_peak_memory_stats()
232
+
233
+
234
+ def print_memory_usage(print_fn=print):
235
+ """
236
+ Print the current GPU memory usage.
237
+
238
+ Returns:
239
+ str: A string containing the allocated and cached memory in MB.
240
+ """
241
+ allocated = torch.cuda.memory_allocated() / 1024**2 # 转换为 MB
242
+ cached = torch.cuda.memory_reserved() / 1024**2 # 转换为 MB
243
+ print_str = f"Allocated Memory: {allocated:.2f} MB\nCached Memory: {cached:.2f} MB"
244
+ print_fn(print_str)
245
+ return print_str
@@ -2,6 +2,7 @@
2
2
  # Modified from Hydra
3
3
  import copy
4
4
  import functools
5
+ from contextlib import contextmanager
5
6
  from enum import Enum
6
7
  from textwrap import dedent
7
8
  from typing import Any, Callable, Dict, List, Sequence, Tuple, Union
@@ -30,6 +31,17 @@ Function to be used for printing function calls.
30
31
  CATCH_EXCEPTION = True
31
32
 
32
33
 
34
+ @contextmanager
35
+ def set_print_function_call(value: bool):
36
+ global PRINT_FUNCTION_CALL
37
+ old_value = PRINT_FUNCTION_CALL
38
+ PRINT_FUNCTION_CALL = value
39
+ try:
40
+ yield
41
+ finally:
42
+ PRINT_FUNCTION_CALL = old_value
43
+
44
+
33
45
  def is_instantiable(config: Union[DictConfig, Any]) -> bool:
34
46
  if OmegaConf.is_dict(config):
35
47
  return "_target_" in config
@@ -1,6 +1,6 @@
1
- from typing import Iterable
1
+ from typing import Iterable, List
2
2
 
3
- __all__ = ["first", "has_length"]
3
+ __all__ = ["first", "has_length", "join_list"]
4
4
 
5
5
 
6
6
  def first(iterable: Iterable):
@@ -16,3 +16,10 @@ def has_length(dataset):
16
16
  except TypeError:
17
17
  # TypeError: len() of unsized object
18
18
  return False
19
+
20
+
21
+ def join_list(list_of_list: List[List]):
22
+ ans = []
23
+ for item in list_of_list:
24
+ ans.extend(item)
25
+ return ans
@@ -82,3 +82,17 @@ def import_object(abs_obj_name: str):
82
82
  module_name, obj_name = abs_obj_name.rsplit(".", 1)
83
83
  module = importlib.import_module(module_name)
84
84
  return getattr(module, obj_name)
85
+
86
+
87
+ def compare_versions(v1, v2):
88
+ """Compare two version strings.
89
+ Returns -1 if v1 < v2, 0 if v1 == v2, 1 if v1 > v2"""
90
+
91
+ v1 = version.parse(v1)
92
+ v2 = version.parse(v2)
93
+ if v1 < v2:
94
+ return -1
95
+ elif v1 > v2:
96
+ return 1
97
+ else:
98
+ return 0
@@ -252,7 +252,7 @@ def print_parameters(
252
252
 
253
253
 
254
254
  def check_parameters_all_equal(
255
- list_of_param_names: List[Union[StateDictType, nn.Module, List[str]]]
255
+ list_of_param_names: List[Union[StateDictType, nn.Module, List[str]]],
256
256
  ) -> None:
257
257
  """
258
258
  Checks if all models have the same parameters.
@@ -1,5 +1,5 @@
1
1
  """
2
- functions deal with tensorboard logs.
2
+ functions deal with tensorboard logs.
3
3
  """
4
4
 
5
5
  from typing import Dict, Iterable, List
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: fusion_bench
3
- Version: 0.2.12
3
+ Version: 0.2.14
4
4
  Summary: A Comprehensive Benchmark of Deep Model Fusion
5
5
  Author-email: Anke Tang <tang.anke@foxmail.com>
6
6
  License: MIT License
@@ -45,6 +45,8 @@ Requires-Dist: rich
45
45
  Requires-Dist: scipy
46
46
  Requires-Dist: h5py
47
47
  Requires-Dist: pytest
48
+ Provides-Extra: lm-eval-harness
49
+ Requires-Dist: lm-eval; extra == "lm-eval-harness"
48
50
  Dynamic: license-file
49
51
 
50
52
  <div align='center'>
@@ -122,7 +124,7 @@ Merging multiple expert models offers a promising approach for performing multi-
122
124
 
123
125
  ## Installation
124
126
 
125
- install from PyPI:
127
+ Install from PyPI:
126
128
 
127
129
  ```bash
128
130
  pip install fusion-bench
@@ -137,6 +139,24 @@ cd fusion_bench
137
139
  pip install -e . # install the package in editable mode
138
140
  ```
139
141
 
142
+ ### Install with [Language Model Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness)
143
+
144
+ [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.10256836.svg)](https://doi.org/10.5281/zenodo.10256836)
145
+
146
+
147
+ ```bash
148
+ pip install "fusion-bench[lm-eval-harness]"
149
+ ```
150
+
151
+ or install from local directory
152
+
153
+ ```bash
154
+ pip install -e ".[lm-eval-harness]"
155
+ ```
156
+
157
+ This will install the latest version of fusion-bench and the dependencies required for LM-Eval Harness.
158
+ Documentation for using LM-Eval Harness within FusionBench framework can be found at [this online documentation](https://tanganke.github.io/fusion_bench/taskpool/lm_eval_harness) or in the [`docs/taskpool/lm_eval_harness.md`](docs/taskpool/lm_eval_harness.md) markdown file.
159
+
140
160
  ## Introduction to Deep Model Fusion
141
161
 
142
162
  Deep model fusion is a technique that merges, ensemble, or fuse multiple deep neural networks to obtain a unified model.