fusion-bench 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. fusion_bench/compat/method/__init__.py +2 -0
  2. fusion_bench/compat/method/base_algorithm.py +7 -2
  3. fusion_bench/compat/modelpool/__init__.py +3 -2
  4. fusion_bench/compat/taskpool/__init__.py +1 -1
  5. fusion_bench/dataset/arc_agi/__init__.py +6 -1
  6. fusion_bench/dataset/arc_agi/arc.py +26 -7
  7. fusion_bench/dataset/arc_agi/arc_agi.py +156 -25
  8. fusion_bench/dataset/arc_agi/np_cache.py +0 -1
  9. fusion_bench/dataset/arc_agi/preprocess.py +51 -9
  10. fusion_bench/dataset/llama/__init__.py +1 -0
  11. fusion_bench/dataset/llama/alpaca.py +93 -3
  12. fusion_bench/dataset/llama/collate.py +72 -5
  13. fusion_bench/dataset/llama/metamathqa.py +50 -0
  14. fusion_bench/dataset/llama/preference_700k.py +70 -0
  15. fusion_bench/dataset/llama/stanford_shp.py +90 -0
  16. fusion_bench/dataset/llama/ultrachat.py +58 -0
  17. fusion_bench/dataset/llama/utils/__init__.py +0 -0
  18. fusion_bench/method/__init__.py +4 -1
  19. fusion_bench/method/adamerging/__init__.py +1 -1
  20. fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -4
  21. fusion_bench/method/adamerging/min_norm_solvers.py +4 -4
  22. fusion_bench/method/linear/expo.py +39 -0
  23. fusion_bench/method/lm_finetune/__init__.py +1 -0
  24. fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
  25. fusion_bench/method/lm_finetune/fullfinetune_sft.py +122 -150
  26. fusion_bench/method/lm_finetune/peftfinetune_sft.py +102 -157
  27. fusion_bench/method/pruning/llama_magnitude_prune.py +2 -2
  28. fusion_bench/method/pruning/llama_random_prune.py +2 -2
  29. fusion_bench/method/pruning/magnitude_diff_pruning.py +2 -1
  30. fusion_bench/method/rankone_moe/__init__.py +3 -0
  31. fusion_bench/method/rankone_moe/clip_rankone_moe.py +160 -0
  32. fusion_bench/method/rankone_moe/rankone_moe.py +249 -0
  33. fusion_bench/method/simple_average.py +1 -1
  34. fusion_bench/method/surgery/__init__.py +3 -0
  35. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
  36. fusion_bench/mixins/__init__.py +2 -0
  37. fusion_bench/mixins/clip_classification.py +60 -12
  38. fusion_bench/mixins/fabric_training.py +320 -0
  39. fusion_bench/mixins/lightning_fabric.py +11 -2
  40. fusion_bench/modelpool/__init__.py +2 -0
  41. fusion_bench/modelpool/causal_lm/__init__.py +1 -1
  42. fusion_bench/modelpool/causal_lm/causal_lm.py +21 -22
  43. fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
  44. fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
  45. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
  46. fusion_bench/models/chat_templates/__init__.py +1 -0
  47. fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
  48. fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
  49. fusion_bench/models/hf_clip.py +50 -9
  50. fusion_bench/models/rankone_moe.py +410 -0
  51. fusion_bench/models/surgery/surgerymodelwrapper.py +157 -0
  52. fusion_bench/models/utils.py +8 -0
  53. fusion_bench/models/wrappers/layer_wise_fusion.py +14 -5
  54. fusion_bench/models/wrappers/task_wise_fusion.py +5 -5
  55. fusion_bench/optim/__init__.py +2 -0
  56. fusion_bench/optim/exception.py +47 -0
  57. fusion_bench/optim/lr_scheduler/__init__.py +1 -0
  58. fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
  59. fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
  60. fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
  61. fusion_bench/optim/mezo.py +0 -2
  62. fusion_bench/programs/fabric_fusion_program.py +5 -1
  63. fusion_bench/taskpool/__init__.py +10 -2
  64. fusion_bench/taskpool/clip_vision/__init__.py +1 -0
  65. fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +112 -0
  66. fusion_bench/taskpool/clip_vision/taskpool.py +43 -6
  67. fusion_bench/taskpool/llama/reward_model.py +157 -0
  68. fusion_bench/taskpool/nyuv2_taskpool.py +2 -0
  69. fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py +2 -1
  70. fusion_bench/utils/hydra_utils.py +22 -0
  71. fusion_bench/utils/plot/__init__.py +0 -0
  72. fusion_bench/utils/plot/token.py +52 -0
  73. fusion_bench/utils/plot/token_notebook.py +127 -0
  74. fusion_bench/utils/type.py +5 -3
  75. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/METADATA +1 -1
  76. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/RECORD +104 -57
  77. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  78. fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
  79. fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
  80. fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
  81. fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
  82. fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
  83. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
  84. fusion_bench_config/fabric_model_fusion.yaml +1 -1
  85. fusion_bench_config/llama_full_finetune.yaml +19 -0
  86. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
  87. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +13 -6
  88. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +17 -9
  89. fusion_bench_config/method/rankone_moe/rankone_moe.yaml +26 -0
  90. fusion_bench_config/method/regmean/clip_regmean.yaml +1 -0
  91. fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
  92. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
  93. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
  94. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
  95. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
  96. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
  97. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
  98. fusion_bench_config/nyuv2_config.yaml +5 -1
  99. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +18 -0
  100. fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
  101. fusion_bench_config/llama_weighted_average.yaml +0 -26
  102. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/LICENSE +0 -0
  103. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/WHEEL +0 -0
  104. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/entry_points.txt +0 -0
  105. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,249 @@
1
+ import logging
2
+ from abc import abstractmethod
3
+ from typing import cast # noqa: F401
4
+
5
+ import lightning as L
6
+ import lightning.fabric.wrappers
7
+ import torch
8
+ from lightning.pytorch.profilers import SimpleProfiler
9
+ from omegaconf import DictConfig
10
+ from torch import Tensor
11
+ from torch.utils.data import DataLoader
12
+ from tqdm.autonotebook import tqdm
13
+
14
+ from fusion_bench.compat.method.base_algorithm import ModelFusionAlgorithm
15
+ from fusion_bench.compat.modelpool import ModelPool
16
+ from fusion_bench.models.rankone_moe import RankOneMoE
17
+ from fusion_bench.utils import timeit_context
18
+ from fusion_bench.utils.parameters import print_parameters
19
+
20
+ log = logging.getLogger(__name__)
21
+
22
+
23
+ def entropy_loss(logits: Tensor) -> Tensor:
24
+ """
25
+ Compute the entropy loss of a set of logits.
26
+
27
+ Args:
28
+ logits (Tensor): The logits to compute the entropy loss of.
29
+
30
+ Returns:
31
+ Tensor: The entropy loss of the logits.
32
+ """
33
+ probs = torch.softmax(logits, dim=-1)
34
+ return -torch.sum(probs * torch.log(probs + 1e-8), dim=-1).mean()
35
+
36
+
37
+ class RankOneMoEAlgorithm(ModelFusionAlgorithm):
38
+ """
39
+ Algorithm for fusing models using RankOne-MoE (https://github.com/EnnengYang/RankOne-MoE).
40
+
41
+ This class provides methods for constructing the MoE model, performing test-time adaptation,
42
+ and running the fusion process.
43
+
44
+ Attributes:
45
+ _fabric (L.Fabric): The fabric for distributed training.
46
+ modelpool (ModelPool): The pool of models to be fused.
47
+ profiler (SimpleProfiler): The profiler for measuring performance.
48
+ """
49
+
50
+ _fabric: L.Fabric = None
51
+ modelpool: ModelPool = None
52
+
53
+ def __init__(self, algorithm_config: DictConfig):
54
+ """
55
+ Initialize the RankOneMoEAlgorithm with the given configuration.
56
+
57
+ Args:
58
+ algorithm_config (DictConfig): The configuration for the algorithm.
59
+ """
60
+ super().__init__(algorithm_config)
61
+
62
+ if self._fabric is None and torch.cuda.is_available():
63
+ self._fabric = L.Fabric(
64
+ devices=self.config.get("devices", 1),
65
+ )
66
+ self._fabric.launch()
67
+ else:
68
+ assert "No CUDA device available."
69
+ self.profiler = SimpleProfiler(
70
+ self.config.get("cache_dir", "outputs"), "we_moe_profiler.txt"
71
+ )
72
+
73
+ @abstractmethod
74
+ def load_checkpoint(self, model, checkpoint):
75
+ """
76
+ Load the checkpoint file.
77
+
78
+ Args:
79
+ model: The model to load the checkpoint into.
80
+ checkpoint: The checkpoint file to load.
81
+ """
82
+ pass
83
+
84
+ @abstractmethod
85
+ def save_checkpoint(self, model, checkpoint):
86
+ """
87
+ Save the checkpoint file.
88
+
89
+ Args:
90
+ model: The model to save the checkpoint from.
91
+ checkpoint: The checkpoint file to save.
92
+ """
93
+ pass
94
+
95
+ @abstractmethod
96
+ def construct_moe_model(self) -> RankOneMoE:
97
+ """
98
+ Construct the Mixture of Experts model using the models in the model pool.
99
+
100
+ Returns:
101
+ RankOne-MoE: The constructed MoE model.
102
+ """
103
+ pass
104
+
105
+ def on_test_time_adaptation_start(self):
106
+ """
107
+ Hook method called at the start of test-time adaptation.
108
+ """
109
+ pass
110
+
111
+ @abstractmethod
112
+ def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
113
+ """
114
+ Get an iterator for the shuffled test data loader for a specific task.
115
+
116
+ Args:
117
+ task (str): The task for which to get the test data loader.
118
+
119
+ Returns:
120
+ DataLoader: The shuffled test data loader iterator.
121
+ """
122
+ pass
123
+
124
+ @abstractmethod
125
+ def compute_logits(self, module, batch, task) -> Tensor:
126
+ """
127
+ Compute the logits for a given batch and task.
128
+
129
+ Args:
130
+ module: The model module to use for computing logits.
131
+ batch: The batch of data.
132
+ task: The task for which to compute logits.
133
+
134
+ Returns:
135
+ Tensor: The computed logits.
136
+ """
137
+ pass
138
+
139
+ def test_time_adaptation(self, module: RankOneMoE):
140
+ """
141
+ Perform test-time adaptation for the given module.
142
+
143
+ Args:
144
+ module (RankOne-MoE): The MoE module to adapt.
145
+
146
+ Returns:
147
+ RankOne-MoE: The adapted MoE module.
148
+ """
149
+ self.on_test_time_adaptation_start()
150
+
151
+ # configure optimizer
152
+ if self.config.optimizer == "adam":
153
+ optimizer = torch.optim.Adam(
154
+ [p for p in module.parameters() if p.requires_grad], lr=self.config.lr
155
+ )
156
+ else:
157
+ raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")
158
+
159
+ if self._fabric is not None:
160
+ module, optimizer = self._fabric.setup(module, optimizer)
161
+
162
+ module.train()
163
+
164
+ if self.config.get("fast_dev_run", False):
165
+ log.info("Running fast_dev_run, only one step")
166
+ pbar = tqdm(
167
+ range(1),
168
+ "Test-time adaptation",
169
+ dynamic_ncols=True,
170
+ )
171
+ else:
172
+ pbar = tqdm(
173
+ range(self.config.max_steps),
174
+ "Test-time adaptation",
175
+ dynamic_ncols=True,
176
+ )
177
+ for step_idx in pbar:
178
+ if self.config.use_grad_accumulate:
179
+ for task in self.modelpool.model_names:
180
+ with self.profiler.profile("data time"):
181
+ batch = next(self.get_shuffled_test_loader_iter(task))
182
+ with self.profiler.profile("forward pass"):
183
+ logits = self.compute_logits(module, batch, task)
184
+ assert (
185
+ logits.dim() == 2
186
+ ), f"Expected logits to be 2D, got {logits.dim()}"
187
+ loss = entropy_loss(logits)
188
+ # .backward() accumulates when .zero_grad() wasn't called
189
+ # this can save memory
190
+ with self.profiler.profile("backward pass"):
191
+ self._fabric.backward(loss, retain_graph=True)
192
+ else:
193
+ loss = 0
194
+ for task in self.modelpool.model_names:
195
+ with self.profiler.profile("data time"):
196
+ batch = next(self.get_shuffled_test_loader_iter(task))
197
+ with self.profiler.profile("forward pass"):
198
+ logits = self.compute_logits(module, batch, task)
199
+ assert (
200
+ logits.dim() == 2
201
+ ), f"Expected logits to be 2D, got {logits.dim()}"
202
+ loss = loss + entropy_loss(logits)
203
+ with self.profiler.profile("backward pass"):
204
+ self._fabric.backward(loss, retain_graph=True)
205
+
206
+ with self.profiler.profile("optimizer step"):
207
+ optimizer.step()
208
+ optimizer.zero_grad()
209
+
210
+ # print([m for m in module.parameters() if m.requires_grad][0])
211
+
212
+ return module
213
+
214
+ def run(self, modelpool: ModelPool):
215
+ """
216
+ Run the RankOneMoEAlgorithm to fuse models using RankOne-MoE.
217
+
218
+ Args:
219
+ modelpool (ModelPool): The pool of models to be fused.
220
+
221
+ Returns:
222
+ RankOne-MoE: The fused RankOne MoE model.
223
+ """
224
+ log.info("Fusing models using RankOne-MoE modules.")
225
+ self.modelpool = modelpool
226
+
227
+ with timeit_context("upscaling models to a RankOne-MoE model"):
228
+ moe_model = self.construct_moe_model()
229
+ print_parameters(moe_model)
230
+
231
+ if self.config.get("checkpoint", False):
232
+ log.info(
233
+ f"load checkpoint from {self.config.checkpoint}, test-time adaptation will be skipped."
234
+ )
235
+ self.load_checkpoint(moe_model, self.config.checkpoint)
236
+ else:
237
+ with self.profiler.profile("test-time adaptation"):
238
+ moe_model = self.test_time_adaptation(moe_model)
239
+ if self.config.get("save_checkpoint", False):
240
+ log.info(f"save checkpoint to {self.config.save_checkpoint}")
241
+ self.save_checkpoint(moe_model, self.config.save_checkpoint)
242
+
243
+ if lightning.fabric.wrappers.is_wrapped(moe_model):
244
+ moe_model = lightning.fabric.wrappers._unwrap_objects(moe_model)
245
+
246
+ # enable sample-wise adaptation
247
+ moe_model.batch_reduce = False
248
+ print(self.profiler.summary())
249
+ return moe_model
@@ -11,8 +11,8 @@ from fusion_bench.modelpool import BaseModelPool
11
11
  from fusion_bench.utils.state_dict_arithmetic import (
12
12
  state_dict_add,
13
13
  state_dict_avg,
14
- state_dict_mul,
15
14
  state_dict_div,
15
+ state_dict_mul,
16
16
  )
17
17
  from fusion_bench.utils.type import StateDictType
18
18
 
@@ -0,0 +1,3 @@
1
+ from .clip_layer_wise_adamerging_surgery import (
2
+ CLIPLayerWiseAdaMergingSurgeryAlgorithm,
3
+ )
@@ -0,0 +1,157 @@
1
+ """
2
+ Implementation of the Layer-Wise AdaMerging+Surgery Algorithm.
3
+
4
+ For more details, please refer to:
5
+
6
+ - (ICLR 2024) Yang, et.al. AdaMerging: Adaptive Model Merging for Multi-Task Learning. http://arxiv.org/abs/2310.02575
7
+ - (ICML 2024) Yang, et.al. Representation Surgery for Multi-Task Model Merging. https://arxiv.org/abs/2402.02705
8
+
9
+ Basic Example:
10
+
11
+ ```shell
12
+ fusion_bench \
13
+ method=surgery/adamerging_surgery \
14
+ modelpool=CLIPVisionModelPool/clip-vit-base-patch32_TA8 \
15
+ taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TA8
16
+ ```
17
+ """
18
+
19
+ import copy
20
+ import functools
21
+ import gc
22
+ import logging
23
+ from typing import TYPE_CHECKING, cast
24
+
25
+ import torch
26
+ import torch.nn.functional as F
27
+ from torch.utils.data import DataLoader
28
+ from tqdm import tqdm
29
+ from transformers import CLIPVisionModel
30
+
31
+ from fusion_bench.dataset.clip_dataset import CLIPDataset
32
+ from fusion_bench.method.adamerging.layer_wise_adamerging import (
33
+ LayerWiseAdaMergingAlgorithm,
34
+ )
35
+ from fusion_bench.method.adamerging.utils import get_memory_usage
36
+ from fusion_bench.mixins import CLIPClassificationMixin
37
+ from fusion_bench.modelpool import CLIPVisionModelPool
38
+ from fusion_bench.models.surgery.surgerymodelwrapper import SurgeryModelWrapper
39
+ from fusion_bench.models.wrappers.layer_wise_fusion import LayerWiseMergedModel
40
+
41
+ log = logging.getLogger(__name__)
42
+
43
+
44
+ class CLIPLayerWiseAdaMergingSurgeryAlgorithm(
45
+ CLIPClassificationMixin,
46
+ LayerWiseAdaMergingAlgorithm,
47
+ ):
48
+
49
+ def on_test_time_adaptation_start(self):
50
+ """
51
+ Here we load the CLIP processor and construct the zero-shot classification head for each task.
52
+ """
53
+ self.setup_zero_shot_classification_head()
54
+
55
+ @functools.cache
56
+ def get_shuffled_test_loader_iter(self, task: str):
57
+ return super().get_shuffled_test_loader_iter(
58
+ task,
59
+ batch_size=self.config.batch_size,
60
+ num_workers=self.config.num_workers,
61
+ )
62
+
63
+ def run(self, modelpool: CLIPVisionModelPool, **kwargs):
64
+ """
65
+ Run the Layer-Wise AdaMerging+Surgery Algorithm.
66
+
67
+ This method constructs the wrapped model and performs test-time adaptation if necessary. Then, it will perform surgery.
68
+
69
+ Args:
70
+ modelpool (ModelPool): The model pool containing the pretrained and fine-tuned models.
71
+
72
+ Returns:
73
+ LayerWiseMergedModel: The merged model after test-time adaptation.
74
+ """
75
+ log.info("Fusing models using layer-wise adaptive merging.")
76
+ self.modelpool = modelpool
77
+ self.log_hyperparams(self.config)
78
+
79
+ # === Start of the AdaMerging Algorithm ===
80
+ with self.profile("construct the wrapped model"):
81
+ module = cast(
82
+ LayerWiseMergedModel[CLIPVisionModel],
83
+ self.construct_layer_wise_merged_model(modelpool),
84
+ )
85
+
86
+ if self.config.weights is not None:
87
+ # skip the test-time adaptation
88
+ merged_model = copy.deepcopy(module.merge_and_unload())
89
+ else:
90
+ with self.profile("test-time adaptation"):
91
+ module = self.test_time_adaptation(module)
92
+ if self.config.get("save_merging_weights", False):
93
+ self.save_merging_weights(
94
+ self.config.save_merging_weights, module.merge_weight
95
+ )
96
+ merged_model = copy.deepcopy(module.merge_and_unload())
97
+
98
+ # free memory
99
+ del module
100
+ gc.collect()
101
+ torch.cuda.empty_cache()
102
+
103
+ # === Start of the Surgery Algorithm ===
104
+ log.info("start performing Surgery")
105
+ alpha_model = SurgeryModelWrapper(
106
+ merged_model,
107
+ modelpool.model_names,
108
+ projection_dim=merged_model.config.projection_dim,
109
+ )
110
+ alpha_model = self.fabric.setup(alpha_model)
111
+ log.info(get_memory_usage("after freeing memory, the memory usage of GPU is:"))
112
+
113
+ optimizer = torch.optim.Adam(
114
+ alpha_model.collect_trainable_params(),
115
+ lr=1e-3,
116
+ betas=(0.9, 0.999),
117
+ weight_decay=0.0,
118
+ )
119
+
120
+ finetuned_models = {
121
+ model_name: modelpool.load_model(model_name)
122
+ for model_name in modelpool.model_names
123
+ }
124
+ for name, model in finetuned_models.items():
125
+ model.requires_grad_(False)
126
+ model = self.fabric.to_device(model)
127
+ model.eval()
128
+
129
+ for iteration in tqdm(
130
+ range(self.config.surgery_steps),
131
+ "surgery",
132
+ dynamic_ncols=True,
133
+ ):
134
+ for dataset_name in modelpool.model_names:
135
+ batch = next(self.get_shuffled_test_loader_iter(dataset_name))
136
+ finetuned_feature = self.compute_features(
137
+ finetuned_models[dataset_name], batch[0]
138
+ )
139
+ features, _, _ = alpha_model.compute_surgery_features(
140
+ lambda model: self.compute_features(model, batch[0]),
141
+ dataset_name,
142
+ )
143
+
144
+ loss = F.l1_loss(features, finetuned_feature)
145
+
146
+ optimizer.zero_grad()
147
+ loss.backward()
148
+ optimizer.step()
149
+
150
+ if ((iteration + 1) % self.config.eval_iterations) == 0:
151
+ # print(list(alpha_model.collect_trainable_params()))
152
+ # Evaluate try to use the test module in fusion bench
153
+ log.info(f"iteration: {iteration+1}")
154
+ self._program.evaluate_merged_model(self._program.taskpool, alpha_model)
155
+
156
+ log.info("test the result of Adamerging")
157
+ return merged_model
@@ -10,10 +10,12 @@ _import_structure = {
10
10
  "serialization": ["YAMLSerializationMixin", "BaseYAMLSerializableModel"],
11
11
  "simple_profiler": ["SimpleProfilerMixin"],
12
12
  "clip_classification": ["CLIPClassificationMixin"],
13
+ "fabric_training": ["FabricTrainingMixin"],
13
14
  }
14
15
 
15
16
  if TYPE_CHECKING:
16
17
  from .clip_classification import CLIPClassificationMixin
18
+ from .fabric_training import FabricTrainingMixin
17
19
  from .lightning_fabric import LightningFabricMixin
18
20
  from .serialization import BaseYAMLSerializableModel, YAMLSerializationMixin
19
21
  from .simple_profiler import SimpleProfilerMixin
@@ -2,14 +2,24 @@ import functools
2
2
  import logging
3
3
  import os
4
4
  from copy import deepcopy
5
- from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union, cast # noqa: F401
5
+ from typing import ( # noqa: F401
6
+ TYPE_CHECKING,
7
+ Any,
8
+ Dict,
9
+ List,
10
+ Optional,
11
+ Tuple,
12
+ TypeVar,
13
+ Union,
14
+ cast,
15
+ )
6
16
 
7
17
  import torch
18
+ from omegaconf import DictConfig
8
19
  from torch import nn
9
20
  from torch.utils.data import DataLoader
10
21
  from tqdm.auto import tqdm
11
22
  from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
12
- from omegaconf import DictConfig
13
23
 
14
24
  from fusion_bench.dataset.clip_dataset import CLIPDataset
15
25
  from fusion_bench.mixins import LightningFabricMixin
@@ -18,10 +28,12 @@ from fusion_bench.models.hf_clip import HFCLIPClassifier
18
28
  from fusion_bench.tasks.clip_classification import get_classnames_and_templates
19
29
  from fusion_bench.utils.data import InfiniteDataLoader
20
30
 
21
- log = logging.getLogger(__name__)
31
+ if TYPE_CHECKING:
32
+ from transformers.models.clip.modeling_clip import CLIPVisionTransformer
22
33
 
23
- TensorOrModule = TypeVar("TensorOrModule", torch.Tensor, torch.nn.Module, Any)
34
+ log = logging.getLogger(__name__)
24
35
 
36
+ # disable tokenizers parallelism by default to avoid deadlocks
25
37
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
26
38
 
27
39
 
@@ -43,12 +55,7 @@ class CLIPClassificationMixin(LightningFabricMixin):
43
55
  # a dict of zeroshot weights for each task, each key is the task name
44
56
  zeroshot_weights_cache_dir: str = "outputs/cache/clip_zeroshot_weights"
45
57
  zeroshot_weights: Dict[str, torch.Tensor] = {}
46
-
47
- def __init__(self, algorithm_config: DictConfig) -> None:
48
- super().__init__(algorithm_config)
49
- self.whether_setup_zero_shot_classification_head = (
50
- False # We want to only do this once
51
- )
58
+ whether_setup_zero_shot_classification_head = False
52
59
 
53
60
  @property
54
61
  def clip_processor(self):
@@ -180,13 +187,30 @@ class CLIPClassificationMixin(LightningFabricMixin):
180
187
 
181
188
  def compute_logits(
182
189
  self,
183
- module: Union[nn.Module, CLIPVisionModel],
190
+ module: Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"],
184
191
  images: torch.Tensor,
185
192
  task: str,
193
+ image_embeds: Optional[torch.Tensor] = None,
186
194
  ) -> torch.Tensor:
195
+ """
196
+ Compute the logits of the images for a given task.
197
+
198
+ Args:
199
+ module (Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"]): The module to compute the logits.
200
+ images (torch.Tensor): The images to compute the logits.
201
+ task (str): The task to compute the logits.
202
+ image_embeds (Optional[torch.Tensor]): The precomputed image embeddings. If None, the image embeddings will be computed.
203
+
204
+ Returns:
205
+ torch.Tensor: The logits of the images.
206
+ """
187
207
  text_embeds = self.zeroshot_weights[task]
188
208
 
189
- image_embeds = module(images)[1]
209
+ if image_embeds is None:
210
+ image_embeds = module(images)[1]
211
+ assert isinstance(
212
+ image_embeds, torch.Tensor
213
+ ), f"`image_embeds` must be a tensor, but got {type(image_embeds)}"
190
214
  image_embeds = self.visual_projection(image_embeds)
191
215
 
192
216
  # normalize embeddings
@@ -199,3 +223,27 @@ class CLIPClassificationMixin(LightningFabricMixin):
199
223
  logits_per_image = logits_per_text.t()
200
224
 
201
225
  return logits_per_image
226
+
227
+ def compute_features(
228
+ self,
229
+ module: Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"],
230
+ images: torch.Tensor,
231
+ normalize: bool = True,
232
+ ) -> torch.Tensor:
233
+ """
234
+ Extracts image features using CLIP's vision encoder and visual projection.
235
+
236
+ Args:
237
+ module (Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"]): The CLIP vision encoder module.
238
+ images (torch.Tensor): Input image batch to process.
239
+ normalize (bool): Whether to normalize the image embeddings.
240
+
241
+ Returns:
242
+ torch.Tensor: Normalized image embeddings with dimension matching CLIP's projection space (`projection_dim` in model config).
243
+ """
244
+ image_embeds = module(images)[1]
245
+ image_embeds = self.visual_projection(image_embeds)
246
+
247
+ if normalize:
248
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
249
+ return image_embeds