project-llm-trainer 0.4.15__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

Files changed (30) hide show
  1. llm_trainer/checkpoint.py +0 -50
  2. llm_trainer/dpo_trainer.py +6 -3
  3. llm_trainer/eval.py +3 -30
  4. llm_trainer/generate_utils.py +9 -74
  5. llm_trainer/grpo_trainer.py +27 -28
  6. llm_trainer/loss.py +1 -1
  7. llm_trainer/partition_utils.py +146 -0
  8. llm_trainer/tokenizer.py +10 -10
  9. llm_trainer/tools.py +0 -2
  10. llm_trainer/train_configs.py +5 -25
  11. llm_trainer/trainer.py +28 -67
  12. llm_trainer/utils.py +0 -1
  13. {project_llm_trainer-0.4.15.dist-info → project_llm_trainer-0.5.1.dist-info}/METADATA +1 -1
  14. project_llm_trainer-0.5.1.dist-info/RECORD +33 -0
  15. llm_trainer/dcp.py +0 -93
  16. llm_trainer/ds_model_params.py +0 -72
  17. llm_trainer/fsdp_checkpoint.py +0 -52
  18. llm_trainer/fsdp_model_params.py +0 -39
  19. llm_trainer/model_params.py +0 -28
  20. llm_trainer/parallel_fsdp.py +0 -121
  21. project_llm_trainer-0.4.15.dist-info/RECORD +0 -38
  22. {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/calc_intermediate_size +0 -0
  23. {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/ddp_train +0 -0
  24. {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/ds_train +0 -0
  25. {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/plot_loss +0 -0
  26. {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/plot_lr +0 -0
  27. {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/py_train +0 -0
  28. {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/smart_train +0 -0
  29. {project_llm_trainer-0.4.15.dist-info → project_llm_trainer-0.5.1.dist-info}/WHEEL +0 -0
  30. {project_llm_trainer-0.4.15.dist-info → project_llm_trainer-0.5.1.dist-info}/top_level.txt +0 -0
@@ -1,121 +0,0 @@
1
- from typing import Optional, Tuple
2
- import functools
3
- import torch
4
- from torch import nn
5
- from torch.distributed.fsdp import (
6
- FullyShardedDataParallel as FSDP,
7
- MixedPrecision,
8
- ShardingStrategy,
9
- BackwardPrefetch,
10
- CPUOffload,
11
- )
12
-
13
- from torch.distributed.fsdp.wrap import (
14
- size_based_auto_wrap_policy,
15
- transformer_auto_wrap_policy,
16
- always_wrap_policy,
17
- enable_wrap,
18
- wrap,
19
- )
20
-
21
- from .parallel import Parallel
22
-
23
- class FsdpParallel(Parallel):
24
- def __init__(self):
25
- super().__init__()
26
-
27
- def process(
28
- self,
29
- model: nn.Module,
30
- optimizer: torch.optim.Optimizer,
31
- kwargs: Optional[dict] = None,
32
- save_instance: bool = True
33
- ) -> Tuple[nn.Module, torch.optim.Optimizer]:
34
- """
35
- :param model:
36
- :param optimizer:
37
- :param kwargs:
38
- "wrap_policy_num_params" int size_based_auto_wrap_policy的最小参数量
39
- "cpu_offload" bool 是否使用cpu卸载
40
- "offload_params" bool 是否卸载参数,在cpu_offload为True时生效
41
- :param save_instance
42
- :return:
43
- """
44
-
45
- model.to(self.device)
46
-
47
- if self._use_compile:
48
- model = torch.compile(model)
49
-
50
- if self._use_parallel:
51
- if 'transformer_layer_cls' in kwargs:
52
- auto_wrap_policy = functools.partial(
53
- transformer_auto_wrap_policy,
54
- transformer_layer_cls=kwargs['transformer_layer_cls']
55
- )
56
- elif 'wrap_policy_num_params' in kwargs:
57
- auto_wrap_policy = functools.partial(
58
- size_based_auto_wrap_policy,
59
- min_num_params=kwargs['wrap_policy_num_params']
60
- )
61
- else:
62
- auto_wrap_policy = None
63
-
64
- if 'cpu_offload' in kwargs:
65
- offload_params = False
66
- if 'offload_params' in kwargs:
67
- offload_params = kwargs['offload_params']
68
-
69
- # 选择配置 cpu_offload,以便在计算中不使用包装参数时将这些参数卸载到 CPU。
70
- # 这可以进一步提高内存效率,但代价是主机和设备之间的数据传输开销。
71
- cpu_offload = CPUOffload(offload_params=offload_params)
72
- else:
73
- cpu_offload = None
74
-
75
- if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
76
- mixed_precision = MixedPrecision(
77
- param_dtype=torch.bfloat16,
78
- # Gradient communication precision.
79
- reduce_dtype=torch.bfloat16,
80
- # Buffer precision.
81
- buffer_dtype=torch.bfloat16,
82
- )
83
- else:
84
- mixed_precision = None
85
-
86
- raw_model = model
87
-
88
- # device_mesh = init_device_mesh("cuda", (self.world_size,))
89
- # model = FSDP(
90
- # model,
91
- # auto_wrap_policy=auto_wrap_policy,
92
- # mixed_precision=mixed_precision,
93
- # cpu_offload=cpu_offload,
94
- # device_id=torch.cuda.current_device(),
95
- # device_mesh=device_mesh
96
- # )
97
-
98
- model = FSDP(
99
- model,
100
- sharding_strategy=ShardingStrategy.FULL_SHARD,
101
- auto_wrap_policy=auto_wrap_policy,
102
- mixed_precision=mixed_precision,
103
- cpu_offload=cpu_offload,
104
- device_id=torch.cuda.current_device(),
105
- process_group=None,
106
- # use_orig_params=True,
107
- # backward_prefetch=BackwardPrefetch.BACKWARD_PRE, # bit faster async comms, bit higher memory
108
- # limit_all_gathers=False,
109
- # forward_prefetch=True,
110
- )
111
- else:
112
- model = model
113
- raw_model = model
114
-
115
- if save_instance:
116
- self.raw_model = raw_model
117
- self.model = model
118
-
119
- return model, optimizer
120
-
121
-
@@ -1,38 +0,0 @@
1
- llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
2
- llm_trainer/checkpoint.py,sha256=AvUC1JLxuahKtg3VNW20VHIE3iIjpaMHIi_pyyDYVJ0,5043
3
- llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
4
- llm_trainer/dcp.py,sha256=PkD97DyrOtoTKn4FJsfL3VqAy4dxufgjdzJEz8-Cnoc,3635
5
- llm_trainer/dpo_trainer.py,sha256=o5lYxt6yVMCvoBqW_yTu9l6Ff-xjEu-CwdPVttu3H8E,11447
6
- llm_trainer/ds_checkpoint.py,sha256=wz48HoLBBt8QGO1tXfvJwrXoiGtPG_gjwHfEqARllso,2175
7
- llm_trainer/ds_model_params.py,sha256=Nwmv0YcBtO6ynC0dXallAD1rWkN22-elGfVjLaWp2Yg,2988
8
- llm_trainer/eval.py,sha256=NDm8PbXLch7xT81xPYPRCNrcrB_Xj5GDJSCxyVwUOp4,1524
9
- llm_trainer/fsdp_checkpoint.py,sha256=xsm71s9WeTaBvBvv6CbuGpwkmX3V6i3xmBcMTDfGxKc,1770
10
- llm_trainer/fsdp_model_params.py,sha256=MRjrs9zmMl-61a1l6188Ij5PSalzztOSp8E4evDvJXo,1541
11
- llm_trainer/generate_utils.py,sha256=tSbA_tLqSq5qJGHSOlPv5T3iRDZkbFg5ZvDAgJ_i_SE,17946
12
- llm_trainer/grpo_trainer.py,sha256=1gZXiL1pogLFecFQUGj9zCU_k66ryVjZciYyd8J5ph4,15998
13
- llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
14
- llm_trainer/loss.py,sha256=Yv3fsaVuZ5AhnGPJOr5vEMb_tM2urR6mCb4DBbrHHI8,6030
15
- llm_trainer/model_params.py,sha256=2f2W9KRCjyqSfEwxI3w5f6TPZaqq25WzY-nEc7aJxcs,970
16
- llm_trainer/parallel.py,sha256=DQu8GqEFxD99HQ6hKuIxxyKi-05dMO33eMhImYlPuOI,4468
17
- llm_trainer/parallel_ddp.py,sha256=Pob9vUlBZnkL4oP1Re11kFob7nufMSE96pn7m7fuOEM,1345
18
- llm_trainer/parallel_ds.py,sha256=oy8RRxHud3rACWubFlJqqd0pjPEQhKeAPGPQUSdJX2c,1145
19
- llm_trainer/parallel_fsdp.py,sha256=cQOdY8ou6m8OsR06PpFVn6GiyZlK9nefkcGyszUOIJk,4055
20
- llm_trainer/parallel_none.py,sha256=TG6Pm829Dg-yQu-97O-EHV3FCARBlNcP47KkGFAs16E,676
21
- llm_trainer/scheduler.py,sha256=Xz8HhwoRMjRe41sf_NHhpZfkTlEs0I2MYusvMY6hCVw,3531
22
- llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,1783
23
- llm_trainer/tokenizer.py,sha256=A7TYYUbtPf75kjCvWP7yBui4xZBObMk2aPem62YpwpY,6776
24
- llm_trainer/tools.py,sha256=O45-20wRmh-nyTfU-U-XtjbKAoe7boEIsUvWT_NaKx4,3041
25
- llm_trainer/train_configs.py,sha256=HKzH3nfMT1-SW4Htwa0KqYtMd6FAJcthR5IEo6di8us,8168
26
- llm_trainer/trainer.py,sha256=95ARdNDfalhZ7Ug-fDj3qIhWEiZQeX9n5WANhijIRLE,27140
27
- llm_trainer/utils.py,sha256=-ivhMF0d999va13S1wt2uBvtVw8Nvr3uBzhaUFKL04Q,6826
28
- project_llm_trainer-0.4.15.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
29
- project_llm_trainer-0.4.15.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
30
- project_llm_trainer-0.4.15.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
31
- project_llm_trainer-0.4.15.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
32
- project_llm_trainer-0.4.15.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
33
- project_llm_trainer-0.4.15.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
34
- project_llm_trainer-0.4.15.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
35
- project_llm_trainer-0.4.15.dist-info/METADATA,sha256=5sveZ3kkRMVCz9dI5_NI64o9tFBVsJhHhun9vwzzL9Q,196
36
- project_llm_trainer-0.4.15.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
37
- project_llm_trainer-0.4.15.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
38
- project_llm_trainer-0.4.15.dist-info/RECORD,,