fusion-bench 0.2.30__py3-none-any.whl → 0.2.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. fusion_bench/__init__.py +6 -0
  2. fusion_bench/__main__.py +2 -2
  3. fusion_bench/constants/runtime.py +4 -1
  4. fusion_bench/dataset/__init__.py +2 -0
  5. fusion_bench/dataset/clip_dataset.py +4 -72
  6. fusion_bench/dataset/image_dataset.py +44 -18
  7. fusion_bench/method/base_algorithm.py +4 -0
  8. fusion_bench/method/classification/image_classification_finetune.py +1 -0
  9. fusion_bench/method/concrete_subspace/clip_concrete_tsvm.py +285 -0
  10. fusion_bench/method/dop/dop.py +0 -22
  11. fusion_bench/method/dop/dop_general.py +489 -0
  12. fusion_bench/method/dop/utils.py +24 -4
  13. fusion_bench/method/emr_merging/__init__.py +1 -0
  14. fusion_bench/method/emr_merging/emr_merging.py +53 -0
  15. fusion_bench/method/emr_merging/utils.py +162 -0
  16. fusion_bench/method/opcm/opcm.py +6 -2
  17. fusion_bench/method/opcm/opcm_general.py +356 -0
  18. fusion_bench/method/opcm/utils.py +1 -4
  19. fusion_bench/method/simple_average.py +52 -18
  20. fusion_bench/method/task_arithmetic/task_arithmetic.py +1 -1
  21. fusion_bench/method/task_singular_vector/TSVM.py +7 -6
  22. fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +0 -1
  23. fusion_bench/mixins/lightning_fabric.py +110 -11
  24. fusion_bench/mixins/openclip_classification.py +155 -1
  25. fusion_bench/mixins/serialization.py +1 -1
  26. fusion_bench/modelpool/base_pool.py +37 -0
  27. fusion_bench/modelpool/convnext_for_image_classification.py +5 -2
  28. fusion_bench/modelpool/openclip_vision/modelpool.py +12 -3
  29. fusion_bench/models/hf_clip.py +20 -0
  30. fusion_bench/models/modulator/__init__.py +1 -0
  31. fusion_bench/models/modulator/base.py +123 -0
  32. fusion_bench/models/open_clip/modeling.py +61 -5
  33. fusion_bench/models/open_clip/utils.py +13 -2
  34. fusion_bench/models/parameter_dict.py +119 -29
  35. fusion_bench/models/utils.py +190 -2
  36. fusion_bench/models/wrappers/switch.py +90 -0
  37. fusion_bench/programs/base_program.py +6 -0
  38. fusion_bench/programs/fabric_fusion_program.py +4 -0
  39. fusion_bench/py.typed +1 -0
  40. fusion_bench/scripts/cli.py +25 -23
  41. fusion_bench/scripts/imgui.py +2 -2
  42. fusion_bench/scripts/webui.py +2 -2
  43. fusion_bench/taskpool/image_classification.py +270 -0
  44. fusion_bench/utils/__init__.py +20 -1
  45. fusion_bench/utils/data.py +1 -1
  46. fusion_bench/utils/dict.py +19 -0
  47. fusion_bench/utils/dtype.py +19 -0
  48. fusion_bench/utils/hydra_utils.py +75 -0
  49. fusion_bench/utils/misc.py +1 -0
  50. fusion_bench/utils/packages.py +4 -0
  51. fusion_bench/utils/parameters.py +33 -0
  52. fusion_bench/utils/rich_utils.py +42 -19
  53. fusion_bench/utils/state_dict_arithmetic.py +183 -1
  54. fusion_bench/utils/tensorboard.py +21 -3
  55. {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/METADATA +3 -1
  56. {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/RECORD +70 -53
  57. {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/WHEEL +1 -1
  58. {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/entry_points.txt +1 -1
  59. fusion_bench_config/README.md +9 -0
  60. fusion_bench_config/fabric/auto.yaml +1 -0
  61. fusion_bench_config/fabric/loggers/mlflow_logger.yaml +4 -0
  62. fusion_bench_config/hydra/default.yaml +3 -1
  63. fusion_bench_config/method/concrete_subspace/clip_concrete_tsvm.yaml +38 -0
  64. fusion_bench_config/method/dop/dop_general.yaml +33 -0
  65. fusion_bench_config/method/emr_merging/emr_merging.yaml +1 -0
  66. fusion_bench_config/method/opcm/opcm_general.yaml +18 -0
  67. fusion_bench_config/modelpool/ConvNextForImageClassification/convnext-base-224_8-tasks.yaml +15 -0
  68. fusion_bench_config/taskpool/ImageClassificationTaskPool/convnext-base-224_8-tasks.yaml +17 -0
  69. {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/licenses/LICENSE +0 -0
  70. {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/top_level.txt +0 -0
fusion_bench/__init__.py CHANGED
@@ -86,6 +86,9 @@ _import_structure = {
86
86
  "set_print_function_call",
87
87
  "set_print_function_call_permeanent",
88
88
  "timeit_context",
89
+ "initialize_hydra_config",
90
+ "get_default_config_path",
91
+ "get_hydra_output_dir",
89
92
  ],
90
93
  }
91
94
 
@@ -144,8 +147,11 @@ if TYPE_CHECKING:
144
147
  StateDictType,
145
148
  TorchModelType,
146
149
  cache_with_joblib,
150
+ get_default_config_path,
151
+ get_hydra_output_dir,
147
152
  get_rankzero_logger,
148
153
  import_object,
154
+ initialize_hydra_config,
149
155
  instantiate,
150
156
  parse_dtype,
151
157
  print_parameters,
fusion_bench/__main__.py CHANGED
@@ -1,4 +1,4 @@
1
- from fusion_bench.scripts.cli import main
1
+ from fusion_bench.scripts.cli import _hydra_main
2
2
 
3
3
  if __name__ == "__main__":
4
- main()
4
+ _hydra_main()
@@ -89,7 +89,10 @@ class RuntimeConstants:
89
89
  self._initialized = True
90
90
 
91
91
  debug = False
92
- """Global debug flag for enabling verbose logging and debugging features."""
92
+ """
93
+ Global debug flag for enabling verbose logging and debugging features.
94
+ Use `RuntimeConstants().debug` instead of `RuntimeConstants.debug`
95
+ """
93
96
 
94
97
  @property
95
98
  def cache_dir(self) -> Path:
@@ -38,10 +38,12 @@ _extra_objects = {
38
38
  }
39
39
  _import_structure = {
40
40
  "clip_dataset": ["CLIPDataset"],
41
+ "image_dataset": ["ImageClassificationDataset"],
41
42
  }
42
43
 
43
44
  if TYPE_CHECKING:
44
45
  from .clip_dataset import CLIPDataset
46
+ from .image_dataset import ImageClassificationDataset
45
47
 
46
48
  else:
47
49
  sys.modules[__name__] = LazyImporter(
@@ -2,80 +2,12 @@
2
2
  This module provides a class to convert a dataset whose object is a list of dictionaries with keys "image" and "label" to a dataset whose object is a tuple of tensors (inputs, label) for CLIP models.
3
3
  """
4
4
 
5
- from typing import Optional, Tuple
5
+ from fusion_bench.utils import DeprecationWarningMeta
6
6
 
7
- import torch
8
- from torch.utils.data import Dataset
9
- from transformers import BaseImageProcessor, CLIPProcessor, ProcessorMixin
7
+ from .image_dataset import ImageClassificationDataset
10
8
 
11
9
  __all__ = ["CLIPDataset"]
12
10
 
13
11
 
14
- class CLIPDataset(torch.utils.data.Dataset):
15
- """
16
- A dataset class for CLIP models that converts a dataset of dictionaries or tuples
17
- into a format suitable for CLIP processing.
18
-
19
- This class wraps an existing dataset and applies CLIP preprocessing to the images.
20
- It expects each item in the dataset to be either a dictionary with 'image' and 'label' keys,
21
- or a tuple/list of (image, label).
22
-
23
- Args:
24
- dataset: The original dataset to wrap.
25
- processor (CLIPProcessor): The CLIP processor for preparing inputs. If None, no preprocessing is applied and raw images are returned.
26
-
27
- Attributes:
28
- dataset: The wrapped dataset.
29
- processor (CLIPProcessor): The CLIP processor used for image preprocessing.
30
- """
31
-
32
- def __init__(self, dataset: Dataset, processor: Optional[CLIPProcessor] = None):
33
- self.dataset = dataset
34
- self.processor = processor
35
-
36
- def __len__(self):
37
- """Returns the number of items in the dataset."""
38
- return len(self.dataset)
39
-
40
- def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
41
- """
42
- Retrieves and processes an item from the dataset.
43
-
44
- Args:
45
- idx (int): The index of the item to retrieve.
46
-
47
- Returns:
48
- tuple: A tuple containing the processed image tensor and the label.
49
-
50
- Raises:
51
- ValueError: If the item is neither a dictionary nor a tuple/list of length 2.
52
- """
53
- item = self.dataset[idx]
54
- if isinstance(item, dict):
55
- item = item
56
- elif isinstance(item, (tuple, list)):
57
- assert len(item) == 2, "Each item should be a tuple or list of length 2"
58
- item = {"image": item[0], "label": item[1]}
59
- else:
60
- raise ValueError("Each item should be a dictionary or a tuple of length 2")
61
- image = item["image"]
62
- if self.processor is not None:
63
- if isinstance(self.processor, (ProcessorMixin, BaseImageProcessor)):
64
- # Apply the processor to the image to get the input tensor
65
- image = image.convert("RGB") # ensure image is in RGB format
66
- inputs = self.processor(images=[image], return_tensors="pt")[
67
- "pixel_values"
68
- ][0]
69
- elif callable(self.processor):
70
- inputs = self.processor(image)
71
- else:
72
- raise ValueError(
73
- "The processor should be a CLIPProcessor or a callable function"
74
- )
75
- else:
76
- # if processor is None, return the raw image directly
77
- inputs = image
78
- # convert boolean label to int, this is for the case when the label is a binary classification task
79
- if isinstance(item["label"], bool):
80
- item["label"] = 1 if item["label"] else 0
81
- return inputs, item["label"]
12
+ class CLIPDataset(ImageClassificationDataset, metaclass=DeprecationWarningMeta):
13
+ pass
@@ -1,35 +1,39 @@
1
- from typing import Any, Callable, Tuple
1
+ from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Union
2
2
 
3
+ import torch
3
4
  from torch.utils.data import Dataset
5
+ from transformers import BaseImageProcessor, ProcessorMixin
4
6
 
5
7
 
6
- class TransformedImageDataset(Dataset):
8
+ class ImageClassificationDataset(Dataset):
7
9
  """
8
- A dataset class for image classification tasks that applies a transform to images.
10
+ A dataset class for image classification models that converts a dataset of dictionaries or tuples
11
+ into a format suitable for model processing.
9
12
 
10
- This class wraps an existing dataset and applies a specified transform to the images.
13
+ This class wraps an existing dataset and applies preprocessing to the images.
11
14
  It expects each item in the dataset to be either a dictionary with 'image' and 'label' keys,
12
15
  or a tuple/list of (image, label).
13
-
14
- Args:
15
- dataset: The original dataset to wrap.
16
- transform (Callable): A function/transform to apply on the image.
17
-
18
- Attributes:
19
- dataset: The wrapped dataset.
20
- transform (Callable): The transform to be applied to the images.
21
16
  """
22
17
 
23
- def __init__(self, dataset: Dataset, transform: Callable):
24
- super().__init__()
18
+ def __init__(
19
+ self,
20
+ dataset: Dataset,
21
+ processor: Optional[Union["ProcessorMixin", "BaseImageProcessor"]] = None,
22
+ ):
23
+ """
24
+ Args:
25
+ dataset (Dataset): The original dataset to wrap.
26
+ processor (Optional[Union[ProcessorMixin, BaseImageProcessor]]): The processor for preparing inputs.
27
+ If None, no preprocessing is applied and raw images are returned.
28
+ """
25
29
  self.dataset = dataset
26
- self.transform = transform
30
+ self.processor = processor
27
31
 
28
32
  def __len__(self):
29
33
  """Returns the number of items in the dataset."""
30
34
  return len(self.dataset)
31
35
 
32
- def __getitem__(self, idx: int) -> Tuple[Any, Any]:
36
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
33
37
  """
34
38
  Retrieves and processes an item from the dataset.
35
39
 
@@ -37,11 +41,13 @@ class TransformedImageDataset(Dataset):
37
41
  idx (int): The index of the item to retrieve.
38
42
 
39
43
  Returns:
40
- tuple: A tuple containing the processed image and the label.
44
+ tuple: A tuple containing the processed image tensor and the label.
41
45
 
42
46
  Raises:
43
47
  ValueError: If the item is neither a dictionary nor a tuple/list of length 2.
44
48
  """
49
+ # Standardize the item to a dictionary format
50
+ # {"image": ..., "label": ...}
45
51
  item = self.dataset[idx]
46
52
  if isinstance(item, dict):
47
53
  item = item
@@ -50,6 +56,26 @@ class TransformedImageDataset(Dataset):
50
56
  item = {"image": item[0], "label": item[1]}
51
57
  else:
52
58
  raise ValueError("Each item should be a dictionary or a tuple of length 2")
59
+
60
+ # Process the image using the provided processor, if any
53
61
  image = item["image"]
54
- inputs = self.transform(image)
62
+ if self.processor is not None:
63
+ if isinstance(self.processor, (ProcessorMixin, BaseImageProcessor)):
64
+ # Apply the processor to the image to get the input tensor
65
+ image = image.convert("RGB") # ensure image is in RGB format
66
+ inputs = self.processor(images=[image], return_tensors="pt")[
67
+ "pixel_values"
68
+ ][0]
69
+ elif callable(self.processor):
70
+ inputs = self.processor(image)
71
+ else:
72
+ raise ValueError(
73
+ "The processor should be a transformers Processor or a callable function"
74
+ )
75
+ else:
76
+ # if processor is None, return the raw image directly
77
+ inputs = image
78
+ # convert boolean label to int, this is for the case when the label is a binary classification task
79
+ if isinstance(item["label"], bool):
80
+ item["label"] = 1 if item["label"] else 0
55
81
  return inputs, item["label"]
@@ -59,6 +59,10 @@ class BaseAlgorithm(BaseYAMLSerializable):
59
59
  core fusion logic in the `run` method, while optional lifecycle hooks allow for
60
60
  setup and cleanup operations.
61
61
 
62
+ If model has `_fusion_bench_target_modules` attribute, the algorithm will only fuse
63
+ the specified target modules. This is useful for models where only certain layers
64
+ should be fused (e.g., classification heads on top of a shared backbone are not merged).
65
+
62
66
  Attributes:
63
67
  _program: Optional program reference for algorithm execution context.
64
68
  _config_key (str): Configuration key used for YAML serialization, defaults to "method".
@@ -173,6 +173,7 @@ class ImageClassificationFineTuning(BaseAlgorithm):
173
173
  ),
174
174
  },
175
175
  )
176
+ lit_module.train()
176
177
 
177
178
  log_dir = (
178
179
  self._program.path.log_dir
@@ -0,0 +1,285 @@
1
+ import logging
2
+ import os
3
+ from copy import deepcopy
4
+ from typing import TYPE_CHECKING, Dict, Iterable, List, Literal, Optional
5
+
6
+ import torch
7
+ from omegaconf import DictConfig
8
+ from tqdm import tqdm
9
+
10
+ from fusion_bench import (
11
+ BaseAlgorithm,
12
+ OpenCLIPClassificationMixin,
13
+ OpenCLIPVisionModelPool,
14
+ SimpleProfilerMixin,
15
+ StateDictType,
16
+ auto_register_config,
17
+ get_rankzero_logger,
18
+ instantiate,
19
+ )
20
+ from fusion_bench.method.adamerging.entropy_loss import entropy_loss
21
+ from fusion_bench.method.task_singular_vector import TaskSingularVectorMerging
22
+ from fusion_bench.method.task_singular_vector.utils import (
23
+ TSVM_utils,
24
+ check_parameterNamesMatch,
25
+ check_state_dicts_equal,
26
+ state_dict_to_vector,
27
+ vector_to_state_dict,
28
+ )
29
+ from fusion_bench.models.masks import MaskModel, mask_sparsity
30
+ from fusion_bench.models.open_clip import (
31
+ ClassificationHead,
32
+ ImageClassifier,
33
+ ImageEncoder,
34
+ )
35
+ from fusion_bench.models.wrappers.task_wise_fusion import (
36
+ TaskWiseMergedModel,
37
+ get_task_wise_weights,
38
+ )
39
+ from fusion_bench.utils.devices import clear_cuda_cache
40
+ from fusion_bench.utils.dtype import parse_dtype
41
+ from fusion_bench.utils.parameters import print_parameters, print_trainable_parameters
42
+ from fusion_bench.utils.rich_utils import print_config_yaml
43
+ from fusion_bench.utils.state_dict_arithmetic import (
44
+ _validate_state_dict_same_keys,
45
+ state_dict_add,
46
+ state_dict_hadamard_product,
47
+ state_dict_mul,
48
+ state_dict_sub,
49
+ )
50
+
51
+ log = get_rankzero_logger(__name__)
52
+
53
+
54
+ @auto_register_config
55
+ class ConcreteTSVMForOpenCLIP(
56
+ OpenCLIPClassificationMixin,
57
+ SimpleProfilerMixin,
58
+ BaseAlgorithm,
59
+ ):
60
+ def __init__(
61
+ self,
62
+ dataloader_kwargs: DictConfig,
63
+ optimizer: DictConfig,
64
+ lr_scheduler: DictConfig,
65
+ max_steps: int,
66
+ save_interval: int,
67
+ initial_logits: float,
68
+ temperature: float,
69
+ eval_mask_type: Literal["continuous", "discrete"],
70
+ mask_checkpoint: Optional[str],
71
+ merge_dtype: str,
72
+ clamp_weights: bool,
73
+ tie_weights: bool,
74
+ strict: bool,
75
+ skip_training: bool,
76
+ # === TSVM parameters ===
77
+ exclude_keys: Optional[List[str]],
78
+ alpha: float,
79
+ return_single_task_models: bool = True,
80
+ **kwargs,
81
+ ):
82
+ super().__init__(**kwargs)
83
+ if not return_single_task_models:
84
+ log.warning("return_single_task_models is forced to be True here.")
85
+ self.return_single_task_models = True
86
+
87
+ @torch.no_grad()
88
+ def setup_models(self):
89
+ """
90
+ load the pre-trained model, task vectors, and construct the mask model.
91
+ """
92
+ merge_dtype = parse_dtype(self.merge_dtype)
93
+ modelpool = self.modelpool
94
+
95
+ # load the pre-trained model
96
+ pretrained_model = modelpool.load_pretrained_model()
97
+ self.set_clip_processor(stage="test", processor=pretrained_model.val_preprocess)
98
+
99
+ # constrcute mask model
100
+ mask_model = MaskModel(
101
+ pretrained_model, ignore_untrained_params=True, parameter_type="logits"
102
+ )
103
+ if merge_dtype is not None:
104
+ mask_model.to(merge_dtype)
105
+ mask_model.fill_(self.initial_logits)
106
+
107
+ if self.fabric.is_global_zero:
108
+ print("summary of mask model:")
109
+ print_parameters(mask_model)
110
+
111
+ if self.fabric.is_global_zero:
112
+ tsvm_algo = TaskSingularVectorMerging(
113
+ alpha=self.alpha,
114
+ exclude_keys=self.exclude_keys,
115
+ return_single_task_models=self.return_single_task_models,
116
+ )
117
+ tsvm_algo._fabric_instance = self.fabric
118
+ models = tsvm_algo.run(modelpool)
119
+
120
+ finetuned_models = [models[name] for name in modelpool.model_names]
121
+
122
+ task_wise_weight = get_task_wise_weights(
123
+ num_models=len(modelpool.model_names),
124
+ init_values=self.alpha,
125
+ )
126
+
127
+ # create a wrapped model
128
+ module = TaskWiseMergedModel(
129
+ task_wise_weight=task_wise_weight,
130
+ pretrained_model=pretrained_model,
131
+ finetuned_models=finetuned_models,
132
+ clamp_weights=self.clamp_weights,
133
+ tie_weights=self.tie_weights,
134
+ strict=self.strict,
135
+ task_vector_dtype=merge_dtype,
136
+ )
137
+ module = module.to(dtype=merge_dtype)
138
+
139
+ print("trainable parameter summary of merged model (TaskWiseMergedModel):")
140
+ print_trainable_parameters(module)
141
+ else:
142
+ module = None
143
+
144
+ with torch.no_grad():
145
+ self.fabric.barrier()
146
+ module = self.fabric.broadcast(module, src=0)
147
+
148
+ return module, mask_model
149
+
150
+ def train_mask(self, module: TaskWiseMergedModel, mask_model: MaskModel):
151
+ """
152
+ Train the mask model using the provided module.
153
+
154
+ This method configures the optimizer, sets up the mask model, and performs test-time adaptation to train the mask model.
155
+
156
+ Args:
157
+ module (TaskWiseMergedModel): The wrapped model with task-wise weights.
158
+ mask_model (MaskModel): The mask model to be trained.
159
+ """
160
+ config = self.config
161
+ merge_dtype = parse_dtype(self.merge_dtype)
162
+ log.info(f"Using merge dtype: {merge_dtype}")
163
+
164
+ optimizer: "torch.optim.Optimizer" = instantiate(
165
+ self.optimizer,
166
+ params=filter(lambda p: p.requires_grad, mask_model.parameters()),
167
+ )
168
+ print(f"{optimizer=}")
169
+ if self.lr_scheduler is not None:
170
+ lr_scheduler = instantiate(
171
+ self.lr_scheduler,
172
+ optimizer=optimizer,
173
+ )
174
+ print(f"{lr_scheduler=}")
175
+ else:
176
+ lr_scheduler = None
177
+
178
+ log.info("Setup models and optimizer with Fabric.")
179
+ mask_model, optimizer = self.fabric.setup(mask_model, optimizer)
180
+
181
+ log.info("Move the merged module to the correct device and disable gradients.")
182
+ module.requires_grad_(False)
183
+ module.to(mask_model.device)
184
+
185
+ mask_model.train()
186
+ optimizer.zero_grad()
187
+ for step_idx in (
188
+ pbar := tqdm(
189
+ range(self.config.max_steps if not self.is_debug_mode else 5),
190
+ ("[DEBUG MODE] " if self.is_debug_mode else "")
191
+ + "Concrete TSVM Test-Time Adaptation",
192
+ dynamic_ncols=True,
193
+ disable=not self.fabric.is_global_zero,
194
+ )
195
+ ):
196
+ metrics = {}
197
+ # sample a shared mask and merge weights
198
+ with self.profile("sample mask"):
199
+ mask = mask_model.sample_mask(
200
+ mask_type="continuous", temperature=config.temperature
201
+ )
202
+ metrics["train/sparsity"] = mask_sparsity(mask)
203
+ with self.profile("merge weights"):
204
+ # rescale mask
205
+ for name, m in mask.items():
206
+ mask[name] = m / torch.mean(m)
207
+ module.merge_weights(task_vector_mask=mask)
208
+
209
+ # ------ inner optimization goes here ------
210
+ # NOTE:
211
+ # Because the algorithmic parameters of TSVM are assumed to be chosen on a validation test
212
+ # set, we do not need to perform inner optimization here. So here we skip the inner optimization step.
213
+ # ------------------------------------------
214
+
215
+ total_loss = None
216
+ for task in self.modelpool.model_names:
217
+ with self.profile("data loading"):
218
+ batch = next(self.get_shuffled_test_loader_iter(task))
219
+ # NOTE: The labels are not allowed to be used during test-time adaptation
220
+ images = batch[0].to(dtype=merge_dtype)
221
+ with self.profile("forward pass"):
222
+ logits = self.compute_logits(module, images, task)
223
+ loss = entropy_loss(logits)
224
+ total_loss = loss if total_loss is None else total_loss + loss
225
+
226
+ with self.profile("compute grad"):
227
+ self.fabric.backward(total_loss)
228
+
229
+ with self.profile("optimizer step"):
230
+ optimizer.step()
231
+ optimizer.zero_grad()
232
+
233
+ if lr_scheduler is not None:
234
+ lr_scheduler.step()
235
+
236
+ metrics.update({"train/loss": loss.item()})
237
+ self.fabric.log_dict(metrics, step=step_idx)
238
+ pbar.set_postfix(metrics)
239
+
240
+ if (step_idx + 1) % self.config.save_interval == 0:
241
+ with self.profiler.profile("save checkpoint"):
242
+ save_dir = os.path.join(self.fabric.logger.log_dir, "checkpoints")
243
+ if not os.path.exists(save_dir):
244
+ os.makedirs(save_dir, exist_ok=True)
245
+ save_path = os.path.join(save_dir, f"mask_steps_{step_idx}.pt")
246
+ print(f"saving checkpoint to {save_path}")
247
+ state = {"model": mask_model}
248
+ self.fabric.save(save_path, state)
249
+
250
+ # Create or update a symbolic link to the latest checkpoint
251
+ if self.fabric.is_global_zero:
252
+ symlink_path = os.path.join(save_dir, "latest_checkpoint.pt")
253
+ if os.path.exists(symlink_path):
254
+ os.remove(symlink_path)
255
+ os.link(os.path.abspath(save_path), symlink_path)
256
+
257
+ self.print_profile_summary()
258
+
259
+ def run(self, modelpool: OpenCLIPVisionModelPool):
260
+ self.modelpool = modelpool
261
+ merge_dtype = parse_dtype(self.merge_dtype)
262
+
263
+ with self.profile("setup models"):
264
+ module, mask_model = self.setup_models()
265
+ self.setup_zero_shot_classification_head(freeze=True, dtype=merge_dtype)
266
+
267
+ if self.mask_checkpoint is None:
268
+ if not self.skip_training:
269
+ clear_cuda_cache()
270
+ self.train_mask(module, mask_model=mask_model)
271
+ else:
272
+ if self.fabric.is_global_zero:
273
+ print("loading mask from checkpoint", self.mask_checkpoint)
274
+ self.fabric.load(self.mask_checkpoint, {"model": mask_model})
275
+
276
+ with torch.no_grad():
277
+ clear_cuda_cache()
278
+ mask = mask_model.sample_mask(
279
+ mask_type=self.eval_mask_type, temperature=self.temperature
280
+ )
281
+ # rescale mask
282
+ for name, m in mask.items():
283
+ mask[name] = m / torch.mean(m)
284
+ model = module.merge_and_unload(mask)
285
+ return model.to(dtype=torch.float32)
@@ -79,28 +79,6 @@ class ContinualDOPForCLIP(BaseAlgorithm, LightningFabricMixin):
79
79
  ), "The alpha should be in the range of [0, 1]"
80
80
  super().__init__(**kwargs)
81
81
 
82
- def print_params(self, pretrained_model):
83
- total_params = 0
84
- linear_params = 0
85
- linear_weight_params = 0
86
- for module_name, module in pretrained_model.named_modules():
87
- if not is_leaf_module(module):
88
- continue
89
- if isinstance(module, nn.Linear):
90
- linear_params += sum(p.numel() for n, p in module.named_parameters())
91
- linear_weight_params += sum(
92
- p.numel() for n, p in module.named_parameters() if "weight" in n
93
- )
94
- total_params += sum(p.numel() for p in module.parameters())
95
-
96
- linear_ratio = linear_params / total_params * 100
97
- linear_weight_ratio = linear_weight_params / total_params * 100
98
- print(f"Total Parameters: {total_params}")
99
- print(f"Linear Parameters: {linear_params}")
100
- print(f"Linear Weight Parameters: {linear_weight_params}")
101
- print(f"Linear Ratio: {linear_ratio:.2f}%")
102
- print(f"Linear Weight Ratio: {linear_weight_ratio:.2f}%")
103
-
104
82
  def run(self, modelpool: BaseModelPool):
105
83
  if self.seed is not None:
106
84
  L.seed_everything(self.seed)