fusion-bench 0.2.26__py3-none-any.whl → 0.2.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (125) hide show
  1. fusion_bench/__init__.py +4 -0
  2. fusion_bench/dataset/clip_dataset.py +1 -0
  3. fusion_bench/method/__init__.py +2 -0
  4. fusion_bench/method/adamerging/__init__.py +28 -5
  5. fusion_bench/method/adamerging/resnet_adamerging.py +279 -0
  6. fusion_bench/method/adamerging/task_wise_adamerging.py +2 -14
  7. fusion_bench/method/adamerging/utils.py +58 -0
  8. fusion_bench/method/classification/image_classification_finetune.py +168 -12
  9. fusion_bench/method/dare/simple_average.py +3 -2
  10. fusion_bench/method/dare/task_arithmetic.py +3 -2
  11. fusion_bench/method/simple_average.py +6 -4
  12. fusion_bench/method/task_arithmetic/task_arithmetic.py +4 -1
  13. fusion_bench/mixins/lightning_fabric.py +9 -0
  14. fusion_bench/modelpool/__init__.py +24 -2
  15. fusion_bench/modelpool/base_pool.py +8 -1
  16. fusion_bench/modelpool/causal_lm/causal_lm.py +2 -1
  17. fusion_bench/modelpool/convnext_for_image_classification.py +198 -0
  18. fusion_bench/modelpool/dinov2_for_image_classification.py +197 -0
  19. fusion_bench/modelpool/resnet_for_image_classification.py +289 -5
  20. fusion_bench/models/hf_clip.py +4 -7
  21. fusion_bench/models/hf_utils.py +4 -1
  22. fusion_bench/models/model_card_templates/default.md +1 -1
  23. fusion_bench/taskpool/__init__.py +2 -0
  24. fusion_bench/taskpool/clip_vision/taskpool.py +1 -1
  25. fusion_bench/taskpool/resnet_for_image_classification.py +231 -0
  26. fusion_bench/utils/json.py +49 -8
  27. fusion_bench/utils/state_dict_arithmetic.py +91 -10
  28. {fusion_bench-0.2.26.dist-info → fusion_bench-0.2.28.dist-info}/METADATA +2 -2
  29. {fusion_bench-0.2.26.dist-info → fusion_bench-0.2.28.dist-info}/RECORD +124 -62
  30. fusion_bench_config/fabric/auto.yaml +1 -1
  31. fusion_bench_config/fabric/loggers/swandb_logger.yaml +5 -0
  32. fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
  33. fusion_bench_config/fabric_model_fusion.yaml +1 -0
  34. fusion_bench_config/method/adamerging/resnet.yaml +18 -0
  35. fusion_bench_config/method/classification/clip_finetune.yaml +5 -0
  36. fusion_bench_config/method/classification/image_classification_finetune.yaml +9 -0
  37. fusion_bench_config/method/linear/expo.yaml +5 -0
  38. fusion_bench_config/method/linear/llama_expo.yaml +5 -0
  39. fusion_bench_config/method/linear/llama_expo_with_dare.yaml +3 -0
  40. fusion_bench_config/method/linear/simple_average_for_causallm.yaml +5 -0
  41. fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +3 -0
  42. fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +5 -0
  43. fusion_bench_config/method/linear/weighted_average_for_llama.yaml +5 -0
  44. fusion_bench_config/method/mixtral_moe_merging.yaml +3 -0
  45. fusion_bench_config/method/mixtral_moe_upscaling.yaml +5 -0
  46. fusion_bench_config/method/regmean/clip_regmean.yaml +3 -0
  47. fusion_bench_config/method/regmean/gpt2_regmean.yaml +3 -0
  48. fusion_bench_config/method/regmean/regmean.yaml +3 -0
  49. fusion_bench_config/method/regmean_plusplus/clip_regmean_plusplus.yaml +3 -0
  50. fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +6 -0
  51. fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
  52. fusion_bench_config/method/smile_upscaling/projected_energy.yaml +5 -0
  53. fusion_bench_config/method/smile_upscaling/singular_projection_merging.yaml +3 -0
  54. fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +5 -0
  55. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +5 -0
  56. fusion_bench_config/method/wudi/wudi.yaml +3 -0
  57. fusion_bench_config/model_fusion.yaml +2 -1
  58. fusion_bench_config/modelpool/ConvNextForImageClassification/convnext-base-224.yaml +10 -0
  59. fusion_bench_config/modelpool/Dinov2ForImageClassification/dinov2-base-imagenet1k-1-layer.yaml +10 -0
  60. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/_generate_config.py +138 -0
  61. fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet152_cifar10.yaml +1 -1
  62. fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet152_cifar100.yaml +1 -1
  63. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_dtd.yaml +14 -0
  64. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_emnist_letters.yaml +14 -0
  65. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_eurosat.yaml +14 -0
  66. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_fashion_mnist.yaml +14 -0
  67. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_fer2013.yaml +14 -0
  68. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_food101.yaml +14 -0
  69. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_gtsrb.yaml +14 -0
  70. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_kmnist.yaml +14 -0
  71. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_mnist.yaml +14 -0
  72. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_oxford-iiit-pet.yaml +14 -0
  73. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_oxford_flowers102.yaml +14 -0
  74. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_pcam.yaml +14 -0
  75. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_rendered-sst2.yaml +14 -0
  76. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_resisc45.yaml +14 -0
  77. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_stanford-cars.yaml +14 -0
  78. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_stl10.yaml +14 -0
  79. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_sun397.yaml +14 -0
  80. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_svhn.yaml +14 -0
  81. fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet18_cifar10.yaml +1 -1
  82. fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet18_cifar100.yaml +1 -1
  83. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_dtd.yaml +14 -0
  84. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_emnist_letters.yaml +14 -0
  85. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_eurosat.yaml +14 -0
  86. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_fashion_mnist.yaml +14 -0
  87. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_fer2013.yaml +14 -0
  88. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_food101.yaml +14 -0
  89. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_gtsrb.yaml +14 -0
  90. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_kmnist.yaml +14 -0
  91. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_mnist.yaml +14 -0
  92. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_oxford-iiit-pet.yaml +14 -0
  93. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_oxford_flowers102.yaml +14 -0
  94. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_pcam.yaml +14 -0
  95. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_rendered-sst2.yaml +14 -0
  96. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_resisc45.yaml +14 -0
  97. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_stanford-cars.yaml +14 -0
  98. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_stl10.yaml +14 -0
  99. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_sun397.yaml +14 -0
  100. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_svhn.yaml +14 -0
  101. fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet50_cifar10.yaml +1 -1
  102. fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet50_cifar100.yaml +1 -1
  103. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_dtd.yaml +14 -0
  104. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_emnist_letters.yaml +14 -0
  105. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_eurosat.yaml +14 -0
  106. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_fashion_mnist.yaml +14 -0
  107. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_fer2013.yaml +14 -0
  108. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_food101.yaml +14 -0
  109. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_gtsrb.yaml +14 -0
  110. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_kmnist.yaml +14 -0
  111. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_mnist.yaml +14 -0
  112. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_oxford-iiit-pet.yaml +14 -0
  113. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_oxford_flowers102.yaml +14 -0
  114. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_pcam.yaml +14 -0
  115. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_rendered-sst2.yaml +14 -0
  116. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_resisc45.yaml +14 -0
  117. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_stanford-cars.yaml +14 -0
  118. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_stl10.yaml +14 -0
  119. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_sun397.yaml +14 -0
  120. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_svhn.yaml +14 -0
  121. fusion_bench_config/method/clip_finetune.yaml +0 -26
  122. {fusion_bench-0.2.26.dist-info → fusion_bench-0.2.28.dist-info}/WHEEL +0 -0
  123. {fusion_bench-0.2.26.dist-info → fusion_bench-0.2.28.dist-info}/entry_points.txt +0 -0
  124. {fusion_bench-0.2.26.dist-info → fusion_bench-0.2.28.dist-info}/licenses/LICENSE +0 -0
  125. {fusion_bench-0.2.26.dist-info → fusion_bench-0.2.28.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,9 @@
1
+ """Image Classification Fine-tuning Module.
2
+
3
+ This module provides algorithms for fine-tuning and evaluating image classification models
4
+ using PyTorch Lightning.
5
+ """
6
+
1
7
  import os
2
8
  from typing import Optional
3
9
 
@@ -23,35 +29,93 @@ from fusion_bench import (
23
29
  from fusion_bench.dataset import CLIPDataset
24
30
  from fusion_bench.modelpool import ResNetForImageClassificationPool
25
31
  from fusion_bench.tasks.clip_classification import get_num_classes
32
+ from torch.utils.data import random_split
26
33
 
27
34
  log = get_rankzero_logger(__name__)
28
35
 
29
36
 
37
+ def _get_base_model_name(model) -> Optional[str]:
38
+ if hasattr(model, "config") and hasattr(model.config, "_name_or_path"):
39
+ return model.config._name_or_path
40
+ else:
41
+ return None
42
+
43
+
30
44
  @auto_register_config
31
45
  class ImageClassificationFineTuning(BaseAlgorithm):
46
+ """Fine-tuning algorithm for image classification models.
47
+
48
+ This class implements end-to-end fine-tuning for image classification tasks using PyTorch Lightning.
49
+ It supports both epoch-based and step-based training with configurable optimizers, learning rate
50
+ schedulers, and data loaders.
51
+
52
+ Args:
53
+ max_epochs (Optional[int]): Maximum number of training epochs. Mutually exclusive with max_steps.
54
+ max_steps (Optional[int]): Maximum number of training steps. Mutually exclusive with max_epochs.
55
+ label_smoothing (float): Label smoothing factor for cross-entropy loss (0.0 = no smoothing).
56
+ optimizer (DictConfig): Configuration for the optimizer (e.g., Adam, SGD).
57
+ lr_scheduler (DictConfig): Configuration for the learning rate scheduler.
58
+ dataloader_kwargs (DictConfig): Additional keyword arguments for DataLoader construction.
59
+ **kwargs: Additional arguments passed to the base class.
60
+
61
+ Raises:
62
+ AssertionError: If both max_epochs and max_steps are provided.
63
+
64
+ Example:
65
+ ```python
66
+ >>> config = {
67
+ ... 'max_epochs': 10,
68
+ ... 'max_steps': None,
69
+ ... 'label_smoothing': 0.1,
70
+ ... 'optimizer': {'_target_': 'torch.optim.Adam', 'lr': 0.001},
71
+ ... 'lr_scheduler': {'_target_': 'torch.optim.lr_scheduler.StepLR', 'step_size': 5},
72
+ ... 'dataloader_kwargs': {'batch_size': 32, 'num_workers': 4}
73
+ ... }
74
+ >>> algorithm = ImageClassificationFineTuning(**config)
75
+ ```
76
+ """
77
+
32
78
  def __init__(
33
79
  self,
34
80
  max_epochs: Optional[int],
35
81
  max_steps: Optional[int],
82
+ training_data_ratio: Optional[float],
36
83
  label_smoothing: float,
37
84
  optimizer: DictConfig,
38
85
  lr_scheduler: DictConfig,
39
86
  dataloader_kwargs: DictConfig,
87
+ save_top_k: int,
88
+ save_interval: int,
89
+ save_on_train_epoch_end: bool,
40
90
  **kwargs,
41
91
  ):
42
92
  super().__init__(**kwargs)
43
- assert (max_epochs is None) or (
93
+ assert (max_epochs is None or max_epochs < 0) or (
44
94
  max_steps is None or max_steps < 0
45
95
  ), "Only one of max_epochs or max_steps should be set."
46
- self.training_interval = "epoch" if max_epochs is not None else "step"
96
+ self.training_interval = (
97
+ "epoch" if max_epochs is not None and max_epochs > 0 else "step"
98
+ )
47
99
  if self.training_interval == "epoch":
48
100
  self.max_steps = -1
49
101
  log.info(f"Training interval: {self.training_interval}")
50
102
  log.info(f"Max epochs: {max_epochs}, max steps: {max_steps}")
51
103
 
52
104
  def run(self, modelpool: ResNetForImageClassificationPool):
105
+ """Execute the fine-tuning process on the provided model pool.
106
+
107
+ This method performs the complete fine-tuning workflow:
108
+ 1. Loads the pretrained model from the model pool
109
+ 2. Prepares training and validation datasets
110
+ 3. Configures optimizer and learning rate scheduler
111
+ 4. Sets up Lightning trainer with appropriate callbacks
112
+ 5. Executes the training process
113
+ 6. Saves the final fine-tuned model
114
+ """
53
115
  # load model and dataset
54
116
  model = modelpool.load_pretrained_or_first_model()
117
+ base_model_name = _get_base_model_name(model)
118
+
55
119
  assert isinstance(model, nn.Module), "Loaded model is not a nn.Module."
56
120
 
57
121
  assert (
@@ -59,7 +123,17 @@ class ImageClassificationFineTuning(BaseAlgorithm):
59
123
  ), "Exactly one training dataset is required."
60
124
  self.dataset_name = dataset_name = modelpool.train_dataset_names[0]
61
125
  num_classes = get_num_classes(dataset_name)
126
+ log.info(f"Number of classes for dataset {dataset_name}: {num_classes}")
62
127
  train_dataset = modelpool.load_train_dataset(dataset_name)
128
+ log.info(f"Training dataset size: {len(train_dataset)}")
129
+ if self.training_data_ratio is not None and 0 < self.training_data_ratio < 1:
130
+ train_dataset, _ = random_split(
131
+ train_dataset,
132
+ lengths=[self.training_data_ratio, 1 - self.training_data_ratio],
133
+ )
134
+ log.info(
135
+ f"Using {len(train_dataset)} samples for training after applying training_data_ratio={self.training_data_ratio}."
136
+ )
63
137
  train_dataset = CLIPDataset(
64
138
  train_dataset, processor=modelpool.load_processor(stage="train")
65
139
  )
@@ -70,6 +144,8 @@ class ImageClassificationFineTuning(BaseAlgorithm):
70
144
  val_dataset, processor=modelpool.load_processor(stage="val")
71
145
  )
72
146
  val_loader = self.get_dataloader(val_dataset, stage="val")
147
+ else:
148
+ val_loader = None
73
149
 
74
150
  # configure optimizer
75
151
  optimizer = instantiate(self.optimizer, params=model.parameters())
@@ -91,7 +167,11 @@ class ImageClassificationFineTuning(BaseAlgorithm):
91
167
  objective=nn.CrossEntropyLoss(label_smoothing=self.label_smoothing),
92
168
  metrics={
93
169
  "acc@1": Accuracy(task="multiclass", num_classes=num_classes),
94
- "acc@5": Accuracy(task="multiclass", num_classes=num_classes, top_k=5),
170
+ f"acc@{min(5,num_classes)}": Accuracy(
171
+ task="multiclass",
172
+ num_classes=num_classes,
173
+ top_k=min(5, num_classes),
174
+ ),
95
175
  },
96
176
  )
97
177
 
@@ -108,11 +188,21 @@ class ImageClassificationFineTuning(BaseAlgorithm):
108
188
  callbacks=[
109
189
  pl_callbacks.LearningRateMonitor(logging_interval="step"),
110
190
  pl_callbacks.DeviceStatsMonitor(),
191
+ pl_callbacks.ModelCheckpoint(
192
+ save_top_k=self.save_top_k,
193
+ every_n_train_steps=(
194
+ self.save_interval if self.training_interval == "step" else None
195
+ ),
196
+ every_n_epochs=(
197
+ self.save_interval
198
+ if self.training_interval == "epoch"
199
+ else None
200
+ ),
201
+ save_on_train_epoch_end=self.save_on_train_epoch_end,
202
+ save_last=True,
203
+ ),
111
204
  ],
112
- logger=TensorBoardLogger(
113
- save_dir=log_dir,
114
- name="",
115
- ),
205
+ logger=TensorBoardLogger(save_dir=log_dir, name="", version=""),
116
206
  fast_dev_run=RuntimeConstants.debug,
117
207
  )
118
208
 
@@ -129,10 +219,27 @@ class ImageClassificationFineTuning(BaseAlgorithm):
129
219
  "raw_checkpoints",
130
220
  "final",
131
221
  ),
222
+ algorithm_config=self.config,
223
+ description=f"Fine-tuned ResNet model on dataset {dataset_name}.",
224
+ base_model=base_model_name,
132
225
  )
133
226
  return model
134
227
 
135
228
  def get_dataloader(self, dataset, stage: str):
229
+ """Create a DataLoader for the specified dataset and training stage.
230
+
231
+ Constructs a PyTorch DataLoader with stage-appropriate configurations:
232
+ - Training stage: shuffling enabled by default
233
+ - Validation/test stages: shuffling disabled by default
234
+
235
+ Args:
236
+ dataset: The dataset to wrap in a DataLoader.
237
+ stage (str): Training stage, must be one of "train", "val", or "test".
238
+ Determines default shuffling behavior.
239
+
240
+ Returns:
241
+ DataLoader: Configured DataLoader for the given dataset and stage.
242
+ """
136
243
  assert stage in ["train", "val", "test"], f"Invalid stage: {stage}"
137
244
  dataloader_kwargs = dict(self.dataloader_kwargs)
138
245
  if "shuffle" not in dataloader_kwargs:
@@ -142,10 +249,42 @@ class ImageClassificationFineTuning(BaseAlgorithm):
142
249
 
143
250
  @auto_register_config
144
251
  class ImageClassificationFineTuning_Test(BaseAlgorithm):
252
+ """Test/evaluation algorithm for fine-tuned image classification models.
253
+
254
+ This class implements model evaluation on test or validation datasets using PyTorch Lightning.
255
+ It can either evaluate a model directly or load a model from a checkpoint before evaluation.
256
+ The evaluation computes standard classification metrics including top-1 and top-5 accuracy.
257
+
258
+ Args:
259
+ checkpoint_path (str): Path to the model checkpoint file. If None, uses the model
260
+ directly from the model pool without loading from checkpoint.
261
+ dataloader_kwargs (DictConfig): Additional keyword arguments for DataLoader construction.
262
+ **kwargs: Additional arguments passed to the base class.
263
+
264
+ Example:
265
+ ```python
266
+ >>> config = {
267
+ ... 'checkpoint_path': '/path/to/model/checkpoint.ckpt',
268
+ ... 'dataloader_kwargs': {'batch_size': 64, 'num_workers': 4}
269
+ ... }
270
+ >>> test_algorithm = ImageClassificationFineTuning_Test(**config)
271
+ ```
272
+ """
273
+
145
274
  def __init__(self, checkpoint_path: str, dataloader_kwargs: DictConfig, **kwargs):
146
275
  super().__init__(**kwargs)
147
276
 
148
- def run(self, modelpool: BaseModelPool):
277
+ def run(self, modelpool: ResNetForImageClassificationPool):
278
+ """Execute model evaluation on the provided model pool's test/validation dataset.
279
+
280
+ This method performs the complete evaluation workflow:
281
+ 1. Loads the model from the model pool (pretrained or first available)
282
+ 2. Prepares the test or validation dataset (prioritizes test if both available)
283
+ 3. Sets up the Lightning module with appropriate metrics (top-1 and top-5 accuracy)
284
+ 4. Loads from checkpoint if specified, otherwise uses the model directly
285
+ 5. Executes the evaluation using Lightning trainer
286
+ 6. Logs and returns the test metrics
287
+ """
149
288
  assert (
150
289
  modelpool.has_val_dataset or modelpool.has_test_dataset
151
290
  ), "No validation or test dataset found in the model pool."
@@ -181,8 +320,10 @@ class ImageClassificationFineTuning_Test(BaseAlgorithm):
181
320
  model,
182
321
  metrics={
183
322
  "acc@1": Accuracy(task="multiclass", num_classes=num_classes),
184
- "acc@5": Accuracy(
185
- task="multiclass", num_classes=num_classes, top_k=5
323
+ f"acc@{min(5,num_classes)}": Accuracy(
324
+ task="multiclass",
325
+ num_classes=num_classes,
326
+ top_k=min(5, num_classes),
186
327
  ),
187
328
  },
188
329
  )
@@ -192,8 +333,10 @@ class ImageClassificationFineTuning_Test(BaseAlgorithm):
192
333
  model=model,
193
334
  metrics={
194
335
  "acc@1": Accuracy(task="multiclass", num_classes=num_classes),
195
- "acc@5": Accuracy(
196
- task="multiclass", num_classes=num_classes, top_k=5
336
+ f"acc@{min(5,num_classes)}": Accuracy(
337
+ task="multiclass",
338
+ num_classes=num_classes,
339
+ top_k=min(5, num_classes),
197
340
  ),
198
341
  },
199
342
  )
@@ -207,6 +350,19 @@ class ImageClassificationFineTuning_Test(BaseAlgorithm):
207
350
  return model
208
351
 
209
352
  def get_dataloader(self, dataset, stage: str):
353
+ """Create a DataLoader for the specified dataset and evaluation stage.
354
+
355
+ Constructs a PyTorch DataLoader with stage-appropriate configurations for evaluation.
356
+ Similar to the training version but typically used for test/validation datasets.
357
+
358
+ Args:
359
+ dataset: The dataset to wrap in a DataLoader.
360
+ stage (str): Evaluation stage, must be one of "train", "val", or "test".
361
+ Determines default shuffling behavior (disabled for non-train stages).
362
+
363
+ Returns:
364
+ DataLoader: Configured DataLoader for the given dataset and stage.
365
+ """
210
366
  assert stage in ["train", "val", "test"], f"Invalid stage: {stage}"
211
367
  dataloader_kwargs = dict(self.dataloader_kwargs)
212
368
  if "shuffle" not in dataloader_kwargs:
@@ -1,6 +1,6 @@
1
1
  import logging
2
2
 
3
- from fusion_bench import BaseAlgorithm, BaseModelPool
3
+ from fusion_bench import BaseAlgorithm, BaseModelPool, auto_register_config
4
4
  from fusion_bench.utils.state_dict_arithmetic import state_dict_add, state_dict_mul
5
5
 
6
6
  from .task_arithmetic import DareTaskArithmetic
@@ -8,6 +8,7 @@ from .task_arithmetic import DareTaskArithmetic
8
8
  log = logging.getLogger(__name__)
9
9
 
10
10
 
11
+ @auto_register_config
11
12
  class DareSimpleAverage(BaseAlgorithm):
12
13
 
13
14
  def __init__(
@@ -17,10 +18,10 @@ class DareSimpleAverage(BaseAlgorithm):
17
18
  rescale: bool = True,
18
19
  **kwargs,
19
20
  ):
21
+ super().__init__(**kwargs)
20
22
  self.sparsity_ratio = sparsity_ratio
21
23
  self.only_on_linear_weight = only_on_linear_weights
22
24
  self.rescale = rescale
23
- super().__init__(**kwargs)
24
25
 
25
26
  def run(self, modelpool: BaseModelPool):
26
27
  return DareTaskArithmetic(
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
  from torch import Tensor, nn
3
3
 
4
- from fusion_bench import BaseAlgorithm, BaseModelPool
4
+ from fusion_bench import BaseAlgorithm, BaseModelPool, auto_register_config
5
5
  from fusion_bench.utils.state_dict_arithmetic import state_dict_sum
6
6
 
7
7
  from .utils import (
@@ -12,6 +12,7 @@ from .utils import (
12
12
  )
13
13
 
14
14
 
15
+ @auto_register_config
15
16
  class DareTaskArithmetic(BaseAlgorithm):
16
17
  """
17
18
  Implementation of Task Arithmetic w/ DARE.
@@ -27,11 +28,11 @@ class DareTaskArithmetic(BaseAlgorithm):
27
28
  rescale: bool = True,
28
29
  **kwargs,
29
30
  ):
31
+ super().__init__(**kwargs)
30
32
  self.scaling_factor = scaling_factor
31
33
  self.sparsity_ratio = sparsity_ratio
32
34
  self.only_on_linear_weights = only_on_linear_weights
33
35
  self.rescale = rescale
34
- super().__init__(**kwargs)
35
36
 
36
37
  def _load_task_vector(
37
38
  self,
@@ -64,10 +64,12 @@ class SimpleAverageAlgorithm(
64
64
  SimpleProfilerMixin,
65
65
  BaseAlgorithm,
66
66
  ):
67
- def __init__(self, show_pbar: bool = False, **kwargs):
67
+ def __init__(self, show_pbar: bool = False, inplace: bool = True, **kwargs):
68
68
  """
69
69
  Args:
70
70
  show_pbar (bool): If True, shows a progress bar during model loading and merging. Default is False.
71
+ inplace (bool): If True, overwrites the weights of the first model in the model pool.
72
+ If False, creates a new model for the merged weights. Default is True.
71
73
  """
72
74
  super().__init__(**kwargs)
73
75
 
@@ -104,12 +106,12 @@ class SimpleAverageAlgorithm(
104
106
  with self.profile("merge weights"):
105
107
  if sd is None:
106
108
  # Initialize the state dictionary with the first model's state dictionary
107
- sd = model.state_dict(keep_vars=True)
108
- forward_model = model
109
+ sd = model.state_dict()
110
+ forward_model = model if self.inplace else deepcopy(model)
109
111
  else:
110
112
  # Add the current model's state dictionary to the accumulated state dictionary
111
113
  sd = state_dict_add(
112
- sd, model.state_dict(keep_vars=True), show_pbar=self.show_pbar
114
+ sd, model.state_dict(), show_pbar=self.show_pbar
113
115
  )
114
116
  with self.profile("merge weights"):
115
117
  # Divide the accumulated state dictionary by the number of models to get the average
@@ -149,7 +149,10 @@ class TaskArithmeticAlgorithm(
149
149
  )
150
150
  with self.profile("merge weights"):
151
151
  # scale the task vector
152
- task_vector = state_dict_mul(task_vector, self.config.scaling_factor)
152
+ # here we keep the dtype when the elements of value are all zeros to avoid dtype mismatch
153
+ task_vector = state_dict_mul(
154
+ task_vector, self.config.scaling_factor, keep_dtype_when_zero=True
155
+ )
153
156
  # add the task vector to the pretrained model
154
157
  state_dict = state_dict_add(pretrained_model.state_dict(), task_vector)
155
158
 
@@ -111,6 +111,15 @@ class LightningFabricMixin:
111
111
  """
112
112
  if self.fabric is not None and len(self.fabric._loggers) > 0:
113
113
  log_dir = self.fabric.logger.log_dir
114
+
115
+ # Special handling for SwanLabLogger to get the correct log directory
116
+ if (
117
+ log_dir is None
118
+ and self.fabric.logger.__class__.__name__ == "SwanLabLogger"
119
+ ):
120
+ log_dir = self.fabric.logger.save_dir or self.fabric.logger._logdir
121
+
122
+ assert log_dir is not None, "log_dir should not be None"
114
123
  if self.fabric.is_global_zero and not os.path.exists(log_dir):
115
124
  os.makedirs(log_dir, exist_ok=True)
116
125
  return log_dir
@@ -8,6 +8,14 @@ _import_structure = {
8
8
  "base_pool": ["BaseModelPool"],
9
9
  "causal_lm": ["CausalLMPool", "CausalLMBackbonePool"],
10
10
  "clip_vision": ["CLIPVisionModelPool"],
11
+ "convnext_for_image_classification": [
12
+ "ConvNextForImageClassificationPool",
13
+ "load_transformers_convnext",
14
+ ],
15
+ "dinov2_for_image_classification": [
16
+ "Dinov2ForImageClassificationPool",
17
+ "load_transformers_dinov2",
18
+ ],
11
19
  "nyuv2_modelpool": ["NYUv2ModelPool"],
12
20
  "huggingface_automodel": ["AutoModelPool"],
13
21
  "seq2seq_lm": ["Seq2SeqLMPool"],
@@ -18,7 +26,10 @@ _import_structure = {
18
26
  "GPT2ForSequenceClassificationPool",
19
27
  ],
20
28
  "seq_classification_lm": ["SequenceClassificationModelPool"],
21
- "resnet_for_image_classification": ["ResNetForImageClassificationPool"],
29
+ "resnet_for_image_classification": [
30
+ "ResNetForImageClassificationPool",
31
+ "load_transformers_resnet",
32
+ ],
22
33
  }
23
34
 
24
35
 
@@ -26,6 +37,14 @@ if TYPE_CHECKING:
26
37
  from .base_pool import BaseModelPool
27
38
  from .causal_lm import CausalLMBackbonePool, CausalLMPool
28
39
  from .clip_vision import CLIPVisionModelPool
40
+ from .convnext_for_image_classification import (
41
+ ConvNextForImageClassificationPool,
42
+ load_transformers_convnext,
43
+ )
44
+ from .dinov2_for_image_classification import (
45
+ Dinov2ForImageClassificationPool,
46
+ load_transformers_dinov2,
47
+ )
29
48
  from .huggingface_automodel import AutoModelPool
30
49
  from .huggingface_gpt2_classification import (
31
50
  GPT2ForSequenceClassificationPool,
@@ -34,7 +53,10 @@ if TYPE_CHECKING:
34
53
  from .nyuv2_modelpool import NYUv2ModelPool
35
54
  from .openclip_vision import OpenCLIPVisionModelPool
36
55
  from .PeftModelForSeq2SeqLM import PeftModelForSeq2SeqLMPool
37
- from .resnet_for_image_classification import ResNetForImageClassificationPool
56
+ from .resnet_for_image_classification import (
57
+ ResNetForImageClassificationPool,
58
+ load_transformers_resnet,
59
+ )
38
60
  from .seq2seq_lm import Seq2SeqLMPool
39
61
  from .seq_classification_lm import SequenceClassificationModelPool
40
62
 
@@ -3,7 +3,7 @@ from copy import deepcopy
3
3
  from typing import Dict, Generator, List, Optional, Tuple, Union
4
4
 
5
5
  import torch
6
- from omegaconf import DictConfig
6
+ from omegaconf import DictConfig, OmegaConf, UnsupportedValueType
7
7
  from torch import nn
8
8
  from torch.utils.data import Dataset
9
9
 
@@ -52,6 +52,13 @@ class BaseModelPool(
52
52
  ):
53
53
  if isinstance(models, List):
54
54
  models = {str(model_idx): model for model_idx, model in enumerate(models)}
55
+
56
+ if isinstance(models, dict):
57
+ try: # try to convert to DictConfig
58
+ models = OmegaConf.create(models)
59
+ except UnsupportedValueType:
60
+ pass
61
+
55
62
  self._models = models
56
63
  self._train_datasets = train_datasets
57
64
  self._val_datasets = val_datasets
@@ -8,6 +8,7 @@ from copy import deepcopy
8
8
  from typing import Any, Dict, Optional, TypeAlias, Union, cast # noqa: F401
9
9
 
10
10
  import peft
11
+ from lightning_utilities.core.rank_zero import rank_zero_only
11
12
  from omegaconf import DictConfig, OmegaConf, flag_override
12
13
  from torch import nn
13
14
  from torch.nn.modules import Module
@@ -342,7 +343,7 @@ class CausalLMPool(BaseModelPool):
342
343
  )
343
344
 
344
345
  # Create and save model card if algorithm_config is provided
345
- if algorithm_config is not None:
346
+ if algorithm_config is not None and rank_zero_only.rank == 0:
346
347
  if description is None:
347
348
  description = "Model created using FusionBench."
348
349
  model_card_str = create_default_model_card(