project-llm-trainer 0.12.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_trainer/__init__.py +13 -0
- llm_trainer/base_trainer.py +683 -0
- llm_trainer/checkpoint.py +126 -0
- llm_trainer/dataset.py +335 -0
- llm_trainer/dpo_trainer.py +297 -0
- llm_trainer/ds_checkpoint.py +63 -0
- llm_trainer/eval.py +33 -0
- llm_trainer/generate_utils.py +450 -0
- llm_trainer/grpo_trainer.py +385 -0
- llm_trainer/log.py +65 -0
- llm_trainer/loss.py +268 -0
- llm_trainer/parallel.py +220 -0
- llm_trainer/partition_utils.py +219 -0
- llm_trainer/ppo_trainer.py +521 -0
- llm_trainer/scheduler.py +179 -0
- llm_trainer/sft_trainer.py +97 -0
- llm_trainer/tokenizer.py +162 -0
- llm_trainer/tools.py +116 -0
- llm_trainer/train_configs.py +324 -0
- llm_trainer/trainer.py +34 -0
- llm_trainer/utils.py +547 -0
- project_llm_trainer-0.12.3.data/scripts/calc_intermediate_size +15 -0
- project_llm_trainer-0.12.3.data/scripts/ddp_train +21 -0
- project_llm_trainer-0.12.3.data/scripts/ds_train +17 -0
- project_llm_trainer-0.12.3.data/scripts/plot_log +69 -0
- project_llm_trainer-0.12.3.data/scripts/plot_lr +45 -0
- project_llm_trainer-0.12.3.data/scripts/py_train +12 -0
- project_llm_trainer-0.12.3.data/scripts/smart_train +37 -0
- project_llm_trainer-0.12.3.dist-info/METADATA +9 -0
- project_llm_trainer-0.12.3.dist-info/RECORD +32 -0
- project_llm_trainer-0.12.3.dist-info/WHEEL +5 -0
- project_llm_trainer-0.12.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,385 @@
|
|
|
1
|
+
from typing import Tuple, List, Callable, Optional
|
|
2
|
+
import gc
|
|
3
|
+
import torch
|
|
4
|
+
from torch.utils.data import Dataset
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
|
|
7
|
+
from .base_trainer import BaseTrainer
|
|
8
|
+
from .train_configs import TrainConfig
|
|
9
|
+
from .dataset import RLDataset
|
|
10
|
+
from .loss import GRPOLoss
|
|
11
|
+
from .tools import TrainerTools
|
|
12
|
+
from .generate_utils import batch_generate
|
|
13
|
+
from .log import Logger
|
|
14
|
+
from .utils import (
|
|
15
|
+
autocast,
|
|
16
|
+
left_pad_sequence,
|
|
17
|
+
log_softmax,
|
|
18
|
+
disable_dropout_in_model,
|
|
19
|
+
calc_position_ids
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
from .partition_utils import (
|
|
23
|
+
sync_model_params,
|
|
24
|
+
unwrap_model_for_generation
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
from .checkpoint import (
|
|
28
|
+
save_checkpoint,
|
|
29
|
+
save_steps,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
class GRPOTrainer(BaseTrainer):
|
|
33
|
+
"""
|
|
34
|
+
reward_func(prompt_ids, complete_ids, answer_ids) -> scores
|
|
35
|
+
"""
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
*,
|
|
39
|
+
train_config: TrainConfig,
|
|
40
|
+
reward_func: Callable[[List[torch.Tensor], torch.Tensor, List[Optional[torch.Tensor]]], List[float]],
|
|
41
|
+
eval_prompts: List[str]
|
|
42
|
+
):
|
|
43
|
+
self.grpo_config = train_config.grpo_config
|
|
44
|
+
super().__init__(
|
|
45
|
+
train_config=train_config,
|
|
46
|
+
eval_prompts=eval_prompts
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
self.reward_func = reward_func
|
|
50
|
+
self.ref_model = self._init_ref_model()
|
|
51
|
+
|
|
52
|
+
def _init_ref_model(self):
|
|
53
|
+
# beta == 0,不需要ref_model
|
|
54
|
+
if self.grpo_config.loss_beta == 0.0:
|
|
55
|
+
return None
|
|
56
|
+
|
|
57
|
+
ref_model = self._new_model(self.train_config)
|
|
58
|
+
|
|
59
|
+
ref_model.eval()
|
|
60
|
+
for param in ref_model.parameters():
|
|
61
|
+
param.requires_grad = False
|
|
62
|
+
|
|
63
|
+
ref_model, _ = TrainerTools().parallel.process(
|
|
64
|
+
model=ref_model,
|
|
65
|
+
optimizer=None,
|
|
66
|
+
kwargs=self._init_ref_model_args(),
|
|
67
|
+
save_instance=False
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
return ref_model
|
|
71
|
+
|
|
72
|
+
def _new_model(self, train_config: TrainConfig):
|
|
73
|
+
model = super()._new_model(train_config)
|
|
74
|
+
disable_dropout_in_model(model)
|
|
75
|
+
return model
|
|
76
|
+
|
|
77
|
+
def _init_loss(self):
|
|
78
|
+
criterion = GRPOLoss(
|
|
79
|
+
beta=self.grpo_config.loss_beta,
|
|
80
|
+
clip_eps_low=self.grpo_config.loss_clip_eps,
|
|
81
|
+
clip_eps_high=self.grpo_config.loss_clip_eps_high,
|
|
82
|
+
delta=self.grpo_config.loss_delta,
|
|
83
|
+
importance_sampling_level=self.grpo_config.loss_importance_sampling_level,
|
|
84
|
+
loss_type=self.grpo_config.loss_type,
|
|
85
|
+
gen_max_new_tokens=self.grpo_config.gen_max_new_tokens
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
return criterion, None
|
|
89
|
+
|
|
90
|
+
def _convert_train_args(self) -> Tuple[dict, dict, dict]:
|
|
91
|
+
parallel_kwargs, data_loader_kwargs, sampler_kwargs = super()._convert_train_args()
|
|
92
|
+
data_loader_kwargs.update({"collate_fn": lambda x: x})
|
|
93
|
+
|
|
94
|
+
return parallel_kwargs, data_loader_kwargs, sampler_kwargs
|
|
95
|
+
|
|
96
|
+
def _create_dataset(self, file_idx) -> Tuple[Dataset, str]:
|
|
97
|
+
file_path = self.train_config.file_dataset[file_idx]
|
|
98
|
+
return RLDataset(file_path), file_path
|
|
99
|
+
|
|
100
|
+
def _calc_loss(self, inputs, attention_mask, logits, labels): ...
|
|
101
|
+
|
|
102
|
+
def _compute_log_probs(
|
|
103
|
+
self,
|
|
104
|
+
model,
|
|
105
|
+
input_ids,
|
|
106
|
+
attention_mask
|
|
107
|
+
):
|
|
108
|
+
position_ids = calc_position_ids(attention_mask)
|
|
109
|
+
|
|
110
|
+
# [batch_size, total_seq_len, vocab_size]
|
|
111
|
+
outputs = model(
|
|
112
|
+
input_ids=input_ids,
|
|
113
|
+
attention_mask=attention_mask,
|
|
114
|
+
position_ids=position_ids
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# [batch_size, total_seq_len - 1, vocab_size]
|
|
118
|
+
logits = outputs['logits'][:, :-1, :]
|
|
119
|
+
input_ids = input_ids[:, 1:]
|
|
120
|
+
|
|
121
|
+
# Compute and return the log probabilities for the selected tokens.
|
|
122
|
+
return log_softmax(logits, input_ids), outputs['aux_loss']
|
|
123
|
+
|
|
124
|
+
def _compute_group_relative_advantages(self, rewards):
|
|
125
|
+
group_size = self.grpo_config.group_size
|
|
126
|
+
|
|
127
|
+
# Reshape rewards to group by prompt
|
|
128
|
+
# [batch, group_size]
|
|
129
|
+
rewards_by_group = rewards.view(-1, group_size)
|
|
130
|
+
|
|
131
|
+
# Compute mean and standard deviation for each prompt group
|
|
132
|
+
# [batch]
|
|
133
|
+
group_means = rewards_by_group.mean(dim=1)
|
|
134
|
+
group_stds = rewards_by_group.std(dim=1)
|
|
135
|
+
|
|
136
|
+
# Expand the means and stds to match the original flat rewards tensor shape
|
|
137
|
+
# [batch*group_size]
|
|
138
|
+
expanded_means = group_means.repeat_interleave(group_size)
|
|
139
|
+
expanded_stds = group_stds.repeat_interleave(group_size)
|
|
140
|
+
|
|
141
|
+
# Normalize rewards to get advantages
|
|
142
|
+
# [batch*group_size]
|
|
143
|
+
advantages = (rewards - expanded_means) / (expanded_stds + 1e-4)
|
|
144
|
+
|
|
145
|
+
# [batch*group_size, 1]
|
|
146
|
+
return advantages.unsqueeze(1) # Add dimension for token-wise operations
|
|
147
|
+
|
|
148
|
+
def _generate_completions(self, model, prompts, group_size: int):
|
|
149
|
+
pad_token_id = TrainerTools().tokenizer.pad
|
|
150
|
+
device = TrainerTools().parallel.device
|
|
151
|
+
|
|
152
|
+
# 左边添加pad,对齐prompt长度
|
|
153
|
+
# [batch, max_prompt_len]
|
|
154
|
+
prompt_ids = left_pad_sequence(prompts, padding_value=pad_token_id)
|
|
155
|
+
prompt_ids = prompt_ids.to(device)
|
|
156
|
+
|
|
157
|
+
prompt_len = prompt_ids.shape[1]
|
|
158
|
+
|
|
159
|
+
# [batch*group_size, max_prompt_len]
|
|
160
|
+
prompt_ids = prompt_ids.repeat_interleave(group_size, 0)
|
|
161
|
+
# [batch*group_size, max_prompt_len]
|
|
162
|
+
prompt_masks = prompt_ids != pad_token_id
|
|
163
|
+
|
|
164
|
+
# [batch*group_size, max_prompt_len+max_gen_len]
|
|
165
|
+
outputs, _ = batch_generate(
|
|
166
|
+
model=model,
|
|
167
|
+
tokens=prompt_ids,
|
|
168
|
+
attention_mask=prompt_masks,
|
|
169
|
+
max_new_tokens=self.grpo_config.gen_max_new_tokens,
|
|
170
|
+
temperature=self.grpo_config.gen_temperature,
|
|
171
|
+
k=self.grpo_config.gen_k,
|
|
172
|
+
p=self.grpo_config.gen_p,
|
|
173
|
+
device=device,
|
|
174
|
+
suppress_tokens=self.grpo_config.gen_suppress_tokens,
|
|
175
|
+
return_logits=False
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
# [batch*group_size, max_gen_len]
|
|
179
|
+
completion_ids = outputs[:, prompt_len:]
|
|
180
|
+
# [batch*group_size, max_gen_len]
|
|
181
|
+
completion_masks = (completion_ids != pad_token_id).int()
|
|
182
|
+
|
|
183
|
+
return prompt_ids, prompt_masks, completion_ids, completion_masks
|
|
184
|
+
|
|
185
|
+
def _generate_rollout_data(self, generate_model, batch_data: List[dict]):
|
|
186
|
+
prompts = [item["prompt"] for item in batch_data]
|
|
187
|
+
answers = [item["answer"] for item in batch_data]
|
|
188
|
+
group_size = self.grpo_config.group_size
|
|
189
|
+
|
|
190
|
+
# 使用no_grad替换inference_mode
|
|
191
|
+
# 修复问题:Inference tensors cannot be saved for backward. To work around you can make a clone to get a normal
|
|
192
|
+
with torch.no_grad():
|
|
193
|
+
# with torch.inference_mode():
|
|
194
|
+
prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate_completions(generate_model, prompts, group_size)
|
|
195
|
+
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
|
196
|
+
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
|
|
197
|
+
|
|
198
|
+
old_log_probs, _ = self._compute_log_probs(generate_model, input_ids, attention_mask)
|
|
199
|
+
|
|
200
|
+
if self.ref_model:
|
|
201
|
+
ref_log_probs, _ = self._compute_log_probs(self.ref_model, input_ids, attention_mask)
|
|
202
|
+
else:
|
|
203
|
+
ref_log_probs = None
|
|
204
|
+
|
|
205
|
+
repeated_prompts = [p for p in prompts for _ in range(group_size)]
|
|
206
|
+
repeated_answers = [a for a in answers for _ in range(group_size)]
|
|
207
|
+
|
|
208
|
+
return {
|
|
209
|
+
'input_ids': input_ids,
|
|
210
|
+
'attention_mask': attention_mask,
|
|
211
|
+
'completion_mask': completion_mask,
|
|
212
|
+
'old_log_probs': old_log_probs,
|
|
213
|
+
'ref_log_probs': ref_log_probs,
|
|
214
|
+
'completion_ids': completion_ids,
|
|
215
|
+
'repeated_prompts': repeated_prompts,
|
|
216
|
+
'repeated_answers': repeated_answers,
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
def _maximize_grpo_objective(self, rollout_data):
|
|
220
|
+
device = TrainerTools().parallel.device
|
|
221
|
+
|
|
222
|
+
input_ids = rollout_data['input_ids']
|
|
223
|
+
attention_mask = rollout_data['attention_mask']
|
|
224
|
+
completion_mask = rollout_data['completion_mask']
|
|
225
|
+
old_log_probs = rollout_data['old_log_probs']
|
|
226
|
+
ref_log_probs = rollout_data['ref_log_probs']
|
|
227
|
+
completion_ids = rollout_data['completion_ids']
|
|
228
|
+
repeated_prompts = rollout_data['repeated_prompts']
|
|
229
|
+
repeated_answers = rollout_data['repeated_answers']
|
|
230
|
+
|
|
231
|
+
prompt_len = input_ids.shape[1] - completion_ids.shape[1]
|
|
232
|
+
|
|
233
|
+
# [batch*group_size]
|
|
234
|
+
rewards = torch.tensor(
|
|
235
|
+
self.reward_func(repeated_prompts, completion_ids, repeated_answers),
|
|
236
|
+
dtype=torch.float32,
|
|
237
|
+
device=device
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
# [batch*group_size, 1]
|
|
241
|
+
advantages = self._compute_group_relative_advantages(rewards)
|
|
242
|
+
|
|
243
|
+
# Compute current log probabilities
|
|
244
|
+
log_probs, aux_loss = self._compute_log_probs(self.train_model, input_ids, attention_mask)
|
|
245
|
+
|
|
246
|
+
pad_len = prompt_len - 1
|
|
247
|
+
if pad_len > 0:
|
|
248
|
+
padded_completion_mask = F.pad(completion_mask, (pad_len, 0), 'constant', 0)
|
|
249
|
+
else:
|
|
250
|
+
padded_completion_mask = completion_mask
|
|
251
|
+
|
|
252
|
+
assert padded_completion_mask.shape == log_probs.shape, \
|
|
253
|
+
f"Shape mismatch! Padded completion mask: {padded_completion_mask.shape}, Log probs: {log_probs.shape}"
|
|
254
|
+
|
|
255
|
+
loss = self.criterion(
|
|
256
|
+
log_probs=log_probs,
|
|
257
|
+
old_log_probs=old_log_probs,
|
|
258
|
+
ref_log_probs=ref_log_probs,
|
|
259
|
+
completion_mask=padded_completion_mask,
|
|
260
|
+
advantages=advantages
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
return loss, aux_loss, rewards
|
|
264
|
+
|
|
265
|
+
def train(self):
|
|
266
|
+
global_steps = 0
|
|
267
|
+
skipping_train = False
|
|
268
|
+
aux_loss_coef = self.train_config.loss_config.aux_loss_coef
|
|
269
|
+
|
|
270
|
+
for epoch in range(self.train_config.n_epochs):
|
|
271
|
+
if self.ref_model:
|
|
272
|
+
sync_model_params(
|
|
273
|
+
_from=self.train_model,
|
|
274
|
+
_to=self.ref_model,
|
|
275
|
+
mixup_alpha=self.grpo_config.mixup_alpha
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
file_count = len(self.train_config.file_dataset)
|
|
279
|
+
|
|
280
|
+
for file_idx in range(file_count):
|
|
281
|
+
dataset, file_path = self._create_dataset(file_idx)
|
|
282
|
+
|
|
283
|
+
train_data_loader = TrainerTools().parallel.process_dataloader(
|
|
284
|
+
dataset=dataset,
|
|
285
|
+
data_loader_kwargs=self.data_loader_kwargs,
|
|
286
|
+
sampler_kwargs=self.sampler_kwargs
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
last_ckpt_batch = 0
|
|
290
|
+
batch_count_per_file = len(train_data_loader)
|
|
291
|
+
|
|
292
|
+
TrainerTools().parallel.on_epoch_start(epoch)
|
|
293
|
+
self._on_file_start(epoch, file_path)
|
|
294
|
+
|
|
295
|
+
for batch, batch_data in enumerate(train_data_loader):
|
|
296
|
+
global_steps += 1
|
|
297
|
+
if global_steps < self.last_global_steps:
|
|
298
|
+
skipping_train = True
|
|
299
|
+
continue
|
|
300
|
+
|
|
301
|
+
if skipping_train:
|
|
302
|
+
TrainerTools().parallel.wait('skip train')
|
|
303
|
+
skipping_train = False
|
|
304
|
+
|
|
305
|
+
# start generate
|
|
306
|
+
if TrainerTools().parallel.is_main_process:
|
|
307
|
+
Logger.std_log(f'start generate for batch {batch}/{batch_count_per_file}')
|
|
308
|
+
|
|
309
|
+
# 生成数据
|
|
310
|
+
with unwrap_model_for_generation(self.train_model) as generate_model:
|
|
311
|
+
rollout_data = self._generate_rollout_data(generate_model, batch_data)
|
|
312
|
+
# end generate
|
|
313
|
+
|
|
314
|
+
torch.cuda.empty_cache()
|
|
315
|
+
|
|
316
|
+
try:
|
|
317
|
+
if TrainerTools().parallel.is_main_process:
|
|
318
|
+
Logger.std_log(f'start train for batch {batch}/{batch_count_per_file}')
|
|
319
|
+
|
|
320
|
+
for grpo_step in range(self.grpo_config.grpo_steps):
|
|
321
|
+
with autocast(TrainerTools().parallel.device_type):
|
|
322
|
+
loss, aux_loss, rewards = self._maximize_grpo_objective(rollout_data)
|
|
323
|
+
if aux_loss_coef and aux_loss is not None:
|
|
324
|
+
aux_loss = aux_loss_coef * aux_loss
|
|
325
|
+
else:
|
|
326
|
+
aux_loss = torch.tensor(0.0, device=loss.device, dtype=loss.dtype)
|
|
327
|
+
|
|
328
|
+
total_loss = loss + aux_loss
|
|
329
|
+
self._backward_loss(total_loss)
|
|
330
|
+
self._apply_grad_clipping()
|
|
331
|
+
self._apply_step()
|
|
332
|
+
|
|
333
|
+
loss_accumulation = total_loss.detach().item()
|
|
334
|
+
aux_loss_accumulation = aux_loss.detach().item()
|
|
335
|
+
|
|
336
|
+
avg_loss, avg_aux_loss = self._avg_loss(
|
|
337
|
+
losses=[
|
|
338
|
+
loss_accumulation,
|
|
339
|
+
aux_loss_accumulation
|
|
340
|
+
],
|
|
341
|
+
gradient_accumulation_steps=1,
|
|
342
|
+
batches_accumulated=1
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
self._log(
|
|
346
|
+
keys={
|
|
347
|
+
'epoch': epoch,
|
|
348
|
+
'file': f'{file_idx + 1}/{file_count}',
|
|
349
|
+
'batch': f'{batch}/{batch_count_per_file}',
|
|
350
|
+
'grpo_step': grpo_step
|
|
351
|
+
},
|
|
352
|
+
values={
|
|
353
|
+
'loss': avg_loss,
|
|
354
|
+
'moe_aux_loss': avg_aux_loss,
|
|
355
|
+
'rewards': (rewards.sum() / rewards.size(0)).item(),
|
|
356
|
+
}
|
|
357
|
+
)
|
|
358
|
+
except Exception as e:
|
|
359
|
+
self._on_exception(e, epoch, batch)
|
|
360
|
+
finally:
|
|
361
|
+
save_steps(global_steps=global_steps, lr_scheduler=self.lr_scheduler)
|
|
362
|
+
|
|
363
|
+
if (batch - last_ckpt_batch) >= self.train_config.eval_config.eval_batch_interval:
|
|
364
|
+
save_checkpoint(model=self.train_model, optimizer=self.optimizer)
|
|
365
|
+
last_ckpt_batch = batch
|
|
366
|
+
self._on_batch_end(tag=f'epoch:{epoch}/batch:{batch}')
|
|
367
|
+
|
|
368
|
+
# 一个文件训练结束后,清理内存
|
|
369
|
+
del train_data_loader
|
|
370
|
+
del dataset
|
|
371
|
+
if hasattr(TrainerTools().parallel, '_sampler'):
|
|
372
|
+
TrainerTools().parallel._sampler = None
|
|
373
|
+
|
|
374
|
+
gc.collect()
|
|
375
|
+
torch.cuda.empty_cache()
|
|
376
|
+
|
|
377
|
+
# end epoch
|
|
378
|
+
if not skipping_train:
|
|
379
|
+
save_steps(global_steps=global_steps, lr_scheduler=self.lr_scheduler)
|
|
380
|
+
save_checkpoint(model=self.train_model, optimizer=self.optimizer)
|
|
381
|
+
|
|
382
|
+
TrainerTools().parallel.on_epoch_end(epoch)
|
|
383
|
+
self._on_epoch_end(tag=f'epoch:{epoch}')
|
|
384
|
+
|
|
385
|
+
TrainerTools().parallel.destroy()
|
llm_trainer/log.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
import time, os, atexit
|
|
2
|
+
from io import TextIOWrapper
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def _get_log_dir() -> str:
|
|
7
|
+
log_dir = os.environ.get('LOG_DIR', './log')
|
|
8
|
+
os.makedirs(log_dir, exist_ok=True)
|
|
9
|
+
return log_dir
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Logger:
|
|
13
|
+
def __init__(self, log_file_name = None, log_dir = None):
|
|
14
|
+
self.log_file_name = log_file_name
|
|
15
|
+
self.log_file: Optional[TextIOWrapper] = None
|
|
16
|
+
|
|
17
|
+
if not log_dir:
|
|
18
|
+
self.log_dir = _get_log_dir()
|
|
19
|
+
else:
|
|
20
|
+
os.makedirs(log_dir, exist_ok=True)
|
|
21
|
+
self.log_dir = log_dir
|
|
22
|
+
|
|
23
|
+
self.flush_interval = int(os.environ.get('LOG_FLUSH_INTERVAL', '1'))
|
|
24
|
+
self.log_steps = 0
|
|
25
|
+
|
|
26
|
+
@staticmethod
|
|
27
|
+
def std_log(msg: str):
|
|
28
|
+
log_content = Logger._build_log(msg)
|
|
29
|
+
print(log_content)
|
|
30
|
+
|
|
31
|
+
def log(self, msg: str, log_to_console = True):
|
|
32
|
+
log_content = Logger._build_log(msg)
|
|
33
|
+
|
|
34
|
+
if log_to_console:
|
|
35
|
+
print(log_content)
|
|
36
|
+
|
|
37
|
+
if self._open_file():
|
|
38
|
+
self.log_file.write(f'{log_content}\n')
|
|
39
|
+
if self.log_steps % self.flush_interval == 0:
|
|
40
|
+
self.log_file.flush()
|
|
41
|
+
|
|
42
|
+
self.log_steps += 1
|
|
43
|
+
return self
|
|
44
|
+
|
|
45
|
+
def release(self):
|
|
46
|
+
if self.log_file:
|
|
47
|
+
self.log_file.close()
|
|
48
|
+
self.log_file = None
|
|
49
|
+
|
|
50
|
+
@staticmethod
|
|
51
|
+
def _build_log(msg: str):
|
|
52
|
+
cur_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
|
53
|
+
return f'[{cur_time}] {msg}'
|
|
54
|
+
|
|
55
|
+
def _open_file(self) -> bool:
|
|
56
|
+
if not self.log_file_name:
|
|
57
|
+
return False
|
|
58
|
+
|
|
59
|
+
if self.log_file:
|
|
60
|
+
return True
|
|
61
|
+
|
|
62
|
+
self.log_file = open(os.path.join(self.log_dir, self.log_file_name), 'a', encoding='utf-8')
|
|
63
|
+
atexit.register(self.release)
|
|
64
|
+
|
|
65
|
+
return True
|