project-llm-trainer 0.4.15__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

Files changed (30) hide show
  1. llm_trainer/checkpoint.py +0 -50
  2. llm_trainer/dpo_trainer.py +6 -3
  3. llm_trainer/eval.py +3 -30
  4. llm_trainer/generate_utils.py +9 -74
  5. llm_trainer/grpo_trainer.py +27 -28
  6. llm_trainer/loss.py +1 -1
  7. llm_trainer/partition_utils.py +146 -0
  8. llm_trainer/tokenizer.py +10 -10
  9. llm_trainer/tools.py +0 -2
  10. llm_trainer/train_configs.py +5 -25
  11. llm_trainer/trainer.py +28 -67
  12. llm_trainer/utils.py +0 -1
  13. {project_llm_trainer-0.4.15.dist-info → project_llm_trainer-0.5.1.dist-info}/METADATA +1 -1
  14. project_llm_trainer-0.5.1.dist-info/RECORD +33 -0
  15. llm_trainer/dcp.py +0 -93
  16. llm_trainer/ds_model_params.py +0 -72
  17. llm_trainer/fsdp_checkpoint.py +0 -52
  18. llm_trainer/fsdp_model_params.py +0 -39
  19. llm_trainer/model_params.py +0 -28
  20. llm_trainer/parallel_fsdp.py +0 -121
  21. project_llm_trainer-0.4.15.dist-info/RECORD +0 -38
  22. {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/calc_intermediate_size +0 -0
  23. {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/ddp_train +0 -0
  24. {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/ds_train +0 -0
  25. {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/plot_loss +0 -0
  26. {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/plot_lr +0 -0
  27. {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/py_train +0 -0
  28. {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/smart_train +0 -0
  29. {project_llm_trainer-0.4.15.dist-info → project_llm_trainer-0.5.1.dist-info}/WHEEL +0 -0
  30. {project_llm_trainer-0.4.15.dist-info → project_llm_trainer-0.5.1.dist-info}/top_level.txt +0 -0
llm_trainer/tokenizer.py CHANGED
@@ -26,8 +26,8 @@ class Tokenizer:
26
26
  self.text_user = '<user>'
27
27
  self.text_assistant = '<assistant>'
28
28
 
29
- self.text_reasoning_start = '<reasoning>'
30
- self.text_reasoning_end = '</reasoning>'
29
+ self.text_think_start = '<think>'
30
+ self.text_think_end = '</think>'
31
31
 
32
32
  self.text_answer_start = '<answer>'
33
33
  self.text_answer_end = '</answer>'
@@ -47,8 +47,8 @@ class Tokenizer:
47
47
  additional_special_tokens = [
48
48
  AddedToken(self.text_user, lstrip=False, rstrip=False),
49
49
  AddedToken(self.text_assistant, lstrip=False, rstrip=False),
50
- AddedToken(self.text_reasoning_start, lstrip=False, rstrip=False),
51
- AddedToken(self.text_reasoning_end, lstrip=False, rstrip=False),
50
+ AddedToken(self.text_think_start, lstrip=False, rstrip=False),
51
+ AddedToken(self.text_think_end, lstrip=False, rstrip=False),
52
52
  AddedToken(self.text_answer_start, lstrip=False, rstrip=False),
53
53
  AddedToken(self.text_answer_end, lstrip=False, rstrip=False),
54
54
  AddedToken(self.text_system, lstrip=False, rstrip=False),
@@ -69,8 +69,8 @@ class Tokenizer:
69
69
  self.user = self.tokenizer.convert_tokens_to_ids(self.text_user)
70
70
  self.assistant = self.tokenizer.convert_tokens_to_ids(self.text_assistant)
71
71
 
72
- self.reasoning_start = self.tokenizer.convert_tokens_to_ids(self.text_reasoning_start)
73
- self.reasoning_end = self.tokenizer.convert_tokens_to_ids(self.text_reasoning_end)
72
+ self.think_start = self.tokenizer.convert_tokens_to_ids(self.text_think_start)
73
+ self.think_end = self.tokenizer.convert_tokens_to_ids(self.text_think_end)
74
74
 
75
75
  self.answer_start = self.tokenizer.convert_tokens_to_ids(self.text_answer_start)
76
76
  self.answer_end = self.tokenizer.convert_tokens_to_ids(self.text_answer_end)
@@ -140,9 +140,9 @@ class Tokenizer:
140
140
  {"role":"user", "content":"hello?"},
141
141
  {"role":"assistant", "content":"hello"},
142
142
  {"role":"user", "content":"hello hello?"},
143
- {"role":"assistant", "reasoning":"thinking", "content":"hello hello"},
143
+ {"role":"assistant", "think":"thinking", "content":"hello hello"},
144
144
  ]
145
- <system>{system_prompt}</s><user>hello?</s><assistant>hello</s><user>hello hello?</s><assistant><reasoning>thinking</reasoning><answer>hello hello</answer></s>
145
+ <system>{system_prompt}</s><user>hello?</s><assistant>hello</s><user>hello hello?</s><assistant><think>thinking</think><answer>hello hello</answer></s>
146
146
  """
147
147
 
148
148
  chat_template = ''
@@ -154,8 +154,8 @@ class Tokenizer:
154
154
  if add_answer_tag_for_assistant and role == 'assistant':
155
155
  content = f"{self.text_answer_start}{content}{self.text_answer_end}"
156
156
 
157
- if 'reasoning' in conversation:
158
- content = f"{self.text_reasoning_start}{conversation['reasoning']}{self.text_reasoning_end}{content}"
157
+ if 'think' in conversation:
158
+ content = f"{self.text_think_start}{conversation['think']}{self.text_think_end}{content}"
159
159
 
160
160
  chat_template = f"{chat_template}{support_roles[role]}{content}{self.text_end}"
161
161
 
llm_trainer/tools.py CHANGED
@@ -3,7 +3,6 @@ from abc import ABC, abstractmethod
3
3
  import torch
4
4
  from .tokenizer import Tokenizer
5
5
  from .parallel_ds import DsParallel
6
- from .parallel_fsdp import FsdpParallel
7
6
  from .parallel_ddp import DdpParallel
8
7
  from .parallel_none import NoneParallel
9
8
  from .log import log
@@ -11,7 +10,6 @@ from .log import log
11
10
 
12
11
  parallel_types = {
13
12
  'ds': DsParallel,
14
- 'fsdp': FsdpParallel,
15
13
  'ddp': DdpParallel,
16
14
  'none': NoneParallel
17
15
  }
@@ -1,8 +1,7 @@
1
- from typing import Optional, Union, Set, Type, Callable, List, Mapping, Any
1
+ from typing import Optional, Union, Callable, List, Mapping, Any
2
2
  from dataclasses import dataclass, field
3
3
 
4
4
  import torch
5
- from torch import nn
6
5
  from llm_model import ModelConfig, VLMConfig
7
6
  from .tools import FileDataset
8
7
 
@@ -33,6 +32,9 @@ class DsZeROConfig:
33
32
  reduce_bucket_size: Optional[Union[str, int]] = 5e8
34
33
  contiguous_gradients: Optional[bool] = True
35
34
 
35
+ @dataclass(kw_only=True)
36
+ class DsZero0Config(DsZeROConfig):
37
+ stage: int = field(default=0, init=False)
36
38
 
37
39
  @dataclass(kw_only=True)
38
40
  class DsZero1Config(DsZeROConfig):
@@ -84,26 +86,6 @@ class DsConfig:
84
86
  activation_checkpointing: Optional[DsActivationCheckpointingConfig] = None
85
87
 
86
88
 
87
- @dataclass(kw_only=True)
88
- class FsdpConfig:
89
- """
90
- fsdp训练模式配置项
91
- Args:
92
- transformer_layer_cls (`Set[Type[nn.Module]]`, *optional*, default is None):
93
- 提供transformer层的类
94
- wrap_policy_num_params (`int`, *optional*, default is -1):
95
- size_based_auto_wrap_policy的min_num_params参数,-1不生效该策略
96
- cpu_offload (`bool`, *optional*, default is False):
97
- 是否使用cpu卸载
98
- offload_params (`bool`, default is False):
99
- 是否卸载参数,在cpu_offload为True时生效
100
- """
101
- transformer_layer_cls: Optional[Set[Type[nn.Module]]] = None
102
- wrap_policy_num_params: int = -1
103
- cpu_offload: bool = False
104
- offload_params: bool = False
105
-
106
-
107
89
  @dataclass(kw_only=True)
108
90
  class DataLoaderConfig:
109
91
  """
@@ -157,6 +139,7 @@ class GRPOConfig:
157
139
  clip_eps: float = 0.2
158
140
  kl_weight: float = 0.01
159
141
  group_size: int = 12
142
+ mixup_alpha: float = 1.0
160
143
  gen_max_new_tokens: Optional[int] = None
161
144
  gen_temperature: Optional[float] = None
162
145
  gen_k: Optional[int] = None
@@ -210,8 +193,6 @@ class TrainConfig:
210
193
  每隔多少个batch进行模型eval
211
194
  lr_config (`LrConfig`):
212
195
  lr配置项
213
- fsdp_config: (`FsdpConfig`):
214
- fsdp训练模式配置项
215
196
  data_loader_config: (`DataLoaderConfig`):
216
197
  data loader配置项
217
198
  kd_config: (`KDConfig`, *Optional*, default is None):
@@ -231,7 +212,6 @@ class TrainConfig:
231
212
  lr_config: LrConfig = field(default_factory=LrConfig)
232
213
 
233
214
  ds_config: DsConfig = field(default_factory=DsConfig)
234
- fsdp_config: FsdpConfig = field(default_factory=FsdpConfig)
235
215
 
236
216
  kd_config: Optional[KDConfig] = None
237
217
  dpo_config: Optional[DPOConfig] = None
llm_trainer/trainer.py CHANGED
@@ -1,21 +1,18 @@
1
1
  import time
2
2
  from contextlib import nullcontext
3
- from typing import Optional, Tuple, List, Dict, Any, Union
3
+ from typing import Optional, Tuple, List, Dict, Any
4
4
 
5
5
  import torch
6
- from torch import nn
7
6
  import torch.distributed as dist
8
- from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
9
7
  from torch.utils.data import Dataset
10
8
  from llm_model import LlmModel, VlmModel
11
9
 
12
10
  from .parallel_ds import DsParallel
13
- from .parallel_fsdp import FsdpParallel
14
11
  from .tools import TrainerTools
15
12
  from .loss import LMLoss, KDLoss
16
13
  from .dataset import TextDataset
17
- from .model_params import copy_model_params
18
14
  from .eval import submit_gen_task
15
+ from .partition_utils import unwrap_model_for_generation
19
16
 
20
17
  from .train_configs import (
21
18
  TrainConfig,
@@ -78,7 +75,6 @@ class Trainer:
78
75
 
79
76
  self.train_model, self.optimizer = self._init_train_model_and_optim(initial_lr, parallel_kwargs, use_ds_optim)
80
77
  self.lr_scheduler = self._init_lr_scheduler(initial_lr)
81
- self.eval_model: Optional[nn.Module] = self._init_eval_model()
82
78
 
83
79
  self.criterion, self.kd_loss = self._init_loss()
84
80
 
@@ -86,9 +82,7 @@ class Trainer:
86
82
  device_type=TrainerTools().parallel.device_type,
87
83
  dtype=TrainerTools().dtype,
88
84
  enabled=True,
89
- # fsdp模式,需要将cache_enabled设置为false
90
- # https://www.zhihu.com/question/642793891
91
- cache_enabled=False if isinstance(self.train_model, FSDP) else None
85
+ cache_enabled=None
92
86
  ) if TrainerTools().use_amp else nullcontext()
93
87
 
94
88
  load_checkpoint(
@@ -176,12 +170,6 @@ class Trainer:
176
170
 
177
171
  return model, optim
178
172
 
179
- def _init_eval_model(self) -> Optional[nn.Module]:
180
- if TrainerTools().parallel.is_main_process:
181
- return self._new_model(self.train_config).to(device='cpu', dtype=TrainerTools().dtype)
182
-
183
- return None
184
-
185
173
  def _init_lr_scheduler(self, initial_lr: float) -> LRScheduler:
186
174
  if self.train_config.lr_config.enable_lr_scheduler:
187
175
  min_lr = self.train_config.lr_config.min_lr
@@ -313,13 +301,6 @@ class Trainer:
313
301
  activation_checkpointing['number_checkpoints'] = activation_checkpointing_config.number_checkpoints
314
302
 
315
303
  parallel_kwargs['activation_checkpointing'] = activation_checkpointing
316
- elif isinstance(TrainerTools().parallel, FsdpParallel) and self.train_config.fsdp_config:
317
- parallel_kwargs = {
318
- 'transformer_layer_cls': self.train_config.fsdp_config.transformer_layer_cls,
319
- 'wrap_policy_num_params': self.train_config.fsdp_config.wrap_policy_num_params,
320
- 'cpu_offload': self.train_config.fsdp_config.cpu_offload,
321
- 'offload_params': self.train_config.fsdp_config.offload_params
322
- }
323
304
 
324
305
  dataloader_args = self.train_config.data_loader_config
325
306
  data_loader_kwargs = {
@@ -441,54 +422,35 @@ class Trainer:
441
422
 
442
423
  raise e
443
424
 
444
- def _on_batch_end(
445
- self,
446
- tag: str
447
- ):
448
- copy_model_params(_from=self.train_model, _to=self.eval_model)
425
+ def _eval(self, tag: str):
426
+ with unwrap_model_for_generation(self.train_model) as generate_model:
427
+ if TrainerTools().parallel.is_main_process:
428
+ generate_model.eval()
429
+ eval_prompt, eval_image_tag = self._get_eval_data()
430
+
431
+ if isinstance(self.train_config, VLMConfig) and self.pixel_values_provider and eval_image_tag:
432
+ eval_pixel_values = self.pixel_values_provider([eval_image_tag])
433
+ else:
434
+ eval_pixel_values = None
435
+
436
+ submit_gen_task(
437
+ generate_model,
438
+ self.train_config.eval_config,
439
+ tag=tag,
440
+ prompt=eval_prompt,
441
+ pixel_values=eval_pixel_values,
442
+ max_position_embeddings=self.train_config.model_config.max_position_embeddings,
443
+ tokens_per_image=self.tokens_per_image
444
+ )
445
+ generate_model.train()
449
446
 
450
- if TrainerTools().parallel.is_main_process:
451
- eval_prompt, eval_image_tag = self._get_eval_data()
452
- if isinstance(self.train_config, VLMConfig) and self.pixel_values_provider and eval_image_tag:
453
- eval_pixel_values = self.pixel_values_provider([eval_image_tag])
454
- else:
455
- eval_pixel_values = None
456
-
457
- submit_gen_task(
458
- self.eval_model,
459
- self.train_config.eval_config,
460
- tag=f'sign:batch/{tag}',
461
- prompt=eval_prompt,
462
- pixel_values=eval_pixel_values,
463
- max_position_embeddings=self.train_config.model_config.max_position_embeddings,
464
- tokens_per_image=self.tokens_per_image
465
- )
466
447
  TrainerTools().parallel.wait()
467
448
 
468
- def _on_epoch_end(
469
- self,
470
- tag: str
471
- ):
472
- copy_model_params(_from=self.train_model, _to=self.eval_model)
473
-
474
- if TrainerTools().parallel.is_main_process:
475
- eval_prompt, eval_image_tag = self._get_eval_data()
476
- if isinstance(self.train_config, VLMConfig) and self.pixel_values_provider and eval_image_tag:
477
- eval_pixel_values = self.pixel_values_provider([eval_image_tag])
478
- else:
479
- eval_pixel_values = None
480
-
481
- submit_gen_task(
482
- self.eval_model,
483
- self.train_config.eval_config,
484
- tag=f'sign:epoch/{tag}',
485
- prompt=eval_prompt,
486
- pixel_values=eval_pixel_values,
487
- max_position_embeddings=self.train_config.model_config.max_position_embeddings,
488
- tokens_per_image=self.tokens_per_image
489
- )
449
+ def _on_batch_end(self, tag: str):
450
+ self._eval(f'sign:batch/{tag}')
490
451
 
491
- TrainerTools().parallel.wait()
452
+ def _on_epoch_end(self, tag: str):
453
+ self._eval(f'sign:epoch/{tag}')
492
454
 
493
455
  def _on_file_start(
494
456
  self,
@@ -574,7 +536,6 @@ class Trainer:
574
536
  if need_update_grad:
575
537
  loss_tensor = torch.tensor(loss_accumulation, device=TrainerTools().parallel.device)
576
538
 
577
- # todo check all_reduce??
578
539
  if TrainerTools().parallel.parallel_train:
579
540
  dist.all_reduce(loss_tensor, dist.ReduceOp.AVG)
580
541
 
llm_trainer/utils.py CHANGED
@@ -1,5 +1,4 @@
1
1
  import random
2
- from typing import Tuple, Optional
3
2
  import torch
4
3
  from torch.nn.utils.rnn import pad_sequence
5
4
  import torch.nn.functional as F
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: project_llm_trainer
3
- Version: 0.4.15
3
+ Version: 0.5.1
4
4
  Summary: LLM and VLM trainer
5
5
  Author: qibin
6
6
  Author-email: qibin0506@gmail.com
@@ -0,0 +1,33 @@
1
+ llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
2
+ llm_trainer/checkpoint.py,sha256=xTmmQSJ_jQDVSTT3km1p_8eRrc7yE_dEsi92z9OX5ec,3251
3
+ llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
4
+ llm_trainer/dpo_trainer.py,sha256=wMREatLt0I8Ajdm_sI2U8Zj-IN1L6txP9s_tH1oI3-s,11431
5
+ llm_trainer/ds_checkpoint.py,sha256=wz48HoLBBt8QGO1tXfvJwrXoiGtPG_gjwHfEqARllso,2175
6
+ llm_trainer/eval.py,sha256=fjASCILU3fSPJxo9cP3rIXEEnkc5ZlUyHqXlZtUiHrw,888
7
+ llm_trainer/generate_utils.py,sha256=CbJ3mfAD6DkQ0GUHcJQ1AK02m-ocwmd-BPXEpiwvNNQ,14933
8
+ llm_trainer/grpo_trainer.py,sha256=qiC3KwxYPSB9UKqyk4eSRvORP3b6GM-2ozqI8u3QvI0,15568
9
+ llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
10
+ llm_trainer/loss.py,sha256=NZCQeUXnLSj__mmDflE8g89KgE0emAJXIab0IERCLno,6023
11
+ llm_trainer/parallel.py,sha256=DQu8GqEFxD99HQ6hKuIxxyKi-05dMO33eMhImYlPuOI,4468
12
+ llm_trainer/parallel_ddp.py,sha256=Pob9vUlBZnkL4oP1Re11kFob7nufMSE96pn7m7fuOEM,1345
13
+ llm_trainer/parallel_ds.py,sha256=oy8RRxHud3rACWubFlJqqd0pjPEQhKeAPGPQUSdJX2c,1145
14
+ llm_trainer/parallel_none.py,sha256=TG6Pm829Dg-yQu-97O-EHV3FCARBlNcP47KkGFAs16E,676
15
+ llm_trainer/partition_utils.py,sha256=xzv8kwlbKp3dai2pBwX89gN5ymeHk1bGbTkGru5H-UM,5167
16
+ llm_trainer/scheduler.py,sha256=Xz8HhwoRMjRe41sf_NHhpZfkTlEs0I2MYusvMY6hCVw,3531
17
+ llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,1783
18
+ llm_trainer/tokenizer.py,sha256=SSpgXtb0e1NtQqRW0gCq09TTZi47umggy-Fh5EMHKJg,6708
19
+ llm_trainer/tools.py,sha256=yF17lp6oOfLe2XJeKDQ1juZcbv-6vFamJSLwEeArduA,2975
20
+ llm_trainer/train_configs.py,sha256=m57W71SI5VCCU9aJ_nJkB-3AJrSGiNXmV28rdpuYmLg,7332
21
+ llm_trainer/trainer.py,sha256=zTJVyY1cAjJdTkyXCOy2ZPVP18SOMLdWhD54Mz2JRe4,25314
22
+ llm_trainer/utils.py,sha256=LWNhyQ0NDEZ9mZtk2Ryvh6EulvHIaUGIflugSpqmeFI,6791
23
+ project_llm_trainer-0.5.1.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
24
+ project_llm_trainer-0.5.1.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
25
+ project_llm_trainer-0.5.1.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
26
+ project_llm_trainer-0.5.1.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
27
+ project_llm_trainer-0.5.1.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
28
+ project_llm_trainer-0.5.1.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
29
+ project_llm_trainer-0.5.1.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
30
+ project_llm_trainer-0.5.1.dist-info/METADATA,sha256=x-Bobn0EH7wyKznJydUeVLK9sdIrkBmDYDbEpyG4pKc,195
31
+ project_llm_trainer-0.5.1.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
32
+ project_llm_trainer-0.5.1.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
33
+ project_llm_trainer-0.5.1.dist-info/RECORD,,
llm_trainer/dcp.py DELETED
@@ -1,93 +0,0 @@
1
- import os
2
- from typing import Optional, Dict, Any
3
- from torch import nn
4
- from torch.optim import Optimizer
5
- import torch.distributed.checkpoint as dcp
6
- from torch.distributed.checkpoint.stateful import Stateful
7
- from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
8
- from torch.distributed.checkpoint.format_utils import dcp_to_torch_save, torch_save_to_dcp
9
-
10
- DEFAULT_CHECKPOINT_DIR = "checkpoint"
11
-
12
- class AppState(Stateful):
13
- def __init__(self, model: nn.Module, optimizer: Optimizer):
14
- self.model = model
15
- self.optimizer = optimizer
16
-
17
- def state_dict(self) -> Dict[str, Any]:
18
- model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
19
- return {
20
- 'model_state_dict': model_state_dict,
21
- 'optim_state_dict': optimizer_state_dict
22
- }
23
-
24
- def load_state_dict(self, state_dict: Dict[str, Any]):
25
- set_state_dict(
26
- model=self.model,
27
- optimizers=self.optimizer,
28
- model_state_dict=state_dict['model_state_dict'],
29
- optim_state_dict=state_dict['optim_state_dict']
30
- )
31
-
32
-
33
- def save_dcp(
34
- model: nn.Module,
35
- optimizer: Optimizer,
36
- suffix: Optional[str] = None
37
- ):
38
- checkpoint_id = os.environ.get('DIST_CHECKPOINT_DIR', DEFAULT_CHECKPOINT_DIR)
39
- if suffix:
40
- checkpoint_id = f"{checkpoint_id}_{suffix}"
41
-
42
- state_dict = {'app': AppState(model, optimizer)}
43
-
44
- # fs_storage_writer = dcp.FileSystemWriter(checkpoint_id, overwrite=True)
45
- # dcp.save(state_dict=state_dict, storage_writer=fs_storage_writer)
46
- dcp.save(state_dict=state_dict, checkpoint_id=checkpoint_id)
47
-
48
-
49
- def load_dcp(
50
- model: nn.Module,
51
- optimizer: Optional[Optimizer] = None,
52
- suffix: Optional[str] = None
53
- ):
54
- checkpoint_id = os.environ.get('DIST_CHECKPOINT_DIR', DEFAULT_CHECKPOINT_DIR)
55
- if suffix:
56
- checkpoint_id = f"{checkpoint_id}_{suffix}"
57
-
58
- if os.path.exists(checkpoint_id):
59
- state_dict = {'app': AppState(model, optimizer)}
60
- # AppState帮助加载到state_dict中, 然后加载到model中
61
- dcp.load(state_dict=state_dict, checkpoint_id=checkpoint_id)
62
-
63
- # if isinstance(model, FSDP):
64
- # state_dict = {'app': AppState(model, optimizer)}
65
- # # AppState帮助加载到state_dict中, 然后加载到model中
66
- # dcp.load(state_dict=state_dict, checkpoint_id=checkpoint_id)
67
- # else:
68
- # state_dict = {"model_state_dict": model.state_dict()}
69
- #
70
- # if optimizer:
71
- # state_dict.update({'optim_state_dict': optimizer.state_dict()})
72
- #
73
- # # since no progress group is initialized, DCP will disable any collectives.
74
- # # 加载到state_dict中,然后通过model.load_state_dict加载到model中
75
- # dcp.load(
76
- # state_dict=state_dict,
77
- # checkpoint_id=checkpoint_id,
78
- # )
79
- #
80
- # model.load_state_dict(state_dict["model_state_dict"])
81
- # if optimizer:
82
- # optimizer.load_state_dict(state_dict["optim_state_dict"])
83
-
84
- def convert_dcp_to_pth(pth_path: str):
85
- dcp_path = os.environ.get('DIST_CHECKPOINT_DIR', DEFAULT_CHECKPOINT_DIR)
86
- if os.path.exists(dcp_path):
87
- # convert dcp model to torch.save (assumes checkpoint was generated as above)
88
- dcp_to_torch_save(dcp_path, pth_path)
89
-
90
- def convert_pth_to_dcp(pth_path: str):
91
- if os.path.exists(pth_path):
92
- # converts the torch.save model back to DCP
93
- torch_save_to_dcp(pth_path, os.environ.get('DIST_CHECKPOINT_DIR', DEFAULT_CHECKPOINT_DIR))
@@ -1,72 +0,0 @@
1
- from typing import Optional
2
- from torch import nn
3
- import torch.distributed as dist
4
-
5
- from .tools import TrainerTools
6
-
7
- try:
8
- import deepspeed
9
- from deepspeed import DeepSpeedEngine
10
- from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
11
- except: ...
12
-
13
-
14
- def _get_ds_full_state_dict_on_rank0(model: DeepSpeedEngine) -> Optional[dict]:
15
- """
16
- 需要在所有rank上调用,然后只有rank0有值
17
- """
18
-
19
- if model.zero_optimization_stage() != 3:
20
- if TrainerTools().parallel.is_main_process:
21
- return {k: v.cpu().clone() for k, v in model.module.state_dict().items()}
22
- return None
23
-
24
- # --- ZeRO-3 ---
25
- # 只调用一次 GatheredParameters,传入所有参数
26
- with deepspeed.zero.GatheredParameters(model.parameters(), modifier_rank=0):
27
- if TrainerTools().parallel.is_main_process:
28
- # 在这个 'with' 代码块内,rank 0 上的 model.module 拥有完整的参数
29
- # 所以我们可以像操作普通模型一样直接调用 state_dict()
30
- full_state_dict = model.module.state_dict()
31
-
32
- # 将其克隆到 CPU 并返回
33
- return {k: v.cpu().clone() for k, v in full_state_dict.items()}
34
-
35
- # 其他 rank 执行到这里时,上下文结束,直接返回 None
36
- return None
37
-
38
- # # ZeRO-3
39
- # state_dict_on_rank_0 = {}
40
- # for param_name, param in model.module.named_parameters():
41
- # if hasattr(param, 'ds_id'):
42
- # with deepspeed.zero.GatheredParameters(param, modifier_rank=0):
43
- # if TrainerTools().parallel.is_main_process:
44
- # state_dict_on_rank_0[param_name] = param.data.to(torch.float32).cpu().clone()
45
- # else:
46
- # if TrainerTools().parallel.is_main_process:
47
- # state_dict_on_rank_0[param_name] = param.data.to(torch.float32).cpu().clone()
48
- #
49
- # return state_dict_on_rank_0 if TrainerTools().parallel.is_main_process else None
50
-
51
-
52
- def get_ds_model_params(model: nn.Module, only_rank0=False):
53
- """
54
- 从一个正在运行的 DeepSpeedEngine 中高效地提取完整的 FP32 state_dict,
55
- 兼容 ZeRO Stages 0, 1, 2, 3。
56
- 包含了对 ZeRO-3 中分片参数的正确处理。
57
- """
58
-
59
- assert isinstance(model, DeepSpeedEngine)
60
- state_dict = _get_ds_full_state_dict_on_rank0(model)
61
-
62
- # 现在,只有 rank 0 上的 state_dict 是一个有效的字典,其他 rank 上是 None。
63
- # 我们需要将其广播给所有进程。
64
- if not only_rank0 and TrainerTools().parallel.world_size > 1:
65
- # 准备一个列表,rank 0 有数据,其他 rank 是占位符
66
- object_list = [state_dict] if TrainerTools().parallel.is_main_process else [None]
67
- # 执行广播,这个操作是阻塞的,会同步所有进程
68
- dist.broadcast_object_list(object_list, src=0)
69
- # 所有进程从列表中获取广播后的 state_dict 副本
70
- state_dict = object_list[0]
71
-
72
- return state_dict
@@ -1,52 +0,0 @@
1
- import os
2
- from typing import Optional, Union, Tuple
3
- import torch
4
- from torch import nn
5
- from torch.optim import Optimizer
6
- from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
7
-
8
- from .tools import TrainerTools
9
-
10
- DEFAULT_CHECKPOINT_NAME = "checkpoint.pth"
11
-
12
- def save_fsdp_checkpoint(
13
- model: nn.Module,
14
- optimizer: Optional[Optimizer] = None,
15
- suffix: Optional[str] = None
16
- ):
17
- # 未经过测试 参考:https://doc.hfai.high-flyer.cn/haiscale/haiscale_fsdp.html
18
- # 是否使用rank0_only=True?
19
- with FSDP.summon_full_params(
20
- module=model,
21
- rank0_only=True,
22
- writeback=False,
23
- offload_to_cpu=True
24
- ):
25
- if TrainerTools().parallel.is_main_process:
26
- checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
27
- if suffix:
28
- checkpoint_name = f"{checkpoint_name}_{suffix}"
29
-
30
- ckpt = {'model_state_dict': model.state_dict()}
31
- if optimizer:
32
- ckpt.update({'optim_state_dict': optimizer.state_dict()})
33
-
34
- torch.save(ckpt, checkpoint_name)
35
-
36
-
37
- def load_fsdp_checkpoint(
38
- model: nn.Module,
39
- optimizer: Optional[Optimizer] = None,
40
- device: Optional[Union[torch.device, str]] = None,
41
- suffix: Optional[str] = None
42
- ):
43
- checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
44
- if suffix:
45
- checkpoint_name = f"{checkpoint_name}_{suffix}"
46
-
47
- with FSDP.summon_full_params(module=model):
48
- state_dict = torch.load(checkpoint_name, weights_only=True, map_location=device)
49
- model.load_state_dict(state_dict['model_state_dict'])
50
-
51
- if optimizer:
52
- optimizer.load_state_dict(state_dict['optim_state_dict'])
@@ -1,39 +0,0 @@
1
- from typing import Optional
2
- from torch import nn
3
- import torch.distributed as dist
4
-
5
- from .tools import TrainerTools
6
-
7
-
8
- def _get_fsdp_full_state_dict_on_rank0(model: nn.Module) -> Optional[dict]:
9
- """
10
- 可以在任意rank上调用,然后只有rank0有值
11
- """
12
-
13
- from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
14
- with FSDP.summon_full_params(model, writeback=False, offload_to_cpu=True):
15
- if TrainerTools().parallel.is_main_process:
16
- return {k: v.clone() for k, v in model.state_dict().items()}
17
-
18
- return None
19
-
20
-
21
- def get_fsdp_model_params(model: nn.Module, only_rank0=False):
22
- """
23
- 从一个 FSDP 包装的模型中高效地提取完整的 FP32 state_dict。
24
- 这个函数会聚合所有分片的参数,并确保所有 rank 都收到一个完整的副本。
25
- """
26
-
27
- state_dict = _get_fsdp_full_state_dict_on_rank0(model)
28
-
29
- # 现在,只有 rank 0 上的 state_dict 是一个有效的字典,其他 rank 上是 None。
30
- # 我们需要将其广播给所有进程。
31
- if not only_rank0 and TrainerTools().parallel.world_size > 1:
32
- # 准备一个列表,rank 0 有数据,其他 rank 是占位符
33
- object_list = [state_dict] if TrainerTools().parallel.is_main_process else [None]
34
- # 执行广播,这个操作是阻塞的,会同步所有进程
35
- dist.broadcast_object_list(object_list, src=0)
36
- # 所有进程从列表中获取广播后的 state_dict 副本
37
- state_dict = object_list[0]
38
-
39
- return state_dict
@@ -1,28 +0,0 @@
1
- from typing import Optional
2
- from torch import nn
3
- from torch.nn.parallel import DistributedDataParallel as DDP
4
-
5
- from .tools import TrainerTools
6
- from .parallel_ds import DsParallel
7
- from .parallel_fsdp import FsdpParallel
8
-
9
- def copy_model_params(
10
- _from: nn.Module,
11
- _to: Optional[nn.Module]
12
- ):
13
- """
14
- 必须在所有rank上调用,非rank0, _to可以设置为None
15
- """
16
- if isinstance(TrainerTools().parallel, DsParallel):
17
- from .ds_model_params import get_ds_model_params
18
- state_dict = get_ds_model_params(_from, only_rank0=_to is None)
19
- elif isinstance(TrainerTools().parallel, FsdpParallel):
20
- from .fsdp_model_params import get_fsdp_model_params
21
- state_dict = get_fsdp_model_params(_from, only_rank0=_to is None)
22
- elif isinstance(_from, DDP):
23
- state_dict = _from.module.state_dict()
24
- else:
25
- state_dict = _from.state_dict()
26
-
27
- if _to and state_dict:
28
- _to.load_state_dict(state_dict)