omnigenome 0.3.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of omnigenome might be problematic. Click here for more details.

Files changed (85) hide show
  1. omnigenome/__init__.py +281 -0
  2. omnigenome/auto/__init__.py +3 -0
  3. omnigenome/auto/auto_bench/__init__.py +12 -0
  4. omnigenome/auto/auto_bench/auto_bench.py +484 -0
  5. omnigenome/auto/auto_bench/auto_bench_cli.py +230 -0
  6. omnigenome/auto/auto_bench/auto_bench_config.py +216 -0
  7. omnigenome/auto/auto_bench/config_check.py +34 -0
  8. omnigenome/auto/auto_train/__init__.py +13 -0
  9. omnigenome/auto/auto_train/auto_train.py +430 -0
  10. omnigenome/auto/auto_train/auto_train_cli.py +222 -0
  11. omnigenome/auto/bench_hub/__init__.py +12 -0
  12. omnigenome/auto/bench_hub/bench_hub.py +25 -0
  13. omnigenome/cli/__init__.py +13 -0
  14. omnigenome/cli/commands/__init__.py +13 -0
  15. omnigenome/cli/commands/base.py +83 -0
  16. omnigenome/cli/commands/bench/__init__.py +13 -0
  17. omnigenome/cli/commands/bench/bench_cli.py +202 -0
  18. omnigenome/cli/commands/rna/__init__.py +13 -0
  19. omnigenome/cli/commands/rna/rna_design.py +178 -0
  20. omnigenome/cli/omnigenome_cli.py +128 -0
  21. omnigenome/src/__init__.py +12 -0
  22. omnigenome/src/abc/__init__.py +12 -0
  23. omnigenome/src/abc/abstract_dataset.py +622 -0
  24. omnigenome/src/abc/abstract_metric.py +114 -0
  25. omnigenome/src/abc/abstract_model.py +689 -0
  26. omnigenome/src/abc/abstract_tokenizer.py +267 -0
  27. omnigenome/src/dataset/__init__.py +16 -0
  28. omnigenome/src/dataset/omni_dataset.py +435 -0
  29. omnigenome/src/lora/__init__.py +13 -0
  30. omnigenome/src/lora/lora_model.py +294 -0
  31. omnigenome/src/metric/__init__.py +15 -0
  32. omnigenome/src/metric/classification_metric.py +184 -0
  33. omnigenome/src/metric/metric.py +199 -0
  34. omnigenome/src/metric/ranking_metric.py +142 -0
  35. omnigenome/src/metric/regression_metric.py +191 -0
  36. omnigenome/src/misc/__init__.py +3 -0
  37. omnigenome/src/misc/utils.py +439 -0
  38. omnigenome/src/model/__init__.py +19 -0
  39. omnigenome/src/model/augmentation/__init__.py +12 -0
  40. omnigenome/src/model/augmentation/model.py +219 -0
  41. omnigenome/src/model/classification/__init__.py +12 -0
  42. omnigenome/src/model/classification/model.py +642 -0
  43. omnigenome/src/model/embedding/__init__.py +12 -0
  44. omnigenome/src/model/embedding/model.py +263 -0
  45. omnigenome/src/model/mlm/__init__.py +12 -0
  46. omnigenome/src/model/mlm/model.py +177 -0
  47. omnigenome/src/model/module_utils.py +232 -0
  48. omnigenome/src/model/regression/__init__.py +12 -0
  49. omnigenome/src/model/regression/model.py +786 -0
  50. omnigenome/src/model/regression/resnet.py +483 -0
  51. omnigenome/src/model/rna_design/__init__.py +12 -0
  52. omnigenome/src/model/rna_design/model.py +426 -0
  53. omnigenome/src/model/seq2seq/__init__.py +12 -0
  54. omnigenome/src/model/seq2seq/model.py +44 -0
  55. omnigenome/src/tokenizer/__init__.py +16 -0
  56. omnigenome/src/tokenizer/bpe_tokenizer.py +226 -0
  57. omnigenome/src/tokenizer/kmers_tokenizer.py +247 -0
  58. omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +249 -0
  59. omnigenome/src/trainer/__init__.py +14 -0
  60. omnigenome/src/trainer/accelerate_trainer.py +739 -0
  61. omnigenome/src/trainer/hf_trainer.py +75 -0
  62. omnigenome/src/trainer/trainer.py +579 -0
  63. omnigenome/utility/__init__.py +3 -0
  64. omnigenome/utility/dataset_hub/__init__.py +13 -0
  65. omnigenome/utility/dataset_hub/dataset_hub.py +178 -0
  66. omnigenome/utility/ensemble.py +324 -0
  67. omnigenome/utility/hub_utils.py +517 -0
  68. omnigenome/utility/model_hub/__init__.py +12 -0
  69. omnigenome/utility/model_hub/model_hub.py +231 -0
  70. omnigenome/utility/pipeline_hub/__init__.py +12 -0
  71. omnigenome/utility/pipeline_hub/pipeline.py +483 -0
  72. omnigenome/utility/pipeline_hub/pipeline_hub.py +129 -0
  73. omnigenome-0.3.0a0.dist-info/METADATA +224 -0
  74. omnigenome-0.3.0a0.dist-info/RECORD +85 -0
  75. omnigenome-0.3.0a0.dist-info/WHEEL +5 -0
  76. omnigenome-0.3.0a0.dist-info/entry_points.txt +3 -0
  77. omnigenome-0.3.0a0.dist-info/licenses/LICENSE +201 -0
  78. omnigenome-0.3.0a0.dist-info/top_level.txt +2 -0
  79. tests/__init__.py +9 -0
  80. tests/conftest.py +160 -0
  81. tests/test_dataset_patterns.py +291 -0
  82. tests/test_examples_syntax.py +83 -0
  83. tests/test_model_loading.py +183 -0
  84. tests/test_rna_functions.py +255 -0
  85. tests/test_training_patterns.py +302 -0
@@ -0,0 +1,430 @@
1
+ # -*- coding: utf-8 -*-
2
+ # file: auto_bench.py
3
+ # time: 11:54 14/04/2024
4
+ # author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
5
+ # github: https://github.com/yangheng95
6
+ # huggingface: https://huggingface.co/yangheng
7
+ # google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
8
+ # Copyright (C) 2019-2024. All Rights Reserved.
9
+
10
+ import os
11
+ import time
12
+ import warnings
13
+
14
+ import findfile
15
+ import torch
16
+ from metric_visualizer import MetricVisualizer
17
+ from transformers import TrainingArguments, Trainer as HFTrainer
18
+
19
+ from ...src.lora.lora_model import OmniLoraModel
20
+ from ...src.abc.abstract_tokenizer import OmniTokenizer
21
+ from ...src.misc.utils import (
22
+ seed_everything,
23
+ fprint,
24
+ load_module_from_path,
25
+ clean_temp_checkpoint,
26
+ )
27
+ from ...src.trainer.accelerate_trainer import AccelerateTrainer
28
+ from ...src.trainer.trainer import Trainer
29
+
30
+ autotrain_evaluations = "./autotrain_evaluations"
31
+
32
+
33
+ class AutoTrain:
34
+ """
35
+ AutoTrain is a class for automatically training genomic foundation models on a given dataset.
36
+
37
+ This class provides a comprehensive framework for training genomic models
38
+ on various datasets with minimal configuration. It handles dataset loading,
39
+ model initialization, training configuration, and result tracking.
40
+
41
+ AutoTrain supports various training scenarios including:
42
+ - Single dataset training with multiple seeds
43
+ - Different trainer backends (native, accelerate, huggingface)
44
+ - Automatic metric visualization and result tracking
45
+ - Configurable training parameters
46
+
47
+ Attributes:
48
+ dataset (str): The name or path of the dataset to use for training.
49
+ model_name_or_path (str): The name or path of the model to train.
50
+ tokenizer: The tokenizer to use for training.
51
+ autocast (str): The autocast precision to use ('fp16', 'bf16', etc.).
52
+ overwrite (bool): Whether to overwrite existing training results.
53
+ trainer (str): The trainer to use ('native', 'accelerate', 'hf_trainer').
54
+ mv_path (str): Path to the metric visualizer file.
55
+ mv (MetricVisualizer): The metric visualizer instance.
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ dataset,
61
+ model_name_or_path,
62
+ tokenizer=None,
63
+ **kwargs,
64
+ ):
65
+ """
66
+ Initialize the AutoTrain instance.
67
+
68
+ Args:
69
+ dataset (str): The name or path of the dataset to use for training.
70
+ model_name_or_path (str): The name or path of the model to train.
71
+ tokenizer: The tokenizer to use. If None, it will be loaded from the model path.
72
+ **kwargs: Additional keyword arguments.
73
+ - autocast (str): The autocast precision to use ('fp16', 'bf16', etc.).
74
+ Defaults to 'fp16'.
75
+ - overwrite (bool): Whether to overwrite existing training results.
76
+ Defaults to False.
77
+ - trainer (str): The trainer to use ('native', 'accelerate', 'hf_trainer').
78
+ Defaults to 'accelerate'.
79
+
80
+ Example:
81
+ >>> # Initialize with a dataset and model
82
+ >>> trainer = AutoTrain("dataset_name", "model_name")
83
+
84
+ >>> # Initialize with custom settings
85
+ >>> trainer = AutoTrain("dataset_name", "model_name",
86
+ ... autocast="bf16", trainer="accelerate")
87
+ """
88
+ self.dataset = dataset.rstrip("/")
89
+ self.autocast = kwargs.pop("autocast", "fp16")
90
+ self.overwrite = kwargs.pop("overwrite", False)
91
+ self.trainer = kwargs.pop("trainer", "accelerate")
92
+
93
+ self.model_name_or_path = model_name_or_path
94
+ self.tokenizer = tokenizer
95
+ if isinstance(self.model_name_or_path, str):
96
+ self.model_name_or_path = self.model_name_or_path.rstrip("/")
97
+ self.model_name = self.model_name_or_path.split("/")[-1]
98
+ else:
99
+ self.model_name = self.model_name_or_path.__class__.__name__
100
+ if isinstance(tokenizer, str):
101
+ self.tokenizer = tokenizer.rstrip("/")
102
+ os.makedirs(autotrain_evaluations, exist_ok=True)
103
+ time_str = time.strftime("%Y%m%d_%H%M%S", time.localtime())
104
+ mv_name = f"{dataset}-{self.model_name}"
105
+ self.mv_path = f"{autotrain_evaluations}/{mv_name}-{time_str}.mv"
106
+
107
+ mv_paths = findfile.find_files(
108
+ autotrain_evaluations,
109
+ [dataset, self.model_name, ".mv"],
110
+ )
111
+ if mv_paths and not self.overwrite:
112
+ self.mv = MetricVisualizer.load(mv_paths[-1])
113
+ self.mv.summary(round=4)
114
+ else:
115
+ self.mv = MetricVisualizer(self.mv_path)
116
+ self.bench_info()
117
+
118
+ def bench_info(self):
119
+ """
120
+ Print and return information about the current training setup.
121
+
122
+ This method provides a comprehensive overview of the current
123
+ training configuration, including dataset details, model information,
124
+ and training settings.
125
+
126
+ Returns:
127
+ str: A string containing training setup information.
128
+
129
+ Example:
130
+ >>> info = trainer.bench_info()
131
+ >>> print(info)
132
+ """
133
+ info = f"Dataset Root: {self.dataset}\n"
134
+ info += f"Model Name or Path: {self.model_name}\n"
135
+ info += f"Tokenizer: {self.tokenizer}\n"
136
+ info += f"Metric Visualizer Path: {self.mv_path}\n"
137
+ fprint(info)
138
+ return info
139
+
140
+ def run(self, **kwargs):
141
+ """
142
+ Run the training process.
143
+
144
+ This method loads the dataset configuration, initializes the model and
145
+ tokenizer, and runs training across multiple seeds. It supports various
146
+ training backends and automatic result tracking.
147
+
148
+ Args:
149
+ **kwargs: Additional keyword arguments that will override the default
150
+ parameters in the dataset configuration.
151
+
152
+ Example:
153
+ >>> # Run training with default settings
154
+ >>> trainer.run()
155
+
156
+ >>> # Run with custom parameters
157
+ >>> trainer.run(learning_rate=1e-4, batch_size=16)
158
+ """
159
+
160
+
161
+ clean_temp_checkpoint(1) # clean temp checkpoint older than 1 day
162
+
163
+ _kwargs = kwargs.copy()
164
+ bench_config_path = findfile.find_file(
165
+ self.dataset, f"{self.dataset}.config".split(".")
166
+ )
167
+ config = load_module_from_path("config", bench_config_path)
168
+ bench_config = config.bench_config
169
+ fprint(f"Loaded config for {self.dataset} from {bench_config_path}")
170
+ fprint(bench_config)
171
+
172
+ # Init Tokenizer and Model
173
+ if not self.tokenizer:
174
+ tokenizer = OmniTokenizer.from_pretrained(
175
+ self.model_name_or_path, trust_remote_code=True
176
+ )
177
+ else:
178
+ tokenizer = self.tokenizer
179
+
180
+ if not isinstance(bench_config["seeds"], list):
181
+ bench_config["seeds"] = [bench_config["seeds"]]
182
+
183
+ random_seeds = bench_config["seeds"]
184
+ for seed in random_seeds:
185
+ for key, value in _kwargs.items():
186
+ if key in bench_config:
187
+ fprint(
188
+ "Override", key, "with", value, "according to the input kwargs"
189
+ )
190
+ bench_config.update({key: value})
191
+
192
+ else:
193
+ warnings.warn(
194
+ f"kwarg: {key} not found in bench_config while setting {key} = {value}"
195
+ )
196
+ bench_config.update({key: value})
197
+
198
+ for key, value in bench_config.items():
199
+ if key in bench_config and key in _kwargs:
200
+ _kwargs.pop(key)
201
+
202
+ fprint(
203
+ f"AutoBench Config for {self.dataset}:",
204
+ "\n".join([f"{k}: {v}" for k, v in bench_config.items()]),
205
+ )
206
+ for key, value in _kwargs.items():
207
+ if key in bench_config:
208
+ fprint(
209
+ "Override", key, "with", value, "according to the input kwargs"
210
+ )
211
+ bench_config.update({key: value})
212
+
213
+ else:
214
+ warnings.warn(
215
+ f"kwarg: {key} not found in bench_config while setting {key} = {value}"
216
+ )
217
+ bench_config.update({key: value})
218
+
219
+ for key, value in bench_config.items():
220
+ if key in bench_config and key in _kwargs:
221
+ _kwargs.pop(key)
222
+
223
+ fprint(
224
+ f"AutoBench Config for {self.dataset}:",
225
+ "\n".join([f"{k}: {v}" for k, v in bench_config.items()]),
226
+ )
227
+
228
+ batch_size = (
229
+ bench_config["batch_size"] if "batch_size" in bench_config else 8
230
+ )
231
+
232
+ record_name = f"{os.path.basename(self.dataset)}-{self.model_name}".split(
233
+ "/"
234
+ )[-1]
235
+ # check if the record exists
236
+ if record_name in self.mv.transpose() and len(
237
+ list(self.mv.transpose()[record_name].values())[0]
238
+ ) >= len(bench_config["seeds"]):
239
+ continue
240
+
241
+ seed_everything(seed)
242
+ if self.model_name_or_path:
243
+ model_cls = bench_config["model_cls"]
244
+ model = model_cls(
245
+ self.model_name_or_path,
246
+ tokenizer=tokenizer,
247
+ label2id=bench_config.label2id,
248
+ num_labels=bench_config["num_labels"],
249
+ trust_remote_code=True,
250
+ ignore_mismatched_sizes=True,
251
+ )
252
+ else:
253
+ raise ValueError(
254
+ "model_name_or_path is not specified. Please provide a valid model name or path."
255
+ )
256
+
257
+ if kwargs.get("lora_config", None) is not None:
258
+ fprint("Applying LoRA to the model with config:", kwargs["lora_config"])
259
+ model = OmniLoraModel(model, **kwargs.get("lora_config", {}))
260
+
261
+ # Init Trainer
262
+ dataset_cls = bench_config["dataset_cls"]
263
+
264
+ if hasattr(model.config, "max_position_embeddings"):
265
+ max_length = min(
266
+ bench_config["max_length"],
267
+ model.config.max_position_embeddings,
268
+ )
269
+ else:
270
+ max_length = bench_config["max_length"]
271
+
272
+ train_set = dataset_cls(
273
+ data_source=bench_config["train_file"],
274
+ tokenizer=tokenizer,
275
+ label2id=bench_config["label2id"],
276
+ max_length=max_length,
277
+ structure_in=bench_config.get("structure_in", False),
278
+ max_examples=bench_config.get("max_examples", None),
279
+ shuffle=bench_config.get("shuffle", True),
280
+ drop_long_seq=bench_config.get("drop_long_seq", False),
281
+ **_kwargs,
282
+ )
283
+ test_set = dataset_cls(
284
+ data_source=bench_config["test_file"],
285
+ tokenizer=tokenizer,
286
+ label2id=bench_config["label2id"],
287
+ max_length=max_length,
288
+ structure_in=bench_config.get("structure_in", False),
289
+ max_examples=bench_config.get("max_examples", None),
290
+ shuffle=False,
291
+ drop_long_seq=bench_config.get("drop_long_seq", False),
292
+ **_kwargs,
293
+ )
294
+ valid_set = dataset_cls(
295
+ data_source=bench_config["valid_file"],
296
+ tokenizer=tokenizer,
297
+ label2id=bench_config["label2id"],
298
+ max_length=max_length,
299
+ structure_in=bench_config.get("structure_in", False),
300
+ max_examples=bench_config.get("max_examples", None),
301
+ shuffle=False,
302
+ drop_long_seq=bench_config.get("drop_long_seq", False),
303
+ **_kwargs,
304
+ )
305
+
306
+ if self.trainer == "hf_trainer":
307
+ # Set up HuggingFace Trainer
308
+ hf_kwargs = {
309
+ k: v
310
+ for k, v in kwargs.items()
311
+ if hasattr(TrainingArguments, k) and k != "output_dir"
312
+ }
313
+ training_args = TrainingArguments(
314
+ output_dir=f"./autotrain_evaluations/{self.model_name}",
315
+ num_train_epochs=hf_kwargs.pop(
316
+ "num_train_epochs", bench_config["epochs"]
317
+ ),
318
+ per_device_train_batch_size=hf_kwargs.pop("batch_size", batch_size),
319
+ per_device_eval_batch_size=hf_kwargs.pop("batch_size", batch_size),
320
+ gradient_accumulation_steps=hf_kwargs.pop(
321
+ "gradient_accumulation_steps", 1
322
+ ),
323
+ learning_rate=hf_kwargs.pop("learning_rate", 2e-5),
324
+ weight_decay=hf_kwargs.pop("weight_decay", 0),
325
+ eval_strategy=hf_kwargs.pop("eval_strategy", "epoch"),
326
+ save_strategy=hf_kwargs.pop("save_strategy", "epoch"),
327
+ fp16=hf_kwargs.pop("fp16", True),
328
+ remove_unused_columns=False,
329
+ label_names=["labels"],
330
+ **hf_kwargs,
331
+ )
332
+
333
+ valid_set = valid_set if len(valid_set) else test_set
334
+
335
+ if len(bench_config["compute_metrics"]) > 1:
336
+ fprint(
337
+ "Multiple metrics not supported by HFTrainer, using the first one metric only."
338
+ )
339
+ trainer = HFTrainer(
340
+ model=model,
341
+ args=training_args,
342
+ train_dataset=train_set,
343
+ eval_dataset=valid_set,
344
+ compute_metrics=(
345
+ bench_config["compute_metrics"][0]
346
+ if isinstance(bench_config["compute_metrics"], list)
347
+ else bench_config["compute_metrics"]
348
+ ),
349
+ )
350
+
351
+ # Train and evaluate
352
+ eval_result = trainer.evaluate(
353
+ valid_set if len(valid_set) else test_set
354
+ )
355
+ print(eval_result)
356
+ train_result = trainer.train()
357
+ eval_result = trainer.evaluate()
358
+ test_result = trainer.evaluate(test_set if len(test_set) else valid_set)
359
+
360
+ metrics = {
361
+ "train": train_result.metrics,
362
+ "eval": eval_result,
363
+ "test": test_result,
364
+ }
365
+ fprint(metrics)
366
+ else:
367
+ optimizer = torch.optim.AdamW(
368
+ filter(lambda p: p.requires_grad, model.parameters()),
369
+ lr=(
370
+ bench_config["learning_rate"]
371
+ if "learning_rate" in bench_config
372
+ else 2e-5
373
+ ),
374
+ weight_decay=(
375
+ bench_config["weight_decay"]
376
+ if "weight_decay" in bench_config
377
+ else 0
378
+ ),
379
+ )
380
+ if self.trainer == "accelerate":
381
+ trainer_cls = AccelerateTrainer
382
+ else:
383
+ trainer_cls = Trainer
384
+ fprint(f"Using Trainer: {trainer_cls}")
385
+ trainer = trainer_cls(
386
+ model=model,
387
+ train_dataset=train_set,
388
+ eval_dataset=valid_set,
389
+ test_dataset=test_set,
390
+ batch_size=batch_size,
391
+ patience=(
392
+ bench_config["patience"] if "patience" in bench_config else 3
393
+ ),
394
+ epochs=bench_config["epochs"],
395
+ gradient_accumulation_steps=bench_config.get(
396
+ "gradient_accumulation_steps", 1
397
+ ),
398
+ optimizer=optimizer,
399
+ loss_fn=(
400
+ bench_config["loss_fn"] if "loss_fn" in bench_config else None
401
+ ),
402
+ compute_metrics=bench_config["compute_metrics"],
403
+ seed=seed,
404
+ autocast=self.autocast,
405
+ **_kwargs,
406
+ )
407
+ metrics = trainer.train()
408
+ save_path = os.path.join(
409
+ autotrain_evaluations, self.dataset, self.model_name
410
+ )
411
+ trainer.save_model(save_path)
412
+
413
+ if metrics:
414
+ for key, value in metrics["test"][-1].items():
415
+ try:
416
+ value = float(value)
417
+ except:
418
+ pass # ignore non-float values
419
+ self.mv.log(f"{record_name}", f"{key}", value)
420
+ # for key, value in metrics['test'][-1].items():
421
+ # self.mv.log(f'{record_name}', f'test_{key}', value)
422
+ # for i, valid_metrics in enumerate(metrics["valid"]):
423
+ # for key, value in valid_metrics.items():
424
+ # self.mv.log(f'{record_name}', f'valid_epoch_{i}_{key}', value)
425
+
426
+ self.mv.summary(round=4)
427
+ self.mv.dump(self.mv_path)
428
+ self.mv.to_csv(self.mv_path.replace(".mv", ".csv"))
429
+ del model, trainer, optimizer
430
+ torch.cuda.empty_cache()
@@ -0,0 +1,222 @@
1
+ # -*- coding: utf-8 -*-
2
+ # file: auto_bench_cli.py
3
+ # time: 19:18 05/02/2025
4
+ # author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
5
+ # Homepage: https://yangheng95.github.io
6
+ # github: https://github.com/yangheng95
7
+ # huggingface: https://huggingface.co/yangheng
8
+ # google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
9
+ # Copyright (C) 2019-2025. All Rights Reserved.
10
+
11
+ import argparse
12
+ import os
13
+ import platform
14
+ import sys
15
+ import time
16
+
17
+ from typing import Optional
18
+ from omnigenome import AutoTrain
19
+ from omnigenome.src.misc.utils import fprint
20
+
21
+
22
+ def train_command(args: Optional[list] = None):
23
+ """
24
+ Entry point for the OmniGenome auto-train command-line interface.
25
+
26
+ :param args: A list of command-line arguments. If None, `sys.argv` is used.
27
+ """
28
+
29
+ parser = create_parser()
30
+ parsed_args = parser.parse_args(args)
31
+
32
+ model_path = parsed_args.model
33
+ fprint(f"\n>> Starting evaluation for model: {model_path}")
34
+
35
+ # Special handling for multimolecule models
36
+ if "multimolecule" in model_path:
37
+ from multimolecule import RnaTokenizer, AutoModelForTokenPrediction
38
+
39
+ tokenizer = RnaTokenizer.from_pretrained(model_path)
40
+ model = AutoModelForTokenPrediction.from_pretrained(
41
+ model_path, trust_remote_code=True
42
+ ).base_model
43
+ else:
44
+ tokenizer = parsed_args.tokenizer
45
+ model = model_path
46
+
47
+ # Initialize AutoTraining
48
+ autobench = AutoTrain(
49
+ dataset=parsed_args.dataset,
50
+ model_name_or_path=model,
51
+ tokenizer=tokenizer,
52
+ overwrite=parsed_args.overwrite,
53
+ trainer=parsed_args.trainer,
54
+ )
55
+
56
+ # Run evaluation
57
+ autobench.run(**vars(parsed_args))
58
+
59
+
60
+ def create_parser() -> argparse.ArgumentParser:
61
+ """
62
+ Creates the argument parser for the auto-train CLI.
63
+
64
+ :return: An `argparse.ArgumentParser` instance.
65
+ """
66
+ parser = argparse.ArgumentParser(
67
+ description="Genomic Foundation Model Benchmark Suite (Single Model)",
68
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
69
+ )
70
+ # Required argument
71
+ parser.add_argument(
72
+ "-d",
73
+ "--dataset",
74
+ type=str,
75
+ help="Path to the dataset and training configuration file.",
76
+ )
77
+ parser.add_argument(
78
+ "-t",
79
+ "--tokenizer",
80
+ type=str,
81
+ default=None,
82
+ help="Path to the tokenizer to use (HF tokenizer ID or local path).",
83
+ )
84
+
85
+ parser.add_argument(
86
+ "-m",
87
+ "--model",
88
+ type=str,
89
+ required=True,
90
+ help="Path to the model to evaluate (HF model ID or local path).",
91
+ )
92
+
93
+ # Optional arguments
94
+ parser.add_argument(
95
+ "--overwrite",
96
+ type=bool,
97
+ default=False,
98
+ help="Overwrite existing bench results, otherwise resume from training checkpoint.",
99
+ )
100
+ parser.add_argument(
101
+ "--bs_scale",
102
+ type=int,
103
+ default=1,
104
+ help="Batch size scale factor. To increase GPU memory utilization, set to 2 or 4, etc.",
105
+ )
106
+ parser.add_argument(
107
+ "--trainer",
108
+ type=str,
109
+ default="accelerate",
110
+ choices=["native", "accelerate", "hf_trainer"],
111
+ help="Trainer to use for training. \n"
112
+ "Use 'accelerate' for distributed training. Set to false to disable. "
113
+ "You can use 'accelerate config' to customize behavior.\n"
114
+ "Use 'hf_trainer' for Hugging Face Trainer. \n"
115
+ "Set to 'native' to use native PyTorch training loop.\n",
116
+ )
117
+ parser.add_argument(
118
+ "--autocast",
119
+ type=str,
120
+ default="fp16",
121
+ choices=["fp16", "fp32", "bf16", "fp8", "no"],
122
+ help="Automatic mixed precision training mode.",
123
+ )
124
+ return parser
125
+
126
+
127
+ def run_train():
128
+ """
129
+ Wrapper function to run the auto-train command.
130
+ """
131
+
132
+ fprint("Running AutoTraining, this may take a while, please be patient...")
133
+ fprint("You can find the logs in the 'autobench_logs' directory.")
134
+ fprint("You can find the metrics in the 'autobench_evaluations' directory.")
135
+ fprint(
136
+ "If you don't intend to use accelerate, please add '--trainer native' to the command."
137
+ )
138
+ fprint(
139
+ "If you want to alter accelerate's behavior, please refer to 'accelerate config' command."
140
+ )
141
+ fprint("If you encounter any issues, please report them on the GitHub repository.")
142
+ os.makedirs("autobench_logs", exist_ok=True)
143
+ time_str = time.strftime("%Y-%m-%d-%H-%M-%S")
144
+ log_file = f"autobench_logs/AutoBench-{time_str}.log"
145
+ from pathlib import Path
146
+
147
+ try:
148
+ mixed_precision = sys.argv[sys.argv.index("--autocast") + 1].lower()
149
+ except ValueError:
150
+ mixed_precision = "fp16"
151
+ file_path = Path(__file__).resolve()
152
+ if (
153
+ "--trainer" in sys.argv
154
+ and sys.argv[sys.argv.index("--trainer") + 1].lower() == "native"
155
+ ):
156
+ cmd_base = f'python "{file_path}" ' + " ".join(sys.argv[1:])
157
+ else:
158
+ cmd_base = (
159
+ f'accelerate launch --mixed_precision "{mixed_precision}" "{file_path}" '
160
+ + " ".join(sys.argv[1:])
161
+ )
162
+
163
+ # Use platform-specific tee commands:
164
+ if platform.system() == "Windows":
165
+ # On Windows, use PowerShell's tee-object.
166
+ # The command below launches PowerShell and passes the tee-object command.
167
+ # try:
168
+ # cmd = f"{cmd_base} 2>&1 | powershell -Command Get-Content {log_file} -Wait"
169
+ # except Exception as e:
170
+ # fprint(f"The log file cannot be saved due to Error: {e}")
171
+ # fprint(
172
+ # "If commands not allowed in PowerShell, "
173
+ # "please run 'Set-ExecutionPolicy RemoteSigned' in PowerShell with Admin."
174
+ # )
175
+ cmd = f"{cmd_base} 2>&1"
176
+ else:
177
+ # On Unix-like systems, use the standard tee command.
178
+ cmd = f"{cmd_base} 2>&1 | tee '{log_file}'"
179
+
180
+ # Execute the command.
181
+ sys.exit(os.system(cmd))
182
+
183
+ # # 匹配tqdm进度条的正则表达式(根据实际输出调整)
184
+ # tqdm_pattern = re.compile(r'^.*\d+%\|.*\|\s+\d+/\d+\s+\[.*\]\s*$')
185
+ #
186
+ # last_tqdm_line = ''
187
+ #
188
+ # with open(log_file, 'w', encoding='utf-8') as log_file:
189
+ # # 执行命令并捕获输出流
190
+ # proc = subprocess.Popen(
191
+ # cmd_base,
192
+ # shell=True,
193
+ # stdout=subprocess.PIPE,
194
+ # stderr=subprocess.STDOUT,
195
+ # bufsize=1,
196
+ # universal_newlines=True
197
+ # )
198
+ #
199
+ # # 实时处理输出流
200
+ # for line in proc.stdout:
201
+ # line = line.rstrip() # 移除行尾换行符
202
+ # if tqdm_pattern.match(line):
203
+ # # 更新最后一行tqdm输出
204
+ # last_tqdm_line = line + '\n' # 换行符需要手动添加
205
+ # # 实时显示进度条(覆盖模式)
206
+ # sys.stdout.write('\r' + line)
207
+ # sys.stdout.flush()
208
+ # else:
209
+ # # 写入日志并正常打印
210
+ # log_file.write(line + '\n')
211
+ # print(line)
212
+ #
213
+ # # 命令执行完毕后写入最后一个tqdm进度条
214
+ # if last_tqdm_line:
215
+ # log_file.write(last_tqdm_line)
216
+ # sys.stdout.write('\n') # 最后换行避免覆盖
217
+ #
218
+ # sys.exit(proc.returncode)
219
+
220
+
221
+ if __name__ == "__main__":
222
+ train_command()
@@ -0,0 +1,12 @@
1
+ # -*- coding: utf-8 -*-
2
+ # file: __init__.py
3
+ # time: 18:28 11/04/2024
4
+ # author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
5
+ # github: https://github.com/yangheng95
6
+ # huggingface: https://huggingface.co/yangheng
7
+ # google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
8
+ # Copyright (C) 2019-2024. All Rights Reserved.
9
+ """
10
+ This package contains modules for the benchmark hub.
11
+ """
12
+
@@ -0,0 +1,25 @@
1
+ # -*- coding: utf-8 -*-
2
+ # file: bench_hub.py
3
+ # time: 11:53 14/04/2024
4
+ # author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
5
+ # github: https://github.com/yangheng95
6
+ # huggingface: https://huggingface.co/yangheng
7
+ # google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
8
+ # Copyright (C) 2019-2024. All Rights Reserved.
9
+
10
+
11
+ class BenchHub:
12
+ """
13
+ A hub for accessing and managing benchmarks.
14
+
15
+ This class is intended to provide a centralized way to list, download,
16
+ and inspect available benchmarks for OmniGenome.
17
+ """
18
+
19
+ def __init__(self):
20
+ """Initializes the BenchHub."""
21
+ pass
22
+
23
+ def run(self):
24
+ """Placeholder for running functionality related to the benchmark hub."""
25
+ pass