fusion-bench 0.2.8__py3-none-any.whl → 0.2.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__main__.py +4 -0
- fusion_bench/dataset/fer2013.py +1 -0
- fusion_bench/method/__init__.py +26 -4
- fusion_bench/method/classification/__init__.py +1 -0
- fusion_bench/method/classification/clip_finetune.py +1 -3
- fusion_bench/method/classification/continual_clip_finetune.py +297 -0
- fusion_bench/method/dare/__init__.py +1 -0
- fusion_bench/method/dare/task_arithmetic.py +14 -7
- fusion_bench/method/dare/ties_merging.py +100 -0
- fusion_bench/method/isotropic_merging/__init__.py +15 -0
- fusion_bench/method/isotropic_merging/iso.py +114 -0
- fusion_bench/method/isotropic_merging/iso_utils.py +176 -0
- fusion_bench/method/opcm/__init__.py +4 -0
- fusion_bench/method/opcm/opcm.py +277 -0
- fusion_bench/method/opcm/task_arithmetic.py +115 -0
- fusion_bench/method/opcm/ties_merging.py +156 -0
- fusion_bench/method/opcm/utils.py +73 -0
- fusion_bench/method/opcm/weight_average.py +120 -0
- fusion_bench/method/slerp/slerp.py +1 -1
- fusion_bench/method/task_singular_vector/TSVM.py +22 -2
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +91 -93
- fusion_bench/method/ties_merging/ties_merging.py +10 -0
- fusion_bench/metrics/continual_learning/backward_transfer.py +22 -0
- fusion_bench/mixins/clip_classification.py +4 -1
- fusion_bench/programs/fabric_fusion_program.py +22 -11
- fusion_bench/scripts/cli.py +1 -0
- fusion_bench/taskpool/base_pool.py +1 -1
- fusion_bench/taskpool/clip_vision/taskpool.py +12 -7
- fusion_bench/utils/__init__.py +2 -1
- fusion_bench/utils/dict.py +43 -0
- fusion_bench/utils/expr.py +90 -0
- fusion_bench/utils/fabric.py +17 -0
- fusion_bench/utils/instantiate.py +7 -1
- fusion_bench/utils/json.py +30 -0
- fusion_bench/utils/parameters.py +27 -7
- fusion_bench/utils/path.py +15 -0
- fusion_bench/utils/plot/color_data.py +1726 -0
- fusion_bench/utils/rich_utils.py +15 -0
- fusion_bench/utils/set.py +8 -0
- fusion_bench/utils/tensorboard.py +51 -0
- {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/METADATA +17 -18
- {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/RECORD +58 -29
- {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/WHEEL +1 -1
- fusion_bench_config/method/classification/clip_continual_finetune.yaml +28 -0
- fusion_bench_config/method/classification/clip_finetune.yaml +26 -0
- fusion_bench_config/method/clip_finetune.yaml +2 -2
- fusion_bench_config/method/dare/ties_merging.yaml +15 -0
- fusion_bench_config/method/isotropic_merging/iso_c.yaml +4 -0
- fusion_bench_config/method/isotropic_merging/iso_cts.yaml +5 -0
- fusion_bench_config/method/opcm/opcm.yaml +12 -0
- fusion_bench_config/method/opcm/task_arithmetic.yaml +12 -0
- fusion_bench_config/method/opcm/ties_merging.yaml +18 -0
- fusion_bench_config/method/opcm/weight_average.yaml +10 -0
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +18 -0
- {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/LICENSE +0 -0
- {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/top_level.txt +0 -0
|
@@ -116,7 +116,10 @@ def sum_svd_dict(svd_dict, config):
|
|
|
116
116
|
|
|
117
117
|
###############
|
|
118
118
|
##### LOSSLESS Orthogonalization
|
|
119
|
-
def compute_and_sum_svd_mem_reduction_lossless(
|
|
119
|
+
def compute_and_sum_svd_mem_reduction_lossless(
|
|
120
|
+
task_vectors: List[StateDictType],
|
|
121
|
+
accelerator: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
|
|
122
|
+
):
|
|
120
123
|
"""
|
|
121
124
|
Computes the Singular Value Decomposition (SVD) for each task vector and merge the results.
|
|
122
125
|
|
|
@@ -129,40 +132,38 @@ def compute_and_sum_svd_mem_reduction_lossless(task_vectors, config):
|
|
|
129
132
|
|
|
130
133
|
Args:
|
|
131
134
|
task_vectors (list): A list of task vectors, where each task vector is a dictionary containing the vectors for each task.
|
|
132
|
-
|
|
133
|
-
|
|
135
|
+
accelerator (torch.device): The device to use for the computation.
|
|
134
136
|
Returns:
|
|
135
137
|
dict: A dictionary containing the new vectors after summing the SVD components.
|
|
136
138
|
"""
|
|
137
139
|
# becareful wit vit-l on 20 task it does not fit in GPU or in 64 GB RAM (try without last layer)
|
|
138
|
-
device = config.device
|
|
139
140
|
print("Computing SVD...")
|
|
140
141
|
with torch.no_grad():
|
|
141
142
|
new_vector = {}
|
|
142
|
-
for key in task_vectors[0]
|
|
143
|
+
for key in task_vectors[0]:
|
|
144
|
+
original_device = task_vectors[0][key].device
|
|
143
145
|
new_vector[key] = {}
|
|
144
|
-
for i,
|
|
145
|
-
|
|
146
|
-
):
|
|
147
|
-
vec = task_vector.vector[key].to(device)
|
|
146
|
+
for i, task_vector in enumerate(task_vectors):
|
|
147
|
+
vec = task_vector[key].to(accelerator)
|
|
148
148
|
|
|
149
|
-
if (
|
|
150
|
-
len(task_vector.vector[key].shape) == 2
|
|
151
|
-
and "text_projection" not in key
|
|
152
|
-
):
|
|
149
|
+
if len(task_vector[key].shape) == 2 and "text_projection" not in key:
|
|
153
150
|
|
|
154
151
|
u, s, v = torch.linalg.svd(vec, full_matrices=False)
|
|
155
152
|
|
|
156
153
|
if i == 0:
|
|
157
154
|
print(f"Computed SVD for {key}...")
|
|
158
155
|
sum_u = torch.zeros(
|
|
159
|
-
u.shape[0],
|
|
156
|
+
u.shape[0],
|
|
157
|
+
u.shape[1] * len(task_vectors),
|
|
158
|
+
device=accelerator,
|
|
160
159
|
)
|
|
161
160
|
sum_s = torch.zeros(
|
|
162
|
-
s.shape[0] *
|
|
161
|
+
s.shape[0] * len(task_vectors), device=accelerator
|
|
163
162
|
)
|
|
164
163
|
sum_v = torch.zeros(
|
|
165
|
-
v.shape[0] *
|
|
164
|
+
v.shape[0] * len(task_vectors),
|
|
165
|
+
v.shape[1],
|
|
166
|
+
device=accelerator,
|
|
166
167
|
)
|
|
167
168
|
reduced_index_s = s.shape[0]
|
|
168
169
|
|
|
@@ -184,7 +185,7 @@ def compute_and_sum_svd_mem_reduction_lossless(task_vectors, config):
|
|
|
184
185
|
else:
|
|
185
186
|
new_vector[key] += (vec - new_vector[key]) / (i + 1)
|
|
186
187
|
|
|
187
|
-
if len(task_vector
|
|
188
|
+
if len(task_vector[key].shape) == 2 and "text_projection" not in key:
|
|
188
189
|
u_u, s_u, v_u = torch.linalg.svd(sum_u, full_matrices=False)
|
|
189
190
|
u_v, s_v, v_v = torch.linalg.svd(sum_v, full_matrices=False)
|
|
190
191
|
|
|
@@ -197,13 +198,16 @@ def compute_and_sum_svd_mem_reduction_lossless(task_vectors, config):
|
|
|
197
198
|
v_v,
|
|
198
199
|
)
|
|
199
200
|
)
|
|
200
|
-
|
|
201
|
+
new_vector[key] = new_vector[key].to(original_device, non_blocking=True)
|
|
201
202
|
return new_vector
|
|
202
203
|
|
|
203
204
|
|
|
204
205
|
###############
|
|
205
206
|
##### LOSSLESS EIGENDECOMP
|
|
206
|
-
def compute_and_sum_svd_mem_reduction_lossless_eigen(
|
|
207
|
+
def compute_and_sum_svd_mem_reduction_lossless_eigen(
|
|
208
|
+
task_vectors: List[StateDictType],
|
|
209
|
+
accelerator: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
|
|
210
|
+
):
|
|
207
211
|
"""
|
|
208
212
|
Computes the Singular Value Decomposition (SVD) for each task vector and merge the results.
|
|
209
213
|
|
|
@@ -216,40 +220,39 @@ def compute_and_sum_svd_mem_reduction_lossless_eigen(task_vectors, config):
|
|
|
216
220
|
|
|
217
221
|
Args:
|
|
218
222
|
task_vectors (list): A list of task vectors, where each task vector is a dictionary containing the vectors for each task.
|
|
219
|
-
|
|
223
|
+
accelerator (torch.device): The device to use for the computation.
|
|
220
224
|
|
|
221
225
|
Returns:
|
|
222
226
|
dict: A dictionary containing the new vectors after merging the SVD components.
|
|
223
227
|
"""
|
|
224
228
|
# becareful wit vit-l on 20 task it does not fit in GPU or in 64 GB RAM (try without last layer)
|
|
225
|
-
device = config.device
|
|
226
229
|
print("Computing SVD...")
|
|
227
230
|
with torch.no_grad():
|
|
228
231
|
new_vector = {}
|
|
229
|
-
for key in task_vectors[0]
|
|
232
|
+
for key in task_vectors[0]:
|
|
233
|
+
original_device = task_vectors[0][key].device
|
|
230
234
|
new_vector[key] = {}
|
|
231
|
-
for i,
|
|
232
|
-
|
|
233
|
-
):
|
|
234
|
-
vec = task_vector.vector[key].to(device)
|
|
235
|
+
for i, task_vector in enumerate(task_vectors):
|
|
236
|
+
vec = task_vector[key].to(accelerator)
|
|
235
237
|
|
|
236
|
-
if (
|
|
237
|
-
len(task_vector.vector[key].shape) == 2
|
|
238
|
-
and "text_projection" not in key
|
|
239
|
-
):
|
|
238
|
+
if len(task_vector[key].shape) == 2 and "text_projection" not in key:
|
|
240
239
|
|
|
241
240
|
u, s, v = torch.linalg.svd(vec, full_matrices=False)
|
|
242
241
|
|
|
243
242
|
if i == 0:
|
|
244
243
|
print(f"Computed SVD for {key}...")
|
|
245
244
|
sum_u = torch.zeros(
|
|
246
|
-
u.shape[0],
|
|
245
|
+
u.shape[0],
|
|
246
|
+
u.shape[1] * len(task_vectors),
|
|
247
|
+
device=accelerator,
|
|
247
248
|
)
|
|
248
249
|
sum_s = torch.zeros(
|
|
249
|
-
s.shape[0] *
|
|
250
|
+
s.shape[0] * len(task_vectors), device=accelerator
|
|
250
251
|
)
|
|
251
252
|
sum_v = torch.zeros(
|
|
252
|
-
v.shape[0] *
|
|
253
|
+
v.shape[0] * len(task_vectors),
|
|
254
|
+
v.shape[1],
|
|
255
|
+
device=accelerator,
|
|
253
256
|
)
|
|
254
257
|
reduced_index_s = s.shape[0]
|
|
255
258
|
|
|
@@ -271,7 +274,7 @@ def compute_and_sum_svd_mem_reduction_lossless_eigen(task_vectors, config):
|
|
|
271
274
|
else:
|
|
272
275
|
new_vector[key] += (vec - new_vector[key]) / (i + 1)
|
|
273
276
|
|
|
274
|
-
if len(task_vector
|
|
277
|
+
if len(task_vector[key].shape) == 2 and "text_projection" not in key:
|
|
275
278
|
sum_s, indices = torch.sort(sum_s, stable=True)
|
|
276
279
|
|
|
277
280
|
sum_u = torch.index_select(sum_u, 1, indices)
|
|
@@ -293,12 +296,14 @@ def compute_and_sum_svd_mem_reduction_lossless_eigen(task_vectors, config):
|
|
|
293
296
|
|
|
294
297
|
new_vector[key] = torch.linalg.multi_dot( # bool_mask *
|
|
295
298
|
(
|
|
299
|
+
sum_u,
|
|
296
300
|
u_orth,
|
|
297
301
|
torch.diag(sum_s),
|
|
298
302
|
v_orth,
|
|
303
|
+
sum_v,
|
|
299
304
|
)
|
|
300
305
|
)
|
|
301
|
-
|
|
306
|
+
new_vector[key] = new_vector[key].to(original_device, non_blocking=True)
|
|
302
307
|
return new_vector
|
|
303
308
|
|
|
304
309
|
|
|
@@ -394,7 +399,10 @@ def compute_and_sum_svd_mem_reduction(
|
|
|
394
399
|
|
|
395
400
|
###############
|
|
396
401
|
#### TSV Merge Eigendecomp
|
|
397
|
-
def compute_and_sum_svd_mem_reduction_2(
|
|
402
|
+
def compute_and_sum_svd_mem_reduction_2(
|
|
403
|
+
task_vectors: List[StateDictType],
|
|
404
|
+
accelerator: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
|
|
405
|
+
):
|
|
398
406
|
"""
|
|
399
407
|
Computes the Singular Value Decomposition (SVD) for each vector in the task_vectors,
|
|
400
408
|
reduces the dimensionality of the vectors based on the sv_reduction factor, and concatenate
|
|
@@ -404,36 +412,30 @@ def compute_and_sum_svd_mem_reduction_2(task_vectors, config):
|
|
|
404
412
|
Args:
|
|
405
413
|
task_vectors (list): A list of task vector objects, where each object contains a
|
|
406
414
|
dictionary of vectors.
|
|
407
|
-
|
|
408
|
-
- DATASETS (list): List of datasets.
|
|
409
|
-
- device (torch.device): The device to perform computations on.
|
|
415
|
+
accelerator (torch.device): The device to use for the computation.
|
|
410
416
|
|
|
411
417
|
Returns:
|
|
412
418
|
dict: A dictionary containing the new vectors after SVD computation and merging.
|
|
413
419
|
"""
|
|
414
|
-
sv_reduction = 1 / len(
|
|
415
|
-
|
|
420
|
+
sv_reduction = 1 / len(task_vectors)
|
|
421
|
+
|
|
416
422
|
print("Computing SVD...")
|
|
417
423
|
with torch.no_grad():
|
|
418
424
|
new_vector = {}
|
|
419
|
-
for key in task_vectors[0]
|
|
425
|
+
for key in task_vectors[0]:
|
|
426
|
+
original_device = task_vectors[0][key].device
|
|
420
427
|
new_vector[key] = {}
|
|
421
|
-
for i,
|
|
422
|
-
|
|
423
|
-
):
|
|
424
|
-
vec = task_vector.vector[key].to(device)
|
|
428
|
+
for i, task_vector in enumerate(task_vectors):
|
|
429
|
+
vec = task_vector[key].to(accelerator)
|
|
425
430
|
|
|
426
|
-
if (
|
|
427
|
-
len(task_vector.vector[key].shape) == 2
|
|
428
|
-
and "text_projection" not in key
|
|
429
|
-
):
|
|
431
|
+
if len(task_vector[key].shape) == 2 and "text_projection" not in key:
|
|
430
432
|
u, s, v = torch.linalg.svd(vec, full_matrices=False)
|
|
431
433
|
|
|
432
434
|
if i == 0:
|
|
433
435
|
print(f"Computed SVD for {key}...")
|
|
434
|
-
sum_u = torch.zeros_like(u, device=
|
|
435
|
-
sum_s = torch.zeros_like(s, device=
|
|
436
|
-
sum_v = torch.zeros_like(v, device=
|
|
436
|
+
sum_u = torch.zeros_like(u, device=accelerator)
|
|
437
|
+
sum_s = torch.zeros_like(s, device=accelerator)
|
|
438
|
+
sum_v = torch.zeros_like(v, device=accelerator)
|
|
437
439
|
reduced_index_s = int(s.shape[0] * sv_reduction)
|
|
438
440
|
|
|
439
441
|
# select only the first reduced_index_s columns of u and place them
|
|
@@ -454,7 +456,7 @@ def compute_and_sum_svd_mem_reduction_2(task_vectors, config):
|
|
|
454
456
|
else:
|
|
455
457
|
new_vector[key] += (vec - new_vector[key]) / (i + 1)
|
|
456
458
|
|
|
457
|
-
if len(task_vector
|
|
459
|
+
if len(task_vector[key].shape) == 2 and "text_projection" not in key:
|
|
458
460
|
sum_s, indices = torch.sort(sum_s, stable=True)
|
|
459
461
|
|
|
460
462
|
sum_u = torch.index_select(sum_u, 1, indices)
|
|
@@ -483,13 +485,17 @@ def compute_and_sum_svd_mem_reduction_2(task_vectors, config):
|
|
|
483
485
|
sum_v,
|
|
484
486
|
)
|
|
485
487
|
)
|
|
488
|
+
new_vector[key] = new_vector[key].to(original_device, non_blocking=True)
|
|
486
489
|
|
|
487
490
|
return new_vector
|
|
488
491
|
|
|
489
492
|
|
|
490
493
|
###############
|
|
491
494
|
#### Rank Reduction TV
|
|
492
|
-
def compute_and_sum_svd_mem_reduction_rank_reduction(
|
|
495
|
+
def compute_and_sum_svd_mem_reduction_rank_reduction(
|
|
496
|
+
task_vectors: List[StateDictType],
|
|
497
|
+
accelerator: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
|
|
498
|
+
):
|
|
493
499
|
"""
|
|
494
500
|
Compute and sum the Singular Value Decomposition (SVD) of task vectors with rank reduction.
|
|
495
501
|
|
|
@@ -499,36 +505,29 @@ def compute_and_sum_svd_mem_reduction_rank_reduction(task_vectors, config):
|
|
|
499
505
|
Args:
|
|
500
506
|
task_vectors (list): A list of task vector objects. Each object should have a `vector` attribute
|
|
501
507
|
which is a dictionary where keys are vector names and values are tensors.
|
|
502
|
-
|
|
503
|
-
- DATASETS (list): List of datasets.
|
|
504
|
-
- device (torch.device): The device to perform computations on.
|
|
508
|
+
accelerator (torch.device): The device to use for the computation.
|
|
505
509
|
|
|
506
510
|
Returns:
|
|
507
511
|
dict: A dictionary containing the new vectors after SVD computation and summation.
|
|
508
512
|
"""
|
|
509
|
-
sv_reduction = 1 / len(
|
|
510
|
-
device = config.device
|
|
513
|
+
sv_reduction = 1 / len(task_vectors)
|
|
511
514
|
print("Computing SVD...")
|
|
512
515
|
with torch.no_grad():
|
|
513
516
|
new_vector = {}
|
|
514
|
-
for key in task_vectors[0]
|
|
517
|
+
for key in task_vectors[0]:
|
|
518
|
+
original_device = task_vectors[0][key].device
|
|
515
519
|
new_vector[key] = {}
|
|
516
|
-
for i,
|
|
517
|
-
|
|
518
|
-
):
|
|
519
|
-
vec = task_vector.vector[key].to(device)
|
|
520
|
+
for i, task_vector in enumerate(task_vectors):
|
|
521
|
+
vec = task_vector[key].to(accelerator)
|
|
520
522
|
|
|
521
|
-
if (
|
|
522
|
-
len(task_vector.vector[key].shape) == 2
|
|
523
|
-
and "text_projection" not in key
|
|
524
|
-
):
|
|
523
|
+
if len(task_vector[key].shape) == 2 and "text_projection" not in key:
|
|
525
524
|
u, s, v = torch.linalg.svd(vec, full_matrices=False)
|
|
526
525
|
|
|
527
526
|
if i == 0:
|
|
528
527
|
print(f"Computed SVD for {key}...")
|
|
529
|
-
sum_u = torch.zeros_like(u, device=
|
|
530
|
-
sum_s = torch.zeros_like(s, device=
|
|
531
|
-
sum_v = torch.zeros_like(v, device=
|
|
528
|
+
sum_u = torch.zeros_like(u, device=accelerator)
|
|
529
|
+
sum_s = torch.zeros_like(s, device=accelerator)
|
|
530
|
+
sum_v = torch.zeros_like(v, device=accelerator)
|
|
532
531
|
reduced_index_s = int(s.shape[0] * sv_reduction)
|
|
533
532
|
|
|
534
533
|
# select only the first reduced_index_s columns of u and place them
|
|
@@ -549,7 +548,7 @@ def compute_and_sum_svd_mem_reduction_rank_reduction(task_vectors, config):
|
|
|
549
548
|
else:
|
|
550
549
|
new_vector[key] += (vec - new_vector[key]) / (i + 1)
|
|
551
550
|
|
|
552
|
-
if len(task_vector
|
|
551
|
+
if len(task_vector[key].shape) == 2 and "text_projection" not in key:
|
|
553
552
|
new_vector[key] = torch.linalg.multi_dot(
|
|
554
553
|
(
|
|
555
554
|
sum_u,
|
|
@@ -557,26 +556,29 @@ def compute_and_sum_svd_mem_reduction_rank_reduction(task_vectors, config):
|
|
|
557
556
|
sum_v,
|
|
558
557
|
)
|
|
559
558
|
)
|
|
559
|
+
|
|
560
|
+
new_vector[key] = new_vector[key].to(original_device, non_blocking=True)
|
|
560
561
|
return new_vector
|
|
561
562
|
|
|
562
563
|
|
|
563
|
-
def compute_and_sum_svd_mem_reduction_dummy(
|
|
564
|
+
def compute_and_sum_svd_mem_reduction_dummy(
|
|
565
|
+
task_vectors: List[StateDictType],
|
|
566
|
+
accelerator: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
|
|
567
|
+
):
|
|
564
568
|
"""To perform dummy operations."""
|
|
565
|
-
sv_reduction = 1 /
|
|
569
|
+
sv_reduction = 1 / len(task_vectors)
|
|
566
570
|
print("Computing SVD...")
|
|
567
571
|
with torch.no_grad():
|
|
568
572
|
new_vector = {}
|
|
569
|
-
for key in task_vectors[0]
|
|
573
|
+
for key in task_vectors[0]:
|
|
574
|
+
original_device = task_vectors[0][key].device
|
|
570
575
|
new_vector[key] = {}
|
|
571
|
-
for i in
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
):
|
|
576
|
+
for i, task_vector in enumerate(task_vectors):
|
|
577
|
+
vec = task_vector[key].to(accelerator)
|
|
578
|
+
|
|
579
|
+
if len(task_vector[key].shape) == 2 and "text_projection" not in key:
|
|
576
580
|
if i == 0:
|
|
577
|
-
u, s, v = torch.linalg.svd(
|
|
578
|
-
task_vectors[0].vector[key], full_matrices=False
|
|
579
|
-
)
|
|
581
|
+
u, s, v = torch.linalg.svd(vec, full_matrices=False)
|
|
580
582
|
reduced_index_s = int(s.shape[0] * sv_reduction)
|
|
581
583
|
|
|
582
584
|
print(f"Computed SVD for {key}...")
|
|
@@ -620,16 +622,11 @@ def compute_and_sum_svd_mem_reduction_dummy(task_vectors, config):
|
|
|
620
622
|
|
|
621
623
|
else:
|
|
622
624
|
if i == 0:
|
|
623
|
-
new_vector[key] =
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
# task_vector.vector[key] - new_vector[key]
|
|
627
|
-
# ) / (i + 1)
|
|
625
|
+
new_vector[key] = vec.clone()
|
|
626
|
+
else:
|
|
627
|
+
new_vector[key] += (vec - new_vector[key]) / (i + 1)
|
|
628
628
|
|
|
629
|
-
if (
|
|
630
|
-
len(task_vectors[0].vector[key].shape) == 2
|
|
631
|
-
and "text_projection" not in key
|
|
632
|
-
):
|
|
629
|
+
if len(task_vector[key].shape) == 2 and "text_projection" not in key:
|
|
633
630
|
|
|
634
631
|
new_vector[key] = torch.linalg.multi_dot(
|
|
635
632
|
(
|
|
@@ -639,4 +636,5 @@ def compute_and_sum_svd_mem_reduction_dummy(task_vectors, config):
|
|
|
639
636
|
)
|
|
640
637
|
)
|
|
641
638
|
|
|
639
|
+
new_vector[key] = new_vector[key].to(original_device, non_blocking=True)
|
|
642
640
|
return new_vector
|
|
@@ -1,3 +1,13 @@
|
|
|
1
|
+
R"""
|
|
2
|
+
Overview of Ties-Merging:
|
|
3
|
+
|
|
4
|
+
1. Trim: For each task t, we trim the redundant parameters from the task vector $\tau_t$ to create $\hat{\tau}_t$ by keeping the top-k% values according to their magnitude and trimming the bottom $(100 - k)\%$ of the redundant parameters by resetting them to 0. This can be decomposed further as $\hat{\tau}_t = \hat{\gamma}_t \odot \hat{\mu}_t$.
|
|
5
|
+
|
|
6
|
+
2. Elect: Next, we create an aggregate elected sign vector $\gamma_m$ for the merged model that resolves the disagreements in the sign for each parameter p across different models. To create the elected sign vector, we choose the sign with the highest total magnitude across all relevant models. For each parameter $p \in \{1, 2, \ldots, d\}$, we separate the values $\{\hat{\tau}_t^p\}_{t=1}^n$ based on their sign $(+1$ or $-1)$ and take their sum to calculate the total mass (i.e., total magnitude) in the positive and the negative direction. We then assign $\gamma_m^p$ as the sign with greater total movement. This can be efficiently computed using $\gamma_m^p = \text{sgn}(\sum_{t=1}^n \hat{\tau}_t^p)$.
|
|
7
|
+
|
|
8
|
+
3. Disjoint Merge: Then, for each parameter p, we compute a disjoint mean by only keeping the parameter values from the models whose signs are the same as the aggregated elected sign and calculate their mean. Formally, let $A_p = \{t \in [n] \mid \hat{\gamma}_t^p = \gamma_m^p\}$, then $\tau_m^p = \frac{1}{|A_p|}\sum_{t\in A_p} \hat{\tau}_t^p$. Note that the disjoint mean always ignores the zero values.
|
|
9
|
+
"""
|
|
10
|
+
|
|
1
11
|
import logging
|
|
2
12
|
from typing import Dict, List, Literal, Mapping, Union # noqa: F401
|
|
3
13
|
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from typing import Dict
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def compute_backward_transfer(
|
|
7
|
+
acc_Ti: Dict[str, float], acc_ii: Dict[str, float]
|
|
8
|
+
) -> float:
|
|
9
|
+
R"""
|
|
10
|
+
Compute the backward transfer (BWT) of a model on a set of tasks.
|
|
11
|
+
|
|
12
|
+
Equation:
|
|
13
|
+
BWT = \frac{1}{n} \sum_{k=1}^{n} (acc_{Ti}[k] - acc_{ii}[k])
|
|
14
|
+
|
|
15
|
+
Returns:
|
|
16
|
+
float: The backward transfer of the model.
|
|
17
|
+
"""
|
|
18
|
+
assert set(acc_ii.keys()) == set(acc_Ti.keys())
|
|
19
|
+
bwt = 0
|
|
20
|
+
for task_name in acc_ii:
|
|
21
|
+
bwt += acc_Ti[task_name] - acc_ii[task_name]
|
|
22
|
+
return bwt / len(acc_ii)
|
|
@@ -161,12 +161,14 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
161
161
|
cache_dir, os.path.normpath(f"{task}_zeroshot_weights.pt")
|
|
162
162
|
)
|
|
163
163
|
if os.path.exists(cache_file):
|
|
164
|
-
log.info(f"Loading cached zeroshot weights for task: {task}")
|
|
165
164
|
zeroshot_weights = torch.load(
|
|
166
165
|
cache_file,
|
|
167
166
|
map_location="cpu",
|
|
168
167
|
weights_only=True,
|
|
169
168
|
).detach()
|
|
169
|
+
log.info(
|
|
170
|
+
f"Loadded cached zeroshot weights for task: {task}, shape: {zeroshot_weights.shape}"
|
|
171
|
+
)
|
|
170
172
|
else:
|
|
171
173
|
log.info(
|
|
172
174
|
f"Construct zero shot classification head for task: {task}"
|
|
@@ -180,6 +182,7 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
180
182
|
self.fabric.barrier()
|
|
181
183
|
self.zeroshot_weights[task] = self.fabric.broadcast(zeroshot_weights, src=0)
|
|
182
184
|
self.zeroshot_weights[task] = self.to_device(self.zeroshot_weights[task])
|
|
185
|
+
self.fabric.barrier()
|
|
183
186
|
|
|
184
187
|
del clip_classifier
|
|
185
188
|
if torch.cuda.is_available():
|
|
@@ -103,12 +103,13 @@ class FabricModelFusionProgram(
|
|
|
103
103
|
)
|
|
104
104
|
if compat_load_fn is not None:
|
|
105
105
|
compat_load_fn = import_object(compat_load_fn)
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
106
|
+
if rank_zero_only.rank == 0:
|
|
107
|
+
print_bordered(
|
|
108
|
+
OmegaConf.to_yaml(config),
|
|
109
|
+
title="instantiate compat object",
|
|
110
|
+
style="magenta",
|
|
111
|
+
code_style="yaml",
|
|
112
|
+
)
|
|
112
113
|
obj = compat_load_fn(config)
|
|
113
114
|
else:
|
|
114
115
|
raise ValueError(
|
|
@@ -159,7 +160,11 @@ class FabricModelFusionProgram(
|
|
|
159
160
|
print("No save path specified for the merged model. Skipping saving.")
|
|
160
161
|
|
|
161
162
|
def evaluate_merged_model(
|
|
162
|
-
self,
|
|
163
|
+
self,
|
|
164
|
+
taskpool: BaseTaskPool,
|
|
165
|
+
merged_model: Union[nn.Module, Dict, Iterable],
|
|
166
|
+
*args,
|
|
167
|
+
**kwargs,
|
|
163
168
|
):
|
|
164
169
|
"""
|
|
165
170
|
Evaluates the merged model using the provided task pool.
|
|
@@ -174,6 +179,8 @@ class FabricModelFusionProgram(
|
|
|
174
179
|
Args:
|
|
175
180
|
taskpool: The task pool used for evaluating the merged model.
|
|
176
181
|
merged_model: The merged model to be evaluated. It can be an instance of `nn.Module`, a dictionary, or an iterable.
|
|
182
|
+
*args: Additional positional arguments to be passed to the `evaluate` method of the taskpool.
|
|
183
|
+
**kwargs: Additional keyword arguments to be passed to the `evaluate` method of the taskpool.
|
|
177
184
|
|
|
178
185
|
Returns:
|
|
179
186
|
The evaluation report. The type of the report depends on the type of the merged model:
|
|
@@ -182,20 +189,20 @@ class FabricModelFusionProgram(
|
|
|
182
189
|
- If the merged model is an iterable, the report is a list of evaluation reports.
|
|
183
190
|
"""
|
|
184
191
|
if isinstance(merged_model, nn.Module):
|
|
185
|
-
report = taskpool.evaluate(merged_model)
|
|
192
|
+
report = taskpool.evaluate(merged_model, *args, **kwargs)
|
|
186
193
|
return report
|
|
187
194
|
elif isinstance(merged_model, Dict):
|
|
188
195
|
report = {}
|
|
189
196
|
for key, item in merged_model.items():
|
|
190
197
|
if isinstance(item, nn.Module):
|
|
191
|
-
report[key] = taskpool.evaluate(item)
|
|
198
|
+
report[key] = taskpool.evaluate(item, *args, **kwargs)
|
|
192
199
|
else:
|
|
193
200
|
# metadata
|
|
194
201
|
report[key] = item
|
|
195
202
|
return report
|
|
196
203
|
elif isinstance(merged_model, Iterable):
|
|
197
204
|
return [
|
|
198
|
-
self.evaluate_merged_model(taskpool, m)
|
|
205
|
+
self.evaluate_merged_model(taskpool, m, *args, **kwargs)
|
|
199
206
|
for m in tqdm(merged_model, desc="Evaluating models")
|
|
200
207
|
]
|
|
201
208
|
else:
|
|
@@ -272,7 +279,11 @@ class FabricModelFusionProgram(
|
|
|
272
279
|
"""
|
|
273
280
|
if self.log_dir is not None:
|
|
274
281
|
# make symlink to the hydra output directory
|
|
275
|
-
|
|
282
|
+
try:
|
|
283
|
+
hydra_output_dir = get_hydra_output_dir()
|
|
284
|
+
except Exception as e:
|
|
285
|
+
hydra_output_dir = None
|
|
286
|
+
|
|
276
287
|
if hydra_output_dir is not None:
|
|
277
288
|
os.makedirs(self.log_dir, exist_ok=True)
|
|
278
289
|
try:
|
fusion_bench/scripts/cli.py
CHANGED
|
@@ -238,11 +238,13 @@ class CLIPVisionModelTaskPool(
|
|
|
238
238
|
else:
|
|
239
239
|
test_loader = test_loader
|
|
240
240
|
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
241
|
+
pbar = tqdm(
|
|
242
|
+
test_loader,
|
|
243
|
+
desc=f"Evaluating {task_name}",
|
|
244
|
+
leave=False,
|
|
245
|
+
dynamic_ncols=True,
|
|
246
|
+
)
|
|
247
|
+
for batch in pbar:
|
|
246
248
|
inputs, targets = batch
|
|
247
249
|
outputs = classifier(
|
|
248
250
|
inputs,
|
|
@@ -309,11 +311,14 @@ class CLIPVisionModelTaskPool(
|
|
|
309
311
|
}
|
|
310
312
|
if name is not None:
|
|
311
313
|
report["model_info"]["name"] = name
|
|
312
|
-
|
|
314
|
+
|
|
315
|
+
# evaluate on each task
|
|
316
|
+
pbar = tqdm(
|
|
313
317
|
self.test_dataloaders.items(),
|
|
314
318
|
desc="Evaluating tasks",
|
|
315
319
|
total=len(self.test_dataloaders),
|
|
316
|
-
)
|
|
320
|
+
)
|
|
321
|
+
for task_name, test_dataloader in pbar:
|
|
317
322
|
classnames, templates = get_classnames_and_templates(task_name)
|
|
318
323
|
self.on_task_evaluation_begin(classifier, task_name)
|
|
319
324
|
classifier.set_classification_task(classnames, templates)
|
fusion_bench/utils/__init__.py
CHANGED
|
@@ -6,7 +6,8 @@ from . import data, functools, path
|
|
|
6
6
|
from .cache_utils import *
|
|
7
7
|
from .devices import *
|
|
8
8
|
from .dtype import parse_dtype
|
|
9
|
-
from .
|
|
9
|
+
from .fabric import seed_everything_by_time
|
|
10
|
+
from .instantiate import instantiate, is_instantiable
|
|
10
11
|
from .misc import *
|
|
11
12
|
from .packages import import_object
|
|
12
13
|
from .parameters import *
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from copy import deepcopy
|
|
2
|
+
from typing import Iterable, List, Tuple, Union
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def dict_get(d: dict, keys: Iterable[str], default=None):
|
|
6
|
+
return [d.get(k, default) for k in keys]
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def dict_map(f, d: dict, *, max_level: int = -1, skip_levels=0, inplace=False):
|
|
10
|
+
"""Apply function f to each element in dictionary d and return a new dictionary.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
f (callable): function to apply
|
|
14
|
+
d (dict): input dictionary
|
|
15
|
+
max_level (int, optional): maximum depth to apply function, -1 means unlimited. Defaults to -1.
|
|
16
|
+
skip_levels (int, optional): number of levels to skip. Defaults to 0.
|
|
17
|
+
inplace (bool, optional): whether to modify input dictionary in place. Defaults to False.
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
dict: transformed dictionary
|
|
21
|
+
"""
|
|
22
|
+
if not isinstance(d, dict):
|
|
23
|
+
raise TypeError("dict_map: d must be a dict")
|
|
24
|
+
|
|
25
|
+
if inplace:
|
|
26
|
+
ans = d
|
|
27
|
+
else:
|
|
28
|
+
ans = deepcopy(d)
|
|
29
|
+
|
|
30
|
+
def dict_map_impl(from_dict, to_dict, level):
|
|
31
|
+
if level == max_level:
|
|
32
|
+
return
|
|
33
|
+
for k in from_dict.keys():
|
|
34
|
+
if isinstance(from_dict[k], dict):
|
|
35
|
+
dict_map_impl(from_dict[k], to_dict[k], level + 1)
|
|
36
|
+
else:
|
|
37
|
+
if level < skip_levels:
|
|
38
|
+
continue
|
|
39
|
+
else:
|
|
40
|
+
to_dict[k] = f(from_dict[k])
|
|
41
|
+
|
|
42
|
+
dict_map_impl(d, ans, 0)
|
|
43
|
+
return ans
|