fusion-bench 0.2.6__py3-none-any.whl → 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +1 -0
- fusion_bench/compat/method/base_algorithm.py +7 -1
- fusion_bench/compat/modelpool/__init__.py +1 -1
- fusion_bench/compat/taskpool/__init__.py +1 -1
- fusion_bench/dataset/arc_agi/arc.py +5 -0
- fusion_bench/dataset/arc_agi/preprocess.py +1 -1
- fusion_bench/dataset/llama/__init__.py +1 -0
- fusion_bench/dataset/llama/alpaca.py +93 -3
- fusion_bench/dataset/llama/collate.py +62 -2
- fusion_bench/dataset/llama/metamathqa.py +50 -0
- fusion_bench/dataset/llama/preference_700k.py +70 -0
- fusion_bench/dataset/llama/stanford_shp.py +90 -0
- fusion_bench/dataset/llama/ultrachat.py +58 -0
- fusion_bench/dataset/llama/utils/__init__.py +0 -0
- fusion_bench/method/__init__.py +1 -1
- fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -4
- fusion_bench/method/adamerging/min_norm_solvers.py +4 -4
- fusion_bench/method/linear/expo.py +39 -0
- fusion_bench/method/lm_finetune/__init__.py +1 -0
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
- fusion_bench/method/lm_finetune/fullfinetune_sft.py +90 -160
- fusion_bench/method/lm_finetune/peftfinetune_sft.py +49 -139
- fusion_bench/method/pruning/llama_magnitude_prune.py +2 -2
- fusion_bench/method/pruning/llama_random_prune.py +2 -2
- fusion_bench/method/surgery/__init__.py +3 -0
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
- fusion_bench/mixins/__init__.py +2 -0
- fusion_bench/mixins/clip_classification.py +58 -5
- fusion_bench/mixins/fabric_training.py +320 -0
- fusion_bench/mixins/lightning_fabric.py +9 -0
- fusion_bench/modelpool/__init__.py +2 -0
- fusion_bench/modelpool/causal_lm/__init__.py +1 -1
- fusion_bench/modelpool/causal_lm/causal_lm.py +21 -22
- fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
- fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
- fusion_bench/models/chat_templates/__init__.py +1 -0
- fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
- fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
- fusion_bench/models/hf_clip.py +50 -9
- fusion_bench/models/surgery/surgerymodelwrapper.py +157 -0
- fusion_bench/models/utils.py +8 -0
- fusion_bench/models/wrappers/layer_wise_fusion.py +14 -5
- fusion_bench/models/wrappers/task_wise_fusion.py +5 -5
- fusion_bench/optim/__init__.py +2 -0
- fusion_bench/optim/exception.py +47 -0
- fusion_bench/optim/lr_scheduler/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
- fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
- fusion_bench/optim/mezo.py +0 -2
- fusion_bench/programs/fabric_fusion_program.py +5 -1
- fusion_bench/taskpool/clip_vision/taskpool.py +43 -6
- fusion_bench/taskpool/llama/reward_model.py +157 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +2 -0
- fusion_bench/utils/hydra_utils.py +22 -0
- fusion_bench/utils/plot/__init__.py +0 -0
- fusion_bench/utils/plot/token.py +52 -0
- fusion_bench/utils/plot/token_notebook.py +127 -0
- fusion_bench/utils/type.py +5 -3
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.7.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.7.dist-info}/RECORD +87 -47
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
- fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
- fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
- fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
- fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
- fusion_bench_config/fabric_model_fusion.yaml +1 -1
- fusion_bench_config/llama_full_finetune.yaml +19 -0
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +11 -4
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +4 -2
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
- fusion_bench_config/nyuv2_config.yaml +5 -1
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
- fusion_bench_config/llama_weighted_average.yaml +0 -26
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.7.dist-info}/LICENSE +0 -0
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.7.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.7.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.7.dist-info}/top_level.txt +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
from typing import Literal, Optional, Union # noqa: F401
|
|
1
|
+
from typing import Dict, Literal, Optional, Union # noqa: F401
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
|
-
from torch import
|
|
4
|
+
from torch import nn
|
|
5
5
|
from tqdm.auto import tqdm
|
|
6
6
|
from transformers import LlamaForCausalLM, LlamaModel
|
|
7
7
|
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Implementation of the Layer-Wise AdaMerging+Surgery Algorithm.
|
|
3
|
+
|
|
4
|
+
For more details, please refer to:
|
|
5
|
+
|
|
6
|
+
- (ICLR 2024) Yang, et.al. AdaMerging: Adaptive Model Merging for Multi-Task Learning. http://arxiv.org/abs/2310.02575
|
|
7
|
+
- (ICML 2024) Yang, et.al. Representation Surgery for Multi-Task Model Merging. https://arxiv.org/abs/2402.02705
|
|
8
|
+
|
|
9
|
+
Basic Example:
|
|
10
|
+
|
|
11
|
+
```shell
|
|
12
|
+
fusion_bench \
|
|
13
|
+
method=surgery/adamerging_surgery \
|
|
14
|
+
modelpool=CLIPVisionModelPool/clip-vit-base-patch32_TA8 \
|
|
15
|
+
taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TA8
|
|
16
|
+
```
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
import copy
|
|
20
|
+
import functools
|
|
21
|
+
import gc
|
|
22
|
+
import logging
|
|
23
|
+
from typing import TYPE_CHECKING, cast
|
|
24
|
+
|
|
25
|
+
import torch
|
|
26
|
+
import torch.nn.functional as F
|
|
27
|
+
from torch.utils.data import DataLoader
|
|
28
|
+
from tqdm import tqdm
|
|
29
|
+
from transformers import CLIPVisionModel
|
|
30
|
+
|
|
31
|
+
from fusion_bench.dataset.clip_dataset import CLIPDataset
|
|
32
|
+
from fusion_bench.method.adamerging.layer_wise_adamerging import (
|
|
33
|
+
LayerWiseAdaMergingAlgorithm,
|
|
34
|
+
)
|
|
35
|
+
from fusion_bench.method.adamerging.utils import get_memory_usage
|
|
36
|
+
from fusion_bench.mixins import CLIPClassificationMixin
|
|
37
|
+
from fusion_bench.modelpool import CLIPVisionModelPool
|
|
38
|
+
from fusion_bench.models.surgery.surgerymodelwrapper import SurgeryModelWrapper
|
|
39
|
+
from fusion_bench.models.wrappers.layer_wise_fusion import LayerWiseMergedModel
|
|
40
|
+
|
|
41
|
+
log = logging.getLogger(__name__)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class CLIPLayerWiseAdaMergingSurgeryAlgorithm(
|
|
45
|
+
CLIPClassificationMixin,
|
|
46
|
+
LayerWiseAdaMergingAlgorithm,
|
|
47
|
+
):
|
|
48
|
+
|
|
49
|
+
def on_test_time_adaptation_start(self):
|
|
50
|
+
"""
|
|
51
|
+
Here we load the CLIP processor and construct the zero-shot classification head for each task.
|
|
52
|
+
"""
|
|
53
|
+
self.setup_zero_shot_classification_head()
|
|
54
|
+
|
|
55
|
+
@functools.cache
|
|
56
|
+
def get_shuffled_test_loader_iter(self, task: str):
|
|
57
|
+
return super().get_shuffled_test_loader_iter(
|
|
58
|
+
task,
|
|
59
|
+
batch_size=self.config.batch_size,
|
|
60
|
+
num_workers=self.config.num_workers,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
def run(self, modelpool: CLIPVisionModelPool, **kwargs):
|
|
64
|
+
"""
|
|
65
|
+
Run the Layer-Wise AdaMerging+Surgery Algorithm.
|
|
66
|
+
|
|
67
|
+
This method constructs the wrapped model and performs test-time adaptation if necessary. Then, it will perform surgery.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
modelpool (ModelPool): The model pool containing the pretrained and fine-tuned models.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
LayerWiseMergedModel: The merged model after test-time adaptation.
|
|
74
|
+
"""
|
|
75
|
+
log.info("Fusing models using layer-wise adaptive merging.")
|
|
76
|
+
self.modelpool = modelpool
|
|
77
|
+
self.log_hyperparams(self.config)
|
|
78
|
+
|
|
79
|
+
# === Start of the AdaMerging Algorithm ===
|
|
80
|
+
with self.profile("construct the wrapped model"):
|
|
81
|
+
module = cast(
|
|
82
|
+
LayerWiseMergedModel[CLIPVisionModel],
|
|
83
|
+
self.construct_layer_wise_merged_model(modelpool),
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
if self.config.weights is not None:
|
|
87
|
+
# skip the test-time adaptation
|
|
88
|
+
merged_model = copy.deepcopy(module.merge_and_unload())
|
|
89
|
+
else:
|
|
90
|
+
with self.profile("test-time adaptation"):
|
|
91
|
+
module = self.test_time_adaptation(module)
|
|
92
|
+
if self.config.get("save_merging_weights", False):
|
|
93
|
+
self.save_merging_weights(
|
|
94
|
+
self.config.save_merging_weights, module.merge_weight
|
|
95
|
+
)
|
|
96
|
+
merged_model = copy.deepcopy(module.merge_and_unload())
|
|
97
|
+
|
|
98
|
+
# free memory
|
|
99
|
+
del module
|
|
100
|
+
gc.collect()
|
|
101
|
+
torch.cuda.empty_cache()
|
|
102
|
+
|
|
103
|
+
# === Start of the Surgery Algorithm ===
|
|
104
|
+
log.info("start performing Surgery")
|
|
105
|
+
alpha_model = SurgeryModelWrapper(
|
|
106
|
+
merged_model,
|
|
107
|
+
modelpool.model_names,
|
|
108
|
+
projection_dim=merged_model.config.projection_dim,
|
|
109
|
+
)
|
|
110
|
+
alpha_model = self.fabric.setup(alpha_model)
|
|
111
|
+
log.info(get_memory_usage("after freeing memory, the memory usage of GPU is:"))
|
|
112
|
+
|
|
113
|
+
optimizer = torch.optim.Adam(
|
|
114
|
+
alpha_model.collect_trainable_params(),
|
|
115
|
+
lr=1e-3,
|
|
116
|
+
betas=(0.9, 0.999),
|
|
117
|
+
weight_decay=0.0,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
finetuned_models = {
|
|
121
|
+
model_name: modelpool.load_model(model_name)
|
|
122
|
+
for model_name in modelpool.model_names
|
|
123
|
+
}
|
|
124
|
+
for name, model in finetuned_models.items():
|
|
125
|
+
model.requires_grad_(False)
|
|
126
|
+
model = self.fabric.to_device(model)
|
|
127
|
+
model.eval()
|
|
128
|
+
|
|
129
|
+
for iteration in tqdm(
|
|
130
|
+
range(self.config.surgery_steps),
|
|
131
|
+
"surgery",
|
|
132
|
+
dynamic_ncols=True,
|
|
133
|
+
):
|
|
134
|
+
for dataset_name in modelpool.model_names:
|
|
135
|
+
batch = next(self.get_shuffled_test_loader_iter(dataset_name))
|
|
136
|
+
finetuned_feature = self.compute_features(
|
|
137
|
+
finetuned_models[dataset_name], batch[0]
|
|
138
|
+
)
|
|
139
|
+
features, _, _ = alpha_model.compute_surgery_features(
|
|
140
|
+
lambda model: self.compute_features(model, batch[0]),
|
|
141
|
+
dataset_name,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
loss = F.l1_loss(features, finetuned_feature)
|
|
145
|
+
|
|
146
|
+
optimizer.zero_grad()
|
|
147
|
+
loss.backward()
|
|
148
|
+
optimizer.step()
|
|
149
|
+
|
|
150
|
+
if ((iteration + 1) % self.config.eval_iterations) == 0:
|
|
151
|
+
# print(list(alpha_model.collect_trainable_params()))
|
|
152
|
+
# Evaluate try to use the test module in fusion bench
|
|
153
|
+
log.info(f"iteration: {iteration+1}")
|
|
154
|
+
self._program.evaluate_merged_model(self._program.taskpool, alpha_model)
|
|
155
|
+
|
|
156
|
+
log.info("test the result of Adamerging")
|
|
157
|
+
return merged_model
|
fusion_bench/mixins/__init__.py
CHANGED
|
@@ -10,10 +10,12 @@ _import_structure = {
|
|
|
10
10
|
"serialization": ["YAMLSerializationMixin", "BaseYAMLSerializableModel"],
|
|
11
11
|
"simple_profiler": ["SimpleProfilerMixin"],
|
|
12
12
|
"clip_classification": ["CLIPClassificationMixin"],
|
|
13
|
+
"fabric_training": ["FabricTrainingMixin"],
|
|
13
14
|
}
|
|
14
15
|
|
|
15
16
|
if TYPE_CHECKING:
|
|
16
17
|
from .clip_classification import CLIPClassificationMixin
|
|
18
|
+
from .fabric_training import FabricTrainingMixin
|
|
17
19
|
from .lightning_fabric import LightningFabricMixin
|
|
18
20
|
from .serialization import BaseYAMLSerializableModel, YAMLSerializationMixin
|
|
19
21
|
from .simple_profiler import SimpleProfilerMixin
|
|
@@ -2,7 +2,17 @@ import functools
|
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
4
|
from copy import deepcopy
|
|
5
|
-
from typing import
|
|
5
|
+
from typing import ( # noqa: F401
|
|
6
|
+
TYPE_CHECKING,
|
|
7
|
+
Any,
|
|
8
|
+
Dict,
|
|
9
|
+
List,
|
|
10
|
+
Optional,
|
|
11
|
+
Tuple,
|
|
12
|
+
TypeVar,
|
|
13
|
+
Union,
|
|
14
|
+
cast,
|
|
15
|
+
)
|
|
6
16
|
|
|
7
17
|
import torch
|
|
8
18
|
from omegaconf import DictConfig
|
|
@@ -18,10 +28,12 @@ from fusion_bench.models.hf_clip import HFCLIPClassifier
|
|
|
18
28
|
from fusion_bench.tasks.clip_classification import get_classnames_and_templates
|
|
19
29
|
from fusion_bench.utils.data import InfiniteDataLoader
|
|
20
30
|
|
|
21
|
-
|
|
31
|
+
if TYPE_CHECKING:
|
|
32
|
+
from transformers.models.clip.modeling_clip import CLIPVisionTransformer
|
|
22
33
|
|
|
23
|
-
|
|
34
|
+
log = logging.getLogger(__name__)
|
|
24
35
|
|
|
36
|
+
# disable tokenizers parallelism by default to avoid deadlocks
|
|
25
37
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
26
38
|
|
|
27
39
|
|
|
@@ -175,13 +187,30 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
175
187
|
|
|
176
188
|
def compute_logits(
|
|
177
189
|
self,
|
|
178
|
-
module: Union[nn.Module, CLIPVisionModel],
|
|
190
|
+
module: Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"],
|
|
179
191
|
images: torch.Tensor,
|
|
180
192
|
task: str,
|
|
193
|
+
image_embeds: Optional[torch.Tensor] = None,
|
|
181
194
|
) -> torch.Tensor:
|
|
195
|
+
"""
|
|
196
|
+
Compute the logits of the images for a given task.
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
module (Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"]): The module to compute the logits.
|
|
200
|
+
images (torch.Tensor): The images to compute the logits.
|
|
201
|
+
task (str): The task to compute the logits.
|
|
202
|
+
image_embeds (Optional[torch.Tensor]): The precomputed image embeddings. If None, the image embeddings will be computed.
|
|
203
|
+
|
|
204
|
+
Returns:
|
|
205
|
+
torch.Tensor: The logits of the images.
|
|
206
|
+
"""
|
|
182
207
|
text_embeds = self.zeroshot_weights[task]
|
|
183
208
|
|
|
184
|
-
image_embeds
|
|
209
|
+
if image_embeds is None:
|
|
210
|
+
image_embeds = module(images)[1]
|
|
211
|
+
assert isinstance(
|
|
212
|
+
image_embeds, torch.Tensor
|
|
213
|
+
), f"`image_embeds` must be a tensor, but got {type(image_embeds)}"
|
|
185
214
|
image_embeds = self.visual_projection(image_embeds)
|
|
186
215
|
|
|
187
216
|
# normalize embeddings
|
|
@@ -194,3 +223,27 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
194
223
|
logits_per_image = logits_per_text.t()
|
|
195
224
|
|
|
196
225
|
return logits_per_image
|
|
226
|
+
|
|
227
|
+
def compute_features(
|
|
228
|
+
self,
|
|
229
|
+
module: Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"],
|
|
230
|
+
images: torch.Tensor,
|
|
231
|
+
normalize: bool = True,
|
|
232
|
+
) -> torch.Tensor:
|
|
233
|
+
"""
|
|
234
|
+
Extracts image features using CLIP's vision encoder and visual projection.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
module (Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"]): The CLIP vision encoder module.
|
|
238
|
+
images (torch.Tensor): Input image batch to process.
|
|
239
|
+
normalize (bool): Whether to normalize the image embeddings.
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
torch.Tensor: Normalized image embeddings with dimension matching CLIP's projection space (`projection_dim` in model config).
|
|
243
|
+
"""
|
|
244
|
+
image_embeds = module(images)[1]
|
|
245
|
+
image_embeds = self.visual_projection(image_embeds)
|
|
246
|
+
|
|
247
|
+
if normalize:
|
|
248
|
+
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
|
249
|
+
return image_embeds
|
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
from abc import abstractmethod
|
|
5
|
+
from typing import TYPE_CHECKING, Literal, Union
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torch import Tensor, nn
|
|
9
|
+
from tqdm.auto import tqdm
|
|
10
|
+
|
|
11
|
+
from .lightning_fabric import LightningFabricMixin
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from lightning.fabric.wrappers import (
|
|
15
|
+
_FabricDataLoader,
|
|
16
|
+
_FabricModule,
|
|
17
|
+
_FabricOptimizer,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
log = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class FabricTrainingMixin(LightningFabricMixin):
|
|
24
|
+
"""
|
|
25
|
+
This is a general purpose mixin for training a model with PyTorch Lightning.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
_latest_saved_checkpoint_global_step: int = -1
|
|
29
|
+
"""The global step index of the latest saved checkpoint."""
|
|
30
|
+
_expected_total_steps: int = None
|
|
31
|
+
"""The expected total number of steps of the entire training."""
|
|
32
|
+
is_training: bool
|
|
33
|
+
"""Whether the training is in progress. If set to False, the training will stop."""
|
|
34
|
+
epoch_idx: int
|
|
35
|
+
"""The epoch index, which is the number of epochs completed."""
|
|
36
|
+
global_step_idx: int
|
|
37
|
+
"""The global step index, which is the number of parameter update steps."""
|
|
38
|
+
max_epochs: int
|
|
39
|
+
"""Max number of epochs of the entire training."""
|
|
40
|
+
max_steps: int
|
|
41
|
+
"""Max number of parameter update steps of the entire training."""
|
|
42
|
+
max_steps_per_epoch: int
|
|
43
|
+
"""Max number of parameter update steps per epoch."""
|
|
44
|
+
gradient_clip_algorithm: Literal["value", "norm"]
|
|
45
|
+
"""The algorithm to clip gradients. Available options: 'value', 'norm'."""
|
|
46
|
+
gradient_clip_val: float
|
|
47
|
+
"""The value to clip gradients. If None, no clipping is applied."""
|
|
48
|
+
accumulate_grad_batches: int
|
|
49
|
+
"""The number of gradient accumulation steps. The effective global batch size is `the batch size per device` x `the number of devices` x `the number of gradient accumulation steps`."""
|
|
50
|
+
lr_scheduler_interval: Literal["step", "epoch"]
|
|
51
|
+
"""The interval to run the learning rate scheduler. Available options: 'step', 'epoch'."""
|
|
52
|
+
lr_scheduler_frequency: int
|
|
53
|
+
"""The frequency to run the learning rate scheduler."""
|
|
54
|
+
checkpoint_save_interval: Literal["step", "epoch"]
|
|
55
|
+
"""The interval to save the model checkpoint. Available options: 'step', 'epoch'."""
|
|
56
|
+
checkpoint_save_frequency: int
|
|
57
|
+
"""The frequency to save the model checkpoint."""
|
|
58
|
+
|
|
59
|
+
def clip_gradients_if_needed(self, model, optimizer):
|
|
60
|
+
"""
|
|
61
|
+
Clips gradients if the gradient clipping value is set.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
model (nn.Module): The model whose gradients need to be clipped.
|
|
65
|
+
optimizer (torch.optim.Optimizer): The optimizer used for training.
|
|
66
|
+
"""
|
|
67
|
+
fabric = self.fabric
|
|
68
|
+
|
|
69
|
+
if self.gradient_clip_val is not None:
|
|
70
|
+
if self.gradient_clip_algorithm == "value":
|
|
71
|
+
fabric.clip_gradients(model, optimizer, clip_val=self.gradient_clip_val)
|
|
72
|
+
elif self.gradient_clip_algorithm == "norm":
|
|
73
|
+
fabric.clip_gradients(model, optimizer, max_norm=self.gradient_clip_val)
|
|
74
|
+
else:
|
|
75
|
+
raise ValueError(
|
|
76
|
+
f"Unknown gradient clip algorithm: {self.gradient_clip_algorithm}. Available options: 'value', 'norm'"
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
def compute_expected_total_steps(
|
|
80
|
+
self, train_dataloader: torch.utils.data.DataLoader
|
|
81
|
+
):
|
|
82
|
+
"""
|
|
83
|
+
Computes the expected total number of steps for the entire training.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
train_dataloader (torch.utils.data.DataLoader): The dataloader for the training data.
|
|
87
|
+
"""
|
|
88
|
+
# compute expected total steps
|
|
89
|
+
self._expected_total_steps = []
|
|
90
|
+
if self.max_steps > 0:
|
|
91
|
+
self._expected_total_steps.append(self.max_steps)
|
|
92
|
+
if self.max_steps_per_epoch > 0 and self.max_epochs > 0:
|
|
93
|
+
self._expected_total_steps.append(
|
|
94
|
+
self.max_steps_per_epoch * self.max_epochs
|
|
95
|
+
)
|
|
96
|
+
if self.max_epochs > 0:
|
|
97
|
+
self._expected_total_steps.append(
|
|
98
|
+
len(train_dataloader) * self.max_epochs // self.accumulate_grad_batches
|
|
99
|
+
)
|
|
100
|
+
self._expected_total_steps = min(self._expected_total_steps)
|
|
101
|
+
log.info(f"Expected total steps: {self._expected_total_steps}")
|
|
102
|
+
|
|
103
|
+
@property
|
|
104
|
+
def expected_total_steps(self):
|
|
105
|
+
"""
|
|
106
|
+
The expected total number of steps of the entire training. You need to run `compute_expected_total_steps` method to compute this value before accessing it.
|
|
107
|
+
|
|
108
|
+
Raises:
|
|
109
|
+
ValueError: If the expected total steps have not been computed.
|
|
110
|
+
"""
|
|
111
|
+
if self._expected_total_steps is None:
|
|
112
|
+
raise ValueError(
|
|
113
|
+
"The expected total steps have not been computed. Run `compute_expected_total_steps` method."
|
|
114
|
+
)
|
|
115
|
+
else:
|
|
116
|
+
return self._expected_total_steps
|
|
117
|
+
|
|
118
|
+
def conditional_checkpoint_save(
|
|
119
|
+
self,
|
|
120
|
+
stage: Literal["end_of_step", "end_of_epoch", "end_of_training"],
|
|
121
|
+
*args,
|
|
122
|
+
**kwargs,
|
|
123
|
+
):
|
|
124
|
+
"""
|
|
125
|
+
Conditionally saves a checkpoint based on the current training stage.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
stage (Literal["end_of_step", "end_of_epoch", "end_of_training"]): The current stage of training.
|
|
129
|
+
"""
|
|
130
|
+
if stage == "end_of_step":
|
|
131
|
+
if (
|
|
132
|
+
self.checkpoint_save_interval == "step"
|
|
133
|
+
and (self.global_step_idx + 1) % self.checkpoint_save_frequency == 0
|
|
134
|
+
):
|
|
135
|
+
save_path = os.path.join(
|
|
136
|
+
self.log_dir, "checkpoints", f"step={self.global_step_idx}.ckpt"
|
|
137
|
+
)
|
|
138
|
+
self.save_checkpoint(save_path, *args, **kwargs)
|
|
139
|
+
elif stage == "end_of_epoch":
|
|
140
|
+
if (
|
|
141
|
+
self.checkpoint_save_interval == "epoch"
|
|
142
|
+
and (self.epoch_idx + 1) % self.checkpoint_save_frequency == 0
|
|
143
|
+
):
|
|
144
|
+
save_path = os.path.join(
|
|
145
|
+
self.log_dir, "checkpoints", f"epoch={self.epoch_idx}.ckpt"
|
|
146
|
+
)
|
|
147
|
+
self.save_checkpoint(save_path, *args, **kwargs)
|
|
148
|
+
elif stage == "end_of_training":
|
|
149
|
+
# if the checkpoint has not been saved yet, save it
|
|
150
|
+
if self.global_step_idx > self._latest_saved_checkpoint_global_step:
|
|
151
|
+
save_path = os.path.join(
|
|
152
|
+
self.log_dir,
|
|
153
|
+
"checkpoints",
|
|
154
|
+
f"epoch={self.epoch_idx}_step={self.global_step_idx}.ckpt",
|
|
155
|
+
)
|
|
156
|
+
self.save_checkpoint(save_path, *args, **kwargs)
|
|
157
|
+
try:
|
|
158
|
+
os.symlink(
|
|
159
|
+
src=save_path,
|
|
160
|
+
dst=os.path.join(
|
|
161
|
+
self.log_dir, "checkpoints", "latest_model.ckpt"
|
|
162
|
+
),
|
|
163
|
+
target_is_directory=os.path.isdir(save_path),
|
|
164
|
+
)
|
|
165
|
+
except Exception as e:
|
|
166
|
+
log.error(f"Failed to create symlink: {e}")
|
|
167
|
+
else:
|
|
168
|
+
raise ValueError(
|
|
169
|
+
f"Unknown stage: {stage}. Available options: 'end_of_step', 'end_of_epoch', 'end_of_training'"
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
@abstractmethod
|
|
173
|
+
def save_checkpoint(self, path, **kwargs):
|
|
174
|
+
"""
|
|
175
|
+
Saves a checkpoint of the model.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
path (str): The path where the checkpoint will be saved.
|
|
179
|
+
|
|
180
|
+
Raises:
|
|
181
|
+
NotImplementedError: If the method is not implemented.
|
|
182
|
+
"""
|
|
183
|
+
raise NotImplementedError("save_checkpoint method is not implemented")
|
|
184
|
+
|
|
185
|
+
def train(
|
|
186
|
+
self,
|
|
187
|
+
model: Union[nn.Module, "_FabricModule"],
|
|
188
|
+
optimizer: Union[torch.optim.Optimizer, "_FabricOptimizer"],
|
|
189
|
+
lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
|
|
190
|
+
):
|
|
191
|
+
"""
|
|
192
|
+
Trains the model.
|
|
193
|
+
|
|
194
|
+
The global batch size is `the batch size per device` x `the number of devices` x `the number of gradient accumulation steps`.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
model (Union[nn.Module, "_FabricModule"]): The model to be trained.
|
|
198
|
+
optimizer (Union[torch.optim.Optimizer, "_FabricOptimizer"]): The optimizer used for training.
|
|
199
|
+
lr_scheduler (torch.optim.lr_scheduler.LRScheduler): The learning rate scheduler.
|
|
200
|
+
"""
|
|
201
|
+
fabric = self.fabric
|
|
202
|
+
self.is_training = True
|
|
203
|
+
# number of parameter update iterations, not the number of batches
|
|
204
|
+
self.global_step_idx = 0
|
|
205
|
+
model.train()
|
|
206
|
+
optimizer.zero_grad()
|
|
207
|
+
for epoch_idx in tqdm(
|
|
208
|
+
range(self.max_epochs) if self.max_epochs > 0 else itertools.count(0),
|
|
209
|
+
"Training Epoch",
|
|
210
|
+
dynamic_ncols=True,
|
|
211
|
+
leave=False,
|
|
212
|
+
disable=not fabric.is_global_zero,
|
|
213
|
+
):
|
|
214
|
+
self.epoch_idx = epoch_idx
|
|
215
|
+
self.train_epoch(model, optimizer, lr_scheduler)
|
|
216
|
+
# run lr_scheduler at the end of the epoch if interval is set to "epoch"
|
|
217
|
+
if (
|
|
218
|
+
self.lr_scheduler_interval == "epoch"
|
|
219
|
+
and (epoch_idx + 1) % self.lr_scheduler_frequency == 0
|
|
220
|
+
):
|
|
221
|
+
lr_scheduler.step()
|
|
222
|
+
|
|
223
|
+
# save the model at the end of the epoch if interval is set to "epoch" and frequency is met
|
|
224
|
+
self.conditional_checkpoint_save(stage="end_of_epoch")
|
|
225
|
+
|
|
226
|
+
if not self.is_training:
|
|
227
|
+
break
|
|
228
|
+
|
|
229
|
+
optimizer.zero_grad()
|
|
230
|
+
# save the model at the end of training
|
|
231
|
+
self.conditional_checkpoint_save(stage="end_of_training")
|
|
232
|
+
return model
|
|
233
|
+
|
|
234
|
+
@abstractmethod
|
|
235
|
+
def train_epoch(
|
|
236
|
+
self,
|
|
237
|
+
model: Union[nn.Module, "_FabricModule"],
|
|
238
|
+
optimizer: Union[torch.optim.Optimizer, "_FabricOptimizer"],
|
|
239
|
+
lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
|
|
240
|
+
):
|
|
241
|
+
"""
|
|
242
|
+
Trains the model for one epoch.
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
model (Union[nn.Module, "_FabricModule"]): The model to be trained.
|
|
246
|
+
optimizer (Union[torch.optim.Optimizer, "_FabricOptimizer"]): The optimizer used for training.
|
|
247
|
+
lr_scheduler (torch.optim.lr_scheduler.LRScheduler): The learning rate scheduler.
|
|
248
|
+
|
|
249
|
+
Raises:
|
|
250
|
+
NotImplementedError: If the method is not implemented.
|
|
251
|
+
"""
|
|
252
|
+
raise NotImplementedError(
|
|
253
|
+
"Copy this as a template and implement your own train_epoch method"
|
|
254
|
+
)
|
|
255
|
+
fabric = self.fabric
|
|
256
|
+
|
|
257
|
+
accumulated_loss = 0
|
|
258
|
+
for step_idx, batch in enumerate(
|
|
259
|
+
pbar := tqdm(
|
|
260
|
+
self.train_dataloader,
|
|
261
|
+
desc="Training Batches",
|
|
262
|
+
dynamic_ncols=True,
|
|
263
|
+
leave=False,
|
|
264
|
+
disable=not fabric.is_global_zero,
|
|
265
|
+
)
|
|
266
|
+
):
|
|
267
|
+
is_accumulating = (step_idx + 1) % self.accumulate_grad_batches != 0
|
|
268
|
+
|
|
269
|
+
# disable gradient synchronization if accumulating gradients across steps for improved performance
|
|
270
|
+
with fabric.no_backward_sync(self.model, enabled=is_accumulating):
|
|
271
|
+
# use_cache=True is not compatible with gradient checkpointing, so we disable it here
|
|
272
|
+
output = self.compute_loss(batch)
|
|
273
|
+
loss = output["loss"] / self.accumulate_grad_batches
|
|
274
|
+
|
|
275
|
+
fabric.backward(loss)
|
|
276
|
+
accumulated_loss += loss.item()
|
|
277
|
+
|
|
278
|
+
# 1. update the model parameters if not accumulating gradients
|
|
279
|
+
# 2. step the lr_scheduler if interval is set to "step" and frequency is met
|
|
280
|
+
# 3. save the model if interval is set to "step" and frequency is met
|
|
281
|
+
# 4. log metrics
|
|
282
|
+
# 5. increase the global step index and reset the accumulated metrics
|
|
283
|
+
if not is_accumulating:
|
|
284
|
+
self.clip_gradients_if_needed(model, optimizer)
|
|
285
|
+
|
|
286
|
+
# run lr_scheduler at the end of the step if interval is set to "step"
|
|
287
|
+
if (
|
|
288
|
+
self.lr_scheduler_interval == "step"
|
|
289
|
+
and (self.global_step_idx + 1) % self.lr_scheduler_frequency == 0
|
|
290
|
+
):
|
|
291
|
+
lr_scheduler.step()
|
|
292
|
+
|
|
293
|
+
# update the model parameters and zero the gradients
|
|
294
|
+
optimizer.step()
|
|
295
|
+
optimizer.zero_grad()
|
|
296
|
+
|
|
297
|
+
metrics = {
|
|
298
|
+
"train/loss": accumulated_loss,
|
|
299
|
+
"train/lr": optimizer.param_groups[0]["lr"],
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
fabric.log_dict(metrics, step=self.global_step_idx)
|
|
303
|
+
pbar.set_postfix(metrics)
|
|
304
|
+
|
|
305
|
+
# save the model at the end of the step if interval is set to "step" and frequency is met
|
|
306
|
+
self.conditional_checkpoint_save(stage="end_of_step")
|
|
307
|
+
|
|
308
|
+
# break if max_steps_per_epoch is set, and exit epoch
|
|
309
|
+
if (
|
|
310
|
+
self.max_steps_per_epoch > 0
|
|
311
|
+
and step_idx + 1 >= self.max_steps_per_epoch
|
|
312
|
+
):
|
|
313
|
+
break
|
|
314
|
+
# break if max_steps is set, and exit training
|
|
315
|
+
if self.max_steps > 0 and self.global_step_idx >= self.max_steps - 1:
|
|
316
|
+
self.is_training = False
|
|
317
|
+
break
|
|
318
|
+
|
|
319
|
+
self.global_step_idx += 1
|
|
320
|
+
accumulated_loss = 0
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import functools
|
|
1
2
|
import logging
|
|
2
3
|
import os
|
|
3
4
|
from typing import TYPE_CHECKING, Any, List, Optional, TypeVar
|
|
@@ -13,6 +14,7 @@ from fusion_bench.utils.instantiate import instantiate
|
|
|
13
14
|
|
|
14
15
|
if TYPE_CHECKING:
|
|
15
16
|
import lightning.fabric.loggers.tensorboard
|
|
17
|
+
from lightning.fabric.strategies import FSDPStrategy
|
|
16
18
|
|
|
17
19
|
log = logging.getLogger(__name__)
|
|
18
20
|
|
|
@@ -32,6 +34,13 @@ def get_policy(*args: str) -> set:
|
|
|
32
34
|
return {import_object(arg) for arg in args}
|
|
33
35
|
|
|
34
36
|
|
|
37
|
+
def get_size_based_auto_wrap_policy(*args, **kwargs):
|
|
38
|
+
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
|
|
39
|
+
|
|
40
|
+
policy = functools.partial(size_based_auto_wrap_policy, *args, **kwargs)
|
|
41
|
+
return policy
|
|
42
|
+
|
|
43
|
+
|
|
35
44
|
class LightningFabricMixin:
|
|
36
45
|
"""
|
|
37
46
|
A mixin class for integrating Lightning Fabric into a project.
|
|
@@ -16,6 +16,7 @@ _import_structure = {
|
|
|
16
16
|
"HuggingFaceGPT2ClassificationPool",
|
|
17
17
|
"GPT2ForSequenceClassificationPool",
|
|
18
18
|
],
|
|
19
|
+
"seq_classification_lm": ["SeqenceClassificationModelPool"],
|
|
19
20
|
}
|
|
20
21
|
|
|
21
22
|
|
|
@@ -31,6 +32,7 @@ if TYPE_CHECKING:
|
|
|
31
32
|
from .nyuv2_modelpool import NYUv2ModelPool
|
|
32
33
|
from .PeftModelForSeq2SeqLM import PeftModelForSeq2SeqLMPool
|
|
33
34
|
from .seq2seq_lm import Seq2SeqLMPool
|
|
35
|
+
from .seq_classification_lm import SeqenceClassificationModelPool
|
|
34
36
|
|
|
35
37
|
else:
|
|
36
38
|
sys.modules[__name__] = LazyImporter(
|
|
@@ -1,2 +1,2 @@
|
|
|
1
1
|
# flake8: noqa F401
|
|
2
|
-
from .causal_lm import CausalLMBackbonePool, CausalLMPool
|
|
2
|
+
from .causal_lm import CausalLMBackbonePool, CausalLMPool, load_peft_causal_lm
|