omnigenome 0.3.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of omnigenome might be problematic. Click here for more details.

Files changed (85) hide show
  1. omnigenome/__init__.py +281 -0
  2. omnigenome/auto/__init__.py +3 -0
  3. omnigenome/auto/auto_bench/__init__.py +12 -0
  4. omnigenome/auto/auto_bench/auto_bench.py +484 -0
  5. omnigenome/auto/auto_bench/auto_bench_cli.py +230 -0
  6. omnigenome/auto/auto_bench/auto_bench_config.py +216 -0
  7. omnigenome/auto/auto_bench/config_check.py +34 -0
  8. omnigenome/auto/auto_train/__init__.py +13 -0
  9. omnigenome/auto/auto_train/auto_train.py +430 -0
  10. omnigenome/auto/auto_train/auto_train_cli.py +222 -0
  11. omnigenome/auto/bench_hub/__init__.py +12 -0
  12. omnigenome/auto/bench_hub/bench_hub.py +25 -0
  13. omnigenome/cli/__init__.py +13 -0
  14. omnigenome/cli/commands/__init__.py +13 -0
  15. omnigenome/cli/commands/base.py +83 -0
  16. omnigenome/cli/commands/bench/__init__.py +13 -0
  17. omnigenome/cli/commands/bench/bench_cli.py +202 -0
  18. omnigenome/cli/commands/rna/__init__.py +13 -0
  19. omnigenome/cli/commands/rna/rna_design.py +178 -0
  20. omnigenome/cli/omnigenome_cli.py +128 -0
  21. omnigenome/src/__init__.py +12 -0
  22. omnigenome/src/abc/__init__.py +12 -0
  23. omnigenome/src/abc/abstract_dataset.py +622 -0
  24. omnigenome/src/abc/abstract_metric.py +114 -0
  25. omnigenome/src/abc/abstract_model.py +689 -0
  26. omnigenome/src/abc/abstract_tokenizer.py +267 -0
  27. omnigenome/src/dataset/__init__.py +16 -0
  28. omnigenome/src/dataset/omni_dataset.py +435 -0
  29. omnigenome/src/lora/__init__.py +13 -0
  30. omnigenome/src/lora/lora_model.py +294 -0
  31. omnigenome/src/metric/__init__.py +15 -0
  32. omnigenome/src/metric/classification_metric.py +184 -0
  33. omnigenome/src/metric/metric.py +199 -0
  34. omnigenome/src/metric/ranking_metric.py +142 -0
  35. omnigenome/src/metric/regression_metric.py +191 -0
  36. omnigenome/src/misc/__init__.py +3 -0
  37. omnigenome/src/misc/utils.py +439 -0
  38. omnigenome/src/model/__init__.py +19 -0
  39. omnigenome/src/model/augmentation/__init__.py +12 -0
  40. omnigenome/src/model/augmentation/model.py +219 -0
  41. omnigenome/src/model/classification/__init__.py +12 -0
  42. omnigenome/src/model/classification/model.py +642 -0
  43. omnigenome/src/model/embedding/__init__.py +12 -0
  44. omnigenome/src/model/embedding/model.py +263 -0
  45. omnigenome/src/model/mlm/__init__.py +12 -0
  46. omnigenome/src/model/mlm/model.py +177 -0
  47. omnigenome/src/model/module_utils.py +232 -0
  48. omnigenome/src/model/regression/__init__.py +12 -0
  49. omnigenome/src/model/regression/model.py +786 -0
  50. omnigenome/src/model/regression/resnet.py +483 -0
  51. omnigenome/src/model/rna_design/__init__.py +12 -0
  52. omnigenome/src/model/rna_design/model.py +426 -0
  53. omnigenome/src/model/seq2seq/__init__.py +12 -0
  54. omnigenome/src/model/seq2seq/model.py +44 -0
  55. omnigenome/src/tokenizer/__init__.py +16 -0
  56. omnigenome/src/tokenizer/bpe_tokenizer.py +226 -0
  57. omnigenome/src/tokenizer/kmers_tokenizer.py +247 -0
  58. omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +249 -0
  59. omnigenome/src/trainer/__init__.py +14 -0
  60. omnigenome/src/trainer/accelerate_trainer.py +739 -0
  61. omnigenome/src/trainer/hf_trainer.py +75 -0
  62. omnigenome/src/trainer/trainer.py +579 -0
  63. omnigenome/utility/__init__.py +3 -0
  64. omnigenome/utility/dataset_hub/__init__.py +13 -0
  65. omnigenome/utility/dataset_hub/dataset_hub.py +178 -0
  66. omnigenome/utility/ensemble.py +324 -0
  67. omnigenome/utility/hub_utils.py +517 -0
  68. omnigenome/utility/model_hub/__init__.py +12 -0
  69. omnigenome/utility/model_hub/model_hub.py +231 -0
  70. omnigenome/utility/pipeline_hub/__init__.py +12 -0
  71. omnigenome/utility/pipeline_hub/pipeline.py +483 -0
  72. omnigenome/utility/pipeline_hub/pipeline_hub.py +129 -0
  73. omnigenome-0.3.0a0.dist-info/METADATA +224 -0
  74. omnigenome-0.3.0a0.dist-info/RECORD +85 -0
  75. omnigenome-0.3.0a0.dist-info/WHEEL +5 -0
  76. omnigenome-0.3.0a0.dist-info/entry_points.txt +3 -0
  77. omnigenome-0.3.0a0.dist-info/licenses/LICENSE +201 -0
  78. omnigenome-0.3.0a0.dist-info/top_level.txt +2 -0
  79. tests/__init__.py +9 -0
  80. tests/conftest.py +160 -0
  81. tests/test_dataset_patterns.py +291 -0
  82. tests/test_examples_syntax.py +83 -0
  83. tests/test_model_loading.py +183 -0
  84. tests/test_rna_functions.py +255 -0
  85. tests/test_training_patterns.py +302 -0
@@ -0,0 +1,484 @@
1
+ # -*- coding: utf-8 -*-
2
+ # file: auto_bench.py
3
+ # time: 11:54 14/04/2024
4
+ # author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
5
+ # github: https://github.com/yangheng95
6
+ # huggingface: https://huggingface.co/yangheng
7
+ # google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
8
+ # Copyright (C) 2019-2024. All Rights Reserved.
9
+
10
+ import os
11
+ import time
12
+ import warnings
13
+
14
+ import findfile
15
+ import torch
16
+ from metric_visualizer import MetricVisualizer
17
+
18
+ from transformers import TrainingArguments, Trainer as HFTrainer
19
+ from ...src.abc.abstract_tokenizer import OmniTokenizer
20
+ from ...src.lora.lora_model import OmniLoraModel
21
+ from ...src.misc.utils import (
22
+ seed_everything,
23
+ fprint,
24
+ load_module_from_path,
25
+ check_bench_version,
26
+ clean_temp_checkpoint,
27
+ )
28
+ from ...src.trainer.trainer import Trainer
29
+ from ...src.trainer.accelerate_trainer import AccelerateTrainer
30
+ from ...utility.hub_utils import download_benchmark
31
+ from ... import __version__ as omnigenome_version
32
+
33
+
34
+ class AutoBench:
35
+ """
36
+ AutoBench is a class for automatically benchmarking genomic foundation models.
37
+
38
+ This class provides a comprehensive framework for evaluating genomic models
39
+ across multiple benchmarks and tasks. It handles loading benchmarks, models,
40
+ tokenizers, and running evaluations with proper metric tracking and result
41
+ visualization.
42
+
43
+ AutoBench supports various evaluation scenarios including:
44
+ - Single model evaluation across multiple benchmarks
45
+ - Multi-seed evaluation for robustness testing
46
+ - Different trainer backends (native, accelerate, huggingface)
47
+ - Automatic metric visualization and result tracking
48
+
49
+ Attributes:
50
+ benchmark (str): The name or path of the benchmark to use.
51
+ model_name_or_path (str): The name or path of the model to evaluate.
52
+ tokenizer: The tokenizer to use for evaluation.
53
+ autocast (str): The autocast precision to use ('fp16', 'bf16', etc.).
54
+ overwrite (bool): Whether to overwrite existing evaluation results.
55
+ trainer (str): The trainer to use ('native', 'accelerate', 'hf_trainer').
56
+ mv_path (str): Path to the metric visualizer file.
57
+ mv (MetricVisualizer): The metric visualizer instance.
58
+ bench_metadata: Metadata about the benchmark configuration.
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ benchmark,
64
+ model_name_or_path,
65
+ tokenizer=None,
66
+ **kwargs,
67
+ ):
68
+ """
69
+ Initializes the AutoBench instance.
70
+
71
+ Args:
72
+ benchmark (str): The name or path of the benchmark to use.
73
+ model_name_or_path (str): The name or path of the model to evaluate.
74
+ tokenizer: The tokenizer to use. If None, it will be loaded from the model path.
75
+ **kwargs: Additional keyword arguments.
76
+ - autocast (str): The autocast precision to use ('fp16', 'bf16', etc.).
77
+ Defaults to 'fp16'.
78
+ - overwrite (bool): Whether to overwrite existing evaluation results.
79
+ Defaults to False.
80
+ - trainer (str): The trainer to use ('native', 'accelerate', 'hf_trainer').
81
+ Defaults to 'native'.
82
+
83
+ Example:
84
+ >>> # Initialize with a benchmark and model
85
+ >>> bench = AutoBench("RGB", "model_name")
86
+
87
+ >>> # Initialize with custom settings
88
+ >>> bench = AutoBench("RGB", "model_name",
89
+ ... autocast="bf16", trainer="accelerate")
90
+ """
91
+ self.benchmark = benchmark.rstrip("/")
92
+ self.autocast = kwargs.pop("autocast", "fp16")
93
+ self.overwrite = kwargs.pop("overwrite", False)
94
+ self.trainer = kwargs.pop("trainer", "native")
95
+
96
+ self.model_name_or_path = model_name_or_path
97
+ self.tokenizer = tokenizer
98
+ if isinstance(self.model_name_or_path, str):
99
+ self.model_name_or_path = self.model_name_or_path.rstrip("/")
100
+ self.model_name = self.model_name_or_path.split("/")[-1]
101
+ else:
102
+ self.model_name = self.model_name_or_path.__class__.__name__
103
+ if isinstance(tokenizer, str):
104
+ self.tokenizer = tokenizer.rstrip("/")
105
+ os.makedirs("./autobench_evaluations", exist_ok=True)
106
+ time_str = time.strftime("%Y%m%d_%H%M%S", time.localtime())
107
+ mv_name = f"{benchmark}-{self.model_name}"
108
+ self.mv_path = f"./autobench_evaluations/{mv_name}-{time_str}.mv"
109
+
110
+ mv_paths = findfile.find_files(
111
+ "./autobench_evaluations",
112
+ and_key=[benchmark, self.model_name, ".mv"],
113
+ )
114
+ if mv_paths and not self.overwrite:
115
+ self.mv = MetricVisualizer.load(mv_paths[-1])
116
+ self.mv.summary(round=4)
117
+ else:
118
+ self.mv = MetricVisualizer(self.mv_path)
119
+ if not os.path.exists(self.benchmark):
120
+ fprint(
121
+ "Benchmark:",
122
+ benchmark,
123
+ "does not exist. Search online for available benchmarks.",
124
+ )
125
+ self.benchmark = download_benchmark(self.benchmark)
126
+
127
+ # Import benchmark list
128
+ self.bench_metadata = load_module_from_path(
129
+ f"bench_metadata", f"{self.benchmark}/metadata.py"
130
+ )
131
+ check_bench_version(
132
+ self.bench_metadata.__omnigenome_version__, omnigenome_version
133
+ )
134
+ fprint("Loaded benchmarks: ", self.bench_metadata.bench_list)
135
+ self.bench_info()
136
+
137
+ def bench_info(self):
138
+ """
139
+ Prints and returns information about the current benchmark setup.
140
+
141
+ This method provides a comprehensive overview of the current
142
+ benchmark configuration, including benchmark details, model information,
143
+ and evaluation settings.
144
+
145
+ Returns:
146
+ str: A string containing benchmark information.
147
+
148
+ Example:
149
+ >>> info = bench.bench_info()
150
+ >>> print(info)
151
+ """
152
+ info = f"Benchmark Root: {self.benchmark}\n"
153
+ info += f"Benchmark List: {self.bench_metadata.bench_list}\n"
154
+ info += f"Model Name or Path: {self.model_name}\n"
155
+ info += f"Tokenizer: {self.tokenizer}\n"
156
+ info += f"Metric Visualizer Path: {self.mv_path}\n"
157
+ info += f"BenchConfig Details: {self.bench_metadata}\n"
158
+ fprint(info)
159
+ return info
160
+
161
+ def run(self, **kwargs):
162
+ """
163
+ Runs the benchmarking process.
164
+
165
+ This method iterates through the tasks in the benchmark, loads the corresponding
166
+ configurations, initializes the model, tokenizer, and datasets, and then
167
+ trains and evaluates the model. It supports multiple evaluation seeds and
168
+ various trainer backends.
169
+
170
+ Args:
171
+ **kwargs: Additional keyword arguments that will override the default
172
+ parameters in the benchmark configuration.
173
+
174
+ Example:
175
+ >>> # Run benchmarking with default settings
176
+ >>> bench.run()
177
+
178
+ >>> # Run with custom parameters
179
+ >>> bench.run(learning_rate=1e-4, batch_size=16)
180
+ """
181
+ bs_scale = kwargs.pop("bs_scale", 1)
182
+ # Import benchmark config
183
+ for _, bench in enumerate(self.bench_metadata.bench_list):
184
+ clean_temp_checkpoint(1) # clean temp checkpoint older than 1 day
185
+ fprint(
186
+ ">" * 80,
187
+ f"\nRunning evaluation for task: {bench}",
188
+ "Progress: ",
189
+ _ + 1,
190
+ "/",
191
+ len(self.bench_metadata.bench_list),
192
+ f"{(_ + 1) * 100 / len(self.bench_metadata.bench_list)}%",
193
+ )
194
+ bench_config_path = findfile.find_file(
195
+ self.benchmark, and_key=f"{self.benchmark}.{bench}.config".split(".")
196
+ )
197
+ config = load_module_from_path("config", bench_config_path)
198
+ bench_config = config.bench_config
199
+ fprint(f"Loaded config for {bench} from {bench_config_path}")
200
+ fprint(bench_config)
201
+
202
+ # Init Tokenizer and Model
203
+ if not self.tokenizer:
204
+ tokenizer = OmniTokenizer.from_pretrained(
205
+ self.model_name_or_path,
206
+ trust_remote_code=bench_config.get("trust_remote_code", True),
207
+ **bench_config,
208
+ )
209
+ else:
210
+ tokenizer = self.tokenizer
211
+
212
+ if not isinstance(bench_config["seeds"], list):
213
+ bench_config["seeds"] = [bench_config["seeds"]]
214
+
215
+ random_seeds = bench_config["seeds"]
216
+ for seed in random_seeds:
217
+ _kwargs = kwargs.copy()
218
+ for key, value in _kwargs.items():
219
+ if key in bench_config:
220
+ fprint(
221
+ "Override", key, "with", value, "according to the input kwargs"
222
+ )
223
+ bench_config.update({key: value})
224
+
225
+ else:
226
+ warnings.warn(
227
+ f"kwarg: {key} not found in bench_config while setting {key} = {value}"
228
+ )
229
+ bench_config.update({key: value})
230
+
231
+ for key, value in bench_config.items():
232
+ if key in bench_config and key in _kwargs:
233
+ _kwargs.pop(key)
234
+
235
+ fprint(
236
+ f"AutoBench Config for {bench}:",
237
+ "\n".join([f"{k}: {v}" for k, v in bench_config.items()]),
238
+ )
239
+ for key, value in _kwargs.items():
240
+ if key in bench_config:
241
+ fprint(
242
+ "Override", key, "with", value, "according to the input kwargs"
243
+ )
244
+ bench_config.update({key: value})
245
+
246
+ else:
247
+ warnings.warn(
248
+ f"kwarg: {key} not found in bench_config while setting {key} = {value}"
249
+ )
250
+ bench_config.update({key: value})
251
+
252
+ for key, value in bench_config.items():
253
+ if key in bench_config and key in _kwargs:
254
+ _kwargs.pop(key)
255
+
256
+ fprint(
257
+ f"AutoBench Config for {bench}:",
258
+ "\n".join([f"{k}: {v}" for k, v in bench_config.items()]),
259
+ )
260
+
261
+ batch_size = (
262
+ bench_config["batch_size"] if "batch_size" in bench_config else 8
263
+ ) * bs_scale
264
+
265
+ record_name = f"{self.benchmark}-{bench}-{self.model_name}".split("/")[
266
+ -1
267
+ ]
268
+ # check if the record exists
269
+ if record_name in self.mv.transpose() and len(
270
+ list(self.mv.transpose()[record_name].values())[0]
271
+ ) >= len(bench_config["seeds"]):
272
+ continue
273
+
274
+ seed_everything(seed)
275
+ if self.model_name_or_path:
276
+ model_cls = bench_config["model_cls"]
277
+ model = model_cls(
278
+ self.model_name_or_path,
279
+ tokenizer=tokenizer,
280
+ label2id=bench_config.label2id,
281
+ num_labels=bench_config["num_labels"],
282
+ trust_remote_code=True,
283
+ ignore_mismatched_sizes=True,
284
+ )
285
+ else:
286
+ raise ValueError(
287
+ "model_name_or_path is not specified. Please provide a valid model name or path."
288
+ )
289
+
290
+ fprint(f"\n{model}")
291
+
292
+ if kwargs.get("lora_config", None) is not None:
293
+ fprint("Applying LoRA to the model with config:", kwargs["lora_config"])
294
+ model = OmniLoraModel(model, **kwargs.get("lora_config", {}))
295
+
296
+ # Init Trainer
297
+ dataset_cls = bench_config["dataset_cls"]
298
+
299
+ if hasattr(model.config, "max_position_embeddings"):
300
+ max_length = min(
301
+ bench_config["max_length"],
302
+ model.config.max_position_embeddings,
303
+ )
304
+ else:
305
+ max_length = bench_config["max_length"]
306
+
307
+ train_set = dataset_cls(
308
+ data_source=bench_config["train_file"],
309
+ tokenizer=tokenizer,
310
+ label2id=bench_config["label2id"],
311
+ max_length=max_length,
312
+ structure_in=bench_config.get("structure_in", False),
313
+ max_examples=bench_config.get("max_examples", None),
314
+ shuffle=bench_config.get("shuffle", True),
315
+ drop_long_seq=bench_config.get("drop_long_seq", False),
316
+ **_kwargs,
317
+ )
318
+ test_set = dataset_cls(
319
+ data_source=bench_config["test_file"],
320
+ tokenizer=tokenizer,
321
+ label2id=bench_config["label2id"],
322
+ max_length=max_length,
323
+ structure_in=bench_config.get("structure_in", False),
324
+ max_examples=bench_config.get("max_examples", None),
325
+ shuffle=False,
326
+ drop_long_seq=bench_config.get("drop_long_seq", False),
327
+ **_kwargs,
328
+ )
329
+ valid_set = dataset_cls(
330
+ data_source=bench_config["valid_file"],
331
+ tokenizer=tokenizer,
332
+ label2id=bench_config["label2id"],
333
+ max_length=max_length,
334
+ structure_in=bench_config.get("structure_in", False),
335
+ max_examples=bench_config.get("max_examples", None),
336
+ shuffle=False,
337
+ drop_long_seq=bench_config.get("drop_long_seq", False),
338
+ **_kwargs,
339
+ )
340
+
341
+ if self.trainer == "hf_trainer":
342
+ # Set up HuggingFace Trainer
343
+ hf_kwargs = {
344
+ k: v
345
+ for k, v in kwargs.items()
346
+ if hasattr(TrainingArguments, k) and k != "output_dir"
347
+ }
348
+ training_args = TrainingArguments(
349
+ output_dir=f"./autobench_evaluations/{self.model_name}-{bench}",
350
+ num_train_epochs=hf_kwargs.pop(
351
+ "num_train_epochs", bench_config["epochs"]
352
+ ),
353
+ per_device_train_batch_size=hf_kwargs.pop(
354
+ "batch_size", batch_size
355
+ ),
356
+ per_device_eval_batch_size=hf_kwargs.pop(
357
+ "batch_size", batch_size
358
+ ),
359
+ gradient_accumulation_steps=hf_kwargs.pop(
360
+ "gradient_accumulation_steps", 1
361
+ ),
362
+ learning_rate=hf_kwargs.pop("learning_rate", 2e-5),
363
+ weight_decay=hf_kwargs.pop("weight_decay", 0),
364
+ eval_strategy=hf_kwargs.pop("eval_strategy", "epoch"),
365
+ save_strategy=hf_kwargs.pop("save_strategy", "epoch"),
366
+ fp16=hf_kwargs.pop("fp16", True),
367
+ remove_unused_columns=False,
368
+ label_names=["labels"],
369
+ **hf_kwargs,
370
+ )
371
+
372
+ valid_set = valid_set if len(valid_set) else test_set
373
+
374
+ if len(bench_config["compute_metrics"]) > 1:
375
+ fprint(
376
+ "Multiple metrics not supported by HFTrainer, using the first one metric only."
377
+ )
378
+ trainer = HFTrainer(
379
+ model=model,
380
+ args=training_args,
381
+ train_dataset=train_set,
382
+ eval_dataset=valid_set,
383
+ compute_metrics=(
384
+ bench_config["compute_metrics"][0]
385
+ if isinstance(bench_config["compute_metrics"], list)
386
+ else bench_config["compute_metrics"]
387
+ ),
388
+ )
389
+
390
+ # Train and evaluate
391
+ eval_result = trainer.evaluate(
392
+ valid_set if len(valid_set) else test_set
393
+ )
394
+ print(eval_result)
395
+ train_result = trainer.train()
396
+ eval_result = trainer.evaluate()
397
+ test_result = trainer.evaluate(
398
+ test_set if len(test_set) else valid_set
399
+ )
400
+
401
+ metrics = {
402
+ "train": train_result.metrics,
403
+ "eval": eval_result,
404
+ "test": test_result,
405
+ }
406
+ fprint(metrics)
407
+ else:
408
+ optimizer = torch.optim.AdamW(
409
+ filter(lambda p: p.requires_grad, model.parameters()),
410
+ lr=(
411
+ bench_config["learning_rate"]
412
+ if "learning_rate" in bench_config
413
+ else 2e-5
414
+ ),
415
+ weight_decay=(
416
+ bench_config["weight_decay"]
417
+ if "weight_decay" in bench_config
418
+ else 0
419
+ ),
420
+ )
421
+ if self.trainer == "accelerate":
422
+ trainer_cls = AccelerateTrainer
423
+ else:
424
+ trainer_cls = Trainer
425
+ fprint(f"Using Trainer: {trainer_cls}")
426
+ trainer = trainer_cls(
427
+ model=model,
428
+ train_dataset=train_set,
429
+ eval_dataset=valid_set,
430
+ test_dataset=test_set,
431
+ batch_size=batch_size,
432
+ patience=(
433
+ bench_config["patience"]
434
+ if "patience" in bench_config
435
+ else 3
436
+ ),
437
+ epochs=bench_config["epochs"],
438
+ gradient_accumulation_steps=bench_config.get(
439
+ "gradient_accumulation_steps", 1
440
+ ),
441
+ optimizer=optimizer,
442
+ loss_fn=(
443
+ bench_config["loss_fn"]
444
+ if "loss_fn" in bench_config
445
+ else None
446
+ ),
447
+ compute_metrics=bench_config["compute_metrics"],
448
+ seed=seed,
449
+ autocast=self.autocast,
450
+ **_kwargs,
451
+ )
452
+ metrics = trainer.train()
453
+
454
+ predictions = trainer.predictions
455
+
456
+ if bench_config.get("save_predictions", False):
457
+ os.makedirs(f"predictions/{bench}", exist_ok=True)
458
+ import numpy as np
459
+
460
+ for split in predictions.keys():
461
+ with open(
462
+ f"predictions/{bench}/{split}.npy",
463
+ "wb",
464
+ ) as f:
465
+ np.save(f, predictions[split])
466
+
467
+ if metrics:
468
+ for key, value in metrics["test"][-1].items():
469
+ try:
470
+ value = float(value)
471
+ except:
472
+ pass # ignore non-float values
473
+ self.mv.log(f"{record_name}", f"{key}", value)
474
+ # for key, value in metrics['test'][-1].items():
475
+ # self.mv.log(f'{record_name}', f'test_{key}', value)
476
+ # for i, valid_metrics in enumerate(metrics["valid"]):
477
+ # for key, value in valid_metrics.items():
478
+ # self.mv.log(f'{record_name}', f'valid_epoch_{i}_{key}', value)
479
+
480
+ self.mv.summary(round=4)
481
+ self.mv.dump(self.mv_path)
482
+ self.mv.to_csv(self.mv_path.replace(".mv", ".csv"))
483
+ del model, trainer, optimizer
484
+ torch.cuda.empty_cache()