fusion-bench 0.2.22__py3-none-any.whl → 0.2.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. fusion_bench/__init__.py +4 -0
  2. fusion_bench/compat/method/__init__.py +5 -2
  3. fusion_bench/compat/method/base_algorithm.py +3 -2
  4. fusion_bench/compat/modelpool/base_pool.py +3 -3
  5. fusion_bench/compat/taskpool/clip_image_classification.py +1 -1
  6. fusion_bench/dataset/gpt2_glue.py +1 -1
  7. fusion_bench/method/__init__.py +12 -2
  8. fusion_bench/method/analysis/task_vector_cos_similarity.py +95 -12
  9. fusion_bench/method/analysis/task_vector_violin_plot.py +160 -52
  10. fusion_bench/method/bitdelta/bitdelta.py +7 -23
  11. fusion_bench/method/ensemble.py +17 -2
  12. fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +2 -0
  13. fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +2 -0
  14. fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +2 -0
  15. fusion_bench/method/linear/__init__.py +6 -2
  16. fusion_bench/method/linear/{simple_average_for_llama.py → simple_average_for_causallm.py} +8 -4
  17. fusion_bench/method/linear/{task_arithmetic_for_llama.py → task_arithmetic_for_causallm.py} +22 -12
  18. fusion_bench/method/linear/ties_merging_for_causallm.py +70 -0
  19. fusion_bench/method/model_stock/__init__.py +1 -0
  20. fusion_bench/method/model_stock/model_stock.py +309 -0
  21. fusion_bench/method/regmean/clip_regmean.py +3 -6
  22. fusion_bench/method/regmean/regmean.py +27 -56
  23. fusion_bench/method/regmean/utils.py +56 -0
  24. fusion_bench/method/regmean_plusplus/regmean_plusplus.py +21 -60
  25. fusion_bench/method/simple_average.py +2 -2
  26. fusion_bench/method/slerp/__init__.py +1 -1
  27. fusion_bench/method/slerp/slerp.py +110 -14
  28. fusion_bench/method/task_arithmetic/task_arithmetic.py +35 -10
  29. fusion_bench/method/ties_merging/ties_merging.py +22 -6
  30. fusion_bench/method/we_moe/flan_t5_we_moe.py +9 -20
  31. fusion_bench/method/wudi/__init__.py +1 -0
  32. fusion_bench/method/wudi/wudi.py +105 -0
  33. fusion_bench/mixins/clip_classification.py +26 -6
  34. fusion_bench/mixins/lightning_fabric.py +4 -0
  35. fusion_bench/mixins/serialization.py +40 -83
  36. fusion_bench/modelpool/base_pool.py +1 -1
  37. fusion_bench/modelpool/causal_lm/causal_lm.py +285 -44
  38. fusion_bench/modelpool/seq2seq_lm/modelpool.py +146 -0
  39. fusion_bench/models/hf_clip.py +4 -0
  40. fusion_bench/models/hf_utils.py +10 -4
  41. fusion_bench/models/linearized/vision_model.py +6 -6
  42. fusion_bench/models/model_card_templates/default.md +8 -1
  43. fusion_bench/models/modeling_smile_mistral/__init__.py +1 -0
  44. fusion_bench/models/we_moe.py +8 -8
  45. fusion_bench/models/wrappers/ensemble.py +136 -7
  46. fusion_bench/scripts/cli.py +2 -2
  47. fusion_bench/taskpool/base_pool.py +99 -17
  48. fusion_bench/taskpool/clip_vision/taskpool.py +12 -5
  49. fusion_bench/taskpool/dummy.py +101 -13
  50. fusion_bench/taskpool/lm_eval_harness/taskpool.py +80 -0
  51. fusion_bench/taskpool/nyuv2_taskpool.py +28 -0
  52. fusion_bench/utils/__init__.py +1 -0
  53. fusion_bench/utils/data.py +6 -4
  54. fusion_bench/utils/devices.py +36 -11
  55. fusion_bench/utils/dtype.py +3 -2
  56. fusion_bench/utils/lazy_state_dict.py +85 -19
  57. fusion_bench/utils/packages.py +3 -3
  58. fusion_bench/utils/parameters.py +0 -2
  59. fusion_bench/utils/rich_utils.py +7 -3
  60. fusion_bench/utils/timer.py +92 -10
  61. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/METADATA +10 -3
  62. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/RECORD +77 -64
  63. fusion_bench_config/_get_started/llm_slerp.yaml +12 -0
  64. fusion_bench_config/method/ensemble/simple_ensemble.yaml +1 -0
  65. fusion_bench_config/method/linear/{simple_average_for_llama.yaml → simple_average_for_causallm.yaml} +1 -1
  66. fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +4 -0
  67. fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +13 -0
  68. fusion_bench_config/method/model_stock/model_stock.yaml +12 -0
  69. fusion_bench_config/method/slerp/slerp_lm.yaml +4 -0
  70. fusion_bench_config/method/wudi/wudi.yaml +4 -0
  71. fusion_bench_config/modelpool/CausalLMPool/{Qwen2.5-1.5B_math_and_coder.yaml → Qwen2.5-1.5B_math_and_code.yaml} +1 -2
  72. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_three_models.yaml +11 -0
  73. fusion_bench_config/modelpool/CausalLMPool/llama-7b_3-models_v1.yaml +11 -0
  74. fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +0 -4
  75. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/WHEEL +0 -0
  76. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/entry_points.txt +0 -0
  77. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/licenses/LICENSE +0 -0
  78. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/top_level.txt +0 -0
fusion_bench/__init__.py CHANGED
@@ -32,6 +32,10 @@ from .models import (
32
32
  from .programs import BaseHydraProgram
33
33
  from .taskpool import BaseTaskPool
34
34
  from .utils import (
35
+ BoolStateDictType,
36
+ LazyStateDict,
37
+ StateDictType,
38
+ TorchModelType,
35
39
  cache_with_joblib,
36
40
  get_rankzero_logger,
37
41
  import_object,
@@ -1,4 +1,5 @@
1
1
  import warnings
2
+ from typing import Any, List, Type
2
3
 
3
4
  from omegaconf import DictConfig
4
5
 
@@ -76,7 +77,9 @@ class AlgorithmFactory:
76
77
  return algorithm_cls(method_config)
77
78
 
78
79
  @staticmethod
79
- def register_algorithm(name: str, algorithm_cls):
80
+ def register_algorithm(
81
+ name: str, algorithm_cls: Type[ModelFusionAlgorithm]
82
+ ) -> None:
80
83
  """
81
84
  Register a new algorithm with the factory.
82
85
 
@@ -87,7 +90,7 @@ class AlgorithmFactory:
87
90
  AlgorithmFactory._aglorithms[name] = algorithm_cls
88
91
 
89
92
  @classmethod
90
- def available_algorithms(cls):
93
+ def available_algorithms(cls) -> List[str]:
91
94
  """
92
95
  Get a list of available algorithms.
93
96
 
@@ -1,9 +1,10 @@
1
1
  from abc import ABC, abstractmethod
2
- from typing import TYPE_CHECKING, Optional
2
+ from typing import TYPE_CHECKING, Any, Optional
3
3
 
4
4
  from omegaconf import DictConfig
5
5
 
6
6
  if TYPE_CHECKING:
7
+ from fusion_bench import BaseModelPool
7
8
  from fusion_bench.programs.base_program import BaseHydraProgram
8
9
 
9
10
  __all__ = ["ModelFusionAlgorithm"]
@@ -51,7 +52,7 @@ class ModelFusionAlgorithm(ABC):
51
52
  pass
52
53
 
53
54
  @abstractmethod
54
- def run(self, modelpool):
55
+ def run(self, modelpool: "BaseModelPool") -> Any:
55
56
  """
56
57
  Fuse the models in the given model pool.
57
58
 
@@ -42,7 +42,7 @@ class ModelPool(ABC):
42
42
  ), "Duplicate model names found in model pool"
43
43
  self._model_names = model_names
44
44
 
45
- def __len__(self):
45
+ def __len__(self) -> int:
46
46
  """
47
47
  Return the number of models in the model pool, exclude special models such as `_pretrained_`.
48
48
 
@@ -66,7 +66,7 @@ class ModelPool(ABC):
66
66
  return names
67
67
 
68
68
  @property
69
- def has_pretrained(self):
69
+ def has_pretrained(self) -> bool:
70
70
  """
71
71
  Check if the pretrained model is available in the model pool.
72
72
 
@@ -78,7 +78,7 @@ class ModelPool(ABC):
78
78
  return True
79
79
  return False
80
80
 
81
- def get_model_config(self, model_name: str):
81
+ def get_model_config(self, model_name: str) -> Dict:
82
82
  """
83
83
  Retrieves the configuration for a specific model from the model pool.
84
84
 
@@ -169,7 +169,7 @@ class CLIPImageClassificationTaskPool(TaskPool):
169
169
  self._fabric = L.Fabric(devices=1)
170
170
  self._fabric.launch()
171
171
 
172
- # CLIPVisionModel works the same with CLIPVisonTransformer, so we can use it directly
172
+ # CLIPVisionModel works the same with CLIPVisionTransformer, so we can use it directly
173
173
  self.clip_model.vision_model = model
174
174
  report = {}
175
175
  training_params, all_params = count_parameters(model)
@@ -121,7 +121,7 @@ class TokenizedGLUE:
121
121
 
122
122
  def load_dataset(
123
123
  self, name: Literal["mrpc", "mnli", "cola", "sst2", "qnli", "qqp", "rte"]
124
- ):
124
+ ) -> Dataset:
125
125
  """
126
126
  Load and tokenize a GLUE dataset.
127
127
 
@@ -26,11 +26,14 @@ _import_structure = {
26
26
  "linear": [
27
27
  "ExPOAlgorithm",
28
28
  "ExPOAlgorithmForLlama",
29
+ "SimpleAverageForCausalLM",
29
30
  "SimpleAverageForLlama",
31
+ "TaskArithmeticForCausalLM",
30
32
  "TaskArithmeticForLlama",
31
33
  "LinearInterpolationAlgorithm",
34
+ "TiesMergingForCausalLM",
32
35
  ],
33
- "slerp": ["SlerpMergeAlgorithm"],
36
+ "slerp": ["SlerpMergeAlgorithm", "SlerpForCausalLM"],
34
37
  "simple_average": ["SimpleAverageAlgorithm"],
35
38
  "weighted_average": ["WeightedAverageAlgorithm", "WeightedAverageForLLama"],
36
39
  "task_arithmetic": ["TaskArithmeticAlgorithm"],
@@ -71,6 +74,8 @@ _import_structure = {
71
74
  ],
72
75
  "fw_merging": ["FrankWolfeHardAlgorithm", "FrankWolfeSoftAlgorithm"],
73
76
  "tall_mask": ["TallMaskTaskArithmeticAlgorithm"],
77
+ "model_stock": ["ModelStock"],
78
+ "wudi": ["wudi_merging", "WUDIMerging"],
74
79
  # plug-and-play model merging methods
75
80
  "concrete_subspace": [
76
81
  "ConcreteTaskArithmeticAlgorithmForCLIP",
@@ -183,8 +188,11 @@ if TYPE_CHECKING:
183
188
  ExPOAlgorithm,
184
189
  ExPOAlgorithmForLlama,
185
190
  LinearInterpolationAlgorithm,
191
+ SimpleAverageForCausalLM,
186
192
  SimpleAverageForLlama,
193
+ TaskArithmeticForCausalLM,
187
194
  TaskArithmeticForLlama,
195
+ TiesMergingForCausalLM,
188
196
  )
189
197
  from .lm_finetune import *
190
198
  from .mixture_of_experts import (
@@ -194,6 +202,7 @@ if TYPE_CHECKING:
194
202
  MixtralUpscalingAlgorithm,
195
203
  )
196
204
  from .model_recombination import ModelRecombinationAlgorithm
205
+ from .model_stock import ModelStock
197
206
  from .opcm import OPCMForCLIP
198
207
  from .pruning import (
199
208
  MagnitudeDiffPruningAlgorithm,
@@ -213,7 +222,7 @@ if TYPE_CHECKING:
213
222
  RegMeanAlgorithmPlusPlus,
214
223
  )
215
224
  from .simple_average import SimpleAverageAlgorithm
216
- from .slerp import SlerpMergeAlgorithm
225
+ from .slerp import SlerpForCausalLM, SlerpMergeAlgorithm
217
226
  from .smile_upscaling import (
218
227
  SingularProjectionMergingAlgorithm,
219
228
  SmileUpscalingAlgorithm,
@@ -236,6 +245,7 @@ if TYPE_CHECKING:
236
245
  FlanT5WeightEnsemblingMoEAlgorithm,
237
246
  )
238
247
  from .weighted_average import WeightedAverageAlgorithm, WeightedAverageForLLama
248
+ from .wudi import WUDIMerging, wudi_merging
239
249
 
240
250
  else:
241
251
  sys.modules[__name__] = LazyImporter(
@@ -11,7 +11,7 @@ from torch import nn
11
11
  from tqdm.auto import tqdm
12
12
 
13
13
  from fusion_bench.method import BaseAlgorithm
14
- from fusion_bench.mixins import LightningFabricMixin
14
+ from fusion_bench.mixins import LightningFabricMixin, auto_register_config
15
15
  from fusion_bench.modelpool import BaseModelPool
16
16
  from fusion_bench.utils.parameters import (
17
17
  StateDictType,
@@ -23,14 +23,50 @@ from fusion_bench.utils.state_dict_arithmetic import state_dict_sub
23
23
  log = logging.getLogger(__name__)
24
24
 
25
25
 
26
- class TaskVectorCosSimilarity(BaseAlgorithm, LightningFabricMixin):
26
+ @auto_register_config
27
+ class TaskVectorCosSimilarity(
28
+ LightningFabricMixin,
29
+ BaseAlgorithm,
30
+ ):
27
31
  """
28
- This class is similar to the Dummy algorithm,
29
- but it also print (or save) the cosine similarity matrix between the task vectors of the models in the model pool.
32
+ Computes and analyzes cosine similarity between task vectors of models in a model pool.
33
+
34
+ This algorithm extracts task vectors from fine-tuned models by computing the difference
35
+ between their parameters and a pretrained base model. It then calculates the pairwise
36
+ cosine similarity between all task vectors to understand the relationships and overlap
37
+ between different tasks.
38
+
39
+ The task vector for a model is defined as:
40
+ task_vector = finetuned_model_params - pretrained_model_params
41
+
42
+ Args:
43
+ plot_heatmap (bool): Whether to generate and save a heatmap visualization
44
+ trainable_only (bool, optional): If True, only consider trainable parameters
45
+ when computing task vectors. Defaults to True.
46
+ max_points_per_model (int, optional): Maximum number of parameters to sample
47
+ per model for memory efficiency. If None, uses all parameters.
48
+ output_path (str, optional): Directory to save outputs. If None, uses the
49
+ fabric logger directory.
50
+
51
+ Outputs:
52
+ - task_vector_cos_similarity.csv: Pairwise cosine similarity matrix
53
+ - task_vector_cos_similarity.pdf: Heatmap visualization (if plot_heatmap=True)
54
+
55
+ Returns:
56
+ The pretrained model from the model pool.
57
+
58
+ Example:
59
+ ```python
60
+ >>> algorithm = TaskVectorCosSimilarity(
61
+ ... plot_heatmap=True,
62
+ ... trainable_only=True,
63
+ ... output_path="/path/to/outputs"
64
+ ... )
65
+ >>> result = algorithm.run(modelpool)
66
+ ```
30
67
  """
31
68
 
32
69
  _config_mapping = BaseAlgorithm._config_mapping | {
33
- "plot_heatmap": "plot_heatmap",
34
70
  "_output_path": "output_path",
35
71
  }
36
72
 
@@ -42,11 +78,8 @@ class TaskVectorCosSimilarity(BaseAlgorithm, LightningFabricMixin):
42
78
  output_path: Optional[str] = None,
43
79
  **kwargs,
44
80
  ):
45
- self.plot_heatmap = plot_heatmap
46
- self.trainable_only = trainable_only
47
- self.max_points_per_model = max_points_per_model
48
- self._output_path = output_path
49
81
  super().__init__(**kwargs)
82
+ self._output_path = output_path
50
83
 
51
84
  @property
52
85
  def output_path(self):
@@ -57,6 +90,22 @@ class TaskVectorCosSimilarity(BaseAlgorithm, LightningFabricMixin):
57
90
 
58
91
  @torch.no_grad()
59
92
  def run(self, modelpool: BaseModelPool):
93
+ """
94
+ Execute the task vector cosine similarity analysis.
95
+
96
+ This method:
97
+ 1. Loads the pretrained base model from the model pool
98
+ 2. Computes task vectors for each fine-tuned model
99
+ 3. Calculates pairwise cosine similarities between all task vectors
100
+ 4. Saves the similarity matrix as a CSV file
101
+ 5. Optionally generates and saves a heatmap visualization
102
+
103
+ Args:
104
+ modelpool (BaseModelPool): Pool containing pretrained and fine-tuned models
105
+
106
+ Returns:
107
+ nn.Module: The pretrained model from the model pool
108
+ """
60
109
  pretrained_model = modelpool.load_pretrained_model()
61
110
 
62
111
  task_vectors = []
@@ -103,11 +152,14 @@ class TaskVectorCosSimilarity(BaseAlgorithm, LightningFabricMixin):
103
152
 
104
153
  def _plot_heatmap(self, data: pd.DataFrame):
105
154
  """
106
- This function plots a heatmap of the provided data using seaborn.
155
+ Generate and save a heatmap visualization of the cosine similarity matrix.
156
+
157
+ Creates a color-coded heatmap showing pairwise cosine similarities between
158
+ task vectors. The heatmap is saved as a PDF file in the output directory.
107
159
 
108
160
  Args:
109
- data (pd.DataFrame): A pandas DataFrame containing the data to be plotted.
110
- figsize (tuple): A tuple specifying the size of the figure. Default is (4, 3).
161
+ data (pd.DataFrame): Symmetric matrix of cosine similarities between
162
+ task vectors, with model names as both index and columns.
111
163
 
112
164
  Returns:
113
165
  None
@@ -141,6 +193,26 @@ class TaskVectorCosSimilarity(BaseAlgorithm, LightningFabricMixin):
141
193
  def get_task_vector(
142
194
  self, pretrained_model: nn.Module, finetuned_model: nn.Module
143
195
  ) -> torch.Tensor:
196
+ """
197
+ Compute the task vector for a fine-tuned model.
198
+
199
+ The task vector represents the parameter changes from pretraining to
200
+ fine-tuning and is computed as:
201
+ task_vector = finetuned_params - pretrained_params
202
+
203
+ Args:
204
+ pretrained_model (nn.Module): The base pretrained model
205
+ finetuned_model (nn.Module): The fine-tuned model for a specific task
206
+
207
+ Returns:
208
+ torch.Tensor: Flattened task vector containing parameter differences.
209
+ If max_points_per_model is set, the vector may be downsampled.
210
+
211
+ Note:
212
+ - Converts parameters to float64 for numerical precision
213
+ - Supports optional downsampling for memory efficiency
214
+ - Uses only trainable parameters if trainable_only=True
215
+ """
144
216
  task_vector = state_dict_sub(
145
217
  self.get_state_dict(finetuned_model),
146
218
  self.get_state_dict(pretrained_model),
@@ -166,6 +238,17 @@ class TaskVectorCosSimilarity(BaseAlgorithm, LightningFabricMixin):
166
238
  return task_vector
167
239
 
168
240
  def get_state_dict(self, model: nn.Module):
241
+ """
242
+ Extract the state dictionary from a model.
243
+
244
+ Args:
245
+ model (nn.Module): The model to extract parameters from
246
+
247
+ Returns:
248
+ Dict[str, torch.Tensor]: State dictionary containing model parameters.
249
+ Returns only trainable parameters if trainable_only=True,
250
+ otherwise returns all parameters.
251
+ """
169
252
  if self.trainable_only:
170
253
  return trainable_state_dict(model)
171
254
  else:
@@ -11,30 +11,81 @@ from numpy.typing import NDArray
11
11
  from torch import nn
12
12
  from tqdm.auto import tqdm
13
13
 
14
- from fusion_bench import BaseAlgorithm, BaseModelPool
15
- from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
16
- from fusion_bench.utils import timeit_context
17
- from fusion_bench.utils.parameters import (
18
- StateDictType,
19
- state_dict_to_vector,
20
- trainable_state_dict,
14
+ from fusion_bench import BaseAlgorithm, BaseModelPool, StateDictType, timeit_context
15
+ from fusion_bench.mixins import (
16
+ LightningFabricMixin,
17
+ SimpleProfilerMixin,
18
+ auto_register_config,
21
19
  )
20
+ from fusion_bench.utils import state_dict_to_vector, trainable_state_dict
22
21
  from fusion_bench.utils.state_dict_arithmetic import state_dict_sub
23
22
 
24
23
  log = logging.getLogger(__name__)
25
24
 
26
25
 
27
- class TaskVectorViolinPlot(BaseAlgorithm, LightningFabricMixin, SimpleProfilerMixin):
28
- R"""
29
- Plot violin plots of task vectors as in:
30
- [L.Shen, A.Tang, E.Yang et al. Efficient and Effective Weight-Ensembling Mixture of Experts for Multi-Task Model Merging](https://arxiv.org/abs/2410.21804)
26
+ @auto_register_config
27
+ class TaskVectorViolinPlot(
28
+ LightningFabricMixin,
29
+ SimpleProfilerMixin,
30
+ BaseAlgorithm,
31
+ ):
32
+ """
33
+ Creates violin plots to visualize the distribution of task vector values across models.
34
+
35
+ This class implements the task vector visualization technique described in:
36
+ "Efficient and Effective Weight-Ensembling Mixture of Experts for Multi-Task Model Merging"
37
+ by L. Shen, A. Tang, E. Yang et al. (https://arxiv.org/abs/2410.21804)
38
+
39
+ Task vectors represent the parameter differences between fine-tuned models and their
40
+ pretrained base model, computed as:
41
+ task_vector = finetuned_params - pretrained_params
42
+
43
+ The algorithm generates two types of violin plots:
44
+ 1. Distribution of raw task vector values (positive and negative)
45
+ 2. Distribution of absolute task vector values
46
+
47
+ Args:
48
+ trainable_only (bool): If True, only consider trainable parameters when computing
49
+ task vectors. If False, use all parameters.
50
+ max_points_per_model (int, optional): Maximum number of parameters to sample
51
+ per model for memory efficiency. If None or 0, uses all parameters.
52
+ Defaults to 1000.
53
+ fig_kwargs (dict, optional): Dictionary of keyword arguments to pass to
54
+ matplotlib.pyplot.subplots. Common options include:
55
+ - figsize: Tuple of (width, height) in inches
56
+ - dpi: Dots per inch for resolution
57
+ - facecolor: Figure background color
58
+ Defaults to None.
59
+ output_path (str, optional): Directory to save the violin plots. If None,
60
+ uses the fabric logger's log directory. Defaults to None.
61
+
62
+ Outputs:
63
+ - task_vector_violin.pdf: Violin plot of raw task vector value distributions
64
+ - task_vector_violin_abs.pdf: Violin plot of absolute task vector value distributions
65
+
66
+ Returns:
67
+ The pretrained model from the model pool.
68
+
69
+ Example:
70
+ ```python
71
+ plotter = TaskVectorViolinPlot(
72
+ trainable_only=True,
73
+ max_points_per_model=5000,
74
+ fig_kwargs={'figsize': (12, 8), 'dpi': 300},
75
+ output_path='./analysis_plots'
76
+ )
77
+ pretrained_model = plotter.run(modelpool)
78
+ ```
79
+
80
+ Note:
81
+ This visualization is particularly useful for understanding:
82
+ - How different tasks affect model parameters
83
+ - The magnitude and distribution of parameter changes
84
+ - Similarities and differences between task adaptations
31
85
  """
32
86
 
33
87
  # config_mapping is a mapping from the attributes to the key in the configuration files
34
88
  _config_mapping = BaseAlgorithm._config_mapping | {
35
- "trainable_only": "trainable_only",
36
- "max_points_per_model": "max_points_per_model",
37
- "fig_kwargs": "fig_kwargs",
38
89
  "_output_path": "output_path",
39
90
  }
40
91
 
@@ -46,40 +97,34 @@ class TaskVectorViolinPlot(BaseAlgorithm, LightningFabricMixin, SimpleProfilerMi
46
97
  output_path: Optional[str] = None,
47
98
  **kwargs,
48
99
  ):
49
- R"""
50
- This class creates violin plots to visualize task vectors, which represent the differences
51
- between fine-tuned models and their pretrained base model.
100
+ """
101
+ Initialize the TaskVectorViolinPlot analyzer.
52
102
 
53
103
  Args:
54
- trainable_only (bool): If True, only consider trainable parameters when computing
55
- task vectors. If False, use all parameters.
56
- fig_kwargs (dict, optional): Dictionary of keyword arguments to pass to
57
- `matplotlib.pyplot.subplots`. Common options include:
58
- - figsize: Tuple of (width, height) in inches
59
- - dpi: Dots per inch
60
- - facecolor: Figure background color
61
- Defaults to None.
62
- output_path (str, optional): Path where the violin plot will be saved. If None,
63
- uses the fabric logger's log directory. Defaults to None.
64
- kwargs: Additional keyword arguments passed to the parent class(es).
65
-
66
- Example:
67
-
68
- ```python
69
- plotter = TaskVectorViolinPlot(
70
- trainable_only=True,
71
- fig_kwargs={'figsize': (10, 6), 'dpi': 300},
72
- output_path='./plots'
73
- )
74
-
75
- plotter.run(modelpool)
76
- ```
104
+ trainable_only (bool): Whether to consider only trainable parameters when
105
+ computing task vectors. Set to True to focus on learnable parameters,
106
+ False to include all parameters including frozen ones.
107
+ max_points_per_model (int, optional): Maximum number of parameter values
108
+ to sample per model for visualization. Useful for large models to
109
+ manage memory usage and plot clarity. Set to None or 0 to use all
110
+ parameters. Defaults to 1000.
111
+ fig_kwargs (dict, optional): Keyword arguments passed to matplotlib's
112
+ subplots function for plot customization. Examples:
113
+ - {'figsize': (10, 6)} for plot dimensions
114
+ - {'dpi': 300} for high resolution
115
+ - {'facecolor': 'white'} for background color
116
+ Defaults to None (uses matplotlib defaults).
117
+ output_path (str, optional): Directory path where violin plots will be saved.
118
+ If None, uses the fabric logger's log directory. The directory will be
119
+ created if it doesn't exist. Defaults to None.
120
+ **kwargs: Additional keyword arguments passed to parent classes.
121
+
122
+ Note:
123
+ The parameter name 'fig_kwawrgs' appears to be a typo for 'fig_kwargs'.
124
+ This should be corrected in the parameter name for consistency.
77
125
  """
78
- self.trainable_only = trainable_only
79
- self.fig_kwargs = fig_kwawrgs
80
- self.max_points_per_model = max_points_per_model
81
- self._output_path = output_path
82
126
  super().__init__(**kwargs)
127
+ self._output_path = output_path
83
128
 
84
129
  @property
85
130
  def output_path(self):
@@ -89,20 +134,39 @@ class TaskVectorViolinPlot(BaseAlgorithm, LightningFabricMixin, SimpleProfilerMi
89
134
  return self._output_path
90
135
 
91
136
  def run(self, modelpool: BaseModelPool):
92
- """Create violin plots of task vectors comparing different fine-tuned models against a pretrained model.
137
+ """
138
+ Execute the task vector violin plot analysis and visualization.
93
139
 
94
- This method implements the visualization technique from the paper "Efficient and Effective
95
- Weight-Ensembling Mixture of Experts for Multi-Task Model Merging". It:
140
+ This method implements the core algorithm that:
141
+ 1. Loads the pretrained base model from the model pool
142
+ 2. Computes task vectors for each fine-tuned model (parameter differences)
143
+ 3. Creates two violin plots showing the distribution of task vector values:
144
+ - Raw values plot: Shows positive and negative parameter changes
145
+ - Absolute values plot: Shows magnitude of parameter changes
146
+ 4. Saves both plots as PDF files in the output directory
96
147
 
97
- 1. Loads the pretrained model
98
- 2. Computes task vectors (differences between fine-tuned and pretrained models)
99
- 3. Creates violin plots showing the distribution of values in these task vectors
148
+ The visualization technique follows the approach described in:
149
+ "Efficient and Effective Weight-Ensembling Mixture of Experts for Multi-Task Model Merging"
100
150
 
101
151
  Args:
102
- modelpool (BaseModelPool): Model pool containing the pretrained model and fine-tuned models
152
+ modelpool (BaseModelPool): Pool containing both a pretrained model and
153
+ fine-tuned models. Must have `has_pretrained=True`.
103
154
 
104
155
  Returns:
105
- pretrained_model (nn.Model): The plot is saved to the specified output path.
156
+ nn.Module: The pretrained model loaded from the model pool.
157
+
158
+ Raises:
159
+ AssertionError: If the model pool doesn't contain a pretrained model.
160
+
161
+ Side Effects:
162
+ - Creates output directory if it doesn't exist
163
+ - Saves 'task_vector_violin.pdf' (raw values distribution)
164
+ - Saves 'task_vector_violin_abs.pdf' (absolute values distribution)
165
+ - Prints progress information during task vector computation
166
+
167
+ Example Output Files:
168
+ - task_vector_violin.pdf: Shows how parameters change (+ and -)
169
+ - task_vector_violin_abs.pdf: Shows magnitude of parameter changes
106
170
  """
107
171
  assert modelpool.has_pretrained
108
172
  pretrained_model = modelpool.load_pretrained_model()
@@ -175,6 +239,34 @@ class TaskVectorViolinPlot(BaseAlgorithm, LightningFabricMixin, SimpleProfilerMi
175
239
  return pretrained_model
176
240
 
177
241
  def get_task_vector(self, pretrained_model, finetuned_model):
242
+ """
243
+ Compute the task vector representing parameter changes from pretraining to fine-tuning.
244
+
245
+ The task vector quantifies how model parameters have changed during task-specific
246
+ fine-tuning and is computed as:
247
+ task_vector = finetuned_params - pretrained_params
248
+
249
+ Args:
250
+ pretrained_model (nn.Module): The base pretrained model
251
+ finetuned_model (nn.Module): The fine-tuned model for a specific task
252
+
253
+ Returns:
254
+ np.ndarray: Flattened numpy array containing parameter differences.
255
+ If max_points_per_model is set, the array may be randomly downsampled
256
+ for memory efficiency and visualization clarity.
257
+
258
+ Processing Steps:
259
+ 1. Extract state dictionaries from both models
260
+ 2. Compute parameter differences (subtraction)
261
+ 3. Flatten to 1D vector
262
+ 4. Convert to numpy array with float32 precision
263
+ 5. Optionally downsample if max_points_per_model is specified
264
+
265
+ Note:
266
+ - Uses only trainable parameters if trainable_only=True
267
+ - Downsampling uses random sampling without replacement
268
+ - Preserves the relative distribution of parameter changes
269
+ """
178
270
  task_vector = state_dict_sub(
179
271
  self.get_state_dict(finetuned_model),
180
272
  self.get_state_dict(pretrained_model),
@@ -199,6 +291,22 @@ class TaskVectorViolinPlot(BaseAlgorithm, LightningFabricMixin, SimpleProfilerMi
199
291
  return task_vector
200
292
 
201
293
  def get_state_dict(self, model: nn.Module):
294
+ """
295
+ Extract the state dictionary from a model based on parameter filtering settings.
296
+
297
+ Args:
298
+ model (nn.Module): The PyTorch model to extract parameters from
299
+
300
+ Returns:
301
+ Dict[str, torch.Tensor]: State dictionary containing model parameters.
302
+ If trainable_only=True, returns only parameters with requires_grad=True.
303
+ If trainable_only=False, returns all parameters including frozen ones.
304
+
305
+ Note:
306
+ This method respects the trainable_only configuration to focus analysis
307
+ on either learnable parameters or the complete parameter set depending
308
+ on the research question being addressed.
309
+ """
202
310
  if self.trainable_only:
203
311
  return trainable_state_dict(model)
204
312
  else:
@@ -6,7 +6,11 @@ import torch.nn.functional as F
6
6
  from tqdm.auto import tqdm
7
7
 
8
8
  from fusion_bench import BaseAlgorithm, BaseModelPool
9
- from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
9
+ from fusion_bench.mixins import (
10
+ LightningFabricMixin,
11
+ SimpleProfilerMixin,
12
+ auto_register_config,
13
+ )
10
14
  from fusion_bench.modelpool import CausalLMPool
11
15
 
12
16
  from .bitdelta_utils.data import get_dataloader, get_dataset
@@ -15,23 +19,12 @@ from .bitdelta_utils.diff import compress_diff, save_diff, save_full_model
15
19
  log = logging.getLogger(__name__)
16
20
 
17
21
 
22
+ @auto_register_config
18
23
  class BitDeltaAlgorithm(
19
- BaseAlgorithm,
20
24
  LightningFabricMixin,
21
25
  SimpleProfilerMixin,
26
+ BaseAlgorithm,
22
27
  ):
23
- _config_mapping = BaseAlgorithm._config_mapping | {
24
- "save_dir": "save_dir",
25
- "save_full_model": "save_full_model",
26
- "lr": "lr",
27
- "batch_size": "batch_size",
28
- "num_steps": "num_steps",
29
- "dataset_name": "dataset_name",
30
- "subset": "subset",
31
- "split": "split",
32
- "max_length": "max_length",
33
- }
34
-
35
28
  def __init__(
36
29
  self,
37
30
  save_dir: str,
@@ -46,15 +39,6 @@ class BitDeltaAlgorithm(
46
39
  **kwargs,
47
40
  ):
48
41
  super().__init__(**kwargs)
49
- self.save_dir = save_dir
50
- self.save_full_model = save_full_model
51
- self.lr = lr
52
- self.batch_size = batch_size
53
- self.num_steps = num_steps
54
- self.dataset_name = dataset_name
55
- self.subset = subset
56
- self.split = split
57
- self.max_length = max_length
58
42
 
59
43
  def run(self, modelpool: CausalLMPool):
60
44
  if self.save_dir is None: