project-llm-trainer 0.4.15__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/checkpoint.py +0 -50
- llm_trainer/dpo_trainer.py +6 -3
- llm_trainer/eval.py +3 -30
- llm_trainer/generate_utils.py +9 -74
- llm_trainer/grpo_trainer.py +27 -28
- llm_trainer/loss.py +1 -1
- llm_trainer/partition_utils.py +146 -0
- llm_trainer/tokenizer.py +10 -10
- llm_trainer/tools.py +0 -2
- llm_trainer/train_configs.py +5 -25
- llm_trainer/trainer.py +28 -67
- llm_trainer/utils.py +0 -1
- {project_llm_trainer-0.4.15.dist-info → project_llm_trainer-0.5.1.dist-info}/METADATA +1 -1
- project_llm_trainer-0.5.1.dist-info/RECORD +33 -0
- llm_trainer/dcp.py +0 -93
- llm_trainer/ds_model_params.py +0 -72
- llm_trainer/fsdp_checkpoint.py +0 -52
- llm_trainer/fsdp_model_params.py +0 -39
- llm_trainer/model_params.py +0 -28
- llm_trainer/parallel_fsdp.py +0 -121
- project_llm_trainer-0.4.15.dist-info/RECORD +0 -38
- {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/calc_intermediate_size +0 -0
- {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/ddp_train +0 -0
- {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/ds_train +0 -0
- {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/plot_loss +0 -0
- {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/plot_lr +0 -0
- {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/py_train +0 -0
- {project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/smart_train +0 -0
- {project_llm_trainer-0.4.15.dist-info → project_llm_trainer-0.5.1.dist-info}/WHEEL +0 -0
- {project_llm_trainer-0.4.15.dist-info → project_llm_trainer-0.5.1.dist-info}/top_level.txt +0 -0
llm_trainer/parallel_fsdp.py
DELETED
|
@@ -1,121 +0,0 @@
|
|
|
1
|
-
from typing import Optional, Tuple
|
|
2
|
-
import functools
|
|
3
|
-
import torch
|
|
4
|
-
from torch import nn
|
|
5
|
-
from torch.distributed.fsdp import (
|
|
6
|
-
FullyShardedDataParallel as FSDP,
|
|
7
|
-
MixedPrecision,
|
|
8
|
-
ShardingStrategy,
|
|
9
|
-
BackwardPrefetch,
|
|
10
|
-
CPUOffload,
|
|
11
|
-
)
|
|
12
|
-
|
|
13
|
-
from torch.distributed.fsdp.wrap import (
|
|
14
|
-
size_based_auto_wrap_policy,
|
|
15
|
-
transformer_auto_wrap_policy,
|
|
16
|
-
always_wrap_policy,
|
|
17
|
-
enable_wrap,
|
|
18
|
-
wrap,
|
|
19
|
-
)
|
|
20
|
-
|
|
21
|
-
from .parallel import Parallel
|
|
22
|
-
|
|
23
|
-
class FsdpParallel(Parallel):
|
|
24
|
-
def __init__(self):
|
|
25
|
-
super().__init__()
|
|
26
|
-
|
|
27
|
-
def process(
|
|
28
|
-
self,
|
|
29
|
-
model: nn.Module,
|
|
30
|
-
optimizer: torch.optim.Optimizer,
|
|
31
|
-
kwargs: Optional[dict] = None,
|
|
32
|
-
save_instance: bool = True
|
|
33
|
-
) -> Tuple[nn.Module, torch.optim.Optimizer]:
|
|
34
|
-
"""
|
|
35
|
-
:param model:
|
|
36
|
-
:param optimizer:
|
|
37
|
-
:param kwargs:
|
|
38
|
-
"wrap_policy_num_params" int size_based_auto_wrap_policy的最小参数量
|
|
39
|
-
"cpu_offload" bool 是否使用cpu卸载
|
|
40
|
-
"offload_params" bool 是否卸载参数,在cpu_offload为True时生效
|
|
41
|
-
:param save_instance
|
|
42
|
-
:return:
|
|
43
|
-
"""
|
|
44
|
-
|
|
45
|
-
model.to(self.device)
|
|
46
|
-
|
|
47
|
-
if self._use_compile:
|
|
48
|
-
model = torch.compile(model)
|
|
49
|
-
|
|
50
|
-
if self._use_parallel:
|
|
51
|
-
if 'transformer_layer_cls' in kwargs:
|
|
52
|
-
auto_wrap_policy = functools.partial(
|
|
53
|
-
transformer_auto_wrap_policy,
|
|
54
|
-
transformer_layer_cls=kwargs['transformer_layer_cls']
|
|
55
|
-
)
|
|
56
|
-
elif 'wrap_policy_num_params' in kwargs:
|
|
57
|
-
auto_wrap_policy = functools.partial(
|
|
58
|
-
size_based_auto_wrap_policy,
|
|
59
|
-
min_num_params=kwargs['wrap_policy_num_params']
|
|
60
|
-
)
|
|
61
|
-
else:
|
|
62
|
-
auto_wrap_policy = None
|
|
63
|
-
|
|
64
|
-
if 'cpu_offload' in kwargs:
|
|
65
|
-
offload_params = False
|
|
66
|
-
if 'offload_params' in kwargs:
|
|
67
|
-
offload_params = kwargs['offload_params']
|
|
68
|
-
|
|
69
|
-
# 选择配置 cpu_offload,以便在计算中不使用包装参数时将这些参数卸载到 CPU。
|
|
70
|
-
# 这可以进一步提高内存效率,但代价是主机和设备之间的数据传输开销。
|
|
71
|
-
cpu_offload = CPUOffload(offload_params=offload_params)
|
|
72
|
-
else:
|
|
73
|
-
cpu_offload = None
|
|
74
|
-
|
|
75
|
-
if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
|
|
76
|
-
mixed_precision = MixedPrecision(
|
|
77
|
-
param_dtype=torch.bfloat16,
|
|
78
|
-
# Gradient communication precision.
|
|
79
|
-
reduce_dtype=torch.bfloat16,
|
|
80
|
-
# Buffer precision.
|
|
81
|
-
buffer_dtype=torch.bfloat16,
|
|
82
|
-
)
|
|
83
|
-
else:
|
|
84
|
-
mixed_precision = None
|
|
85
|
-
|
|
86
|
-
raw_model = model
|
|
87
|
-
|
|
88
|
-
# device_mesh = init_device_mesh("cuda", (self.world_size,))
|
|
89
|
-
# model = FSDP(
|
|
90
|
-
# model,
|
|
91
|
-
# auto_wrap_policy=auto_wrap_policy,
|
|
92
|
-
# mixed_precision=mixed_precision,
|
|
93
|
-
# cpu_offload=cpu_offload,
|
|
94
|
-
# device_id=torch.cuda.current_device(),
|
|
95
|
-
# device_mesh=device_mesh
|
|
96
|
-
# )
|
|
97
|
-
|
|
98
|
-
model = FSDP(
|
|
99
|
-
model,
|
|
100
|
-
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
|
101
|
-
auto_wrap_policy=auto_wrap_policy,
|
|
102
|
-
mixed_precision=mixed_precision,
|
|
103
|
-
cpu_offload=cpu_offload,
|
|
104
|
-
device_id=torch.cuda.current_device(),
|
|
105
|
-
process_group=None,
|
|
106
|
-
# use_orig_params=True,
|
|
107
|
-
# backward_prefetch=BackwardPrefetch.BACKWARD_PRE, # bit faster async comms, bit higher memory
|
|
108
|
-
# limit_all_gathers=False,
|
|
109
|
-
# forward_prefetch=True,
|
|
110
|
-
)
|
|
111
|
-
else:
|
|
112
|
-
model = model
|
|
113
|
-
raw_model = model
|
|
114
|
-
|
|
115
|
-
if save_instance:
|
|
116
|
-
self.raw_model = raw_model
|
|
117
|
-
self.model = model
|
|
118
|
-
|
|
119
|
-
return model, optimizer
|
|
120
|
-
|
|
121
|
-
|
|
@@ -1,38 +0,0 @@
|
|
|
1
|
-
llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
|
|
2
|
-
llm_trainer/checkpoint.py,sha256=AvUC1JLxuahKtg3VNW20VHIE3iIjpaMHIi_pyyDYVJ0,5043
|
|
3
|
-
llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
|
|
4
|
-
llm_trainer/dcp.py,sha256=PkD97DyrOtoTKn4FJsfL3VqAy4dxufgjdzJEz8-Cnoc,3635
|
|
5
|
-
llm_trainer/dpo_trainer.py,sha256=o5lYxt6yVMCvoBqW_yTu9l6Ff-xjEu-CwdPVttu3H8E,11447
|
|
6
|
-
llm_trainer/ds_checkpoint.py,sha256=wz48HoLBBt8QGO1tXfvJwrXoiGtPG_gjwHfEqARllso,2175
|
|
7
|
-
llm_trainer/ds_model_params.py,sha256=Nwmv0YcBtO6ynC0dXallAD1rWkN22-elGfVjLaWp2Yg,2988
|
|
8
|
-
llm_trainer/eval.py,sha256=NDm8PbXLch7xT81xPYPRCNrcrB_Xj5GDJSCxyVwUOp4,1524
|
|
9
|
-
llm_trainer/fsdp_checkpoint.py,sha256=xsm71s9WeTaBvBvv6CbuGpwkmX3V6i3xmBcMTDfGxKc,1770
|
|
10
|
-
llm_trainer/fsdp_model_params.py,sha256=MRjrs9zmMl-61a1l6188Ij5PSalzztOSp8E4evDvJXo,1541
|
|
11
|
-
llm_trainer/generate_utils.py,sha256=tSbA_tLqSq5qJGHSOlPv5T3iRDZkbFg5ZvDAgJ_i_SE,17946
|
|
12
|
-
llm_trainer/grpo_trainer.py,sha256=1gZXiL1pogLFecFQUGj9zCU_k66ryVjZciYyd8J5ph4,15998
|
|
13
|
-
llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
|
|
14
|
-
llm_trainer/loss.py,sha256=Yv3fsaVuZ5AhnGPJOr5vEMb_tM2urR6mCb4DBbrHHI8,6030
|
|
15
|
-
llm_trainer/model_params.py,sha256=2f2W9KRCjyqSfEwxI3w5f6TPZaqq25WzY-nEc7aJxcs,970
|
|
16
|
-
llm_trainer/parallel.py,sha256=DQu8GqEFxD99HQ6hKuIxxyKi-05dMO33eMhImYlPuOI,4468
|
|
17
|
-
llm_trainer/parallel_ddp.py,sha256=Pob9vUlBZnkL4oP1Re11kFob7nufMSE96pn7m7fuOEM,1345
|
|
18
|
-
llm_trainer/parallel_ds.py,sha256=oy8RRxHud3rACWubFlJqqd0pjPEQhKeAPGPQUSdJX2c,1145
|
|
19
|
-
llm_trainer/parallel_fsdp.py,sha256=cQOdY8ou6m8OsR06PpFVn6GiyZlK9nefkcGyszUOIJk,4055
|
|
20
|
-
llm_trainer/parallel_none.py,sha256=TG6Pm829Dg-yQu-97O-EHV3FCARBlNcP47KkGFAs16E,676
|
|
21
|
-
llm_trainer/scheduler.py,sha256=Xz8HhwoRMjRe41sf_NHhpZfkTlEs0I2MYusvMY6hCVw,3531
|
|
22
|
-
llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,1783
|
|
23
|
-
llm_trainer/tokenizer.py,sha256=A7TYYUbtPf75kjCvWP7yBui4xZBObMk2aPem62YpwpY,6776
|
|
24
|
-
llm_trainer/tools.py,sha256=O45-20wRmh-nyTfU-U-XtjbKAoe7boEIsUvWT_NaKx4,3041
|
|
25
|
-
llm_trainer/train_configs.py,sha256=HKzH3nfMT1-SW4Htwa0KqYtMd6FAJcthR5IEo6di8us,8168
|
|
26
|
-
llm_trainer/trainer.py,sha256=95ARdNDfalhZ7Ug-fDj3qIhWEiZQeX9n5WANhijIRLE,27140
|
|
27
|
-
llm_trainer/utils.py,sha256=-ivhMF0d999va13S1wt2uBvtVw8Nvr3uBzhaUFKL04Q,6826
|
|
28
|
-
project_llm_trainer-0.4.15.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
|
|
29
|
-
project_llm_trainer-0.4.15.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
|
|
30
|
-
project_llm_trainer-0.4.15.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
|
|
31
|
-
project_llm_trainer-0.4.15.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
|
|
32
|
-
project_llm_trainer-0.4.15.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
|
|
33
|
-
project_llm_trainer-0.4.15.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
|
|
34
|
-
project_llm_trainer-0.4.15.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
|
|
35
|
-
project_llm_trainer-0.4.15.dist-info/METADATA,sha256=5sveZ3kkRMVCz9dI5_NI64o9tFBVsJhHhun9vwzzL9Q,196
|
|
36
|
-
project_llm_trainer-0.4.15.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
|
37
|
-
project_llm_trainer-0.4.15.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
|
|
38
|
-
project_llm_trainer-0.4.15.dist-info/RECORD,,
|
{project_llm_trainer-0.4.15.data → project_llm_trainer-0.5.1.data}/scripts/calc_intermediate_size
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|