fusion-bench 0.2.24__py3-none-any.whl → 0.2.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. fusion_bench/__init__.py +152 -42
  2. fusion_bench/dataset/__init__.py +27 -4
  3. fusion_bench/dataset/clip_dataset.py +2 -2
  4. fusion_bench/method/__init__.py +10 -1
  5. fusion_bench/method/classification/__init__.py +27 -2
  6. fusion_bench/method/classification/image_classification_finetune.py +214 -0
  7. fusion_bench/method/opcm/opcm.py +1 -0
  8. fusion_bench/method/pwe_moe/module.py +0 -2
  9. fusion_bench/method/tall_mask/task_arithmetic.py +2 -2
  10. fusion_bench/mixins/__init__.py +2 -0
  11. fusion_bench/mixins/pyinstrument.py +174 -0
  12. fusion_bench/mixins/simple_profiler.py +106 -23
  13. fusion_bench/modelpool/__init__.py +2 -0
  14. fusion_bench/modelpool/base_pool.py +77 -14
  15. fusion_bench/modelpool/clip_vision/modelpool.py +56 -19
  16. fusion_bench/modelpool/resnet_for_image_classification.py +208 -0
  17. fusion_bench/models/__init__.py +35 -9
  18. fusion_bench/optim/__init__.py +40 -2
  19. fusion_bench/optim/lr_scheduler/__init__.py +27 -1
  20. fusion_bench/optim/muon.py +339 -0
  21. fusion_bench/programs/__init__.py +2 -0
  22. fusion_bench/programs/fabric_fusion_program.py +2 -2
  23. fusion_bench/programs/fusion_program.py +271 -0
  24. fusion_bench/tasks/clip_classification/__init__.py +15 -0
  25. fusion_bench/utils/__init__.py +167 -21
  26. fusion_bench/utils/lazy_imports.py +91 -12
  27. fusion_bench/utils/lazy_state_dict.py +55 -5
  28. fusion_bench/utils/misc.py +104 -13
  29. fusion_bench/utils/packages.py +4 -0
  30. fusion_bench/utils/path.py +7 -0
  31. fusion_bench/utils/pylogger.py +6 -0
  32. fusion_bench/utils/rich_utils.py +1 -0
  33. fusion_bench/utils/state_dict_arithmetic.py +935 -162
  34. {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.25.dist-info}/METADATA +1 -1
  35. {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.25.dist-info}/RECORD +48 -34
  36. fusion_bench_config/method/classification/image_classification_finetune.yaml +16 -0
  37. fusion_bench_config/method/classification/image_classification_finetune_test.yaml +6 -0
  38. fusion_bench_config/model_fusion.yaml +45 -0
  39. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar10.yaml +14 -0
  40. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar100.yaml +14 -0
  41. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar10.yaml +14 -0
  42. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar100.yaml +14 -0
  43. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar10.yaml +14 -0
  44. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar100.yaml +14 -0
  45. {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.25.dist-info}/WHEEL +0 -0
  46. {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.25.dist-info}/entry_points.txt +0 -0
  47. {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.25.dist-info}/licenses/LICENSE +0 -0
  48. {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.25.dist-info}/top_level.txt +0 -0
fusion_bench/__init__.py CHANGED
@@ -5,46 +5,156 @@
5
5
  # ██║ ╚██████╔╝███████║██║╚██████╔╝██║ ╚████║ ██████╔╝███████╗██║ ╚████║╚██████╗██║ ██║
6
6
  # ╚═╝ ╚═════╝ ╚══════╝╚═╝ ╚═════╝ ╚═╝ ╚═══╝ ╚═════╝ ╚══════╝╚═╝ ╚═══╝ ╚═════╝╚═╝ ╚═╝
7
7
  # flake8: noqa: F401
8
- from . import (
9
- constants,
10
- dataset,
11
- method,
12
- metrics,
13
- mixins,
14
- modelpool,
15
- models,
16
- optim,
17
- programs,
18
- taskpool,
19
- tasks,
20
- utils,
21
- )
8
+ import sys
9
+ from typing import TYPE_CHECKING
10
+
11
+ from fusion_bench.utils.lazy_imports import LazyImporter
12
+
13
+ from . import constants, metrics, optim, tasks
22
14
  from .constants import RuntimeConstants
23
- from .method import BaseAlgorithm, BaseModelFusionAlgorithm
24
- from .mixins import auto_register_config
25
- from .modelpool import BaseModelPool
26
- from .models import (
27
- create_default_model_card,
28
- load_model_card_template,
29
- save_pretrained_with_remote_code,
30
- separate_io,
31
- )
32
- from .programs import BaseHydraProgram
33
- from .taskpool import BaseTaskPool
34
- from .utils import (
35
- BoolStateDictType,
36
- LazyStateDict,
37
- StateDictType,
38
- TorchModelType,
39
- cache_with_joblib,
40
- get_rankzero_logger,
41
- import_object,
42
- instantiate,
43
- parse_dtype,
44
- print_parameters,
45
- seed_everything_by_time,
46
- set_default_cache_dir,
47
- set_print_function_call,
48
- set_print_function_call_permeanent,
49
- timeit_context,
50
- )
15
+ from .method import _available_algorithms
16
+
17
+ _extra_objects = {
18
+ "RuntimeConstants": RuntimeConstants,
19
+ "constants": constants,
20
+ "metrics": metrics,
21
+ "optim": optim,
22
+ "tasks": tasks,
23
+ }
24
+ _import_structure = {
25
+ "dataset": ["CLIPDataset"],
26
+ "method": _available_algorithms,
27
+ "mixins": [
28
+ "CLIPClassificationMixin",
29
+ "FabricTrainingMixin",
30
+ "HydraConfigMixin",
31
+ "LightningFabricMixin",
32
+ "OpenCLIPClassificationMixin",
33
+ "PyinstrumentProfilerMixin",
34
+ "SimpleProfilerMixin",
35
+ "YAMLSerializationMixin",
36
+ "auto_register_config",
37
+ ],
38
+ "modelpool": [
39
+ "AutoModelPool",
40
+ "BaseModelPool",
41
+ "CausalLMBackbonePool",
42
+ "CausalLMPool",
43
+ "CLIPVisionModelPool",
44
+ "GPT2ForSequenceClassificationPool",
45
+ "HuggingFaceGPT2ClassificationPool",
46
+ "NYUv2ModelPool",
47
+ "OpenCLIPVisionModelPool",
48
+ "PeftModelForSeq2SeqLMPool",
49
+ "ResNetForImageClassificationPool",
50
+ "Seq2SeqLMPool",
51
+ "SequenceClassificationModelPool",
52
+ ],
53
+ "models": [
54
+ "create_default_model_card",
55
+ "load_model_card_template",
56
+ "save_pretrained_with_remote_code",
57
+ "separate_load",
58
+ "separate_save",
59
+ ],
60
+ "programs": ["BaseHydraProgram", "FabricModelFusionProgram"],
61
+ "taskpool": [
62
+ "BaseTaskPool",
63
+ "CLIPVisionModelTaskPool",
64
+ "DummyTaskPool",
65
+ "GPT2TextClassificationTaskPool",
66
+ "LMEvalHarnessTaskPool",
67
+ "OpenCLIPVisionModelTaskPool",
68
+ "NYUv2TaskPool",
69
+ ],
70
+ "utils": [
71
+ "ArithmeticStateDict",
72
+ "BoolStateDictType",
73
+ "LazyStateDict",
74
+ "StateDictType",
75
+ "TorchModelType",
76
+ "cache_with_joblib",
77
+ "get_rankzero_logger",
78
+ "import_object",
79
+ "instantiate",
80
+ "parse_dtype",
81
+ "print_parameters",
82
+ "seed_everything_by_time",
83
+ "set_default_cache_dir",
84
+ "set_print_function_call",
85
+ "set_print_function_call_permeanent",
86
+ "timeit_context",
87
+ ],
88
+ }
89
+
90
+ if TYPE_CHECKING:
91
+ from .dataset import CLIPDataset
92
+ from .method import BaseAlgorithm, BaseModelFusionAlgorithm
93
+ from .mixins import (
94
+ CLIPClassificationMixin,
95
+ FabricTrainingMixin,
96
+ HydraConfigMixin,
97
+ LightningFabricMixin,
98
+ OpenCLIPClassificationMixin,
99
+ PyinstrumentProfilerMixin,
100
+ SimpleProfilerMixin,
101
+ YAMLSerializationMixin,
102
+ auto_register_config,
103
+ )
104
+ from .modelpool import (
105
+ AutoModelPool,
106
+ BaseModelPool,
107
+ CausalLMBackbonePool,
108
+ CausalLMPool,
109
+ CLIPVisionModelPool,
110
+ GPT2ForSequenceClassificationPool,
111
+ HuggingFaceGPT2ClassificationPool,
112
+ NYUv2ModelPool,
113
+ OpenCLIPVisionModelPool,
114
+ PeftModelForSeq2SeqLMPool,
115
+ ResNetForImageClassificationPool,
116
+ Seq2SeqLMPool,
117
+ SequenceClassificationModelPool,
118
+ )
119
+ from .models import (
120
+ create_default_model_card,
121
+ load_model_card_template,
122
+ save_pretrained_with_remote_code,
123
+ separate_load,
124
+ separate_save,
125
+ )
126
+ from .programs import BaseHydraProgram, FabricModelFusionProgram
127
+ from .taskpool import (
128
+ BaseTaskPool,
129
+ CLIPVisionModelTaskPool,
130
+ DummyTaskPool,
131
+ GPT2TextClassificationTaskPool,
132
+ LMEvalHarnessTaskPool,
133
+ NYUv2TaskPool,
134
+ OpenCLIPVisionModelTaskPool,
135
+ )
136
+ from .utils import (
137
+ ArithmeticStateDict,
138
+ BoolStateDictType,
139
+ LazyStateDict,
140
+ StateDictType,
141
+ TorchModelType,
142
+ cache_with_joblib,
143
+ get_rankzero_logger,
144
+ import_object,
145
+ instantiate,
146
+ parse_dtype,
147
+ print_parameters,
148
+ seed_everything_by_time,
149
+ set_default_cache_dir,
150
+ set_print_function_call,
151
+ set_print_function_call_permeanent,
152
+ timeit_context,
153
+ )
154
+ else:
155
+ sys.modules[__name__] = LazyImporter(
156
+ __name__,
157
+ globals()["__file__"],
158
+ _import_structure,
159
+ extra_objects=_extra_objects,
160
+ )
@@ -1,16 +1,20 @@
1
1
  # flake8: noqa F401
2
- from datasets import load_dataset
3
- from omegaconf import DictConfig, open_dict
2
+ import sys
3
+ from typing import TYPE_CHECKING
4
4
 
5
- from fusion_bench.utils import instantiate
5
+ from omegaconf import DictConfig, open_dict
6
6
 
7
- from .clip_dataset import CLIPDataset
7
+ from fusion_bench.utils.lazy_imports import LazyImporter
8
8
 
9
9
 
10
10
  def load_dataset_from_config(dataset_config: DictConfig):
11
11
  """
12
12
  Load the dataset from the configuration.
13
13
  """
14
+ from datasets import load_dataset
15
+
16
+ from fusion_bench.utils import instantiate
17
+
14
18
  assert hasattr(dataset_config, "type"), "Dataset type not specified"
15
19
  if dataset_config.type == "instantiate":
16
20
  return instantiate(dataset_config.object)
@@ -27,3 +31,22 @@ def load_dataset_from_config(dataset_config: DictConfig):
27
31
  return dataset
28
32
  else:
29
33
  raise ValueError(f"Unknown dataset type: {dataset_config.type}")
34
+
35
+
36
+ _extra_objects = {
37
+ "load_dataset_from_config": load_dataset_from_config,
38
+ }
39
+ _import_structure = {
40
+ "clip_dataset": ["CLIPDataset"],
41
+ }
42
+
43
+ if TYPE_CHECKING:
44
+ from .clip_dataset import CLIPDataset
45
+
46
+ else:
47
+ sys.modules[__name__] = LazyImporter(
48
+ __name__,
49
+ globals()["__file__"],
50
+ _import_structure,
51
+ extra_objects=_extra_objects,
52
+ )
@@ -6,7 +6,7 @@ from typing import Optional, Tuple
6
6
 
7
7
  import torch
8
8
  from torch.utils.data import Dataset
9
- from transformers import CLIPProcessor, ProcessorMixin
9
+ from transformers import BaseImageProcessor, CLIPProcessor, ProcessorMixin
10
10
 
11
11
  __all__ = ["CLIPDataset"]
12
12
 
@@ -60,7 +60,7 @@ class CLIPDataset(torch.utils.data.Dataset):
60
60
  raise ValueError("Each item should be a dictionary or a tuple of length 2")
61
61
  image = item["image"]
62
62
  if self.processor is not None:
63
- if isinstance(self.processor, ProcessorMixin):
63
+ if isinstance(self.processor, (ProcessorMixin, BaseImageProcessor)):
64
64
  # Apply the processor to the image to get the input tensor
65
65
  inputs = self.processor(images=[image], return_tensors="pt")[
66
66
  "pixel_values"
@@ -2,6 +2,7 @@
2
2
  import sys
3
3
  from typing import TYPE_CHECKING
4
4
 
5
+ from fusion_bench.utils import join_lists
5
6
  from fusion_bench.utils.lazy_imports import LazyImporter
6
7
 
7
8
  _import_structure = {
@@ -12,6 +13,8 @@ _import_structure = {
12
13
  "classification": [
13
14
  "ImageClassificationFineTuningForCLIP",
14
15
  "ContinualImageClassificationFineTuningForCLIP",
16
+ "ImageClassificationFineTuning",
17
+ "ImageClassificationFineTuning_Test",
15
18
  ],
16
19
  "lm_finetune": ["FullFinetuneSFT", "PeftFinetuneSFT", "BradleyTerryRewardModeling"],
17
20
  # analysis
@@ -131,7 +134,10 @@ _import_structure = {
131
134
  "ProgressivePruningForMixtral",
132
135
  ],
133
136
  }
134
-
137
+ _available_algorithms = join_lists(list(_import_structure.values()))
138
+ _extra_objects = {
139
+ "_available_algorithms": _available_algorithms,
140
+ }
135
141
 
136
142
  if TYPE_CHECKING:
137
143
  from .ada_svd import AdaSVDMergingForCLIPVisionModel
@@ -141,6 +147,8 @@ if TYPE_CHECKING:
141
147
  from .bitdelta import BitDeltaAlgorithm
142
148
  from .classification import (
143
149
  ContinualImageClassificationFineTuningForCLIP,
150
+ ImageClassificationFineTuning,
151
+ ImageClassificationFineTuning_Test,
144
152
  ImageClassificationFineTuningForCLIP,
145
153
  )
146
154
  from .concrete_subspace import (
@@ -252,4 +260,5 @@ else:
252
260
  __name__,
253
261
  globals()["__file__"],
254
262
  _import_structure,
263
+ extra_objects=_extra_objects,
255
264
  )
@@ -1,3 +1,28 @@
1
1
  # flake8: noqa F401
2
- from .clip_finetune import ImageClassificationFineTuningForCLIP
3
- from .continual_clip_finetune import ContinualImageClassificationFineTuningForCLIP
2
+ import sys
3
+ from typing import TYPE_CHECKING
4
+
5
+ from fusion_bench.utils.lazy_imports import LazyImporter
6
+
7
+ _import_structure = {
8
+ "clip_finetune": ["ImageClassificationFineTuningForCLIP"],
9
+ "continual_clip_finetune": ["ContinualImageClassificationFineTuningForCLIP"],
10
+ "image_classification_finetune": [
11
+ "ImageClassificationFineTuning",
12
+ "ImageClassificationFineTuning_Test",
13
+ ],
14
+ }
15
+
16
+ if TYPE_CHECKING:
17
+ from .clip_finetune import ImageClassificationFineTuningForCLIP
18
+ from .continual_clip_finetune import ContinualImageClassificationFineTuningForCLIP
19
+ from .image_classification_finetune import (
20
+ ImageClassificationFineTuning,
21
+ ImageClassificationFineTuning_Test,
22
+ )
23
+ else:
24
+ sys.modules[__name__] = LazyImporter(
25
+ __name__,
26
+ globals()["__file__"],
27
+ _import_structure,
28
+ )
@@ -0,0 +1,214 @@
1
+ import os
2
+ from typing import Optional
3
+
4
+ import lightning as L
5
+ import lightning.pytorch.callbacks as pl_callbacks
6
+ import torch
7
+ from lightning.pytorch.loggers import TensorBoardLogger
8
+ from lightning_utilities.core.rank_zero import rank_zero_only
9
+ from lit_learn.lit_modules import ERM_LitModule
10
+ from omegaconf import DictConfig
11
+ from torch import nn
12
+ from torch.utils.data import DataLoader
13
+ from torchmetrics.classification import Accuracy
14
+
15
+ from fusion_bench import (
16
+ BaseAlgorithm,
17
+ BaseModelPool,
18
+ RuntimeConstants,
19
+ auto_register_config,
20
+ get_rankzero_logger,
21
+ instantiate,
22
+ )
23
+ from fusion_bench.dataset import CLIPDataset
24
+ from fusion_bench.modelpool import ResNetForImageClassificationPool
25
+ from fusion_bench.tasks.clip_classification import get_num_classes
26
+
27
+ log = get_rankzero_logger(__name__)
28
+
29
+
30
+ @auto_register_config
31
+ class ImageClassificationFineTuning(BaseAlgorithm):
32
+ def __init__(
33
+ self,
34
+ max_epochs: Optional[int],
35
+ max_steps: Optional[int],
36
+ label_smoothing: float,
37
+ optimizer: DictConfig,
38
+ lr_scheduler: DictConfig,
39
+ dataloader_kwargs: DictConfig,
40
+ **kwargs,
41
+ ):
42
+ super().__init__(**kwargs)
43
+ assert (max_epochs is None) or (
44
+ max_steps is None or max_steps < 0
45
+ ), "Only one of max_epochs or max_steps should be set."
46
+ self.training_interval = "epoch" if max_epochs is not None else "step"
47
+ if self.training_interval == "epoch":
48
+ self.max_steps = -1
49
+ log.info(f"Training interval: {self.training_interval}")
50
+ log.info(f"Max epochs: {max_epochs}, max steps: {max_steps}")
51
+
52
+ def run(self, modelpool: ResNetForImageClassificationPool):
53
+ # load model and dataset
54
+ model = modelpool.load_pretrained_or_first_model()
55
+ assert isinstance(model, nn.Module), "Loaded model is not a nn.Module."
56
+
57
+ assert (
58
+ len(modelpool.train_dataset_names) == 1
59
+ ), "Exactly one training dataset is required."
60
+ self.dataset_name = dataset_name = modelpool.train_dataset_names[0]
61
+ num_classes = get_num_classes(dataset_name)
62
+ train_dataset = modelpool.load_train_dataset(dataset_name)
63
+ train_dataset = CLIPDataset(
64
+ train_dataset, processor=modelpool.load_processor(stage="train")
65
+ )
66
+ train_loader = self.get_dataloader(train_dataset, stage="train")
67
+ if modelpool.has_val_dataset:
68
+ val_dataset = modelpool.load_val_dataset(dataset_name)
69
+ val_dataset = CLIPDataset(
70
+ val_dataset, processor=modelpool.load_processor(stage="val")
71
+ )
72
+ val_loader = self.get_dataloader(val_dataset, stage="val")
73
+
74
+ # configure optimizer
75
+ optimizer = instantiate(self.optimizer, params=model.parameters())
76
+ if self.lr_scheduler is not None:
77
+ lr_scheduler = instantiate(self.lr_scheduler, optimizer=optimizer)
78
+ optimizer = {
79
+ "optimizer": optimizer,
80
+ "lr_scheduler": {
81
+ "scheduler": lr_scheduler,
82
+ "interval": self.training_interval,
83
+ "frequency": 1,
84
+ },
85
+ }
86
+ log.info(f"optimizer:\n{optimizer}")
87
+
88
+ lit_module = ERM_LitModule(
89
+ model,
90
+ optimizer,
91
+ objective=nn.CrossEntropyLoss(label_smoothing=self.label_smoothing),
92
+ metrics={
93
+ "acc@1": Accuracy(task="multiclass", num_classes=num_classes),
94
+ "acc@5": Accuracy(task="multiclass", num_classes=num_classes, top_k=5),
95
+ },
96
+ )
97
+
98
+ log_dir = (
99
+ self._program.path.log_dir
100
+ if self._program is not None
101
+ else "outputs/lightning_logs"
102
+ )
103
+ trainer = L.Trainer(
104
+ max_epochs=self.max_epochs,
105
+ max_steps=self.max_steps,
106
+ accelerator="auto",
107
+ devices="auto",
108
+ callbacks=[
109
+ pl_callbacks.LearningRateMonitor(logging_interval="step"),
110
+ pl_callbacks.DeviceStatsMonitor(),
111
+ ],
112
+ logger=TensorBoardLogger(
113
+ save_dir=log_dir,
114
+ name="",
115
+ ),
116
+ fast_dev_run=RuntimeConstants.debug,
117
+ )
118
+
119
+ trainer.fit(
120
+ lit_module, train_dataloaders=train_loader, val_dataloaders=val_loader
121
+ )
122
+ model = lit_module.model
123
+ if rank_zero_only.rank == 0:
124
+ log.info(f"Saving the final model to {log_dir}/raw_checkpoints/final")
125
+ modelpool.save_model(
126
+ model,
127
+ path=os.path.join(
128
+ trainer.log_dir if trainer.log_dir is not None else log_dir,
129
+ "raw_checkpoints",
130
+ "final",
131
+ ),
132
+ )
133
+ return model
134
+
135
+ def get_dataloader(self, dataset, stage: str):
136
+ assert stage in ["train", "val", "test"], f"Invalid stage: {stage}"
137
+ dataloader_kwargs = dict(self.dataloader_kwargs)
138
+ if "shuffle" not in dataloader_kwargs:
139
+ dataloader_kwargs["shuffle"] = stage == "train"
140
+ return DataLoader(dataset, **dataloader_kwargs)
141
+
142
+
143
+ @auto_register_config
144
+ class ImageClassificationFineTuning_Test(BaseAlgorithm):
145
+ def __init__(self, checkpoint_path: str, dataloader_kwargs: DictConfig, **kwargs):
146
+ super().__init__(**kwargs)
147
+
148
+ def run(self, modelpool: BaseModelPool):
149
+ assert (
150
+ modelpool.has_val_dataset or modelpool.has_test_dataset
151
+ ), "No validation or test dataset found in the model pool."
152
+
153
+ # load model and dataset
154
+ model = modelpool.load_pretrained_or_first_model()
155
+ assert isinstance(model, nn.Module), "Loaded model is not a nn.Module."
156
+
157
+ if modelpool.has_test_dataset:
158
+ assert (
159
+ len(modelpool.test_dataset_names) == 1
160
+ ), "Exactly one test dataset is required."
161
+ self.dataset_name = dataset_name = modelpool.test_dataset_names[0]
162
+ dataset = modelpool.load_test_dataset(dataset_name)
163
+ dataset = CLIPDataset(
164
+ dataset, processor=modelpool.load_processor(stage="test")
165
+ )
166
+ else:
167
+ assert (
168
+ len(modelpool.val_dataset_names) == 1
169
+ ), "Exactly one validation dataset is required."
170
+ self.dataset_name = dataset_name = modelpool.val_dataset_names[0]
171
+ dataset = modelpool.load_val_dataset(dataset_name)
172
+ dataset = CLIPDataset(
173
+ dataset, processor=modelpool.load_processor(stage="test")
174
+ )
175
+ num_classes = get_num_classes(dataset_name)
176
+
177
+ test_loader = self.get_dataloader(dataset, stage="test")
178
+
179
+ if self.checkpoint_path is None:
180
+ lit_module = ERM_LitModule(
181
+ model,
182
+ metrics={
183
+ "acc@1": Accuracy(task="multiclass", num_classes=num_classes),
184
+ "acc@5": Accuracy(
185
+ task="multiclass", num_classes=num_classes, top_k=5
186
+ ),
187
+ },
188
+ )
189
+ else:
190
+ lit_module = ERM_LitModule.load_from_checkpoint(
191
+ checkpoint_path=self.checkpoint_path,
192
+ model=model,
193
+ metrics={
194
+ "acc@1": Accuracy(task="multiclass", num_classes=num_classes),
195
+ "acc@5": Accuracy(
196
+ task="multiclass", num_classes=num_classes, top_k=5
197
+ ),
198
+ },
199
+ )
200
+
201
+ trainer = L.Trainer(
202
+ devices=1, num_nodes=1, logger=False, fast_dev_run=RuntimeConstants.debug
203
+ )
204
+
205
+ test_metrics = trainer.test(lit_module, dataloaders=test_loader)
206
+ log.info(f"Test metrics: {test_metrics}")
207
+ return model
208
+
209
+ def get_dataloader(self, dataset, stage: str):
210
+ assert stage in ["train", "val", "test"], f"Invalid stage: {stage}"
211
+ dataloader_kwargs = dict(self.dataloader_kwargs)
212
+ if "shuffle" not in dataloader_kwargs:
213
+ dataloader_kwargs["shuffle"] = stage == "train"
214
+ return DataLoader(dataset, **dataloader_kwargs)
@@ -87,6 +87,7 @@ class OPCMForCLIP(
87
87
  # get the average model
88
88
  with self.profile("loading model"):
89
89
  merged_model = modelpool.load_model(model_names[0])
90
+ assert merged_model is not None, "Failed to load the first model"
90
91
 
91
92
  if self.evaluate_on_every_step:
92
93
  with self.profile("evaluating model"):
@@ -13,8 +13,6 @@ import torch.func
13
13
  from torch import Tensor, nn
14
14
  from torch.nn import functional as F
15
15
 
16
- from fusion_bench.utils import join_list
17
-
18
16
  log = logging.getLogger(__name__)
19
17
 
20
18
 
@@ -15,7 +15,7 @@ from fusion_bench.utils.state_dict_arithmetic import (
15
15
  state_dict_add,
16
16
  state_dict_binary_mask,
17
17
  state_dict_diff_abs,
18
- state_dict_hadmard_product,
18
+ state_dict_hadamard_product,
19
19
  state_dict_mul,
20
20
  state_dict_sub,
21
21
  state_dict_sum,
@@ -111,7 +111,7 @@ class TallMaskTaskArithmeticAlgorithm(
111
111
 
112
112
  with self.profile("compress and retrieve"):
113
113
  for model_name in modelpool.model_names:
114
- retrieved_task_vector = state_dict_hadmard_product(
114
+ retrieved_task_vector = state_dict_hadamard_product(
115
115
  tall_masks[model_name], multi_task_vector
116
116
  )
117
117
  retrieved_state_dict = state_dict_add(
@@ -11,6 +11,7 @@ _import_structure = {
11
11
  "hydra_config": ["HydraConfigMixin"],
12
12
  "lightning_fabric": ["LightningFabricMixin"],
13
13
  "openclip_classification": ["OpenCLIPClassificationMixin"],
14
+ "pyinstrument": ["PyinstrumentProfilerMixin"],
14
15
  "serialization": [
15
16
  "BaseYAMLSerializable",
16
17
  "YAMLSerializationMixin",
@@ -25,6 +26,7 @@ if TYPE_CHECKING:
25
26
  from .hydra_config import HydraConfigMixin
26
27
  from .lightning_fabric import LightningFabricMixin
27
28
  from .openclip_classification import OpenCLIPClassificationMixin
29
+ from .pyinstrument import PyinstrumentProfilerMixin
28
30
  from .serialization import (
29
31
  BaseYAMLSerializable,
30
32
  YAMLSerializationMixin,