fusion-bench 0.2.18__py3-none-any.whl → 0.2.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fusion_bench/__init__.py CHANGED
@@ -1,3 +1,9 @@
1
+ # ███████╗██╗ ██╗███████╗██╗ ██████╗ ███╗ ██╗ ██████╗ ███████╗███╗ ██╗ ██████╗██╗ ██╗
2
+ # ██╔════╝██║ ██║██╔════╝██║██╔═══██╗████╗ ██║ ██╔══██╗██╔════╝████╗ ██║██╔════╝██║ ██║
3
+ # █████╗ ██║ ██║███████╗██║██║ ██║██╔██╗ ██║█████╗██████╔╝█████╗ ██╔██╗ ██║██║ ███████║
4
+ # ██╔══╝ ██║ ██║╚════██║██║██║ ██║██║╚██╗██║╚════╝██╔══██╗██╔══╝ ██║╚██╗██║██║ ██╔══██║
5
+ # ██║ ╚██████╔╝███████║██║╚██████╔╝██║ ╚████║ ██████╔╝███████╗██║ ╚████║╚██████╗██║ ██║
6
+ # ╚═╝ ╚═════╝ ╚══════╝╚═╝ ╚═════╝ ╚═╝ ╚═══╝ ╚═════╝ ╚══════╝╚═╝ ╚═══╝ ╚═════╝╚═╝ ╚═╝
1
7
  # flake8: noqa: F401
2
8
  from . import (
3
9
  constants,
@@ -0,0 +1,12 @@
1
+ FUSION_BENCH_BANNER = (
2
+ ""
3
+ + "███████╗██╗ ██╗███████╗██╗ ██████╗ ███╗ ██╗ ██████╗ ███████╗███╗ ██╗ ██████╗██╗ ██╗\n"
4
+ + "██╔════╝██║ ██║██╔════╝██║██╔═══██╗████╗ ██║ ██╔══██╗██╔════╝████╗ ██║██╔════╝██║ ██║\n"
5
+ + "█████╗ ██║ ██║███████╗██║██║ ██║██╔██╗ ██║█████╗██████╔╝█████╗ ██╔██╗ ██║██║ ███████║\n"
6
+ + "██╔══╝ ██║ ██║╚════██║██║██║ ██║██║╚██╗██║╚════╝██╔══██╗██╔══╝ ██║╚██╗██║██║ ██╔══██║\n"
7
+ + "██║ ╚██████╔╝███████║██║╚██████╔╝██║ ╚████║ ██████╔╝███████╗██║ ╚████║╚██████╗██║ ██║\n"
8
+ + "╚═╝ ╚═════╝ ╚══════╝╚═╝ ╚═════╝ ╚═╝ ╚═══╝ ╚═════╝ ╚══════╝╚═╝ ╚═══╝ ╚═════╝╚═╝ ╚═╝\n"
9
+ )
10
+
11
+ if __name__ == "__main__":
12
+ print(FUSION_BENCH_BANNER)
@@ -1,4 +1,5 @@
1
- from typing import Optional
1
+ from copy import deepcopy
2
+ from typing import TYPE_CHECKING, Optional
2
3
 
3
4
  from typing_extensions import override
4
5
 
@@ -6,6 +7,11 @@ from fusion_bench import timeit_context
6
7
  from fusion_bench.method.base_algorithm import BaseAlgorithm
7
8
  from fusion_bench.method.simple_average import SimpleAverageAlgorithm
8
9
  from fusion_bench.modelpool import CausalLMBackbonePool, CausalLMPool
10
+ from fusion_bench.utils.pylogger import getRankZeroLogger
11
+ from omegaconf import flag_override
12
+ from fusion_bench.utils import instantiate
13
+
14
+ log = getRankZeroLogger(__name__)
9
15
 
10
16
 
11
17
  class SimpleAverageForLlama(BaseAlgorithm):
@@ -40,12 +46,20 @@ class SimpleAverageForLlama(BaseAlgorithm):
40
46
 
41
47
  if self.merge_backbone:
42
48
  assert modelpool.has_pretrained
43
- backbone_modelpool = CausalLMBackbonePool(**modelpool.config)
49
+ log.info(
50
+ "Merging backbone of the model pool, use CausalLMBackbonePool instead of CausalLMPool."
51
+ )
52
+ modelpool_config = deepcopy(modelpool.config)
53
+ with flag_override(modelpool_config, "allow_objects", True):
54
+ modelpool_config._target_ = (
55
+ "fusion_bench.modelpool.causal_lm.CausalLMBackbonePool"
56
+ )
57
+ backbone_modelpool = instantiate(modelpool_config)
44
58
  model = modelpool.load_model("_pretrained_")
45
59
  backbone_model = SimpleAverageAlgorithm().run(backbone_modelpool)
46
60
  model.model.layers = backbone_model
47
61
  else:
48
- model = SimpleAverageAlgorithm().run()
62
+ model = SimpleAverageAlgorithm().run(modelpool=modelpool)
49
63
 
50
64
  if self.model_save_path is not None:
51
65
  with timeit_context(f"Saving the model to {self.model_save_path}"):
@@ -8,6 +8,7 @@ from torch import nn
8
8
  from fusion_bench.method.base_algorithm import BaseAlgorithm
9
9
  from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
10
10
  from fusion_bench.modelpool import BaseModelPool
11
+ from fusion_bench.utils import LazyStateDict
11
12
  from fusion_bench.utils.state_dict_arithmetic import (
12
13
  state_dict_add,
13
14
  state_dict_avg,
@@ -104,6 +105,15 @@ class SimpleAverageAlgorithm(
104
105
  # Divide the accumulated state dictionary by the number of models to get the average
105
106
  sd = state_dict_div(sd, len(modelpool.model_names))
106
107
 
108
+ if isinstance(forward_model, LazyStateDict):
109
+ # if the model is a LazyStateDict, convert it to an empty module
110
+ forward_model = forward_model.meta_module.to_empty(
111
+ device=(
112
+ "cpu"
113
+ if forward_model._torch_dtype is None
114
+ else forward_model._torch_dtype
115
+ )
116
+ )
107
117
  forward_model.load_state_dict(sd)
108
118
  # print profile report and log the merged models
109
119
  self.print_profile_summary()
@@ -22,6 +22,8 @@ from typing_extensions import override
22
22
  from fusion_bench.modelpool import BaseModelPool
23
23
  from fusion_bench.utils import instantiate
24
24
  from fusion_bench.utils.dtype import parse_dtype
25
+ from fusion_bench.utils.lazy_state_dict import LazyStateDict
26
+ from fusion_bench.utils.packages import import_object
25
27
 
26
28
  log = logging.getLogger(__name__)
27
29
 
@@ -30,6 +32,7 @@ class CausalLMPool(BaseModelPool):
30
32
  _config_mapping = BaseModelPool._config_mapping | {
31
33
  "_tokenizer": "tokenizer",
32
34
  "_model_kwargs": "model_kwargs",
35
+ "load_lazy": "load_lazy",
33
36
  }
34
37
 
35
38
  def __init__(
@@ -38,6 +41,7 @@ class CausalLMPool(BaseModelPool):
38
41
  *,
39
42
  tokenizer: Optional[DictConfig],
40
43
  model_kwargs: Optional[DictConfig] = None,
44
+ load_lazy: bool = False,
41
45
  **kwargs,
42
46
  ):
43
47
  super().__init__(models, **kwargs)
@@ -51,6 +55,7 @@ class CausalLMPool(BaseModelPool):
51
55
  self._model_kwargs.torch_dtype = parse_dtype(
52
56
  self._model_kwargs.torch_dtype
53
57
  )
58
+ self.load_lazy = load_lazy
54
59
 
55
60
  @override
56
61
  def load_model(
@@ -88,21 +93,41 @@ class CausalLMPool(BaseModelPool):
88
93
  model_kwargs.update(kwargs)
89
94
 
90
95
  if isinstance(model_name_or_config, str):
96
+ # If model_name_or_config is a string, it is the name or the path of the model
91
97
  log.info(f"Loading model: {model_name_or_config}", stacklevel=2)
92
98
  if model_name_or_config in self._models.keys():
93
99
  model_config = self._models[model_name_or_config]
94
100
  if isinstance(model_config, str):
95
101
  # model_config is a string
96
- model = AutoModelForCausalLM.from_pretrained(
97
- model_config,
98
- *args,
99
- **model_kwargs,
100
- )
102
+ if not self.load_lazy:
103
+ model = AutoModelForCausalLM.from_pretrained(
104
+ model_config,
105
+ *args,
106
+ **model_kwargs,
107
+ )
108
+ else:
109
+ # model_config is a string, but we want to use LazyStateDict
110
+ model = LazyStateDict(
111
+ checkpoint=model_config,
112
+ meta_module_class=AutoModelForCausalLM,
113
+ *args,
114
+ **model_kwargs,
115
+ )
101
116
  return model
102
117
  elif isinstance(model_name_or_config, (DictConfig, Dict)):
103
118
  model_config = model_name_or_config
104
119
 
105
- model = instantiate(model_config, *args, **model_kwargs)
120
+ if not self.load_lazy:
121
+ model = instantiate(model_config, *args, **model_kwargs)
122
+ else:
123
+ meta_module_class = model_config.pop("_target_")
124
+ checkpoint = model_config.pop("pretrained_model_name_or_path")
125
+ model = LazyStateDict(
126
+ checkpoint=checkpoint,
127
+ meta_module_class=meta_module_class,
128
+ *args,
129
+ **model_kwargs,
130
+ )
106
131
  return model
107
132
 
108
133
  def load_tokenizer(self, *args, **kwargs) -> PreTrainedTokenizer:
@@ -179,6 +204,12 @@ class CausalLMBackbonePool(CausalLMPool):
179
204
  def load_model(
180
205
  self, model_name_or_config: str | DictConfig, *args, **kwargs
181
206
  ) -> Module:
207
+ if self.load_lazy:
208
+ log.warning(
209
+ "CausalLMBackbonePool does not support lazy loading. "
210
+ "Falling back to normal loading."
211
+ )
212
+ self.load_lazy = False
182
213
  model: AutoModelForCausalLM = super().load_model(
183
214
  model_name_or_config, *args, **kwargs
184
215
  )