project-llm-trainer 0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/__init__.py +6 -0
- llm_trainer/checkpoint.py +161 -0
- llm_trainer/dataset.py +140 -0
- llm_trainer/dcp.py +93 -0
- llm_trainer/dpo_trainer.py +300 -0
- llm_trainer/ds_checkpoint.py +61 -0
- llm_trainer/eval.py +86 -0
- llm_trainer/generate_utils.py +424 -0
- llm_trainer/grpo_trainer.py +393 -0
- llm_trainer/log.py +16 -0
- llm_trainer/loss.py +171 -0
- llm_trainer/parallel.py +146 -0
- llm_trainer/parallel_ddp.py +39 -0
- llm_trainer/parallel_ds.py +45 -0
- llm_trainer/parallel_fsdp.py +115 -0
- llm_trainer/parallel_none.py +28 -0
- llm_trainer/scheduler.py +138 -0
- llm_trainer/sft_trainer.py +39 -0
- llm_trainer/tokenizer.py +166 -0
- llm_trainer/tools.py +102 -0
- llm_trainer/train_configs.py +445 -0
- llm_trainer/trainer.py +569 -0
- llm_trainer/utils.py +262 -0
- project_llm_trainer-0.3.data/scripts/calc_intermediate_size +15 -0
- project_llm_trainer-0.3.data/scripts/ddp_train +12 -0
- project_llm_trainer-0.3.data/scripts/ds_train +12 -0
- project_llm_trainer-0.3.data/scripts/plot_loss +39 -0
- project_llm_trainer-0.3.data/scripts/plot_lr +41 -0
- project_llm_trainer-0.3.data/scripts/py_train +12 -0
- project_llm_trainer-0.3.data/scripts/smart_train +28 -0
- project_llm_trainer-0.3.dist-info/METADATA +9 -0
- project_llm_trainer-0.3.dist-info/RECORD +34 -0
- project_llm_trainer-0.3.dist-info/WHEEL +5 -0
- project_llm_trainer-0.3.dist-info/top_level.txt +1 -0
llm_trainer/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
from .trainer import Trainer
|
|
2
|
+
from .sft_trainer import SFTTrainer
|
|
3
|
+
from .dpo_trainer import DPOTrainer
|
|
4
|
+
from .grpo_trainer import GRPOTrainer
|
|
5
|
+
from .tools import TrainerTools, FileDataset, estimate_data_size
|
|
6
|
+
from .generate_utils import generate, streaming_generate
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Optional, Union, Tuple
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn
|
|
5
|
+
from torch.optim import Optimizer
|
|
6
|
+
|
|
7
|
+
from .parallel_ds import DsParallel
|
|
8
|
+
from .parallel_fsdp import FsdpParallel
|
|
9
|
+
from .parallel_ddp import DdpParallel
|
|
10
|
+
from .scheduler import LRScheduler
|
|
11
|
+
from .tools import TrainerTools
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
from .dcp import save_dcp, load_dcp, convert_dcp_to_pth
|
|
15
|
+
except:
|
|
16
|
+
os.environ['ENABLE_DCP'] = "0"
|
|
17
|
+
|
|
18
|
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
19
|
+
|
|
20
|
+
# https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html
|
|
21
|
+
|
|
22
|
+
DEFAULT_CHECKPOINT_NAME = "checkpoint.pth"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _can_use_dcp(model: nn.Module) -> bool:
|
|
26
|
+
if os.environ.get('ENABLE_DCP', '1') != '1':
|
|
27
|
+
return False
|
|
28
|
+
|
|
29
|
+
# 如果是fsdp或者ddp,才能使用dcp保存
|
|
30
|
+
if (isinstance(TrainerTools().parallel, FsdpParallel)
|
|
31
|
+
or isinstance(TrainerTools().parallel, DdpParallel)):
|
|
32
|
+
return True
|
|
33
|
+
|
|
34
|
+
return False
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def save_checkpoint(
|
|
38
|
+
model: nn.Module,
|
|
39
|
+
optimizer: Optional[Optimizer] = None,
|
|
40
|
+
suffix: Optional[str] = None
|
|
41
|
+
):
|
|
42
|
+
if isinstance(TrainerTools().parallel, DsParallel):
|
|
43
|
+
from .ds_checkpoint import save_ds_checkpoint
|
|
44
|
+
save_ds_checkpoint(model, suffix)
|
|
45
|
+
elif _can_use_dcp(model):
|
|
46
|
+
save_dcp(model, optimizer, suffix)
|
|
47
|
+
else:
|
|
48
|
+
if isinstance(model, FSDP):
|
|
49
|
+
# 未经过测试 参考:https://doc.hfai.high-flyer.cn/haiscale/haiscale_fsdp.html
|
|
50
|
+
# 是否使用rank0_only=True?
|
|
51
|
+
with FSDP.summon_full_params(
|
|
52
|
+
module=model,
|
|
53
|
+
rank0_only=True,
|
|
54
|
+
writeback=False,
|
|
55
|
+
offload_to_cpu=True
|
|
56
|
+
):
|
|
57
|
+
if TrainerTools().parallel.is_main_process:
|
|
58
|
+
checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
|
|
59
|
+
if suffix:
|
|
60
|
+
checkpoint_name = f"{checkpoint_name}_{suffix}"
|
|
61
|
+
|
|
62
|
+
ckpt = {'model_state_dict': model.state_dict()}
|
|
63
|
+
|
|
64
|
+
if optimizer:
|
|
65
|
+
ckpt.update({'optim_state_dict': optimizer.state_dict()})
|
|
66
|
+
|
|
67
|
+
torch.save(ckpt, checkpoint_name)
|
|
68
|
+
else:
|
|
69
|
+
if TrainerTools().parallel.is_main_process:
|
|
70
|
+
checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
|
|
71
|
+
if suffix:
|
|
72
|
+
checkpoint_name = f"{checkpoint_name}_{suffix}"
|
|
73
|
+
|
|
74
|
+
ckpt = {'model_state_dict': TrainerTools().parallel.raw_model.state_dict()}
|
|
75
|
+
|
|
76
|
+
if optimizer:
|
|
77
|
+
ckpt.update({'optim_state_dict': optimizer.state_dict()})
|
|
78
|
+
|
|
79
|
+
torch.save(ckpt, checkpoint_name)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def load_checkpoint(
|
|
83
|
+
model: nn.Module,
|
|
84
|
+
optimizer: Optional[Optimizer] = None,
|
|
85
|
+
device: Optional[Union[torch.device, str]] = None,
|
|
86
|
+
load_module_only: bool = False,
|
|
87
|
+
suffix: Optional[str] = None
|
|
88
|
+
):
|
|
89
|
+
if isinstance(TrainerTools().parallel, DsParallel):
|
|
90
|
+
from .ds_checkpoint import load_ds_checkpoint
|
|
91
|
+
load_ds_checkpoint(model, load_module_only=load_module_only, suffix=suffix)
|
|
92
|
+
elif _can_use_dcp(model):
|
|
93
|
+
load_dcp(model, optimizer, suffix)
|
|
94
|
+
else:
|
|
95
|
+
checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
|
|
96
|
+
if suffix:
|
|
97
|
+
checkpoint_name = f"{checkpoint_name}_{suffix}"
|
|
98
|
+
|
|
99
|
+
if os.path.exists(checkpoint_name):
|
|
100
|
+
# 未经过测试,else的逻辑经过测试在fsdp下也没问题
|
|
101
|
+
if isinstance(model, FSDP):
|
|
102
|
+
with FSDP.summon_full_params(module=model):
|
|
103
|
+
state_dict = torch.load(checkpoint_name, weights_only=True, map_location=device)
|
|
104
|
+
model.load_state_dict(state_dict['model_state_dict'])
|
|
105
|
+
|
|
106
|
+
if optimizer:
|
|
107
|
+
optimizer.load_state_dict(state_dict['optim_state_dict'])
|
|
108
|
+
else:
|
|
109
|
+
state_dict = torch.load(checkpoint_name, weights_only=True, map_location=device)
|
|
110
|
+
model.load_state_dict(state_dict['model_state_dict'])
|
|
111
|
+
|
|
112
|
+
if optimizer:
|
|
113
|
+
optimizer.load_state_dict(state_dict['optim_state_dict'])
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def load_checkpoint_for_eval(
|
|
117
|
+
model: nn.Module,
|
|
118
|
+
device: Optional[Union[torch.device, str]] = None,
|
|
119
|
+
suffix: Optional[str] = None
|
|
120
|
+
):
|
|
121
|
+
if isinstance(TrainerTools().parallel, DsParallel):
|
|
122
|
+
from .ds_checkpoint import load_ds_checkpoint_for_eval
|
|
123
|
+
load_ds_checkpoint_for_eval(model)
|
|
124
|
+
elif _can_use_dcp(model):
|
|
125
|
+
checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
|
|
126
|
+
|
|
127
|
+
# load_dcp方式在cpu上会报错,所以改为先将ckpt转换为pth,然后再加载pth
|
|
128
|
+
# load_dcp(model, optimizer)
|
|
129
|
+
pth_name = os.environ.get('EVAL_CHECKPOINT_NAME', checkpoint_name)
|
|
130
|
+
if suffix:
|
|
131
|
+
pth_name = f'{pth_name}_{suffix}'
|
|
132
|
+
|
|
133
|
+
convert_dcp_to_pth(pth_name)
|
|
134
|
+
|
|
135
|
+
if os.path.exists(pth_name):
|
|
136
|
+
ckpt = torch.load(pth_name, map_location=device, weights_only=True)
|
|
137
|
+
model.load_state_dict(ckpt['app']['model_state_dict'])
|
|
138
|
+
# 使用完删除
|
|
139
|
+
os.remove(pth_name)
|
|
140
|
+
else:
|
|
141
|
+
load_checkpoint(model, None, device, suffix=suffix)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def save_steps(global_steps: int, lr_scheduler: Optional[LRScheduler] = None):
|
|
145
|
+
# 暂时只保存主进程的
|
|
146
|
+
if TrainerTools().parallel.is_main_process:
|
|
147
|
+
steps_checkpoint_name = f"{os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)}.steps"
|
|
148
|
+
ckpt = {'global_steps': global_steps, 'lr_steps': lr_scheduler.cur_steps}
|
|
149
|
+
torch.save(ckpt, steps_checkpoint_name)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def load_steps(
|
|
153
|
+
default_global_steps: int = 0,
|
|
154
|
+
default_lr_steps: int = 0
|
|
155
|
+
) -> Tuple[Optional[int], Optional[int]]:
|
|
156
|
+
steps_checkpoint_name = f"{os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)}.steps"
|
|
157
|
+
if os.path.exists(steps_checkpoint_name):
|
|
158
|
+
ckpt = torch.load(steps_checkpoint_name, weights_only=True)
|
|
159
|
+
return ckpt['global_steps'], ckpt['lr_steps']
|
|
160
|
+
|
|
161
|
+
return default_global_steps, default_lr_steps
|
llm_trainer/dataset.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
import os.path
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch.utils.data import Dataset
|
|
5
|
+
import pickle
|
|
6
|
+
|
|
7
|
+
from .tools import TrainerTools
|
|
8
|
+
from .utils import extra_image_tag_and_repeat_image_tok
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _try_load_pkl(file_path: str):
|
|
12
|
+
tokens = None
|
|
13
|
+
try:
|
|
14
|
+
with open(file_path, 'rb') as f:
|
|
15
|
+
tokens = pickle.load(f)
|
|
16
|
+
except Exception as e:
|
|
17
|
+
raise e
|
|
18
|
+
finally:
|
|
19
|
+
return tokens
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class TextDataset(Dataset):
|
|
23
|
+
"""
|
|
24
|
+
适用于pretrain阶段
|
|
25
|
+
"""
|
|
26
|
+
def __init__(self, file_path, block_size, stride):
|
|
27
|
+
super().__init__()
|
|
28
|
+
|
|
29
|
+
self.input_ids = []
|
|
30
|
+
|
|
31
|
+
tokens = _try_load_pkl(file_path)
|
|
32
|
+
if not tokens:
|
|
33
|
+
cache_file = f'{file_path}.cache'
|
|
34
|
+
if os.path.exists(cache_file):
|
|
35
|
+
tokens = _try_load_pkl(cache_file)
|
|
36
|
+
else:
|
|
37
|
+
tokens = []
|
|
38
|
+
with open(file_path, 'r') as f:
|
|
39
|
+
for line in f:
|
|
40
|
+
tokens.extend(TrainerTools().tokenizer.encode(line))
|
|
41
|
+
|
|
42
|
+
with open(cache_file, 'wb') as f:
|
|
43
|
+
pickle.dump(tokens, f)
|
|
44
|
+
|
|
45
|
+
for i in range(0, len(tokens) - block_size + 1, stride):
|
|
46
|
+
self.input_ids.append(tokens[i:i+block_size])
|
|
47
|
+
|
|
48
|
+
def __len__(self):
|
|
49
|
+
return len(self.input_ids)
|
|
50
|
+
|
|
51
|
+
def __getitem__(self, item):
|
|
52
|
+
return torch.tensor(self.input_ids[item]).long()
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class LineByLineTextDataset(Dataset):
|
|
56
|
+
"""
|
|
57
|
+
适用于sft阶段
|
|
58
|
+
"""
|
|
59
|
+
def __init__(self, file_path, max_len, tokens_per_image=-1):
|
|
60
|
+
super().__init__()
|
|
61
|
+
|
|
62
|
+
self.max_len = max_len
|
|
63
|
+
self.tokens_per_image = tokens_per_image
|
|
64
|
+
self.input_ids = []
|
|
65
|
+
|
|
66
|
+
tokens = _try_load_pkl(file_path)
|
|
67
|
+
if not tokens:
|
|
68
|
+
cache_file = f'{file_path}.cache'
|
|
69
|
+
if os.path.exists(cache_file):
|
|
70
|
+
tokens = _try_load_pkl(cache_file)
|
|
71
|
+
else:
|
|
72
|
+
tokens = []
|
|
73
|
+
with open(file_path, 'r') as f:
|
|
74
|
+
for line in f:
|
|
75
|
+
tokens.append(TrainerTools().tokenizer.encode(line))
|
|
76
|
+
|
|
77
|
+
with open(cache_file, 'wb') as f:
|
|
78
|
+
pickle.dump(tokens, f)
|
|
79
|
+
|
|
80
|
+
self.input_ids = tokens
|
|
81
|
+
|
|
82
|
+
def __len__(self):
|
|
83
|
+
return len(self.input_ids)
|
|
84
|
+
|
|
85
|
+
def __getitem__(self, item):
|
|
86
|
+
inputs = self.input_ids[item]
|
|
87
|
+
if self.tokens_per_image != -1:
|
|
88
|
+
inputs, image_tag = extra_image_tag_and_repeat_image_tok(inputs, self.tokens_per_image)
|
|
89
|
+
else:
|
|
90
|
+
image_tag = None
|
|
91
|
+
|
|
92
|
+
inputs = inputs[:self.max_len]
|
|
93
|
+
|
|
94
|
+
return {'inputs': torch.tensor(inputs).long(), 'image_tag': image_tag}
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class DPODataset(Dataset):
|
|
98
|
+
def __init__(self, file_path, max_len):
|
|
99
|
+
self.max_len = max_len
|
|
100
|
+
self.chosen_ids = []
|
|
101
|
+
self.rejected_ids = []
|
|
102
|
+
|
|
103
|
+
# [{'chosen': xxx, 'rejected': xxx} ...]
|
|
104
|
+
tokens = _try_load_pkl(file_path)
|
|
105
|
+
for token in tokens:
|
|
106
|
+
self.chosen_ids.append(token['chosen'])
|
|
107
|
+
self.rejected_ids.append(token['rejected'])
|
|
108
|
+
|
|
109
|
+
def __len__(self):
|
|
110
|
+
return len(self.chosen_ids)
|
|
111
|
+
|
|
112
|
+
def __getitem__(self, item):
|
|
113
|
+
chosen_id = self.chosen_ids[item]
|
|
114
|
+
rejected_id = self.rejected_ids[item]
|
|
115
|
+
|
|
116
|
+
return {'chosen': chosen_id[:self.max_len], 'rejected': rejected_id[:self.max_len]}
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class GRPORolloutDataset(Dataset):
|
|
120
|
+
def __init__(self, file_path):
|
|
121
|
+
self.questions = []
|
|
122
|
+
self.answers = []
|
|
123
|
+
|
|
124
|
+
# [{'question': xxx, 'answer': ''}]
|
|
125
|
+
tokens = _try_load_pkl(file_path)
|
|
126
|
+
for token in tokens:
|
|
127
|
+
self.questions.append(token['prompt'])
|
|
128
|
+
self.answers.append(token['answer'])
|
|
129
|
+
|
|
130
|
+
def __len__(self):
|
|
131
|
+
return len(self.questions)
|
|
132
|
+
|
|
133
|
+
def __getitem__(self, item):
|
|
134
|
+
question = self.questions[item]
|
|
135
|
+
answer = self.answers[item]
|
|
136
|
+
|
|
137
|
+
return {
|
|
138
|
+
'prompt': torch.tensor(question).long(),
|
|
139
|
+
'answer': torch.tensor(answer).long()
|
|
140
|
+
}
|
llm_trainer/dcp.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Optional, Dict, Any
|
|
3
|
+
from torch import nn
|
|
4
|
+
from torch.optim import Optimizer
|
|
5
|
+
import torch.distributed.checkpoint as dcp
|
|
6
|
+
from torch.distributed.checkpoint.stateful import Stateful
|
|
7
|
+
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
|
|
8
|
+
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save, torch_save_to_dcp
|
|
9
|
+
|
|
10
|
+
DEFAULT_CHECKPOINT_DIR = "checkpoint"
|
|
11
|
+
|
|
12
|
+
class AppState(Stateful):
|
|
13
|
+
def __init__(self, model: nn.Module, optimizer: Optimizer):
|
|
14
|
+
self.model = model
|
|
15
|
+
self.optimizer = optimizer
|
|
16
|
+
|
|
17
|
+
def state_dict(self) -> Dict[str, Any]:
|
|
18
|
+
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
|
|
19
|
+
return {
|
|
20
|
+
'model_state_dict': model_state_dict,
|
|
21
|
+
'optim_state_dict': optimizer_state_dict
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
def load_state_dict(self, state_dict: Dict[str, Any]):
|
|
25
|
+
set_state_dict(
|
|
26
|
+
model=self.model,
|
|
27
|
+
optimizers=self.optimizer,
|
|
28
|
+
model_state_dict=state_dict['model_state_dict'],
|
|
29
|
+
optim_state_dict=state_dict['optim_state_dict']
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def save_dcp(
|
|
34
|
+
model: nn.Module,
|
|
35
|
+
optimizer: Optimizer,
|
|
36
|
+
suffix: Optional[str] = None
|
|
37
|
+
):
|
|
38
|
+
checkpoint_id = os.environ.get('DIST_CHECKPOINT_DIR', DEFAULT_CHECKPOINT_DIR)
|
|
39
|
+
if suffix:
|
|
40
|
+
checkpoint_id = f"{checkpoint_id}_{suffix}"
|
|
41
|
+
|
|
42
|
+
state_dict = {'app': AppState(model, optimizer)}
|
|
43
|
+
|
|
44
|
+
# fs_storage_writer = dcp.FileSystemWriter(checkpoint_id, overwrite=True)
|
|
45
|
+
# dcp.save(state_dict=state_dict, storage_writer=fs_storage_writer)
|
|
46
|
+
dcp.save(state_dict=state_dict, checkpoint_id=checkpoint_id)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def load_dcp(
|
|
50
|
+
model: nn.Module,
|
|
51
|
+
optimizer: Optional[Optimizer] = None,
|
|
52
|
+
suffix: Optional[str] = None
|
|
53
|
+
):
|
|
54
|
+
checkpoint_id = os.environ.get('DIST_CHECKPOINT_DIR', DEFAULT_CHECKPOINT_DIR)
|
|
55
|
+
if suffix:
|
|
56
|
+
checkpoint_id = f"{checkpoint_id}_{suffix}"
|
|
57
|
+
|
|
58
|
+
if os.path.exists(checkpoint_id):
|
|
59
|
+
state_dict = {'app': AppState(model, optimizer)}
|
|
60
|
+
# AppState帮助加载到state_dict中, 然后加载到model中
|
|
61
|
+
dcp.load(state_dict=state_dict, checkpoint_id=checkpoint_id)
|
|
62
|
+
|
|
63
|
+
# if isinstance(model, FSDP):
|
|
64
|
+
# state_dict = {'app': AppState(model, optimizer)}
|
|
65
|
+
# # AppState帮助加载到state_dict中, 然后加载到model中
|
|
66
|
+
# dcp.load(state_dict=state_dict, checkpoint_id=checkpoint_id)
|
|
67
|
+
# else:
|
|
68
|
+
# state_dict = {"model_state_dict": model.state_dict()}
|
|
69
|
+
#
|
|
70
|
+
# if optimizer:
|
|
71
|
+
# state_dict.update({'optim_state_dict': optimizer.state_dict()})
|
|
72
|
+
#
|
|
73
|
+
# # since no progress group is initialized, DCP will disable any collectives.
|
|
74
|
+
# # 加载到state_dict中,然后通过model.load_state_dict加载到model中
|
|
75
|
+
# dcp.load(
|
|
76
|
+
# state_dict=state_dict,
|
|
77
|
+
# checkpoint_id=checkpoint_id,
|
|
78
|
+
# )
|
|
79
|
+
#
|
|
80
|
+
# model.load_state_dict(state_dict["model_state_dict"])
|
|
81
|
+
# if optimizer:
|
|
82
|
+
# optimizer.load_state_dict(state_dict["optim_state_dict"])
|
|
83
|
+
|
|
84
|
+
def convert_dcp_to_pth(pth_path: str):
|
|
85
|
+
dcp_path = os.environ.get('DIST_CHECKPOINT_DIR', DEFAULT_CHECKPOINT_DIR)
|
|
86
|
+
if os.path.exists(dcp_path):
|
|
87
|
+
# convert dcp model to torch.save (assumes checkpoint was generated as above)
|
|
88
|
+
dcp_to_torch_save(dcp_path, pth_path)
|
|
89
|
+
|
|
90
|
+
def convert_pth_to_dcp(pth_path: str):
|
|
91
|
+
if os.path.exists(pth_path):
|
|
92
|
+
# converts the torch.save model back to DCP
|
|
93
|
+
torch_save_to_dcp(pth_path, os.environ.get('DIST_CHECKPOINT_DIR', DEFAULT_CHECKPOINT_DIR))
|