fusion-bench 0.2.8__py3-none-any.whl → 0.2.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. fusion_bench/__main__.py +4 -0
  2. fusion_bench/dataset/fer2013.py +1 -0
  3. fusion_bench/method/__init__.py +26 -4
  4. fusion_bench/method/classification/__init__.py +1 -0
  5. fusion_bench/method/classification/clip_finetune.py +1 -3
  6. fusion_bench/method/classification/continual_clip_finetune.py +297 -0
  7. fusion_bench/method/dare/__init__.py +1 -0
  8. fusion_bench/method/dare/task_arithmetic.py +14 -7
  9. fusion_bench/method/dare/ties_merging.py +100 -0
  10. fusion_bench/method/isotropic_merging/__init__.py +15 -0
  11. fusion_bench/method/isotropic_merging/iso.py +114 -0
  12. fusion_bench/method/isotropic_merging/iso_utils.py +176 -0
  13. fusion_bench/method/opcm/__init__.py +4 -0
  14. fusion_bench/method/opcm/opcm.py +277 -0
  15. fusion_bench/method/opcm/task_arithmetic.py +115 -0
  16. fusion_bench/method/opcm/ties_merging.py +156 -0
  17. fusion_bench/method/opcm/utils.py +73 -0
  18. fusion_bench/method/opcm/weight_average.py +120 -0
  19. fusion_bench/method/slerp/slerp.py +1 -1
  20. fusion_bench/method/task_singular_vector/TSVM.py +22 -2
  21. fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +91 -93
  22. fusion_bench/method/ties_merging/ties_merging.py +10 -0
  23. fusion_bench/metrics/continual_learning/backward_transfer.py +22 -0
  24. fusion_bench/mixins/clip_classification.py +4 -1
  25. fusion_bench/programs/fabric_fusion_program.py +22 -11
  26. fusion_bench/scripts/cli.py +1 -0
  27. fusion_bench/taskpool/base_pool.py +1 -1
  28. fusion_bench/taskpool/clip_vision/taskpool.py +12 -7
  29. fusion_bench/utils/__init__.py +2 -1
  30. fusion_bench/utils/dict.py +43 -0
  31. fusion_bench/utils/expr.py +90 -0
  32. fusion_bench/utils/fabric.py +17 -0
  33. fusion_bench/utils/instantiate.py +7 -1
  34. fusion_bench/utils/json.py +30 -0
  35. fusion_bench/utils/parameters.py +27 -7
  36. fusion_bench/utils/path.py +15 -0
  37. fusion_bench/utils/plot/color_data.py +1726 -0
  38. fusion_bench/utils/rich_utils.py +15 -0
  39. fusion_bench/utils/set.py +8 -0
  40. fusion_bench/utils/tensorboard.py +51 -0
  41. {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/METADATA +17 -18
  42. {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/RECORD +58 -29
  43. {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/WHEEL +1 -1
  44. fusion_bench_config/method/classification/clip_continual_finetune.yaml +28 -0
  45. fusion_bench_config/method/classification/clip_finetune.yaml +26 -0
  46. fusion_bench_config/method/clip_finetune.yaml +2 -2
  47. fusion_bench_config/method/dare/ties_merging.yaml +15 -0
  48. fusion_bench_config/method/isotropic_merging/iso_c.yaml +4 -0
  49. fusion_bench_config/method/isotropic_merging/iso_cts.yaml +5 -0
  50. fusion_bench_config/method/opcm/opcm.yaml +12 -0
  51. fusion_bench_config/method/opcm/task_arithmetic.yaml +12 -0
  52. fusion_bench_config/method/opcm/ties_merging.yaml +18 -0
  53. fusion_bench_config/method/opcm/weight_average.yaml +10 -0
  54. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +6 -0
  55. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +18 -0
  56. {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/LICENSE +0 -0
  57. {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/entry_points.txt +0 -0
  58. {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/top_level.txt +0 -0
@@ -116,7 +116,10 @@ def sum_svd_dict(svd_dict, config):
116
116
 
117
117
  ###############
118
118
  ##### LOSSLESS Orthogonalization
119
- def compute_and_sum_svd_mem_reduction_lossless(task_vectors, config):
119
+ def compute_and_sum_svd_mem_reduction_lossless(
120
+ task_vectors: List[StateDictType],
121
+ accelerator: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
122
+ ):
120
123
  """
121
124
  Computes the Singular Value Decomposition (SVD) for each task vector and merge the results.
122
125
 
@@ -129,40 +132,38 @@ def compute_and_sum_svd_mem_reduction_lossless(task_vectors, config):
129
132
 
130
133
  Args:
131
134
  task_vectors (list): A list of task vectors, where each task vector is a dictionary containing the vectors for each task.
132
- config (object): A configuration object containing the device and dataset information.
133
-
135
+ accelerator (torch.device): The device to use for the computation.
134
136
  Returns:
135
137
  dict: A dictionary containing the new vectors after summing the SVD components.
136
138
  """
137
139
  # becareful wit vit-l on 20 task it does not fit in GPU or in 64 GB RAM (try without last layer)
138
- device = config.device
139
140
  print("Computing SVD...")
140
141
  with torch.no_grad():
141
142
  new_vector = {}
142
- for key in task_vectors[0].vector:
143
+ for key in task_vectors[0]:
144
+ original_device = task_vectors[0][key].device
143
145
  new_vector[key] = {}
144
- for i, (task_vector, dataset) in enumerate(
145
- zip(task_vectors, config.DATASETS)
146
- ):
147
- vec = task_vector.vector[key].to(device)
146
+ for i, task_vector in enumerate(task_vectors):
147
+ vec = task_vector[key].to(accelerator)
148
148
 
149
- if (
150
- len(task_vector.vector[key].shape) == 2
151
- and "text_projection" not in key
152
- ):
149
+ if len(task_vector[key].shape) == 2 and "text_projection" not in key:
153
150
 
154
151
  u, s, v = torch.linalg.svd(vec, full_matrices=False)
155
152
 
156
153
  if i == 0:
157
154
  print(f"Computed SVD for {key}...")
158
155
  sum_u = torch.zeros(
159
- u.shape[0], u.shape[1] * config.num_tasks, device=device
156
+ u.shape[0],
157
+ u.shape[1] * len(task_vectors),
158
+ device=accelerator,
160
159
  )
161
160
  sum_s = torch.zeros(
162
- s.shape[0] * config.num_tasks, device=device
161
+ s.shape[0] * len(task_vectors), device=accelerator
163
162
  )
164
163
  sum_v = torch.zeros(
165
- v.shape[0] * config.num_tasks, v.shape[1], device=device
164
+ v.shape[0] * len(task_vectors),
165
+ v.shape[1],
166
+ device=accelerator,
166
167
  )
167
168
  reduced_index_s = s.shape[0]
168
169
 
@@ -184,7 +185,7 @@ def compute_and_sum_svd_mem_reduction_lossless(task_vectors, config):
184
185
  else:
185
186
  new_vector[key] += (vec - new_vector[key]) / (i + 1)
186
187
 
187
- if len(task_vector.vector[key].shape) == 2 and "text_projection" not in key:
188
+ if len(task_vector[key].shape) == 2 and "text_projection" not in key:
188
189
  u_u, s_u, v_u = torch.linalg.svd(sum_u, full_matrices=False)
189
190
  u_v, s_v, v_v = torch.linalg.svd(sum_v, full_matrices=False)
190
191
 
@@ -197,13 +198,16 @@ def compute_and_sum_svd_mem_reduction_lossless(task_vectors, config):
197
198
  v_v,
198
199
  )
199
200
  )
200
-
201
+ new_vector[key] = new_vector[key].to(original_device, non_blocking=True)
201
202
  return new_vector
202
203
 
203
204
 
204
205
  ###############
205
206
  ##### LOSSLESS EIGENDECOMP
206
- def compute_and_sum_svd_mem_reduction_lossless_eigen(task_vectors, config):
207
+ def compute_and_sum_svd_mem_reduction_lossless_eigen(
208
+ task_vectors: List[StateDictType],
209
+ accelerator: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
210
+ ):
207
211
  """
208
212
  Computes the Singular Value Decomposition (SVD) for each task vector and merge the results.
209
213
 
@@ -216,40 +220,39 @@ def compute_and_sum_svd_mem_reduction_lossless_eigen(task_vectors, config):
216
220
 
217
221
  Args:
218
222
  task_vectors (list): A list of task vectors, where each task vector is a dictionary containing the vectors for each task.
219
- config (object): A configuration object containing the device and dataset information.
223
+ accelerator (torch.device): The device to use for the computation.
220
224
 
221
225
  Returns:
222
226
  dict: A dictionary containing the new vectors after merging the SVD components.
223
227
  """
224
228
  # becareful wit vit-l on 20 task it does not fit in GPU or in 64 GB RAM (try without last layer)
225
- device = config.device
226
229
  print("Computing SVD...")
227
230
  with torch.no_grad():
228
231
  new_vector = {}
229
- for key in task_vectors[0].vector:
232
+ for key in task_vectors[0]:
233
+ original_device = task_vectors[0][key].device
230
234
  new_vector[key] = {}
231
- for i, (task_vector, dataset) in enumerate(
232
- zip(task_vectors, config.DATASETS)
233
- ):
234
- vec = task_vector.vector[key].to(device)
235
+ for i, task_vector in enumerate(task_vectors):
236
+ vec = task_vector[key].to(accelerator)
235
237
 
236
- if (
237
- len(task_vector.vector[key].shape) == 2
238
- and "text_projection" not in key
239
- ):
238
+ if len(task_vector[key].shape) == 2 and "text_projection" not in key:
240
239
 
241
240
  u, s, v = torch.linalg.svd(vec, full_matrices=False)
242
241
 
243
242
  if i == 0:
244
243
  print(f"Computed SVD for {key}...")
245
244
  sum_u = torch.zeros(
246
- u.shape[0], u.shape[1] * config.num_tasks, device=device
245
+ u.shape[0],
246
+ u.shape[1] * len(task_vectors),
247
+ device=accelerator,
247
248
  )
248
249
  sum_s = torch.zeros(
249
- s.shape[0] * config.num_tasks, device=device
250
+ s.shape[0] * len(task_vectors), device=accelerator
250
251
  )
251
252
  sum_v = torch.zeros(
252
- v.shape[0] * config.num_tasks, v.shape[1], device=device
253
+ v.shape[0] * len(task_vectors),
254
+ v.shape[1],
255
+ device=accelerator,
253
256
  )
254
257
  reduced_index_s = s.shape[0]
255
258
 
@@ -271,7 +274,7 @@ def compute_and_sum_svd_mem_reduction_lossless_eigen(task_vectors, config):
271
274
  else:
272
275
  new_vector[key] += (vec - new_vector[key]) / (i + 1)
273
276
 
274
- if len(task_vector.vector[key].shape) == 2 and "text_projection" not in key:
277
+ if len(task_vector[key].shape) == 2 and "text_projection" not in key:
275
278
  sum_s, indices = torch.sort(sum_s, stable=True)
276
279
 
277
280
  sum_u = torch.index_select(sum_u, 1, indices)
@@ -293,12 +296,14 @@ def compute_and_sum_svd_mem_reduction_lossless_eigen(task_vectors, config):
293
296
 
294
297
  new_vector[key] = torch.linalg.multi_dot( # bool_mask *
295
298
  (
299
+ sum_u,
296
300
  u_orth,
297
301
  torch.diag(sum_s),
298
302
  v_orth,
303
+ sum_v,
299
304
  )
300
305
  )
301
-
306
+ new_vector[key] = new_vector[key].to(original_device, non_blocking=True)
302
307
  return new_vector
303
308
 
304
309
 
@@ -394,7 +399,10 @@ def compute_and_sum_svd_mem_reduction(
394
399
 
395
400
  ###############
396
401
  #### TSV Merge Eigendecomp
397
- def compute_and_sum_svd_mem_reduction_2(task_vectors, config):
402
+ def compute_and_sum_svd_mem_reduction_2(
403
+ task_vectors: List[StateDictType],
404
+ accelerator: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
405
+ ):
398
406
  """
399
407
  Computes the Singular Value Decomposition (SVD) for each vector in the task_vectors,
400
408
  reduces the dimensionality of the vectors based on the sv_reduction factor, and concatenate
@@ -404,36 +412,30 @@ def compute_and_sum_svd_mem_reduction_2(task_vectors, config):
404
412
  Args:
405
413
  task_vectors (list): A list of task vector objects, where each object contains a
406
414
  dictionary of vectors.
407
- config (object): Configuration object containing the following attributes:
408
- - DATASETS (list): List of datasets.
409
- - device (torch.device): The device to perform computations on.
415
+ accelerator (torch.device): The device to use for the computation.
410
416
 
411
417
  Returns:
412
418
  dict: A dictionary containing the new vectors after SVD computation and merging.
413
419
  """
414
- sv_reduction = 1 / len(config.DATASETS)
415
- device = config.device
420
+ sv_reduction = 1 / len(task_vectors)
421
+
416
422
  print("Computing SVD...")
417
423
  with torch.no_grad():
418
424
  new_vector = {}
419
- for key in task_vectors[0].vector:
425
+ for key in task_vectors[0]:
426
+ original_device = task_vectors[0][key].device
420
427
  new_vector[key] = {}
421
- for i, (task_vector, dataset) in enumerate(
422
- zip(task_vectors, config.DATASETS)
423
- ):
424
- vec = task_vector.vector[key].to(device)
428
+ for i, task_vector in enumerate(task_vectors):
429
+ vec = task_vector[key].to(accelerator)
425
430
 
426
- if (
427
- len(task_vector.vector[key].shape) == 2
428
- and "text_projection" not in key
429
- ):
431
+ if len(task_vector[key].shape) == 2 and "text_projection" not in key:
430
432
  u, s, v = torch.linalg.svd(vec, full_matrices=False)
431
433
 
432
434
  if i == 0:
433
435
  print(f"Computed SVD for {key}...")
434
- sum_u = torch.zeros_like(u, device=device)
435
- sum_s = torch.zeros_like(s, device=device)
436
- sum_v = torch.zeros_like(v, device=device)
436
+ sum_u = torch.zeros_like(u, device=accelerator)
437
+ sum_s = torch.zeros_like(s, device=accelerator)
438
+ sum_v = torch.zeros_like(v, device=accelerator)
437
439
  reduced_index_s = int(s.shape[0] * sv_reduction)
438
440
 
439
441
  # select only the first reduced_index_s columns of u and place them
@@ -454,7 +456,7 @@ def compute_and_sum_svd_mem_reduction_2(task_vectors, config):
454
456
  else:
455
457
  new_vector[key] += (vec - new_vector[key]) / (i + 1)
456
458
 
457
- if len(task_vector.vector[key].shape) == 2 and "text_projection" not in key:
459
+ if len(task_vector[key].shape) == 2 and "text_projection" not in key:
458
460
  sum_s, indices = torch.sort(sum_s, stable=True)
459
461
 
460
462
  sum_u = torch.index_select(sum_u, 1, indices)
@@ -483,13 +485,17 @@ def compute_and_sum_svd_mem_reduction_2(task_vectors, config):
483
485
  sum_v,
484
486
  )
485
487
  )
488
+ new_vector[key] = new_vector[key].to(original_device, non_blocking=True)
486
489
 
487
490
  return new_vector
488
491
 
489
492
 
490
493
  ###############
491
494
  #### Rank Reduction TV
492
- def compute_and_sum_svd_mem_reduction_rank_reduction(task_vectors, config):
495
+ def compute_and_sum_svd_mem_reduction_rank_reduction(
496
+ task_vectors: List[StateDictType],
497
+ accelerator: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
498
+ ):
493
499
  """
494
500
  Compute and sum the Singular Value Decomposition (SVD) of task vectors with rank reduction.
495
501
 
@@ -499,36 +505,29 @@ def compute_and_sum_svd_mem_reduction_rank_reduction(task_vectors, config):
499
505
  Args:
500
506
  task_vectors (list): A list of task vector objects. Each object should have a `vector` attribute
501
507
  which is a dictionary where keys are vector names and values are tensors.
502
- config (object): Configuration object containing the following attributes:
503
- - DATASETS (list): List of datasets.
504
- - device (torch.device): The device to perform computations on.
508
+ accelerator (torch.device): The device to use for the computation.
505
509
 
506
510
  Returns:
507
511
  dict: A dictionary containing the new vectors after SVD computation and summation.
508
512
  """
509
- sv_reduction = 1 / len(config.DATASETS)
510
- device = config.device
513
+ sv_reduction = 1 / len(task_vectors)
511
514
  print("Computing SVD...")
512
515
  with torch.no_grad():
513
516
  new_vector = {}
514
- for key in task_vectors[0].vector:
517
+ for key in task_vectors[0]:
518
+ original_device = task_vectors[0][key].device
515
519
  new_vector[key] = {}
516
- for i, (task_vector, dataset) in enumerate(
517
- zip(task_vectors, config.DATASETS)
518
- ):
519
- vec = task_vector.vector[key].to(device)
520
+ for i, task_vector in enumerate(task_vectors):
521
+ vec = task_vector[key].to(accelerator)
520
522
 
521
- if (
522
- len(task_vector.vector[key].shape) == 2
523
- and "text_projection" not in key
524
- ):
523
+ if len(task_vector[key].shape) == 2 and "text_projection" not in key:
525
524
  u, s, v = torch.linalg.svd(vec, full_matrices=False)
526
525
 
527
526
  if i == 0:
528
527
  print(f"Computed SVD for {key}...")
529
- sum_u = torch.zeros_like(u, device=device)
530
- sum_s = torch.zeros_like(s, device=device)
531
- sum_v = torch.zeros_like(v, device=device)
528
+ sum_u = torch.zeros_like(u, device=accelerator)
529
+ sum_s = torch.zeros_like(s, device=accelerator)
530
+ sum_v = torch.zeros_like(v, device=accelerator)
532
531
  reduced_index_s = int(s.shape[0] * sv_reduction)
533
532
 
534
533
  # select only the first reduced_index_s columns of u and place them
@@ -549,7 +548,7 @@ def compute_and_sum_svd_mem_reduction_rank_reduction(task_vectors, config):
549
548
  else:
550
549
  new_vector[key] += (vec - new_vector[key]) / (i + 1)
551
550
 
552
- if len(task_vector.vector[key].shape) == 2 and "text_projection" not in key:
551
+ if len(task_vector[key].shape) == 2 and "text_projection" not in key:
553
552
  new_vector[key] = torch.linalg.multi_dot(
554
553
  (
555
554
  sum_u,
@@ -557,26 +556,29 @@ def compute_and_sum_svd_mem_reduction_rank_reduction(task_vectors, config):
557
556
  sum_v,
558
557
  )
559
558
  )
559
+
560
+ new_vector[key] = new_vector[key].to(original_device, non_blocking=True)
560
561
  return new_vector
561
562
 
562
563
 
563
- def compute_and_sum_svd_mem_reduction_dummy(task_vectors, config):
564
+ def compute_and_sum_svd_mem_reduction_dummy(
565
+ task_vectors: List[StateDictType],
566
+ accelerator: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
567
+ ):
564
568
  """To perform dummy operations."""
565
- sv_reduction = 1 / 8
569
+ sv_reduction = 1 / len(task_vectors)
566
570
  print("Computing SVD...")
567
571
  with torch.no_grad():
568
572
  new_vector = {}
569
- for key in task_vectors[0].vector:
573
+ for key in task_vectors[0]:
574
+ original_device = task_vectors[0][key].device
570
575
  new_vector[key] = {}
571
- for i in range(0, 8):
572
- if (
573
- len(task_vectors[0].vector[key].shape) == 2
574
- and "text_projection" not in key
575
- ):
576
+ for i, task_vector in enumerate(task_vectors):
577
+ vec = task_vector[key].to(accelerator)
578
+
579
+ if len(task_vector[key].shape) == 2 and "text_projection" not in key:
576
580
  if i == 0:
577
- u, s, v = torch.linalg.svd(
578
- task_vectors[0].vector[key], full_matrices=False
579
- )
581
+ u, s, v = torch.linalg.svd(vec, full_matrices=False)
580
582
  reduced_index_s = int(s.shape[0] * sv_reduction)
581
583
 
582
584
  print(f"Computed SVD for {key}...")
@@ -620,16 +622,11 @@ def compute_and_sum_svd_mem_reduction_dummy(task_vectors, config):
620
622
 
621
623
  else:
622
624
  if i == 0:
623
- new_vector[key] = task_vectors[0].vector[key]
624
- # else:
625
- # new_vector[key] += (
626
- # task_vector.vector[key] - new_vector[key]
627
- # ) / (i + 1)
625
+ new_vector[key] = vec.clone()
626
+ else:
627
+ new_vector[key] += (vec - new_vector[key]) / (i + 1)
628
628
 
629
- if (
630
- len(task_vectors[0].vector[key].shape) == 2
631
- and "text_projection" not in key
632
- ):
629
+ if len(task_vector[key].shape) == 2 and "text_projection" not in key:
633
630
 
634
631
  new_vector[key] = torch.linalg.multi_dot(
635
632
  (
@@ -639,4 +636,5 @@ def compute_and_sum_svd_mem_reduction_dummy(task_vectors, config):
639
636
  )
640
637
  )
641
638
 
639
+ new_vector[key] = new_vector[key].to(original_device, non_blocking=True)
642
640
  return new_vector
@@ -1,3 +1,13 @@
1
+ R"""
2
+ Overview of Ties-Merging:
3
+
4
+ 1. Trim: For each task t, we trim the redundant parameters from the task vector $\tau_t$ to create $\hat{\tau}_t$ by keeping the top-k% values according to their magnitude and trimming the bottom $(100 - k)\%$ of the redundant parameters by resetting them to 0. This can be decomposed further as $\hat{\tau}_t = \hat{\gamma}_t \odot \hat{\mu}_t$.
5
+
6
+ 2. Elect: Next, we create an aggregate elected sign vector $\gamma_m$ for the merged model that resolves the disagreements in the sign for each parameter p across different models. To create the elected sign vector, we choose the sign with the highest total magnitude across all relevant models. For each parameter $p \in \{1, 2, \ldots, d\}$, we separate the values $\{\hat{\tau}_t^p\}_{t=1}^n$ based on their sign $(+1$ or $-1)$ and take their sum to calculate the total mass (i.e., total magnitude) in the positive and the negative direction. We then assign $\gamma_m^p$ as the sign with greater total movement. This can be efficiently computed using $\gamma_m^p = \text{sgn}(\sum_{t=1}^n \hat{\tau}_t^p)$.
7
+
8
+ 3. Disjoint Merge: Then, for each parameter p, we compute a disjoint mean by only keeping the parameter values from the models whose signs are the same as the aggregated elected sign and calculate their mean. Formally, let $A_p = \{t \in [n] \mid \hat{\gamma}_t^p = \gamma_m^p\}$, then $\tau_m^p = \frac{1}{|A_p|}\sum_{t\in A_p} \hat{\tau}_t^p$. Note that the disjoint mean always ignores the zero values.
9
+ """
10
+
1
11
  import logging
2
12
  from typing import Dict, List, Literal, Mapping, Union # noqa: F401
3
13
 
@@ -0,0 +1,22 @@
1
+ from typing import Dict
2
+
3
+ import numpy as np
4
+
5
+
6
+ def compute_backward_transfer(
7
+ acc_Ti: Dict[str, float], acc_ii: Dict[str, float]
8
+ ) -> float:
9
+ R"""
10
+ Compute the backward transfer (BWT) of a model on a set of tasks.
11
+
12
+ Equation:
13
+ BWT = \frac{1}{n} \sum_{k=1}^{n} (acc_{Ti}[k] - acc_{ii}[k])
14
+
15
+ Returns:
16
+ float: The backward transfer of the model.
17
+ """
18
+ assert set(acc_ii.keys()) == set(acc_Ti.keys())
19
+ bwt = 0
20
+ for task_name in acc_ii:
21
+ bwt += acc_Ti[task_name] - acc_ii[task_name]
22
+ return bwt / len(acc_ii)
@@ -161,12 +161,14 @@ class CLIPClassificationMixin(LightningFabricMixin):
161
161
  cache_dir, os.path.normpath(f"{task}_zeroshot_weights.pt")
162
162
  )
163
163
  if os.path.exists(cache_file):
164
- log.info(f"Loading cached zeroshot weights for task: {task}")
165
164
  zeroshot_weights = torch.load(
166
165
  cache_file,
167
166
  map_location="cpu",
168
167
  weights_only=True,
169
168
  ).detach()
169
+ log.info(
170
+ f"Loadded cached zeroshot weights for task: {task}, shape: {zeroshot_weights.shape}"
171
+ )
170
172
  else:
171
173
  log.info(
172
174
  f"Construct zero shot classification head for task: {task}"
@@ -180,6 +182,7 @@ class CLIPClassificationMixin(LightningFabricMixin):
180
182
  self.fabric.barrier()
181
183
  self.zeroshot_weights[task] = self.fabric.broadcast(zeroshot_weights, src=0)
182
184
  self.zeroshot_weights[task] = self.to_device(self.zeroshot_weights[task])
185
+ self.fabric.barrier()
183
186
 
184
187
  del clip_classifier
185
188
  if torch.cuda.is_available():
@@ -103,12 +103,13 @@ class FabricModelFusionProgram(
103
103
  )
104
104
  if compat_load_fn is not None:
105
105
  compat_load_fn = import_object(compat_load_fn)
106
- print_bordered(
107
- OmegaConf.to_yaml(config),
108
- title="instantiate compat object",
109
- style="magenta",
110
- code_style="yaml",
111
- )
106
+ if rank_zero_only.rank == 0:
107
+ print_bordered(
108
+ OmegaConf.to_yaml(config),
109
+ title="instantiate compat object",
110
+ style="magenta",
111
+ code_style="yaml",
112
+ )
112
113
  obj = compat_load_fn(config)
113
114
  else:
114
115
  raise ValueError(
@@ -159,7 +160,11 @@ class FabricModelFusionProgram(
159
160
  print("No save path specified for the merged model. Skipping saving.")
160
161
 
161
162
  def evaluate_merged_model(
162
- self, taskpool: BaseTaskPool, merged_model: Union[nn.Module, Dict, Iterable]
163
+ self,
164
+ taskpool: BaseTaskPool,
165
+ merged_model: Union[nn.Module, Dict, Iterable],
166
+ *args,
167
+ **kwargs,
163
168
  ):
164
169
  """
165
170
  Evaluates the merged model using the provided task pool.
@@ -174,6 +179,8 @@ class FabricModelFusionProgram(
174
179
  Args:
175
180
  taskpool: The task pool used for evaluating the merged model.
176
181
  merged_model: The merged model to be evaluated. It can be an instance of `nn.Module`, a dictionary, or an iterable.
182
+ *args: Additional positional arguments to be passed to the `evaluate` method of the taskpool.
183
+ **kwargs: Additional keyword arguments to be passed to the `evaluate` method of the taskpool.
177
184
 
178
185
  Returns:
179
186
  The evaluation report. The type of the report depends on the type of the merged model:
@@ -182,20 +189,20 @@ class FabricModelFusionProgram(
182
189
  - If the merged model is an iterable, the report is a list of evaluation reports.
183
190
  """
184
191
  if isinstance(merged_model, nn.Module):
185
- report = taskpool.evaluate(merged_model)
192
+ report = taskpool.evaluate(merged_model, *args, **kwargs)
186
193
  return report
187
194
  elif isinstance(merged_model, Dict):
188
195
  report = {}
189
196
  for key, item in merged_model.items():
190
197
  if isinstance(item, nn.Module):
191
- report[key] = taskpool.evaluate(item)
198
+ report[key] = taskpool.evaluate(item, *args, **kwargs)
192
199
  else:
193
200
  # metadata
194
201
  report[key] = item
195
202
  return report
196
203
  elif isinstance(merged_model, Iterable):
197
204
  return [
198
- self.evaluate_merged_model(taskpool, m)
205
+ self.evaluate_merged_model(taskpool, m, *args, **kwargs)
199
206
  for m in tqdm(merged_model, desc="Evaluating models")
200
207
  ]
201
208
  else:
@@ -272,7 +279,11 @@ class FabricModelFusionProgram(
272
279
  """
273
280
  if self.log_dir is not None:
274
281
  # make symlink to the hydra output directory
275
- hydra_output_dir = get_hydra_output_dir()
282
+ try:
283
+ hydra_output_dir = get_hydra_output_dir()
284
+ except Exception as e:
285
+ hydra_output_dir = None
286
+
276
287
  if hydra_output_dir is not None:
277
288
  os.makedirs(self.log_dir, exist_ok=True)
278
289
  try:
@@ -1,3 +1,4 @@
1
+ #!/usr/bin/env python3
1
2
  """
2
3
  This is the CLI script that is executed when the user runs the `fusion-bench` command.
3
4
  The script is responsible for parsing the command-line arguments, loading the configuration file, and running the fusion algorithm.
@@ -7,7 +7,7 @@ class BaseTaskPool(BaseYAMLSerializableModel):
7
7
  _program = None
8
8
 
9
9
  @abstractmethod
10
- def evaluate(self, model):
10
+ def evaluate(self, model, *args, **kwargs):
11
11
  """
12
12
  Evaluate the model on all tasks in the task pool, and return a report.
13
13
 
@@ -238,11 +238,13 @@ class CLIPVisionModelTaskPool(
238
238
  else:
239
239
  test_loader = test_loader
240
240
 
241
- for batch in (
242
- pbar := tqdm(
243
- test_loader, desc="Evaluating", leave=False, dynamic_ncols=True
244
- )
245
- ):
241
+ pbar = tqdm(
242
+ test_loader,
243
+ desc=f"Evaluating {task_name}",
244
+ leave=False,
245
+ dynamic_ncols=True,
246
+ )
247
+ for batch in pbar:
246
248
  inputs, targets = batch
247
249
  outputs = classifier(
248
250
  inputs,
@@ -309,11 +311,14 @@ class CLIPVisionModelTaskPool(
309
311
  }
310
312
  if name is not None:
311
313
  report["model_info"]["name"] = name
312
- for task_name, test_dataloader in tqdm(
314
+
315
+ # evaluate on each task
316
+ pbar = tqdm(
313
317
  self.test_dataloaders.items(),
314
318
  desc="Evaluating tasks",
315
319
  total=len(self.test_dataloaders),
316
- ):
320
+ )
321
+ for task_name, test_dataloader in pbar:
317
322
  classnames, templates = get_classnames_and_templates(task_name)
318
323
  self.on_task_evaluation_begin(classifier, task_name)
319
324
  classifier.set_classification_task(classnames, templates)
@@ -6,7 +6,8 @@ from . import data, functools, path
6
6
  from .cache_utils import *
7
7
  from .devices import *
8
8
  from .dtype import parse_dtype
9
- from .instantiate import instantiate
9
+ from .fabric import seed_everything_by_time
10
+ from .instantiate import instantiate, is_instantiable
10
11
  from .misc import *
11
12
  from .packages import import_object
12
13
  from .parameters import *
@@ -0,0 +1,43 @@
1
+ from copy import deepcopy
2
+ from typing import Iterable, List, Tuple, Union
3
+
4
+
5
+ def dict_get(d: dict, keys: Iterable[str], default=None):
6
+ return [d.get(k, default) for k in keys]
7
+
8
+
9
+ def dict_map(f, d: dict, *, max_level: int = -1, skip_levels=0, inplace=False):
10
+ """Apply function f to each element in dictionary d and return a new dictionary.
11
+
12
+ Args:
13
+ f (callable): function to apply
14
+ d (dict): input dictionary
15
+ max_level (int, optional): maximum depth to apply function, -1 means unlimited. Defaults to -1.
16
+ skip_levels (int, optional): number of levels to skip. Defaults to 0.
17
+ inplace (bool, optional): whether to modify input dictionary in place. Defaults to False.
18
+
19
+ Returns:
20
+ dict: transformed dictionary
21
+ """
22
+ if not isinstance(d, dict):
23
+ raise TypeError("dict_map: d must be a dict")
24
+
25
+ if inplace:
26
+ ans = d
27
+ else:
28
+ ans = deepcopy(d)
29
+
30
+ def dict_map_impl(from_dict, to_dict, level):
31
+ if level == max_level:
32
+ return
33
+ for k in from_dict.keys():
34
+ if isinstance(from_dict[k], dict):
35
+ dict_map_impl(from_dict[k], to_dict[k], level + 1)
36
+ else:
37
+ if level < skip_levels:
38
+ continue
39
+ else:
40
+ to_dict[k] = f(from_dict[k])
41
+
42
+ dict_map_impl(d, ans, 0)
43
+ return ans