fusion-bench 0.2.8__py3-none-any.whl → 0.2.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. fusion_bench/__main__.py +4 -0
  2. fusion_bench/dataset/fer2013.py +1 -0
  3. fusion_bench/method/__init__.py +26 -4
  4. fusion_bench/method/classification/__init__.py +1 -0
  5. fusion_bench/method/classification/clip_finetune.py +1 -3
  6. fusion_bench/method/classification/continual_clip_finetune.py +297 -0
  7. fusion_bench/method/dare/__init__.py +1 -0
  8. fusion_bench/method/dare/task_arithmetic.py +14 -7
  9. fusion_bench/method/dare/ties_merging.py +100 -0
  10. fusion_bench/method/isotropic_merging/__init__.py +15 -0
  11. fusion_bench/method/isotropic_merging/iso.py +114 -0
  12. fusion_bench/method/isotropic_merging/iso_utils.py +176 -0
  13. fusion_bench/method/opcm/__init__.py +4 -0
  14. fusion_bench/method/opcm/opcm.py +277 -0
  15. fusion_bench/method/opcm/task_arithmetic.py +115 -0
  16. fusion_bench/method/opcm/ties_merging.py +156 -0
  17. fusion_bench/method/opcm/utils.py +73 -0
  18. fusion_bench/method/opcm/weight_average.py +120 -0
  19. fusion_bench/method/slerp/slerp.py +1 -1
  20. fusion_bench/method/task_singular_vector/TSVM.py +22 -2
  21. fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +91 -93
  22. fusion_bench/method/ties_merging/ties_merging.py +10 -0
  23. fusion_bench/metrics/continual_learning/backward_transfer.py +22 -0
  24. fusion_bench/mixins/clip_classification.py +4 -1
  25. fusion_bench/programs/fabric_fusion_program.py +22 -11
  26. fusion_bench/scripts/cli.py +1 -0
  27. fusion_bench/taskpool/base_pool.py +1 -1
  28. fusion_bench/taskpool/clip_vision/taskpool.py +12 -7
  29. fusion_bench/utils/__init__.py +2 -1
  30. fusion_bench/utils/dict.py +43 -0
  31. fusion_bench/utils/expr.py +90 -0
  32. fusion_bench/utils/fabric.py +17 -0
  33. fusion_bench/utils/instantiate.py +7 -1
  34. fusion_bench/utils/json.py +30 -0
  35. fusion_bench/utils/parameters.py +27 -7
  36. fusion_bench/utils/path.py +15 -0
  37. fusion_bench/utils/plot/color_data.py +1726 -0
  38. fusion_bench/utils/rich_utils.py +15 -0
  39. fusion_bench/utils/set.py +8 -0
  40. fusion_bench/utils/tensorboard.py +51 -0
  41. {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/METADATA +17 -18
  42. {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/RECORD +58 -29
  43. {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/WHEEL +1 -1
  44. fusion_bench_config/method/classification/clip_continual_finetune.yaml +28 -0
  45. fusion_bench_config/method/classification/clip_finetune.yaml +26 -0
  46. fusion_bench_config/method/clip_finetune.yaml +2 -2
  47. fusion_bench_config/method/dare/ties_merging.yaml +15 -0
  48. fusion_bench_config/method/isotropic_merging/iso_c.yaml +4 -0
  49. fusion_bench_config/method/isotropic_merging/iso_cts.yaml +5 -0
  50. fusion_bench_config/method/opcm/opcm.yaml +12 -0
  51. fusion_bench_config/method/opcm/task_arithmetic.yaml +12 -0
  52. fusion_bench_config/method/opcm/ties_merging.yaml +18 -0
  53. fusion_bench_config/method/opcm/weight_average.yaml +10 -0
  54. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +6 -0
  55. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +18 -0
  56. {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/LICENSE +0 -0
  57. {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/entry_points.txt +0 -0
  58. {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,156 @@
1
+ import os
2
+ import random
3
+ import time
4
+ from collections import defaultdict
5
+ from copy import deepcopy
6
+ from pathlib import Path
7
+ from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, cast
8
+
9
+ import lightning as L
10
+ import numpy as np
11
+ import torch
12
+ from omegaconf import DictConfig
13
+ from torch import Tensor, nn
14
+ from tqdm.auto import tqdm
15
+ from transformers import CLIPVisionModel
16
+
17
+ from fusion_bench import BaseAlgorithm, BaseModelPool
18
+ from fusion_bench.method.ties_merging.ties_merging_utils import (
19
+ state_dict_to_vector,
20
+ ties_merging,
21
+ vector_to_state_dict,
22
+ )
23
+ from fusion_bench.mixins import LightningFabricMixin
24
+ from fusion_bench.taskpool import CLIPVisionModelTaskPool
25
+ from fusion_bench.utils.json import load_from_json, save_to_json
26
+ from fusion_bench.utils.state_dict_arithmetic import state_dict_add, state_dict_sub
27
+
28
+ if TYPE_CHECKING:
29
+ from torch.utils.tensorboard import SummaryWriter
30
+
31
+
32
+ class ContinualTiesMergingForCLIP(BaseAlgorithm, LightningFabricMixin):
33
+ def __init__(
34
+ self,
35
+ scaling_factor: float,
36
+ threshold: float,
37
+ remove_keys: Optional[List[str]] = None,
38
+ merge_func: Literal["sum", "mean", "max"] = "sum",
39
+ shuffle_order: bool = True,
40
+ seed: Optional[int] = None,
41
+ save_on_every_step: bool = True,
42
+ evaluate_on_every_step: bool = False,
43
+ **kwargs,
44
+ ):
45
+ """
46
+ Continual Model Merging via Ties-Merging.
47
+
48
+ Args:
49
+ scaling_factor (float): the scaling factor to use.
50
+ shuffle_order (bool): whether to shuffle the order of the models.
51
+ seed (Optional[int]): the seed to use.
52
+ save_on_every_step (bool): whether to save the merged model on every step.
53
+ evaluate_on_every_step (bool): whether to evaluate the merged model on every step.
54
+ """
55
+ self.scaling_factor = scaling_factor
56
+ self.threshold = threshold
57
+ self.remove_keys = remove_keys if remove_keys is not None else []
58
+ self.merge_func = merge_func
59
+ self.shuffle_order = shuffle_order
60
+ self.seed = seed
61
+ self.save_on_every_step = save_on_every_step
62
+ self.evaluate_on_every_step = evaluate_on_every_step
63
+ super().__init__(**kwargs)
64
+
65
+ @torch.no_grad()
66
+ def run(self, modelpool: BaseModelPool):
67
+ if self.seed is not None:
68
+ L.seed_everything(self.seed)
69
+
70
+ model_names = modelpool.model_names
71
+ if self.shuffle_order:
72
+ random.shuffle(model_names)
73
+
74
+ self.taskpool = cast(CLIPVisionModelTaskPool, self._program.taskpool)
75
+ self._test_datasets = deepcopy(self.taskpool._test_datasets)
76
+ """Configuration for the test datasets"""
77
+
78
+ # log the model names
79
+ if self.log_dir is not None:
80
+ save_to_json(model_names, Path(self.log_dir) / "model_names.json")
81
+ tensorboard_summarywriter: "SummaryWriter" = self.tensorboard_summarywriter
82
+ tensorboard_summarywriter.add_text(
83
+ "global/model_names", str(model_names), global_step=0
84
+ )
85
+
86
+ # get the average model
87
+ pretrained_model = modelpool.load_pretrained_model()
88
+ merged_model = deepcopy(pretrained_model)
89
+
90
+ for model_idx, model_name in tqdm(
91
+ enumerate(model_names), desc="Processing models"
92
+ ):
93
+ task_model = modelpool.load_model(model_name)
94
+
95
+ task_vector = state_dict_sub(
96
+ task_model.state_dict(),
97
+ pretrained_model.state_dict(),
98
+ )
99
+ if model_idx == 0:
100
+ # if is the first model, the merged task vector is equal to the task vector
101
+ ties_merging_state_dict = task_vector
102
+ else:
103
+ # if is not the first model, we need to merge the task vector with the previous merged task vector
104
+ merged_tv = state_dict_sub(
105
+ merged_model.state_dict(),
106
+ pretrained_model.state_dict(),
107
+ )
108
+ tv_flat_checks = torch.vstack(
109
+ [
110
+ state_dict_to_vector(merged_tv, remove_keys=self.remove_keys),
111
+ state_dict_to_vector(task_vector, remove_keys=self.remove_keys),
112
+ ]
113
+ )
114
+ # perform the TIES merging
115
+ ties_merging_tv = ties_merging(
116
+ tv_flat_checks,
117
+ reset_thresh=self.threshold,
118
+ merge_func=self.merge_func,
119
+ )
120
+ # convert the merged task vector back to a state dict
121
+ ties_merging_state_dict = vector_to_state_dict(
122
+ ties_merging_tv,
123
+ merged_model.state_dict(),
124
+ remove_keys=self.remove_keys,
125
+ )
126
+
127
+ for param_name, param in task_model.named_parameters():
128
+ if not param.requires_grad:
129
+ continue
130
+
131
+ merged_param = merged_model.get_parameter(param_name)
132
+ new_param = (
133
+ merged_param
134
+ + self.scaling_factor * ties_merging_state_dict[param_name]
135
+ )
136
+ merged_model.get_parameter(param_name).data = new_param
137
+
138
+ if self.save_on_every_step:
139
+ self.save_merged_model(merged_model, model_idx)
140
+
141
+ if self.evaluate_on_every_step:
142
+ self.taskpool._is_setup = False
143
+ self.taskpool._test_datasets = DictConfig(
144
+ {n: self._test_datasets[n] for n in model_names[: model_idx + 1]}
145
+ )
146
+ report = self.taskpool.evaluate(deepcopy(merged_model))
147
+ save_to_json(report, Path(self.log_dir) / f"report_{model_idx}.json")
148
+
149
+ return merged_model
150
+
151
+ def save_merged_model(self, merged_model: CLIPVisionModel, step: int):
152
+ os.makedirs(Path(self.log_dir) / "checkpoints", exist_ok=True)
153
+ torch.save(
154
+ merged_model.state_dict(),
155
+ Path(self.log_dir) / "checkpoints" / f"model_{step}.pth",
156
+ )
@@ -0,0 +1,73 @@
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ from torch import Tensor, nn
5
+
6
+ from fusion_bench.utils.parameters import state_dict_to_vector
7
+ from fusion_bench.utils.state_dict_arithmetic import state_dict_sub
8
+
9
+
10
+ def _svd(w: Tensor, full_matrices=True) -> Tuple[Tensor, Tensor, Tensor]:
11
+ """
12
+ Perform Singular Value Decomposition (SVD) on a tensor.
13
+
14
+ Args:
15
+ w (Tensor): The input tensor.
16
+ full_matrices (bool): Whether to compute the full-sized U and V matrices.
17
+
18
+ Returns:
19
+ Tuple[Tensor, Tensor, Tensor]: The U, S, and V matrices from SVD.
20
+ """
21
+ u, s, vh = torch.linalg.svd(
22
+ w, full_matrices=full_matrices, driver="gesvd" if w.is_cuda else None
23
+ )
24
+ v = vh.T
25
+ return u, s, v
26
+
27
+
28
+ def svd(
29
+ w: Tensor, full_matrices=True, accelerator=None
30
+ ) -> Tuple[Tensor, Tensor, Tensor]:
31
+ """
32
+ Perform SVD on a tensor, optionally using a specified accelerator.
33
+
34
+ Args:
35
+ w (Tensor): The input tensor.
36
+ full_matrices (bool): Whether to compute the full-sized U and V matrices.
37
+ accelerator (str): The device to perform the computation on.
38
+
39
+ Returns:
40
+ Tuple[Tensor, Tensor, Tensor]: The U, S, and V matrices from SVD.
41
+ """
42
+ if accelerator is None:
43
+ return _svd(w, full_matrices=full_matrices)
44
+ original_device = w.device
45
+ w = w.to(accelerator)
46
+ u, s, v = _svd(w)
47
+ return u.to(original_device), s.to(original_device), v.to(original_device)
48
+
49
+
50
+ def frobenius_inner_product(w1: Tensor, w2: Tensor) -> Tensor:
51
+ return torch.trace(w1.T @ w2)
52
+
53
+
54
+ def is_leaf_module(module: nn.Module) -> bool:
55
+ return len(list(module.children())) == 0
56
+
57
+
58
+ def get_task_vector_norm(model: nn.Module, pretrained_model: nn.Module) -> Tensor:
59
+ """
60
+ Get the vector norm of the task model.
61
+
62
+ Args:
63
+ model (nn.Module): The task model.
64
+ pretrained_model (nn.Module): The pretrained model.
65
+
66
+ Returns:
67
+ Tensor: The vector norm of the task model.
68
+ """
69
+ return torch.linalg.norm(
70
+ state_dict_to_vector(
71
+ state_dict_sub(model.state_dict(), pretrained_model.state_dict())
72
+ )
73
+ )
@@ -0,0 +1,120 @@
1
+ import os
2
+ import random
3
+ import time
4
+ from collections import defaultdict
5
+ from copy import deepcopy
6
+ from pathlib import Path
7
+ from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, cast
8
+
9
+ import lightning as L
10
+ import numpy as np
11
+ import torch
12
+ from omegaconf import DictConfig
13
+ from torch import Tensor, nn
14
+ from tqdm.auto import tqdm
15
+ from transformers import CLIPVisionModel
16
+
17
+ from fusion_bench import BaseAlgorithm, BaseModelPool
18
+ from fusion_bench.mixins import LightningFabricMixin
19
+ from fusion_bench.taskpool import CLIPVisionModelTaskPool
20
+ from fusion_bench.utils.json import load_from_json, save_to_json
21
+
22
+ if TYPE_CHECKING:
23
+ from torch.utils.tensorboard import SummaryWriter
24
+
25
+
26
+ class ContinualWeightAverageForCLIP(
27
+ BaseAlgorithm,
28
+ LightningFabricMixin,
29
+ ):
30
+ def __init__(
31
+ self,
32
+ shuffle_order: bool = True,
33
+ seed: Optional[int] = None,
34
+ save_on_every_step: bool = True,
35
+ evaluate_on_every_step: bool = False,
36
+ **kwargs,
37
+ ):
38
+ """
39
+ Continual Model Merging via Weight Average.
40
+
41
+ Args:
42
+ shuffle_order (bool): whether to shuffle the order of the models.
43
+ seed (Optional[int]): the seed to use.
44
+ save_on_every_step (bool): whether to save the merged model on every step.
45
+ evaluate_on_every_step (bool): whether to evaluate the merged model on every step.
46
+ """
47
+ self.shuffle_order = shuffle_order
48
+ self.seed = seed
49
+ self.save_on_every_step = save_on_every_step
50
+ self.evaluate_on_every_step = evaluate_on_every_step
51
+ super().__init__(**kwargs)
52
+
53
+ def run(self, modelpool: BaseModelPool):
54
+ if self.seed is not None:
55
+ L.seed_everything(self.seed)
56
+
57
+ model_names = modelpool.model_names
58
+ if self.shuffle_order:
59
+ random.shuffle(model_names)
60
+
61
+ self.taskpool = cast(CLIPVisionModelTaskPool, self._program.taskpool)
62
+ self._test_datasets = deepcopy(self.taskpool._test_datasets)
63
+ """Configuration for the test datasets"""
64
+
65
+ # log the model names
66
+ if self.log_dir is not None:
67
+ save_to_json(model_names, Path(self.log_dir) / "model_names.json")
68
+ tensorboard_summarywriter: "SummaryWriter" = self.tensorboard_summarywriter
69
+ tensorboard_summarywriter.add_text(
70
+ "global/model_names", str(model_names), global_step=0
71
+ )
72
+
73
+ # get the average model
74
+ merged_model = modelpool.load_model(model_names[0])
75
+
76
+ if self.evaluate_on_every_step:
77
+ self.taskpool._is_setup = False
78
+ self.taskpool._test_datasets = DictConfig(
79
+ {model_names[0]: self._test_datasets[model_names[0]]}
80
+ )
81
+ report = self.taskpool.evaluate(deepcopy(merged_model))
82
+ save_to_json(report, Path(self.log_dir) / "report_0.json")
83
+
84
+ if self.save_on_every_step:
85
+ self.save_merged_model(merged_model, 0)
86
+
87
+ for model_idx, model_name in tqdm(
88
+ enumerate(model_names[1:]), desc="Processing models"
89
+ ):
90
+ model_idx += 1
91
+ task_model = modelpool.load_model(model_name)
92
+
93
+ for param_name, param in task_model.named_parameters():
94
+ if not param.requires_grad:
95
+ continue
96
+
97
+ task_param = param
98
+ merged_param = merged_model.get_parameter(param_name)
99
+
100
+ new_param = (merged_param * model_idx + task_param) / (model_idx + 1)
101
+ merged_model.get_parameter(param_name).data = new_param
102
+
103
+ if self.save_on_every_step:
104
+ self.save_merged_model(merged_model, model_idx)
105
+
106
+ if self.evaluate_on_every_step:
107
+ self.taskpool._is_setup = False
108
+ self.taskpool._test_datasets = DictConfig(
109
+ {n: self._test_datasets[n] for n in model_names[: model_idx + 1]}
110
+ )
111
+ report = self.taskpool.evaluate(deepcopy(merged_model))
112
+ save_to_json(report, Path(self.log_dir) / f"report_{model_idx}.json")
113
+
114
+ return merged_model
115
+
116
+ def save_merged_model(self, merged_model: CLIPVisionModel, step: int):
117
+ os.makedirs(Path(self.log_dir) / "checkpoints", exist_ok=True)
118
+ merged_model.save_pretrained(
119
+ Path(self.log_dir) / "checkpoints" / f"merged_model_{step}"
120
+ )
@@ -51,7 +51,7 @@ class SlerpMergeAlgorithm(BaseAlgorithm):
51
51
  General purpose implementation of Slerp (Spherical Linear Interpolation) for PyTorch models.
52
52
  """
53
53
 
54
- _config_mapping = BaseAlgorithm._config_mapping + {
54
+ _config_mapping = BaseAlgorithm._config_mapping | {
55
55
  "t": "t",
56
56
  "DOT_THRESHOLD": "DOT_THRESHOLD",
57
57
  "epsilon": "epsilon",
@@ -9,15 +9,20 @@ fusion_bench \
9
9
  ```
10
10
  """
11
11
 
12
- from typing import List, Optional
12
+ from typing import List, Optional, Union, Iterable
13
13
 
14
14
  import torch
15
15
  from torch import Tensor, nn
16
+ from omegaconf import ListConfig
16
17
 
17
18
  from fusion_bench import BaseAlgorithm
18
19
  from fusion_bench.mixins import LightningFabricMixin
19
20
  from fusion_bench.utils import timeit_context
20
- from fusion_bench.utils.state_dict_arithmetic import state_dict_sub, state_dict_add
21
+ from fusion_bench.utils.state_dict_arithmetic import (
22
+ state_dict_add,
23
+ state_dict_sub,
24
+ state_dict_mul,
25
+ )
21
26
  from fusion_bench.utils.type import StateDictType
22
27
 
23
28
  from .utils import (
@@ -33,9 +38,11 @@ class TaskSingularVectorMerging(BaseAlgorithm, LightningFabricMixin):
33
38
 
34
39
  def __init__(
35
40
  self,
41
+ alpha: Union[float, Iterable[float]] = None,
36
42
  remove_keys: Optional[List[str]] = None,
37
43
  **kwargs,
38
44
  ):
45
+ self.alpha = alpha
39
46
  self.remove_keys = remove_keys if remove_keys is not None else []
40
47
  super().__init__(**kwargs)
41
48
 
@@ -50,6 +57,14 @@ class TaskSingularVectorMerging(BaseAlgorithm, LightningFabricMixin):
50
57
 
51
58
  with timeit_context("Flattening out Checkpoints"):
52
59
  task_vectors = [state_dict_sub(check, ptm_check) for check in ft_checks]
60
+ if isinstance(self.alpha, Iterable):
61
+ assert len(self.alpha) == len(
62
+ task_vectors
63
+ ), "Alpha and task vectors must have the same length"
64
+ task_vectors = [
65
+ state_dict_mul(state_dict=tv, scalar=alpha)
66
+ for alpha, tv in zip(self.alpha, task_vectors)
67
+ ]
53
68
 
54
69
  new_merged_tv = TSVM_utils.compute_and_sum_svd_mem_reduction(
55
70
  task_vectors,
@@ -57,6 +72,11 @@ class TaskSingularVectorMerging(BaseAlgorithm, LightningFabricMixin):
57
72
  accelerator=self.fabric.device,
58
73
  )
59
74
 
75
+ # If alpha is a float, we need to scale the new merged task vector by alpha
76
+ if self.alpha is not None and isinstance(self.alpha, float):
77
+ print(f"Scaling new merged task vector by alpha: {self.alpha}")
78
+ new_merged_tv = state_dict_mul(state_dict=new_merged_tv, scalar=self.alpha)
79
+
60
80
  pretrained_model.load_state_dict(
61
81
  state_dict_add(new_merged_tv, pretrained_model.state_dict())
62
82
  )