omnigenome 0.3.1a0__py3-none-any.whl → 0.3.3a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of omnigenome might be problematic. Click here for more details.

Files changed (79) hide show
  1. omnigenome/__init__.py +252 -266
  2. {omnigenome-0.3.1a0.dist-info → omnigenome-0.3.3a0.dist-info}/METADATA +9 -9
  3. omnigenome-0.3.3a0.dist-info/RECORD +7 -0
  4. omnigenome/auto/__init__.py +0 -3
  5. omnigenome/auto/auto_bench/__init__.py +0 -11
  6. omnigenome/auto/auto_bench/auto_bench.py +0 -494
  7. omnigenome/auto/auto_bench/auto_bench_cli.py +0 -230
  8. omnigenome/auto/auto_bench/auto_bench_config.py +0 -216
  9. omnigenome/auto/auto_bench/config_check.py +0 -34
  10. omnigenome/auto/auto_train/__init__.py +0 -12
  11. omnigenome/auto/auto_train/auto_train.py +0 -429
  12. omnigenome/auto/auto_train/auto_train_cli.py +0 -222
  13. omnigenome/auto/bench_hub/__init__.py +0 -11
  14. omnigenome/auto/bench_hub/bench_hub.py +0 -25
  15. omnigenome/cli/__init__.py +0 -12
  16. omnigenome/cli/commands/__init__.py +0 -12
  17. omnigenome/cli/commands/base.py +0 -83
  18. omnigenome/cli/commands/bench/__init__.py +0 -12
  19. omnigenome/cli/commands/bench/bench_cli.py +0 -202
  20. omnigenome/cli/commands/rna/__init__.py +0 -12
  21. omnigenome/cli/commands/rna/rna_design.py +0 -177
  22. omnigenome/cli/omnigenome_cli.py +0 -128
  23. omnigenome/src/__init__.py +0 -11
  24. omnigenome/src/abc/__init__.py +0 -11
  25. omnigenome/src/abc/abstract_dataset.py +0 -641
  26. omnigenome/src/abc/abstract_metric.py +0 -114
  27. omnigenome/src/abc/abstract_model.py +0 -690
  28. omnigenome/src/abc/abstract_tokenizer.py +0 -269
  29. omnigenome/src/dataset/__init__.py +0 -16
  30. omnigenome/src/dataset/omni_dataset.py +0 -437
  31. omnigenome/src/lora/__init__.py +0 -12
  32. omnigenome/src/lora/lora_model.py +0 -300
  33. omnigenome/src/metric/__init__.py +0 -15
  34. omnigenome/src/metric/classification_metric.py +0 -184
  35. omnigenome/src/metric/metric.py +0 -199
  36. omnigenome/src/metric/ranking_metric.py +0 -142
  37. omnigenome/src/metric/regression_metric.py +0 -191
  38. omnigenome/src/misc/__init__.py +0 -3
  39. omnigenome/src/misc/utils.py +0 -503
  40. omnigenome/src/model/__init__.py +0 -19
  41. omnigenome/src/model/augmentation/__init__.py +0 -11
  42. omnigenome/src/model/augmentation/model.py +0 -219
  43. omnigenome/src/model/classification/__init__.py +0 -11
  44. omnigenome/src/model/classification/model.py +0 -638
  45. omnigenome/src/model/embedding/__init__.py +0 -11
  46. omnigenome/src/model/embedding/model.py +0 -263
  47. omnigenome/src/model/mlm/__init__.py +0 -11
  48. omnigenome/src/model/mlm/model.py +0 -177
  49. omnigenome/src/model/module_utils.py +0 -232
  50. omnigenome/src/model/regression/__init__.py +0 -11
  51. omnigenome/src/model/regression/model.py +0 -781
  52. omnigenome/src/model/regression/resnet.py +0 -483
  53. omnigenome/src/model/rna_design/__init__.py +0 -11
  54. omnigenome/src/model/rna_design/model.py +0 -476
  55. omnigenome/src/model/seq2seq/__init__.py +0 -11
  56. omnigenome/src/model/seq2seq/model.py +0 -44
  57. omnigenome/src/tokenizer/__init__.py +0 -16
  58. omnigenome/src/tokenizer/bpe_tokenizer.py +0 -226
  59. omnigenome/src/tokenizer/kmers_tokenizer.py +0 -247
  60. omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +0 -249
  61. omnigenome/src/trainer/__init__.py +0 -14
  62. omnigenome/src/trainer/accelerate_trainer.py +0 -747
  63. omnigenome/src/trainer/hf_trainer.py +0 -75
  64. omnigenome/src/trainer/trainer.py +0 -591
  65. omnigenome/utility/__init__.py +0 -3
  66. omnigenome/utility/dataset_hub/__init__.py +0 -12
  67. omnigenome/utility/dataset_hub/dataset_hub.py +0 -178
  68. omnigenome/utility/ensemble.py +0 -324
  69. omnigenome/utility/hub_utils.py +0 -517
  70. omnigenome/utility/model_hub/__init__.py +0 -11
  71. omnigenome/utility/model_hub/model_hub.py +0 -232
  72. omnigenome/utility/pipeline_hub/__init__.py +0 -11
  73. omnigenome/utility/pipeline_hub/pipeline.py +0 -483
  74. omnigenome/utility/pipeline_hub/pipeline_hub.py +0 -129
  75. omnigenome-0.3.1a0.dist-info/RECORD +0 -78
  76. {omnigenome-0.3.1a0.dist-info → omnigenome-0.3.3a0.dist-info}/WHEEL +0 -0
  77. {omnigenome-0.3.1a0.dist-info → omnigenome-0.3.3a0.dist-info}/entry_points.txt +0 -0
  78. {omnigenome-0.3.1a0.dist-info → omnigenome-0.3.3a0.dist-info}/licenses/LICENSE +0 -0
  79. {omnigenome-0.3.1a0.dist-info → omnigenome-0.3.3a0.dist-info}/top_level.txt +0 -0
@@ -1,747 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # file: trainer.py
3
- # time: 14:40 06/04/2024
4
- # author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
5
- # github: https://github.com/yangheng95
6
- # huggingface: https://huggingface.co/yangheng
7
- # google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
8
- # Copyright (C) 2019-2024. All Rights Reserved.
9
-
10
- import os
11
- import time
12
- import numpy as np
13
- from torch.utils.data import DataLoader
14
- from tqdm import tqdm
15
-
16
- import torch
17
-
18
- from ..misc.utils import env_meta_info, fprint, seed_everything
19
-
20
-
21
- def _infer_optimization_direction(metrics, prev_metrics):
22
- """
23
- Infer the optimization direction based on metric values.
24
-
25
- This function analyzes the trend of metric values to determine whether
26
- larger values are better (e.g., accuracy) or smaller values are better
27
- (e.g., loss).
28
-
29
- Args:
30
- metrics (dict): Current metric values
31
- prev_metrics (list): Previous metric values
32
-
33
- Returns:
34
- str: Either 'larger_is_better' or 'smaller_is_better'
35
- """
36
- larger_is_better_metrics = [
37
- "accuracy",
38
- "f1",
39
- "recall",
40
- "precision",
41
- "roc_auc",
42
- "pr_auc",
43
- "score",
44
- # ...
45
- ]
46
- smaller_is_better_metrics = [
47
- "loss",
48
- "error",
49
- "mse",
50
- "mae",
51
- "r2",
52
- "distance",
53
- # ...
54
- ]
55
- for metric in larger_is_better_metrics:
56
- if prev_metrics and metric in list(prev_metrics[0].keys())[0]:
57
- return "larger_is_better"
58
- for metric in smaller_is_better_metrics:
59
- if prev_metrics and metric in list(prev_metrics[0].keys())[0]:
60
- return "smaller_is_better"
61
-
62
- fprint(
63
- "Cannot determine the optimisation direction. Attempting inference from the metrics."
64
- )
65
- is_prev_increasing = np.mean(list(prev_metrics[0].values())[0]) < np.mean(
66
- list(prev_metrics[-1].values())[0]
67
- )
68
- is_still_increasing = np.mean(list(prev_metrics[1].values())[0]) < np.mean(
69
- list(metrics.values())[0]
70
- )
71
- fprint(
72
- "Cannot determine the optimisation direction. Attempting inference from the metrics."
73
- )
74
-
75
- if is_prev_increasing and is_still_increasing:
76
- return "larger_is_better"
77
-
78
- is_prev_decreasing = np.mean(list(prev_metrics[0].values())[0]) > np.mean(
79
- list(prev_metrics[-1].values())[0]
80
- )
81
- is_still_decreasing = np.mean(list(prev_metrics[1].values())[0]) > np.mean(
82
- list(metrics.values())
83
- )
84
-
85
- if is_prev_decreasing and is_still_decreasing:
86
- return "smaller_is_better"
87
-
88
- return "larger_is_better" if is_prev_increasing else "smaller_is_better"
89
-
90
-
91
- class AccelerateTrainer:
92
- """
93
- A distributed training trainer using HuggingFace Accelerate.
94
-
95
- This trainer provides distributed training capabilities with automatic mixed precision,
96
- gradient accumulation, and early stopping. It supports both single and multi-GPU
97
- training with seamless integration with HuggingFace Accelerate.
98
-
99
- Attributes:
100
- model: The model to train
101
- train_loader: DataLoader for training data
102
- eval_loader: DataLoader for validation data
103
- test_loader: DataLoader for test data
104
- epochs: Number of training epochs
105
- patience: Early stopping patience
106
- gradient_accumulation_steps: Number of steps for gradient accumulation
107
- optimizer: The optimizer for training
108
- loss_fn: Loss function
109
- compute_metrics: List of metric functions to compute
110
- accelerator: HuggingFace Accelerate instance
111
- metrics: Dictionary to store training metrics
112
- predictions: Dictionary to store predictions
113
-
114
- Example:
115
- >>> from omnigenome.src.trainer import AccelerateTrainer
116
- >>> trainer = AccelerateTrainer(
117
- ... model=model,
118
- ... train_dataset=train_dataset,
119
- ... eval_dataset=eval_dataset,
120
- ... epochs=10,
121
- ... batch_size=32,
122
- ... optimizer=optimizer
123
- ... )
124
- >>> metrics = trainer.train()
125
- """
126
-
127
- def __init__(
128
- self,
129
- model,
130
- train_dataset: torch.utils.data.Dataset = None,
131
- eval_dataset: torch.utils.data.Dataset = None,
132
- test_dataset: torch.utils.data.Dataset = None,
133
- epochs: int = 3,
134
- batch_size: int = 8,
135
- patience: int = -1,
136
- gradient_accumulation_steps: int = 1,
137
- optimizer: torch.optim.Optimizer = None,
138
- loss_fn: torch.nn.Module = None,
139
- compute_metrics: list | str = None,
140
- seed: int = 42,
141
- autocast: str = "float16",
142
- **kwargs,
143
- ):
144
- """
145
- Initialize the AccelerateTrainer.
146
-
147
- Args:
148
- model: The model to train
149
- train_dataset (torch.utils.data.Dataset, optional): Training dataset
150
- eval_dataset (torch.utils.data.Dataset, optional): Validation dataset
151
- test_dataset (torch.utils.data.Dataset, optional): Test dataset
152
- epochs (int, optional): Number of training epochs. Defaults to 3
153
- batch_size (int, optional): Batch size for training. Defaults to 8
154
- patience (int, optional): Early stopping patience. Defaults to -1 (no early stopping)
155
- gradient_accumulation_steps (int, optional): Number of steps for gradient accumulation. Defaults to 1
156
- optimizer (torch.optim.Optimizer, optional): Optimizer for training
157
- loss_fn (torch.nn.Module, optional): Loss function
158
- compute_metrics (list | str, optional): List of metric functions or single metric function
159
- seed (int, optional): Random seed for reproducibility. Defaults to 42
160
- autocast (str, optional): Mixed precision type. Options: 'float16', 'bfloat16', 'no'. Defaults to 'float16'
161
- **kwargs: Additional keyword arguments
162
- """
163
- self.model = model
164
-
165
- # DataLoaders
166
- if kwargs.get("train_loader"):
167
- self.train_loader = kwargs.get("train_loader")
168
- self.eval_loader = kwargs.get("eval_loader", None)
169
- self.test_loader = kwargs.get("test_loader", None)
170
- else:
171
- self.train_loader = (
172
- DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
173
- if train_dataset
174
- else None
175
- )
176
- self.eval_loader = (
177
- DataLoader(eval_dataset, batch_size=batch_size)
178
- if eval_dataset
179
- else None
180
- )
181
- self.test_loader = (
182
- DataLoader(test_dataset, batch_size=batch_size)
183
- if test_dataset
184
- else None
185
- )
186
- self.train_loader = (
187
- DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
188
- if train_dataset
189
- else None
190
- )
191
- self.eval_loader = (
192
- DataLoader(eval_dataset, batch_size=batch_size)
193
- if eval_dataset
194
- else None
195
- )
196
- self.test_loader = (
197
- DataLoader(test_dataset, batch_size=batch_size)
198
- if test_dataset
199
- else None
200
- )
201
-
202
- self.epochs = epochs
203
- self.patience = patience
204
- self.gradient_accumulation_steps = gradient_accumulation_steps
205
- self.optimizer = optimizer
206
- self.loss_fn = loss_fn
207
- self.compute_metrics = (
208
- compute_metrics if isinstance(compute_metrics, list) else [compute_metrics]
209
- )
210
- self.compute_metrics = (
211
- compute_metrics if isinstance(compute_metrics, list) else [compute_metrics]
212
- )
213
- self.seed = seed
214
- self._optimization_direction = None
215
- self.trial_name = kwargs.get("trial_name", self.model.__class__.__name__)
216
-
217
- # Determine mixed precision from `autocast` argument if desired
218
- if autocast in ["float16", "fp16"]:
219
- mp_setting = "fp16"
220
- elif autocast in ["bfloat16", "bf16"]:
221
- mp_setting = "bf16"
222
- else:
223
- mp_setting = "no"
224
-
225
- # Prepare Accelerator
226
- from accelerate import Accelerator
227
- from accelerate import DistributedDataParallelKwargs
228
-
229
- ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
230
-
231
- self.accelerator = Accelerator(
232
- mixed_precision=mp_setting, kwargs_handlers=[ddp_kwargs]
233
- )
234
-
235
- self.accelerator = Accelerator(
236
- mixed_precision=mp_setting, kwargs_handlers=[ddp_kwargs]
237
- )
238
- if self.loss_fn is not None:
239
- self.model.set_loss_fn(self.loss_fn)
240
- # 创建 dataloaders
241
- if kwargs.get("train_loader"):
242
- self.train_loader = kwargs.get("train_loader")
243
- self.eval_loader = kwargs.get("eval_loader", None)
244
- self.test_loader = kwargs.get("test_loader", None)
245
- else:
246
- self.train_loader = (
247
- DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
248
- if train_dataset
249
- else None
250
- )
251
- self.eval_loader = (
252
- DataLoader(eval_dataset, batch_size=batch_size)
253
- if eval_dataset
254
- else None
255
- )
256
- self.test_loader = (
257
- DataLoader(test_dataset, batch_size=batch_size)
258
- if test_dataset
259
- else None
260
- )
261
-
262
- # 让 accelerate 处理模型和优化器的准备
263
- to_prepare = [self.model]
264
- if optimizer is not None:
265
- to_prepare.append(optimizer)
266
- if self.train_loader is not None:
267
- to_prepare.append(self.train_loader)
268
- if self.eval_loader is not None:
269
- to_prepare.append(self.eval_loader)
270
- if self.test_loader is not None:
271
- to_prepare.append(self.test_loader)
272
-
273
- prepared = self.accelerator.prepare(*to_prepare)
274
- self.model = prepared[0]
275
- idx = 1
276
- if optimizer is not None:
277
- self.optimizer = prepared[idx]
278
- idx += 1
279
- if self.train_loader is not None:
280
- self.train_loader = prepared[idx]
281
- idx += 1
282
- if self.eval_loader is not None:
283
- self.eval_loader = prepared[idx]
284
- idx += 1
285
- if self.test_loader is not None:
286
- self.test_loader = prepared[idx]
287
-
288
- self.metadata = env_meta_info()
289
- self.metrics = {}
290
-
291
- self.predictions = {}
292
-
293
- def evaluate(self):
294
- """
295
- Evaluate the model on the validation dataset.
296
-
297
- This method runs the model in evaluation mode and computes metrics
298
- on the validation dataset. It handles distributed evaluation and
299
- gathers results from all processes.
300
-
301
- Returns:
302
- dict: Dictionary containing evaluation metrics
303
-
304
- Example:
305
- >>> metrics = trainer.evaluate()
306
- >>> print(f"Validation accuracy: {metrics['accuracy']:.4f}")
307
- """
308
- self.model.eval()
309
- all_truth = []
310
- all_preds = []
311
-
312
- # 禁用进度条在非主进程上显示
313
- it = tqdm(
314
- self.eval_loader,
315
- desc="Evaluating",
316
- disable=not self.accelerator.is_main_process,
317
- )
318
-
319
- with torch.no_grad():
320
- for batch in it:
321
- output = self.accelerator.unwrap_model(self.model).predict(batch)
322
- predictions = output["predictions"]
323
- labels = batch["labels"]
324
-
325
- # 收集所有进程的预测结果和标签
326
- gathered_predictions = self.accelerator.gather(predictions)
327
- gathered_labels = self.accelerator.gather(labels)
328
-
329
- # 只在主进程中处理收集到的数据
330
- if self.accelerator.is_main_process:
331
- gathered_predictions = (
332
- gathered_predictions.float().cpu().numpy(force=True)
333
- )
334
- gathered_labels = gathered_labels.float().cpu().numpy(force=True)
335
- all_preds.append(gathered_predictions)
336
- all_truth.append(gathered_labels)
337
-
338
- # # 同步所有进程
339
- # self.accelerator.wait_for_everyone()
340
-
341
- # 只在主进程中计算指标
342
- if self.accelerator.is_main_process:
343
- all_preds = np.concatenate(all_preds, axis=0)
344
- all_truth = np.concatenate(all_truth, axis=0)
345
-
346
- if not np.all(all_truth == -100):
347
- valid_metrics = {}
348
- for metric_func in self.compute_metrics:
349
- valid_metrics.update(metric_func(all_truth, all_preds))
350
- else:
351
- valid_metrics = {
352
- "Validation labels predictions may be NaN. No metrics calculated.": 0
353
- }
354
-
355
- # 打印指标信息
356
- fprint(valid_metrics)
357
- else:
358
- valid_metrics = None
359
-
360
- self.predictions.update({"valid": {"pred": all_preds, "true": all_truth}})
361
-
362
- return valid_metrics
363
-
364
- def test(self):
365
- """
366
- Test the model on the test dataset.
367
-
368
- This method runs the model in evaluation mode and computes metrics
369
- on the test dataset. It handles distributed testing and gathers
370
- results from all processes.
371
-
372
- Returns:
373
- dict: Dictionary containing test metrics
374
-
375
- Example:
376
- >>> metrics = trainer.test()
377
- >>> print(f"Test accuracy: {metrics['accuracy']:.4f}")
378
- """
379
- self.model.eval()
380
- all_truth = []
381
- all_preds = []
382
-
383
- it = tqdm(
384
- self.test_loader,
385
- desc="Testing",
386
- disable=not self.accelerator.is_main_process,
387
- )
388
-
389
- with torch.no_grad():
390
- for batch in it:
391
- output = self.accelerator.unwrap_model(self.model).predict(batch)
392
- predictions = output["predictions"]
393
- labels = batch["labels"]
394
-
395
- gathered_predictions = self.accelerator.gather(predictions)
396
- gathered_labels = self.accelerator.gather(labels)
397
-
398
- if self.accelerator.is_main_process:
399
- gathered_predictions = (
400
- gathered_predictions.float().cpu().numpy(force=True)
401
- )
402
- gathered_labels = gathered_labels.float().cpu().numpy(force=True)
403
- all_preds.append(gathered_predictions)
404
- all_truth.append(gathered_labels)
405
-
406
- # # 同步所有进程
407
- # self.accelerator.wait_for_everyone()
408
-
409
- # 只在主进程中计算指标
410
- if self.accelerator.is_main_process:
411
- all_preds = np.concatenate(all_preds, axis=0)
412
- all_truth = np.concatenate(all_truth, axis=0)
413
-
414
- if not np.all(all_truth == -100):
415
- test_metrics = {}
416
- for metric_func in self.compute_metrics:
417
- test_metrics.update(metric_func(all_truth, all_preds))
418
- else:
419
- test_metrics = {
420
- "Test labels predictions may be NaN. No metrics calculated.": 0
421
- }
422
- # 打印指标信息
423
- fprint(test_metrics)
424
- else:
425
- test_metrics = None
426
-
427
- self.predictions.update({"test": {"pred": all_preds, "true": all_truth}})
428
-
429
- return test_metrics
430
-
431
- def train(self, path_to_save=None, **kwargs):
432
- """
433
- Train the model using distributed training.
434
-
435
- This method performs the complete training loop with validation,
436
- early stopping, and model checkpointing. It handles distributed
437
- training across multiple GPUs and processes.
438
-
439
- Args:
440
- path_to_save (str, optional): Path to save the trained model
441
- **kwargs: Additional keyword arguments for model saving
442
-
443
- Returns:
444
- dict: Dictionary containing training metrics
445
-
446
- Example:
447
- >>> metrics = trainer.train(path_to_save="./checkpoints/model")
448
- >>> print(f"Best validation accuracy: {metrics['best_valid']['accuracy']:.4f}")
449
- """
450
- seed_everything(self.seed)
451
- # 在所有进程上创建早停标志
452
- early_stop_flag = torch.tensor(0, device=self.accelerator.device)
453
-
454
- # 确保所有进程同步启动
455
- self.accelerator.wait_for_everyone()
456
-
457
- # Initial validation or test
458
- if self.eval_loader is not None and len(self.eval_loader) > 0:
459
- valid_metrics = self.evaluate()
460
- else:
461
- valid_metrics = self.test()
462
-
463
- # 在主进程中更新指标和保存模型
464
- if self.accelerator.is_main_process:
465
- if self._is_metric_better(valid_metrics, stage="valid"):
466
- self._save_state_dict()
467
- early_stop_flag = torch.tensor(0, device=self.accelerator.device)
468
-
469
- # 使用 all_gather 同步早停标志
470
- gathered_flags = self.accelerator.gather(early_stop_flag)
471
- early_stop_flag = (
472
- gathered_flags if gathered_flags.ndim == 0 else gathered_flags[0]
473
- ) # 使用主进程的值
474
-
475
- for epoch in range(self.epochs):
476
- self.model.train()
477
-
478
- train_it = tqdm(
479
- self.train_loader,
480
- desc=f"Epoch {epoch + 1}/{self.epochs} Loss",
481
- disable=not self.accelerator.is_main_process,
482
- )
483
- # 使用 accelerator.accumulate 控制梯度累积
484
- for step, batch in enumerate(train_it):
485
- train_loss = []
486
-
487
- with self.accelerator.accumulate(self.model):
488
- outputs = self.model(**batch)
489
- if "loss" not in outputs:
490
- # Generally, the model should return a loss in the outputs via OmniGenBench
491
- # For the Lora models, the loss is computed separately
492
- if hasattr(self.model, "loss_function") and callable(
493
- self.model.loss_function
494
- ):
495
- loss = self.model.loss_function(
496
- outputs["logits"], outputs["labels"]
497
- )
498
- elif (
499
- hasattr(self.model, "model")
500
- and hasattr(self.model.model, "loss_function")
501
- and callable(self.model.model.loss_function)
502
- ):
503
- loss = self.model.model.loss_function(
504
- outputs["logits"], outputs["labels"]
505
- )
506
- else:
507
- raise ValueError(
508
- "The model does not have a loss function defined. "
509
- "Please provide a loss function or ensure the model has one."
510
- )
511
- else:
512
- # If the model returns a loss directly
513
- loss = outputs["loss"]
514
-
515
- train_loss.append(loss.item() * self.gradient_accumulation_steps)
516
- train_it.set_description(
517
- f"Epoch {epoch + 1}/{self.epochs} Loss: {np.nanmean(train_loss):.4f}"
518
- )
519
-
520
- self.accelerator.backward(loss)
521
-
522
- self.optimizer.step()
523
- self.optimizer.zero_grad()
524
-
525
- # 同步所有进程后再进行评估
526
- self.accelerator.wait_for_everyone()
527
-
528
- if self.eval_loader is not None and len(self.eval_loader) > 0:
529
- valid_metrics = self.evaluate()
530
- else:
531
- valid_metrics = self.test()
532
-
533
- # 在主进程中更新指标和判断是否需要早停
534
- if self.accelerator.is_main_process:
535
- if self._is_metric_better(valid_metrics, stage="valid"):
536
- self._save_state_dict()
537
- early_stop_flag = torch.tensor(0, device=self.accelerator.device)
538
- else:
539
- early_stop_flag += 1
540
-
541
- # 使用 all_gather 同步早停标志
542
- gathered_flags = self.accelerator.gather(early_stop_flag)
543
- early_stop_flag = (
544
- gathered_flags if gathered_flags.ndim == 0 else gathered_flags[0]
545
- ) # 使用主进程的值
546
-
547
- # 检查是否需要早停
548
- if early_stop_flag.item() > self.patience:
549
- if self.accelerator.is_main_process:
550
- print(f"Early stopping at epoch {epoch + 1}.")
551
- fprint(f"Early stopping at epoch {epoch + 1}.")
552
- break
553
-
554
- # 只在主进程中保存检查点
555
- if path_to_save and self.accelerator.is_main_process:
556
- _path_to_save = path_to_save + "_epoch_" + str(epoch + 1)
557
- if valid_metrics:
558
- for key, value in valid_metrics.items():
559
- _path_to_save += f"_seed_{self.seed}_{key}_{value:.4f}"
560
- self.save_model(_path_to_save, **kwargs)
561
-
562
- # 确保所有进程同步后再进入下一轮
563
- self.accelerator.wait_for_everyone()
564
-
565
- # Final test using the best checkpoint
566
- if self.test_loader is not None and len(self.test_loader) > 0:
567
- self._load_state_dict()
568
- self.accelerator.wait_for_everyone() # 确保加载完成后再测试
569
- test_metrics = self.test()
570
- if self.accelerator.is_main_process:
571
- self._is_metric_better(test_metrics, stage="test")
572
-
573
- # 只在主进程中保存最终模型
574
- if path_to_save and self.accelerator.is_main_process:
575
- _path_to_save = path_to_save + "_final"
576
- if self.metrics.get("test"):
577
- for key, value in self.metrics["test"][-1].items():
578
- _path_to_save += f"_seed_{self.seed}_{key}_{value:.4f}"
579
- self.save_model(_path_to_save, **kwargs)
580
-
581
- self._remove_state_dict()
582
-
583
- self.accelerator.free_memory()
584
- del (
585
- self.optimizer,
586
- self.train_loader,
587
- self.eval_loader,
588
- self.test_loader,
589
- )
590
-
591
- return self.metrics
592
-
593
- def _is_metric_better(self, metrics, stage="valid"):
594
- """
595
- Check if the current metrics are better than the best metrics so far.
596
-
597
- Args:
598
- metrics (dict): Current metrics
599
- stage (str): Stage of evaluation ('valid' or 'test')
600
-
601
- Returns:
602
- bool: True if current metrics are better, False otherwise
603
- """
604
- # 只在主进程中进行metric比较
605
- if not self.accelerator.is_main_process:
606
- return False
607
-
608
- assert stage in [
609
- "valid",
610
- "test",
611
- ], "The metrics stage should be either 'valid' or 'test'."
612
- assert stage in [
613
- "valid",
614
- "test",
615
- ], "The metrics stage should be either 'valid' or 'test'."
616
-
617
- prev_metrics = self.metrics.get(stage, None)
618
- if stage not in self.metrics:
619
- self.metrics.update({f"{stage}": [metrics]})
620
- else:
621
- self.metrics[f"{stage}"].append(metrics)
622
-
623
- if "best_valid" not in self.metrics:
624
- self.metrics.update({"best_valid": metrics})
625
- return True
626
-
627
- if prev_metrics is None:
628
- return False
629
-
630
- self._optimization_direction = (
631
- _infer_optimization_direction(metrics, prev_metrics)
632
- if self._optimization_direction is None
633
- else self._optimization_direction
634
- )
635
-
636
- if self._optimization_direction == "larger_is_better":
637
- if np.mean(list(metrics.values())[0]) > np.mean(
638
- list(self.metrics["best_valid"].values())[0]
639
- ):
640
- self.metrics.update({"best_valid": metrics})
641
- return True
642
- elif self._optimization_direction == "smaller_is_better":
643
- if np.mean(list(metrics.values())[0]) < np.mean(
644
- list(self.metrics["best_valid"].values())[0]
645
- ):
646
- self.metrics.update({"best_valid": metrics})
647
- return True
648
-
649
- return False
650
-
651
- def predict(self, data_loader):
652
- """
653
- Make predictions using the trained model.
654
-
655
- Args:
656
- data_loader: DataLoader containing data to predict on
657
-
658
- Returns:
659
- dict: Dictionary containing predictions
660
- """
661
- return self.accelerator.unwrap_model(self.model).predict(data_loader)
662
-
663
- def get_model(self, **kwargs):
664
- """
665
- Get the trained model.
666
-
667
- Args:
668
- **kwargs: Additional keyword arguments
669
-
670
- Returns:
671
- The trained model
672
- """
673
- return self.model
674
-
675
- def compute_metrics(self):
676
- """
677
- Compute metrics for evaluation.
678
-
679
- This method should be implemented by subclasses to provide specific
680
- metric computation logic.
681
-
682
- Raises:
683
- NotImplementedError: If compute_metrics method is not implemented
684
- """
685
- raise NotImplementedError(
686
- "The compute_metrics() function should be implemented for your model."
687
- " It should return a dictionary of metrics."
688
- )
689
-
690
- def save_model(self, path, overwrite=False, **kwargs):
691
- """
692
- Save the trained model.
693
-
694
- Args:
695
- path (str): Path to save the model
696
- overwrite (bool, optional): Whether to overwrite existing files. Defaults to False
697
- **kwargs: Additional keyword arguments for model saving
698
- """
699
- # Make certain only one process saves, if you're in distributed mode
700
- if self.accelerator.is_main_process:
701
- self.accelerator.unwrap_model(self.model).save(path, overwrite, **kwargs)
702
-
703
- def _load_state_dict(self):
704
- """Load the best model state dictionary."""
705
- if hasattr(self, "_model_state_dict_path") and os.path.exists(
706
- self._model_state_dict_path
707
- ):
708
- weights = torch.load(self._model_state_dict_path, map_location="cpu")
709
- self.accelerator.unwrap_model(self.model).load_state_dict(weights)
710
-
711
- def _save_state_dict(self):
712
- """Save the current model state dictionary."""
713
- if not hasattr(self, "_model_state_dict_path"):
714
- from hashlib import sha256
715
-
716
- time_str = time.strftime("%Y%m%d_%H%M%S", time.localtime())
717
- hash_digest = sha256(self.__repr__().encode("utf-8")).hexdigest()
718
- self._model_state_dict_path = f"tmp_ckpt_{time_str}_{hash_digest}.pt"
719
-
720
- if os.path.exists(self._model_state_dict_path):
721
- os.remove(self._model_state_dict_path)
722
-
723
- # Use accelerator to gather model weights on one process
724
- if self.accelerator.is_main_process:
725
- torch.save(
726
- self.accelerator.unwrap_model(self.model).state_dict(),
727
- self._model_state_dict_path,
728
- )
729
- torch.save(
730
- self.accelerator.unwrap_model(self.model).state_dict(),
731
- self._model_state_dict_path,
732
- )
733
-
734
- def _remove_state_dict(self):
735
- """Remove the temporary model state dictionary file."""
736
- if not hasattr(self, "_model_state_dict_path"):
737
- from hashlib import sha256
738
-
739
- time_str = time.strftime("%Y%m%d_%H%M%S", time.localtime())
740
- hash_digest = sha256(self.__repr__().encode("utf-8")).hexdigest()
741
- self._model_state_dict_path = f"tmp_ckpt_{time_str}_{hash_digest}.pt"
742
-
743
- if (
744
- os.path.exists(self._model_state_dict_path)
745
- and self.accelerator.is_main_process
746
- ):
747
- os.remove(self._model_state_dict_path)