fusion-bench 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. fusion_bench/compat/method/__init__.py +2 -0
  2. fusion_bench/compat/method/base_algorithm.py +7 -2
  3. fusion_bench/compat/modelpool/__init__.py +3 -2
  4. fusion_bench/compat/taskpool/__init__.py +1 -1
  5. fusion_bench/dataset/arc_agi/__init__.py +6 -1
  6. fusion_bench/dataset/arc_agi/arc.py +26 -7
  7. fusion_bench/dataset/arc_agi/arc_agi.py +156 -25
  8. fusion_bench/dataset/arc_agi/np_cache.py +0 -1
  9. fusion_bench/dataset/arc_agi/preprocess.py +51 -9
  10. fusion_bench/dataset/llama/__init__.py +1 -0
  11. fusion_bench/dataset/llama/alpaca.py +93 -3
  12. fusion_bench/dataset/llama/collate.py +72 -5
  13. fusion_bench/dataset/llama/metamathqa.py +50 -0
  14. fusion_bench/dataset/llama/preference_700k.py +70 -0
  15. fusion_bench/dataset/llama/stanford_shp.py +90 -0
  16. fusion_bench/dataset/llama/ultrachat.py +58 -0
  17. fusion_bench/dataset/llama/utils/__init__.py +0 -0
  18. fusion_bench/method/__init__.py +4 -1
  19. fusion_bench/method/adamerging/__init__.py +1 -1
  20. fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -4
  21. fusion_bench/method/adamerging/min_norm_solvers.py +4 -4
  22. fusion_bench/method/linear/expo.py +39 -0
  23. fusion_bench/method/lm_finetune/__init__.py +1 -0
  24. fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
  25. fusion_bench/method/lm_finetune/fullfinetune_sft.py +122 -150
  26. fusion_bench/method/lm_finetune/peftfinetune_sft.py +102 -157
  27. fusion_bench/method/pruning/llama_magnitude_prune.py +2 -2
  28. fusion_bench/method/pruning/llama_random_prune.py +2 -2
  29. fusion_bench/method/pruning/magnitude_diff_pruning.py +2 -1
  30. fusion_bench/method/rankone_moe/__init__.py +3 -0
  31. fusion_bench/method/rankone_moe/clip_rankone_moe.py +160 -0
  32. fusion_bench/method/rankone_moe/rankone_moe.py +249 -0
  33. fusion_bench/method/simple_average.py +1 -1
  34. fusion_bench/method/surgery/__init__.py +3 -0
  35. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
  36. fusion_bench/mixins/__init__.py +2 -0
  37. fusion_bench/mixins/clip_classification.py +60 -12
  38. fusion_bench/mixins/fabric_training.py +320 -0
  39. fusion_bench/mixins/lightning_fabric.py +11 -2
  40. fusion_bench/modelpool/__init__.py +2 -0
  41. fusion_bench/modelpool/causal_lm/__init__.py +1 -1
  42. fusion_bench/modelpool/causal_lm/causal_lm.py +21 -22
  43. fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
  44. fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
  45. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
  46. fusion_bench/models/chat_templates/__init__.py +1 -0
  47. fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
  48. fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
  49. fusion_bench/models/hf_clip.py +50 -9
  50. fusion_bench/models/rankone_moe.py +410 -0
  51. fusion_bench/models/surgery/surgerymodelwrapper.py +157 -0
  52. fusion_bench/models/utils.py +8 -0
  53. fusion_bench/models/wrappers/layer_wise_fusion.py +14 -5
  54. fusion_bench/models/wrappers/task_wise_fusion.py +5 -5
  55. fusion_bench/optim/__init__.py +2 -0
  56. fusion_bench/optim/exception.py +47 -0
  57. fusion_bench/optim/lr_scheduler/__init__.py +1 -0
  58. fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
  59. fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
  60. fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
  61. fusion_bench/optim/mezo.py +0 -2
  62. fusion_bench/programs/fabric_fusion_program.py +5 -1
  63. fusion_bench/taskpool/__init__.py +10 -2
  64. fusion_bench/taskpool/clip_vision/__init__.py +1 -0
  65. fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +112 -0
  66. fusion_bench/taskpool/clip_vision/taskpool.py +43 -6
  67. fusion_bench/taskpool/llama/reward_model.py +157 -0
  68. fusion_bench/taskpool/nyuv2_taskpool.py +2 -0
  69. fusion_bench/tasks/flan_t5_text_generation/glue_load_dataset.py +2 -1
  70. fusion_bench/utils/hydra_utils.py +22 -0
  71. fusion_bench/utils/plot/__init__.py +0 -0
  72. fusion_bench/utils/plot/token.py +52 -0
  73. fusion_bench/utils/plot/token_notebook.py +127 -0
  74. fusion_bench/utils/type.py +5 -3
  75. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/METADATA +1 -1
  76. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/RECORD +104 -57
  77. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  78. fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
  79. fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
  80. fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
  81. fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
  82. fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
  83. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
  84. fusion_bench_config/fabric_model_fusion.yaml +1 -1
  85. fusion_bench_config/llama_full_finetune.yaml +19 -0
  86. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
  87. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +13 -6
  88. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +17 -9
  89. fusion_bench_config/method/rankone_moe/rankone_moe.yaml +26 -0
  90. fusion_bench_config/method/regmean/clip_regmean.yaml +1 -0
  91. fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
  92. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
  93. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
  94. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
  95. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
  96. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
  97. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
  98. fusion_bench_config/nyuv2_config.yaml +5 -1
  99. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +18 -0
  100. fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
  101. fusion_bench_config/llama_weighted_average.yaml +0 -26
  102. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/LICENSE +0 -0
  103. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/WHEEL +0 -0
  104. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/entry_points.txt +0 -0
  105. {fusion_bench-0.2.5.dist-info → fusion_bench-0.2.7.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ import functools
1
2
  import itertools
2
3
  import logging
3
4
  import os
@@ -13,11 +14,12 @@ from omegaconf import DictConfig
13
14
  from torch import nn
14
15
  from torch.utils.data import ConcatDataset, DataLoader
15
16
  from tqdm.auto import tqdm
17
+ from transformers import AutoModelForCausalLM, AutoTokenizer
16
18
  from typing_extensions import TYPE_CHECKING, override
17
19
 
18
20
  from fusion_bench import BaseAlgorithm, BaseModelPool
19
21
  from fusion_bench.dataset.llama.collate import padded_collate_sft
20
- from fusion_bench.mixins import LightningFabricMixin
22
+ from fusion_bench.mixins import FabricTrainingMixin
21
23
  from fusion_bench.modelpool import CausalLMPool
22
24
  from fusion_bench.utils import instantiate
23
25
  from fusion_bench.utils.dtype import get_dtype
@@ -33,7 +35,7 @@ if TYPE_CHECKING:
33
35
  log = logging.getLogger(__name__)
34
36
 
35
37
 
36
- class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
38
+ class FullFinetuneSFT(BaseAlgorithm, FabricTrainingMixin):
37
39
 
38
40
  model: Union[nn.Module, "_FabricModule", "LlamaForCausalLM"]
39
41
  optimizer: Union[torch.optim.Optimizer, "_FabricOptimizer"]
@@ -58,7 +60,10 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
58
60
  gradient_clip_algorithm: Literal["value", "norm"] = "norm",
59
61
  save_optimizer_state: bool = False,
60
62
  save_full_model: bool = False,
63
+ save_ckpt_type: Literal["lightning", "hf"] = "lightning",
61
64
  ckpt_path: Optional[str] = None,
65
+ max_length: int = 6144,
66
+ fix_token_embedding: bool = True,
62
67
  **kwargs,
63
68
  ):
64
69
  """
@@ -80,7 +85,10 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
80
85
  gradient_clip_algorithm(str): Algorithm to use for gradient clipping. Available options: 'value', 'norm'. If set to 'value', the gradients will be clipped to the specified value. If set to 'norm', the gradients will be clipped to the specified norm.
81
86
  save_optimizer_state(bool): Whether to save the optimizer and lr_scheduler state along with the model checkpoint.
82
87
  save_full_model(bool): Whether to save the full model or only the trainable parameters in the model checkpoint.
88
+ save_ckpt_type (str): Type of checkpoint to save. Available options: 'lightning', 'hf'. If set to 'lightning', the checkpoint will be saved in the lightning format. If set to 'hf', the checkpoint will be saved in the huggingface format.
83
89
  ckpt_path(str): Path to the checkpoint to load before training. If set to None, no checkpoint will be loaded.
90
+ max_length(int): Maximum input length to consider. If the input length exceeds this value, it will be truncated.
91
+ fix_token_embedding(bool): Whether to fix the token embeddings during training. If set to True, the token embeddings will not be updated during training.
84
92
  """
85
93
  self._optimizer = optimizer
86
94
  self._lr_scheduler = lr_scheduler
@@ -97,18 +105,28 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
97
105
  self.gradient_clip_algorithm = gradient_clip_algorithm
98
106
  self.save_optimizer_state = save_optimizer_state
99
107
  self.save_full_model = save_full_model
108
+ self.save_ckpt_type = save_ckpt_type
100
109
  self.ckpt_path = ckpt_path
110
+ self.max_length = max_length
111
+ self.fix_token_embedding = fix_token_embedding
101
112
  super().__init__(**kwargs)
102
113
 
103
114
  def run(self, modelpool: CausalLMPool):
104
115
  self.modelpool = modelpool
105
116
  self.setup()
106
- self.train()
117
+ self.train(self.model, self.optimizer, self.lr_scheduler)
107
118
  return self.model
108
119
 
109
120
  def setup_model(self):
121
+ self.tokenizer = self.modelpool.load_tokenizer()
122
+ if self.tokenizer.pad_token_id is None:
123
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
124
+
110
125
  model = self.modelpool.load_pretrained_model()
111
- self.model = model
126
+ self.model: "LlamaForCausalLM" = model
127
+
128
+ if self.fix_token_embedding:
129
+ self.model.model.embed_tokens.requires_grad_(False)
112
130
 
113
131
  if self.fabric.strategy == "fsdp" or isinstance(
114
132
  self.fabric.strategy, FSDPStrategy
@@ -117,21 +135,14 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
117
135
  self.model.gradient_checkpointing_enable(
118
136
  gradient_checkpointing_kwargs={"use_reentrant": True}
119
137
  )
138
+ self.use_cache = False
139
+ else:
140
+ self.use_cache = True
120
141
  self.model_dtype = get_dtype(self.model)
121
142
 
122
143
  def configure_optimizer(self):
123
144
  # compute expected total steps
124
- self.expected_total_steps = []
125
- if self.max_steps > 0:
126
- self.expected_total_steps.append(self.max_steps)
127
- if self.max_steps_per_epoch > 0 and self.max_epochs > 0:
128
- self.expected_total_steps.append(self.max_steps_per_epoch * self.max_epochs)
129
- if self.max_epochs > 0:
130
- self.expected_total_steps.append(
131
- len(self.train_dataloader) * self.max_epochs
132
- )
133
- self.expected_total_steps = min(self.expected_total_steps)
134
- log.info(f"Expected total steps: {self.expected_total_steps}")
145
+ self.compute_expected_total_steps(self.train_dataloader)
135
146
 
136
147
  optimizer = instantiate(self._optimizer, self.model.parameters())
137
148
  if self._lr_scheduler is not None:
@@ -170,7 +181,9 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
170
181
  train_dataset,
171
182
  **self.dataloader_kwargs,
172
183
  shuffle=True,
173
- collate_fn=padded_collate_sft,
184
+ collate_fn=functools.partial(
185
+ padded_collate_sft, pad_token_id=self.tokenizer.pad_token_id
186
+ ),
174
187
  )
175
188
  self.train_dataloader = fabric.setup_dataloaders(self.train_dataloader)
176
189
 
@@ -186,25 +199,15 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
186
199
  self.model, self.optimizer = fabric.setup(self.model, optimizer)
187
200
  self.lr_scheduler = lr_scheduler
188
201
 
189
- def _clip_gradients_if_needed(self):
202
+ @override
203
+ def train_epoch(self, *args, **kwargs):
190
204
  fabric = self.fabric
191
205
 
192
- if self.gradient_clip_val is not None:
193
- if self.gradient_clip_algorithm == "value":
194
- fabric.clip_gradients(self.model, clip_val=self.gradient_clip_val)
195
- elif self.gradient_clip_algorithm == "norm":
196
- fabric.clip_gradients(self.model, max_norm=self.gradient_clip_val)
197
- else:
198
- raise ValueError(
199
- f"Unknown gradient clip algorithm: {self.gradient_clip_algorithm}. Available options: 'value', 'norm'"
200
- )
201
-
202
- def train_epoch(self):
203
- fabric = self.fabric
206
+ accumulated_loss = 0
204
207
  for step_idx, batch in enumerate(
205
208
  pbar := tqdm(
206
209
  self.train_dataloader,
207
- desc="Training Steps",
210
+ desc="Training Batches",
208
211
  dynamic_ncols=True,
209
212
  leave=False,
210
213
  disable=not fabric.is_global_zero,
@@ -212,24 +215,30 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
212
215
  ):
213
216
  is_accumulating = (step_idx + 1) % self.accumulate_grad_batches != 0
214
217
 
218
+ if self.max_length > 0 and batch["input_ids"].shape[1] > self.max_length:
219
+ log.warning(
220
+ f"Input length exceeds max_length: {batch['input_ids'].shape[1]} > {self.max_length}. Truncating input."
221
+ )
222
+ batch["input_ids"] = batch["input_ids"][:, : self.max_length]
223
+ batch["attention_mask"] = batch["attention_mask"][:, : self.max_length]
224
+ batch["labels"] = batch["labels"][:, : self.max_length]
225
+
215
226
  # disable gradient synchronization if accumulating gradients across steps for improved performance
216
227
  with fabric.no_backward_sync(self.model, enabled=is_accumulating):
217
228
  # use_cache=True is not compatible with gradient checkpointing, so we disable it here
218
- output = self.model(**batch, use_cache=False)
219
- loss = output["loss"]
229
+ output = self.model(
230
+ input_ids=batch["input_ids"],
231
+ attention_mask=batch["attention_mask"],
232
+ labels=batch["labels"],
233
+ use_cache=self.use_cache,
234
+ )
235
+ loss = output["loss"] / self.accumulate_grad_batches
220
236
 
221
237
  fabric.backward(loss)
222
-
223
- metrics = {
224
- "train/loss": loss.item(),
225
- "train/epoch_idx": self.epoch_idx,
226
- "train/lr": self.optimizer.param_groups[0]["lr"],
227
- }
228
- fabric.log_dict(metrics, step=self.global_step_idx)
229
- pbar.set_postfix(metrics)
238
+ accumulated_loss += loss.item()
230
239
 
231
240
  if not is_accumulating:
232
- self._clip_gradients_if_needed()
241
+ self.clip_gradients_if_needed(self.model, self.optimizer)
233
242
 
234
243
  # run lr_scheduler at the end of the step if interval is set to "step"
235
244
  if (
@@ -242,104 +251,30 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
242
251
  self.optimizer.step()
243
252
  self.optimizer.zero_grad()
244
253
 
245
- # save the model at the end of the step if interval is set to "step" and frequency is met
246
- self._try_save_checkpoint(stage="end_of_step")
254
+ metrics = {
255
+ "train/loss": accumulated_loss,
256
+ "train/epoch_idx": self.epoch_idx,
257
+ "train/lr": self.optimizer.param_groups[0]["lr"],
258
+ }
259
+ fabric.log_dict(metrics, step=self.global_step_idx)
260
+ pbar.set_postfix(metrics)
247
261
 
248
- # break if max_steps_per_epoch is set, and exit epoch
249
- if (
250
- self.max_steps_per_epoch > 0
251
- and step_idx + 1 >= self.max_steps_per_epoch
252
- ):
253
- break
254
- # break if max_steps is set, and exit training
255
- if self.max_steps > 0 and self.global_step_idx >= self.max_steps:
256
- self.is_training = False
257
- break
262
+ # save the model at the end of the step if interval is set to "step" and frequency is met
263
+ self.conditional_checkpoint_save(stage="end_of_step")
258
264
 
259
- self.global_step_idx += 1
265
+ # break if max_steps_per_epoch is set, and exit epoch
266
+ if (
267
+ self.max_steps_per_epoch > 0
268
+ and step_idx + 1 >= self.max_steps_per_epoch
269
+ ):
270
+ break
271
+ # break if max_steps is set, and exit training
272
+ if self.max_steps > 0 and self.global_step_idx >= self.max_steps - 1:
273
+ self.is_training = False
274
+ break
260
275
 
261
- def train(self):
262
- fabric = self.fabric
263
- self.is_training = True
264
- self.global_step_idx = 0
265
- self.model.train()
266
- for epoch_idx in tqdm(
267
- range(self.max_epochs) if self.max_epochs > 0 else itertools.count(0),
268
- "Training Epoch",
269
- dynamic_ncols=True,
270
- leave=False,
271
- disable=not fabric.is_global_zero,
272
- ):
273
- self.epoch_idx = epoch_idx
274
- self.train_epoch()
275
- # run lr_scheduler at the end of the epoch if interval is set to "epoch"
276
- if (
277
- self.lr_scheduler_interval == "epoch"
278
- and (epoch_idx + 1) % self.lr_scheduler_frequency == 0
279
- ):
280
- self.lr_scheduler.step()
281
-
282
- # save the model at the end of the epoch if interval is set to "epoch" and frequency is met
283
- self._try_save_checkpoint(stage="end_of_epoch")
284
-
285
- if not self.is_training:
286
- break
287
-
288
- # save the model at the end of training
289
- self._try_save_checkpoint(stage="end_of_training")
290
-
291
- def _try_save_checkpoint(
292
- self, stage: Literal["end_of_step", "end_of_epoch", "end_of_training"]
293
- ):
294
- if stage == "end_of_step":
295
- if (
296
- self.checkpoint_save_interval == "step"
297
- and (self.global_step_idx + 1) % self.checkpoint_save_frequency == 0
298
- ):
299
- self.save_checkpoint(
300
- os.path.join(
301
- self.log_dir, "checkpoints", f"step={self.global_step_idx}.ckpt"
302
- )
303
- )
304
- elif stage == "end_of_epoch":
305
- if (
306
- self.checkpoint_save_interval == "epoch"
307
- and (self.epoch_idx + 1) % self.checkpoint_save_frequency == 0
308
- ):
309
- self.save_checkpoint(
310
- os.path.join(
311
- self.log_dir, "checkpoints", f"epoch={self.epoch_idx}.ckpt"
312
- )
313
- )
314
- elif stage == "end_of_training":
315
- # if the checkpoint has not been saved yet, save it
316
- if self.global_step_idx > self._latest_saved_checkpoint_global_step:
317
- self.save_checkpoint(
318
- os.path.join(
319
- self.log_dir,
320
- "checkpoints",
321
- f"epoch={self.epoch_idx}_step={self.global_step_idx}.ckpt",
322
- )
323
- )
324
- try:
325
- os.symlink(
326
- os.path.join(
327
- self.log_dir,
328
- "checkpoints",
329
- "latest_model.ckpt",
330
- ),
331
- os.path.join(
332
- self.log_dir,
333
- "checkpoints",
334
- f"epoch={self.epoch_idx}_step={self.global_step_idx}.ckpt",
335
- ),
336
- )
337
- except Exception as e:
338
- pass
339
- else:
340
- raise ValueError(
341
- f"Unknown stage: {stage}. Available options: 'end_of_step', 'end_of_epoch', 'end_of_training'"
342
- )
276
+ self.global_step_idx += 1
277
+ accumulated_loss = 0
343
278
 
344
279
  def save_checkpoint(
345
280
  self,
@@ -351,24 +286,36 @@ class FullFinetuneSFT(BaseAlgorithm, LightningFabricMixin):
351
286
  return log.warning(f"Checkpoint already exists at {path}. Skipping save.")
352
287
 
353
288
  fabric = self.fabric
354
- state = {"model": self.model}
355
289
 
356
- # save the optimizer and lr_scheduler state if needed
357
- if self.save_optimizer_state and save_optimizer_state is not False:
358
- state.update(
359
- {
360
- "optimizer": self.optimizer,
361
- "lr_scheduler": self.lr_scheduler,
362
- "global_step_idx": self.global_step_idx,
363
- "epoch_idx": self.epoch_idx,
364
- }
290
+ if self.save_ckpt_type == "lightning":
291
+ state = {"model": self.model}
292
+
293
+ # save the optimizer and lr_scheduler state if needed
294
+ if self.save_optimizer_state and save_optimizer_state is not False:
295
+ state.update(
296
+ {
297
+ "optimizer": self.optimizer,
298
+ "lr_scheduler": self.lr_scheduler,
299
+ "global_step_idx": self.global_step_idx,
300
+ "epoch_idx": self.epoch_idx,
301
+ }
302
+ )
303
+
304
+ trainable_param_names = set(
305
+ name
306
+ for name, param in self.model.state_dict(keep_vars=True).items()
307
+ if param.requires_grad
308
+ )
309
+ filter = (
310
+ None
311
+ if self.save_full_model
312
+ else {"model": lambda k, p: k in trainable_param_names}
365
313
  )
366
314
 
367
- filter = (
368
- None if self.save_full_model else {"model": lambda k, p: p.requires_grad}
369
- )
315
+ fabric.save(path, state=state, filter=filter)
316
+ else:
317
+ self.model.save_pretrained(path, is_main_process=fabric.is_global_zero)
370
318
 
371
- fabric.save(path, state=state, filter=filter)
372
319
  self._latest_saved_checkpoint_global_step = self.global_step_idx
373
320
 
374
321
  def load_checkpoint(self, path: Union[str, Path]):
@@ -401,3 +348,28 @@ def load_checkpoint(
401
348
  state = {"model": model}
402
349
  state.update(state_components)
403
350
  fabric.load(ckpt_path, state=state, strict=strict)
351
+
352
+
353
+ if __name__ == "__main__":
354
+ # convert a checkpoint to hf format
355
+ import argparse
356
+
357
+ parser = argparse.ArgumentParser()
358
+ parser.add_argument("--base-model-path", type=str)
359
+ parser.add_argument("--ckpt-path", type=str)
360
+ parser.add_argument("--output-path", type=str)
361
+
362
+ args = parser.parse_args()
363
+
364
+ fabric = L.Fabric(devices=1, strategy="fsdp")
365
+ fabric.launch()
366
+
367
+ tokenizer = AutoTokenizer.from_pretrained(args.base_model_path)
368
+ tokenizer.save_pretrained(args.output_path)
369
+
370
+ model = AutoModelForCausalLM.from_pretrained(
371
+ args.base_model_path, torch_dtype=torch.bfloat16
372
+ )
373
+ model = fabric.setup_module(model)
374
+ load_checkpoint(fabric, args.ckpt_path, model=model, strict=True)
375
+ model.save_pretrained(args.output_path)