fusion-bench 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (199) hide show
  1. fusion_bench/compat/method/__init__.py +3 -1
  2. fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +4 -1
  3. fusion_bench/constants/clip_vision.py +22 -0
  4. fusion_bench/dataset/clip_dataset.py +10 -2
  5. fusion_bench/dataset/gsm8k.py +2 -2
  6. fusion_bench/method/__init__.py +12 -2
  7. fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
  8. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +1 -29
  9. fusion_bench/method/doge_ta/__init__.py +2 -0
  10. fusion_bench/method/{DOGE_TA → doge_ta}/clip_layer_wise_adamerging.py +1 -1
  11. fusion_bench/method/{DOGE_TA/DOGE_TA.py → doge_ta/doge_ta.py} +1 -1
  12. fusion_bench/method/fisher_merging/fisher_merging.py +29 -17
  13. fusion_bench/method/gossip/__init__.py +3 -0
  14. fusion_bench/method/gossip/clip_layer_wise_gossip.py +43 -0
  15. fusion_bench/method/gossip/clip_task_wise_gossip.py +190 -0
  16. fusion_bench/method/gossip/entropy_loss.py +25 -0
  17. fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +388 -0
  18. fusion_bench/method/gossip/layer_wise_gossip.py +434 -0
  19. fusion_bench/method/gossip/min_norm_solvers.py +227 -0
  20. fusion_bench/method/gossip/task_wise_gossip.py +265 -0
  21. fusion_bench/method/gossip/utils.py +74 -0
  22. fusion_bench/method/isotropic_merging/__init__.py +1 -1
  23. fusion_bench/method/opcm/opcm.py +102 -84
  24. fusion_bench/method/opcm/task_arithmetic.py +35 -21
  25. fusion_bench/method/opcm/ties_merging.py +71 -52
  26. fusion_bench/method/pwe_moe/module.py +1 -1
  27. fusion_bench/method/pwe_moe/openclip_pwe_moe.py +476 -0
  28. fusion_bench/method/regmean/regmean.py +25 -17
  29. fusion_bench/method/smile_upscaling/__init__.py +1 -1
  30. fusion_bench/method/smile_upscaling/smile_upscaling.py +13 -10
  31. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +7 -0
  32. fusion_bench/method/task_arithmetic/task_arithmetic.py +8 -6
  33. fusion_bench/method/ties_merging/ties_merging.py +36 -31
  34. fusion_bench/method/we_moe/we_moe.py +14 -15
  35. fusion_bench/mixins/__init__.py +6 -3
  36. fusion_bench/mixins/hydra_config.py +49 -0
  37. fusion_bench/mixins/openclip_classification.py +11 -0
  38. fusion_bench/mixins/simple_profiler.py +4 -2
  39. fusion_bench/modelpool/__init__.py +3 -1
  40. fusion_bench/modelpool/base_pool.py +2 -2
  41. fusion_bench/modelpool/openclip_vision/__init__.py +1 -0
  42. fusion_bench/modelpool/openclip_vision/modelpool.py +255 -0
  43. fusion_bench/models/open_clip/__init__.py +6 -0
  44. fusion_bench/models/open_clip/modeling.py +176 -0
  45. fusion_bench/models/open_clip/utils.py +311 -0
  46. fusion_bench/models/open_clip/variables_and_paths.py +56 -0
  47. fusion_bench/models/parameter_dict.py +54 -13
  48. fusion_bench/models/wrappers/layer_wise_fusion.py +1 -46
  49. fusion_bench/models/wrappers/layer_wise_fusion_doge_ta.py +4 -119
  50. fusion_bench/scripts/nyuv2_mtl_train.py +1 -1
  51. fusion_bench/taskpool/__init__.py +5 -3
  52. fusion_bench/taskpool/clip_vision/__init__.py +1 -0
  53. fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +2 -30
  54. fusion_bench/taskpool/clip_vision/clip_smile_taskpool.py +102 -0
  55. fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +2 -30
  56. fusion_bench/taskpool/clip_vision/taskpool.py +1 -2
  57. fusion_bench/taskpool/clip_vision/utils/__init__.py +0 -0
  58. fusion_bench/taskpool/clip_vision/utils/routing_analysis_utils.py +65 -0
  59. fusion_bench/taskpool/gpt2_text_classification.py +30 -1
  60. fusion_bench/taskpool/openclip_vision/__init__.py +1 -0
  61. fusion_bench/taskpool/openclip_vision/openclip_taskpool.py +196 -0
  62. fusion_bench/utils/data.py +12 -0
  63. fusion_bench/utils/devices.py +14 -0
  64. fusion_bench/utils/instantiate.py +12 -0
  65. fusion_bench/utils/misc.py +9 -2
  66. fusion_bench/utils/packages.py +14 -0
  67. fusion_bench/utils/parameters.py +1 -1
  68. fusion_bench/utils/tensorboard.py +1 -1
  69. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/METADATA +15 -2
  70. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/RECORD +198 -158
  71. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/WHEEL +1 -1
  72. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -2
  73. fusion_bench_config/dataset/image_classification/test/TALL20.yaml +0 -1
  74. fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +0 -1
  75. fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +1 -1
  76. fusion_bench_config/dataset/image_classification/train/TALL20.yaml +0 -1
  77. fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +1 -1
  78. fusion_bench_config/fabric/auto.yaml +0 -1
  79. fusion_bench_config/fabric/llama_ddp.yaml +0 -1
  80. fusion_bench_config/fabric/llama_fsdp.yaml +0 -1
  81. fusion_bench_config/fabric/llama_peft_fsdp.yaml +0 -1
  82. fusion_bench_config/fabric/strategy/deepspeed.yaml +0 -1
  83. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +0 -1
  84. fusion_bench_config/fabric_model_fusion.yaml +0 -1
  85. fusion_bench_config/llama_full_finetune.yaml +0 -2
  86. fusion_bench_config/llama_model_fusion.yaml +0 -2
  87. fusion_bench_config/method/ada_svd/clip_vision.yaml +0 -1
  88. fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +0 -5
  89. fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +0 -5
  90. fusion_bench_config/method/adamerging/llama_sft.yaml +0 -2
  91. fusion_bench_config/method/adamerging.yaml +2 -2
  92. fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +0 -1
  93. fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +0 -1
  94. fusion_bench_config/method/classification/clip_continual_finetune.yaml +0 -1
  95. fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +0 -1
  96. fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +0 -1
  97. fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +1 -12
  98. fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +1 -12
  99. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +1 -10
  100. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +1 -14
  101. fusion_bench_config/method/dare/simple_average.yaml +0 -1
  102. fusion_bench_config/method/dare/task_arithmetic.yaml +0 -1
  103. fusion_bench_config/method/dare/ties_merging.yaml +0 -2
  104. fusion_bench_config/method/dawe/dawe_for_clip.yaml +0 -3
  105. fusion_bench_config/method/{DOGE_TA/DOGE_TA.yaml → doge_ta/doge_ta.yaml} +1 -1
  106. fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -1
  107. fusion_bench_config/method/ensemble/simple_ensemble.yaml +0 -1
  108. fusion_bench_config/method/ensemble/weighted_ensemble.yaml +0 -1
  109. fusion_bench_config/method/gossip/layer_wise_clip.yaml +30 -0
  110. fusion_bench_config/method/gossip/layer_wise_flan_t5.yaml +25 -0
  111. fusion_bench_config/method/isotropic_merging/iso_c.yaml +0 -1
  112. fusion_bench_config/method/isotropic_merging/iso_cts.yaml +0 -1
  113. fusion_bench_config/method/linear/linear_interpolation.yaml +0 -1
  114. fusion_bench_config/method/linear/llama_expo.yaml +0 -3
  115. fusion_bench_config/method/linear/llama_expo_with_dare.yaml +0 -5
  116. fusion_bench_config/method/linear/weighted_average.yaml +0 -1
  117. fusion_bench_config/method/linear/weighted_average_for_llama.yaml +0 -1
  118. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +0 -4
  119. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +0 -4
  120. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +0 -6
  121. fusion_bench_config/method/mixtral_moe_upscaling.yaml +1 -2
  122. fusion_bench_config/method/model_recombination.yaml +0 -1
  123. fusion_bench_config/method/opcm/opcm.yaml +0 -1
  124. fusion_bench_config/method/opcm/task_arithmetic.yaml +0 -2
  125. fusion_bench_config/method/opcm/ties_merging.yaml +0 -2
  126. fusion_bench_config/method/opcm/weight_average.yaml +0 -1
  127. fusion_bench_config/method/pwe_moe/epo_for_openclip.yaml +30 -0
  128. fusion_bench_config/method/pwe_moe/ls_for_openclip.yaml +30 -0
  129. fusion_bench_config/method/{pwe_moe_ls_for_clip.yaml → pwe_moe/pwe_moe_ls_for_clip.yaml} +7 -6
  130. fusion_bench_config/method/rankone_moe/rankone_moe.yaml +1 -3
  131. fusion_bench_config/method/regmean/gpt2_regmean.yaml +0 -1
  132. fusion_bench_config/method/slerp/slerp.yaml +0 -2
  133. fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +1 -1
  134. fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
  135. fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
  136. fusion_bench_config/method/surgery/adamerging_surgery.yaml +1 -2
  137. fusion_bench_config/method/task_arithmetic.yaml +1 -1
  138. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +0 -1
  139. fusion_bench_config/method/ties_merging.yaml +1 -1
  140. fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +0 -1
  141. fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +0 -8
  142. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -1
  143. fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -1
  144. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -1
  145. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -1
  146. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -1
  147. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -1
  148. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -1
  149. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -1
  150. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -1
  151. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -1
  152. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -1
  153. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +0 -3
  154. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +0 -3
  155. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +0 -3
  156. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +0 -3
  157. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +0 -3
  158. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +0 -3
  159. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +0 -4
  160. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +0 -3
  161. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +0 -4
  162. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +0 -4
  163. fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +0 -1
  164. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +0 -4
  165. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +0 -4
  166. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +0 -1
  167. fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +0 -3
  168. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/README.md +90 -0
  169. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-16_TA8.yaml +27 -0
  170. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA8.yaml +45 -0
  171. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_cars_dtd.yaml +23 -0
  172. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_cars.yaml +23 -0
  173. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_dtd.yaml +23 -0
  174. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_individual.yaml +7 -0
  175. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-L-14_TA8.yaml +26 -0
  176. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +0 -1
  177. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +0 -2
  178. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +8 -10
  179. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +66 -0
  180. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +0 -1
  181. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +0 -3
  182. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +0 -4
  183. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +0 -3
  184. fusion_bench_config/modelpool/gpt-2_glue.yaml +0 -3
  185. fusion_bench_config/nyuv2_config.yaml +0 -2
  186. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +0 -3
  187. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +0 -2
  188. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +0 -2
  189. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +0 -2
  190. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-16_TA8.yaml +24 -0
  191. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-32_TA8.yaml +24 -0
  192. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-L-14_TA8.yaml +24 -0
  193. fusion_bench_config/taskpool/gpt-2_glue.yaml +0 -1
  194. fusion_bench_config/taskpool/reward_model_evaluation.yaml +0 -4
  195. fusion_bench/method/DOGE_TA/__init__.py +0 -2
  196. /fusion_bench/method/{DOGE_TA → doge_ta}/layer_wise_adamerging.py +0 -0
  197. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/entry_points.txt +0 -0
  198. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info/licenses}/LICENSE +0 -0
  199. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,311 @@
1
+ import copy
2
+ import os
3
+ import pickle
4
+ from collections import OrderedDict
5
+ from typing import Any, Dict, List, Optional, Tuple
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+
12
+ def compute_l1_norm(
13
+ model1: nn.Module, model2: nn.Module
14
+ ) -> Tuple[torch.Tensor, Dict[str, float]]:
15
+ """
16
+ Computes the L1 norm between the parameters of two models.
17
+
18
+ Args:
19
+ model1 (nn.Module): The first model.
20
+ model2 (nn.Module): The second model.
21
+
22
+ Returns:
23
+ Tuple[torch.Tensor, Dict[str, float]]: A tuple containing the total L1 norm and a dictionary
24
+ with the L1 norm for each layer.
25
+
26
+ """
27
+ norms = dict()
28
+ l1_norm = 0.0
29
+ for (n, p1), p2 in zip(model1.named_parameters(), model2.parameters()):
30
+ layer_l1_norm = torch.norm(p1 - p2, 1)
31
+ l1_norm += layer_l1_norm
32
+ norms[n] = layer_l1_norm.item()
33
+
34
+ return l1_norm, norms
35
+
36
+
37
+ def assign_learning_rate(param_group, new_lr):
38
+ param_group["lr"] = new_lr
39
+
40
+
41
+ def _warmup_lr(base_lr, warmup_length, step):
42
+ return base_lr * (step + 1) / warmup_length
43
+
44
+
45
+ def cosine_lr(optimizer, base_lrs, warmup_length, steps):
46
+ if not isinstance(base_lrs, list):
47
+ base_lrs = [base_lrs for _ in optimizer.param_groups]
48
+ assert len(base_lrs) == len(optimizer.param_groups)
49
+
50
+ def _lr_adjuster(step):
51
+ for param_group, base_lr in zip(optimizer.param_groups, base_lrs):
52
+ if step < warmup_length:
53
+ lr = _warmup_lr(base_lr, warmup_length, step)
54
+ else:
55
+ e = step - warmup_length
56
+ es = steps - warmup_length
57
+ lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
58
+ assign_learning_rate(param_group, lr)
59
+
60
+ return _lr_adjuster
61
+
62
+
63
+ def accuracy(output: torch.Tensor, target: torch.Tensor, topk: List[int] = (1,)):
64
+ pred = output.topk(max(topk), 1, True, True)[1].t()
65
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
66
+ return [
67
+ float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy())
68
+ for k in topk
69
+ ]
70
+
71
+
72
+ def torch_load_old(save_path: str, device=None):
73
+ with open(save_path, "rb") as f:
74
+ classifier = pickle.load(f)
75
+ if device is not None:
76
+ classifier = classifier.to(device)
77
+ return classifier
78
+
79
+
80
+ def torch_save(model, save_path, save_state_dict=True):
81
+ # TODO: hacky way to save state dict
82
+ if save_state_dict and isinstance(model, torch.nn.Module):
83
+ model = model.state_dict()
84
+ if os.path.dirname(save_path) != "":
85
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
86
+ torch.save(model, save_path)
87
+
88
+
89
+ def torch_load(save_path, device=None):
90
+ model = torch.load(save_path, map_location="cpu")
91
+ if device is not None:
92
+ model = model.to(device)
93
+ return model
94
+
95
+
96
+ def get_logits(inputs, classifier):
97
+ assert callable(classifier)
98
+ if hasattr(classifier, "to"):
99
+ classifier = classifier.to(inputs.device)
100
+ return classifier(inputs)
101
+
102
+
103
+ def get_probs(inputs, classifier):
104
+ if hasattr(classifier, "predict_proba"):
105
+ probs = classifier.predict_proba(inputs.detach().cpu().numpy())
106
+ return torch.from_numpy(probs)
107
+ logits = get_logits(inputs, classifier)
108
+ return logits.softmax(dim=1)
109
+
110
+
111
+ class LabelSmoothing(torch.nn.Module):
112
+ def __init__(self, smoothing=0.0):
113
+ super(LabelSmoothing, self).__init__()
114
+ self.confidence = 1.0 - smoothing
115
+ self.smoothing = smoothing
116
+
117
+ def forward(self, x, target):
118
+ logprobs = torch.nn.functional.log_softmax(x, dim=-1)
119
+
120
+ nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
121
+ nll_loss = nll_loss.squeeze(1)
122
+ smooth_loss = -logprobs.mean(dim=-1)
123
+ loss = self.confidence * nll_loss + self.smoothing * smooth_loss
124
+ return loss.mean()
125
+
126
+
127
+ class DotDict(dict):
128
+ """dot.notation access to dictionary attributes"""
129
+
130
+ __getattr__ = dict.get
131
+ __setattr__ = dict.__setitem__
132
+ __delattr__ = dict.__delitem__
133
+
134
+
135
+ def find_optimal_coef(
136
+ results: Dict[str, Any],
137
+ metric: str = "avg_normalized_top1",
138
+ minimize: bool = False,
139
+ control_metric: Optional[str] = None,
140
+ control_metric_threshold: float = 0.0,
141
+ ) -> float:
142
+ """
143
+ Finds the optimal coefficient based on the given results and metric.
144
+
145
+ Args:
146
+ results (Dict[str, Any]): A dictionary containing the results for different scaling coefficients.
147
+ metric (str, optional): The metric to optimize. Defaults to "avg_normalized_top1".
148
+ minimize (bool, optional): Whether to minimize the metric. Defaults to False.
149
+ control_metric (str, optional): The control metric to check against. Defaults to None.
150
+ control_metric_threshold (float, optional): The threshold value for the control metric. Defaults to 0.0.
151
+
152
+ Returns:
153
+ The optimal coefficient based on the given results and metric.
154
+ """
155
+ best_coef = None
156
+ if minimize:
157
+ best_metric = 1
158
+ else:
159
+ best_metric = 0
160
+ for scaling_coef in results.keys():
161
+ if control_metric is not None:
162
+ if results[scaling_coef][control_metric] < control_metric_threshold:
163
+ print(f"Control metric fell below {control_metric_threshold} threshold")
164
+ continue
165
+ if minimize:
166
+ if results[scaling_coef][metric] < best_metric:
167
+ best_metric = results[scaling_coef][metric]
168
+ best_coef = scaling_coef
169
+ else:
170
+ if results[scaling_coef][metric] > best_metric:
171
+ best_metric = results[scaling_coef][metric]
172
+ best_coef = scaling_coef
173
+ return best_coef
174
+
175
+
176
+ def nonlinear_advantage(nonlinear_acc, linear_acc, num_classes):
177
+ """Computes the normalized non-linear advantage of a finetuned model.
178
+
179
+ The nonlinear_advantage is defined as:
180
+ error_rate(linear_model) - error_rate(nonlinear_model) / (1 - 1 / num_classes)
181
+ and takes values between [-1, 1]. A value of 0 indicates that the nonlinear
182
+ model is no better than the linear one. Meanwhile, a value of 1 indicates
183
+ that the nonlinear model is perfect and the linear trivial, and a value of
184
+ -1 indicates the opposite.
185
+ """
186
+ return (nonlinear_acc - linear_acc) / (1.0 - 1.0 / num_classes)
187
+
188
+
189
+ def to_cuda(input_dict):
190
+ cuda_dict = {}
191
+ for key, value in input_dict.items():
192
+ cuda_dict[key] = value.to("cuda")
193
+ return cuda_dict
194
+
195
+
196
+ def state_dict_to_vector(state_dict, remove_keys=[]):
197
+ shared_state_dict = copy.deepcopy(state_dict)
198
+ for key in remove_keys:
199
+ if key in shared_state_dict:
200
+ del shared_state_dict[key]
201
+ sorted_shared_state_dict = OrderedDict(sorted(shared_state_dict.items()))
202
+ return torch.nn.utils.parameters_to_vector(
203
+ [value.reshape(-1) for key, value in sorted_shared_state_dict.items()]
204
+ )
205
+
206
+
207
+ def vector_to_state_dict(vector, state_dict, remove_keys=[]):
208
+ # create a reference dict to define the order of the vector
209
+ reference_dict = copy.deepcopy(state_dict)
210
+ for key in remove_keys:
211
+ if key in reference_dict:
212
+ del reference_dict[key]
213
+ sorted_reference_dict = OrderedDict(sorted(reference_dict.items()))
214
+
215
+ # create a shared state dict using the reference dict
216
+ torch.nn.utils.vector_to_parameters(vector, sorted_reference_dict.values())
217
+
218
+ # add back the encoder and decoder embedding weights.
219
+ if "transformer.shared.weight" in sorted_reference_dict:
220
+ for key in remove_keys:
221
+ sorted_reference_dict[key] = sorted_reference_dict[
222
+ "transformer.shared.weight"
223
+ ]
224
+ return sorted_reference_dict
225
+
226
+
227
+ def add_ptm_to_tv(tv_dict, ptm_dict):
228
+ assert set(tv_dict.keys()) == set(
229
+ ptm_dict.keys()
230
+ ), "Differing parameter names in models."
231
+ final_dict = copy.deepcopy(tv_dict)
232
+ for k, v in ptm_dict.items():
233
+ final_dict[k] = tv_dict[k] + v
234
+ return final_dict
235
+
236
+
237
+ def check_parameterNamesMatch(checkpoints):
238
+ parameter_names = set(checkpoints[0].keys())
239
+
240
+ if len(checkpoints) >= 2:
241
+ # raise ValueError("Number of models is less than 2.")
242
+ for checkpoint in checkpoints[1:]:
243
+ current_parameterNames = set(checkpoint.keys())
244
+ if current_parameterNames != parameter_names:
245
+ raise ValueError(
246
+ "Differing parameter names in models. "
247
+ f"The different parameters are {parameter_names.symmetric_difference(current_parameterNames)}"
248
+ )
249
+
250
+
251
+ def check_state_dicts_equal(state_dict1, state_dict2):
252
+ if set(state_dict1.keys()) != set(state_dict2.keys()):
253
+ return False
254
+
255
+ for key in state_dict1.keys():
256
+ if not torch.equal(state_dict1[key], state_dict2[key]):
257
+ return False
258
+
259
+ return True
260
+
261
+
262
+ def topk_values_mask(M, K=0.7, return_mask=False, reshape_mask=False):
263
+ if K == 100:
264
+ # print("Not applying mask")
265
+ if return_mask:
266
+ return M, torch.ones_like(M), None
267
+ else:
268
+ return M, torch.ones_like(M)
269
+
270
+ if K >= 1:
271
+ K /= 100
272
+
273
+ original_shape = M.shape
274
+ if M.dim() == 1:
275
+ M = M.unsqueeze(0)
276
+
277
+ n, d = M.shape
278
+ k = int(d * K)
279
+ k = d - k # Keep top k elements instead of bottom k elements
280
+
281
+ # Find the k-th smallest element by magnitude for each row
282
+ kth_values, _ = M.abs().kthvalue(k, dim=1, keepdim=True)
283
+ # Create a mask tensor with True for the top k elements in each row
284
+ mask = M.abs() >= kth_values
285
+ final_mask = mask.squeeze() if original_shape == M.squeeze().shape else mask
286
+
287
+ if reshape_mask:
288
+ final_mask = final_mask.reshape(M.shape)
289
+
290
+ if return_mask:
291
+ return M * final_mask, final_mask.float().mean(dim=1), final_mask
292
+ else:
293
+ return M * final_mask, final_mask.float().mean(dim=1)
294
+
295
+
296
+ def cleanup_linear(state_dict):
297
+ # The linear model also has keys for the reference point $\theta_0$ in the state dict with the prefix `params0`.
298
+ state_dict = {k: v for k, v in state_dict.items() if "params." in k}
299
+ return state_dict
300
+
301
+
302
+ def get_ptm_linear(state_dict: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
303
+ # rename keys so that they match afterwards
304
+ state_dict_new = {
305
+ k.replace("params0", "params"): v
306
+ for k, v in state_dict.items()
307
+ if "params0." in k
308
+ }
309
+ state_dict_remaining = {k: v for k, v in state_dict.items() if "params." not in k}
310
+
311
+ return state_dict_new, state_dict_remaining
@@ -0,0 +1,56 @@
1
+ from pathlib import Path
2
+ from typing import Literal
3
+
4
+ TQDM_BAR_FORMAT = "{l_bar}{bar:10}{r_bar}{bar:-10b}"
5
+ MODELS = ["ViT-B-32", "ViT-B-16", "ViT-L-14"]
6
+ OPENCLIP_CACHEDIR = Path(Path.home(), "openclip-cachedir", "open_clip").as_posix()
7
+ CACHEDIR = None
8
+
9
+ ALL_DATASETS = [
10
+ "Cars",
11
+ "DTD",
12
+ "EuroSAT",
13
+ "GTSRB",
14
+ "MNIST",
15
+ "RESISC45",
16
+ "SVHN",
17
+ "SUN397",
18
+ "STL10",
19
+ "OxfordIIITPet",
20
+ "Flowers102",
21
+ "CIFAR100",
22
+ "PCAM",
23
+ "FER2013",
24
+ "CIFAR10",
25
+ "Food101",
26
+ "FashionMNIST",
27
+ "RenderedSST2",
28
+ "EMNIST",
29
+ "KMNIST",
30
+ ]
31
+
32
+ DATASETS_8 = ALL_DATASETS[:8]
33
+ DATASETS_14 = ALL_DATASETS[:14]
34
+ DATASETS_20 = ALL_DATASETS[:20]
35
+
36
+
37
+ def cleanup_dataset_name(dataset_name: str):
38
+ return dataset_name.replace("Val", "") + "Val"
39
+
40
+
41
+ def get_zeroshot_path(root, dataset, model):
42
+ return Path(
43
+ root, model, cleanup_dataset_name(dataset), f"nonlinear_zeroshot.pt"
44
+ ).as_posix()
45
+
46
+
47
+ def get_finetuned_path(root, dataset, model):
48
+ return Path(
49
+ root, model, cleanup_dataset_name(dataset), f"nonlinear_finetuned.pt"
50
+ ).as_posix()
51
+
52
+
53
+ def get_single_task_accuracies_path(model):
54
+ return Path(
55
+ "results/single_task", model, f"nonlinear_ft_accuracies.json"
56
+ ).as_posix()
@@ -1,4 +1,4 @@
1
- from typing import List, Mapping
1
+ from typing import List, Mapping, Optional, Tuple
2
2
 
3
3
  import torch
4
4
  from torch import nn
@@ -6,7 +6,13 @@ from torch import nn
6
6
  __all__ = "ParamterDictModel"
7
7
 
8
8
 
9
- def set_attr(obj, names: List[str], val, check_parent: bool = False):
9
+ def _set_attr(
10
+ obj,
11
+ names: List[str],
12
+ val,
13
+ check_parent: bool = False,
14
+ parent_builder=nn.Module,
15
+ ):
10
16
  """
11
17
  Sets an attribute of an object recursively.
12
18
 
@@ -20,8 +26,14 @@ def set_attr(obj, names: List[str], val, check_parent: bool = False):
20
26
  setattr(obj, names[0], val)
21
27
  else:
22
28
  if check_parent and not hasattr(obj, names[0]):
23
- setattr(obj, names[0], nn.Module())
24
- set_attr(getattr(obj, names[0]), names[1:], val, check_parent=check_parent)
29
+ setattr(obj, names[0], parent_builder())
30
+ _set_attr(
31
+ getattr(obj, names[0]),
32
+ names[1:],
33
+ val,
34
+ check_parent=check_parent,
35
+ parent_builder=parent_builder,
36
+ )
25
37
 
26
38
 
27
39
  def has_attr(obj, names: List[str]):
@@ -49,17 +61,19 @@ class ParameterDictModel(nn.Module):
49
61
 
50
62
  def __init__(
51
63
  self,
52
- parameters: Mapping[str, nn.Parameter],
64
+ parameters: Optional[Mapping[str, nn.Parameter]] = None,
53
65
  ):
54
66
  super().__init__()
55
- for name, param in parameters.items():
56
- assert isinstance(param, nn.Parameter), f"{name} is not a nn.Parameter"
57
- set_attr(
58
- self,
59
- name.split("."),
60
- param,
61
- check_parent=True,
62
- )
67
+ if parameters is not None:
68
+ for name, param in parameters.items():
69
+ assert isinstance(param, nn.Parameter), f"{name} is not a nn.Parameter"
70
+ _set_attr(
71
+ self,
72
+ name.split("."),
73
+ param,
74
+ check_parent=True,
75
+ parent_builder=self.__class__,
76
+ )
63
77
 
64
78
  def __repr__(self):
65
79
  """
@@ -73,3 +87,30 @@ class ParameterDictModel(nn.Module):
73
87
  param_repr = f"{name}: {param.size()}"
74
88
  param_reprs.append(param_repr)
75
89
  return f"{self.__class__.__name__}({', '.join(param_reprs)})"
90
+
91
+ def __getitem__(self, key: str):
92
+ if not has_attr(self, key.split(".")):
93
+ raise KeyError(f"Key {key} not found in {self}")
94
+ key = key.split(".")
95
+ obj = self
96
+ for k in key:
97
+ obj = getattr(obj, k)
98
+ return obj
99
+
100
+ def __setitem__(self, key: str, value: nn.Parameter):
101
+ if not has_attr(self, key.split(".")):
102
+ _set_attr(self, key.split("."), value, check_parent=True)
103
+ else:
104
+ _set_attr(self, key.split("."), value, check_parent=False)
105
+
106
+ def __contains__(self, key: str):
107
+ return has_attr(self, key.split("."))
108
+
109
+ def keys(self):
110
+ return [name for name, _ in self.named_parameters()]
111
+
112
+ def items(self) -> List[Tuple[str, nn.Parameter]]:
113
+ return [(name, self[name]) for name in self.keys()]
114
+
115
+ def values(self) -> List[nn.Parameter]:
116
+ return [self[name] for name in self.keys()]
@@ -16,6 +16,7 @@ import torch
16
16
  from torch import Tensor, nn
17
17
  from torch.func import functional_call
18
18
 
19
+ from fusion_bench.models.utils import del_attr, get_attr, set_attr
19
20
  from fusion_bench.utils.type import StateDictType, TorchModelType
20
21
 
21
22
  __all__ = ["get_layer_wise_weights", "fuse_weights", "LayerWiseMergedModel"]
@@ -23,52 +24,6 @@ __all__ = ["get_layer_wise_weights", "fuse_weights", "LayerWiseMergedModel"]
23
24
  log = logging.getLogger(__name__)
24
25
 
25
26
 
26
- def del_attr(obj, names: List[str]):
27
- """
28
- Deletes an attribute from an object recursively.
29
-
30
- Args:
31
- obj (object): Object to delete attribute from.
32
- names (list): List of attribute names to delete recursively.
33
- """
34
- if len(names) == 1:
35
- delattr(obj, names[0])
36
- else:
37
- del_attr(getattr(obj, names[0]), names[1:])
38
-
39
-
40
- def set_attr(obj, names: List[str], val):
41
- """
42
- Sets an attribute of an object recursively.
43
-
44
- Args:
45
- obj (object): Object to set attribute of.
46
- names (list): List of attribute names to set recursively.
47
- val (object): Value to set the attribute to.
48
- """
49
- if len(names) == 1:
50
- setattr(obj, names[0], val)
51
- else:
52
- set_attr(getattr(obj, names[0]), names[1:], val)
53
-
54
-
55
- def get_attr(obj, names: List[str]):
56
- """
57
- Gets an attribute of an object recursively.
58
-
59
- Args:
60
- obj (object): Object to get attribute of.
61
- names (list): List of attribute names to get recursively.
62
-
63
- Returns:
64
- object: The attribute of the object.
65
- """
66
- if len(names) == 1:
67
- return getattr(obj, names[0])
68
- else:
69
- return get_attr(getattr(obj, names[0]), names[1:])
70
-
71
-
72
27
  def get_layer_wise_weights(
73
28
  num_models: int,
74
29
  num_layers: int,
@@ -10,132 +10,17 @@ import torch
10
10
  from torch import Tensor, nn
11
11
  from torch.func import functional_call
12
12
 
13
+ from fusion_bench.models.utils import del_attr, get_attr, set_attr
13
14
  from fusion_bench.utils.state_dict_arithmetic import state_dict_add
14
15
  from fusion_bench.utils.type import StateDictType
15
16
 
17
+ from .layer_wise_fusion import fuse_weights, get_layer_wise_weights
18
+
16
19
  __all__ = ["get_layer_wise_weights", "fuse_weights", "LayerWiseMergedModel"]
17
20
 
18
21
  log = logging.getLogger(__name__)
19
22
 
20
23
 
21
- def del_attr(obj, names: List[str]):
22
- """
23
- Deletes an attribute from an object recursively.
24
-
25
- Args:
26
- obj (object): Object to delete attribute from.
27
- names (list): List of attribute names to delete recursively.
28
- """
29
- if len(names) == 1:
30
- delattr(obj, names[0])
31
- else:
32
- del_attr(getattr(obj, names[0]), names[1:])
33
-
34
-
35
- def set_attr(obj, names: List[str], val):
36
- """
37
- Sets an attribute of an object recursively.
38
-
39
- Args:
40
- obj (object): Object to set attribute of.
41
- names (list): List of attribute names to set recursively.
42
- val (object): Value to set the attribute to.
43
- """
44
- if len(names) == 1:
45
- setattr(obj, names[0], val)
46
- else:
47
- set_attr(getattr(obj, names[0]), names[1:], val)
48
-
49
-
50
- def get_attr(obj, names: List[str]):
51
- """
52
- Gets an attribute of an object recursively.
53
-
54
- Args:
55
- obj (object): Object to get attribute of.
56
- names (list): List of attribute names to get recursively.
57
-
58
- Returns:
59
- object: The attribute of the object.
60
- """
61
- if len(names) == 1:
62
- return getattr(obj, names[0])
63
- else:
64
- return get_attr(getattr(obj, names[0]), names[1:])
65
-
66
-
67
- def get_layer_wise_weights(
68
- num_models: int,
69
- num_layers: int,
70
- init_values: float = None,
71
- dtype: torch.dtype = torch.float32,
72
- ):
73
- """
74
- Return a tensor of layer-wise weights for the given number of models and layers.
75
-
76
- Args:
77
- num_models (int): The number of models to fuse.
78
- num_layers (int): The number of layers in each model.
79
- init_values (float, optional): The initial value for each weight. Defaults to 1.0 / num_models.
80
- dtype (torch.dtype): dtype of weights. This should be the same with model dtype.
81
-
82
- Returns:
83
- Tensor: A tensor of shape (num_models, num_layers) containing the layer-wise weights.
84
- """
85
- assert num_models >= 1, f"num_models must be >= 1, got {num_models}"
86
- assert num_layers >= 1, f"num_layers must be >= 1, got {num_layers}"
87
- if init_values is None:
88
- init_values = 1.0 / num_models
89
- return torch.full((num_models, num_layers), init_values, dtype=dtype)
90
-
91
-
92
- def _fuse_weights(layer_wise_weight: Tensor, tensors: List[Tensor]):
93
- """
94
- Fuse the layer-wise weights with the given state dictionaries.
95
-
96
- Args:
97
- layer_wise_weight (Tensor): A tensor of shape (num_models,) containing the layer-wise weights.
98
- state_dicts (List[Tensor]): A list of state dictionaries, each containing the weights for a single layer.
99
-
100
- Returns:
101
- Tensor: A tensor of shape (num_params,) containing the fused weights.
102
- """
103
- assert len(layer_wise_weight) == len(
104
- tensors
105
- ), f"layer_wise_weight.shape={layer_wise_weight.shape}, len(tensors)={len(tensors)}"
106
- return sum(
107
- layer_wise_weight[i] * w.to(layer_wise_weight.device)
108
- for i, w in enumerate(tensors)
109
- )
110
-
111
-
112
- def fuse_weights(
113
- layer_wise_weight: Tensor, state_dicts: List[StateDictType]
114
- ) -> StateDictType:
115
- """
116
- Fuse the weights of multiple models using layer-wise fusion.
117
-
118
- Args:
119
- layer_wise_weight (Tensor): A tensor of shape (num_models, num_layers) representing the weight of each layer for each model.
120
- state_dicts (List[StateDict]): A list of state dictionaries, one for each model.
121
-
122
- Returns:
123
- A dictionary mapping each weight tensor key to the fused weight tensor.
124
- """
125
- num_models = len(state_dicts)
126
- num_layers = len(state_dicts[0])
127
- assert layer_wise_weight.shape == (
128
- num_models,
129
- num_layers,
130
- ), f"layer_wise_weight.shape={layer_wise_weight.shape}, expected (num_models, num_layers): ({num_models}, {num_layers})"
131
- return {
132
- k: _fuse_weights(
133
- layer_wise_weight[:, i], [state_dict[k] for state_dict in state_dicts]
134
- )
135
- for i, k in enumerate(state_dicts[0].keys())
136
- }
137
-
138
-
139
24
  class LayerWiseMergedModel(nn.Module):
140
25
  _merged_state_dict: StateDictType = None
141
26
 
@@ -390,7 +275,7 @@ class LayerWiseMergedModel(nn.Module):
390
275
  layer_vectors_scale = layer_vectors * layer_lamdas.view(-1, 1, 1)
391
276
  sum_over_num_vectors = layer_vectors_scale.sum(dim=0)
392
277
 
393
- layer_delta_scale = layer_delta.unsqueeze(0) * layer_lamdas.view(-1, 1, 1)
278
+ layer_delta_scale = layer_delta * layer_lamdas.view(-1, 1, 1)
394
279
  sum_over_delta = layer_delta_scale.sum(dim=0)
395
280
 
396
281
  # Iterate through each vector and calculate the loss one by one
@@ -1,5 +1,5 @@
1
1
  R"""
2
- This script is used to train a multi-task learning (MTL) model on the NYUv2 dataset.
2
+ This script is used to train a multi-task learning (MTL) model on the NYUv2 dataset.
3
3
  """
4
4
 
5
5
  import importlib