fusion-bench 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +2 -0
- fusion_bench/compat/method/base_algorithm.py +7 -2
- fusion_bench/compat/modelpool/__init__.py +3 -2
- fusion_bench/compat/taskpool/__init__.py +1 -1
- fusion_bench/dataset/arc_agi/__init__.py +6 -1
- fusion_bench/dataset/arc_agi/arc.py +26 -7
- fusion_bench/dataset/arc_agi/arc_agi.py +156 -25
- fusion_bench/dataset/arc_agi/np_cache.py +0 -1
- fusion_bench/dataset/arc_agi/preprocess.py +51 -9
- fusion_bench/dataset/llama/__init__.py +1 -0
- fusion_bench/dataset/llama/alpaca.py +93 -3
- fusion_bench/dataset/llama/collate.py +72 -5
- fusion_bench/dataset/llama/metamathqa.py +50 -0
- fusion_bench/dataset/llama/preference_700k.py +70 -0
- fusion_bench/dataset/llama/stanford_shp.py +90 -0
- fusion_bench/dataset/llama/ultrachat.py +58 -0
- fusion_bench/dataset/llama/utils/__init__.py +0 -0
- fusion_bench/method/__init__.py +4 -1
- fusion_bench/method/adamerging/__init__.py +1 -1
- fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -4
- fusion_bench/method/adamerging/min_norm_solvers.py +4 -4
- fusion_bench/method/linear/expo.py +39 -0
- fusion_bench/method/lm_finetune/__init__.py +1 -0
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
- fusion_bench/method/lm_finetune/fullfinetune_sft.py +122 -150
- fusion_bench/method/lm_finetune/peftfinetune_sft.py +102 -157
- fusion_bench/method/pruning/llama_magnitude_prune.py +2 -2
- fusion_bench/method/pruning/llama_random_prune.py +2 -2
- fusion_bench/method/pruning/magnitude_diff_pruning.py +2 -1
- fusion_bench/method/rankone_moe/__init__.py +3 -0
- fusion_bench/method/rankone_moe/clip_rankone_moe.py +160 -0
- fusion_bench/method/rankone_moe/rankone_moe.py +249 -0
- fusion_bench/method/simple_average.py +1 -1
- fusion_bench/method/surgery/__init__.py +3 -0
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
- fusion_bench/mixins/__init__.py +2 -0
- fusion_bench/mixins/clip_classification.py +60 -12
- fusion_bench/mixins/fabric_training.py +320 -0
- fusion_bench/mixins/lightning_fabric.py +11 -2
- fusion_bench/modelpool/__init__.py +2 -0
- fusion_bench/modelpool/causal_lm/__init__.py +1 -1
- fusion_bench/modelpool/causal_lm/causal_lm.py +21 -22
- fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
- fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
- fusion_bench/models/chat_templates/__init__.py +1 -0
- fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
- fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
- fusion_bench/models/hf_clip.py +50 -9
- fusion_bench/models/rankone_moe.py +410 -0
- fusion_bench/models/surgery/surgerymodelwrapper.py +157 -0
- fusion_bench/models/utils.py +8 -0
- fusion_bench/models/wrappers/layer_wise_fusion.py +14 -5
- fusion_bench/models/wrappers/task_wise_fusion.py +5 -5
- fusion_bench/optim/__init__.py +2 -0
- fusion_bench/optim/exception.py +47 -0
- fusion_bench/optim/lr_scheduler/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
- fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
- fusion_bench/optim/mezo.py +0 -2
- fusion_bench/programs/fabric_fusion_program.py +5 -1
- fusion_bench/taskpool/__init__.py +10 -2
- fusion_bench/taskpool/clip_vision/__init__.py +1 -0
- fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +112 -0
- fusion_bench/taskpool/clip_vision/taskpool.py +43 -6
- fusion_bench/taskpool/llama/reward_model.py +157 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +2 -0
- fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py +2 -1
- fusion_bench/utils/hydra_utils.py +22 -0
- fusion_bench/utils/plot/__init__.py +0 -0
- fusion_bench/utils/plot/token.py +52 -0
- fusion_bench/utils/plot/token_notebook.py +127 -0
- fusion_bench/utils/type.py +5 -3
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/RECORD +104 -57
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
- fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
- fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
- fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
- fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
- fusion_bench_config/fabric_model_fusion.yaml +1 -1
- fusion_bench_config/llama_full_finetune.yaml +19 -0
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +13 -6
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +17 -9
- fusion_bench_config/method/rankone_moe/rankone_moe.yaml +26 -0
- fusion_bench_config/method/regmean/clip_regmean.yaml +1 -0
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
- fusion_bench_config/nyuv2_config.yaml +5 -1
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +18 -0
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
- fusion_bench_config/llama_weighted_average.yaml +0 -26
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/LICENSE +0 -0
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/top_level.txt +0 -0
|
@@ -1,13 +1,22 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
import logging
|
|
3
3
|
from copy import deepcopy
|
|
4
|
-
from typing import
|
|
4
|
+
from typing import ( # noqa: F401
|
|
5
|
+
Any,
|
|
6
|
+
Callable,
|
|
7
|
+
Dict,
|
|
8
|
+
Generic,
|
|
9
|
+
Iterator,
|
|
10
|
+
List,
|
|
11
|
+
Optional,
|
|
12
|
+
TypeVar,
|
|
13
|
+
)
|
|
5
14
|
|
|
6
15
|
import torch
|
|
7
16
|
from torch import Tensor, nn
|
|
8
17
|
from torch.func import functional_call
|
|
9
18
|
|
|
10
|
-
from fusion_bench.utils.type import StateDictType
|
|
19
|
+
from fusion_bench.utils.type import TorchModelType, StateDictType
|
|
11
20
|
|
|
12
21
|
__all__ = ["get_layer_wise_weights", "fuse_weights", "LayerWiseMergedModel"]
|
|
13
22
|
|
|
@@ -132,14 +141,14 @@ def fuse_weights(
|
|
|
132
141
|
}
|
|
133
142
|
|
|
134
143
|
|
|
135
|
-
class LayerWiseMergedModel(nn.Module):
|
|
144
|
+
class LayerWiseMergedModel(nn.Module, Generic[TorchModelType]):
|
|
136
145
|
_merged_state_dict: StateDictType = None
|
|
137
146
|
|
|
138
147
|
def __init__(
|
|
139
148
|
self,
|
|
140
149
|
layer_wise_weight: Tensor,
|
|
141
|
-
pretrained_model:
|
|
142
|
-
finetuned_models: List[
|
|
150
|
+
pretrained_model: TorchModelType,
|
|
151
|
+
finetuned_models: List[TorchModelType],
|
|
143
152
|
clamp_weights: bool = True,
|
|
144
153
|
tie_weights: bool = False,
|
|
145
154
|
strict: bool = True,
|
|
@@ -16,13 +16,13 @@ outputs = merged_model(inputs)
|
|
|
16
16
|
|
|
17
17
|
import functools
|
|
18
18
|
import logging
|
|
19
|
-
from typing import Any, Callable, Dict, Iterator, List, Optional # noqa: F401
|
|
19
|
+
from typing import Any, Callable, Dict, Generic, Iterator, List, Optional # noqa: F401
|
|
20
20
|
|
|
21
21
|
import torch
|
|
22
22
|
from torch import Tensor, nn
|
|
23
23
|
from torch.func import functional_call
|
|
24
24
|
|
|
25
|
-
from fusion_bench.utils.type import StateDictType
|
|
25
|
+
from fusion_bench.utils.type import TorchModelType, StateDictType
|
|
26
26
|
|
|
27
27
|
log = logging.getLogger(__name__)
|
|
28
28
|
|
|
@@ -157,14 +157,14 @@ def fuse_weights(
|
|
|
157
157
|
}
|
|
158
158
|
|
|
159
159
|
|
|
160
|
-
class TaskWiseMergedModel(nn.Module):
|
|
160
|
+
class TaskWiseMergedModel(nn.Module, Generic[TorchModelType]):
|
|
161
161
|
_merged_state_dict: StateDictType = None
|
|
162
162
|
|
|
163
163
|
def __init__(
|
|
164
164
|
self,
|
|
165
165
|
task_wise_weight: Tensor,
|
|
166
|
-
pretrained_model:
|
|
167
|
-
finetuned_models: List[
|
|
166
|
+
pretrained_model: TorchModelType,
|
|
167
|
+
finetuned_models: List[TorchModelType],
|
|
168
168
|
clamp_weights: bool = True,
|
|
169
169
|
tie_weights: bool = False,
|
|
170
170
|
strict: bool = True,
|
fusion_bench/optim/__init__.py
CHANGED
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
class NoSparseGradientError(Exception):
|
|
2
|
+
"""Raised when the gradient is sparse gradient.
|
|
3
|
+
|
|
4
|
+
:param optimizer_name: str. optimizer name.
|
|
5
|
+
:param note: str. special conditions to note (default '').
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
def __init__(self, optimizer_name: str, note: str = ""):
|
|
9
|
+
self.note: str = " " if not note else f" w/ {note} "
|
|
10
|
+
self.message: str = (
|
|
11
|
+
f"[-] {optimizer_name}{self.note}does not support sparse gradient."
|
|
12
|
+
)
|
|
13
|
+
super().__init__(self.message)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ZeroParameterSizeError(Exception):
|
|
17
|
+
"""Raised when the parameter size is 0."""
|
|
18
|
+
|
|
19
|
+
def __init__(self):
|
|
20
|
+
self.message: str = "[-] parameter size is 0"
|
|
21
|
+
super().__init__(self.message)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class NoClosureError(Exception):
|
|
25
|
+
"""Raised when there's no closure function."""
|
|
26
|
+
|
|
27
|
+
def __init__(self, optimizer_name: str, note: str = ""):
|
|
28
|
+
self.message: str = f"[-] {optimizer_name} requires closure.{note}"
|
|
29
|
+
super().__init__(self.message)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class NegativeLRError(Exception):
|
|
33
|
+
"""Raised when learning rate is negative."""
|
|
34
|
+
|
|
35
|
+
def __init__(self, lr: float, lr_type: str = ""):
|
|
36
|
+
self.note: str = lr_type if lr_type else "learning rate"
|
|
37
|
+
self.message: str = f"[-] {self.note} must be positive. ({lr} > 0)"
|
|
38
|
+
super().__init__(self.message)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class NegativeStepError(Exception):
|
|
42
|
+
"""Raised when step is negative."""
|
|
43
|
+
|
|
44
|
+
def __init__(self, num_steps: int, step_type: str = ""):
|
|
45
|
+
self.note: str = step_type if step_type else "step"
|
|
46
|
+
self.message: str = f"[-] {self.note} must be positive. ({num_steps} > 0)"
|
|
47
|
+
super().__init__(self.message)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .linear_warmup import *
|
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Modified from pytorch_optimizer: https://github.com/kozistr/pytorch_optimizer/blob/main/pytorch_optimizer/lr_scheduler/linear_warmup.py
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import math
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
from typing import List
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
from fusion_bench.optim.exception import NegativeLRError, NegativeStepError
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"BaseLinearWarmupScheduler",
|
|
16
|
+
"LinearWarmupScheduler",
|
|
17
|
+
"CosineDecayWithWarmup",
|
|
18
|
+
"PolySchedulerWithWarmup",
|
|
19
|
+
]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class BaseLinearWarmupScheduler(ABC):
|
|
23
|
+
r"""BaseLinearWarmupScheduler class.
|
|
24
|
+
|
|
25
|
+
The LR Scheduler class based on this class has linear warmup strategy.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
optimizer (torch.optim.Optimizer): Optimizer. It will set learning rate to all trainable parameters in optimizer.
|
|
29
|
+
T_max (int): Total steps to train.
|
|
30
|
+
max_lr (float): Maximum learning rate.
|
|
31
|
+
min_lr (float): Minimum learning rate.
|
|
32
|
+
init_lr (float): Initial learning rate.
|
|
33
|
+
warmup_steps (int): Steps to warm-up.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
optimizer: torch.optim.Optimizer,
|
|
39
|
+
T_max: int,
|
|
40
|
+
max_lr: float,
|
|
41
|
+
min_lr: float = 0.0,
|
|
42
|
+
init_lr: float = 0.0,
|
|
43
|
+
warmup_steps: int = 0,
|
|
44
|
+
):
|
|
45
|
+
"""
|
|
46
|
+
Initialize the BaseLinearWarmupScheduler.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
optimizer (torch.optim.Optimizer): Optimizer to apply the learning rate schedule.
|
|
50
|
+
T_max (int): Total number of training steps.
|
|
51
|
+
max_lr (float): Maximum learning rate.
|
|
52
|
+
min_lr (float): Minimum learning rate.
|
|
53
|
+
init_lr (float): Initial learning rate.
|
|
54
|
+
warmup_steps (int): Number of steps for the warm-up phase.
|
|
55
|
+
"""
|
|
56
|
+
self.optimizer = optimizer
|
|
57
|
+
self.total_steps = T_max
|
|
58
|
+
self.max_lr = max_lr
|
|
59
|
+
self.min_lr = min_lr
|
|
60
|
+
self.init_lr = init_lr
|
|
61
|
+
self.warmup_steps = warmup_steps
|
|
62
|
+
|
|
63
|
+
self.step_t: int = 0
|
|
64
|
+
self.base_lrs: List[float] = []
|
|
65
|
+
|
|
66
|
+
# record current value in self._last_lr to match API from torch.optim.lr_scheduler
|
|
67
|
+
self.last_lr: List[float] = [init_lr]
|
|
68
|
+
|
|
69
|
+
self.validate_parameters()
|
|
70
|
+
|
|
71
|
+
self._init_lr()
|
|
72
|
+
|
|
73
|
+
def validate_parameters(self):
|
|
74
|
+
"""
|
|
75
|
+
Validate the parameters to ensure they are non-negative.
|
|
76
|
+
|
|
77
|
+
Raises:
|
|
78
|
+
NegativeLRError: If any of the learning rates are negative.
|
|
79
|
+
NegativeStepError: If any of the step values are negative.
|
|
80
|
+
"""
|
|
81
|
+
if self.min_lr < 0:
|
|
82
|
+
raise NegativeLRError(self.min_lr, "min_lr")
|
|
83
|
+
|
|
84
|
+
if self.max_lr < 0:
|
|
85
|
+
raise NegativeLRError(self.max_lr, "max_lr")
|
|
86
|
+
|
|
87
|
+
if self.init_lr < 0:
|
|
88
|
+
raise NegativeLRError(self.init_lr, "init_lr")
|
|
89
|
+
|
|
90
|
+
if self.total_steps < 0:
|
|
91
|
+
raise NegativeStepError(self.total_steps, "T_max")
|
|
92
|
+
|
|
93
|
+
if self.warmup_steps < 0:
|
|
94
|
+
raise NegativeStepError(self.warmup_steps, "warmup_steps")
|
|
95
|
+
|
|
96
|
+
def _init_lr(self):
|
|
97
|
+
"""
|
|
98
|
+
Initialize the learning rate for each parameter group in the optimizer.
|
|
99
|
+
"""
|
|
100
|
+
self.base_lrs = []
|
|
101
|
+
for param_group in self.optimizer.param_groups:
|
|
102
|
+
param_group["lr"] = self.min_lr
|
|
103
|
+
self.base_lrs.append(self.min_lr)
|
|
104
|
+
|
|
105
|
+
def step(self):
|
|
106
|
+
"""
|
|
107
|
+
Update the learning rate for the current step.
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
float: The updated learning rate.
|
|
111
|
+
"""
|
|
112
|
+
if self.step_t < self.warmup_steps:
|
|
113
|
+
value = (
|
|
114
|
+
self.init_lr
|
|
115
|
+
+ (self.max_lr - self.init_lr) * self.step_t / self.warmup_steps
|
|
116
|
+
)
|
|
117
|
+
elif self.step_t == self.warmup_steps:
|
|
118
|
+
value = self.max_lr
|
|
119
|
+
else:
|
|
120
|
+
value = self._step()
|
|
121
|
+
|
|
122
|
+
self.step_t += 1
|
|
123
|
+
|
|
124
|
+
if self.optimizer is not None:
|
|
125
|
+
for param_group in self.optimizer.param_groups:
|
|
126
|
+
param_group["lr"] = value
|
|
127
|
+
|
|
128
|
+
self.last_lr = [value]
|
|
129
|
+
|
|
130
|
+
return value
|
|
131
|
+
|
|
132
|
+
@abstractmethod
|
|
133
|
+
def _step(self) -> float: # pragma: no cover
|
|
134
|
+
"""
|
|
135
|
+
Abstract method to calculate the learning rate for the current step.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
float: The calculated learning rate.
|
|
139
|
+
"""
|
|
140
|
+
raise NotImplementedError
|
|
141
|
+
|
|
142
|
+
def get_lr(self) -> float:
|
|
143
|
+
"""
|
|
144
|
+
Get the current learning rate.
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
float: The current learning rate.
|
|
148
|
+
"""
|
|
149
|
+
return self.last_lr[0]
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class LinearWarmupScheduler(BaseLinearWarmupScheduler):
|
|
153
|
+
r"""Linear LR Scheduler w/ linear warmup."""
|
|
154
|
+
|
|
155
|
+
def _step(self) -> float:
|
|
156
|
+
"""
|
|
157
|
+
Calculate the learning rate for the current step using a linear decay.
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
float: The calculated learning rate.
|
|
161
|
+
"""
|
|
162
|
+
return self.max_lr + (self.min_lr - self.max_lr) * (
|
|
163
|
+
self.step_t - self.warmup_steps
|
|
164
|
+
) / (self.total_steps - self.warmup_steps)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
class CosineDecayWithWarmup(BaseLinearWarmupScheduler):
|
|
168
|
+
r"""Cosine LR Scheduler w/ linear warmup."""
|
|
169
|
+
|
|
170
|
+
def _step(self) -> float:
|
|
171
|
+
"""
|
|
172
|
+
Calculate the learning rate for the current step using a cosine decay.
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
float: The calculated learning rate.
|
|
176
|
+
"""
|
|
177
|
+
phase: float = (
|
|
178
|
+
(self.step_t - self.warmup_steps)
|
|
179
|
+
/ (self.total_steps - self.warmup_steps)
|
|
180
|
+
* math.pi
|
|
181
|
+
)
|
|
182
|
+
return self.min_lr + (self.max_lr - self.min_lr) * (np.cos(phase) + 1.0) / 2.0
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
class PolySchedulerWithWarmup(BaseLinearWarmupScheduler):
|
|
186
|
+
r"""Poly LR Scheduler.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
poly_order (float): LR scheduler decreases with steps.
|
|
190
|
+
"""
|
|
191
|
+
|
|
192
|
+
def __init__(self, optimizer, poly_order: float = 0.5, **kwargs):
|
|
193
|
+
"""
|
|
194
|
+
Initialize the PolySchedulerWithWarmup.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
optimizer (torch.optim.Optimizer): Optimizer to apply the learning rate schedule.
|
|
198
|
+
poly_order (float): Order of the polynomial for the learning rate decay.
|
|
199
|
+
kwargs: Additional arguments for the base class.
|
|
200
|
+
|
|
201
|
+
Raises:
|
|
202
|
+
ValueError: If poly_order is not positive.
|
|
203
|
+
"""
|
|
204
|
+
self.poly_order = poly_order
|
|
205
|
+
|
|
206
|
+
if poly_order <= 0:
|
|
207
|
+
raise ValueError(f"[-] poly_order must be positive. {poly_order}")
|
|
208
|
+
|
|
209
|
+
super().__init__(optimizer, **kwargs)
|
|
210
|
+
|
|
211
|
+
def _step(self) -> float:
|
|
212
|
+
"""
|
|
213
|
+
Calculate the learning rate for the current step using a polynomial decay.
|
|
214
|
+
|
|
215
|
+
Returns:
|
|
216
|
+
float: The calculated learning rate.
|
|
217
|
+
"""
|
|
218
|
+
return (
|
|
219
|
+
self.min_lr
|
|
220
|
+
+ (self.max_lr - self.min_lr)
|
|
221
|
+
* (self.step_t - self.warmup_steps) ** self.poly_order
|
|
222
|
+
)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .visualization import *
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module provides utilities for visualizing learning rate schedulers.
|
|
3
|
+
|
|
4
|
+
Functions:
|
|
5
|
+
simulate_scheduler(lr_scheduler, steps): Simulates the learning rate scheduler for a given number of steps.
|
|
6
|
+
plot_lr_schedulers(lr_schedulers, steps, titles): Plots the learning rates of one or more schedulers over a number of steps.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import TYPE_CHECKING, List, Union
|
|
10
|
+
|
|
11
|
+
import matplotlib.pyplot as plt
|
|
12
|
+
import torch
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from torch.optim.lr_scheduler import LRScheduler
|
|
16
|
+
|
|
17
|
+
__all__ = ["simulate_scheduler", "plot_lr_schedulers"]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def simulate_scheduler(lr_scheduler, steps: int):
|
|
21
|
+
"""
|
|
22
|
+
Simulates the learning rate scheduler for a given number of steps.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
lr_scheduler (torch.optim.lr_scheduler.LRScheduler): The learning rate scheduler object.
|
|
26
|
+
steps (int): The number of steps to simulate.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
List[float]: A list of learning rates for each step.
|
|
30
|
+
"""
|
|
31
|
+
lrs = []
|
|
32
|
+
for _ in range(steps):
|
|
33
|
+
lr = lr_scheduler.step()
|
|
34
|
+
lrs.append(lr)
|
|
35
|
+
return lrs
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def plot_lr_schedulers(
|
|
39
|
+
lr_schedulers: Union["LRScheduler", List["LRScheduler"]],
|
|
40
|
+
steps: int,
|
|
41
|
+
titles: Union[str, List[str]],
|
|
42
|
+
show_plot: bool = True,
|
|
43
|
+
):
|
|
44
|
+
"""
|
|
45
|
+
Plots the learning rates of one or more schedulers over a number of steps.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
lr_schedulers (Union[LRScheduler, List[LRScheduler]]): One or more learning rate scheduler objects.
|
|
49
|
+
steps (int): The number of steps to simulate.
|
|
50
|
+
titles (Union[str, List[str]]): Titles for the plots.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
fig, axes: The matplotlib figure and axes objects.
|
|
54
|
+
"""
|
|
55
|
+
# Handle single scheduler
|
|
56
|
+
if isinstance(lr_schedulers, torch.optim.lr_scheduler.LRScheduler):
|
|
57
|
+
lr_schedulers = [lr_schedulers]
|
|
58
|
+
if isinstance(titles, str):
|
|
59
|
+
titles = [titles]
|
|
60
|
+
|
|
61
|
+
fig, axs = plt.subplots(len(lr_schedulers), 1, figsize=(5, 3 * len(lr_schedulers)))
|
|
62
|
+
if len(lr_schedulers) == 1:
|
|
63
|
+
axs = [axs]
|
|
64
|
+
|
|
65
|
+
for i, (scheduler, title) in enumerate(zip(lr_schedulers, titles)):
|
|
66
|
+
lrs = simulate_scheduler(scheduler, steps)
|
|
67
|
+
axs[i].plot(lrs, label=title)
|
|
68
|
+
axs[i].set_title(title)
|
|
69
|
+
axs[i].set_xlabel("Steps")
|
|
70
|
+
axs[i].set_ylabel("Learning Rate")
|
|
71
|
+
axs[i].legend()
|
|
72
|
+
axs[i].grid(True)
|
|
73
|
+
|
|
74
|
+
plt.tight_layout()
|
|
75
|
+
if show_plot:
|
|
76
|
+
plt.show()
|
|
77
|
+
return fig, axs
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
# Example usage
|
|
81
|
+
if __name__ == "__main__":
|
|
82
|
+
from fusion_bench.optim.lr_scheduler.linear_warmup import (
|
|
83
|
+
CosineDecayWithWarmup,
|
|
84
|
+
LinearWarmupScheduler,
|
|
85
|
+
PolySchedulerWithWarmup,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
# Dummy optimizer
|
|
89
|
+
optimizer = torch.optim.SGD(
|
|
90
|
+
[torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))], lr=0.1
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
# Define the schedulers
|
|
94
|
+
linear_scheduler = LinearWarmupScheduler(
|
|
95
|
+
optimizer, t_max=100, max_lr=0.1, min_lr=0.01, init_lr=0.0, warmup_steps=10
|
|
96
|
+
)
|
|
97
|
+
cosine_scheduler = CosineDecayWithWarmup(
|
|
98
|
+
optimizer, t_max=100, max_lr=0.1, min_lr=0.01, init_lr=0.0, warmup_steps=10
|
|
99
|
+
)
|
|
100
|
+
poly_scheduler = PolySchedulerWithWarmup(
|
|
101
|
+
optimizer,
|
|
102
|
+
t_max=100,
|
|
103
|
+
max_lr=0.1,
|
|
104
|
+
min_lr=0.01,
|
|
105
|
+
init_lr=0.0,
|
|
106
|
+
warmup_steps=40,
|
|
107
|
+
poly_order=2.0,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
# Plot the learning rates
|
|
111
|
+
plot_lr_schedulers(
|
|
112
|
+
[linear_scheduler, cosine_scheduler, poly_scheduler],
|
|
113
|
+
steps=100,
|
|
114
|
+
titles=[
|
|
115
|
+
"Linear Warmup",
|
|
116
|
+
"Cosine Decay with Warmup",
|
|
117
|
+
"Poly Scheduler with Warmup",
|
|
118
|
+
],
|
|
119
|
+
)
|
fusion_bench/optim/mezo.py
CHANGED
|
@@ -236,7 +236,11 @@ class FabricModelFusionProgram(
|
|
|
236
236
|
self.save_merged_model(merged_model)
|
|
237
237
|
if self.taskpool is not None:
|
|
238
238
|
report = self.evaluate_merged_model(self.taskpool, merged_model)
|
|
239
|
-
|
|
239
|
+
try:
|
|
240
|
+
print_json(report, print_type=False)
|
|
241
|
+
except Exception as e:
|
|
242
|
+
log.warning(f"Failed to pretty print the report: {e}")
|
|
243
|
+
print(report)
|
|
240
244
|
if self.report_save_path is not None:
|
|
241
245
|
# save report (Dict) to a file
|
|
242
246
|
# if the directory of `save_report` does not exists, create it
|
|
@@ -7,7 +7,11 @@ from fusion_bench.utils.lazy_imports import LazyImporter
|
|
|
7
7
|
|
|
8
8
|
_import_structure = {
|
|
9
9
|
"base_pool": ["BaseTaskPool"],
|
|
10
|
-
"clip_vision": [
|
|
10
|
+
"clip_vision": [
|
|
11
|
+
"CLIPVisionModelTaskPool",
|
|
12
|
+
"SparseWEMoECLIPVisionModelTaskPool",
|
|
13
|
+
"RankoneWEMoECLIPVisionModelTaskPool",
|
|
14
|
+
],
|
|
11
15
|
"dummy": ["DummyTaskPool"],
|
|
12
16
|
"gpt2_text_classification": ["GPT2TextClassificationTaskPool"],
|
|
13
17
|
"nyuv2_taskpool": ["NYUv2TaskPool"],
|
|
@@ -17,7 +21,11 @@ _import_structure = {
|
|
|
17
21
|
|
|
18
22
|
if TYPE_CHECKING:
|
|
19
23
|
from .base_pool import BaseTaskPool
|
|
20
|
-
from .clip_vision import
|
|
24
|
+
from .clip_vision import (
|
|
25
|
+
CLIPVisionModelTaskPool,
|
|
26
|
+
RankoneWEMoECLIPVisionModelTaskPool,
|
|
27
|
+
SparseWEMoECLIPVisionModelTaskPool,
|
|
28
|
+
)
|
|
21
29
|
from .dummy import DummyTaskPool
|
|
22
30
|
from .gpt2_text_classification import GPT2TextClassificationTaskPool
|
|
23
31
|
from .llama import LlamaTestGenerationTaskPool
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
from copy import deepcopy
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Any, Dict, Optional, Tuple
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch import Tensor
|
|
7
|
+
from torch.utils.hooks import RemovableHandle
|
|
8
|
+
from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
|
|
9
|
+
from transformers.models.clip.modeling_clip import CLIPVisionTransformer
|
|
10
|
+
|
|
11
|
+
from fusion_bench.models.hf_clip import HFCLIPClassifier
|
|
12
|
+
from fusion_bench.models.rankone_moe import RankOneMoE
|
|
13
|
+
|
|
14
|
+
from .taskpool import CLIPVisionModelTaskPool
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class LayerWiseRoutingWeightSaver:
|
|
18
|
+
def __init__(self, save_path: Path, max_num: Optional[int] = None):
|
|
19
|
+
self.save_path = save_path
|
|
20
|
+
self.max_num = max_num
|
|
21
|
+
self.routing_weights = []
|
|
22
|
+
|
|
23
|
+
def __call__(self, module, input: Tuple[Tensor], output: Tensor):
|
|
24
|
+
assert isinstance(output, Tensor), "Output is expected to be a Tensor"
|
|
25
|
+
# (batch_size, num_tokens, num_experts)
|
|
26
|
+
routing_weights = output.detach().cpu()
|
|
27
|
+
if self.max_num is not None and self.max_num > 0:
|
|
28
|
+
if len(self.routing_weights) > self.max_num:
|
|
29
|
+
return
|
|
30
|
+
elif routing_weights.size(0) + len(self.routing_weights) > self.max_num:
|
|
31
|
+
self.routing_weights.append(
|
|
32
|
+
routing_weights[: self.max_num - len(self.routing_weights)]
|
|
33
|
+
)
|
|
34
|
+
else:
|
|
35
|
+
self.routing_weights.append(routing_weights)
|
|
36
|
+
else:
|
|
37
|
+
self.routing_weights.append(routing_weights)
|
|
38
|
+
|
|
39
|
+
def save_routing_weights(self):
|
|
40
|
+
routing_weights = torch.cat(self.routing_weights, dim=0)
|
|
41
|
+
if self.save_path is not None:
|
|
42
|
+
self.save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
43
|
+
print(f"Saving routing weights to {self.save_path}")
|
|
44
|
+
torch.save(routing_weights, self.save_path)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class RankoneMoECLIPVisionModelTaskPool(CLIPVisionModelTaskPool):
|
|
48
|
+
|
|
49
|
+
# hooks and handles for saving layer-wise routing weights
|
|
50
|
+
_layer_wise_routing_weights_save_hooks: Dict[Any, LayerWiseRoutingWeightSaver] = {}
|
|
51
|
+
_layer_wise_routing_weights_save_hook_handles: Dict[Any, RemovableHandle] = {}
|
|
52
|
+
|
|
53
|
+
_config_mapping = CLIPVisionModelTaskPool._config_mapping | {
|
|
54
|
+
"_layer_wise_routing_weights_save_path": "layer_wise_routing_weights_save_path",
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
layer_wise_routing_weights_save_path: Optional[str],
|
|
60
|
+
layer_wise_routing_weights_max_num: Optional[int] = None,
|
|
61
|
+
**kwargs,
|
|
62
|
+
):
|
|
63
|
+
# save path for layer-wise routing weights
|
|
64
|
+
self._layer_wise_routing_weights_save_path = (
|
|
65
|
+
layer_wise_routing_weights_save_path
|
|
66
|
+
)
|
|
67
|
+
self.layer_wise_routing_weights_save_path = (
|
|
68
|
+
Path(layer_wise_routing_weights_save_path)
|
|
69
|
+
if layer_wise_routing_weights_save_path is not None
|
|
70
|
+
else None
|
|
71
|
+
)
|
|
72
|
+
self.layer_wise_routing_weights_max_num = layer_wise_routing_weights_max_num
|
|
73
|
+
super().__init__(**kwargs)
|
|
74
|
+
|
|
75
|
+
def on_task_evaluation_begin(self, classifier: HFCLIPClassifier, task_name: str):
|
|
76
|
+
super().on_task_evaluation_begin(classifier, task_name)
|
|
77
|
+
if self.layer_wise_routing_weights_save_path is not None:
|
|
78
|
+
# setup hooks for saving layer-wise routing weights
|
|
79
|
+
assert isinstance(
|
|
80
|
+
classifier.clip_model.vision_model,
|
|
81
|
+
(CLIPVisionTransformer, CLIPVisionModel),
|
|
82
|
+
), "Vision model is expected to be a CLIPVisionTransformer"
|
|
83
|
+
vision_model = classifier.clip_model.vision_model
|
|
84
|
+
if isinstance(vision_model, CLIPVisionModel):
|
|
85
|
+
vision_model = vision_model.vision_model
|
|
86
|
+
# assign forward hooks for each layer
|
|
87
|
+
|
|
88
|
+
for i, layer in enumerate(vision_model.encoder.layers):
|
|
89
|
+
mlp = layer.mlp
|
|
90
|
+
assert isinstance(
|
|
91
|
+
mlp,
|
|
92
|
+
(RankOneMoE),
|
|
93
|
+
), f"MLP is expected to be a RankOneWeightEnsemblingMoE, but got {type(mlp)}"
|
|
94
|
+
# layer-wise routing weights
|
|
95
|
+
hook = LayerWiseRoutingWeightSaver(
|
|
96
|
+
self.layer_wise_routing_weights_save_path
|
|
97
|
+
/ task_name
|
|
98
|
+
/ f"layer_{i}.pt",
|
|
99
|
+
max_num=self.layer_wise_routing_weights_max_num,
|
|
100
|
+
)
|
|
101
|
+
self._layer_wise_routing_weights_save_hooks[i] = hook
|
|
102
|
+
self._layer_wise_routing_weights_save_hook_handles[i] = (
|
|
103
|
+
mlp.gate.register_forward_hook(hook)
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
def on_task_evaluation_end(self):
|
|
107
|
+
super().on_task_evaluation_end()
|
|
108
|
+
if self.layer_wise_routing_weights_save_path is not None:
|
|
109
|
+
# remove hooks for saving layer-wise routing weights
|
|
110
|
+
for i, handle in self._layer_wise_routing_weights_save_hook_handles.items():
|
|
111
|
+
self._layer_wise_routing_weights_save_hooks[i].save_routing_weights()
|
|
112
|
+
handle.remove()
|