fusion-bench 0.2.25__py3-none-any.whl → 0.2.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/dataset/clip_dataset.py +1 -0
- fusion_bench/method/__init__.py +4 -0
- fusion_bench/method/adamerging/__init__.py +28 -5
- fusion_bench/method/adamerging/resnet_adamerging.py +279 -0
- fusion_bench/method/adamerging/task_wise_adamerging.py +2 -14
- fusion_bench/method/adamerging/utils.py +58 -0
- fusion_bench/method/classification/clip_finetune.py +6 -4
- fusion_bench/method/classification/image_classification_finetune.py +156 -12
- fusion_bench/method/dare/simple_average.py +3 -2
- fusion_bench/method/dare/task_arithmetic.py +3 -2
- fusion_bench/method/dop/__init__.py +1 -0
- fusion_bench/method/dop/dop.py +366 -0
- fusion_bench/method/dop/min_norm_solvers.py +227 -0
- fusion_bench/method/dop/utils.py +73 -0
- fusion_bench/method/simple_average.py +6 -4
- fusion_bench/mixins/lightning_fabric.py +9 -0
- fusion_bench/modelpool/causal_lm/causal_lm.py +2 -1
- fusion_bench/modelpool/resnet_for_image_classification.py +285 -4
- fusion_bench/models/hf_clip.py +4 -7
- fusion_bench/models/hf_utils.py +4 -1
- fusion_bench/taskpool/__init__.py +2 -0
- fusion_bench/taskpool/clip_vision/taskpool.py +1 -1
- fusion_bench/taskpool/resnet_for_image_classification.py +231 -0
- fusion_bench/utils/state_dict_arithmetic.py +91 -10
- {fusion_bench-0.2.25.dist-info → fusion_bench-0.2.27.dist-info}/METADATA +9 -3
- {fusion_bench-0.2.25.dist-info → fusion_bench-0.2.27.dist-info}/RECORD +140 -77
- fusion_bench_config/fabric/auto.yaml +1 -1
- fusion_bench_config/fabric/loggers/swandb_logger.yaml +5 -0
- fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
- fusion_bench_config/fabric_model_fusion.yaml +1 -0
- fusion_bench_config/method/adamerging/resnet.yaml +18 -0
- fusion_bench_config/method/bitdelta/bitdelta.yaml +3 -0
- fusion_bench_config/method/classification/clip_finetune.yaml +5 -0
- fusion_bench_config/method/classification/image_classification_finetune.yaml +9 -0
- fusion_bench_config/method/depth_upscaling.yaml +9 -0
- fusion_bench_config/method/dop/dop.yaml +30 -0
- fusion_bench_config/method/dummy.yaml +6 -0
- fusion_bench_config/method/ensemble/max_model_predictor.yaml +6 -0
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +8 -1
- fusion_bench_config/method/ensemble/weighted_ensemble.yaml +8 -0
- fusion_bench_config/method/linear/expo.yaml +5 -0
- fusion_bench_config/method/linear/linear_interpolation.yaml +8 -0
- fusion_bench_config/method/linear/llama_expo.yaml +5 -0
- fusion_bench_config/method/linear/llama_expo_with_dare.yaml +3 -0
- fusion_bench_config/method/linear/simple_average_for_causallm.yaml +5 -0
- fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +3 -0
- fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +5 -0
- fusion_bench_config/method/linear/weighted_average.yaml +3 -0
- fusion_bench_config/method/linear/weighted_average_for_llama.yaml +6 -1
- fusion_bench_config/method/mixtral_moe_merging.yaml +3 -0
- fusion_bench_config/method/mixtral_moe_upscaling.yaml +5 -0
- fusion_bench_config/method/model_recombination.yaml +8 -0
- fusion_bench_config/method/model_stock/model_stock.yaml +4 -1
- fusion_bench_config/method/opcm/opcm.yaml +5 -0
- fusion_bench_config/method/opcm/task_arithmetic.yaml +6 -0
- fusion_bench_config/method/opcm/ties_merging.yaml +5 -0
- fusion_bench_config/method/opcm/weight_average.yaml +5 -0
- fusion_bench_config/method/regmean/clip_regmean.yaml +3 -0
- fusion_bench_config/method/regmean/gpt2_regmean.yaml +3 -0
- fusion_bench_config/method/regmean/regmean.yaml +3 -0
- fusion_bench_config/method/regmean_plusplus/clip_regmean_plusplus.yaml +3 -0
- fusion_bench_config/method/simple_average.yaml +9 -0
- fusion_bench_config/method/slerp/slerp.yaml +9 -0
- fusion_bench_config/method/slerp/slerp_lm.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +6 -0
- fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/projected_energy.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/singular_projection_merging.yaml +3 -0
- fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/smile_upscaling.yaml +3 -0
- fusion_bench_config/method/task_arithmetic.yaml +9 -0
- fusion_bench_config/method/ties_merging.yaml +3 -0
- fusion_bench_config/method/wudi/wudi.yaml +3 -0
- fusion_bench_config/model_fusion.yaml +2 -1
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/_generate_config.py +138 -0
- fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet152_cifar10.yaml +1 -1
- fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet152_cifar100.yaml +1 -1
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_dtd.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_emnist_letters.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_eurosat.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_fashion_mnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_fer2013.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_food101.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_gtsrb.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_kmnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_mnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_oxford-iiit-pet.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_oxford_flowers102.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_pcam.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_rendered-sst2.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_resisc45.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_stanford-cars.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_stl10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_sun397.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_svhn.yaml +14 -0
- fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet18_cifar10.yaml +1 -1
- fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet18_cifar100.yaml +1 -1
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_dtd.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_emnist_letters.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_eurosat.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_fashion_mnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_fer2013.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_food101.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_gtsrb.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_kmnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_mnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_oxford-iiit-pet.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_oxford_flowers102.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_pcam.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_rendered-sst2.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_resisc45.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_stanford-cars.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_stl10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_sun397.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_svhn.yaml +14 -0
- fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet50_cifar10.yaml +1 -1
- fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet50_cifar100.yaml +1 -1
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_dtd.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_emnist_letters.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_eurosat.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_fashion_mnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_fer2013.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_food101.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_gtsrb.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_kmnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_mnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_oxford-iiit-pet.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_oxford_flowers102.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_pcam.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_rendered-sst2.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_resisc45.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_stanford-cars.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_stl10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_sun397.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_svhn.yaml +14 -0
- fusion_bench_config/method/clip_finetune.yaml +0 -26
- {fusion_bench-0.2.25.dist-info → fusion_bench-0.2.27.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.25.dist-info → fusion_bench-0.2.27.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.25.dist-info → fusion_bench-0.2.27.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.25.dist-info → fusion_bench-0.2.27.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
# This code is from
|
|
2
|
+
# Multi-Task Learning as Multi-Objective Optimization
|
|
3
|
+
# Ozan Sener, Vladlen Koltun
|
|
4
|
+
# Neural Information Processing Systems (NeurIPS) 2018
|
|
5
|
+
# https://github.com/intel-isl/MultiObjectiveOptimization
|
|
6
|
+
from typing import Union
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def np_sum(x: Union[torch.Tensor, np.ndarray]) -> float:
|
|
13
|
+
if isinstance(x, torch.Tensor):
|
|
14
|
+
return x.sum().item()
|
|
15
|
+
return np.sum(x)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def to_numpy(x: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
|
|
19
|
+
if isinstance(x, torch.Tensor):
|
|
20
|
+
return x.detach().cpu().numpy()
|
|
21
|
+
return x
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class MinNormSolver:
|
|
25
|
+
MAX_ITER = 250
|
|
26
|
+
STOP_CRIT = 1e-5
|
|
27
|
+
|
|
28
|
+
def _min_norm_element_from2(v1v1, v1v2, v2v2):
|
|
29
|
+
"""
|
|
30
|
+
Analytical solution for min_{c} |cx_1 + (1-c)x_2|_2^2
|
|
31
|
+
d is the distance (objective) optimzed
|
|
32
|
+
v1v1 = <x1,x1>
|
|
33
|
+
v1v2 = <x1,x2>
|
|
34
|
+
v2v2 = <x2,x2>
|
|
35
|
+
"""
|
|
36
|
+
if v1v2 >= v1v1:
|
|
37
|
+
# Case: Fig 1, third column
|
|
38
|
+
gamma = 0.999
|
|
39
|
+
cost = v1v1
|
|
40
|
+
return gamma, cost
|
|
41
|
+
if v1v2 >= v2v2:
|
|
42
|
+
# Case: Fig 1, first column
|
|
43
|
+
gamma = 0.001
|
|
44
|
+
cost = v2v2
|
|
45
|
+
return gamma, cost
|
|
46
|
+
# Case: Fig 1, second column
|
|
47
|
+
gamma = -1.0 * ((v1v2 - v2v2) / (v1v1 + v2v2 - 2 * v1v2))
|
|
48
|
+
cost = v2v2 + gamma * (v1v2 - v2v2)
|
|
49
|
+
return gamma, cost
|
|
50
|
+
|
|
51
|
+
def _min_norm_2d(vecs, dps):
|
|
52
|
+
R"""
|
|
53
|
+
Find the minimum norm solution as combination of two points
|
|
54
|
+
This is correct only in 2D
|
|
55
|
+
ie. min_c |\sum c_i x_i|_2^2 st. \sum c_i = 1 , 1 >= c_1 >= 0 for all i, c_i + c_j = 1.0 for some i, j
|
|
56
|
+
"""
|
|
57
|
+
dmin = 1e8
|
|
58
|
+
for i in range(len(vecs)):
|
|
59
|
+
for j in range(i + 1, len(vecs)):
|
|
60
|
+
if (i, j) not in dps:
|
|
61
|
+
dps[(i, j)] = 0.0
|
|
62
|
+
for k in range(len(vecs[i])):
|
|
63
|
+
dps[(i, j)] += (
|
|
64
|
+
torch.mul(vecs[i][k], vecs[j][k]).sum().data.cpu()
|
|
65
|
+
)
|
|
66
|
+
dps[(j, i)] = dps[(i, j)]
|
|
67
|
+
if (i, i) not in dps:
|
|
68
|
+
dps[(i, i)] = 0.0
|
|
69
|
+
for k in range(len(vecs[i])):
|
|
70
|
+
dps[(i, i)] += (
|
|
71
|
+
torch.mul(vecs[i][k], vecs[i][k]).sum().data.cpu()
|
|
72
|
+
)
|
|
73
|
+
if (j, j) not in dps:
|
|
74
|
+
dps[(j, j)] = 0.0
|
|
75
|
+
for k in range(len(vecs[i])):
|
|
76
|
+
dps[(j, j)] += (
|
|
77
|
+
torch.mul(vecs[j][k], vecs[j][k]).sum().data.cpu()
|
|
78
|
+
)
|
|
79
|
+
c, d = MinNormSolver._min_norm_element_from2(
|
|
80
|
+
dps[(i, i)], dps[(i, j)], dps[(j, j)]
|
|
81
|
+
)
|
|
82
|
+
if d < dmin:
|
|
83
|
+
dmin = d
|
|
84
|
+
sol = [(i, j), c, d]
|
|
85
|
+
return sol, dps
|
|
86
|
+
|
|
87
|
+
def _projection2simplex(y):
|
|
88
|
+
R"""
|
|
89
|
+
Given y, it solves argmin_z |y-z|_2 st \sum z = 1 , 1 >= z_i >= 0 for all i
|
|
90
|
+
"""
|
|
91
|
+
m = len(y)
|
|
92
|
+
sorted_y = np.flip(np.sort(y), axis=0)
|
|
93
|
+
tmpsum = 0.0
|
|
94
|
+
tmax_f = (np.sum(y) - 1.0) / m
|
|
95
|
+
for i in range(m - 1):
|
|
96
|
+
tmpsum += sorted_y[i]
|
|
97
|
+
tmax = (tmpsum - 1) / (i + 1.0)
|
|
98
|
+
if tmax > sorted_y[i + 1]:
|
|
99
|
+
tmax_f = tmax
|
|
100
|
+
break
|
|
101
|
+
return np.maximum(y - tmax_f, np.zeros(y.shape))
|
|
102
|
+
|
|
103
|
+
def _next_point(cur_val, grad, n):
|
|
104
|
+
proj_grad = grad - (np.sum(grad) / n)
|
|
105
|
+
tm1 = -1.0 * cur_val[proj_grad < 0] / proj_grad[proj_grad < 0]
|
|
106
|
+
tm2 = (1.0 - cur_val[proj_grad > 0]) / (proj_grad[proj_grad > 0])
|
|
107
|
+
|
|
108
|
+
skippers = np_sum(tm1 < 1e-7) + np_sum(tm2 < 1e-7)
|
|
109
|
+
t = 1
|
|
110
|
+
if len(tm1[tm1 > 1e-7]) > 0:
|
|
111
|
+
t = np.min(to_numpy(tm1[tm1 > 1e-7]))
|
|
112
|
+
if len(tm2[tm2 > 1e-7]) > 0:
|
|
113
|
+
t = min(t, np.min(to_numpy(tm2[tm2 > 1e-7])))
|
|
114
|
+
|
|
115
|
+
next_point = proj_grad * t + to_numpy(cur_val)
|
|
116
|
+
next_point = MinNormSolver._projection2simplex(next_point)
|
|
117
|
+
return next_point
|
|
118
|
+
|
|
119
|
+
def find_min_norm_element(vecs):
|
|
120
|
+
R"""
|
|
121
|
+
Given a list of vectors (vecs), this method finds the minimum norm element in the convex hull
|
|
122
|
+
as min |u|_2 st. u = \sum c_i vecs[i] and \sum c_i = 1.
|
|
123
|
+
It is quite geometric, and the main idea is the fact that if d_{ij} = min |u|_2 st u = c x_i + (1-c) x_j; the solution lies in (0, d_{i,j})
|
|
124
|
+
Hence, we find the best 2-task solution, and then run the projected gradient descent until convergence
|
|
125
|
+
"""
|
|
126
|
+
# Solution lying at the combination of two points
|
|
127
|
+
dps = {}
|
|
128
|
+
init_sol, dps = MinNormSolver._min_norm_2d(vecs, dps)
|
|
129
|
+
|
|
130
|
+
n = len(vecs)
|
|
131
|
+
sol_vec = np.zeros(n)
|
|
132
|
+
sol_vec[init_sol[0][0]] = init_sol[1]
|
|
133
|
+
sol_vec[init_sol[0][1]] = 1 - init_sol[1]
|
|
134
|
+
|
|
135
|
+
if n < 3:
|
|
136
|
+
# This is optimal for n=2, so return the solution
|
|
137
|
+
return sol_vec, init_sol[2]
|
|
138
|
+
|
|
139
|
+
iter_count = 0
|
|
140
|
+
|
|
141
|
+
grad_mat = np.zeros((n, n))
|
|
142
|
+
for i in range(n):
|
|
143
|
+
for j in range(n):
|
|
144
|
+
grad_mat[i, j] = dps[(i, j)]
|
|
145
|
+
|
|
146
|
+
while iter_count < MinNormSolver.MAX_ITER:
|
|
147
|
+
grad_dir = -1.0 * np.dot(grad_mat, sol_vec)
|
|
148
|
+
new_point = MinNormSolver._next_point(sol_vec, grad_dir, n)
|
|
149
|
+
# Re-compute the inner products for line search
|
|
150
|
+
v1v1 = 0.0
|
|
151
|
+
v1v2 = 0.0
|
|
152
|
+
v2v2 = 0.0
|
|
153
|
+
for i in range(n):
|
|
154
|
+
for j in range(n):
|
|
155
|
+
v1v1 += sol_vec[i] * sol_vec[j] * dps[(i, j)]
|
|
156
|
+
v1v2 += sol_vec[i] * new_point[j] * dps[(i, j)]
|
|
157
|
+
v2v2 += new_point[i] * new_point[j] * dps[(i, j)]
|
|
158
|
+
nc, nd = MinNormSolver._min_norm_element_from2(v1v1, v1v2, v2v2)
|
|
159
|
+
new_sol_vec = nc * sol_vec + (1 - nc) * new_point
|
|
160
|
+
change = new_sol_vec - sol_vec
|
|
161
|
+
if np_sum(np.abs(change)) < MinNormSolver.STOP_CRIT:
|
|
162
|
+
return sol_vec, nd
|
|
163
|
+
sol_vec = new_sol_vec
|
|
164
|
+
|
|
165
|
+
def find_min_norm_element_FW(vecs):
|
|
166
|
+
R"""
|
|
167
|
+
Given a list of vectors (vecs), this method finds the minimum norm element in the convex hull
|
|
168
|
+
as min |u|_2 st. u = \sum c_i vecs[i] and \sum c_i = 1.
|
|
169
|
+
It is quite geometric, and the main idea is the fact that if d_{ij} = min |u|_2 st u = c x_i + (1-c) x_j; the solution lies in (0, d_{i,j})
|
|
170
|
+
Hence, we find the best 2-task solution, and then run the Frank Wolfe until convergence
|
|
171
|
+
"""
|
|
172
|
+
# Solution lying at the combination of two points
|
|
173
|
+
dps = {}
|
|
174
|
+
init_sol, dps = MinNormSolver._min_norm_2d(vecs, dps)
|
|
175
|
+
|
|
176
|
+
n = len(vecs)
|
|
177
|
+
sol_vec = np.zeros(n)
|
|
178
|
+
sol_vec[init_sol[0][0]] = init_sol[1]
|
|
179
|
+
sol_vec[init_sol[0][1]] = 1 - init_sol[1]
|
|
180
|
+
|
|
181
|
+
if n < 3:
|
|
182
|
+
# This is optimal for n=2, so return the solution
|
|
183
|
+
return sol_vec, init_sol[2]
|
|
184
|
+
|
|
185
|
+
iter_count = 0
|
|
186
|
+
|
|
187
|
+
grad_mat = np.zeros((n, n))
|
|
188
|
+
for i in range(n):
|
|
189
|
+
for j in range(n):
|
|
190
|
+
grad_mat[i, j] = dps[(i, j)]
|
|
191
|
+
|
|
192
|
+
while iter_count < MinNormSolver.MAX_ITER:
|
|
193
|
+
t_iter = np.argmin(np.dot(grad_mat, sol_vec))
|
|
194
|
+
|
|
195
|
+
v1v1 = np.dot(sol_vec, np.dot(grad_mat, sol_vec))
|
|
196
|
+
v1v2 = np.dot(sol_vec, grad_mat[:, t_iter])
|
|
197
|
+
v2v2 = grad_mat[t_iter, t_iter]
|
|
198
|
+
|
|
199
|
+
nc, nd = MinNormSolver._min_norm_element_from2(v1v1, v1v2, v2v2)
|
|
200
|
+
new_sol_vec = nc * sol_vec
|
|
201
|
+
new_sol_vec[t_iter] += 1 - nc
|
|
202
|
+
|
|
203
|
+
change = new_sol_vec - sol_vec
|
|
204
|
+
if np_sum(np.abs(change)) < MinNormSolver.STOP_CRIT:
|
|
205
|
+
return sol_vec, nd
|
|
206
|
+
sol_vec = new_sol_vec
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def gradient_normalizers(grads, losses, normalization_type):
|
|
210
|
+
gn = {}
|
|
211
|
+
if normalization_type == "l2":
|
|
212
|
+
for t in grads:
|
|
213
|
+
gn[t] = np.sqrt(np.sum([gr.pow(2).sum().data.cpu() for gr in grads[t]]))
|
|
214
|
+
elif normalization_type == "loss":
|
|
215
|
+
for t in grads:
|
|
216
|
+
gn[t] = losses[t]
|
|
217
|
+
elif normalization_type == "loss+":
|
|
218
|
+
for t in grads:
|
|
219
|
+
gn[t] = losses[t] * np.sqrt(
|
|
220
|
+
np.sum([gr.pow(2).sum().data.cpu() for gr in grads[t]])
|
|
221
|
+
)
|
|
222
|
+
elif normalization_type == "none":
|
|
223
|
+
for t in grads:
|
|
224
|
+
gn[t] = 1.0
|
|
225
|
+
else:
|
|
226
|
+
print("ERROR: Invalid Normalization Type")
|
|
227
|
+
return gn
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor, nn
|
|
5
|
+
|
|
6
|
+
from fusion_bench.utils.parameters import state_dict_to_vector
|
|
7
|
+
from fusion_bench.utils.state_dict_arithmetic import state_dict_sub
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _svd(w: Tensor, full_matrices=True) -> Tuple[Tensor, Tensor, Tensor]:
|
|
11
|
+
"""
|
|
12
|
+
Perform Singular Value Decomposition (SVD) on a tensor.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
w (Tensor): The input tensor.
|
|
16
|
+
full_matrices (bool): Whether to compute the full-sized U and V matrices.
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
Tuple[Tensor, Tensor, Tensor]: The U, S, and V matrices from SVD.
|
|
20
|
+
"""
|
|
21
|
+
u, s, vh = torch.linalg.svd(
|
|
22
|
+
w, full_matrices=full_matrices, driver="gesvd" if w.is_cuda else None
|
|
23
|
+
)
|
|
24
|
+
v = vh.T
|
|
25
|
+
return u, s, v
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def svd(
|
|
29
|
+
w: Tensor, full_matrices=True, accelerator=None
|
|
30
|
+
) -> Tuple[Tensor, Tensor, Tensor]:
|
|
31
|
+
"""
|
|
32
|
+
Perform SVD on a tensor, optionally using a specified accelerator.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
w (Tensor): The input tensor.
|
|
36
|
+
full_matrices (bool): Whether to compute the full-sized U and V matrices.
|
|
37
|
+
accelerator (str): The device to perform the computation on.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
Tuple[Tensor, Tensor, Tensor]: The U, S, and V matrices from SVD.
|
|
41
|
+
"""
|
|
42
|
+
if accelerator is None:
|
|
43
|
+
return _svd(w, full_matrices=full_matrices)
|
|
44
|
+
original_device = w.device
|
|
45
|
+
w = w.to(accelerator)
|
|
46
|
+
u, s, v = _svd(w)
|
|
47
|
+
return u.to(original_device), s.to(original_device), v.to(original_device)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def frobenius_inner_product(w1: Tensor, w2: Tensor) -> Tensor:
|
|
51
|
+
return torch.trace(w1.T @ w2)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def is_leaf_module(module: nn.Module) -> bool:
|
|
55
|
+
return len(list(module.children())) == 0
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def get_task_vector_norm(model: nn.Module, pretrained_model: nn.Module) -> Tensor:
|
|
59
|
+
"""
|
|
60
|
+
Get the vector norm of the task model.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
model (nn.Module): The task model.
|
|
64
|
+
pretrained_model (nn.Module): The pretrained model.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
Tensor: The vector norm of the task model.
|
|
68
|
+
"""
|
|
69
|
+
return torch.linalg.norm(
|
|
70
|
+
state_dict_to_vector(
|
|
71
|
+
state_dict_sub(model.state_dict(), pretrained_model.state_dict())
|
|
72
|
+
)
|
|
73
|
+
)
|
|
@@ -64,10 +64,12 @@ class SimpleAverageAlgorithm(
|
|
|
64
64
|
SimpleProfilerMixin,
|
|
65
65
|
BaseAlgorithm,
|
|
66
66
|
):
|
|
67
|
-
def __init__(self, show_pbar: bool = False, **kwargs):
|
|
67
|
+
def __init__(self, show_pbar: bool = False, inplace: bool = True, **kwargs):
|
|
68
68
|
"""
|
|
69
69
|
Args:
|
|
70
70
|
show_pbar (bool): If True, shows a progress bar during model loading and merging. Default is False.
|
|
71
|
+
inplace (bool): If True, overwrites the weights of the first model in the model pool.
|
|
72
|
+
If False, creates a new model for the merged weights. Default is True.
|
|
71
73
|
"""
|
|
72
74
|
super().__init__(**kwargs)
|
|
73
75
|
|
|
@@ -104,12 +106,12 @@ class SimpleAverageAlgorithm(
|
|
|
104
106
|
with self.profile("merge weights"):
|
|
105
107
|
if sd is None:
|
|
106
108
|
# Initialize the state dictionary with the first model's state dictionary
|
|
107
|
-
sd = model.state_dict(
|
|
108
|
-
forward_model = model
|
|
109
|
+
sd = model.state_dict()
|
|
110
|
+
forward_model = model if self.inplace else deepcopy(model)
|
|
109
111
|
else:
|
|
110
112
|
# Add the current model's state dictionary to the accumulated state dictionary
|
|
111
113
|
sd = state_dict_add(
|
|
112
|
-
sd, model.state_dict(
|
|
114
|
+
sd, model.state_dict(), show_pbar=self.show_pbar
|
|
113
115
|
)
|
|
114
116
|
with self.profile("merge weights"):
|
|
115
117
|
# Divide the accumulated state dictionary by the number of models to get the average
|
|
@@ -111,6 +111,15 @@ class LightningFabricMixin:
|
|
|
111
111
|
"""
|
|
112
112
|
if self.fabric is not None and len(self.fabric._loggers) > 0:
|
|
113
113
|
log_dir = self.fabric.logger.log_dir
|
|
114
|
+
|
|
115
|
+
# Special handling for SwanLabLogger to get the correct log directory
|
|
116
|
+
if (
|
|
117
|
+
log_dir is None
|
|
118
|
+
and self.fabric.logger.__class__.__name__ == "SwanLabLogger"
|
|
119
|
+
):
|
|
120
|
+
log_dir = self.fabric.logger.save_dir or self.fabric.logger._logdir
|
|
121
|
+
|
|
122
|
+
assert log_dir is not None, "log_dir should not be None"
|
|
114
123
|
if self.fabric.is_global_zero and not os.path.exists(log_dir):
|
|
115
124
|
os.makedirs(log_dir, exist_ok=True)
|
|
116
125
|
return log_dir
|
|
@@ -8,6 +8,7 @@ from copy import deepcopy
|
|
|
8
8
|
from typing import Any, Dict, Optional, TypeAlias, Union, cast # noqa: F401
|
|
9
9
|
|
|
10
10
|
import peft
|
|
11
|
+
from lightning_utilities.core.rank_zero import rank_zero_only
|
|
11
12
|
from omegaconf import DictConfig, OmegaConf, flag_override
|
|
12
13
|
from torch import nn
|
|
13
14
|
from torch.nn.modules import Module
|
|
@@ -342,7 +343,7 @@ class CausalLMPool(BaseModelPool):
|
|
|
342
343
|
)
|
|
343
344
|
|
|
344
345
|
# Create and save model card if algorithm_config is provided
|
|
345
|
-
if algorithm_config is not None:
|
|
346
|
+
if algorithm_config is not None and rank_zero_only.rank == 0:
|
|
346
347
|
if description is None:
|
|
347
348
|
description = "Model created using FusionBench."
|
|
348
349
|
model_card_str = create_default_model_card(
|