fusion-bench 0.2.18__py3-none-any.whl → 0.2.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. fusion_bench/__init__.py +6 -0
  2. fusion_bench/constants/banner.py +12 -0
  3. fusion_bench/method/__init__.py +2 -0
  4. fusion_bench/method/linear/simple_average_for_llama.py +30 -5
  5. fusion_bench/method/regmean_plusplus/__init__.py +3 -0
  6. fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +192 -0
  7. fusion_bench/method/regmean_plusplus/regmean_plusplus.py +365 -0
  8. fusion_bench/method/simple_average.py +29 -3
  9. fusion_bench/modelpool/causal_lm/causal_lm.py +37 -6
  10. fusion_bench/modelpool/clip_vision/modelpool.py +45 -12
  11. fusion_bench/scripts/cli.py +1 -1
  12. fusion_bench/tasks/clip_classification/imagenet.py +1008 -2004
  13. fusion_bench/utils/lazy_state_dict.py +75 -3
  14. fusion_bench/utils/misc.py +66 -2
  15. fusion_bench/utils/modelscope.py +146 -0
  16. fusion_bench/utils/state_dict_arithmetic.py +10 -5
  17. {fusion_bench-0.2.18.dist-info → fusion_bench-0.2.20.dist-info}/METADATA +9 -1
  18. {fusion_bench-0.2.18.dist-info → fusion_bench-0.2.20.dist-info}/RECORD +50 -43
  19. fusion_bench_config/method/regmean/clip_regmean.yaml +1 -1
  20. fusion_bench_config/method/regmean_plusplus/clip_regmean_plusplus.yaml +11 -0
  21. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +73 -8
  22. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +27 -7
  23. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml +34 -4
  24. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +14 -17
  25. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml +14 -3
  26. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +39 -5
  27. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +49 -5
  28. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +55 -5
  29. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +21 -4
  30. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +61 -5
  31. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +67 -5
  32. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +73 -5
  33. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +26 -3
  34. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +7 -5
  35. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +6 -10
  36. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_cars.yaml +6 -7
  37. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_dtd.yaml +6 -7
  38. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_cars_and_dtd.yaml +7 -8
  39. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +8 -0
  40. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +4 -6
  41. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +32 -7
  42. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +14 -6
  43. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +73 -8
  44. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +27 -7
  45. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +6 -10
  46. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +11 -0
  47. {fusion_bench-0.2.18.dist-info → fusion_bench-0.2.20.dist-info}/WHEEL +0 -0
  48. {fusion_bench-0.2.18.dist-info → fusion_bench-0.2.20.dist-info}/entry_points.txt +0 -0
  49. {fusion_bench-0.2.18.dist-info → fusion_bench-0.2.20.dist-info}/licenses/LICENSE +0 -0
  50. {fusion_bench-0.2.18.dist-info → fusion_bench-0.2.20.dist-info}/top_level.txt +0 -0
@@ -8,6 +8,7 @@ from torch import nn
8
8
  from fusion_bench.method.base_algorithm import BaseAlgorithm
9
9
  from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
10
10
  from fusion_bench.modelpool import BaseModelPool
11
+ from fusion_bench.utils import LazyStateDict
11
12
  from fusion_bench.utils.state_dict_arithmetic import (
12
13
  state_dict_add,
13
14
  state_dict_avg,
@@ -62,6 +63,18 @@ class SimpleAverageAlgorithm(
62
63
  BaseAlgorithm,
63
64
  SimpleProfilerMixin,
64
65
  ):
66
+ _config_mapping = BaseAlgorithm._config_mapping | {
67
+ "show_pbar": "show_pbar",
68
+ }
69
+
70
+ def __init__(self, show_pbar: bool = False):
71
+ """
72
+ Args:
73
+ show_pbar (bool): If True, shows a progress bar during model loading and merging. Default is False.
74
+ """
75
+ super().__init__()
76
+ self.show_pbar = show_pbar
77
+
65
78
  @torch.no_grad()
66
79
  def run(self, modelpool: Union[BaseModelPool, Dict[str, nn.Module]]):
67
80
  """
@@ -99,11 +112,24 @@ class SimpleAverageAlgorithm(
99
112
  forward_model = model
100
113
  else:
101
114
  # Add the current model's state dictionary to the accumulated state dictionary
102
- sd = state_dict_add(sd, model.state_dict(keep_vars=True))
115
+ sd = state_dict_add(
116
+ sd, model.state_dict(keep_vars=True), show_pbar=self.show_pbar
117
+ )
103
118
  with self.profile("merge weights"):
104
119
  # Divide the accumulated state dictionary by the number of models to get the average
105
- sd = state_dict_div(sd, len(modelpool.model_names))
106
-
120
+ sd = state_dict_div(
121
+ sd, len(modelpool.model_names), show_pbar=self.show_pbar
122
+ )
123
+
124
+ if isinstance(forward_model, LazyStateDict):
125
+ # if the model is a LazyStateDict, convert it to an empty module
126
+ forward_model = forward_model.meta_module.to_empty(
127
+ device=(
128
+ "cpu"
129
+ if forward_model._torch_dtype is None
130
+ else forward_model._torch_dtype
131
+ )
132
+ )
107
133
  forward_model.load_state_dict(sd)
108
134
  # print profile report and log the merged models
109
135
  self.print_profile_summary()
@@ -22,6 +22,8 @@ from typing_extensions import override
22
22
  from fusion_bench.modelpool import BaseModelPool
23
23
  from fusion_bench.utils import instantiate
24
24
  from fusion_bench.utils.dtype import parse_dtype
25
+ from fusion_bench.utils.lazy_state_dict import LazyStateDict
26
+ from fusion_bench.utils.packages import import_object
25
27
 
26
28
  log = logging.getLogger(__name__)
27
29
 
@@ -30,6 +32,7 @@ class CausalLMPool(BaseModelPool):
30
32
  _config_mapping = BaseModelPool._config_mapping | {
31
33
  "_tokenizer": "tokenizer",
32
34
  "_model_kwargs": "model_kwargs",
35
+ "load_lazy": "load_lazy",
33
36
  }
34
37
 
35
38
  def __init__(
@@ -38,6 +41,7 @@ class CausalLMPool(BaseModelPool):
38
41
  *,
39
42
  tokenizer: Optional[DictConfig],
40
43
  model_kwargs: Optional[DictConfig] = None,
44
+ load_lazy: bool = False,
41
45
  **kwargs,
42
46
  ):
43
47
  super().__init__(models, **kwargs)
@@ -51,6 +55,7 @@ class CausalLMPool(BaseModelPool):
51
55
  self._model_kwargs.torch_dtype = parse_dtype(
52
56
  self._model_kwargs.torch_dtype
53
57
  )
58
+ self.load_lazy = load_lazy
54
59
 
55
60
  @override
56
61
  def load_model(
@@ -88,21 +93,41 @@ class CausalLMPool(BaseModelPool):
88
93
  model_kwargs.update(kwargs)
89
94
 
90
95
  if isinstance(model_name_or_config, str):
96
+ # If model_name_or_config is a string, it is the name or the path of the model
91
97
  log.info(f"Loading model: {model_name_or_config}", stacklevel=2)
92
98
  if model_name_or_config in self._models.keys():
93
99
  model_config = self._models[model_name_or_config]
94
100
  if isinstance(model_config, str):
95
101
  # model_config is a string
96
- model = AutoModelForCausalLM.from_pretrained(
97
- model_config,
98
- *args,
99
- **model_kwargs,
100
- )
102
+ if not self.load_lazy:
103
+ model = AutoModelForCausalLM.from_pretrained(
104
+ model_config,
105
+ *args,
106
+ **model_kwargs,
107
+ )
108
+ else:
109
+ # model_config is a string, but we want to use LazyStateDict
110
+ model = LazyStateDict(
111
+ checkpoint=model_config,
112
+ meta_module_class=AutoModelForCausalLM,
113
+ *args,
114
+ **model_kwargs,
115
+ )
101
116
  return model
102
117
  elif isinstance(model_name_or_config, (DictConfig, Dict)):
103
118
  model_config = model_name_or_config
104
119
 
105
- model = instantiate(model_config, *args, **model_kwargs)
120
+ if not self.load_lazy:
121
+ model = instantiate(model_config, *args, **model_kwargs)
122
+ else:
123
+ meta_module_class = model_config.pop("_target_")
124
+ checkpoint = model_config.pop("pretrained_model_name_or_path")
125
+ model = LazyStateDict(
126
+ checkpoint=checkpoint,
127
+ meta_module_class=meta_module_class,
128
+ *args,
129
+ **model_kwargs,
130
+ )
106
131
  return model
107
132
 
108
133
  def load_tokenizer(self, *args, **kwargs) -> PreTrainedTokenizer:
@@ -179,6 +204,12 @@ class CausalLMBackbonePool(CausalLMPool):
179
204
  def load_model(
180
205
  self, model_name_or_config: str | DictConfig, *args, **kwargs
181
206
  ) -> Module:
207
+ if self.load_lazy:
208
+ log.warning(
209
+ "CausalLMBackbonePool does not support lazy loading. "
210
+ "Falling back to normal loading."
211
+ )
212
+ self.load_lazy = False
182
213
  model: AutoModelForCausalLM = super().load_model(
183
214
  model_name_or_config, *args, **kwargs
184
215
  )
@@ -1,6 +1,6 @@
1
1
  import logging
2
2
  from copy import deepcopy
3
- from typing import Optional, Union
3
+ from typing import Literal, Optional, Union
4
4
 
5
5
  from datasets import load_dataset
6
6
  from lightning.fabric.utilities import rank_zero_only
@@ -11,6 +11,9 @@ from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
11
11
  from typing_extensions import override
12
12
 
13
13
  from fusion_bench.utils import instantiate, timeit_context
14
+ from fusion_bench.utils.modelscope import (
15
+ resolve_repo_path,
16
+ )
14
17
 
15
18
  from ..base_pool import BaseModelPool
16
19
 
@@ -25,25 +28,32 @@ class CLIPVisionModelPool(BaseModelPool):
25
28
  the specifics of the CLIP Vision models provided by the Hugging Face Transformers library.
26
29
  """
27
30
 
28
- _config_mapping = BaseModelPool._config_mapping | {"_processor": "processor"}
31
+ _config_mapping = BaseModelPool._config_mapping | {
32
+ "_processor": "processor",
33
+ "_platform": "hf",
34
+ }
29
35
 
30
36
  def __init__(
31
37
  self,
32
38
  models: DictConfig,
33
39
  *,
34
40
  processor: Optional[DictConfig] = None,
41
+ platform: Literal["hf", "huggingface", "modelscope"] = "hf",
35
42
  **kwargs,
36
43
  ):
37
44
  super().__init__(models, **kwargs)
38
-
39
45
  self._processor = processor
46
+ self._platform = platform
40
47
 
41
48
  def load_processor(self, *args, **kwargs) -> CLIPProcessor:
42
49
  assert self._processor is not None, "Processor is not defined in the config"
43
50
  if isinstance(self._processor, str):
44
51
  if rank_zero_only.rank == 0:
45
52
  log.info(f"Loading `transformers.CLIPProcessor`: {self._processor}")
46
- processor = CLIPProcessor.from_pretrained(self._processor)
53
+ repo_path = resolve_repo_path(
54
+ repo_id=self._processor, repo_type="model", platform=self._platform
55
+ )
56
+ processor = CLIPProcessor.from_pretrained(repo_path, *args, **kwargs)
47
57
  else:
48
58
  processor = instantiate(self._processor, *args, **kwargs)
49
59
  return processor
@@ -54,7 +64,10 @@ class CLIPVisionModelPool(BaseModelPool):
54
64
  if isinstance(model_config, str):
55
65
  if rank_zero_only.rank == 0:
56
66
  log.info(f"Loading `transformers.CLIPModel`: {model_config}")
57
- clip_model = CLIPModel.from_pretrained(model_config, *args, **kwargs)
67
+ repo_path = resolve_repo_path(
68
+ repo_id=model_config, repo_type="model", platform=self._platform
69
+ )
70
+ clip_model = CLIPModel.from_pretrained(repo_path, *args, **kwargs)
58
71
  return clip_model
59
72
  else:
60
73
  assert isinstance(
@@ -107,14 +120,17 @@ class CLIPVisionModelPool(BaseModelPool):
107
120
  if isinstance(model, str):
108
121
  if rank_zero_only.rank == 0:
109
122
  log.info(f"Loading `transformers.CLIPVisionModel`: {model}")
110
- return CLIPVisionModel.from_pretrained(model, *args, **kwargs)
123
+ repo_path = resolve_repo_path(
124
+ model, repo_type="model", platform=self._platform
125
+ )
126
+ return CLIPVisionModel.from_pretrained(repo_path, *args, **kwargs)
111
127
  if isinstance(model, nn.Module):
112
128
  if rank_zero_only.rank == 0:
113
129
  log.info(f"Returning existing model: {model}")
114
130
  return model
115
-
116
- # If the model is not a string, we use the default load_model method
117
- return super().load_model(model_name_or_config, *args, **kwargs)
131
+ else:
132
+ # If the model is not a string, we use the default load_model method
133
+ return super().load_model(model_name_or_config, *args, **kwargs)
118
134
 
119
135
  def load_train_dataset(self, dataset_name: str, *args, **kwargs):
120
136
  dataset_config = self._train_datasets[dataset_name]
@@ -123,7 +139,7 @@ class CLIPVisionModelPool(BaseModelPool):
123
139
  log.info(
124
140
  f"Loading train dataset using `datasets.load_dataset`: {dataset_config}"
125
141
  )
126
- dataset = load_dataset(dataset_config, split="train")
142
+ dataset = self._load_dataset(dataset_config, split="train")
127
143
  else:
128
144
  dataset = super().load_train_dataset(dataset_name, *args, **kwargs)
129
145
  return dataset
@@ -135,7 +151,7 @@ class CLIPVisionModelPool(BaseModelPool):
135
151
  log.info(
136
152
  f"Loading validation dataset using `datasets.load_dataset`: {dataset_config}"
137
153
  )
138
- dataset = load_dataset(dataset_config, split="validation")
154
+ dataset = self._load_dataset(dataset_config, split="validation")
139
155
  else:
140
156
  dataset = super().load_val_dataset(dataset_name, *args, **kwargs)
141
157
  return dataset
@@ -147,7 +163,24 @@ class CLIPVisionModelPool(BaseModelPool):
147
163
  log.info(
148
164
  f"Loading test dataset using `datasets.load_dataset`: {dataset_config}"
149
165
  )
150
- dataset = load_dataset(dataset_config, split="test")
166
+ dataset = self._load_dataset(dataset_config, split="test")
151
167
  else:
152
168
  dataset = super().load_test_dataset(dataset_name, *args, **kwargs)
153
169
  return dataset
170
+
171
+ def _load_dataset(self, name: str, split: str):
172
+ """
173
+ Load a dataset by its name and split.
174
+
175
+ Args:
176
+ dataset_name (str): The name of the dataset.
177
+ split (str): The split of the dataset to load (e.g., "train", "validation", "test").
178
+
179
+ Returns:
180
+ Dataset: The loaded dataset.
181
+ """
182
+ datset_dir = resolve_repo_path(
183
+ name, repo_type="dataset", platform=self._platform
184
+ )
185
+ dataset = load_dataset(datset_dir, split=split)
186
+ return dataset
@@ -1,6 +1,6 @@
1
1
  #!/usr/bin/env python3
2
2
  """
3
- This is the CLI script that is executed when the user runs the `fusion-bench` command.
3
+ This is the CLI script that is executed when the user runs the `fusion_bench` command.
4
4
  The script is responsible for parsing the command-line arguments, loading the configuration file, and running the fusion algorithm.
5
5
  """
6
6