fusion-bench 0.2.6__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. fusion_bench/compat/method/__init__.py +1 -0
  2. fusion_bench/compat/method/base_algorithm.py +7 -1
  3. fusion_bench/compat/modelpool/__init__.py +1 -1
  4. fusion_bench/compat/taskpool/__init__.py +1 -1
  5. fusion_bench/dataset/arc_agi/arc.py +5 -0
  6. fusion_bench/dataset/arc_agi/preprocess.py +1 -1
  7. fusion_bench/dataset/llama/__init__.py +1 -0
  8. fusion_bench/dataset/llama/alpaca.py +93 -3
  9. fusion_bench/dataset/llama/collate.py +62 -2
  10. fusion_bench/dataset/llama/metamathqa.py +50 -0
  11. fusion_bench/dataset/llama/preference_700k.py +70 -0
  12. fusion_bench/dataset/llama/stanford_shp.py +90 -0
  13. fusion_bench/dataset/llama/ultrachat.py +58 -0
  14. fusion_bench/dataset/llama/utils/__init__.py +0 -0
  15. fusion_bench/method/__init__.py +1 -1
  16. fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -4
  17. fusion_bench/method/adamerging/min_norm_solvers.py +4 -4
  18. fusion_bench/method/linear/expo.py +39 -0
  19. fusion_bench/method/lm_finetune/__init__.py +1 -0
  20. fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
  21. fusion_bench/method/lm_finetune/fullfinetune_sft.py +90 -160
  22. fusion_bench/method/lm_finetune/peftfinetune_sft.py +49 -139
  23. fusion_bench/method/pruning/llama_magnitude_prune.py +2 -2
  24. fusion_bench/method/pruning/llama_random_prune.py +2 -2
  25. fusion_bench/method/surgery/__init__.py +3 -0
  26. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
  27. fusion_bench/mixins/__init__.py +2 -0
  28. fusion_bench/mixins/clip_classification.py +58 -5
  29. fusion_bench/mixins/fabric_training.py +320 -0
  30. fusion_bench/mixins/lightning_fabric.py +9 -0
  31. fusion_bench/modelpool/__init__.py +2 -0
  32. fusion_bench/modelpool/causal_lm/__init__.py +1 -1
  33. fusion_bench/modelpool/causal_lm/causal_lm.py +21 -22
  34. fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
  35. fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
  36. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
  37. fusion_bench/models/chat_templates/__init__.py +1 -0
  38. fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
  39. fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
  40. fusion_bench/models/hf_clip.py +50 -9
  41. fusion_bench/models/surgery/surgerymodelwrapper.py +157 -0
  42. fusion_bench/models/utils.py +8 -0
  43. fusion_bench/models/wrappers/layer_wise_fusion.py +14 -5
  44. fusion_bench/models/wrappers/task_wise_fusion.py +5 -5
  45. fusion_bench/optim/__init__.py +2 -0
  46. fusion_bench/optim/exception.py +47 -0
  47. fusion_bench/optim/lr_scheduler/__init__.py +1 -0
  48. fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
  49. fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
  50. fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
  51. fusion_bench/optim/mezo.py +0 -2
  52. fusion_bench/programs/fabric_fusion_program.py +5 -1
  53. fusion_bench/taskpool/clip_vision/taskpool.py +43 -6
  54. fusion_bench/taskpool/llama/reward_model.py +157 -0
  55. fusion_bench/taskpool/nyuv2_taskpool.py +2 -0
  56. fusion_bench/utils/hydra_utils.py +22 -0
  57. fusion_bench/utils/plot/__init__.py +0 -0
  58. fusion_bench/utils/plot/token.py +52 -0
  59. fusion_bench/utils/plot/token_notebook.py +127 -0
  60. fusion_bench/utils/type.py +5 -3
  61. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.7.dist-info}/METADATA +1 -1
  62. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.7.dist-info}/RECORD +87 -47
  63. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  64. fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
  65. fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
  66. fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
  67. fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
  68. fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
  69. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
  70. fusion_bench_config/fabric_model_fusion.yaml +1 -1
  71. fusion_bench_config/llama_full_finetune.yaml +19 -0
  72. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
  73. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +11 -4
  74. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +4 -2
  75. fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
  76. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
  77. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
  78. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
  79. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
  80. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
  81. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
  82. fusion_bench_config/nyuv2_config.yaml +5 -1
  83. fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
  84. fusion_bench_config/llama_weighted_average.yaml +0 -26
  85. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.7.dist-info}/LICENSE +0 -0
  86. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.7.dist-info}/WHEEL +0 -0
  87. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.7.dist-info}/entry_points.txt +0 -0
  88. {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.7.dist-info}/top_level.txt +0 -0
@@ -6,6 +6,10 @@ Reference:
6
6
  """
7
7
 
8
8
  import logging
9
+ from copy import deepcopy
10
+
11
+ import torch
12
+ from torch import nn
9
13
 
10
14
  from fusion_bench import BaseAlgorithm, BaseModelPool
11
15
  from fusion_bench.method import SimpleAverageAlgorithm
@@ -18,6 +22,41 @@ from fusion_bench.utils.state_dict_arithmetic import (
18
22
  log = logging.getLogger(__name__)
19
23
 
20
24
 
25
+ def expo_merge(
26
+ sft_model: nn.Module,
27
+ rlhf_model: nn.Module,
28
+ extrapolation_factor: float,
29
+ inplace: bool = True,
30
+ enable_grad: bool = False,
31
+ ):
32
+ """
33
+ Minimal implementation of ExPO merge.
34
+
35
+ Args:
36
+ sft_model (nn.Module): The pretrained model (base model).
37
+ rlhf_model (nn.Module): The finetuned model (medium-aligned model).
38
+ extrapolation_factor (float): The extrapolation factor.
39
+ inplace (bool): Whether to perform the merge in-place. If not, the rlhf_model will be copied before merging.
40
+ enable_grad (bool): Whether to enable gradient computation during the merge.
41
+
42
+ Returns:
43
+ nn.Module: The merged model.
44
+ """
45
+
46
+ if not inplace:
47
+ rlhf_model = deepcopy(rlhf_model)
48
+
49
+ with torch.set_grad_enabled(enable_grad):
50
+ for (sft_name, sft_param), (rlhf_name, rlhf_param) in zip(
51
+ sft_model.named_parameters(), rlhf_model.named_parameters()
52
+ ):
53
+ assert sft_name == rlhf_name, f"Model mismatch: {sft_name} != {rlhf_name}"
54
+ rlhf_param.data = rlhf_param.data + extrapolation_factor * (
55
+ rlhf_param.data - sft_param.data
56
+ )
57
+ return rlhf_model
58
+
59
+
21
60
  class ExPOAlgorithm(BaseAlgorithm):
22
61
  R"""
23
62
  ExPO merge algorithm.
@@ -1,2 +1,3 @@
1
+ from .bradley_terry_rm import BradleyTerryRewardModeling
1
2
  from .fullfinetune_sft import FullFinetuneSFT
2
3
  from .peftfinetune_sft import PeftFinetuneSFT
@@ -0,0 +1,432 @@
1
+ R"""
2
+ This is basically the same as fullfinetune_sft.py, but with a different loss function.
3
+
4
+ The dataset contains the following fields:
5
+
6
+ - chosen_input_ids: The input token ids for the winner.
7
+ - chosen_attention_mask: The attention mask for the winner.
8
+ - rejected_input_ids: The input token ids for the loser.
9
+ - rejected_attention_mask: The attention mask for the loser.
10
+
11
+ """
12
+
13
+ import functools
14
+ import itertools
15
+ import logging
16
+ import os
17
+ from pathlib import Path
18
+ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, override
19
+
20
+ import lightning as L
21
+ import omegaconf
22
+ import torch
23
+ from lightning.fabric.strategies.fsdp import FSDPStrategy
24
+ from lightning.fabric.utilities import rank_zero_only
25
+ from omegaconf import DictConfig
26
+ from torch import Tensor, nn
27
+ from torch.utils.data import ConcatDataset, DataLoader
28
+ from tqdm.auto import tqdm
29
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
30
+
31
+ from fusion_bench.dataset.llama.collate import bradley_terry_rm_collate
32
+ from fusion_bench.method import BaseAlgorithm
33
+ from fusion_bench.mixins import FabricTrainingMixin
34
+ from fusion_bench.modelpool import SeqenceClassificationModelPool
35
+ from fusion_bench.utils import instantiate
36
+ from fusion_bench.utils.dtype import get_dtype
37
+
38
+ if TYPE_CHECKING:
39
+ from lightning.fabric.wrappers import (
40
+ _FabricDataLoader,
41
+ _FabricModule,
42
+ _FabricOptimizer,
43
+ )
44
+ from transformers.models.llama.modeling_llama import LlamaForSequenceClassification
45
+
46
+ log = logging.getLogger(__name__)
47
+
48
+
49
+ class BradleyTerryRewardModeling(BaseAlgorithm, FabricTrainingMixin):
50
+
51
+ model: Union[nn.Module, "_FabricModule", "LlamaForSequenceClassification"]
52
+ optimizer: Union[torch.optim.Optimizer, "_FabricOptimizer"]
53
+ train_dataloader: Union[DataLoader, "_FabricDataLoader"]
54
+ lr_scheduler: torch.optim.lr_scheduler.LRScheduler
55
+
56
+ def __init__(
57
+ self,
58
+ optimizer: DictConfig,
59
+ lr_scheduler: Optional[DictConfig],
60
+ dataloader_kwargs: DictConfig,
61
+ max_epochs: int,
62
+ max_steps: int = -1,
63
+ max_steps_per_epoch: int = -1,
64
+ lr_scheduler_interval: Literal["epoch", "step"] = "step",
65
+ lr_scheduler_frequency: int = 1,
66
+ checkpoint_save_interval: Literal["epoch", "step"] = "epoch",
67
+ checkpoint_save_frequency: int = 1,
68
+ accumulate_grad_batches: int = 1,
69
+ gradient_clip_val: Optional[float] = None,
70
+ gradient_clip_algorithm: Literal["value", "norm"] = "norm",
71
+ save_optimizer_state: bool = False,
72
+ save_full_model: bool = False,
73
+ save_ckpt_type: Literal["lightning", "hf"] = "lightning",
74
+ ckpt_path: Optional[str] = None,
75
+ max_length: int = 6144,
76
+ fix_token_embedding: bool = True,
77
+ **kwargs,
78
+ ):
79
+ """
80
+ Class for reward modeling using Bradley-Terry model.
81
+
82
+ Args:
83
+ optimizer(DictConfig): Configuration for the optimizer.
84
+ lr_scheduler(DictConfig): Configuration for the learning rate scheduler.
85
+ dataloader_kwargs(DictConfig): Configuration for the dataloader, such as batch size, num_workers, etc.
86
+ max_epochs(int): Maximum number of epochs to train the model. If set to -1, the training will continue indefinitely or until max_steps is reached.
87
+ max_steps(int): Maximum number of steps to train the model. If set to -1, the training will continue indefinitely or until max_epochs is reached.
88
+ max_steps_per_epoch(int): Maximum number of steps to train the model in each epoch. If set to -1, the training will continue until the end of the epoch.
89
+ lr_scheduler_interval(str): Interval at which to run the learning rate scheduler. Available options: 'epoch', 'step'. If set to 'epoch', the scheduler will run at the end of each epoch. If set to 'step', the scheduler will run at the end of each step.
90
+ lr_scheduler_frequency(int): Frequency at which to run the learning rate scheduler. The scheduler will run every `lr_scheduler_frequency` epochs or steps, depending on the value of `lr_scheduler_interval`.
91
+ checkpoint_save_interval(str): Interval at which to save the model checkpoint. Available options: 'epoch', 'step'. If set to 'epoch', the model will be saved at the end of each epoch. If set to 'step', the model will be saved at the end of each step.
92
+ checkpoint_save_frequency(int): Frequency at which to save the model checkpoint. The model will be saved every `checkpoint_save_frequency` epochs or steps, depending on the value of `checkpoint_save_interval`.
93
+ accumulate_grad_batches(int): Number of batches to accumulate gradients across before updating the model parameters.
94
+ gradient_clip_val(float): Value to clip the gradients. If set to None, no gradient clipping will be applied.
95
+ gradient_clip_algorithm(str): Algorithm to use for gradient clipping. Available options: 'value', 'norm'. If set to 'value', the gradients will be clipped to the specified value. If set to 'norm', the gradients will be clipped to the specified norm.
96
+ save_optimizer_state(bool): Whether to save the optimizer and lr_scheduler state along with the model checkpoint.
97
+ save_full_model(bool): Whether to save the full model or only the trainable parameters in the model checkpoint.
98
+ save_ckpt_type (str): Type of checkpoint to save. Available options: 'lightning', 'hf'. If set to 'lightning', the checkpoint will be saved in the lightning format. If set to 'hf', the checkpoint will be saved in the huggingface format.
99
+ ckpt_path(str): Path to the checkpoint to load before training. If set to None, no checkpoint will be loaded.
100
+ max_length(int): Maximum input length to consider. If the input length exceeds this value, it will be truncated.
101
+ fix_token_embedding(bool): Whether to fix the token embeddings during training. If set to True, the token embeddings will not be updated during training.
102
+ """
103
+ self._optimizer = optimizer
104
+ self._lr_scheduler = lr_scheduler
105
+ self.dataloader_kwargs = dataloader_kwargs
106
+ self.max_epochs = max_epochs
107
+ self.max_steps = max_steps
108
+ self.max_steps_per_epoch = max_steps_per_epoch
109
+ self.lr_scheduler_interval = lr_scheduler_interval
110
+ self.lr_scheduler_frequency = lr_scheduler_frequency
111
+ self.checkpoint_save_interval = checkpoint_save_interval
112
+ self.checkpoint_save_frequency = checkpoint_save_frequency
113
+ self.accumulate_grad_batches = accumulate_grad_batches
114
+ self.gradient_clip_val = gradient_clip_val
115
+ self.gradient_clip_algorithm = gradient_clip_algorithm
116
+ self.save_optimizer_state = save_optimizer_state
117
+ self.save_full_model = save_full_model
118
+ self.save_ckpt_type = save_ckpt_type
119
+ self.ckpt_path = ckpt_path
120
+ self.max_length = max_length
121
+ self.fix_token_embedding = fix_token_embedding
122
+ super().__init__(**kwargs)
123
+
124
+ def run(self, modelpool: SeqenceClassificationModelPool):
125
+ self.modelpool = modelpool
126
+ self.setup()
127
+ self.train(self.model, self.optimizer, self.lr_scheduler)
128
+ return self.model
129
+
130
+ def setup_model(self):
131
+ self.tokenizer = self.modelpool.load_tokenizer()
132
+ if self.tokenizer.pad_token_id is None:
133
+ self.tokenizer.pad_token_id = (
134
+ self.tokenizer.eos_token_id
135
+ ) #! make sure eos_token_id only show up at the end of the sequence
136
+
137
+ model = self.modelpool.load_pretrained_model()
138
+ self.model: "LlamaForSequenceClassification" = model
139
+
140
+ if model.config.pad_token_id is None:
141
+ model.config.pad_token_id = self.tokenizer.pad_token_id
142
+
143
+ if self.fix_token_embedding:
144
+ self.model.model.embed_tokens.requires_grad_(False)
145
+
146
+ if self.fabric.strategy == "fsdp" or isinstance(
147
+ self.fabric.strategy, FSDPStrategy
148
+ ):
149
+ # https://github.com/Lightning-AI/pytorch-lightning/issues/19267
150
+ self.model.gradient_checkpointing_enable(
151
+ gradient_checkpointing_kwargs={"use_reentrant": True}
152
+ )
153
+ self.use_cache = False
154
+ else:
155
+ self.use_cache = True
156
+ self.model_dtype = get_dtype(self.model)
157
+
158
+ def setup_data(self):
159
+ fabric = self.fabric
160
+ modelpool = self.modelpool
161
+ assert (
162
+ len(modelpool.train_dataset_names) > 0
163
+ ), "No training datasets found in modelpool."
164
+
165
+ train_datasets = [
166
+ modelpool.load_train_dataset(dataset_name)
167
+ for dataset_name in modelpool.train_dataset_names
168
+ ]
169
+ if len(train_datasets) > 1:
170
+ train_dataset = ConcatDataset(train_datasets)
171
+ else:
172
+ train_dataset = train_datasets[0]
173
+
174
+ self.train_dataset = train_dataset
175
+ self.train_dataloader = DataLoader(
176
+ train_dataset,
177
+ **self.dataloader_kwargs,
178
+ shuffle=True,
179
+ collate_fn=functools.partial(
180
+ bradley_terry_rm_collate,
181
+ pad_token_id=self.tokenizer.pad_token_id,
182
+ ), # NOTE: different from SFT, uses bradley_terry_rm_collate
183
+ )
184
+ self.train_dataloader = fabric.setup_dataloaders(self.train_dataloader)
185
+
186
+ def configure_optimizer(self):
187
+ # compute expected total steps
188
+ self.compute_expected_total_steps(self.train_dataloader)
189
+
190
+ optimizer = instantiate(self._optimizer, self.model.parameters())
191
+ if self._lr_scheduler is not None:
192
+ for key, arg in self._lr_scheduler.items():
193
+ if arg == "_T_max_":
194
+ log.info(
195
+ f"Setting key `{key}` of lr_scheduler configuration to {self.expected_total_steps}"
196
+ )
197
+ self._lr_scheduler[key] = self.expected_total_steps
198
+ lr_scheduler: torch.optim.lr_scheduler.LRScheduler = instantiate(
199
+ self._lr_scheduler,
200
+ optimizer=optimizer,
201
+ )
202
+ else:
203
+ lr_scheduler = None
204
+ return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}
205
+
206
+ def setup(self):
207
+ fabric = self.fabric
208
+
209
+ self.setup_model()
210
+ self.setup_data()
211
+
212
+ optimizer = self.configure_optimizer()
213
+ optimizer, lr_scheduler = optimizer["optimizer"], optimizer["lr_scheduler"]
214
+
215
+ self.model, self.optimizer = fabric.setup(self.model, optimizer)
216
+ self.lr_scheduler = lr_scheduler
217
+
218
+ def compute_loss(self, batch: Dict[str, Union[Tensor, Any]]) -> Dict[str, Tensor]:
219
+ """
220
+ Maximize the likelihood of the winner over the loser using the Bradley-Terry model.
221
+
222
+ Args:
223
+ batch (Dict[str, Union[Tensor, Any]]): A dictionary containing the input token ids and attention masks for the winner and loser.
224
+ """
225
+ batch_size = batch["input_ids"].size(0)
226
+ assert batch_size % 2 == 0, "Batch size must be even."
227
+
228
+ outputs = self.model(
229
+ input_ids=batch["input_ids"],
230
+ attention_mask=batch["attention_mask"],
231
+ use_cache=self.use_cache,
232
+ )
233
+
234
+ rewards = outputs[0]
235
+ chosen_reward = rewards[: batch_size // 2]
236
+ rejected_rewards = rewards[batch_size // 2 :]
237
+ loss = -torch.log(torch.sigmoid(chosen_reward - rejected_rewards)).mean()
238
+
239
+ return {
240
+ "chosen_reward": chosen_reward,
241
+ "rejected_reward": rejected_rewards,
242
+ "loss": loss,
243
+ }
244
+
245
+ @override
246
+ def train_epoch(self, *args, **kwargs):
247
+ fabric = self.fabric
248
+
249
+ accumulated_loss = 0
250
+ accumulated_chosen_reward = 0
251
+ accumulated_rejected_reward = 0
252
+ for step_idx, batch in enumerate(
253
+ pbar := tqdm(
254
+ self.train_dataloader,
255
+ desc="Training Batches",
256
+ dynamic_ncols=True,
257
+ leave=False,
258
+ disable=not fabric.is_global_zero,
259
+ )
260
+ ):
261
+ is_accumulating = (step_idx + 1) % self.accumulate_grad_batches != 0
262
+
263
+ if self.max_length > 0 and batch["input_ids"].shape[1] > self.max_length:
264
+ log.warning(
265
+ f"Input length exceeds max_length: {batch['input_ids'].shape[1]} > {self.max_length}. Truncating input."
266
+ )
267
+ batch["input_ids"] = batch["input_ids"][:, -self.max_length :]
268
+ batch["attention_mask"] = batch["attention_mask"][:, -self.max_length :]
269
+
270
+ # disable gradient synchronization if accumulating gradients across steps for improved performance
271
+ with fabric.no_backward_sync(self.model, enabled=is_accumulating):
272
+ # use_cache=True is not compatible with gradient checkpointing, so we disable it here
273
+ output = self.compute_loss(batch)
274
+ loss = output["loss"] / self.accumulate_grad_batches
275
+
276
+ fabric.backward(loss)
277
+
278
+ accumulated_loss += loss.item()
279
+ accumulated_chosen_reward += output["chosen_reward"].mean().item()
280
+ accumulated_rejected_reward += output["rejected_reward"].mean().item()
281
+
282
+ # 1. update the model parameters if not accumulating gradients
283
+ # 2. step the lr_scheduler if interval is set to "step" and frequency is met
284
+ # 3. save the model if interval is set to "step" and frequency is met
285
+ # 4. log metrics
286
+ # 5. increase the global step index
287
+ if not is_accumulating:
288
+ self.clip_gradients_if_needed(self.model, self.optimizer)
289
+
290
+ # run lr_scheduler at the end of the step if interval is set to "step"
291
+ if (
292
+ self.lr_scheduler_interval == "step"
293
+ and (self.global_step_idx + 1) % self.lr_scheduler_frequency == 0
294
+ ):
295
+ self.lr_scheduler.step()
296
+
297
+ # update the model parameters and zero the gradients
298
+ self.optimizer.step()
299
+ self.optimizer.zero_grad()
300
+
301
+ metrics = {
302
+ "train/loss": accumulated_loss,
303
+ "train/chosen_reward": accumulated_chosen_reward
304
+ / self.accumulate_grad_batches,
305
+ "train/rejected_reward": accumulated_rejected_reward
306
+ / self.accumulate_grad_batches,
307
+ "train/epoch_idx": self.epoch_idx,
308
+ "train/lr": self.optimizer.param_groups[0]["lr"],
309
+ }
310
+ metrics["train/chosen_reward-rejected_reward"] = (
311
+ metrics["train/chosen_reward"] - metrics["train/rejected_reward"]
312
+ )
313
+
314
+ fabric.log_dict(metrics, step=self.global_step_idx)
315
+ pbar.set_postfix(metrics)
316
+
317
+ # save the model at the end of the step if interval is set to "step" and frequency is met
318
+ self.conditional_checkpoint_save(stage="end_of_step")
319
+
320
+ # break if max_steps_per_epoch is set, and exit epoch
321
+ if (
322
+ self.max_steps_per_epoch > 0
323
+ and step_idx + 1 >= self.max_steps_per_epoch
324
+ ):
325
+ break
326
+ # break if max_steps is set, and exit training
327
+ if self.max_steps > 0 and self.global_step_idx >= self.max_steps - 1:
328
+ self.is_training = False
329
+ break
330
+
331
+ self.global_step_idx += 1
332
+ accumulated_loss = 0
333
+ accumulated_chosen_reward = 0
334
+ accumulated_rejected_reward = 0
335
+
336
+ def save_checkpoint(
337
+ self,
338
+ path: Union[str, Path],
339
+ save_optimizer_state: Optional[bool] = None,
340
+ overwrite: bool = False,
341
+ ):
342
+ if not overwrite and os.path.exists(path):
343
+ return log.warning(f"Checkpoint already exists at {path}. Skipping save.")
344
+
345
+ fabric = self.fabric
346
+
347
+ if self.save_ckpt_type == "lightning":
348
+ state = {"model": self.model}
349
+
350
+ # save the optimizer and lr_scheduler state if needed
351
+ if self.save_optimizer_state and save_optimizer_state is not False:
352
+ state.update(
353
+ {
354
+ "optimizer": self.optimizer,
355
+ "lr_scheduler": self.lr_scheduler,
356
+ "global_step_idx": self.global_step_idx,
357
+ "epoch_idx": self.epoch_idx,
358
+ }
359
+ )
360
+
361
+ trainable_param_names = set(
362
+ name
363
+ for name, param in self.model.state_dict(keep_vars=True).items()
364
+ if param.requires_grad
365
+ )
366
+ filter = (
367
+ None
368
+ if self.save_full_model
369
+ else {"model": lambda k, p: k in trainable_param_names}
370
+ )
371
+
372
+ fabric.save(path, state=state, filter=filter)
373
+ else:
374
+ self.model.save_pretrained(path, is_main_process=fabric.is_global_zero)
375
+
376
+ self._latest_saved_checkpoint_global_step = self.global_step_idx
377
+
378
+ def load_checkpoint(self, path: Union[str, Path]):
379
+ fabric = self.fabric
380
+
381
+ state = {"model": self.model}
382
+
383
+ # save the optimizer and lr_scheduler state if needed
384
+ if self.save_optimizer_state:
385
+ state.update(
386
+ {
387
+ "optimizer": self.optimizer,
388
+ "lr_scheduler": self.lr_scheduler,
389
+ }
390
+ )
391
+
392
+ fabric.load(path, state)
393
+
394
+
395
+ def load_checkpoint(
396
+ fabric: L.Fabric,
397
+ ckpt_path: Union[str, Path],
398
+ model: Union[nn.Module, "LlamaForSequenceClassification"],
399
+ strict: bool = True,
400
+ **state_components,
401
+ ):
402
+ """
403
+ Load a checkpoint into a model.
404
+ """
405
+ state = {"model": model}
406
+ state.update(state_components)
407
+ fabric.load(ckpt_path, state=state, strict=strict)
408
+
409
+
410
+ if __name__ == "__main__":
411
+ # convert a checkpoint to hf format
412
+ import argparse
413
+
414
+ parser = argparse.ArgumentParser()
415
+ parser.add_argument("--base-model-path", type=str)
416
+ parser.add_argument("--ckpt-path", type=str)
417
+ parser.add_argument("--output-path", type=str)
418
+
419
+ args = parser.parse_args()
420
+
421
+ fabric = L.Fabric(devices=1, strategy="fsdp")
422
+ fabric.launch()
423
+
424
+ tokenizer = AutoTokenizer.from_pretrained(args.base_model_path)
425
+ tokenizer.save_pretrained(args.output_path)
426
+
427
+ model = AutoModelForSequenceClassification.from_pretrained(
428
+ args.base_model_path, num_labels=1, torch_dtype=torch.bfloat16
429
+ )
430
+ model = fabric.setup_module(model)
431
+ load_checkpoint(fabric, args.ckpt_path, model=model, strict=True)
432
+ model.save_pretrained(args.output_path)