fusion-bench 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. fusion_bench/compat/method/__init__.py +2 -0
  2. fusion_bench/compat/method/base_algorithm.py +7 -2
  3. fusion_bench/compat/modelpool/__init__.py +3 -2
  4. fusion_bench/compat/taskpool/__init__.py +1 -1
  5. fusion_bench/dataset/arc_agi/__init__.py +6 -1
  6. fusion_bench/dataset/arc_agi/arc.py +26 -7
  7. fusion_bench/dataset/arc_agi/arc_agi.py +156 -25
  8. fusion_bench/dataset/arc_agi/np_cache.py +0 -1
  9. fusion_bench/dataset/arc_agi/preprocess.py +51 -9
  10. fusion_bench/dataset/llama/__init__.py +1 -0
  11. fusion_bench/dataset/llama/alpaca.py +93 -3
  12. fusion_bench/dataset/llama/collate.py +72 -5
  13. fusion_bench/dataset/llama/metamathqa.py +50 -0
  14. fusion_bench/dataset/llama/preference_700k.py +70 -0
  15. fusion_bench/dataset/llama/stanford_shp.py +90 -0
  16. fusion_bench/dataset/llama/ultrachat.py +58 -0
  17. fusion_bench/dataset/llama/utils/__init__.py +0 -0
  18. fusion_bench/method/__init__.py +4 -1
  19. fusion_bench/method/adamerging/__init__.py +1 -1
  20. fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -4
  21. fusion_bench/method/adamerging/min_norm_solvers.py +4 -4
  22. fusion_bench/method/linear/expo.py +39 -0
  23. fusion_bench/method/lm_finetune/__init__.py +1 -0
  24. fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
  25. fusion_bench/method/lm_finetune/fullfinetune_sft.py +122 -150
  26. fusion_bench/method/lm_finetune/peftfinetune_sft.py +102 -157
  27. fusion_bench/method/pruning/llama_magnitude_prune.py +2 -2
  28. fusion_bench/method/pruning/llama_random_prune.py +2 -2
  29. fusion_bench/method/pruning/magnitude_diff_pruning.py +2 -1
  30. fusion_bench/method/rankone_moe/__init__.py +3 -0
  31. fusion_bench/method/rankone_moe/clip_rankone_moe.py +160 -0
  32. fusion_bench/method/rankone_moe/rankone_moe.py +249 -0
  33. fusion_bench/method/simple_average.py +1 -1
  34. fusion_bench/method/surgery/__init__.py +3 -0
  35. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
  36. fusion_bench/mixins/__init__.py +2 -0
  37. fusion_bench/mixins/clip_classification.py +60 -12
  38. fusion_bench/mixins/fabric_training.py +320 -0
  39. fusion_bench/mixins/lightning_fabric.py +11 -2
  40. fusion_bench/modelpool/__init__.py +2 -0
  41. fusion_bench/modelpool/causal_lm/__init__.py +1 -1
  42. fusion_bench/modelpool/causal_lm/causal_lm.py +21 -22
  43. fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
  44. fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
  45. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
  46. fusion_bench/models/chat_templates/__init__.py +1 -0
  47. fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
  48. fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
  49. fusion_bench/models/hf_clip.py +50 -9
  50. fusion_bench/models/rankone_moe.py +410 -0
  51. fusion_bench/models/surgery/surgerymodelwrapper.py +157 -0
  52. fusion_bench/models/utils.py +8 -0
  53. fusion_bench/models/wrappers/layer_wise_fusion.py +14 -5
  54. fusion_bench/models/wrappers/task_wise_fusion.py +5 -5
  55. fusion_bench/optim/__init__.py +2 -0
  56. fusion_bench/optim/exception.py +47 -0
  57. fusion_bench/optim/lr_scheduler/__init__.py +1 -0
  58. fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
  59. fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
  60. fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
  61. fusion_bench/optim/mezo.py +0 -2
  62. fusion_bench/programs/fabric_fusion_program.py +5 -1
  63. fusion_bench/taskpool/__init__.py +10 -2
  64. fusion_bench/taskpool/clip_vision/__init__.py +1 -0
  65. fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +112 -0
  66. fusion_bench/taskpool/clip_vision/taskpool.py +43 -6
  67. fusion_bench/taskpool/llama/reward_model.py +157 -0
  68. fusion_bench/taskpool/nyuv2_taskpool.py +2 -0
  69. fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py +2 -1
  70. fusion_bench/utils/hydra_utils.py +22 -0
  71. fusion_bench/utils/plot/__init__.py +0 -0
  72. fusion_bench/utils/plot/token.py +52 -0
  73. fusion_bench/utils/plot/token_notebook.py +127 -0
  74. fusion_bench/utils/type.py +5 -3
  75. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/METADATA +1 -1
  76. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/RECORD +104 -57
  77. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  78. fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
  79. fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
  80. fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
  81. fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
  82. fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
  83. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
  84. fusion_bench_config/fabric_model_fusion.yaml +1 -1
  85. fusion_bench_config/llama_full_finetune.yaml +19 -0
  86. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
  87. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +13 -6
  88. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +17 -9
  89. fusion_bench_config/method/rankone_moe/rankone_moe.yaml +26 -0
  90. fusion_bench_config/method/regmean/clip_regmean.yaml +1 -0
  91. fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
  92. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
  93. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
  94. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
  95. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
  96. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
  97. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
  98. fusion_bench_config/nyuv2_config.yaml +5 -1
  99. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +18 -0
  100. fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
  101. fusion_bench_config/llama_weighted_average.yaml +0 -26
  102. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/LICENSE +0 -0
  103. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/WHEEL +0 -0
  104. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/entry_points.txt +0 -0
  105. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,22 @@
1
1
  import functools
2
2
  import logging
3
3
  from copy import deepcopy
4
- from typing import Any, Callable, Dict, Iterator, List, Optional # noqa: F401
4
+ from typing import ( # noqa: F401
5
+ Any,
6
+ Callable,
7
+ Dict,
8
+ Generic,
9
+ Iterator,
10
+ List,
11
+ Optional,
12
+ TypeVar,
13
+ )
5
14
 
6
15
  import torch
7
16
  from torch import Tensor, nn
8
17
  from torch.func import functional_call
9
18
 
10
- from fusion_bench.utils.type import StateDictType
19
+ from fusion_bench.utils.type import TorchModelType, StateDictType
11
20
 
12
21
  __all__ = ["get_layer_wise_weights", "fuse_weights", "LayerWiseMergedModel"]
13
22
 
@@ -132,14 +141,14 @@ def fuse_weights(
132
141
  }
133
142
 
134
143
 
135
- class LayerWiseMergedModel(nn.Module):
144
+ class LayerWiseMergedModel(nn.Module, Generic[TorchModelType]):
136
145
  _merged_state_dict: StateDictType = None
137
146
 
138
147
  def __init__(
139
148
  self,
140
149
  layer_wise_weight: Tensor,
141
- pretrained_model: nn.Module,
142
- finetuned_models: List[nn.Module],
150
+ pretrained_model: TorchModelType,
151
+ finetuned_models: List[TorchModelType],
143
152
  clamp_weights: bool = True,
144
153
  tie_weights: bool = False,
145
154
  strict: bool = True,
@@ -16,13 +16,13 @@ outputs = merged_model(inputs)
16
16
 
17
17
  import functools
18
18
  import logging
19
- from typing import Any, Callable, Dict, Iterator, List, Optional # noqa: F401
19
+ from typing import Any, Callable, Dict, Generic, Iterator, List, Optional # noqa: F401
20
20
 
21
21
  import torch
22
22
  from torch import Tensor, nn
23
23
  from torch.func import functional_call
24
24
 
25
- from fusion_bench.utils.type import StateDictType
25
+ from fusion_bench.utils.type import TorchModelType, StateDictType
26
26
 
27
27
  log = logging.getLogger(__name__)
28
28
 
@@ -157,14 +157,14 @@ def fuse_weights(
157
157
  }
158
158
 
159
159
 
160
- class TaskWiseMergedModel(nn.Module):
160
+ class TaskWiseMergedModel(nn.Module, Generic[TorchModelType]):
161
161
  _merged_state_dict: StateDictType = None
162
162
 
163
163
  def __init__(
164
164
  self,
165
165
  task_wise_weight: Tensor,
166
- pretrained_model: nn.Module,
167
- finetuned_models: List[nn.Module],
166
+ pretrained_model: TorchModelType,
167
+ finetuned_models: List[TorchModelType],
168
168
  clamp_weights: bool = True,
169
169
  tie_weights: bool = False,
170
170
  strict: bool = True,
@@ -0,0 +1,2 @@
1
+ from . import exception, lr_scheduler
2
+ from .mezo import MeZO
@@ -0,0 +1,47 @@
1
+ class NoSparseGradientError(Exception):
2
+ """Raised when the gradient is sparse gradient.
3
+
4
+ :param optimizer_name: str. optimizer name.
5
+ :param note: str. special conditions to note (default '').
6
+ """
7
+
8
+ def __init__(self, optimizer_name: str, note: str = ""):
9
+ self.note: str = " " if not note else f" w/ {note} "
10
+ self.message: str = (
11
+ f"[-] {optimizer_name}{self.note}does not support sparse gradient."
12
+ )
13
+ super().__init__(self.message)
14
+
15
+
16
+ class ZeroParameterSizeError(Exception):
17
+ """Raised when the parameter size is 0."""
18
+
19
+ def __init__(self):
20
+ self.message: str = "[-] parameter size is 0"
21
+ super().__init__(self.message)
22
+
23
+
24
+ class NoClosureError(Exception):
25
+ """Raised when there's no closure function."""
26
+
27
+ def __init__(self, optimizer_name: str, note: str = ""):
28
+ self.message: str = f"[-] {optimizer_name} requires closure.{note}"
29
+ super().__init__(self.message)
30
+
31
+
32
+ class NegativeLRError(Exception):
33
+ """Raised when learning rate is negative."""
34
+
35
+ def __init__(self, lr: float, lr_type: str = ""):
36
+ self.note: str = lr_type if lr_type else "learning rate"
37
+ self.message: str = f"[-] {self.note} must be positive. ({lr} > 0)"
38
+ super().__init__(self.message)
39
+
40
+
41
+ class NegativeStepError(Exception):
42
+ """Raised when step is negative."""
43
+
44
+ def __init__(self, num_steps: int, step_type: str = ""):
45
+ self.note: str = step_type if step_type else "step"
46
+ self.message: str = f"[-] {self.note} must be positive. ({num_steps} > 0)"
47
+ super().__init__(self.message)
@@ -0,0 +1 @@
1
+ from .linear_warmup import *
@@ -0,0 +1,222 @@
1
+ """
2
+ Modified from pytorch_optimizer: https://github.com/kozistr/pytorch_optimizer/blob/main/pytorch_optimizer/lr_scheduler/linear_warmup.py
3
+ """
4
+
5
+ import math
6
+ from abc import ABC, abstractmethod
7
+ from typing import List
8
+
9
+ import numpy as np
10
+ import torch
11
+
12
+ from fusion_bench.optim.exception import NegativeLRError, NegativeStepError
13
+
14
+ __all__ = [
15
+ "BaseLinearWarmupScheduler",
16
+ "LinearWarmupScheduler",
17
+ "CosineDecayWithWarmup",
18
+ "PolySchedulerWithWarmup",
19
+ ]
20
+
21
+
22
+ class BaseLinearWarmupScheduler(ABC):
23
+ r"""BaseLinearWarmupScheduler class.
24
+
25
+ The LR Scheduler class based on this class has linear warmup strategy.
26
+
27
+ Args:
28
+ optimizer (torch.optim.Optimizer): Optimizer. It will set learning rate to all trainable parameters in optimizer.
29
+ T_max (int): Total steps to train.
30
+ max_lr (float): Maximum learning rate.
31
+ min_lr (float): Minimum learning rate.
32
+ init_lr (float): Initial learning rate.
33
+ warmup_steps (int): Steps to warm-up.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ optimizer: torch.optim.Optimizer,
39
+ T_max: int,
40
+ max_lr: float,
41
+ min_lr: float = 0.0,
42
+ init_lr: float = 0.0,
43
+ warmup_steps: int = 0,
44
+ ):
45
+ """
46
+ Initialize the BaseLinearWarmupScheduler.
47
+
48
+ Args:
49
+ optimizer (torch.optim.Optimizer): Optimizer to apply the learning rate schedule.
50
+ T_max (int): Total number of training steps.
51
+ max_lr (float): Maximum learning rate.
52
+ min_lr (float): Minimum learning rate.
53
+ init_lr (float): Initial learning rate.
54
+ warmup_steps (int): Number of steps for the warm-up phase.
55
+ """
56
+ self.optimizer = optimizer
57
+ self.total_steps = T_max
58
+ self.max_lr = max_lr
59
+ self.min_lr = min_lr
60
+ self.init_lr = init_lr
61
+ self.warmup_steps = warmup_steps
62
+
63
+ self.step_t: int = 0
64
+ self.base_lrs: List[float] = []
65
+
66
+ # record current value in self._last_lr to match API from torch.optim.lr_scheduler
67
+ self.last_lr: List[float] = [init_lr]
68
+
69
+ self.validate_parameters()
70
+
71
+ self._init_lr()
72
+
73
+ def validate_parameters(self):
74
+ """
75
+ Validate the parameters to ensure they are non-negative.
76
+
77
+ Raises:
78
+ NegativeLRError: If any of the learning rates are negative.
79
+ NegativeStepError: If any of the step values are negative.
80
+ """
81
+ if self.min_lr < 0:
82
+ raise NegativeLRError(self.min_lr, "min_lr")
83
+
84
+ if self.max_lr < 0:
85
+ raise NegativeLRError(self.max_lr, "max_lr")
86
+
87
+ if self.init_lr < 0:
88
+ raise NegativeLRError(self.init_lr, "init_lr")
89
+
90
+ if self.total_steps < 0:
91
+ raise NegativeStepError(self.total_steps, "T_max")
92
+
93
+ if self.warmup_steps < 0:
94
+ raise NegativeStepError(self.warmup_steps, "warmup_steps")
95
+
96
+ def _init_lr(self):
97
+ """
98
+ Initialize the learning rate for each parameter group in the optimizer.
99
+ """
100
+ self.base_lrs = []
101
+ for param_group in self.optimizer.param_groups:
102
+ param_group["lr"] = self.min_lr
103
+ self.base_lrs.append(self.min_lr)
104
+
105
+ def step(self):
106
+ """
107
+ Update the learning rate for the current step.
108
+
109
+ Returns:
110
+ float: The updated learning rate.
111
+ """
112
+ if self.step_t < self.warmup_steps:
113
+ value = (
114
+ self.init_lr
115
+ + (self.max_lr - self.init_lr) * self.step_t / self.warmup_steps
116
+ )
117
+ elif self.step_t == self.warmup_steps:
118
+ value = self.max_lr
119
+ else:
120
+ value = self._step()
121
+
122
+ self.step_t += 1
123
+
124
+ if self.optimizer is not None:
125
+ for param_group in self.optimizer.param_groups:
126
+ param_group["lr"] = value
127
+
128
+ self.last_lr = [value]
129
+
130
+ return value
131
+
132
+ @abstractmethod
133
+ def _step(self) -> float: # pragma: no cover
134
+ """
135
+ Abstract method to calculate the learning rate for the current step.
136
+
137
+ Returns:
138
+ float: The calculated learning rate.
139
+ """
140
+ raise NotImplementedError
141
+
142
+ def get_lr(self) -> float:
143
+ """
144
+ Get the current learning rate.
145
+
146
+ Returns:
147
+ float: The current learning rate.
148
+ """
149
+ return self.last_lr[0]
150
+
151
+
152
+ class LinearWarmupScheduler(BaseLinearWarmupScheduler):
153
+ r"""Linear LR Scheduler w/ linear warmup."""
154
+
155
+ def _step(self) -> float:
156
+ """
157
+ Calculate the learning rate for the current step using a linear decay.
158
+
159
+ Returns:
160
+ float: The calculated learning rate.
161
+ """
162
+ return self.max_lr + (self.min_lr - self.max_lr) * (
163
+ self.step_t - self.warmup_steps
164
+ ) / (self.total_steps - self.warmup_steps)
165
+
166
+
167
+ class CosineDecayWithWarmup(BaseLinearWarmupScheduler):
168
+ r"""Cosine LR Scheduler w/ linear warmup."""
169
+
170
+ def _step(self) -> float:
171
+ """
172
+ Calculate the learning rate for the current step using a cosine decay.
173
+
174
+ Returns:
175
+ float: The calculated learning rate.
176
+ """
177
+ phase: float = (
178
+ (self.step_t - self.warmup_steps)
179
+ / (self.total_steps - self.warmup_steps)
180
+ * math.pi
181
+ )
182
+ return self.min_lr + (self.max_lr - self.min_lr) * (np.cos(phase) + 1.0) / 2.0
183
+
184
+
185
+ class PolySchedulerWithWarmup(BaseLinearWarmupScheduler):
186
+ r"""Poly LR Scheduler.
187
+
188
+ Args:
189
+ poly_order (float): LR scheduler decreases with steps.
190
+ """
191
+
192
+ def __init__(self, optimizer, poly_order: float = 0.5, **kwargs):
193
+ """
194
+ Initialize the PolySchedulerWithWarmup.
195
+
196
+ Args:
197
+ optimizer (torch.optim.Optimizer): Optimizer to apply the learning rate schedule.
198
+ poly_order (float): Order of the polynomial for the learning rate decay.
199
+ kwargs: Additional arguments for the base class.
200
+
201
+ Raises:
202
+ ValueError: If poly_order is not positive.
203
+ """
204
+ self.poly_order = poly_order
205
+
206
+ if poly_order <= 0:
207
+ raise ValueError(f"[-] poly_order must be positive. {poly_order}")
208
+
209
+ super().__init__(optimizer, **kwargs)
210
+
211
+ def _step(self) -> float:
212
+ """
213
+ Calculate the learning rate for the current step using a polynomial decay.
214
+
215
+ Returns:
216
+ float: The calculated learning rate.
217
+ """
218
+ return (
219
+ self.min_lr
220
+ + (self.max_lr - self.min_lr)
221
+ * (self.step_t - self.warmup_steps) ** self.poly_order
222
+ )
@@ -0,0 +1 @@
1
+ from .visualization import *
@@ -0,0 +1,119 @@
1
+ """
2
+ This module provides utilities for visualizing learning rate schedulers.
3
+
4
+ Functions:
5
+ simulate_scheduler(lr_scheduler, steps): Simulates the learning rate scheduler for a given number of steps.
6
+ plot_lr_schedulers(lr_schedulers, steps, titles): Plots the learning rates of one or more schedulers over a number of steps.
7
+ """
8
+
9
+ from typing import TYPE_CHECKING, List, Union
10
+
11
+ import matplotlib.pyplot as plt
12
+ import torch
13
+
14
+ if TYPE_CHECKING:
15
+ from torch.optim.lr_scheduler import LRScheduler
16
+
17
+ __all__ = ["simulate_scheduler", "plot_lr_schedulers"]
18
+
19
+
20
+ def simulate_scheduler(lr_scheduler, steps: int):
21
+ """
22
+ Simulates the learning rate scheduler for a given number of steps.
23
+
24
+ Args:
25
+ lr_scheduler (torch.optim.lr_scheduler.LRScheduler): The learning rate scheduler object.
26
+ steps (int): The number of steps to simulate.
27
+
28
+ Returns:
29
+ List[float]: A list of learning rates for each step.
30
+ """
31
+ lrs = []
32
+ for _ in range(steps):
33
+ lr = lr_scheduler.step()
34
+ lrs.append(lr)
35
+ return lrs
36
+
37
+
38
+ def plot_lr_schedulers(
39
+ lr_schedulers: Union["LRScheduler", List["LRScheduler"]],
40
+ steps: int,
41
+ titles: Union[str, List[str]],
42
+ show_plot: bool = True,
43
+ ):
44
+ """
45
+ Plots the learning rates of one or more schedulers over a number of steps.
46
+
47
+ Args:
48
+ lr_schedulers (Union[LRScheduler, List[LRScheduler]]): One or more learning rate scheduler objects.
49
+ steps (int): The number of steps to simulate.
50
+ titles (Union[str, List[str]]): Titles for the plots.
51
+
52
+ Returns:
53
+ fig, axes: The matplotlib figure and axes objects.
54
+ """
55
+ # Handle single scheduler
56
+ if isinstance(lr_schedulers, torch.optim.lr_scheduler.LRScheduler):
57
+ lr_schedulers = [lr_schedulers]
58
+ if isinstance(titles, str):
59
+ titles = [titles]
60
+
61
+ fig, axs = plt.subplots(len(lr_schedulers), 1, figsize=(5, 3 * len(lr_schedulers)))
62
+ if len(lr_schedulers) == 1:
63
+ axs = [axs]
64
+
65
+ for i, (scheduler, title) in enumerate(zip(lr_schedulers, titles)):
66
+ lrs = simulate_scheduler(scheduler, steps)
67
+ axs[i].plot(lrs, label=title)
68
+ axs[i].set_title(title)
69
+ axs[i].set_xlabel("Steps")
70
+ axs[i].set_ylabel("Learning Rate")
71
+ axs[i].legend()
72
+ axs[i].grid(True)
73
+
74
+ plt.tight_layout()
75
+ if show_plot:
76
+ plt.show()
77
+ return fig, axs
78
+
79
+
80
+ # Example usage
81
+ if __name__ == "__main__":
82
+ from fusion_bench.optim.lr_scheduler.linear_warmup import (
83
+ CosineDecayWithWarmup,
84
+ LinearWarmupScheduler,
85
+ PolySchedulerWithWarmup,
86
+ )
87
+
88
+ # Dummy optimizer
89
+ optimizer = torch.optim.SGD(
90
+ [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))], lr=0.1
91
+ )
92
+
93
+ # Define the schedulers
94
+ linear_scheduler = LinearWarmupScheduler(
95
+ optimizer, t_max=100, max_lr=0.1, min_lr=0.01, init_lr=0.0, warmup_steps=10
96
+ )
97
+ cosine_scheduler = CosineDecayWithWarmup(
98
+ optimizer, t_max=100, max_lr=0.1, min_lr=0.01, init_lr=0.0, warmup_steps=10
99
+ )
100
+ poly_scheduler = PolySchedulerWithWarmup(
101
+ optimizer,
102
+ t_max=100,
103
+ max_lr=0.1,
104
+ min_lr=0.01,
105
+ init_lr=0.0,
106
+ warmup_steps=40,
107
+ poly_order=2.0,
108
+ )
109
+
110
+ # Plot the learning rates
111
+ plot_lr_schedulers(
112
+ [linear_scheduler, cosine_scheduler, poly_scheduler],
113
+ steps=100,
114
+ titles=[
115
+ "Linear Warmup",
116
+ "Cosine Decay with Warmup",
117
+ "Poly Scheduler with Warmup",
118
+ ],
119
+ )
@@ -5,8 +5,6 @@ import numpy as np
5
5
  import torch
6
6
  from torch.optim.optimizer import Optimizer
7
7
 
8
- from fusion_bench.utils import timeit_context
9
-
10
8
  log = logging.getLogger(__name__)
11
9
 
12
10
 
@@ -236,7 +236,11 @@ class FabricModelFusionProgram(
236
236
  self.save_merged_model(merged_model)
237
237
  if self.taskpool is not None:
238
238
  report = self.evaluate_merged_model(self.taskpool, merged_model)
239
- print_json(report, print_type=False)
239
+ try:
240
+ print_json(report, print_type=False)
241
+ except Exception as e:
242
+ log.warning(f"Failed to pretty print the report: {e}")
243
+ print(report)
240
244
  if self.report_save_path is not None:
241
245
  # save report (Dict) to a file
242
246
  # if the directory of `save_report` does not exists, create it
@@ -7,7 +7,11 @@ from fusion_bench.utils.lazy_imports import LazyImporter
7
7
 
8
8
  _import_structure = {
9
9
  "base_pool": ["BaseTaskPool"],
10
- "clip_vision": ["CLIPVisionModelTaskPool", "SparseWEMoECLIPVisionModelTaskPool"],
10
+ "clip_vision": [
11
+ "CLIPVisionModelTaskPool",
12
+ "SparseWEMoECLIPVisionModelTaskPool",
13
+ "RankoneWEMoECLIPVisionModelTaskPool",
14
+ ],
11
15
  "dummy": ["DummyTaskPool"],
12
16
  "gpt2_text_classification": ["GPT2TextClassificationTaskPool"],
13
17
  "nyuv2_taskpool": ["NYUv2TaskPool"],
@@ -17,7 +21,11 @@ _import_structure = {
17
21
 
18
22
  if TYPE_CHECKING:
19
23
  from .base_pool import BaseTaskPool
20
- from .clip_vision import CLIPVisionModelTaskPool, SparseWEMoECLIPVisionModelTaskPool
24
+ from .clip_vision import (
25
+ CLIPVisionModelTaskPool,
26
+ RankoneWEMoECLIPVisionModelTaskPool,
27
+ SparseWEMoECLIPVisionModelTaskPool,
28
+ )
21
29
  from .dummy import DummyTaskPool
22
30
  from .gpt2_text_classification import GPT2TextClassificationTaskPool
23
31
  from .llama import LlamaTestGenerationTaskPool
@@ -1,3 +1,4 @@
1
1
  # flake8: noqa F401
2
+ from .clip_rankone_moe_taskpool import RankoneMoECLIPVisionModelTaskPool
2
3
  from .clip_sparse_wemoe_taskpool import SparseWEMoECLIPVisionModelTaskPool
3
4
  from .taskpool import CLIPVisionModelTaskPool
@@ -0,0 +1,112 @@
1
+ from copy import deepcopy
2
+ from pathlib import Path
3
+ from typing import Any, Dict, Optional, Tuple
4
+
5
+ import torch
6
+ from torch import Tensor
7
+ from torch.utils.hooks import RemovableHandle
8
+ from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
9
+ from transformers.models.clip.modeling_clip import CLIPVisionTransformer
10
+
11
+ from fusion_bench.models.hf_clip import HFCLIPClassifier
12
+ from fusion_bench.models.rankone_moe import RankOneMoE
13
+
14
+ from .taskpool import CLIPVisionModelTaskPool
15
+
16
+
17
+ class LayerWiseRoutingWeightSaver:
18
+ def __init__(self, save_path: Path, max_num: Optional[int] = None):
19
+ self.save_path = save_path
20
+ self.max_num = max_num
21
+ self.routing_weights = []
22
+
23
+ def __call__(self, module, input: Tuple[Tensor], output: Tensor):
24
+ assert isinstance(output, Tensor), "Output is expected to be a Tensor"
25
+ # (batch_size, num_tokens, num_experts)
26
+ routing_weights = output.detach().cpu()
27
+ if self.max_num is not None and self.max_num > 0:
28
+ if len(self.routing_weights) > self.max_num:
29
+ return
30
+ elif routing_weights.size(0) + len(self.routing_weights) > self.max_num:
31
+ self.routing_weights.append(
32
+ routing_weights[: self.max_num - len(self.routing_weights)]
33
+ )
34
+ else:
35
+ self.routing_weights.append(routing_weights)
36
+ else:
37
+ self.routing_weights.append(routing_weights)
38
+
39
+ def save_routing_weights(self):
40
+ routing_weights = torch.cat(self.routing_weights, dim=0)
41
+ if self.save_path is not None:
42
+ self.save_path.parent.mkdir(parents=True, exist_ok=True)
43
+ print(f"Saving routing weights to {self.save_path}")
44
+ torch.save(routing_weights, self.save_path)
45
+
46
+
47
+ class RankoneMoECLIPVisionModelTaskPool(CLIPVisionModelTaskPool):
48
+
49
+ # hooks and handles for saving layer-wise routing weights
50
+ _layer_wise_routing_weights_save_hooks: Dict[Any, LayerWiseRoutingWeightSaver] = {}
51
+ _layer_wise_routing_weights_save_hook_handles: Dict[Any, RemovableHandle] = {}
52
+
53
+ _config_mapping = CLIPVisionModelTaskPool._config_mapping | {
54
+ "_layer_wise_routing_weights_save_path": "layer_wise_routing_weights_save_path",
55
+ }
56
+
57
+ def __init__(
58
+ self,
59
+ layer_wise_routing_weights_save_path: Optional[str],
60
+ layer_wise_routing_weights_max_num: Optional[int] = None,
61
+ **kwargs,
62
+ ):
63
+ # save path for layer-wise routing weights
64
+ self._layer_wise_routing_weights_save_path = (
65
+ layer_wise_routing_weights_save_path
66
+ )
67
+ self.layer_wise_routing_weights_save_path = (
68
+ Path(layer_wise_routing_weights_save_path)
69
+ if layer_wise_routing_weights_save_path is not None
70
+ else None
71
+ )
72
+ self.layer_wise_routing_weights_max_num = layer_wise_routing_weights_max_num
73
+ super().__init__(**kwargs)
74
+
75
+ def on_task_evaluation_begin(self, classifier: HFCLIPClassifier, task_name: str):
76
+ super().on_task_evaluation_begin(classifier, task_name)
77
+ if self.layer_wise_routing_weights_save_path is not None:
78
+ # setup hooks for saving layer-wise routing weights
79
+ assert isinstance(
80
+ classifier.clip_model.vision_model,
81
+ (CLIPVisionTransformer, CLIPVisionModel),
82
+ ), "Vision model is expected to be a CLIPVisionTransformer"
83
+ vision_model = classifier.clip_model.vision_model
84
+ if isinstance(vision_model, CLIPVisionModel):
85
+ vision_model = vision_model.vision_model
86
+ # assign forward hooks for each layer
87
+
88
+ for i, layer in enumerate(vision_model.encoder.layers):
89
+ mlp = layer.mlp
90
+ assert isinstance(
91
+ mlp,
92
+ (RankOneMoE),
93
+ ), f"MLP is expected to be a RankOneWeightEnsemblingMoE, but got {type(mlp)}"
94
+ # layer-wise routing weights
95
+ hook = LayerWiseRoutingWeightSaver(
96
+ self.layer_wise_routing_weights_save_path
97
+ / task_name
98
+ / f"layer_{i}.pt",
99
+ max_num=self.layer_wise_routing_weights_max_num,
100
+ )
101
+ self._layer_wise_routing_weights_save_hooks[i] = hook
102
+ self._layer_wise_routing_weights_save_hook_handles[i] = (
103
+ mlp.gate.register_forward_hook(hook)
104
+ )
105
+
106
+ def on_task_evaluation_end(self):
107
+ super().on_task_evaluation_end()
108
+ if self.layer_wise_routing_weights_save_path is not None:
109
+ # remove hooks for saving layer-wise routing weights
110
+ for i, handle in self._layer_wise_routing_weights_save_hook_handles.items():
111
+ self._layer_wise_routing_weights_save_hooks[i].save_routing_weights()
112
+ handle.remove()