fusion-bench 0.2.21__py3-none-any.whl → 0.2.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. fusion_bench/__init__.py +25 -2
  2. fusion_bench/compat/method/__init__.py +5 -2
  3. fusion_bench/compat/method/base_algorithm.py +3 -2
  4. fusion_bench/compat/modelpool/base_pool.py +3 -3
  5. fusion_bench/compat/taskpool/clip_image_classification.py +1 -1
  6. fusion_bench/constants/__init__.py +1 -0
  7. fusion_bench/constants/runtime.py +57 -0
  8. fusion_bench/dataset/gpt2_glue.py +1 -1
  9. fusion_bench/method/__init__.py +12 -4
  10. fusion_bench/method/analysis/task_vector_cos_similarity.py +95 -12
  11. fusion_bench/method/analysis/task_vector_violin_plot.py +160 -52
  12. fusion_bench/method/bitdelta/__init__.py +1 -0
  13. fusion_bench/method/bitdelta/bitdelta.py +7 -23
  14. fusion_bench/method/classification/clip_finetune.py +1 -1
  15. fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +2 -0
  16. fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +2 -0
  17. fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +2 -0
  18. fusion_bench/method/fisher_merging/clip_fisher_merging.py +0 -4
  19. fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +2 -2
  20. fusion_bench/method/linear/simple_average_for_llama.py +16 -11
  21. fusion_bench/method/model_stock/__init__.py +1 -0
  22. fusion_bench/method/model_stock/model_stock.py +309 -0
  23. fusion_bench/method/regmean/clip_regmean.py +3 -6
  24. fusion_bench/method/regmean/regmean.py +27 -56
  25. fusion_bench/method/regmean/utils.py +56 -0
  26. fusion_bench/method/regmean_plusplus/regmean_plusplus.py +21 -60
  27. fusion_bench/method/simple_average.py +7 -7
  28. fusion_bench/method/slerp/__init__.py +1 -1
  29. fusion_bench/method/slerp/slerp.py +110 -14
  30. fusion_bench/method/smile_upscaling/causal_lm_upscaling.py +371 -0
  31. fusion_bench/method/smile_upscaling/projected_energy.py +1 -2
  32. fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +5 -1
  33. fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +40 -31
  34. fusion_bench/method/smile_upscaling/smile_upscaling.py +1 -1
  35. fusion_bench/method/we_moe/__init__.py +1 -0
  36. fusion_bench/method/we_moe/entropy_loss.py +25 -0
  37. fusion_bench/method/we_moe/flan_t5_we_moe.py +320 -0
  38. fusion_bench/method/we_moe/utils.py +15 -0
  39. fusion_bench/method/weighted_average/llama.py +1 -1
  40. fusion_bench/mixins/clip_classification.py +37 -48
  41. fusion_bench/mixins/serialization.py +30 -10
  42. fusion_bench/modelpool/base_pool.py +1 -1
  43. fusion_bench/modelpool/causal_lm/causal_lm.py +293 -75
  44. fusion_bench/modelpool/seq2seq_lm/modelpool.py +146 -0
  45. fusion_bench/models/__init__.py +5 -0
  46. fusion_bench/models/hf_utils.py +69 -86
  47. fusion_bench/models/linearized/vision_model.py +6 -6
  48. fusion_bench/models/model_card_templates/default.md +46 -0
  49. fusion_bench/models/modeling_smile_llama/__init__.py +7 -0
  50. fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +1 -8
  51. fusion_bench/models/modeling_smile_mistral/__init__.py +2 -1
  52. fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +1 -5
  53. fusion_bench/models/we_moe.py +8 -8
  54. fusion_bench/programs/fabric_fusion_program.py +29 -60
  55. fusion_bench/scripts/cli.py +34 -1
  56. fusion_bench/taskpool/base_pool.py +99 -17
  57. fusion_bench/taskpool/clip_vision/taskpool.py +10 -5
  58. fusion_bench/taskpool/dummy.py +101 -13
  59. fusion_bench/taskpool/lm_eval_harness/taskpool.py +80 -0
  60. fusion_bench/taskpool/nyuv2_taskpool.py +28 -0
  61. fusion_bench/utils/__init__.py +2 -0
  62. fusion_bench/utils/cache_utils.py +101 -1
  63. fusion_bench/utils/data.py +6 -4
  64. fusion_bench/utils/devices.py +7 -4
  65. fusion_bench/utils/dtype.py +3 -2
  66. fusion_bench/utils/fabric.py +2 -2
  67. fusion_bench/utils/lazy_imports.py +23 -0
  68. fusion_bench/utils/lazy_state_dict.py +117 -19
  69. fusion_bench/utils/modelscope.py +3 -3
  70. fusion_bench/utils/packages.py +3 -3
  71. fusion_bench/utils/parameters.py +0 -2
  72. fusion_bench/utils/path.py +56 -0
  73. fusion_bench/utils/pylogger.py +1 -1
  74. fusion_bench/utils/timer.py +92 -10
  75. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/METADATA +1 -23
  76. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/RECORD +89 -75
  77. fusion_bench_config/_get_started/llm_slerp.yaml +12 -0
  78. fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +0 -1
  79. fusion_bench_config/method/linear/simple_average_for_llama.yaml +3 -2
  80. fusion_bench_config/method/model_stock/model_stock.yaml +12 -0
  81. fusion_bench_config/method/slerp/slerp_lm.yaml +4 -0
  82. fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +21 -0
  83. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -1
  84. fusion_bench_config/method/wemoe/flan_t5_weight_ensembling_moe.yaml +20 -0
  85. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +1 -1
  86. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/WHEEL +0 -0
  87. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/entry_points.txt +0 -0
  88. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/licenses/LICENSE +0 -0
  89. {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/top_level.txt +0 -0
@@ -11,30 +11,81 @@ from numpy.typing import NDArray
11
11
  from torch import nn
12
12
  from tqdm.auto import tqdm
13
13
 
14
- from fusion_bench import BaseAlgorithm, BaseModelPool
15
- from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
16
- from fusion_bench.utils import timeit_context
17
- from fusion_bench.utils.parameters import (
18
- StateDictType,
19
- state_dict_to_vector,
20
- trainable_state_dict,
14
+ from fusion_bench import BaseAlgorithm, BaseModelPool, StateDictType, timeit_context
15
+ from fusion_bench.mixins import (
16
+ LightningFabricMixin,
17
+ SimpleProfilerMixin,
18
+ auto_register_config,
21
19
  )
20
+ from fusion_bench.utils import state_dict_to_vector, trainable_state_dict
22
21
  from fusion_bench.utils.state_dict_arithmetic import state_dict_sub
23
22
 
24
23
  log = logging.getLogger(__name__)
25
24
 
26
25
 
27
- class TaskVectorViolinPlot(BaseAlgorithm, LightningFabricMixin, SimpleProfilerMixin):
28
- R"""
29
- Plot violin plots of task vectors as in:
30
- [L.Shen, A.Tang, E.Yang et al. Efficient and Effective Weight-Ensembling Mixture of Experts for Multi-Task Model Merging](https://arxiv.org/abs/2410.21804)
26
+ @auto_register_config
27
+ class TaskVectorViolinPlot(
28
+ LightningFabricMixin,
29
+ SimpleProfilerMixin,
30
+ BaseAlgorithm,
31
+ ):
32
+ """
33
+ Creates violin plots to visualize the distribution of task vector values across models.
34
+
35
+ This class implements the task vector visualization technique described in:
36
+ "Efficient and Effective Weight-Ensembling Mixture of Experts for Multi-Task Model Merging"
37
+ by L. Shen, A. Tang, E. Yang et al. (https://arxiv.org/abs/2410.21804)
38
+
39
+ Task vectors represent the parameter differences between fine-tuned models and their
40
+ pretrained base model, computed as:
41
+ task_vector = finetuned_params - pretrained_params
42
+
43
+ The algorithm generates two types of violin plots:
44
+ 1. Distribution of raw task vector values (positive and negative)
45
+ 2. Distribution of absolute task vector values
46
+
47
+ Args:
48
+ trainable_only (bool): If True, only consider trainable parameters when computing
49
+ task vectors. If False, use all parameters.
50
+ max_points_per_model (int, optional): Maximum number of parameters to sample
51
+ per model for memory efficiency. If None or 0, uses all parameters.
52
+ Defaults to 1000.
53
+ fig_kwargs (dict, optional): Dictionary of keyword arguments to pass to
54
+ matplotlib.pyplot.subplots. Common options include:
55
+ - figsize: Tuple of (width, height) in inches
56
+ - dpi: Dots per inch for resolution
57
+ - facecolor: Figure background color
58
+ Defaults to None.
59
+ output_path (str, optional): Directory to save the violin plots. If None,
60
+ uses the fabric logger's log directory. Defaults to None.
61
+
62
+ Outputs:
63
+ - task_vector_violin.pdf: Violin plot of raw task vector value distributions
64
+ - task_vector_violin_abs.pdf: Violin plot of absolute task vector value distributions
65
+
66
+ Returns:
67
+ The pretrained model from the model pool.
68
+
69
+ Example:
70
+ ```python
71
+ plotter = TaskVectorViolinPlot(
72
+ trainable_only=True,
73
+ max_points_per_model=5000,
74
+ fig_kwargs={'figsize': (12, 8), 'dpi': 300},
75
+ output_path='./analysis_plots'
76
+ )
77
+ pretrained_model = plotter.run(modelpool)
78
+ ```
79
+
80
+ Note:
81
+ This visualization is particularly useful for understanding:
82
+ - How different tasks affect model parameters
83
+ - The magnitude and distribution of parameter changes
84
+ - Similarities and differences between task adaptations
31
85
  """
32
86
 
33
87
  # config_mapping is a mapping from the attributes to the key in the configuration files
34
88
  _config_mapping = BaseAlgorithm._config_mapping | {
35
- "trainable_only": "trainable_only",
36
- "max_points_per_model": "max_points_per_model",
37
- "fig_kwargs": "fig_kwargs",
38
89
  "_output_path": "output_path",
39
90
  }
40
91
 
@@ -46,40 +97,34 @@ class TaskVectorViolinPlot(BaseAlgorithm, LightningFabricMixin, SimpleProfilerMi
46
97
  output_path: Optional[str] = None,
47
98
  **kwargs,
48
99
  ):
49
- R"""
50
- This class creates violin plots to visualize task vectors, which represent the differences
51
- between fine-tuned models and their pretrained base model.
100
+ """
101
+ Initialize the TaskVectorViolinPlot analyzer.
52
102
 
53
103
  Args:
54
- trainable_only (bool): If True, only consider trainable parameters when computing
55
- task vectors. If False, use all parameters.
56
- fig_kwargs (dict, optional): Dictionary of keyword arguments to pass to
57
- `matplotlib.pyplot.subplots`. Common options include:
58
- - figsize: Tuple of (width, height) in inches
59
- - dpi: Dots per inch
60
- - facecolor: Figure background color
61
- Defaults to None.
62
- output_path (str, optional): Path where the violin plot will be saved. If None,
63
- uses the fabric logger's log directory. Defaults to None.
64
- kwargs: Additional keyword arguments passed to the parent class(es).
65
-
66
- Example:
67
-
68
- ```python
69
- plotter = TaskVectorViolinPlot(
70
- trainable_only=True,
71
- fig_kwargs={'figsize': (10, 6), 'dpi': 300},
72
- output_path='./plots'
73
- )
74
-
75
- plotter.run(modelpool)
76
- ```
104
+ trainable_only (bool): Whether to consider only trainable parameters when
105
+ computing task vectors. Set to True to focus on learnable parameters,
106
+ False to include all parameters including frozen ones.
107
+ max_points_per_model (int, optional): Maximum number of parameter values
108
+ to sample per model for visualization. Useful for large models to
109
+ manage memory usage and plot clarity. Set to None or 0 to use all
110
+ parameters. Defaults to 1000.
111
+ fig_kwargs (dict, optional): Keyword arguments passed to matplotlib's
112
+ subplots function for plot customization. Examples:
113
+ - {'figsize': (10, 6)} for plot dimensions
114
+ - {'dpi': 300} for high resolution
115
+ - {'facecolor': 'white'} for background color
116
+ Defaults to None (uses matplotlib defaults).
117
+ output_path (str, optional): Directory path where violin plots will be saved.
118
+ If None, uses the fabric logger's log directory. The directory will be
119
+ created if it doesn't exist. Defaults to None.
120
+ **kwargs: Additional keyword arguments passed to parent classes.
121
+
122
+ Note:
123
+ The parameter name 'fig_kwawrgs' appears to be a typo for 'fig_kwargs'.
124
+ This should be corrected in the parameter name for consistency.
77
125
  """
78
- self.trainable_only = trainable_only
79
- self.fig_kwargs = fig_kwawrgs
80
- self.max_points_per_model = max_points_per_model
81
- self._output_path = output_path
82
126
  super().__init__(**kwargs)
127
+ self._output_path = output_path
83
128
 
84
129
  @property
85
130
  def output_path(self):
@@ -89,20 +134,39 @@ class TaskVectorViolinPlot(BaseAlgorithm, LightningFabricMixin, SimpleProfilerMi
89
134
  return self._output_path
90
135
 
91
136
  def run(self, modelpool: BaseModelPool):
92
- """Create violin plots of task vectors comparing different fine-tuned models against a pretrained model.
137
+ """
138
+ Execute the task vector violin plot analysis and visualization.
93
139
 
94
- This method implements the visualization technique from the paper "Efficient and Effective
95
- Weight-Ensembling Mixture of Experts for Multi-Task Model Merging". It:
140
+ This method implements the core algorithm that:
141
+ 1. Loads the pretrained base model from the model pool
142
+ 2. Computes task vectors for each fine-tuned model (parameter differences)
143
+ 3. Creates two violin plots showing the distribution of task vector values:
144
+ - Raw values plot: Shows positive and negative parameter changes
145
+ - Absolute values plot: Shows magnitude of parameter changes
146
+ 4. Saves both plots as PDF files in the output directory
96
147
 
97
- 1. Loads the pretrained model
98
- 2. Computes task vectors (differences between fine-tuned and pretrained models)
99
- 3. Creates violin plots showing the distribution of values in these task vectors
148
+ The visualization technique follows the approach described in:
149
+ "Efficient and Effective Weight-Ensembling Mixture of Experts for Multi-Task Model Merging"
100
150
 
101
151
  Args:
102
- modelpool (BaseModelPool): Model pool containing the pretrained model and fine-tuned models
152
+ modelpool (BaseModelPool): Pool containing both a pretrained model and
153
+ fine-tuned models. Must have `has_pretrained=True`.
103
154
 
104
155
  Returns:
105
- pretrained_model (nn.Model): The plot is saved to the specified output path.
156
+ nn.Module: The pretrained model loaded from the model pool.
157
+
158
+ Raises:
159
+ AssertionError: If the model pool doesn't contain a pretrained model.
160
+
161
+ Side Effects:
162
+ - Creates output directory if it doesn't exist
163
+ - Saves 'task_vector_violin.pdf' (raw values distribution)
164
+ - Saves 'task_vector_violin_abs.pdf' (absolute values distribution)
165
+ - Prints progress information during task vector computation
166
+
167
+ Example Output Files:
168
+ - task_vector_violin.pdf: Shows how parameters change (+ and -)
169
+ - task_vector_violin_abs.pdf: Shows magnitude of parameter changes
106
170
  """
107
171
  assert modelpool.has_pretrained
108
172
  pretrained_model = modelpool.load_pretrained_model()
@@ -175,6 +239,34 @@ class TaskVectorViolinPlot(BaseAlgorithm, LightningFabricMixin, SimpleProfilerMi
175
239
  return pretrained_model
176
240
 
177
241
  def get_task_vector(self, pretrained_model, finetuned_model):
242
+ """
243
+ Compute the task vector representing parameter changes from pretraining to fine-tuning.
244
+
245
+ The task vector quantifies how model parameters have changed during task-specific
246
+ fine-tuning and is computed as:
247
+ task_vector = finetuned_params - pretrained_params
248
+
249
+ Args:
250
+ pretrained_model (nn.Module): The base pretrained model
251
+ finetuned_model (nn.Module): The fine-tuned model for a specific task
252
+
253
+ Returns:
254
+ np.ndarray: Flattened numpy array containing parameter differences.
255
+ If max_points_per_model is set, the array may be randomly downsampled
256
+ for memory efficiency and visualization clarity.
257
+
258
+ Processing Steps:
259
+ 1. Extract state dictionaries from both models
260
+ 2. Compute parameter differences (subtraction)
261
+ 3. Flatten to 1D vector
262
+ 4. Convert to numpy array with float32 precision
263
+ 5. Optionally downsample if max_points_per_model is specified
264
+
265
+ Note:
266
+ - Uses only trainable parameters if trainable_only=True
267
+ - Downsampling uses random sampling without replacement
268
+ - Preserves the relative distribution of parameter changes
269
+ """
178
270
  task_vector = state_dict_sub(
179
271
  self.get_state_dict(finetuned_model),
180
272
  self.get_state_dict(pretrained_model),
@@ -199,6 +291,22 @@ class TaskVectorViolinPlot(BaseAlgorithm, LightningFabricMixin, SimpleProfilerMi
199
291
  return task_vector
200
292
 
201
293
  def get_state_dict(self, model: nn.Module):
294
+ """
295
+ Extract the state dictionary from a model based on parameter filtering settings.
296
+
297
+ Args:
298
+ model (nn.Module): The PyTorch model to extract parameters from
299
+
300
+ Returns:
301
+ Dict[str, torch.Tensor]: State dictionary containing model parameters.
302
+ If trainable_only=True, returns only parameters with requires_grad=True.
303
+ If trainable_only=False, returns all parameters including frozen ones.
304
+
305
+ Note:
306
+ This method respects the trainable_only configuration to focus analysis
307
+ on either learnable parameters or the complete parameter set depending
308
+ on the research question being addressed.
309
+ """
202
310
  if self.trainable_only:
203
311
  return trainable_state_dict(model)
204
312
  else:
@@ -1,4 +1,5 @@
1
1
  """
2
2
  Adapted from https://github.com/FasterDecoding/BitDelta
3
3
  """
4
+
4
5
  from .bitdelta import BitDeltaAlgorithm
@@ -6,7 +6,11 @@ import torch.nn.functional as F
6
6
  from tqdm.auto import tqdm
7
7
 
8
8
  from fusion_bench import BaseAlgorithm, BaseModelPool
9
- from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
9
+ from fusion_bench.mixins import (
10
+ LightningFabricMixin,
11
+ SimpleProfilerMixin,
12
+ auto_register_config,
13
+ )
10
14
  from fusion_bench.modelpool import CausalLMPool
11
15
 
12
16
  from .bitdelta_utils.data import get_dataloader, get_dataset
@@ -15,23 +19,12 @@ from .bitdelta_utils.diff import compress_diff, save_diff, save_full_model
15
19
  log = logging.getLogger(__name__)
16
20
 
17
21
 
22
+ @auto_register_config
18
23
  class BitDeltaAlgorithm(
19
- BaseAlgorithm,
20
24
  LightningFabricMixin,
21
25
  SimpleProfilerMixin,
26
+ BaseAlgorithm,
22
27
  ):
23
- _config_mapping = BaseAlgorithm._config_mapping | {
24
- "save_dir": "save_dir",
25
- "save_full_model": "save_full_model",
26
- "lr": "lr",
27
- "batch_size": "batch_size",
28
- "num_steps": "num_steps",
29
- "dataset_name": "dataset_name",
30
- "subset": "subset",
31
- "split": "split",
32
- "max_length": "max_length",
33
- }
34
-
35
28
  def __init__(
36
29
  self,
37
30
  save_dir: str,
@@ -46,15 +39,6 @@ class BitDeltaAlgorithm(
46
39
  **kwargs,
47
40
  ):
48
41
  super().__init__(**kwargs)
49
- self.save_dir = save_dir
50
- self.save_full_model = save_full_model
51
- self.lr = lr
52
- self.batch_size = batch_size
53
- self.num_steps = num_steps
54
- self.dataset_name = dataset_name
55
- self.subset = subset
56
- self.split = split
57
- self.max_length = max_length
58
42
 
59
43
  def run(self, modelpool: CausalLMPool):
60
44
  if self.save_dir is None:
@@ -393,7 +393,7 @@ def convert_l_lora_state_dict_to_hf(
393
393
  base_model_name: Optional[str] = None,
394
394
  ):
395
395
  """
396
- Convert a linearized Lora model's checkpoint to Hugggingface's format.
396
+ Convert a linearized Lora model's checkpoint to huggingface's format.
397
397
 
398
398
  Args:
399
399
  pretrained_path (str): The path to the pretrained model.
@@ -23,6 +23,7 @@ from transformers import MixtralForCausalLM
23
23
  from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
24
24
 
25
25
  import fusion_bench as fb
26
+ from fusion_bench import auto_register_config
26
27
  from fusion_bench.method.expert_sparsity.utils.calibration_data import (
27
28
  build_calib_loader,
28
29
  )
@@ -97,6 +98,7 @@ def dynamic_skipping(
97
98
  return model, (res_median, res_mean)
98
99
 
99
100
 
101
+ @auto_register_config
100
102
  class DynamicSkippingPruningForMixtral(
101
103
  fb.BaseAlgorithm,
102
104
  fb.mixins.LightningFabricMixin,
@@ -22,6 +22,7 @@ from transformers import MixtralForCausalLM
22
22
  from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer
23
23
 
24
24
  import fusion_bench as fb
25
+ from fusion_bench import auto_register_config
25
26
  from fusion_bench.method.expert_sparsity.utils.calibration_data import (
26
27
  build_calib_loader,
27
28
  )
@@ -81,6 +82,7 @@ def layerwise_pruning(
81
82
  return model, (global_loss_history,)
82
83
 
83
84
 
85
+ @auto_register_config
84
86
  class LayerWisePruningForMixtral(
85
87
  fb.BaseAlgorithm,
86
88
  fb.mixins.LightningFabricMixin,
@@ -20,6 +20,7 @@ from tqdm import tqdm
20
20
  from transformers import MixtralForCausalLM
21
21
 
22
22
  import fusion_bench as fb
23
+ from fusion_bench import auto_register_config
23
24
  from fusion_bench.method.expert_sparsity.utils.calibration_data import (
24
25
  build_calib_loader,
25
26
  )
@@ -95,6 +96,7 @@ def progressive_pruning(
95
96
  return model, (global_loss_history,)
96
97
 
97
98
 
99
+ @auto_register_config
98
100
  class ProgressivePruningForMixtral(
99
101
  fb.BaseAlgorithm,
100
102
  fb.mixins.LightningFabricMixin,
@@ -32,7 +32,6 @@ class FisherMergingForCLIPVisionModel(
32
32
  zeroshot_weights = {}
33
33
 
34
34
  _config_mapping = FisherMergingAlgorithm._config_mapping | {
35
- "zeroshot_weights_cache_dir": "zeroshot_weights_cache_dir",
36
35
  "_dataloader_kwargs": "dataloader_kwargs",
37
36
  }
38
37
 
@@ -44,7 +43,6 @@ class FisherMergingForCLIPVisionModel(
44
43
  minimal_fisher_weight,
45
44
  num_fisher_examples,
46
45
  dataloader_kwargs: DictConfig,
47
- zeroshot_weights_cache_dir=None,
48
46
  **kwargs,
49
47
  ):
50
48
  """
@@ -56,7 +54,6 @@ class FisherMergingForCLIPVisionModel(
56
54
  minimal_fisher_weight (float): Minimal value for Fisher weights to avoid numerical issues.
57
55
  num_fisher_examples (int): Number of examples to compute Fisher weights.
58
56
  dataloader_kwargs (DictConfig): Configuration for the dataloader.
59
- zeroshot_weights_cache_dir (str, optional): Directory to cache zero-shot weights. Defaults to None.
60
57
  **kwargs: Additional keyword arguments.
61
58
  """
62
59
  super().__init__(
@@ -66,7 +63,6 @@ class FisherMergingForCLIPVisionModel(
66
63
  num_fisher_examples=num_fisher_examples,
67
64
  )
68
65
  self.dataloader_kwargs = dataloader_kwargs
69
- self.zeroshot_weights_cache_dir = zeroshot_weights_cache_dir
70
66
  for key, value in kwargs.items():
71
67
  log.warning(f"Unused argument: {key}={value}")
72
68
  setattr(self, key, value)
@@ -15,10 +15,10 @@ from transformers import GPT2ForSequenceClassification, GPT2Model
15
15
  from transformers.data import default_data_collator
16
16
  from transformers.models.gpt2.modeling_gpt2 import Conv1D
17
17
 
18
- from fusion_bench.mixins import LightningFabricMixin
18
+ from fusion_bench.mixins import LightningFabricMixin, auto_register_config
19
19
  from fusion_bench.modelpool import GPT2ForSequenceClassificationPool
20
20
  from fusion_bench.utils import timeit_context
21
- from fusion_bench.mixins import auto_register_config
21
+
22
22
  from .fisher_merging import FisherMergingAlgorithm, get_param_squared_gradients
23
23
 
24
24
 
@@ -1,3 +1,4 @@
1
+ import os
1
2
  from copy import deepcopy
2
3
  from typing import TYPE_CHECKING, Optional
3
4
 
@@ -7,13 +8,16 @@ from typing_extensions import override
7
8
  from fusion_bench import timeit_context
8
9
  from fusion_bench.method.base_algorithm import BaseAlgorithm
9
10
  from fusion_bench.method.simple_average import SimpleAverageAlgorithm
11
+ from fusion_bench.mixins import auto_register_config
10
12
  from fusion_bench.modelpool import CausalLMBackbonePool, CausalLMPool
13
+ from fusion_bench.models.hf_utils import create_default_model_card
11
14
  from fusion_bench.utils import instantiate
12
- from fusion_bench.utils.pylogger import getRankZeroLogger
15
+ from fusion_bench.utils.pylogger import get_rankzero_logger
13
16
 
14
- log = getRankZeroLogger(__name__)
17
+ log = get_rankzero_logger(__name__)
15
18
 
16
19
 
20
+ @auto_register_config
17
21
  class SimpleAverageForLlama(BaseAlgorithm):
18
22
  R"""
19
23
  A simple averaging algorithm for LLama models. If `merge_backbone` is set to `True`, the backbone of the model will be averaged and the rest of the model will be loaded from the pre-trained model.
@@ -29,21 +33,14 @@ class SimpleAverageForLlama(BaseAlgorithm):
29
33
  ```
30
34
  """
31
35
 
32
- _config_mapping = BaseAlgorithm._config_mapping | {
33
- "merge_backbone": "merge_backbone",
34
- "show_pbar": "show_pbar",
35
- }
36
-
37
36
  def __init__(
38
37
  self,
39
38
  merge_backbone: bool,
40
39
  model_save_path: Optional[str] = None,
41
40
  show_pbar: bool = False,
41
+ **kwargs,
42
42
  ):
43
- super().__init__()
44
- self.merge_backbone = merge_backbone
45
- self.model_save_path = model_save_path
46
- self.show_pbar = show_pbar
43
+ super().__init__(**kwargs)
47
44
 
48
45
  @override
49
46
  def run(self, modelpool: CausalLMPool):
@@ -75,4 +72,12 @@ class SimpleAverageForLlama(BaseAlgorithm):
75
72
  with timeit_context(f"Saving the model to {self.model_save_path}"):
76
73
  tokenizer.save_pretrained(self.model_save_path)
77
74
  model.save_pretrained(self.model_save_path)
75
+ model_card_str = create_default_model_card(
76
+ models=[modelpool.get_model_path(m) for m in modelpool.model_names],
77
+ description="Merged model using simple averaging.",
78
+ algorithm_config=self.config,
79
+ modelpool_config=modelpool.config,
80
+ )
81
+ with open(os.path.join(self.model_save_path, "README.md"), "w") as f:
82
+ f.write(model_card_str)
78
83
  return model
@@ -0,0 +1 @@
1
+ from .model_stock import ModelStock