fusion-bench 0.2.31__py3-none-any.whl → 0.2.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. fusion_bench/__init__.py +6 -0
  2. fusion_bench/__main__.py +2 -2
  3. fusion_bench/dataset/__init__.py +2 -0
  4. fusion_bench/dataset/clip_dataset.py +4 -72
  5. fusion_bench/dataset/image_dataset.py +44 -18
  6. fusion_bench/method/base_algorithm.py +4 -0
  7. fusion_bench/method/dop/dop.py +0 -22
  8. fusion_bench/method/dop/dop_general.py +489 -0
  9. fusion_bench/method/dop/utils.py +24 -4
  10. fusion_bench/method/emr_merging/__init__.py +1 -0
  11. fusion_bench/method/emr_merging/emr_merging.py +53 -0
  12. fusion_bench/method/emr_merging/utils.py +162 -0
  13. fusion_bench/method/opcm/opcm.py +6 -2
  14. fusion_bench/method/opcm/opcm_general.py +356 -0
  15. fusion_bench/method/opcm/utils.py +1 -4
  16. fusion_bench/method/simple_average.py +52 -18
  17. fusion_bench/method/task_arithmetic/task_arithmetic.py +1 -1
  18. fusion_bench/mixins/lightning_fabric.py +108 -3
  19. fusion_bench/mixins/serialization.py +1 -1
  20. fusion_bench/modelpool/base_pool.py +37 -1
  21. fusion_bench/modelpool/convnext_for_image_classification.py +5 -2
  22. fusion_bench/models/hf_clip.py +20 -0
  23. fusion_bench/models/modulator/__init__.py +1 -0
  24. fusion_bench/models/modulator/base.py +123 -0
  25. fusion_bench/models/parameter_dict.py +119 -29
  26. fusion_bench/models/utils.py +190 -2
  27. fusion_bench/models/wrappers/switch.py +90 -0
  28. fusion_bench/programs/base_program.py +6 -0
  29. fusion_bench/programs/fabric_fusion_program.py +4 -0
  30. fusion_bench/scripts/cli.py +19 -8
  31. fusion_bench/taskpool/image_classification.py +270 -0
  32. fusion_bench/utils/__init__.py +18 -1
  33. fusion_bench/utils/data.py +1 -1
  34. fusion_bench/utils/dict.py +19 -0
  35. fusion_bench/utils/dtype.py +19 -0
  36. fusion_bench/utils/misc.py +1 -0
  37. fusion_bench/utils/packages.py +4 -0
  38. fusion_bench/utils/state_dict_arithmetic.py +183 -1
  39. fusion_bench/utils/tensorboard.py +21 -3
  40. {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/METADATA +3 -1
  41. {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/RECORD +51 -37
  42. {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/WHEEL +1 -1
  43. {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/entry_points.txt +1 -1
  44. fusion_bench_config/fabric/loggers/mlflow_logger.yaml +4 -0
  45. fusion_bench_config/method/dop/dop_general.yaml +33 -0
  46. fusion_bench_config/method/emr_merging/emr_merging.yaml +1 -0
  47. fusion_bench_config/method/opcm/opcm_general.yaml +18 -0
  48. fusion_bench_config/modelpool/ConvNextForImageClassification/convnext-base-224_8-tasks.yaml +15 -0
  49. fusion_bench_config/taskpool/ImageClassificationTaskPool/convnext-base-224_8-tasks.yaml +17 -0
  50. {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/licenses/LICENSE +0 -0
  51. {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/top_level.txt +0 -0
fusion_bench/__init__.py CHANGED
@@ -86,6 +86,9 @@ _import_structure = {
86
86
  "set_print_function_call",
87
87
  "set_print_function_call_permeanent",
88
88
  "timeit_context",
89
+ "initialize_hydra_config",
90
+ "get_default_config_path",
91
+ "get_hydra_output_dir",
89
92
  ],
90
93
  }
91
94
 
@@ -144,8 +147,11 @@ if TYPE_CHECKING:
144
147
  StateDictType,
145
148
  TorchModelType,
146
149
  cache_with_joblib,
150
+ get_default_config_path,
151
+ get_hydra_output_dir,
147
152
  get_rankzero_logger,
148
153
  import_object,
154
+ initialize_hydra_config,
149
155
  instantiate,
150
156
  parse_dtype,
151
157
  print_parameters,
fusion_bench/__main__.py CHANGED
@@ -1,4 +1,4 @@
1
- from fusion_bench.scripts.cli import main
1
+ from fusion_bench.scripts.cli import _hydra_main
2
2
 
3
3
  if __name__ == "__main__":
4
- main()
4
+ _hydra_main()
@@ -38,10 +38,12 @@ _extra_objects = {
38
38
  }
39
39
  _import_structure = {
40
40
  "clip_dataset": ["CLIPDataset"],
41
+ "image_dataset": ["ImageClassificationDataset"],
41
42
  }
42
43
 
43
44
  if TYPE_CHECKING:
44
45
  from .clip_dataset import CLIPDataset
46
+ from .image_dataset import ImageClassificationDataset
45
47
 
46
48
  else:
47
49
  sys.modules[__name__] = LazyImporter(
@@ -2,80 +2,12 @@
2
2
  This module provides a class to convert a dataset whose object is a list of dictionaries with keys "image" and "label" to a dataset whose object is a tuple of tensors (inputs, label) for CLIP models.
3
3
  """
4
4
 
5
- from typing import Optional, Tuple
5
+ from fusion_bench.utils import DeprecationWarningMeta
6
6
 
7
- import torch
8
- from torch.utils.data import Dataset
9
- from transformers import BaseImageProcessor, CLIPProcessor, ProcessorMixin
7
+ from .image_dataset import ImageClassificationDataset
10
8
 
11
9
  __all__ = ["CLIPDataset"]
12
10
 
13
11
 
14
- class CLIPDataset(torch.utils.data.Dataset):
15
- """
16
- A dataset class for CLIP models that converts a dataset of dictionaries or tuples
17
- into a format suitable for CLIP processing.
18
-
19
- This class wraps an existing dataset and applies CLIP preprocessing to the images.
20
- It expects each item in the dataset to be either a dictionary with 'image' and 'label' keys,
21
- or a tuple/list of (image, label).
22
-
23
- Args:
24
- dataset: The original dataset to wrap.
25
- processor (CLIPProcessor): The CLIP processor for preparing inputs. If None, no preprocessing is applied and raw images are returned.
26
-
27
- Attributes:
28
- dataset: The wrapped dataset.
29
- processor (CLIPProcessor): The CLIP processor used for image preprocessing.
30
- """
31
-
32
- def __init__(self, dataset: Dataset, processor: Optional[CLIPProcessor] = None):
33
- self.dataset = dataset
34
- self.processor = processor
35
-
36
- def __len__(self):
37
- """Returns the number of items in the dataset."""
38
- return len(self.dataset)
39
-
40
- def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
41
- """
42
- Retrieves and processes an item from the dataset.
43
-
44
- Args:
45
- idx (int): The index of the item to retrieve.
46
-
47
- Returns:
48
- tuple: A tuple containing the processed image tensor and the label.
49
-
50
- Raises:
51
- ValueError: If the item is neither a dictionary nor a tuple/list of length 2.
52
- """
53
- item = self.dataset[idx]
54
- if isinstance(item, dict):
55
- item = item
56
- elif isinstance(item, (tuple, list)):
57
- assert len(item) == 2, "Each item should be a tuple or list of length 2"
58
- item = {"image": item[0], "label": item[1]}
59
- else:
60
- raise ValueError("Each item should be a dictionary or a tuple of length 2")
61
- image = item["image"]
62
- if self.processor is not None:
63
- if isinstance(self.processor, (ProcessorMixin, BaseImageProcessor)):
64
- # Apply the processor to the image to get the input tensor
65
- image = image.convert("RGB") # ensure image is in RGB format
66
- inputs = self.processor(images=[image], return_tensors="pt")[
67
- "pixel_values"
68
- ][0]
69
- elif callable(self.processor):
70
- inputs = self.processor(image)
71
- else:
72
- raise ValueError(
73
- "The processor should be a CLIPProcessor or a callable function"
74
- )
75
- else:
76
- # if processor is None, return the raw image directly
77
- inputs = image
78
- # convert boolean label to int, this is for the case when the label is a binary classification task
79
- if isinstance(item["label"], bool):
80
- item["label"] = 1 if item["label"] else 0
81
- return inputs, item["label"]
12
+ class CLIPDataset(ImageClassificationDataset, metaclass=DeprecationWarningMeta):
13
+ pass
@@ -1,35 +1,39 @@
1
- from typing import Any, Callable, Tuple
1
+ from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Union
2
2
 
3
+ import torch
3
4
  from torch.utils.data import Dataset
5
+ from transformers import BaseImageProcessor, ProcessorMixin
4
6
 
5
7
 
6
- class TransformedImageDataset(Dataset):
8
+ class ImageClassificationDataset(Dataset):
7
9
  """
8
- A dataset class for image classification tasks that applies a transform to images.
10
+ A dataset class for image classification models that converts a dataset of dictionaries or tuples
11
+ into a format suitable for model processing.
9
12
 
10
- This class wraps an existing dataset and applies a specified transform to the images.
13
+ This class wraps an existing dataset and applies preprocessing to the images.
11
14
  It expects each item in the dataset to be either a dictionary with 'image' and 'label' keys,
12
15
  or a tuple/list of (image, label).
13
-
14
- Args:
15
- dataset: The original dataset to wrap.
16
- transform (Callable): A function/transform to apply on the image.
17
-
18
- Attributes:
19
- dataset: The wrapped dataset.
20
- transform (Callable): The transform to be applied to the images.
21
16
  """
22
17
 
23
- def __init__(self, dataset: Dataset, transform: Callable):
24
- super().__init__()
18
+ def __init__(
19
+ self,
20
+ dataset: Dataset,
21
+ processor: Optional[Union["ProcessorMixin", "BaseImageProcessor"]] = None,
22
+ ):
23
+ """
24
+ Args:
25
+ dataset (Dataset): The original dataset to wrap.
26
+ processor (Optional[Union[ProcessorMixin, BaseImageProcessor]]): The processor for preparing inputs.
27
+ If None, no preprocessing is applied and raw images are returned.
28
+ """
25
29
  self.dataset = dataset
26
- self.transform = transform
30
+ self.processor = processor
27
31
 
28
32
  def __len__(self):
29
33
  """Returns the number of items in the dataset."""
30
34
  return len(self.dataset)
31
35
 
32
- def __getitem__(self, idx: int) -> Tuple[Any, Any]:
36
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
33
37
  """
34
38
  Retrieves and processes an item from the dataset.
35
39
 
@@ -37,11 +41,13 @@ class TransformedImageDataset(Dataset):
37
41
  idx (int): The index of the item to retrieve.
38
42
 
39
43
  Returns:
40
- tuple: A tuple containing the processed image and the label.
44
+ tuple: A tuple containing the processed image tensor and the label.
41
45
 
42
46
  Raises:
43
47
  ValueError: If the item is neither a dictionary nor a tuple/list of length 2.
44
48
  """
49
+ # Standardize the item to a dictionary format
50
+ # {"image": ..., "label": ...}
45
51
  item = self.dataset[idx]
46
52
  if isinstance(item, dict):
47
53
  item = item
@@ -50,6 +56,26 @@ class TransformedImageDataset(Dataset):
50
56
  item = {"image": item[0], "label": item[1]}
51
57
  else:
52
58
  raise ValueError("Each item should be a dictionary or a tuple of length 2")
59
+
60
+ # Process the image using the provided processor, if any
53
61
  image = item["image"]
54
- inputs = self.transform(image)
62
+ if self.processor is not None:
63
+ if isinstance(self.processor, (ProcessorMixin, BaseImageProcessor)):
64
+ # Apply the processor to the image to get the input tensor
65
+ image = image.convert("RGB") # ensure image is in RGB format
66
+ inputs = self.processor(images=[image], return_tensors="pt")[
67
+ "pixel_values"
68
+ ][0]
69
+ elif callable(self.processor):
70
+ inputs = self.processor(image)
71
+ else:
72
+ raise ValueError(
73
+ "The processor should be a transformers Processor or a callable function"
74
+ )
75
+ else:
76
+ # if processor is None, return the raw image directly
77
+ inputs = image
78
+ # convert boolean label to int, this is for the case when the label is a binary classification task
79
+ if isinstance(item["label"], bool):
80
+ item["label"] = 1 if item["label"] else 0
55
81
  return inputs, item["label"]
@@ -59,6 +59,10 @@ class BaseAlgorithm(BaseYAMLSerializable):
59
59
  core fusion logic in the `run` method, while optional lifecycle hooks allow for
60
60
  setup and cleanup operations.
61
61
 
62
+ If model has `_fusion_bench_target_modules` attribute, the algorithm will only fuse
63
+ the specified target modules. This is useful for models where only certain layers
64
+ should be fused (e.g., classification heads on top of a shared backbone are not merged).
65
+
62
66
  Attributes:
63
67
  _program: Optional program reference for algorithm execution context.
64
68
  _config_key (str): Configuration key used for YAML serialization, defaults to "method".
@@ -79,28 +79,6 @@ class ContinualDOPForCLIP(BaseAlgorithm, LightningFabricMixin):
79
79
  ), "The alpha should be in the range of [0, 1]"
80
80
  super().__init__(**kwargs)
81
81
 
82
- def print_params(self, pretrained_model):
83
- total_params = 0
84
- linear_params = 0
85
- linear_weight_params = 0
86
- for module_name, module in pretrained_model.named_modules():
87
- if not is_leaf_module(module):
88
- continue
89
- if isinstance(module, nn.Linear):
90
- linear_params += sum(p.numel() for n, p in module.named_parameters())
91
- linear_weight_params += sum(
92
- p.numel() for n, p in module.named_parameters() if "weight" in n
93
- )
94
- total_params += sum(p.numel() for p in module.parameters())
95
-
96
- linear_ratio = linear_params / total_params * 100
97
- linear_weight_ratio = linear_weight_params / total_params * 100
98
- print(f"Total Parameters: {total_params}")
99
- print(f"Linear Parameters: {linear_params}")
100
- print(f"Linear Weight Parameters: {linear_weight_params}")
101
- print(f"Linear Ratio: {linear_ratio:.2f}%")
102
- print(f"Linear Weight Ratio: {linear_weight_ratio:.2f}%")
103
-
104
82
  def run(self, modelpool: BaseModelPool):
105
83
  if self.seed is not None:
106
84
  L.seed_everything(self.seed)