fusion-bench 0.2.23__py3-none-any.whl → 0.2.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. fusion_bench/__init__.py +152 -42
  2. fusion_bench/dataset/__init__.py +27 -4
  3. fusion_bench/dataset/clip_dataset.py +2 -2
  4. fusion_bench/method/__init__.py +18 -1
  5. fusion_bench/method/classification/__init__.py +27 -2
  6. fusion_bench/method/classification/image_classification_finetune.py +214 -0
  7. fusion_bench/method/ensemble.py +17 -2
  8. fusion_bench/method/linear/__init__.py +6 -2
  9. fusion_bench/method/linear/{simple_average_for_llama.py → simple_average_for_causallm.py} +8 -4
  10. fusion_bench/method/linear/{task_arithmetic_for_llama.py → task_arithmetic_for_causallm.py} +22 -12
  11. fusion_bench/method/linear/ties_merging_for_causallm.py +70 -0
  12. fusion_bench/method/opcm/opcm.py +1 -0
  13. fusion_bench/method/pwe_moe/module.py +0 -2
  14. fusion_bench/method/simple_average.py +2 -2
  15. fusion_bench/method/tall_mask/task_arithmetic.py +2 -2
  16. fusion_bench/method/task_arithmetic/task_arithmetic.py +35 -10
  17. fusion_bench/method/ties_merging/ties_merging.py +22 -6
  18. fusion_bench/method/wudi/__init__.py +1 -0
  19. fusion_bench/method/wudi/wudi.py +105 -0
  20. fusion_bench/mixins/__init__.py +2 -0
  21. fusion_bench/mixins/lightning_fabric.py +4 -0
  22. fusion_bench/mixins/pyinstrument.py +174 -0
  23. fusion_bench/mixins/serialization.py +25 -78
  24. fusion_bench/mixins/simple_profiler.py +106 -23
  25. fusion_bench/modelpool/__init__.py +2 -0
  26. fusion_bench/modelpool/base_pool.py +77 -14
  27. fusion_bench/modelpool/causal_lm/causal_lm.py +32 -10
  28. fusion_bench/modelpool/clip_vision/modelpool.py +56 -19
  29. fusion_bench/modelpool/resnet_for_image_classification.py +208 -0
  30. fusion_bench/models/__init__.py +35 -9
  31. fusion_bench/models/hf_clip.py +4 -0
  32. fusion_bench/models/hf_utils.py +2 -1
  33. fusion_bench/models/model_card_templates/default.md +8 -1
  34. fusion_bench/models/wrappers/ensemble.py +136 -7
  35. fusion_bench/optim/__init__.py +40 -2
  36. fusion_bench/optim/lr_scheduler/__init__.py +27 -1
  37. fusion_bench/optim/muon.py +339 -0
  38. fusion_bench/programs/__init__.py +2 -0
  39. fusion_bench/programs/fabric_fusion_program.py +2 -2
  40. fusion_bench/programs/fusion_program.py +271 -0
  41. fusion_bench/scripts/cli.py +2 -2
  42. fusion_bench/taskpool/clip_vision/taskpool.py +11 -4
  43. fusion_bench/tasks/clip_classification/__init__.py +15 -0
  44. fusion_bench/utils/__init__.py +167 -21
  45. fusion_bench/utils/devices.py +30 -8
  46. fusion_bench/utils/lazy_imports.py +91 -12
  47. fusion_bench/utils/lazy_state_dict.py +58 -5
  48. fusion_bench/utils/misc.py +104 -13
  49. fusion_bench/utils/packages.py +4 -0
  50. fusion_bench/utils/path.py +7 -0
  51. fusion_bench/utils/pylogger.py +6 -0
  52. fusion_bench/utils/rich_utils.py +8 -3
  53. fusion_bench/utils/state_dict_arithmetic.py +935 -162
  54. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/METADATA +10 -3
  55. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/RECORD +76 -55
  56. fusion_bench_config/method/classification/image_classification_finetune.yaml +16 -0
  57. fusion_bench_config/method/classification/image_classification_finetune_test.yaml +6 -0
  58. fusion_bench_config/method/ensemble/simple_ensemble.yaml +1 -0
  59. fusion_bench_config/method/linear/{simple_average_for_llama.yaml → simple_average_for_causallm.yaml} +1 -1
  60. fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +4 -0
  61. fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +13 -0
  62. fusion_bench_config/method/wudi/wudi.yaml +4 -0
  63. fusion_bench_config/model_fusion.yaml +45 -0
  64. fusion_bench_config/modelpool/CausalLMPool/{Qwen2.5-1.5B_math_and_coder.yaml → Qwen2.5-1.5B_math_and_code.yaml} +1 -2
  65. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_three_models.yaml +11 -0
  66. fusion_bench_config/modelpool/CausalLMPool/llama-7b_3-models_v1.yaml +11 -0
  67. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar10.yaml +14 -0
  68. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar100.yaml +14 -0
  69. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar10.yaml +14 -0
  70. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar100.yaml +14 -0
  71. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar10.yaml +14 -0
  72. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar100.yaml +14 -0
  73. fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +0 -4
  74. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/WHEEL +0 -0
  75. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/entry_points.txt +0 -0
  76. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/licenses/LICENSE +0 -0
  77. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/top_level.txt +0 -0
fusion_bench/__init__.py CHANGED
@@ -5,46 +5,156 @@
5
5
  # ██║ ╚██████╔╝███████║██║╚██████╔╝██║ ╚████║ ██████╔╝███████╗██║ ╚████║╚██████╗██║ ██║
6
6
  # ╚═╝ ╚═════╝ ╚══════╝╚═╝ ╚═════╝ ╚═╝ ╚═══╝ ╚═════╝ ╚══════╝╚═╝ ╚═══╝ ╚═════╝╚═╝ ╚═╝
7
7
  # flake8: noqa: F401
8
- from . import (
9
- constants,
10
- dataset,
11
- method,
12
- metrics,
13
- mixins,
14
- modelpool,
15
- models,
16
- optim,
17
- programs,
18
- taskpool,
19
- tasks,
20
- utils,
21
- )
8
+ import sys
9
+ from typing import TYPE_CHECKING
10
+
11
+ from fusion_bench.utils.lazy_imports import LazyImporter
12
+
13
+ from . import constants, metrics, optim, tasks
22
14
  from .constants import RuntimeConstants
23
- from .method import BaseAlgorithm, BaseModelFusionAlgorithm
24
- from .mixins import auto_register_config
25
- from .modelpool import BaseModelPool
26
- from .models import (
27
- create_default_model_card,
28
- load_model_card_template,
29
- save_pretrained_with_remote_code,
30
- separate_io,
31
- )
32
- from .programs import BaseHydraProgram
33
- from .taskpool import BaseTaskPool
34
- from .utils import (
35
- BoolStateDictType,
36
- LazyStateDict,
37
- StateDictType,
38
- TorchModelType,
39
- cache_with_joblib,
40
- get_rankzero_logger,
41
- import_object,
42
- instantiate,
43
- parse_dtype,
44
- print_parameters,
45
- seed_everything_by_time,
46
- set_default_cache_dir,
47
- set_print_function_call,
48
- set_print_function_call_permeanent,
49
- timeit_context,
50
- )
15
+ from .method import _available_algorithms
16
+
17
+ _extra_objects = {
18
+ "RuntimeConstants": RuntimeConstants,
19
+ "constants": constants,
20
+ "metrics": metrics,
21
+ "optim": optim,
22
+ "tasks": tasks,
23
+ }
24
+ _import_structure = {
25
+ "dataset": ["CLIPDataset"],
26
+ "method": _available_algorithms,
27
+ "mixins": [
28
+ "CLIPClassificationMixin",
29
+ "FabricTrainingMixin",
30
+ "HydraConfigMixin",
31
+ "LightningFabricMixin",
32
+ "OpenCLIPClassificationMixin",
33
+ "PyinstrumentProfilerMixin",
34
+ "SimpleProfilerMixin",
35
+ "YAMLSerializationMixin",
36
+ "auto_register_config",
37
+ ],
38
+ "modelpool": [
39
+ "AutoModelPool",
40
+ "BaseModelPool",
41
+ "CausalLMBackbonePool",
42
+ "CausalLMPool",
43
+ "CLIPVisionModelPool",
44
+ "GPT2ForSequenceClassificationPool",
45
+ "HuggingFaceGPT2ClassificationPool",
46
+ "NYUv2ModelPool",
47
+ "OpenCLIPVisionModelPool",
48
+ "PeftModelForSeq2SeqLMPool",
49
+ "ResNetForImageClassificationPool",
50
+ "Seq2SeqLMPool",
51
+ "SequenceClassificationModelPool",
52
+ ],
53
+ "models": [
54
+ "create_default_model_card",
55
+ "load_model_card_template",
56
+ "save_pretrained_with_remote_code",
57
+ "separate_load",
58
+ "separate_save",
59
+ ],
60
+ "programs": ["BaseHydraProgram", "FabricModelFusionProgram"],
61
+ "taskpool": [
62
+ "BaseTaskPool",
63
+ "CLIPVisionModelTaskPool",
64
+ "DummyTaskPool",
65
+ "GPT2TextClassificationTaskPool",
66
+ "LMEvalHarnessTaskPool",
67
+ "OpenCLIPVisionModelTaskPool",
68
+ "NYUv2TaskPool",
69
+ ],
70
+ "utils": [
71
+ "ArithmeticStateDict",
72
+ "BoolStateDictType",
73
+ "LazyStateDict",
74
+ "StateDictType",
75
+ "TorchModelType",
76
+ "cache_with_joblib",
77
+ "get_rankzero_logger",
78
+ "import_object",
79
+ "instantiate",
80
+ "parse_dtype",
81
+ "print_parameters",
82
+ "seed_everything_by_time",
83
+ "set_default_cache_dir",
84
+ "set_print_function_call",
85
+ "set_print_function_call_permeanent",
86
+ "timeit_context",
87
+ ],
88
+ }
89
+
90
+ if TYPE_CHECKING:
91
+ from .dataset import CLIPDataset
92
+ from .method import BaseAlgorithm, BaseModelFusionAlgorithm
93
+ from .mixins import (
94
+ CLIPClassificationMixin,
95
+ FabricTrainingMixin,
96
+ HydraConfigMixin,
97
+ LightningFabricMixin,
98
+ OpenCLIPClassificationMixin,
99
+ PyinstrumentProfilerMixin,
100
+ SimpleProfilerMixin,
101
+ YAMLSerializationMixin,
102
+ auto_register_config,
103
+ )
104
+ from .modelpool import (
105
+ AutoModelPool,
106
+ BaseModelPool,
107
+ CausalLMBackbonePool,
108
+ CausalLMPool,
109
+ CLIPVisionModelPool,
110
+ GPT2ForSequenceClassificationPool,
111
+ HuggingFaceGPT2ClassificationPool,
112
+ NYUv2ModelPool,
113
+ OpenCLIPVisionModelPool,
114
+ PeftModelForSeq2SeqLMPool,
115
+ ResNetForImageClassificationPool,
116
+ Seq2SeqLMPool,
117
+ SequenceClassificationModelPool,
118
+ )
119
+ from .models import (
120
+ create_default_model_card,
121
+ load_model_card_template,
122
+ save_pretrained_with_remote_code,
123
+ separate_load,
124
+ separate_save,
125
+ )
126
+ from .programs import BaseHydraProgram, FabricModelFusionProgram
127
+ from .taskpool import (
128
+ BaseTaskPool,
129
+ CLIPVisionModelTaskPool,
130
+ DummyTaskPool,
131
+ GPT2TextClassificationTaskPool,
132
+ LMEvalHarnessTaskPool,
133
+ NYUv2TaskPool,
134
+ OpenCLIPVisionModelTaskPool,
135
+ )
136
+ from .utils import (
137
+ ArithmeticStateDict,
138
+ BoolStateDictType,
139
+ LazyStateDict,
140
+ StateDictType,
141
+ TorchModelType,
142
+ cache_with_joblib,
143
+ get_rankzero_logger,
144
+ import_object,
145
+ instantiate,
146
+ parse_dtype,
147
+ print_parameters,
148
+ seed_everything_by_time,
149
+ set_default_cache_dir,
150
+ set_print_function_call,
151
+ set_print_function_call_permeanent,
152
+ timeit_context,
153
+ )
154
+ else:
155
+ sys.modules[__name__] = LazyImporter(
156
+ __name__,
157
+ globals()["__file__"],
158
+ _import_structure,
159
+ extra_objects=_extra_objects,
160
+ )
@@ -1,16 +1,20 @@
1
1
  # flake8: noqa F401
2
- from datasets import load_dataset
3
- from omegaconf import DictConfig, open_dict
2
+ import sys
3
+ from typing import TYPE_CHECKING
4
4
 
5
- from fusion_bench.utils import instantiate
5
+ from omegaconf import DictConfig, open_dict
6
6
 
7
- from .clip_dataset import CLIPDataset
7
+ from fusion_bench.utils.lazy_imports import LazyImporter
8
8
 
9
9
 
10
10
  def load_dataset_from_config(dataset_config: DictConfig):
11
11
  """
12
12
  Load the dataset from the configuration.
13
13
  """
14
+ from datasets import load_dataset
15
+
16
+ from fusion_bench.utils import instantiate
17
+
14
18
  assert hasattr(dataset_config, "type"), "Dataset type not specified"
15
19
  if dataset_config.type == "instantiate":
16
20
  return instantiate(dataset_config.object)
@@ -27,3 +31,22 @@ def load_dataset_from_config(dataset_config: DictConfig):
27
31
  return dataset
28
32
  else:
29
33
  raise ValueError(f"Unknown dataset type: {dataset_config.type}")
34
+
35
+
36
+ _extra_objects = {
37
+ "load_dataset_from_config": load_dataset_from_config,
38
+ }
39
+ _import_structure = {
40
+ "clip_dataset": ["CLIPDataset"],
41
+ }
42
+
43
+ if TYPE_CHECKING:
44
+ from .clip_dataset import CLIPDataset
45
+
46
+ else:
47
+ sys.modules[__name__] = LazyImporter(
48
+ __name__,
49
+ globals()["__file__"],
50
+ _import_structure,
51
+ extra_objects=_extra_objects,
52
+ )
@@ -6,7 +6,7 @@ from typing import Optional, Tuple
6
6
 
7
7
  import torch
8
8
  from torch.utils.data import Dataset
9
- from transformers import CLIPProcessor, ProcessorMixin
9
+ from transformers import BaseImageProcessor, CLIPProcessor, ProcessorMixin
10
10
 
11
11
  __all__ = ["CLIPDataset"]
12
12
 
@@ -60,7 +60,7 @@ class CLIPDataset(torch.utils.data.Dataset):
60
60
  raise ValueError("Each item should be a dictionary or a tuple of length 2")
61
61
  image = item["image"]
62
62
  if self.processor is not None:
63
- if isinstance(self.processor, ProcessorMixin):
63
+ if isinstance(self.processor, (ProcessorMixin, BaseImageProcessor)):
64
64
  # Apply the processor to the image to get the input tensor
65
65
  inputs = self.processor(images=[image], return_tensors="pt")[
66
66
  "pixel_values"
@@ -2,6 +2,7 @@
2
2
  import sys
3
3
  from typing import TYPE_CHECKING
4
4
 
5
+ from fusion_bench.utils import join_lists
5
6
  from fusion_bench.utils.lazy_imports import LazyImporter
6
7
 
7
8
  _import_structure = {
@@ -12,6 +13,8 @@ _import_structure = {
12
13
  "classification": [
13
14
  "ImageClassificationFineTuningForCLIP",
14
15
  "ContinualImageClassificationFineTuningForCLIP",
16
+ "ImageClassificationFineTuning",
17
+ "ImageClassificationFineTuning_Test",
15
18
  ],
16
19
  "lm_finetune": ["FullFinetuneSFT", "PeftFinetuneSFT", "BradleyTerryRewardModeling"],
17
20
  # analysis
@@ -26,9 +29,12 @@ _import_structure = {
26
29
  "linear": [
27
30
  "ExPOAlgorithm",
28
31
  "ExPOAlgorithmForLlama",
32
+ "SimpleAverageForCausalLM",
29
33
  "SimpleAverageForLlama",
34
+ "TaskArithmeticForCausalLM",
30
35
  "TaskArithmeticForLlama",
31
36
  "LinearInterpolationAlgorithm",
37
+ "TiesMergingForCausalLM",
32
38
  ],
33
39
  "slerp": ["SlerpMergeAlgorithm", "SlerpForCausalLM"],
34
40
  "simple_average": ["SimpleAverageAlgorithm"],
@@ -72,6 +78,7 @@ _import_structure = {
72
78
  "fw_merging": ["FrankWolfeHardAlgorithm", "FrankWolfeSoftAlgorithm"],
73
79
  "tall_mask": ["TallMaskTaskArithmeticAlgorithm"],
74
80
  "model_stock": ["ModelStock"],
81
+ "wudi": ["wudi_merging", "WUDIMerging"],
75
82
  # plug-and-play model merging methods
76
83
  "concrete_subspace": [
77
84
  "ConcreteTaskArithmeticAlgorithmForCLIP",
@@ -127,7 +134,10 @@ _import_structure = {
127
134
  "ProgressivePruningForMixtral",
128
135
  ],
129
136
  }
130
-
137
+ _available_algorithms = join_lists(list(_import_structure.values()))
138
+ _extra_objects = {
139
+ "_available_algorithms": _available_algorithms,
140
+ }
131
141
 
132
142
  if TYPE_CHECKING:
133
143
  from .ada_svd import AdaSVDMergingForCLIPVisionModel
@@ -137,6 +147,8 @@ if TYPE_CHECKING:
137
147
  from .bitdelta import BitDeltaAlgorithm
138
148
  from .classification import (
139
149
  ContinualImageClassificationFineTuningForCLIP,
150
+ ImageClassificationFineTuning,
151
+ ImageClassificationFineTuning_Test,
140
152
  ImageClassificationFineTuningForCLIP,
141
153
  )
142
154
  from .concrete_subspace import (
@@ -184,8 +196,11 @@ if TYPE_CHECKING:
184
196
  ExPOAlgorithm,
185
197
  ExPOAlgorithmForLlama,
186
198
  LinearInterpolationAlgorithm,
199
+ SimpleAverageForCausalLM,
187
200
  SimpleAverageForLlama,
201
+ TaskArithmeticForCausalLM,
188
202
  TaskArithmeticForLlama,
203
+ TiesMergingForCausalLM,
189
204
  )
190
205
  from .lm_finetune import *
191
206
  from .mixture_of_experts import (
@@ -238,10 +253,12 @@ if TYPE_CHECKING:
238
253
  FlanT5WeightEnsemblingMoEAlgorithm,
239
254
  )
240
255
  from .weighted_average import WeightedAverageAlgorithm, WeightedAverageForLLama
256
+ from .wudi import WUDIMerging, wudi_merging
241
257
 
242
258
  else:
243
259
  sys.modules[__name__] = LazyImporter(
244
260
  __name__,
245
261
  globals()["__file__"],
246
262
  _import_structure,
263
+ extra_objects=_extra_objects,
247
264
  )
@@ -1,3 +1,28 @@
1
1
  # flake8: noqa F401
2
- from .clip_finetune import ImageClassificationFineTuningForCLIP
3
- from .continual_clip_finetune import ContinualImageClassificationFineTuningForCLIP
2
+ import sys
3
+ from typing import TYPE_CHECKING
4
+
5
+ from fusion_bench.utils.lazy_imports import LazyImporter
6
+
7
+ _import_structure = {
8
+ "clip_finetune": ["ImageClassificationFineTuningForCLIP"],
9
+ "continual_clip_finetune": ["ContinualImageClassificationFineTuningForCLIP"],
10
+ "image_classification_finetune": [
11
+ "ImageClassificationFineTuning",
12
+ "ImageClassificationFineTuning_Test",
13
+ ],
14
+ }
15
+
16
+ if TYPE_CHECKING:
17
+ from .clip_finetune import ImageClassificationFineTuningForCLIP
18
+ from .continual_clip_finetune import ContinualImageClassificationFineTuningForCLIP
19
+ from .image_classification_finetune import (
20
+ ImageClassificationFineTuning,
21
+ ImageClassificationFineTuning_Test,
22
+ )
23
+ else:
24
+ sys.modules[__name__] = LazyImporter(
25
+ __name__,
26
+ globals()["__file__"],
27
+ _import_structure,
28
+ )
@@ -0,0 +1,214 @@
1
+ import os
2
+ from typing import Optional
3
+
4
+ import lightning as L
5
+ import lightning.pytorch.callbacks as pl_callbacks
6
+ import torch
7
+ from lightning.pytorch.loggers import TensorBoardLogger
8
+ from lightning_utilities.core.rank_zero import rank_zero_only
9
+ from lit_learn.lit_modules import ERM_LitModule
10
+ from omegaconf import DictConfig
11
+ from torch import nn
12
+ from torch.utils.data import DataLoader
13
+ from torchmetrics.classification import Accuracy
14
+
15
+ from fusion_bench import (
16
+ BaseAlgorithm,
17
+ BaseModelPool,
18
+ RuntimeConstants,
19
+ auto_register_config,
20
+ get_rankzero_logger,
21
+ instantiate,
22
+ )
23
+ from fusion_bench.dataset import CLIPDataset
24
+ from fusion_bench.modelpool import ResNetForImageClassificationPool
25
+ from fusion_bench.tasks.clip_classification import get_num_classes
26
+
27
+ log = get_rankzero_logger(__name__)
28
+
29
+
30
+ @auto_register_config
31
+ class ImageClassificationFineTuning(BaseAlgorithm):
32
+ def __init__(
33
+ self,
34
+ max_epochs: Optional[int],
35
+ max_steps: Optional[int],
36
+ label_smoothing: float,
37
+ optimizer: DictConfig,
38
+ lr_scheduler: DictConfig,
39
+ dataloader_kwargs: DictConfig,
40
+ **kwargs,
41
+ ):
42
+ super().__init__(**kwargs)
43
+ assert (max_epochs is None) or (
44
+ max_steps is None or max_steps < 0
45
+ ), "Only one of max_epochs or max_steps should be set."
46
+ self.training_interval = "epoch" if max_epochs is not None else "step"
47
+ if self.training_interval == "epoch":
48
+ self.max_steps = -1
49
+ log.info(f"Training interval: {self.training_interval}")
50
+ log.info(f"Max epochs: {max_epochs}, max steps: {max_steps}")
51
+
52
+ def run(self, modelpool: ResNetForImageClassificationPool):
53
+ # load model and dataset
54
+ model = modelpool.load_pretrained_or_first_model()
55
+ assert isinstance(model, nn.Module), "Loaded model is not a nn.Module."
56
+
57
+ assert (
58
+ len(modelpool.train_dataset_names) == 1
59
+ ), "Exactly one training dataset is required."
60
+ self.dataset_name = dataset_name = modelpool.train_dataset_names[0]
61
+ num_classes = get_num_classes(dataset_name)
62
+ train_dataset = modelpool.load_train_dataset(dataset_name)
63
+ train_dataset = CLIPDataset(
64
+ train_dataset, processor=modelpool.load_processor(stage="train")
65
+ )
66
+ train_loader = self.get_dataloader(train_dataset, stage="train")
67
+ if modelpool.has_val_dataset:
68
+ val_dataset = modelpool.load_val_dataset(dataset_name)
69
+ val_dataset = CLIPDataset(
70
+ val_dataset, processor=modelpool.load_processor(stage="val")
71
+ )
72
+ val_loader = self.get_dataloader(val_dataset, stage="val")
73
+
74
+ # configure optimizer
75
+ optimizer = instantiate(self.optimizer, params=model.parameters())
76
+ if self.lr_scheduler is not None:
77
+ lr_scheduler = instantiate(self.lr_scheduler, optimizer=optimizer)
78
+ optimizer = {
79
+ "optimizer": optimizer,
80
+ "lr_scheduler": {
81
+ "scheduler": lr_scheduler,
82
+ "interval": self.training_interval,
83
+ "frequency": 1,
84
+ },
85
+ }
86
+ log.info(f"optimizer:\n{optimizer}")
87
+
88
+ lit_module = ERM_LitModule(
89
+ model,
90
+ optimizer,
91
+ objective=nn.CrossEntropyLoss(label_smoothing=self.label_smoothing),
92
+ metrics={
93
+ "acc@1": Accuracy(task="multiclass", num_classes=num_classes),
94
+ "acc@5": Accuracy(task="multiclass", num_classes=num_classes, top_k=5),
95
+ },
96
+ )
97
+
98
+ log_dir = (
99
+ self._program.path.log_dir
100
+ if self._program is not None
101
+ else "outputs/lightning_logs"
102
+ )
103
+ trainer = L.Trainer(
104
+ max_epochs=self.max_epochs,
105
+ max_steps=self.max_steps,
106
+ accelerator="auto",
107
+ devices="auto",
108
+ callbacks=[
109
+ pl_callbacks.LearningRateMonitor(logging_interval="step"),
110
+ pl_callbacks.DeviceStatsMonitor(),
111
+ ],
112
+ logger=TensorBoardLogger(
113
+ save_dir=log_dir,
114
+ name="",
115
+ ),
116
+ fast_dev_run=RuntimeConstants.debug,
117
+ )
118
+
119
+ trainer.fit(
120
+ lit_module, train_dataloaders=train_loader, val_dataloaders=val_loader
121
+ )
122
+ model = lit_module.model
123
+ if rank_zero_only.rank == 0:
124
+ log.info(f"Saving the final model to {log_dir}/raw_checkpoints/final")
125
+ modelpool.save_model(
126
+ model,
127
+ path=os.path.join(
128
+ trainer.log_dir if trainer.log_dir is not None else log_dir,
129
+ "raw_checkpoints",
130
+ "final",
131
+ ),
132
+ )
133
+ return model
134
+
135
+ def get_dataloader(self, dataset, stage: str):
136
+ assert stage in ["train", "val", "test"], f"Invalid stage: {stage}"
137
+ dataloader_kwargs = dict(self.dataloader_kwargs)
138
+ if "shuffle" not in dataloader_kwargs:
139
+ dataloader_kwargs["shuffle"] = stage == "train"
140
+ return DataLoader(dataset, **dataloader_kwargs)
141
+
142
+
143
+ @auto_register_config
144
+ class ImageClassificationFineTuning_Test(BaseAlgorithm):
145
+ def __init__(self, checkpoint_path: str, dataloader_kwargs: DictConfig, **kwargs):
146
+ super().__init__(**kwargs)
147
+
148
+ def run(self, modelpool: BaseModelPool):
149
+ assert (
150
+ modelpool.has_val_dataset or modelpool.has_test_dataset
151
+ ), "No validation or test dataset found in the model pool."
152
+
153
+ # load model and dataset
154
+ model = modelpool.load_pretrained_or_first_model()
155
+ assert isinstance(model, nn.Module), "Loaded model is not a nn.Module."
156
+
157
+ if modelpool.has_test_dataset:
158
+ assert (
159
+ len(modelpool.test_dataset_names) == 1
160
+ ), "Exactly one test dataset is required."
161
+ self.dataset_name = dataset_name = modelpool.test_dataset_names[0]
162
+ dataset = modelpool.load_test_dataset(dataset_name)
163
+ dataset = CLIPDataset(
164
+ dataset, processor=modelpool.load_processor(stage="test")
165
+ )
166
+ else:
167
+ assert (
168
+ len(modelpool.val_dataset_names) == 1
169
+ ), "Exactly one validation dataset is required."
170
+ self.dataset_name = dataset_name = modelpool.val_dataset_names[0]
171
+ dataset = modelpool.load_val_dataset(dataset_name)
172
+ dataset = CLIPDataset(
173
+ dataset, processor=modelpool.load_processor(stage="test")
174
+ )
175
+ num_classes = get_num_classes(dataset_name)
176
+
177
+ test_loader = self.get_dataloader(dataset, stage="test")
178
+
179
+ if self.checkpoint_path is None:
180
+ lit_module = ERM_LitModule(
181
+ model,
182
+ metrics={
183
+ "acc@1": Accuracy(task="multiclass", num_classes=num_classes),
184
+ "acc@5": Accuracy(
185
+ task="multiclass", num_classes=num_classes, top_k=5
186
+ ),
187
+ },
188
+ )
189
+ else:
190
+ lit_module = ERM_LitModule.load_from_checkpoint(
191
+ checkpoint_path=self.checkpoint_path,
192
+ model=model,
193
+ metrics={
194
+ "acc@1": Accuracy(task="multiclass", num_classes=num_classes),
195
+ "acc@5": Accuracy(
196
+ task="multiclass", num_classes=num_classes, top_k=5
197
+ ),
198
+ },
199
+ )
200
+
201
+ trainer = L.Trainer(
202
+ devices=1, num_nodes=1, logger=False, fast_dev_run=RuntimeConstants.debug
203
+ )
204
+
205
+ test_metrics = trainer.test(lit_module, dataloaders=test_loader)
206
+ log.info(f"Test metrics: {test_metrics}")
207
+ return model
208
+
209
+ def get_dataloader(self, dataset, stage: str):
210
+ assert stage in ["train", "val", "test"], f"Invalid stage: {stage}"
211
+ dataloader_kwargs = dict(self.dataloader_kwargs)
212
+ if "shuffle" not in dataloader_kwargs:
213
+ dataloader_kwargs["shuffle"] = stage == "train"
214
+ return DataLoader(dataset, **dataloader_kwargs)
@@ -17,7 +17,21 @@ from fusion_bench.models.wrappers.ensemble import (
17
17
  log = logging.getLogger(__name__)
18
18
 
19
19
 
20
+ @auto_register_config
20
21
  class SimpleEnsembleAlgorithm(BaseAlgorithm):
22
+ def __init__(
23
+ self,
24
+ device_map: Optional[Mapping[int, Union[str, torch.device]]] = None,
25
+ **kwargs,
26
+ ):
27
+ """
28
+ Initializes the SimpleEnsembleAlgorithm with an optional device map.
29
+
30
+ Args:
31
+ device_map (Optional[Mapping[int, Union[str, torch.device]]], optional): A mapping from model index to device. Defaults to None.
32
+ """
33
+ super().__init__(**kwargs)
34
+
21
35
  @torch.no_grad()
22
36
  def run(self, modelpool: BaseModelPool | List[nn.Module]) -> EnsembleModule:
23
37
  """
@@ -30,9 +44,10 @@ class SimpleEnsembleAlgorithm(BaseAlgorithm):
30
44
  EnsembleModule: The ensembled model.
31
45
  """
32
46
  log.info(f"Running ensemble algorithm with {len(modelpool)} models")
33
-
34
47
  models = [modelpool.load_model(m) for m in modelpool.model_names]
35
- ensemble = EnsembleModule(models=models)
48
+
49
+ log.info("creating ensemble module")
50
+ ensemble = EnsembleModule(models=models, device_map=self.device_map)
36
51
  return ensemble
37
52
 
38
53
 
@@ -2,5 +2,9 @@
2
2
  from .expo import ExPOAlgorithm
3
3
  from .linear_interpolation import LinearInterpolationAlgorithm
4
4
  from .llama_expo import ExPOAlgorithmForLlama
5
- from .simple_average_for_llama import SimpleAverageForLlama
6
- from .task_arithmetic_for_llama import TaskArithmeticForLlama
5
+ from .simple_average_for_causallm import SimpleAverageForCausalLM, SimpleAverageForLlama
6
+ from .task_arithmetic_for_causallm import (
7
+ TaskArithmeticForCausalLM,
8
+ TaskArithmeticForLlama,
9
+ )
10
+ from .ties_merging_for_causallm import TiesMergingForCausalLM
@@ -18,16 +18,16 @@ log = get_rankzero_logger(__name__)
18
18
 
19
19
 
20
20
  @auto_register_config
21
- class SimpleAverageForLlama(BaseAlgorithm):
21
+ class SimpleAverageForCausalLM(BaseAlgorithm):
22
22
  R"""
23
23
  A simple averaging algorithm for LLama models. If `merge_backbone` is set to `True`, the backbone of the model will be averaged and the rest of the model will be loaded from the pre-trained model.
24
24
 
25
25
  Examples:
26
- The following example demonstrates how to use the `SimpleAverageForLlama` algorithm to merge Mistral models.
26
+ The following example demonstrates how to use the `SimpleAverageForCausalLM` algorithm to merge Mistral models.
27
27
 
28
28
  ```bash
29
29
  fusion_bench \
30
- method=linear/simple_average_for_llama \
30
+ method=linear/simple_average_for_causallm \
31
31
  method.model_save_path=outputs/simle_mixtral_exp_v4/simple_average \
32
32
  modelpool=CausalLMPool/simle_mixtral_exp_v4.yaml
33
33
  ```
@@ -35,7 +35,7 @@ class SimpleAverageForLlama(BaseAlgorithm):
35
35
 
36
36
  def __init__(
37
37
  self,
38
- merge_backbone: bool,
38
+ merge_backbone: bool = False,
39
39
  model_save_path: Optional[str] = None,
40
40
  show_pbar: bool = False,
41
41
  **kwargs,
@@ -81,3 +81,7 @@ class SimpleAverageForLlama(BaseAlgorithm):
81
81
  with open(os.path.join(self.model_save_path, "README.md"), "w") as f:
82
82
  f.write(model_card_str)
83
83
  return model
84
+
85
+
86
+ SimpleAverageForLlama = SimpleAverageForCausalLM
87
+ """Alias for SimpleAverageForCausalLM"""