fusion-bench 0.2.25__py3-none-any.whl → 0.2.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. fusion_bench/dataset/clip_dataset.py +1 -0
  2. fusion_bench/method/__init__.py +4 -0
  3. fusion_bench/method/adamerging/__init__.py +28 -5
  4. fusion_bench/method/adamerging/resnet_adamerging.py +279 -0
  5. fusion_bench/method/adamerging/task_wise_adamerging.py +2 -14
  6. fusion_bench/method/adamerging/utils.py +58 -0
  7. fusion_bench/method/classification/clip_finetune.py +6 -4
  8. fusion_bench/method/classification/image_classification_finetune.py +156 -12
  9. fusion_bench/method/dare/simple_average.py +3 -2
  10. fusion_bench/method/dare/task_arithmetic.py +3 -2
  11. fusion_bench/method/dop/__init__.py +1 -0
  12. fusion_bench/method/dop/dop.py +366 -0
  13. fusion_bench/method/dop/min_norm_solvers.py +227 -0
  14. fusion_bench/method/dop/utils.py +73 -0
  15. fusion_bench/method/simple_average.py +6 -4
  16. fusion_bench/mixins/lightning_fabric.py +9 -0
  17. fusion_bench/modelpool/causal_lm/causal_lm.py +2 -1
  18. fusion_bench/modelpool/resnet_for_image_classification.py +285 -4
  19. fusion_bench/models/hf_clip.py +4 -7
  20. fusion_bench/models/hf_utils.py +4 -1
  21. fusion_bench/taskpool/__init__.py +2 -0
  22. fusion_bench/taskpool/clip_vision/taskpool.py +1 -1
  23. fusion_bench/taskpool/resnet_for_image_classification.py +231 -0
  24. fusion_bench/utils/state_dict_arithmetic.py +91 -10
  25. {fusion_bench-0.2.25.dist-info → fusion_bench-0.2.27.dist-info}/METADATA +9 -3
  26. {fusion_bench-0.2.25.dist-info → fusion_bench-0.2.27.dist-info}/RECORD +140 -77
  27. fusion_bench_config/fabric/auto.yaml +1 -1
  28. fusion_bench_config/fabric/loggers/swandb_logger.yaml +5 -0
  29. fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
  30. fusion_bench_config/fabric_model_fusion.yaml +1 -0
  31. fusion_bench_config/method/adamerging/resnet.yaml +18 -0
  32. fusion_bench_config/method/bitdelta/bitdelta.yaml +3 -0
  33. fusion_bench_config/method/classification/clip_finetune.yaml +5 -0
  34. fusion_bench_config/method/classification/image_classification_finetune.yaml +9 -0
  35. fusion_bench_config/method/depth_upscaling.yaml +9 -0
  36. fusion_bench_config/method/dop/dop.yaml +30 -0
  37. fusion_bench_config/method/dummy.yaml +6 -0
  38. fusion_bench_config/method/ensemble/max_model_predictor.yaml +6 -0
  39. fusion_bench_config/method/ensemble/simple_ensemble.yaml +8 -1
  40. fusion_bench_config/method/ensemble/weighted_ensemble.yaml +8 -0
  41. fusion_bench_config/method/linear/expo.yaml +5 -0
  42. fusion_bench_config/method/linear/linear_interpolation.yaml +8 -0
  43. fusion_bench_config/method/linear/llama_expo.yaml +5 -0
  44. fusion_bench_config/method/linear/llama_expo_with_dare.yaml +3 -0
  45. fusion_bench_config/method/linear/simple_average_for_causallm.yaml +5 -0
  46. fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +3 -0
  47. fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +5 -0
  48. fusion_bench_config/method/linear/weighted_average.yaml +3 -0
  49. fusion_bench_config/method/linear/weighted_average_for_llama.yaml +6 -1
  50. fusion_bench_config/method/mixtral_moe_merging.yaml +3 -0
  51. fusion_bench_config/method/mixtral_moe_upscaling.yaml +5 -0
  52. fusion_bench_config/method/model_recombination.yaml +8 -0
  53. fusion_bench_config/method/model_stock/model_stock.yaml +4 -1
  54. fusion_bench_config/method/opcm/opcm.yaml +5 -0
  55. fusion_bench_config/method/opcm/task_arithmetic.yaml +6 -0
  56. fusion_bench_config/method/opcm/ties_merging.yaml +5 -0
  57. fusion_bench_config/method/opcm/weight_average.yaml +5 -0
  58. fusion_bench_config/method/regmean/clip_regmean.yaml +3 -0
  59. fusion_bench_config/method/regmean/gpt2_regmean.yaml +3 -0
  60. fusion_bench_config/method/regmean/regmean.yaml +3 -0
  61. fusion_bench_config/method/regmean_plusplus/clip_regmean_plusplus.yaml +3 -0
  62. fusion_bench_config/method/simple_average.yaml +9 -0
  63. fusion_bench_config/method/slerp/slerp.yaml +9 -0
  64. fusion_bench_config/method/slerp/slerp_lm.yaml +5 -0
  65. fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +6 -0
  66. fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
  67. fusion_bench_config/method/smile_upscaling/projected_energy.yaml +5 -0
  68. fusion_bench_config/method/smile_upscaling/singular_projection_merging.yaml +3 -0
  69. fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +5 -0
  70. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +5 -0
  71. fusion_bench_config/method/smile_upscaling/smile_upscaling.yaml +3 -0
  72. fusion_bench_config/method/task_arithmetic.yaml +9 -0
  73. fusion_bench_config/method/ties_merging.yaml +3 -0
  74. fusion_bench_config/method/wudi/wudi.yaml +3 -0
  75. fusion_bench_config/model_fusion.yaml +2 -1
  76. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/_generate_config.py +138 -0
  77. fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet152_cifar10.yaml +1 -1
  78. fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet152_cifar100.yaml +1 -1
  79. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_dtd.yaml +14 -0
  80. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_emnist_letters.yaml +14 -0
  81. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_eurosat.yaml +14 -0
  82. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_fashion_mnist.yaml +14 -0
  83. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_fer2013.yaml +14 -0
  84. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_food101.yaml +14 -0
  85. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_gtsrb.yaml +14 -0
  86. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_kmnist.yaml +14 -0
  87. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_mnist.yaml +14 -0
  88. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_oxford-iiit-pet.yaml +14 -0
  89. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_oxford_flowers102.yaml +14 -0
  90. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_pcam.yaml +14 -0
  91. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_rendered-sst2.yaml +14 -0
  92. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_resisc45.yaml +14 -0
  93. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_stanford-cars.yaml +14 -0
  94. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_stl10.yaml +14 -0
  95. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_sun397.yaml +14 -0
  96. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_svhn.yaml +14 -0
  97. fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet18_cifar10.yaml +1 -1
  98. fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet18_cifar100.yaml +1 -1
  99. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_dtd.yaml +14 -0
  100. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_emnist_letters.yaml +14 -0
  101. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_eurosat.yaml +14 -0
  102. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_fashion_mnist.yaml +14 -0
  103. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_fer2013.yaml +14 -0
  104. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_food101.yaml +14 -0
  105. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_gtsrb.yaml +14 -0
  106. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_kmnist.yaml +14 -0
  107. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_mnist.yaml +14 -0
  108. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_oxford-iiit-pet.yaml +14 -0
  109. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_oxford_flowers102.yaml +14 -0
  110. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_pcam.yaml +14 -0
  111. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_rendered-sst2.yaml +14 -0
  112. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_resisc45.yaml +14 -0
  113. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_stanford-cars.yaml +14 -0
  114. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_stl10.yaml +14 -0
  115. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_sun397.yaml +14 -0
  116. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_svhn.yaml +14 -0
  117. fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet50_cifar10.yaml +1 -1
  118. fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet50_cifar100.yaml +1 -1
  119. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_dtd.yaml +14 -0
  120. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_emnist_letters.yaml +14 -0
  121. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_eurosat.yaml +14 -0
  122. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_fashion_mnist.yaml +14 -0
  123. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_fer2013.yaml +14 -0
  124. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_food101.yaml +14 -0
  125. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_gtsrb.yaml +14 -0
  126. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_kmnist.yaml +14 -0
  127. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_mnist.yaml +14 -0
  128. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_oxford-iiit-pet.yaml +14 -0
  129. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_oxford_flowers102.yaml +14 -0
  130. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_pcam.yaml +14 -0
  131. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_rendered-sst2.yaml +14 -0
  132. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_resisc45.yaml +14 -0
  133. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_stanford-cars.yaml +14 -0
  134. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_stl10.yaml +14 -0
  135. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_sun397.yaml +14 -0
  136. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_svhn.yaml +14 -0
  137. fusion_bench_config/method/clip_finetune.yaml +0 -26
  138. {fusion_bench-0.2.25.dist-info → fusion_bench-0.2.27.dist-info}/WHEEL +0 -0
  139. {fusion_bench-0.2.25.dist-info → fusion_bench-0.2.27.dist-info}/entry_points.txt +0 -0
  140. {fusion_bench-0.2.25.dist-info → fusion_bench-0.2.27.dist-info}/licenses/LICENSE +0 -0
  141. {fusion_bench-0.2.25.dist-info → fusion_bench-0.2.27.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,227 @@
1
+ # This code is from
2
+ # Multi-Task Learning as Multi-Objective Optimization
3
+ # Ozan Sener, Vladlen Koltun
4
+ # Neural Information Processing Systems (NeurIPS) 2018
5
+ # https://github.com/intel-isl/MultiObjectiveOptimization
6
+ from typing import Union
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+
12
+ def np_sum(x: Union[torch.Tensor, np.ndarray]) -> float:
13
+ if isinstance(x, torch.Tensor):
14
+ return x.sum().item()
15
+ return np.sum(x)
16
+
17
+
18
+ def to_numpy(x: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
19
+ if isinstance(x, torch.Tensor):
20
+ return x.detach().cpu().numpy()
21
+ return x
22
+
23
+
24
+ class MinNormSolver:
25
+ MAX_ITER = 250
26
+ STOP_CRIT = 1e-5
27
+
28
+ def _min_norm_element_from2(v1v1, v1v2, v2v2):
29
+ """
30
+ Analytical solution for min_{c} |cx_1 + (1-c)x_2|_2^2
31
+ d is the distance (objective) optimzed
32
+ v1v1 = <x1,x1>
33
+ v1v2 = <x1,x2>
34
+ v2v2 = <x2,x2>
35
+ """
36
+ if v1v2 >= v1v1:
37
+ # Case: Fig 1, third column
38
+ gamma = 0.999
39
+ cost = v1v1
40
+ return gamma, cost
41
+ if v1v2 >= v2v2:
42
+ # Case: Fig 1, first column
43
+ gamma = 0.001
44
+ cost = v2v2
45
+ return gamma, cost
46
+ # Case: Fig 1, second column
47
+ gamma = -1.0 * ((v1v2 - v2v2) / (v1v1 + v2v2 - 2 * v1v2))
48
+ cost = v2v2 + gamma * (v1v2 - v2v2)
49
+ return gamma, cost
50
+
51
+ def _min_norm_2d(vecs, dps):
52
+ R"""
53
+ Find the minimum norm solution as combination of two points
54
+ This is correct only in 2D
55
+ ie. min_c |\sum c_i x_i|_2^2 st. \sum c_i = 1 , 1 >= c_1 >= 0 for all i, c_i + c_j = 1.0 for some i, j
56
+ """
57
+ dmin = 1e8
58
+ for i in range(len(vecs)):
59
+ for j in range(i + 1, len(vecs)):
60
+ if (i, j) not in dps:
61
+ dps[(i, j)] = 0.0
62
+ for k in range(len(vecs[i])):
63
+ dps[(i, j)] += (
64
+ torch.mul(vecs[i][k], vecs[j][k]).sum().data.cpu()
65
+ )
66
+ dps[(j, i)] = dps[(i, j)]
67
+ if (i, i) not in dps:
68
+ dps[(i, i)] = 0.0
69
+ for k in range(len(vecs[i])):
70
+ dps[(i, i)] += (
71
+ torch.mul(vecs[i][k], vecs[i][k]).sum().data.cpu()
72
+ )
73
+ if (j, j) not in dps:
74
+ dps[(j, j)] = 0.0
75
+ for k in range(len(vecs[i])):
76
+ dps[(j, j)] += (
77
+ torch.mul(vecs[j][k], vecs[j][k]).sum().data.cpu()
78
+ )
79
+ c, d = MinNormSolver._min_norm_element_from2(
80
+ dps[(i, i)], dps[(i, j)], dps[(j, j)]
81
+ )
82
+ if d < dmin:
83
+ dmin = d
84
+ sol = [(i, j), c, d]
85
+ return sol, dps
86
+
87
+ def _projection2simplex(y):
88
+ R"""
89
+ Given y, it solves argmin_z |y-z|_2 st \sum z = 1 , 1 >= z_i >= 0 for all i
90
+ """
91
+ m = len(y)
92
+ sorted_y = np.flip(np.sort(y), axis=0)
93
+ tmpsum = 0.0
94
+ tmax_f = (np.sum(y) - 1.0) / m
95
+ for i in range(m - 1):
96
+ tmpsum += sorted_y[i]
97
+ tmax = (tmpsum - 1) / (i + 1.0)
98
+ if tmax > sorted_y[i + 1]:
99
+ tmax_f = tmax
100
+ break
101
+ return np.maximum(y - tmax_f, np.zeros(y.shape))
102
+
103
+ def _next_point(cur_val, grad, n):
104
+ proj_grad = grad - (np.sum(grad) / n)
105
+ tm1 = -1.0 * cur_val[proj_grad < 0] / proj_grad[proj_grad < 0]
106
+ tm2 = (1.0 - cur_val[proj_grad > 0]) / (proj_grad[proj_grad > 0])
107
+
108
+ skippers = np_sum(tm1 < 1e-7) + np_sum(tm2 < 1e-7)
109
+ t = 1
110
+ if len(tm1[tm1 > 1e-7]) > 0:
111
+ t = np.min(to_numpy(tm1[tm1 > 1e-7]))
112
+ if len(tm2[tm2 > 1e-7]) > 0:
113
+ t = min(t, np.min(to_numpy(tm2[tm2 > 1e-7])))
114
+
115
+ next_point = proj_grad * t + to_numpy(cur_val)
116
+ next_point = MinNormSolver._projection2simplex(next_point)
117
+ return next_point
118
+
119
+ def find_min_norm_element(vecs):
120
+ R"""
121
+ Given a list of vectors (vecs), this method finds the minimum norm element in the convex hull
122
+ as min |u|_2 st. u = \sum c_i vecs[i] and \sum c_i = 1.
123
+ It is quite geometric, and the main idea is the fact that if d_{ij} = min |u|_2 st u = c x_i + (1-c) x_j; the solution lies in (0, d_{i,j})
124
+ Hence, we find the best 2-task solution, and then run the projected gradient descent until convergence
125
+ """
126
+ # Solution lying at the combination of two points
127
+ dps = {}
128
+ init_sol, dps = MinNormSolver._min_norm_2d(vecs, dps)
129
+
130
+ n = len(vecs)
131
+ sol_vec = np.zeros(n)
132
+ sol_vec[init_sol[0][0]] = init_sol[1]
133
+ sol_vec[init_sol[0][1]] = 1 - init_sol[1]
134
+
135
+ if n < 3:
136
+ # This is optimal for n=2, so return the solution
137
+ return sol_vec, init_sol[2]
138
+
139
+ iter_count = 0
140
+
141
+ grad_mat = np.zeros((n, n))
142
+ for i in range(n):
143
+ for j in range(n):
144
+ grad_mat[i, j] = dps[(i, j)]
145
+
146
+ while iter_count < MinNormSolver.MAX_ITER:
147
+ grad_dir = -1.0 * np.dot(grad_mat, sol_vec)
148
+ new_point = MinNormSolver._next_point(sol_vec, grad_dir, n)
149
+ # Re-compute the inner products for line search
150
+ v1v1 = 0.0
151
+ v1v2 = 0.0
152
+ v2v2 = 0.0
153
+ for i in range(n):
154
+ for j in range(n):
155
+ v1v1 += sol_vec[i] * sol_vec[j] * dps[(i, j)]
156
+ v1v2 += sol_vec[i] * new_point[j] * dps[(i, j)]
157
+ v2v2 += new_point[i] * new_point[j] * dps[(i, j)]
158
+ nc, nd = MinNormSolver._min_norm_element_from2(v1v1, v1v2, v2v2)
159
+ new_sol_vec = nc * sol_vec + (1 - nc) * new_point
160
+ change = new_sol_vec - sol_vec
161
+ if np_sum(np.abs(change)) < MinNormSolver.STOP_CRIT:
162
+ return sol_vec, nd
163
+ sol_vec = new_sol_vec
164
+
165
+ def find_min_norm_element_FW(vecs):
166
+ R"""
167
+ Given a list of vectors (vecs), this method finds the minimum norm element in the convex hull
168
+ as min |u|_2 st. u = \sum c_i vecs[i] and \sum c_i = 1.
169
+ It is quite geometric, and the main idea is the fact that if d_{ij} = min |u|_2 st u = c x_i + (1-c) x_j; the solution lies in (0, d_{i,j})
170
+ Hence, we find the best 2-task solution, and then run the Frank Wolfe until convergence
171
+ """
172
+ # Solution lying at the combination of two points
173
+ dps = {}
174
+ init_sol, dps = MinNormSolver._min_norm_2d(vecs, dps)
175
+
176
+ n = len(vecs)
177
+ sol_vec = np.zeros(n)
178
+ sol_vec[init_sol[0][0]] = init_sol[1]
179
+ sol_vec[init_sol[0][1]] = 1 - init_sol[1]
180
+
181
+ if n < 3:
182
+ # This is optimal for n=2, so return the solution
183
+ return sol_vec, init_sol[2]
184
+
185
+ iter_count = 0
186
+
187
+ grad_mat = np.zeros((n, n))
188
+ for i in range(n):
189
+ for j in range(n):
190
+ grad_mat[i, j] = dps[(i, j)]
191
+
192
+ while iter_count < MinNormSolver.MAX_ITER:
193
+ t_iter = np.argmin(np.dot(grad_mat, sol_vec))
194
+
195
+ v1v1 = np.dot(sol_vec, np.dot(grad_mat, sol_vec))
196
+ v1v2 = np.dot(sol_vec, grad_mat[:, t_iter])
197
+ v2v2 = grad_mat[t_iter, t_iter]
198
+
199
+ nc, nd = MinNormSolver._min_norm_element_from2(v1v1, v1v2, v2v2)
200
+ new_sol_vec = nc * sol_vec
201
+ new_sol_vec[t_iter] += 1 - nc
202
+
203
+ change = new_sol_vec - sol_vec
204
+ if np_sum(np.abs(change)) < MinNormSolver.STOP_CRIT:
205
+ return sol_vec, nd
206
+ sol_vec = new_sol_vec
207
+
208
+
209
+ def gradient_normalizers(grads, losses, normalization_type):
210
+ gn = {}
211
+ if normalization_type == "l2":
212
+ for t in grads:
213
+ gn[t] = np.sqrt(np.sum([gr.pow(2).sum().data.cpu() for gr in grads[t]]))
214
+ elif normalization_type == "loss":
215
+ for t in grads:
216
+ gn[t] = losses[t]
217
+ elif normalization_type == "loss+":
218
+ for t in grads:
219
+ gn[t] = losses[t] * np.sqrt(
220
+ np.sum([gr.pow(2).sum().data.cpu() for gr in grads[t]])
221
+ )
222
+ elif normalization_type == "none":
223
+ for t in grads:
224
+ gn[t] = 1.0
225
+ else:
226
+ print("ERROR: Invalid Normalization Type")
227
+ return gn
@@ -0,0 +1,73 @@
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ from torch import Tensor, nn
5
+
6
+ from fusion_bench.utils.parameters import state_dict_to_vector
7
+ from fusion_bench.utils.state_dict_arithmetic import state_dict_sub
8
+
9
+
10
+ def _svd(w: Tensor, full_matrices=True) -> Tuple[Tensor, Tensor, Tensor]:
11
+ """
12
+ Perform Singular Value Decomposition (SVD) on a tensor.
13
+
14
+ Args:
15
+ w (Tensor): The input tensor.
16
+ full_matrices (bool): Whether to compute the full-sized U and V matrices.
17
+
18
+ Returns:
19
+ Tuple[Tensor, Tensor, Tensor]: The U, S, and V matrices from SVD.
20
+ """
21
+ u, s, vh = torch.linalg.svd(
22
+ w, full_matrices=full_matrices, driver="gesvd" if w.is_cuda else None
23
+ )
24
+ v = vh.T
25
+ return u, s, v
26
+
27
+
28
+ def svd(
29
+ w: Tensor, full_matrices=True, accelerator=None
30
+ ) -> Tuple[Tensor, Tensor, Tensor]:
31
+ """
32
+ Perform SVD on a tensor, optionally using a specified accelerator.
33
+
34
+ Args:
35
+ w (Tensor): The input tensor.
36
+ full_matrices (bool): Whether to compute the full-sized U and V matrices.
37
+ accelerator (str): The device to perform the computation on.
38
+
39
+ Returns:
40
+ Tuple[Tensor, Tensor, Tensor]: The U, S, and V matrices from SVD.
41
+ """
42
+ if accelerator is None:
43
+ return _svd(w, full_matrices=full_matrices)
44
+ original_device = w.device
45
+ w = w.to(accelerator)
46
+ u, s, v = _svd(w)
47
+ return u.to(original_device), s.to(original_device), v.to(original_device)
48
+
49
+
50
+ def frobenius_inner_product(w1: Tensor, w2: Tensor) -> Tensor:
51
+ return torch.trace(w1.T @ w2)
52
+
53
+
54
+ def is_leaf_module(module: nn.Module) -> bool:
55
+ return len(list(module.children())) == 0
56
+
57
+
58
+ def get_task_vector_norm(model: nn.Module, pretrained_model: nn.Module) -> Tensor:
59
+ """
60
+ Get the vector norm of the task model.
61
+
62
+ Args:
63
+ model (nn.Module): The task model.
64
+ pretrained_model (nn.Module): The pretrained model.
65
+
66
+ Returns:
67
+ Tensor: The vector norm of the task model.
68
+ """
69
+ return torch.linalg.norm(
70
+ state_dict_to_vector(
71
+ state_dict_sub(model.state_dict(), pretrained_model.state_dict())
72
+ )
73
+ )
@@ -64,10 +64,12 @@ class SimpleAverageAlgorithm(
64
64
  SimpleProfilerMixin,
65
65
  BaseAlgorithm,
66
66
  ):
67
- def __init__(self, show_pbar: bool = False, **kwargs):
67
+ def __init__(self, show_pbar: bool = False, inplace: bool = True, **kwargs):
68
68
  """
69
69
  Args:
70
70
  show_pbar (bool): If True, shows a progress bar during model loading and merging. Default is False.
71
+ inplace (bool): If True, overwrites the weights of the first model in the model pool.
72
+ If False, creates a new model for the merged weights. Default is True.
71
73
  """
72
74
  super().__init__(**kwargs)
73
75
 
@@ -104,12 +106,12 @@ class SimpleAverageAlgorithm(
104
106
  with self.profile("merge weights"):
105
107
  if sd is None:
106
108
  # Initialize the state dictionary with the first model's state dictionary
107
- sd = model.state_dict(keep_vars=True)
108
- forward_model = model
109
+ sd = model.state_dict()
110
+ forward_model = model if self.inplace else deepcopy(model)
109
111
  else:
110
112
  # Add the current model's state dictionary to the accumulated state dictionary
111
113
  sd = state_dict_add(
112
- sd, model.state_dict(keep_vars=True), show_pbar=self.show_pbar
114
+ sd, model.state_dict(), show_pbar=self.show_pbar
113
115
  )
114
116
  with self.profile("merge weights"):
115
117
  # Divide the accumulated state dictionary by the number of models to get the average
@@ -111,6 +111,15 @@ class LightningFabricMixin:
111
111
  """
112
112
  if self.fabric is not None and len(self.fabric._loggers) > 0:
113
113
  log_dir = self.fabric.logger.log_dir
114
+
115
+ # Special handling for SwanLabLogger to get the correct log directory
116
+ if (
117
+ log_dir is None
118
+ and self.fabric.logger.__class__.__name__ == "SwanLabLogger"
119
+ ):
120
+ log_dir = self.fabric.logger.save_dir or self.fabric.logger._logdir
121
+
122
+ assert log_dir is not None, "log_dir should not be None"
114
123
  if self.fabric.is_global_zero and not os.path.exists(log_dir):
115
124
  os.makedirs(log_dir, exist_ok=True)
116
125
  return log_dir
@@ -8,6 +8,7 @@ from copy import deepcopy
8
8
  from typing import Any, Dict, Optional, TypeAlias, Union, cast # noqa: F401
9
9
 
10
10
  import peft
11
+ from lightning_utilities.core.rank_zero import rank_zero_only
11
12
  from omegaconf import DictConfig, OmegaConf, flag_override
12
13
  from torch import nn
13
14
  from torch.nn.modules import Module
@@ -342,7 +343,7 @@ class CausalLMPool(BaseModelPool):
342
343
  )
343
344
 
344
345
  # Create and save model card if algorithm_config is provided
345
- if algorithm_config is not None:
346
+ if algorithm_config is not None and rank_zero_only.rank == 0:
346
347
  if description is None:
347
348
  description = "Model created using FusionBench."
348
349
  model_card_str = create_default_model_card(