omnigenome 0.3.1a0__py3-none-any.whl → 0.4.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. omnigenome/__init__.py +304 -266
  2. omnigenome-0.4.0a0.dist-info/METADATA +354 -0
  3. omnigenome-0.4.0a0.dist-info/RECORD +7 -0
  4. omnigenome/auto/__init__.py +0 -3
  5. omnigenome/auto/auto_bench/__init__.py +0 -11
  6. omnigenome/auto/auto_bench/auto_bench.py +0 -494
  7. omnigenome/auto/auto_bench/auto_bench_cli.py +0 -230
  8. omnigenome/auto/auto_bench/auto_bench_config.py +0 -216
  9. omnigenome/auto/auto_bench/config_check.py +0 -34
  10. omnigenome/auto/auto_train/__init__.py +0 -12
  11. omnigenome/auto/auto_train/auto_train.py +0 -429
  12. omnigenome/auto/auto_train/auto_train_cli.py +0 -222
  13. omnigenome/auto/bench_hub/__init__.py +0 -11
  14. omnigenome/auto/bench_hub/bench_hub.py +0 -25
  15. omnigenome/cli/__init__.py +0 -12
  16. omnigenome/cli/commands/__init__.py +0 -12
  17. omnigenome/cli/commands/base.py +0 -83
  18. omnigenome/cli/commands/bench/__init__.py +0 -12
  19. omnigenome/cli/commands/bench/bench_cli.py +0 -202
  20. omnigenome/cli/commands/rna/__init__.py +0 -12
  21. omnigenome/cli/commands/rna/rna_design.py +0 -177
  22. omnigenome/cli/omnigenome_cli.py +0 -128
  23. omnigenome/src/__init__.py +0 -11
  24. omnigenome/src/abc/__init__.py +0 -11
  25. omnigenome/src/abc/abstract_dataset.py +0 -641
  26. omnigenome/src/abc/abstract_metric.py +0 -114
  27. omnigenome/src/abc/abstract_model.py +0 -690
  28. omnigenome/src/abc/abstract_tokenizer.py +0 -269
  29. omnigenome/src/dataset/__init__.py +0 -16
  30. omnigenome/src/dataset/omni_dataset.py +0 -437
  31. omnigenome/src/lora/__init__.py +0 -12
  32. omnigenome/src/lora/lora_model.py +0 -300
  33. omnigenome/src/metric/__init__.py +0 -15
  34. omnigenome/src/metric/classification_metric.py +0 -184
  35. omnigenome/src/metric/metric.py +0 -199
  36. omnigenome/src/metric/ranking_metric.py +0 -142
  37. omnigenome/src/metric/regression_metric.py +0 -191
  38. omnigenome/src/misc/__init__.py +0 -3
  39. omnigenome/src/misc/utils.py +0 -503
  40. omnigenome/src/model/__init__.py +0 -19
  41. omnigenome/src/model/augmentation/__init__.py +0 -11
  42. omnigenome/src/model/augmentation/model.py +0 -219
  43. omnigenome/src/model/classification/__init__.py +0 -11
  44. omnigenome/src/model/classification/model.py +0 -638
  45. omnigenome/src/model/embedding/__init__.py +0 -11
  46. omnigenome/src/model/embedding/model.py +0 -263
  47. omnigenome/src/model/mlm/__init__.py +0 -11
  48. omnigenome/src/model/mlm/model.py +0 -177
  49. omnigenome/src/model/module_utils.py +0 -232
  50. omnigenome/src/model/regression/__init__.py +0 -11
  51. omnigenome/src/model/regression/model.py +0 -781
  52. omnigenome/src/model/regression/resnet.py +0 -483
  53. omnigenome/src/model/rna_design/__init__.py +0 -11
  54. omnigenome/src/model/rna_design/model.py +0 -476
  55. omnigenome/src/model/seq2seq/__init__.py +0 -11
  56. omnigenome/src/model/seq2seq/model.py +0 -44
  57. omnigenome/src/tokenizer/__init__.py +0 -16
  58. omnigenome/src/tokenizer/bpe_tokenizer.py +0 -226
  59. omnigenome/src/tokenizer/kmers_tokenizer.py +0 -247
  60. omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +0 -249
  61. omnigenome/src/trainer/__init__.py +0 -14
  62. omnigenome/src/trainer/accelerate_trainer.py +0 -747
  63. omnigenome/src/trainer/hf_trainer.py +0 -75
  64. omnigenome/src/trainer/trainer.py +0 -591
  65. omnigenome/utility/__init__.py +0 -3
  66. omnigenome/utility/dataset_hub/__init__.py +0 -12
  67. omnigenome/utility/dataset_hub/dataset_hub.py +0 -178
  68. omnigenome/utility/ensemble.py +0 -324
  69. omnigenome/utility/hub_utils.py +0 -517
  70. omnigenome/utility/model_hub/__init__.py +0 -11
  71. omnigenome/utility/model_hub/model_hub.py +0 -232
  72. omnigenome/utility/pipeline_hub/__init__.py +0 -11
  73. omnigenome/utility/pipeline_hub/pipeline.py +0 -483
  74. omnigenome/utility/pipeline_hub/pipeline_hub.py +0 -129
  75. omnigenome-0.3.1a0.dist-info/METADATA +0 -224
  76. omnigenome-0.3.1a0.dist-info/RECORD +0 -78
  77. {omnigenome-0.3.1a0.dist-info → omnigenome-0.4.0a0.dist-info}/WHEEL +0 -0
  78. {omnigenome-0.3.1a0.dist-info → omnigenome-0.4.0a0.dist-info}/entry_points.txt +0 -0
  79. {omnigenome-0.3.1a0.dist-info → omnigenome-0.4.0a0.dist-info}/licenses/LICENSE +0 -0
  80. {omnigenome-0.3.1a0.dist-info → omnigenome-0.4.0a0.dist-info}/top_level.txt +0 -0
@@ -1,494 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # file: auto_bench.py
3
- # time: 11:54 14/04/2024
4
- # author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
5
- # github: https://github.com/yangheng95
6
- # huggingface: https://huggingface.co/yangheng
7
- # google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
8
- # Copyright (C) 2019-2024. All Rights Reserved.
9
-
10
- import os
11
- import time
12
- import warnings
13
-
14
- import findfile
15
- import torch
16
- from metric_visualizer import MetricVisualizer
17
-
18
- from transformers import TrainingArguments, Trainer as HFTrainer
19
- from ...src.abc.abstract_tokenizer import OmniTokenizer
20
- from ...src.lora.lora_model import OmniLoraModel
21
- from ...src.misc.utils import (
22
- seed_everything,
23
- fprint,
24
- load_module_from_path,
25
- check_bench_version,
26
- clean_temp_checkpoint,
27
- )
28
- from ...src.trainer.trainer import Trainer
29
- from ...src.trainer.accelerate_trainer import AccelerateTrainer
30
- from ...utility.hub_utils import download_benchmark
31
- from ... import __version__ as omnigenome_version
32
-
33
-
34
- class AutoBench:
35
- """
36
- AutoBench is a class for automatically benchmarking genomic foundation models.
37
-
38
- This class provides a comprehensive framework for evaluating genomic models
39
- across multiple benchmarks and tasks. It handles loading benchmarks, models,
40
- tokenizers, and running evaluations with proper metric tracking and result
41
- visualization.
42
-
43
- AutoBench supports various evaluation scenarios including:
44
- - Single model evaluation across multiple benchmarks
45
- - Multi-seed evaluation for robustness testing
46
- - Different trainer backends (native, accelerate, huggingface)
47
- - Automatic metric visualization and result tracking
48
-
49
- Attributes:
50
- benchmark (str): The name or path of the benchmark to use.
51
- model_name_or_path (str): The name or path of the model to evaluate.
52
- tokenizer: The tokenizer to use for evaluation.
53
- autocast (str): The autocast precision to use ('fp16', 'bf16', etc.).
54
- overwrite (bool): Whether to overwrite existing evaluation results.
55
- trainer (str): The trainer to use ('native', 'accelerate', 'hf_trainer').
56
- mv_path (str): Path to the metric visualizer file.
57
- mv (MetricVisualizer): The metric visualizer instance.
58
- bench_metadata: Metadata about the benchmark configuration.
59
- """
60
-
61
- def __init__(
62
- self,
63
- benchmark,
64
- model_name_or_path,
65
- tokenizer=None,
66
- **kwargs,
67
- ):
68
- """
69
- Initializes the AutoBench instance.
70
-
71
- Args:
72
- benchmark (str): The name or path of the benchmark to use.
73
- model_name_or_path (str): The name or path of the model to evaluate.
74
- tokenizer: The tokenizer to use. If None, it will be loaded from the model path.
75
- **kwargs: Additional keyword arguments.
76
- - autocast (str): The autocast precision to use ('fp16', 'bf16', etc.).
77
- Defaults to 'fp16'.
78
- - overwrite (bool): Whether to overwrite existing evaluation results.
79
- Defaults to False.
80
- - trainer (str): The trainer to use ('native', 'accelerate', 'hf_trainer').
81
- Defaults to 'native'.
82
-
83
- Example:
84
- >>> # Initialize with a benchmark and model
85
- >>> bench = AutoBench("RGB", "model_name")
86
-
87
- >>> # Initialize with custom settings
88
- >>> bench = AutoBench("RGB", "model_name",
89
- ... autocast="bf16", trainer="accelerate")
90
- """
91
- self.benchmark = benchmark.rstrip("/")
92
- self.autocast = kwargs.pop("autocast", "fp16")
93
- self.overwrite = kwargs.pop("overwrite", False)
94
- self.trainer = kwargs.pop("trainer", "native")
95
-
96
- self.model_name_or_path = model_name_or_path
97
- self.tokenizer = tokenizer
98
- if isinstance(self.model_name_or_path, str):
99
- self.model_name_or_path = self.model_name_or_path.rstrip("/")
100
- self.model_name = self.model_name_or_path.split("/")[-1]
101
- else:
102
- self.model_name = self.model_name_or_path.__class__.__name__
103
- if isinstance(tokenizer, str):
104
- self.tokenizer = tokenizer.rstrip("/")
105
- os.makedirs("./autobench_evaluations", exist_ok=True)
106
- time_str = time.strftime("%Y%m%d_%H%M%S", time.localtime())
107
- mv_name = f"{benchmark}-{self.model_name}"
108
- self.mv_path = f"./autobench_evaluations/{mv_name}-{time_str}.mv"
109
-
110
- mv_paths = findfile.find_files(
111
- "./autobench_evaluations",
112
- and_key=[benchmark, self.model_name, ".mv"],
113
- )
114
- if mv_paths and not self.overwrite:
115
- self.mv = MetricVisualizer.load(mv_paths[-1])
116
- self.mv.summary(round=4)
117
- else:
118
- self.mv = MetricVisualizer(self.mv_path)
119
- if not os.path.exists(self.benchmark):
120
- fprint(
121
- "Benchmark:",
122
- benchmark,
123
- "does not exist. Search online for available benchmarks.",
124
- )
125
- self.benchmark = download_benchmark(self.benchmark)
126
-
127
- # Import benchmark list
128
- self.bench_metadata = load_module_from_path(
129
- f"bench_metadata", f"{self.benchmark}/metadata.py"
130
- )
131
- check_bench_version(
132
- self.bench_metadata.__omnigenome_version__, omnigenome_version
133
- )
134
- fprint("Loaded benchmarks: ", self.bench_metadata.bench_list)
135
- self.bench_info()
136
-
137
- def bench_info(self):
138
- """
139
- Prints and returns information about the current benchmark setup.
140
-
141
- This method provides a comprehensive overview of the current
142
- benchmark configuration, including benchmark details, model information,
143
- and evaluation settings.
144
-
145
- Returns:
146
- str: A string containing benchmark information.
147
-
148
- Example:
149
- >>> info = bench.bench_info()
150
- >>> print(info)
151
- """
152
- info = f"Benchmark Root: {self.benchmark}\n"
153
- info += f"Benchmark List: {self.bench_metadata.bench_list}\n"
154
- info += f"Model Name or Path: {self.model_name}\n"
155
- info += f"Tokenizer: {self.tokenizer}\n"
156
- info += f"Metric Visualizer Path: {self.mv_path}\n"
157
- info += f"BenchConfig Details: {self.bench_metadata}\n"
158
- fprint(info)
159
- return info
160
-
161
- def run(self, **kwargs):
162
- """
163
- Runs the benchmarking process.
164
-
165
- This method iterates through the tasks in the benchmark, loads the corresponding
166
- configurations, initializes the model, tokenizer, and datasets, and then
167
- trains and evaluates the model. It supports multiple evaluation seeds and
168
- various trainer backends.
169
-
170
- Args:
171
- **kwargs: Additional keyword arguments that will override the default
172
- parameters in the benchmark configuration.
173
-
174
- Example:
175
- >>> # Run benchmarking with default settings
176
- >>> bench.run()
177
-
178
- >>> # Run with custom parameters
179
- >>> bench.run(learning_rate=1e-4, batch_size=16)
180
- """
181
- bs_scale = kwargs.pop("bs_scale", 1)
182
- # Import benchmark config
183
- for _, bench in enumerate(self.bench_metadata.bench_list):
184
- clean_temp_checkpoint(1) # clean temp checkpoint older than 1 day
185
- fprint(
186
- ">" * 80,
187
- f"\nRunning evaluation for task: {bench}",
188
- "Progress: ",
189
- _ + 1,
190
- "/",
191
- len(self.bench_metadata.bench_list),
192
- f"{(_ + 1) * 100 / len(self.bench_metadata.bench_list)}%",
193
- )
194
- bench_config_path = findfile.find_file(
195
- self.benchmark, and_key=f"{self.benchmark}.{bench}.config".split(".")
196
- )
197
- config = load_module_from_path("config", bench_config_path)
198
- bench_config = config.bench_config
199
- fprint(f"Loaded config for {bench} from {bench_config_path}")
200
- fprint(bench_config)
201
-
202
- # Init Tokenizer and Model
203
- if not self.tokenizer:
204
- tokenizer = OmniTokenizer.from_pretrained(
205
- self.model_name_or_path,
206
- trust_remote_code=bench_config.get("trust_remote_code", True),
207
- **bench_config,
208
- )
209
- else:
210
- tokenizer = self.tokenizer
211
-
212
- if not isinstance(bench_config["seeds"], list):
213
- bench_config["seeds"] = [bench_config["seeds"]]
214
-
215
- random_seeds = bench_config["seeds"]
216
- for seed in random_seeds:
217
- _kwargs = kwargs.copy()
218
- for key, value in _kwargs.items():
219
- if key in bench_config:
220
- fprint(
221
- "Override",
222
- key,
223
- "with",
224
- value,
225
- "according to the input kwargs",
226
- )
227
- bench_config.update({key: value})
228
-
229
- else:
230
- warnings.warn(
231
- f"kwarg: {key} not found in bench_config while setting {key} = {value}"
232
- )
233
- bench_config.update({key: value})
234
-
235
- for key, value in bench_config.items():
236
- if key in bench_config and key in _kwargs:
237
- _kwargs.pop(key)
238
-
239
- fprint(
240
- f"AutoBench Config for {bench}:",
241
- "\n".join([f"{k}: {v}" for k, v in bench_config.items()]),
242
- )
243
- for key, value in _kwargs.items():
244
- if key in bench_config:
245
- fprint(
246
- "Override",
247
- key,
248
- "with",
249
- value,
250
- "according to the input kwargs",
251
- )
252
- bench_config.update({key: value})
253
-
254
- else:
255
- warnings.warn(
256
- f"kwarg: {key} not found in bench_config while setting {key} = {value}"
257
- )
258
- bench_config.update({key: value})
259
-
260
- for key, value in bench_config.items():
261
- if key in bench_config and key in _kwargs:
262
- _kwargs.pop(key)
263
-
264
- fprint(
265
- f"AutoBench Config for {bench}:",
266
- "\n".join([f"{k}: {v}" for k, v in bench_config.items()]),
267
- )
268
-
269
- batch_size = (
270
- bench_config["batch_size"] if "batch_size" in bench_config else 8
271
- ) * bs_scale
272
-
273
- record_name = f"{self.benchmark}-{bench}-{self.model_name}".split("/")[
274
- -1
275
- ]
276
- # check if the record exists
277
- if record_name in self.mv.transpose() and len(
278
- list(self.mv.transpose()[record_name].values())[0]
279
- ) >= len(bench_config["seeds"]):
280
- continue
281
-
282
- seed_everything(seed)
283
- if self.model_name_or_path:
284
- model_cls = bench_config["model_cls"]
285
- model = model_cls(
286
- self.model_name_or_path,
287
- tokenizer=tokenizer,
288
- label2id=bench_config.label2id,
289
- num_labels=bench_config["num_labels"],
290
- trust_remote_code=True,
291
- ignore_mismatched_sizes=True,
292
- )
293
- else:
294
- raise ValueError(
295
- "model_name_or_path is not specified. Please provide a valid model name or path."
296
- )
297
-
298
- fprint(f"\n{model}")
299
-
300
- if kwargs.get("lora_config", None) is not None:
301
- fprint(
302
- "Applying LoRA to the model with config:", kwargs["lora_config"]
303
- )
304
- model = OmniLoraModel(model, **kwargs.get("lora_config", {}))
305
-
306
- # Init Trainer
307
- dataset_cls = bench_config["dataset_cls"]
308
-
309
- if hasattr(model.config, "max_position_embeddings"):
310
- max_length = min(
311
- bench_config["max_length"],
312
- model.config.max_position_embeddings,
313
- )
314
- else:
315
- max_length = bench_config["max_length"]
316
-
317
- train_set = dataset_cls(
318
- data_source=bench_config["train_file"],
319
- tokenizer=tokenizer,
320
- label2id=bench_config["label2id"],
321
- max_length=max_length,
322
- structure_in=bench_config.get("structure_in", False),
323
- max_examples=bench_config.get("max_examples", None),
324
- shuffle=bench_config.get("shuffle", True),
325
- drop_long_seq=bench_config.get("drop_long_seq", False),
326
- **_kwargs,
327
- )
328
- test_set = dataset_cls(
329
- data_source=bench_config["test_file"],
330
- tokenizer=tokenizer,
331
- label2id=bench_config["label2id"],
332
- max_length=max_length,
333
- structure_in=bench_config.get("structure_in", False),
334
- max_examples=bench_config.get("max_examples", None),
335
- shuffle=False,
336
- drop_long_seq=bench_config.get("drop_long_seq", False),
337
- **_kwargs,
338
- )
339
- valid_set = dataset_cls(
340
- data_source=bench_config["valid_file"],
341
- tokenizer=tokenizer,
342
- label2id=bench_config["label2id"],
343
- max_length=max_length,
344
- structure_in=bench_config.get("structure_in", False),
345
- max_examples=bench_config.get("max_examples", None),
346
- shuffle=False,
347
- drop_long_seq=bench_config.get("drop_long_seq", False),
348
- **_kwargs,
349
- )
350
-
351
- if self.trainer == "hf_trainer":
352
- # Set up HuggingFace Trainer
353
- hf_kwargs = {
354
- k: v
355
- for k, v in kwargs.items()
356
- if hasattr(TrainingArguments, k) and k != "output_dir"
357
- }
358
- training_args = TrainingArguments(
359
- output_dir=f"./autobench_evaluations/{self.model_name}-{bench}",
360
- num_train_epochs=hf_kwargs.pop(
361
- "num_train_epochs", bench_config["epochs"]
362
- ),
363
- per_device_train_batch_size=hf_kwargs.pop(
364
- "batch_size", batch_size
365
- ),
366
- per_device_eval_batch_size=hf_kwargs.pop(
367
- "batch_size", batch_size
368
- ),
369
- gradient_accumulation_steps=hf_kwargs.pop(
370
- "gradient_accumulation_steps", 1
371
- ),
372
- learning_rate=hf_kwargs.pop("learning_rate", 2e-5),
373
- weight_decay=hf_kwargs.pop("weight_decay", 0),
374
- eval_strategy=hf_kwargs.pop("eval_strategy", "epoch"),
375
- save_strategy=hf_kwargs.pop("save_strategy", "epoch"),
376
- fp16=hf_kwargs.pop("fp16", True),
377
- remove_unused_columns=False,
378
- label_names=["labels"],
379
- **hf_kwargs,
380
- )
381
-
382
- valid_set = valid_set if len(valid_set) else test_set
383
-
384
- if len(bench_config["compute_metrics"]) > 1:
385
- fprint(
386
- "Multiple metrics not supported by HFTrainer, using the first one metric only."
387
- )
388
- trainer = HFTrainer(
389
- model=model,
390
- args=training_args,
391
- train_dataset=train_set,
392
- eval_dataset=valid_set,
393
- compute_metrics=(
394
- bench_config["compute_metrics"][0]
395
- if isinstance(bench_config["compute_metrics"], list)
396
- else bench_config["compute_metrics"]
397
- ),
398
- )
399
-
400
- # Train and evaluate
401
- eval_result = trainer.evaluate(
402
- valid_set if len(valid_set) else test_set
403
- )
404
- print(eval_result)
405
- train_result = trainer.train()
406
- eval_result = trainer.evaluate()
407
- test_result = trainer.evaluate(
408
- test_set if len(test_set) else valid_set
409
- )
410
-
411
- metrics = {
412
- "train": train_result.metrics,
413
- "eval": eval_result,
414
- "test": test_result,
415
- }
416
- fprint(metrics)
417
- else:
418
- optimizer = torch.optim.AdamW(
419
- filter(lambda p: p.requires_grad, model.parameters()),
420
- lr=(
421
- bench_config["learning_rate"]
422
- if "learning_rate" in bench_config
423
- else 2e-5
424
- ),
425
- weight_decay=(
426
- bench_config["weight_decay"]
427
- if "weight_decay" in bench_config
428
- else 0
429
- ),
430
- )
431
- if self.trainer == "accelerate":
432
- trainer_cls = AccelerateTrainer
433
- else:
434
- trainer_cls = Trainer
435
- fprint(f"Using Trainer: {trainer_cls}")
436
- trainer = trainer_cls(
437
- model=model,
438
- train_dataset=train_set,
439
- eval_dataset=valid_set,
440
- test_dataset=test_set,
441
- batch_size=batch_size,
442
- patience=(
443
- bench_config["patience"]
444
- if "patience" in bench_config
445
- else 3
446
- ),
447
- epochs=bench_config["epochs"],
448
- gradient_accumulation_steps=bench_config.get(
449
- "gradient_accumulation_steps", 1
450
- ),
451
- optimizer=optimizer,
452
- loss_fn=(
453
- bench_config["loss_fn"]
454
- if "loss_fn" in bench_config
455
- else None
456
- ),
457
- compute_metrics=bench_config["compute_metrics"],
458
- seed=seed,
459
- autocast=self.autocast,
460
- **_kwargs,
461
- )
462
- metrics = trainer.train()
463
-
464
- predictions = trainer.predictions
465
-
466
- if bench_config.get("save_predictions", False):
467
- os.makedirs(f"predictions/{bench}", exist_ok=True)
468
- import numpy as np
469
-
470
- for split in predictions.keys():
471
- with open(
472
- f"predictions/{bench}/{split}.npy",
473
- "wb",
474
- ) as f:
475
- np.save(f, predictions[split])
476
-
477
- if metrics:
478
- for key, value in metrics["test"][-1].items():
479
- try:
480
- value = float(value)
481
- except:
482
- pass # ignore non-float values
483
- self.mv.log(f"{record_name}", f"{key}", value)
484
- # for key, value in metrics['test'][-1].items():
485
- # self.mv.log(f'{record_name}', f'test_{key}', value)
486
- # for i, valid_metrics in enumerate(metrics["valid"]):
487
- # for key, value in valid_metrics.items():
488
- # self.mv.log(f'{record_name}', f'valid_epoch_{i}_{key}', value)
489
-
490
- self.mv.summary(round=4)
491
- self.mv.dump(self.mv_path)
492
- self.mv.to_csv(self.mv_path.replace(".mv", ".csv"))
493
- del model, trainer, optimizer
494
- torch.cuda.empty_cache()