fusion-bench 0.2.25__py3-none-any.whl → 0.2.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. fusion_bench/dataset/clip_dataset.py +1 -0
  2. fusion_bench/method/__init__.py +4 -0
  3. fusion_bench/method/adamerging/__init__.py +28 -5
  4. fusion_bench/method/adamerging/resnet_adamerging.py +279 -0
  5. fusion_bench/method/adamerging/task_wise_adamerging.py +2 -14
  6. fusion_bench/method/adamerging/utils.py +58 -0
  7. fusion_bench/method/classification/clip_finetune.py +6 -4
  8. fusion_bench/method/classification/image_classification_finetune.py +156 -12
  9. fusion_bench/method/dare/simple_average.py +3 -2
  10. fusion_bench/method/dare/task_arithmetic.py +3 -2
  11. fusion_bench/method/dop/__init__.py +1 -0
  12. fusion_bench/method/dop/dop.py +366 -0
  13. fusion_bench/method/dop/min_norm_solvers.py +227 -0
  14. fusion_bench/method/dop/utils.py +73 -0
  15. fusion_bench/method/simple_average.py +6 -4
  16. fusion_bench/mixins/lightning_fabric.py +9 -0
  17. fusion_bench/modelpool/causal_lm/causal_lm.py +2 -1
  18. fusion_bench/modelpool/resnet_for_image_classification.py +285 -4
  19. fusion_bench/models/hf_clip.py +4 -7
  20. fusion_bench/models/hf_utils.py +4 -1
  21. fusion_bench/taskpool/__init__.py +2 -0
  22. fusion_bench/taskpool/clip_vision/taskpool.py +1 -1
  23. fusion_bench/taskpool/resnet_for_image_classification.py +231 -0
  24. fusion_bench/utils/state_dict_arithmetic.py +91 -10
  25. {fusion_bench-0.2.25.dist-info → fusion_bench-0.2.27.dist-info}/METADATA +9 -3
  26. {fusion_bench-0.2.25.dist-info → fusion_bench-0.2.27.dist-info}/RECORD +140 -77
  27. fusion_bench_config/fabric/auto.yaml +1 -1
  28. fusion_bench_config/fabric/loggers/swandb_logger.yaml +5 -0
  29. fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
  30. fusion_bench_config/fabric_model_fusion.yaml +1 -0
  31. fusion_bench_config/method/adamerging/resnet.yaml +18 -0
  32. fusion_bench_config/method/bitdelta/bitdelta.yaml +3 -0
  33. fusion_bench_config/method/classification/clip_finetune.yaml +5 -0
  34. fusion_bench_config/method/classification/image_classification_finetune.yaml +9 -0
  35. fusion_bench_config/method/depth_upscaling.yaml +9 -0
  36. fusion_bench_config/method/dop/dop.yaml +30 -0
  37. fusion_bench_config/method/dummy.yaml +6 -0
  38. fusion_bench_config/method/ensemble/max_model_predictor.yaml +6 -0
  39. fusion_bench_config/method/ensemble/simple_ensemble.yaml +8 -1
  40. fusion_bench_config/method/ensemble/weighted_ensemble.yaml +8 -0
  41. fusion_bench_config/method/linear/expo.yaml +5 -0
  42. fusion_bench_config/method/linear/linear_interpolation.yaml +8 -0
  43. fusion_bench_config/method/linear/llama_expo.yaml +5 -0
  44. fusion_bench_config/method/linear/llama_expo_with_dare.yaml +3 -0
  45. fusion_bench_config/method/linear/simple_average_for_causallm.yaml +5 -0
  46. fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +3 -0
  47. fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +5 -0
  48. fusion_bench_config/method/linear/weighted_average.yaml +3 -0
  49. fusion_bench_config/method/linear/weighted_average_for_llama.yaml +6 -1
  50. fusion_bench_config/method/mixtral_moe_merging.yaml +3 -0
  51. fusion_bench_config/method/mixtral_moe_upscaling.yaml +5 -0
  52. fusion_bench_config/method/model_recombination.yaml +8 -0
  53. fusion_bench_config/method/model_stock/model_stock.yaml +4 -1
  54. fusion_bench_config/method/opcm/opcm.yaml +5 -0
  55. fusion_bench_config/method/opcm/task_arithmetic.yaml +6 -0
  56. fusion_bench_config/method/opcm/ties_merging.yaml +5 -0
  57. fusion_bench_config/method/opcm/weight_average.yaml +5 -0
  58. fusion_bench_config/method/regmean/clip_regmean.yaml +3 -0
  59. fusion_bench_config/method/regmean/gpt2_regmean.yaml +3 -0
  60. fusion_bench_config/method/regmean/regmean.yaml +3 -0
  61. fusion_bench_config/method/regmean_plusplus/clip_regmean_plusplus.yaml +3 -0
  62. fusion_bench_config/method/simple_average.yaml +9 -0
  63. fusion_bench_config/method/slerp/slerp.yaml +9 -0
  64. fusion_bench_config/method/slerp/slerp_lm.yaml +5 -0
  65. fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +6 -0
  66. fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
  67. fusion_bench_config/method/smile_upscaling/projected_energy.yaml +5 -0
  68. fusion_bench_config/method/smile_upscaling/singular_projection_merging.yaml +3 -0
  69. fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +5 -0
  70. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +5 -0
  71. fusion_bench_config/method/smile_upscaling/smile_upscaling.yaml +3 -0
  72. fusion_bench_config/method/task_arithmetic.yaml +9 -0
  73. fusion_bench_config/method/ties_merging.yaml +3 -0
  74. fusion_bench_config/method/wudi/wudi.yaml +3 -0
  75. fusion_bench_config/model_fusion.yaml +2 -1
  76. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/_generate_config.py +138 -0
  77. fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet152_cifar10.yaml +1 -1
  78. fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet152_cifar100.yaml +1 -1
  79. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_dtd.yaml +14 -0
  80. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_emnist_letters.yaml +14 -0
  81. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_eurosat.yaml +14 -0
  82. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_fashion_mnist.yaml +14 -0
  83. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_fer2013.yaml +14 -0
  84. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_food101.yaml +14 -0
  85. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_gtsrb.yaml +14 -0
  86. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_kmnist.yaml +14 -0
  87. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_mnist.yaml +14 -0
  88. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_oxford-iiit-pet.yaml +14 -0
  89. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_oxford_flowers102.yaml +14 -0
  90. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_pcam.yaml +14 -0
  91. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_rendered-sst2.yaml +14 -0
  92. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_resisc45.yaml +14 -0
  93. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_stanford-cars.yaml +14 -0
  94. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_stl10.yaml +14 -0
  95. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_sun397.yaml +14 -0
  96. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_svhn.yaml +14 -0
  97. fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet18_cifar10.yaml +1 -1
  98. fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet18_cifar100.yaml +1 -1
  99. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_dtd.yaml +14 -0
  100. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_emnist_letters.yaml +14 -0
  101. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_eurosat.yaml +14 -0
  102. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_fashion_mnist.yaml +14 -0
  103. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_fer2013.yaml +14 -0
  104. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_food101.yaml +14 -0
  105. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_gtsrb.yaml +14 -0
  106. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_kmnist.yaml +14 -0
  107. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_mnist.yaml +14 -0
  108. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_oxford-iiit-pet.yaml +14 -0
  109. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_oxford_flowers102.yaml +14 -0
  110. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_pcam.yaml +14 -0
  111. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_rendered-sst2.yaml +14 -0
  112. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_resisc45.yaml +14 -0
  113. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_stanford-cars.yaml +14 -0
  114. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_stl10.yaml +14 -0
  115. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_sun397.yaml +14 -0
  116. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_svhn.yaml +14 -0
  117. fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet50_cifar10.yaml +1 -1
  118. fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet50_cifar100.yaml +1 -1
  119. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_dtd.yaml +14 -0
  120. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_emnist_letters.yaml +14 -0
  121. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_eurosat.yaml +14 -0
  122. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_fashion_mnist.yaml +14 -0
  123. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_fer2013.yaml +14 -0
  124. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_food101.yaml +14 -0
  125. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_gtsrb.yaml +14 -0
  126. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_kmnist.yaml +14 -0
  127. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_mnist.yaml +14 -0
  128. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_oxford-iiit-pet.yaml +14 -0
  129. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_oxford_flowers102.yaml +14 -0
  130. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_pcam.yaml +14 -0
  131. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_rendered-sst2.yaml +14 -0
  132. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_resisc45.yaml +14 -0
  133. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_stanford-cars.yaml +14 -0
  134. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_stl10.yaml +14 -0
  135. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_sun397.yaml +14 -0
  136. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_svhn.yaml +14 -0
  137. fusion_bench_config/method/clip_finetune.yaml +0 -26
  138. {fusion_bench-0.2.25.dist-info → fusion_bench-0.2.27.dist-info}/WHEEL +0 -0
  139. {fusion_bench-0.2.25.dist-info → fusion_bench-0.2.27.dist-info}/entry_points.txt +0 -0
  140. {fusion_bench-0.2.25.dist-info → fusion_bench-0.2.27.dist-info}/licenses/LICENSE +0 -0
  141. {fusion_bench-0.2.25.dist-info → fusion_bench-0.2.27.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,34 @@
1
+ """ResNet Model Pool for Image Classification.
2
+
3
+ This module provides a flexible model pool implementation for ResNet models used in image
4
+ classification tasks. It supports both torchvision and transformers implementations of ResNet
5
+ architectures with configurable preprocessing, loading, and saving capabilities.
6
+
7
+ Example Usage:
8
+ Create a pool with a torchvision ResNet model:
9
+
10
+ ```python
11
+ >>> # Torchvision ResNet pool
12
+ >>> pool = ResNetForImageClassificationPool(
13
+ ... type="torchvision",
14
+ ... models={"resnet18_cifar10": {"model_name": "resnet18", "dataset_name": "cifar10"}}
15
+ ... )
16
+ >>> model = pool.load_model("resnet18_cifar10")
17
+ >>> processor = pool.load_processor(stage="train")
18
+ ```
19
+
20
+ Create a pool with a transformers ResNet model:
21
+
22
+ ```python
23
+ >>> # Transformers ResNet pool
24
+ >>> pool = ResNetForImageClassificationPool(
25
+ ... type="transformers",
26
+ ... models={"resnet_model": {"config_path": "microsoft/resnet-50", "pretrained": True}}
27
+ ... )
28
+ ```
29
+ """
30
+
31
+ import os
1
32
  from typing import (
2
33
  TYPE_CHECKING,
3
34
  Any,
@@ -11,6 +42,7 @@ from typing import (
11
42
  )
12
43
 
13
44
  import torch
45
+ from lightning_utilities.core.rank_zero import rank_zero_only
14
46
  from omegaconf import DictConfig
15
47
  from torch import nn
16
48
 
@@ -26,6 +58,31 @@ log = get_rankzero_logger(__name__)
26
58
  def load_torchvision_resnet(
27
59
  model_name: str, weights: Optional[str], num_classes: Optional[int]
28
60
  ) -> "TorchVisionResNet":
61
+ """Load a ResNet model from torchvision with optional custom classifier head.
62
+
63
+ This function creates a ResNet model using torchvision's model zoo and optionally
64
+ replaces the final classification layer to match the required number of classes.
65
+
66
+ Args:
67
+ model_name (str): Name of the ResNet model to load (e.g., 'resnet18', 'resnet50').
68
+ Must be a valid torchvision model name.
69
+ weights (Optional[str]): Pretrained weights to load. Can be 'DEFAULT', 'IMAGENET1K_V1',
70
+ or None for random initialization. See torchvision documentation for available options.
71
+ num_classes (Optional[int]): Number of output classes. If provided, replaces the final
72
+ fully connected layer. If None, keeps the original classifier (typically 1000 classes).
73
+
74
+ Returns:
75
+ TorchVisionResNet: The loaded ResNet model with appropriate classifier head.
76
+
77
+ Raises:
78
+ AttributeError: If model_name is not a valid torchvision model.
79
+
80
+ Example:
81
+ ```python
82
+ >>> model = load_torchvision_resnet("resnet18", "DEFAULT", 10) # CIFAR-10
83
+ >>> model = load_torchvision_resnet("resnet50", None, 100) # Random init, 100 classes
84
+ ```
85
+ """
29
86
  import torchvision.models
30
87
 
31
88
  model_fn = getattr(torchvision.models, model_name)
@@ -40,6 +97,31 @@ def load_torchvision_resnet(
40
97
  def load_transformers_resnet(
41
98
  config_path: str, pretrained: bool, dataset_name: Optional[str]
42
99
  ):
100
+ """Load a ResNet model from transformers with optional dataset-specific adaptation.
101
+
102
+ This function creates a ResNet model using the transformers library and optionally
103
+ adapts it for a specific dataset by updating the classifier head and label mappings.
104
+
105
+ Args:
106
+ config_path (str): Path or identifier for the model configuration. Can be a local path
107
+ or a Hugging Face model identifier (e.g., 'microsoft/resnet-50').
108
+ pretrained (bool): Whether to load pretrained weights. If True, loads from the
109
+ specified config_path. If False, initializes with random weights using the config.
110
+ dataset_name (Optional[str]): Name of the target dataset for adaptation. If provided,
111
+ updates the model's classifier and label mappings to match the dataset's classes.
112
+ If None, keeps the original model configuration.
113
+
114
+ Returns:
115
+ ResNetForImageClassification: The loaded and optionally adapted ResNet model.
116
+
117
+ Example:
118
+ ```python
119
+ >>> # Load pretrained model adapted for CIFAR-10
120
+ >>> model = load_transformers_resnet("microsoft/resnet-50", True, "cifar10")
121
+ >>> # Load random initialized model with default classes
122
+ >>> model = load_transformers_resnet("microsoft/resnet-50", False, None)
123
+ ```
124
+ """
43
125
  from transformers import AutoConfig, ResNetForImageClassification
44
126
 
45
127
  if pretrained:
@@ -70,13 +152,107 @@ def load_transformers_resnet(
70
152
 
71
153
  @auto_register_config
72
154
  class ResNetForImageClassificationPool(BaseModelPool):
73
- def __init__(self, type: str, **kwargs):
74
- super().__init__(**kwargs)
75
- assert type in ["torchvision", "transformers"]
155
+ """Model pool for ResNet-based image classification models.
156
+
157
+ This class provides a unified interface for managing ResNet models from different sources
158
+ (torchvision and transformers) with automatic preprocessing, loading, and saving capabilities.
159
+ It supports multiple ResNet architectures and can automatically adapt models to different
160
+ datasets by adjusting the number of output classes.
161
+
162
+ The pool supports two main types:
163
+ - "torchvision": Uses torchvision's ResNet implementations with standard ImageNet preprocessing
164
+ - "transformers": Uses Hugging Face transformers' ResNetForImageClassification with auto processors
165
+
166
+ Args:
167
+ type (str): Model source type, must be either "torchvision" or "transformers".
168
+ **kwargs: Additional arguments passed to the base BaseModelPool class.
169
+
170
+ Attributes:
171
+ type (str): The model source type specified during initialization.
172
+
173
+ Raises:
174
+ AssertionError: If type is not "torchvision" or "transformers".
175
+
176
+ Example:
177
+ Create a pool with a torchvision ResNet model:
178
+
179
+ ```python
180
+ >>> # Torchvision-based pool
181
+ >>> pool = ResNetForImageClassificationPool(
182
+ ... type="torchvision",
183
+ ... models={
184
+ ... "resnet18_cifar10": {
185
+ ... "model_name": "resnet18",
186
+ ... "weights": "DEFAULT",
187
+ ... "dataset_name": "cifar10"
188
+ ... }
189
+ ... }
190
+ ... )
191
+ ```
192
+ ```
193
+
194
+ Create a pool with a transformers ResNet model:
195
+
196
+ ```python
197
+ >>> # Transformers-based pool
198
+ >>> pool = ResNetForImageClassificationPool(
199
+ ... type="transformers",
200
+ ... models={
201
+ ... "resnet_model": {
202
+ ... "config_path": "microsoft/resnet-50",
203
+ ... "pretrained": True,
204
+ ... "dataset_name": "imagenet"
205
+ ... }
206
+ ... }
207
+ ... )
208
+ ```
209
+ """
210
+
211
+ def __init__(self, models, type: str, **kwargs):
212
+ super().__init__(models=models, **kwargs)
213
+ assert type in [
214
+ "torchvision",
215
+ "transformers",
216
+ ], "type must be either 'torchvision' or 'transformers'"
76
217
 
77
218
  def load_processor(
78
219
  self, stage: Literal["train", "val", "test"] = "test", *args, **kwargs
79
220
  ):
221
+ """Load the appropriate image processor/transform for the specified training stage.
222
+
223
+ Creates stage-specific image preprocessing pipelines optimized for the model type:
224
+
225
+ For torchvision models:
226
+ - Train stage: Includes data augmentation (random resize crop, horizontal flip)
227
+ - Val/test stages: Standard preprocessing (resize, center crop) without augmentation
228
+ - All stages: Apply ImageNet normalization (mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
229
+
230
+ For transformers models:
231
+ - Uses AutoImageProcessor from the pretrained model configuration
232
+ - Automatically handles model-specific preprocessing requirements
233
+
234
+ Args:
235
+ stage (Literal["train", "val", "test"]): The training stage determining preprocessing type.
236
+ - "train": Applies data augmentation for training
237
+ - "val"/"test": Uses standard preprocessing for evaluation
238
+ *args: Additional positional arguments (unused).
239
+ **kwargs: Additional keyword arguments (unused).
240
+
241
+ Returns:
242
+ Union[transforms.Compose, AutoImageProcessor]: The image processor/transform pipeline
243
+ appropriate for the specified stage and model type.
244
+
245
+ Raises:
246
+ ValueError: If no valid config_path can be found for transformers models.
247
+
248
+ Example:
249
+ ```python
250
+ >>> # Get training transforms for torchvision model
251
+ >>> train_transform = pool.load_processor(stage="train")
252
+ >>> # Get evaluation processor for transformers model
253
+ >>> eval_processor = pool.load_processor(stage="test")
254
+ ```
255
+ """
80
256
  if self.type == "torchvision":
81
257
  from torchvision import transforms
82
258
 
@@ -122,6 +298,58 @@ class ResNetForImageClassificationPool(BaseModelPool):
122
298
 
123
299
  @override
124
300
  def load_model(self, model_name_or_config: Union[str, DictConfig], *args, **kwargs):
301
+ """Load a ResNet model based on the provided configuration or model name.
302
+
303
+ This method supports flexible model loading from different sources and configurations:
304
+ - Direct model names (e.g., "resnet18", "resnet50") for standard architectures
305
+ - Model pool keys that map to configurations
306
+ - Dictionary/DictConfig objects with detailed model specifications
307
+ - Hugging Face model identifiers for transformers models
308
+
309
+ For torchvision models, supports:
310
+ - Standard ResNet architectures: resnet18, resnet34, resnet50, resnet101, resnet152
311
+ - Custom configurations with model_name, weights, and num_classes specifications
312
+ - Automatic dataset adaptation with class number inference
313
+
314
+ For transformers models:
315
+ - Loading from Hugging Face Hub or local paths
316
+ - Pretrained or randomly initialized models
317
+ - Automatic logits extraction by overriding forward method
318
+ - Dataset-specific label mapping configuration
319
+
320
+ Args:
321
+ model_name_or_config (Union[str, DictConfig]): Model specification that can be:
322
+ - A string model name (e.g., "resnet18") for standard architectures
323
+ - A model pool key referencing a stored configuration
324
+ - A dict/DictConfig with model parameters like:
325
+ * For torchvision: {"model_name": "resnet18", "weights": "DEFAULT", "num_classes": 10}
326
+ * For transformers: {"config_path": "microsoft/resnet-50", "pretrained": True, "dataset_name": "cifar10"}
327
+ *args: Additional positional arguments (unused).
328
+ **kwargs: Additional keyword arguments (unused).
329
+
330
+ Returns:
331
+ Union[TorchVisionResNet, ResNetForImageClassification]: The loaded ResNet model
332
+ configured for the specified task. For transformers models, the forward method
333
+ is modified to return logits directly instead of the full model output.
334
+
335
+ Raises:
336
+ ValueError: If model_name_or_config type is invalid or if model type is unknown.
337
+ AssertionError: If num_classes from dataset doesn't match explicit num_classes specification.
338
+
339
+ Example:
340
+ ```python
341
+ >>> # Load standard torchvision model
342
+ >>> model = pool.load_model("resnet18")
343
+
344
+ >>> # Load with custom configuration
345
+ >>> config = {"model_name": "resnet50", "weights": "DEFAULT", "dataset_name": "cifar10"}
346
+ >>> model = pool.load_model(config)
347
+
348
+ >>> # Load transformers model
349
+ >>> config = {"config_path": "microsoft/resnet-50", "pretrained": True}
350
+ >>> model = pool.load_model(config)
351
+ ```
352
+ """
125
353
  log.debug(f"Loading model: {model_name_or_config}", stacklevel=2)
126
354
  if (
127
355
  isinstance(model_name_or_config, str)
@@ -198,11 +426,64 @@ class ResNetForImageClassificationPool(BaseModelPool):
198
426
  return model
199
427
 
200
428
  @override
201
- def save_model(self, model, path, *args, **kwargs):
429
+ def save_model(
430
+ self,
431
+ model,
432
+ path,
433
+ algorithm_config: Optional[DictConfig] = None,
434
+ description: Optional[str] = None,
435
+ *args,
436
+ **kwargs,
437
+ ):
438
+ """Save a ResNet model to the specified path using the appropriate format.
439
+
440
+ This method handles model saving based on the model pool type:
441
+ - For torchvision models: Saves only the state_dict using torch.save()
442
+ - For transformers models: Saves the complete model and processor using save_pretrained()
443
+
444
+ The saving format ensures compatibility with the corresponding loading mechanisms
445
+ and preserves all necessary components for model restoration.
446
+
447
+ Args:
448
+ model: The ResNet model to save. Should be compatible with the pool's model type.
449
+ path (str): Destination path for saving the model. For torchvision models, this
450
+ should be a file path (e.g., "model.pth"). For transformers models, this
451
+ should be a directory path where model files will be stored.
452
+ *args: Additional positional arguments (unused).
453
+ **kwargs: Additional keyword arguments (unused).
454
+
455
+ Raises:
456
+ ValueError: If the model type is unknown or unsupported.
457
+
458
+ Note:
459
+ For transformers models, both the model weights and the associated image processor
460
+ are saved to ensure complete reproducibility of the preprocessing pipeline.
461
+
462
+ Example:
463
+ ```python
464
+ >>> # Save torchvision model
465
+ >>> pool.save_model(model, "checkpoints/resnet18_cifar10.pth")
466
+
467
+ >>> # Save transformers model (saves to directory)
468
+ >>> pool.save_model(model, "checkpoints/resnet50_model/")
469
+ ```
470
+ """
202
471
  if self.type == "torchvision":
472
+ os.makedirs(os.path.dirname(path), exist_ok=True)
203
473
  torch.save(model.state_dict(), path)
204
474
  elif self.type == "transformers":
205
475
  model.save_pretrained(path)
206
476
  self.load_processor().save_pretrained(path)
477
+
478
+ if algorithm_config is not None and rank_zero_only.rank == 0:
479
+ from fusion_bench.models.hf_utils import create_default_model_card
480
+
481
+ model_card_str = create_default_model_card(
482
+ algorithm_config=algorithm_config,
483
+ description=description,
484
+ modelpool_config=self.config,
485
+ )
486
+ with open(os.path.join(path, "README.md"), "w") as f:
487
+ f.write(model_card_str)
207
488
  else:
208
489
  raise ValueError(f"Unknown model type: {self.type}")
@@ -1,5 +1,5 @@
1
1
  import logging
2
- from typing import TYPE_CHECKING, Callable, Iterable, List # noqa: F401
2
+ from typing import TYPE_CHECKING, Callable, Iterable, List, Optional # noqa: F401
3
3
 
4
4
  import torch
5
5
  from torch import Tensor, nn
@@ -39,7 +39,6 @@ class HFCLIPClassifier(nn.Module):
39
39
  self,
40
40
  clip_model: CLIPModel,
41
41
  processor: CLIPProcessor,
42
- extra_module=None,
43
42
  ):
44
43
  """
45
44
  Initialize the HFCLIPClassifier.
@@ -63,8 +62,6 @@ class HFCLIPClassifier(nn.Module):
63
62
  persistent=False,
64
63
  )
65
64
 
66
- self.extra_module = extra_module
67
-
68
65
  @property
69
66
  def text_model(self):
70
67
  """Get the text model component of CLIP."""
@@ -123,9 +120,9 @@ class HFCLIPClassifier(nn.Module):
123
120
  def forward(
124
121
  self,
125
122
  images: Tensor,
126
- return_image_embeds=False,
127
- return_dict=False,
128
- task_name=None,
123
+ return_image_embeds: bool = False,
124
+ return_dict: bool = False,
125
+ task_name: Optional[str] = None,
129
126
  ):
130
127
  """
131
128
  Perform forward pass for zero-shot image classification.
@@ -142,7 +142,7 @@ def save_pretrained_with_remote_code(
142
142
 
143
143
 
144
144
  def create_default_model_card(
145
- models: list[str],
145
+ models: Optional[list[str]] = None,
146
146
  base_model: Optional[str] = None,
147
147
  title: str = "Deep Model Fusion",
148
148
  tags: list[str] = ["fusion-bench", "merge"],
@@ -152,6 +152,9 @@ def create_default_model_card(
152
152
  ):
153
153
  from jinja2 import Template
154
154
 
155
+ if models is None:
156
+ models = []
157
+
155
158
  template: Template = Template(load_model_card_template("default.md"))
156
159
  card = template.render(
157
160
  base_model=base_model,
@@ -18,6 +18,7 @@ _import_structure = {
18
18
  "lm_eval_harness": ["LMEvalHarnessTaskPool"],
19
19
  "nyuv2_taskpool": ["NYUv2TaskPool"],
20
20
  "openclip_vision": ["OpenCLIPVisionModelTaskPool"],
21
+ "resnet_for_image_classification": ["ResNetForImageClassificationTaskPool"],
21
22
  }
22
23
 
23
24
 
@@ -34,6 +35,7 @@ if TYPE_CHECKING:
34
35
  from .lm_eval_harness import LMEvalHarnessTaskPool
35
36
  from .nyuv2_taskpool import NYUv2TaskPool
36
37
  from .openclip_vision import OpenCLIPVisionModelTaskPool
38
+ from .resnet_for_image_classification import ResNetForImageClassificationTaskPool
37
39
 
38
40
  else:
39
41
  sys.modules[__name__] = LazyImporter(
@@ -264,7 +264,7 @@ class CLIPVisionModelTaskPool(
264
264
 
265
265
  pbar = tqdm(
266
266
  test_loader,
267
- desc=f"Evaluating {task_name}",
267
+ desc=f"Evaluating {task_name}" if task_name is not None else "Evaluating",
268
268
  leave=False,
269
269
  dynamic_ncols=True,
270
270
  )
@@ -0,0 +1,231 @@
1
+ import itertools
2
+ import json
3
+ import os
4
+ from typing import (
5
+ TYPE_CHECKING,
6
+ Any,
7
+ Callable,
8
+ Dict,
9
+ Literal,
10
+ Optional,
11
+ TypeVar,
12
+ Union,
13
+ override,
14
+ )
15
+
16
+ import lightning as L
17
+ import torch
18
+ from lightning_utilities.core.rank_zero import rank_zero_only
19
+ from omegaconf import DictConfig
20
+ from torch import Tensor, nn
21
+ from torch.nn import functional as F
22
+ from torch.utils.data import DataLoader
23
+ from torchmetrics import Accuracy, MeanMetric
24
+ from tqdm.auto import tqdm
25
+
26
+ from fusion_bench import (
27
+ BaseTaskPool,
28
+ LightningFabricMixin,
29
+ RuntimeConstants,
30
+ auto_register_config,
31
+ get_rankzero_logger,
32
+ )
33
+ from fusion_bench.dataset import CLIPDataset
34
+ from fusion_bench.modelpool.resnet_for_image_classification import (
35
+ ResNetForImageClassificationPool,
36
+ load_torchvision_resnet,
37
+ load_transformers_resnet,
38
+ )
39
+ from fusion_bench.tasks.clip_classification import get_classnames, get_num_classes
40
+ from fusion_bench.utils import count_parameters
41
+
42
+ if TYPE_CHECKING:
43
+ from torchvision.models import ResNet as TorchVisionResNet
44
+ from transformers import ResNetForImageClassification
45
+
46
+ log = get_rankzero_logger(__name__)
47
+
48
+ __all__ = ["ResNetForImageClassificationTaskPool"]
49
+
50
+
51
+ @auto_register_config
52
+ class ResNetForImageClassificationTaskPool(
53
+ BaseTaskPool,
54
+ LightningFabricMixin,
55
+ ResNetForImageClassificationPool,
56
+ ):
57
+
58
+ _is_setup = False
59
+
60
+ def __init__(
61
+ self,
62
+ type: str,
63
+ test_datasets: DictConfig,
64
+ dataloader_kwargs: DictConfig,
65
+ processor_config_path: str,
66
+ **kwargs,
67
+ ):
68
+ if type == "transformers":
69
+ super().__init__(
70
+ models=DictConfig(
71
+ {"_pretrained_": {"config_path": processor_config_path}}
72
+ ),
73
+ type=type,
74
+ test_datasets=test_datasets,
75
+ **kwargs,
76
+ )
77
+ elif type == "torchvision":
78
+ super().__init__(type=type, test_datasets=test_datasets, **kwargs)
79
+ else:
80
+ raise ValueError(f"Unknown ResNet type: {type}")
81
+
82
+ def setup(self):
83
+ processor = self.load_processor(stage="test")
84
+
85
+ # Load test datasets
86
+ test_datasets = {
87
+ ds_name: CLIPDataset(self.load_test_dataset(ds_name), processor=processor)
88
+ for ds_name in self._test_datasets
89
+ }
90
+ self.test_dataloaders = {
91
+ ds_name: self.fabric.setup_dataloaders(
92
+ self.get_dataloader(ds, stage="test")
93
+ )
94
+ for ds_name, ds in test_datasets.items()
95
+ }
96
+
97
+ def _evaluate(
98
+ self,
99
+ classifier,
100
+ test_loader,
101
+ num_classes: int,
102
+ task_name: str = None,
103
+ ):
104
+ classifier.eval()
105
+ accuracy = Accuracy(task="multiclass", num_classes=num_classes)
106
+ loss_metric = MeanMetric()
107
+ if RuntimeConstants.debug:
108
+ log.info("Running under fast_dev_run mode, evaluating on a single batch.")
109
+ test_loader = itertools.islice(test_loader, 1)
110
+ else:
111
+ test_loader = test_loader
112
+
113
+ pbar = tqdm(
114
+ test_loader,
115
+ desc=f"Evaluating {task_name}" if task_name is not None else "Evaluating",
116
+ leave=False,
117
+ dynamic_ncols=True,
118
+ )
119
+ for batch in pbar:
120
+ inputs, targets = batch
121
+ outputs = classifier(inputs)
122
+ logits: Tensor = outputs["logits"]
123
+ if logits.device != targets.device:
124
+ targets = targets.to(logits.device)
125
+
126
+ loss = F.cross_entropy(logits, targets)
127
+ loss_metric.update(loss.detach().cpu())
128
+ acc = accuracy(logits.detach().cpu(), targets.detach().cpu())
129
+ pbar.set_postfix(
130
+ {
131
+ "accuracy": accuracy.compute().item(),
132
+ "loss": loss_metric.compute().item(),
133
+ }
134
+ )
135
+
136
+ acc = accuracy.compute().item()
137
+ loss = loss_metric.compute().item()
138
+ results = {"accuracy": acc, "loss": loss}
139
+ return results
140
+
141
+ def evaluate(
142
+ self,
143
+ model: Union["ResNetForImageClassification", "TorchVisionResNet"],
144
+ name: str = None,
145
+ **kwargs,
146
+ ) -> Dict[str, Any]:
147
+ assert isinstance(
148
+ model, nn.Module
149
+ ), f"Expected model to be an instance of nn.Module, but got {type(model)}"
150
+
151
+ if not self._is_setup:
152
+ self.setup()
153
+
154
+ classifier = self.fabric.to_device(model)
155
+ classifier.eval()
156
+ report = {}
157
+ # collect basic model information
158
+ training_params, all_params = count_parameters(model)
159
+ report["model_info"] = {
160
+ "trainable_params": training_params,
161
+ "all_params": all_params,
162
+ "trainable_percentage": training_params / all_params,
163
+ }
164
+ if name is not None:
165
+ report["model_info"]["name"] = name
166
+
167
+ # evaluate on each task
168
+ pbar = tqdm(
169
+ self.test_dataloaders.items(),
170
+ desc="Evaluating tasks",
171
+ total=len(self.test_dataloaders),
172
+ )
173
+ for task_name, test_dataloader in pbar:
174
+ num_classes = get_num_classes(task_name)
175
+ result = self._evaluate(
176
+ classifier,
177
+ test_dataloader,
178
+ num_classes=num_classes,
179
+ task_name=task_name,
180
+ )
181
+ report[task_name] = result
182
+
183
+ # calculate the average accuracy and loss
184
+ if "average" not in report:
185
+ report["average"] = {}
186
+ accuracies = [
187
+ value["accuracy"]
188
+ for key, value in report.items()
189
+ if "accuracy" in value
190
+ ]
191
+ if len(accuracies) > 0:
192
+ average_accuracy = sum(accuracies) / len(accuracies)
193
+ report["average"]["accuracy"] = average_accuracy
194
+ losses = [value["loss"] for key, value in report.items() if "loss" in value]
195
+ if len(losses) > 0:
196
+ average_loss = sum(losses) / len(losses)
197
+ report["average"]["loss"] = average_loss
198
+
199
+ log.info(f"Evaluation Result: {report}")
200
+ if self.fabric.is_global_zero and len(self.fabric._loggers) > 0:
201
+ save_path = os.path.join(self.log_dir, "report.json")
202
+ for version in itertools.count(1):
203
+ if not os.path.exists(save_path):
204
+ break
205
+ # if the file already exists, increment the version to avoid overwriting
206
+ save_path = os.path.join(self.log_dir, f"report_{version}.json")
207
+ with open(save_path, "w") as fp:
208
+ json.dump(report, fp)
209
+ log.info(f"Evaluation report saved to {save_path}")
210
+ return report
211
+
212
+ def get_dataloader(self, dataset, stage: str):
213
+ """Create a DataLoader for the specified dataset and training stage.
214
+
215
+ Constructs a PyTorch DataLoader with stage-appropriate configurations:
216
+ - Training stage: shuffling enabled by default
217
+ - Validation/test stages: shuffling disabled by default
218
+
219
+ Args:
220
+ dataset: The dataset to wrap in a DataLoader.
221
+ stage (str): Training stage, must be one of "train", "val", or "test".
222
+ Determines default shuffling behavior.
223
+
224
+ Returns:
225
+ DataLoader: Configured DataLoader for the given dataset and stage.
226
+ """
227
+ assert stage in ["train", "val", "test"], f"Invalid stage: {stage}"
228
+ dataloader_kwargs = dict(self.dataloader_kwargs)
229
+ if "shuffle" not in dataloader_kwargs:
230
+ dataloader_kwargs["shuffle"] = stage == "train"
231
+ return DataLoader(dataset, **dataloader_kwargs)