omnigenome 0.3.0a0__py3-none-any.whl → 0.3.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of omnigenome might be problematic. Click here for more details.

omnigenome/__init__.py CHANGED
@@ -8,17 +8,11 @@
8
8
  # Copyright (C) 2019-2024. All Rights Reserved.
9
9
 
10
10
  """
11
- OmniGenome: A comprehensive toolkit for genomic foundation models.
11
+ This __init__.py file exposes the Key API Entries of the library for easy access.
12
+ Use dir(omnigenome) to see all available APIs.
12
13
 
13
- This package provides a suite of tools for working with genomic data, including:
14
- - Automated benchmarking and training pipelines.
15
- - A hub for accessing pre-trained models, datasets, and pipelines.
16
- - A flexible and extensible framework for building custom models and tasks.
17
-
18
- This __init__.py file exposes the core components of the library for easy access.
19
-
20
- Key Components:
21
- ---------------
14
+ Key API Entries:
15
+ ----------------
22
16
  - AutoBench: Automated benchmarking of genomic models
23
17
  - AutoTrain: Automated training of genomic models
24
18
  - BenchHub: Hub for accessing benchmarks
@@ -29,27 +23,10 @@ Key Components:
29
23
  - Tokenizer classes for different sequence representations
30
24
  - Metric classes for evaluation
31
25
  - Trainer classes for model training
32
-
33
- Example Usage:
34
- --------------
35
- ```python
36
- from omnigenome import AutoBench, AutoTrain, OmniModelForSequenceClassification
37
-
38
- # Run automated benchmarking
39
- bench = AutoBench("RGB", "model_name")
40
- bench.run()
41
-
42
- # Train a model
43
- trainer = AutoTrain("RGB", "model_name")
44
- trainer.run()
45
-
46
- # Use a specific model
47
- model = OmniModelForSequenceClassification("model_path", tokenizer)
48
- ```
49
26
  """
50
27
 
51
- __name__ = "omnigenome"
52
- __version__ = "0.3.0alpha"
28
+ __name__ = "omnigenbench"
29
+ __version__ = "0.3.0alpha1"
53
30
 
54
31
  __author__ = "YANG, HENG"
55
32
  __email__ = "yangheng2021@gmail.com"
@@ -227,10 +204,10 @@ LOGO1 = r"""
227
204
  @@** = **@@ \___/ |_| |_| |_||_| |_||_|
228
205
  @@** ------+ **@@
229
206
  @@** =========# **@@ ____
230
- @@ ---------------+ @@ / ___| ___ _ __ ___ _ __ ___ ___
231
- @@ ================== @@ | | _ / _ \| '_ \ / _ \ | '_ ` _ \ / _ \
232
- @@ +--------------- @@ | |_| || __/| | | || (_) || | | | | || __/
233
- @@** #========= **@@ \____| \___||_| |_| \___/ |_| |_| |_| \___|
207
+ @@ ---------------+ @@ / ___| ___ _ __
208
+ @@ ================== @@ | | _ / _ \| '_ \
209
+ @@ +--------------- @@ | |_| || __/| | | |
210
+ @@** #========= **@@ \____| \___||_| |_|
234
211
  @@** +------ **@@
235
212
  @@** = **@@
236
213
  @@** ____ _
@@ -251,10 +228,10 @@ LOGO2 = r"""
251
228
  *@@ #========= @@*
252
229
  *@@* *@@*
253
230
  *@@ +---@@@* ____
254
- *@@* ** / ___| ___ _ __ ___ _ __ ___ ___
255
- **@** | | _ / _ \| '_ \ / _ \ | '_ ` _ \ / _ \
256
- *@@* *@@* | |_| || __/| | | || (_) || | | | | || __/
257
- *@@ ---+ @@* \____| \___||_| |_| \___/ |_| |_| |_| \___|
231
+ *@@* ** / ___| ___ _ __
232
+ **@** | | _ / _ \| '_ \
233
+ *@@* *@@* | |_| || __/| | | |
234
+ *@@ ---+ @@* \____| \___||_| |_|
258
235
  *@@* *@@*
259
236
  *@@ =========# @@*
260
237
  *@@ @@*
@@ -12,6 +12,7 @@ import pickle
12
12
  import sys
13
13
  import tempfile
14
14
  import time
15
+ import warnings
15
16
 
16
17
  import ViennaRNA as RNA
17
18
  import findfile
@@ -48,58 +49,50 @@ def seed_everything(seed=42):
48
49
  torch.manual_seed(seed)
49
50
  torch.cuda.manual_seed(seed)
50
51
  torch.backends.cudnn.deterministic = True
52
+ torch.backends.cudnn.benchmark = False
51
53
 
52
54
 
53
55
  class RNA2StructureCache(dict):
54
56
  """
55
- A cache for RNA sequence to structure predictions using ViennaRNA.
57
+ A cache for RNA secondary structure predictions using ViennaRNA.
56
58
 
57
- This class provides a dictionary-like interface for caching RNA secondary
58
- structure predictions. It uses ViennaRNA for structure prediction and
59
- supports both single sequences and batches of sequences.
60
-
61
- The cache can be persisted to disk and loaded back, making it useful for
62
- avoiding redundant structure predictions across multiple runs.
59
+ This class provides a caching mechanism for RNA secondary structure predictions
60
+ to avoid redundant computations. It supports both single sequence and batch
61
+ processing with optional multiprocessing for improved performance.
63
62
 
64
63
  Attributes:
65
- cache_file (str): Path to the cache file on disk.
66
- cache (dict): The in-memory cache dictionary.
67
- queue_num (int): Counter for tracking cache updates.
64
+ cache (dict): Dictionary storing sequence-structure mappings
65
+ cache_file (str): Path to the cache file on disk
66
+ queue_num (int): Counter for tracking cache updates
68
67
  """
69
68
 
70
69
  def __init__(self, cache_file=None, *args, **kwargs):
71
70
  """
72
- Initializes the RNA structure cache.
73
-
71
+ Initialize the RNA structure cache.
72
+
74
73
  Args:
75
74
  cache_file (str, optional): Path to the cache file. If None, uses
76
- a default path in `__OMNIGENOME_DATA__`.
77
- *args: Additional arguments passed to dict constructor.
78
- **kwargs: Additional keyword arguments passed to dict constructor.
79
-
80
- Example:
81
- >>> # Initialize with default cache file
82
- >>> cache = RNA2StructureCache()
83
-
84
- >>> # Initialize with custom cache file
85
- >>> cache = RNA2StructureCache("my_cache.pkl")
75
+ a default temporary file.
76
+ *args: Additional positional arguments for dict initialization
77
+ **kwargs: Additional keyword arguments for dict initialization
86
78
  """
87
79
  super().__init__(*args, **kwargs)
88
-
89
- if not cache_file:
90
- self.cache_file = "__OMNIGENOME_DATA__/rna2structure.cache.pkl"
91
- else:
92
- self.cache_file = cache_file
93
-
94
- if self.cache_file is None or not os.path.exists(self.cache_file):
95
- self.cache = {}
96
- else:
97
- fprint(f"Initialize sequence to structure cache from {self.cache_file}...")
98
- with open(self.cache_file, "rb") as f:
99
- self.cache = pickle.load(f)
100
-
80
+ self.cache = dict(*args, **kwargs)
81
+ self.cache_file = (
82
+ cache_file
83
+ if cache_file is not None
84
+ else os.path.join(tempfile.gettempdir(), "rna_structure_cache.pkl")
85
+ )
101
86
  self.queue_num = 0
102
87
 
88
+ # Load existing cache if available
89
+ if os.path.exists(self.cache_file):
90
+ try:
91
+ with open(self.cache_file, "rb") as f:
92
+ self.cache.update(pickle.load(f))
93
+ except Exception as e:
94
+ warnings.warn(f"Failed to load cache file: {e}")
95
+
103
96
  def __getitem__(self, key):
104
97
  """Gets a cached structure prediction."""
105
98
  return self.cache[key]
@@ -116,6 +109,22 @@ class RNA2StructureCache(dict):
116
109
  """String representation of the cache."""
117
110
  return str(self.cache)
118
111
 
112
+ def _fold_single_sequence(self, sequence):
113
+ """
114
+ Predict structure for a single sequence (worker function for multiprocessing).
115
+
116
+ Args:
117
+ sequence (str): RNA sequence to fold
118
+
119
+ Returns:
120
+ tuple: (structure, mfe) tuple
121
+ """
122
+ try:
123
+ return RNA.fold(sequence)
124
+ except Exception as e:
125
+ warnings.warn(f"Failed to fold sequence {sequence}: {e}")
126
+ return ("." * len(sequence), 0.0)
127
+
119
128
  def fold(self, sequence, return_mfe=False, num_workers=1):
120
129
  """
121
130
  Predicts RNA secondary structure for given sequences.
@@ -151,39 +160,58 @@ class RNA2StructureCache(dict):
151
160
  else:
152
161
  sequences = sequence
153
162
 
154
- if (
155
- os.name != "nt" and len(sequences) > 1
156
- ): # multiprocessing is not working on Windows in my case
157
- num_workers = min(os.cpu_count(), len(sequences))
158
-
159
- structures = []
163
+ # Determine if we should use multiprocessing
164
+ use_multiprocessing = (
165
+ os.name != "nt" and # Not Windows
166
+ len(sequences) > 1 and # Multiple sequences
167
+ num_workers > 1 # Multiple workers requested
168
+ )
160
169
 
161
- if not all([seq in self.cache for seq in sequences]):
162
- if num_workers == 1:
163
- for seq in sequences:
164
- if seq not in self.cache:
165
- self.queue_num += 1
166
- self.cache[seq] = RNA.fold(seq)
167
- else:
170
+ # Find sequences that need prediction
171
+ sequences_to_predict = [seq for seq in sequences if seq not in self.cache]
172
+
173
+ if sequences_to_predict:
174
+ if use_multiprocessing:
175
+ # Use multiprocessing for batch prediction
168
176
  if num_workers is None:
169
- num_workers = min(os.cpu_count(), len(sequences))
170
-
171
- with multiprocessing.Pool(num_workers) as pool:
172
- for seq in sequences:
173
- if seq not in self.cache:
177
+ num_workers = min(os.cpu_count(), len(sequences_to_predict))
178
+
179
+ try:
180
+ # Set multiprocessing start method to 'spawn' for better compatibility
181
+ if multiprocessing.get_start_method(allow_none=True) != 'spawn':
182
+ multiprocessing.set_start_method('spawn', force=True)
183
+
184
+ with multiprocessing.Pool(num_workers) as pool:
185
+ # Use map instead of apply_async for better error handling
186
+ results = pool.map(self._fold_single_sequence, sequences_to_predict)
187
+
188
+ # Update cache with results
189
+ for seq, result in zip(sequences_to_predict, results):
190
+ self.cache[seq] = result
174
191
  self.queue_num += 1
175
- async_result = pool.apply_async(RNA.fold, args=(seq,))
176
- structures.append((seq, async_result))
177
-
178
- for seq, result in structures:
179
- self.cache[seq] = result.get() # result is a tuple
192
+
193
+ except Exception as e:
194
+ warnings.warn(f"Multiprocessing failed, falling back to sequential: {e}")
195
+ # Fallback to sequential processing
196
+ for seq in sequences_to_predict:
197
+ self.cache[seq] = self._fold_single_sequence(seq)
198
+ self.queue_num += 1
199
+ else:
200
+ # Sequential processing
201
+ for seq in sequences_to_predict:
202
+ self.cache[seq] = self._fold_single_sequence(seq)
203
+ self.queue_num += 1
180
204
 
205
+ # Prepare output
181
206
  if return_mfe:
182
207
  structures = [self.cache[seq] for seq in sequences]
183
208
  else:
184
209
  structures = [self.cache[seq][0] for seq in sequences]
210
+
211
+ # Update cache file periodically
185
212
  self.update_cache_file(self.cache_file)
186
213
 
214
+ # Return single result or list
187
215
  if len(structures) == 1:
188
216
  return structures[0]
189
217
  else:
@@ -209,14 +237,16 @@ class RNA2StructureCache(dict):
209
237
  if cache_file is None:
210
238
  cache_file = self.cache_file
211
239
 
212
- if not os.path.exists(os.path.dirname(cache_file)):
213
- os.makedirs(os.path.dirname(cache_file))
240
+ try:
241
+ if not os.path.exists(os.path.dirname(cache_file)):
242
+ os.makedirs(os.path.dirname(cache_file))
214
243
 
215
- # print(f"Updating cache file {cache_file}...")
216
- with open(cache_file, "wb") as f:
217
- pickle.dump(self.cache, f)
244
+ with open(cache_file, "wb") as f:
245
+ pickle.dump(self.cache, f)
218
246
 
219
- self.queue_num = 0
247
+ self.queue_num = 0
248
+ except Exception as e:
249
+ warnings.warn(f"Failed to update cache file: {e}")
220
250
 
221
251
 
222
252
  def env_meta_info():
@@ -330,110 +360,140 @@ def print_args(config, logger=None):
330
360
  >>> config = Namespace(learning_rate=0.001, batch_size=32)
331
361
  >>> print_args(config)
332
362
  """
333
- args = [key for key in sorted(config.args.keys())]
334
- if logger:
335
- logger.info(args)
363
+ if logger is None:
364
+ for arg in config.args:
365
+ if config.args_call_count[arg]:
366
+ print("{}: {}".format(arg, config.args[arg]))
336
367
  else:
337
- fprint(args)
368
+ for arg in config.args:
369
+ if config.args_call_count[arg]:
370
+ logger.info("{}: {}".format(arg, config.args[arg]))
338
371
 
339
372
 
340
373
  def fprint(*objects, sep=" ", end="\n", file=sys.stdout, flush=False):
341
374
  """
342
- Custom print function that adds a timestamp and the pyabsa version before the printed message.
375
+ Enhanced print function with automatic flushing.
376
+
377
+ This function provides a print-like interface with automatic flushing
378
+ to ensure output is displayed immediately. It's useful for real-time
379
+ logging and progress tracking.
343
380
 
344
381
  Args:
345
- *objects: Any number of objects to be printed
346
- sep (str, optional): Separator between objects. Defaults to " ".
347
- end (str, optional): Ending character after all objects are printed. Defaults to "\n".
348
- file (io.TextIOWrapper, optional): Text file to write printed output to. Defaults to sys.stdout.
349
- flush (bool, optional): Whether to flush output buffer after printing. Defaults to False.
382
+ *objects: Objects to print
383
+ sep (str): Separator between objects (default: " ")
384
+ end (str): String appended after the last value (default: "\n")
385
+ file: File-like object to write to (default: sys.stdout)
386
+ flush (bool): Whether to flush the stream (default: False)
387
+
388
+ Example:
389
+ >>> fprint("Training started...", flush=True)
390
+ >>> fprint("Epoch 1/10", "Loss: 0.5", sep=" | ")
350
391
  """
351
- from omnigenome import __version__
352
- from omnigenome import __name__
353
-
354
- print(
355
- time.strftime(
356
- "[%Y-%m-%d %H:%M:%S] [{} {}] ".format(__name__, __version__),
357
- time.localtime(time.time()),
358
- ),
359
- *objects,
360
- sep=sep,
361
- end=end,
362
- file=file,
363
- flush=flush,
364
- )
392
+ print(*objects, sep=sep, end=end, file=file, flush=True)
365
393
 
366
394
 
367
395
  def clean_temp_checkpoint(days_threshold=7):
368
396
  """
369
- 删除超过指定时间的 checkpoint 文件。
397
+ Clean up temporary checkpoint files older than specified days.
398
+
399
+ This function removes temporary checkpoint files that are older than
400
+ the specified threshold to free up disk space.
401
+
402
+ Args:
403
+ days_threshold (int): Number of days after which files are considered old.
404
+ Defaults to 7.
370
405
 
371
- 参数:
372
- - directory (str): 文件所在的目录路径。
373
- - file_extension (str): checkpoint 文件的扩展名,默认是 ".ckpt"。
374
- - days_threshold (int): 超过多少天的文件将被删除,默认是 7 天。
406
+ Example:
407
+ >>> clean_temp_checkpoint(3) # Remove files older than 3 days
375
408
  """
376
- # 获取当前时间
377
- import os
378
- from datetime import datetime, timedelta
379
-
380
- current_time = datetime.now()
381
- ckpt_files = findfile.find_cwd_files(["tmp_ckpt", ".pt"])
382
- # 遍历目录中的所有文件
383
- for file_path in ckpt_files:
384
- # 获取文件的最后修改时间
385
- file_mod_time = datetime.fromtimestamp(os.path.getmtime(file_path))
386
-
387
- # 计算文件是否超过指定的时间阈值
388
- if current_time - file_mod_time > timedelta(days=days_threshold):
409
+ import glob
410
+ import time
411
+
412
+ temp_patterns = [
413
+ "temp_checkpoint_*",
414
+ "checkpoint_*",
415
+ "*.tmp",
416
+ "*.temp",
417
+ ]
418
+
419
+ current_time = time.time()
420
+ threshold_time = current_time - (days_threshold * 24 * 60 * 60)
421
+
422
+ for pattern in temp_patterns:
423
+ for file_path in glob.glob(pattern):
389
424
  try:
390
- # 删除文件
391
- os.remove(file_path)
392
- print(f"Deleted: {file_path}")
393
- except Exception as e:
394
- print(f"Error deleting {file_path}: {e}")
425
+ if os.path.getmtime(file_path) < threshold_time:
426
+ os.remove(file_path)
427
+ except Exception:
428
+ pass
395
429
 
396
430
 
397
431
  def load_module_from_path(module_name, file_path):
398
- import importlib
432
+ """
433
+ Load a Python module from a file path.
434
+
435
+ This function dynamically loads a Python module from a file path,
436
+ useful for loading configuration files or custom modules.
437
+
438
+ Args:
439
+ module_name (str): Name to assign to the loaded module
440
+ file_path (str): Path to the Python file to load
441
+
442
+ Returns:
443
+ module: The loaded module object
444
+
445
+ Example:
446
+ >>> config = load_module_from_path("config", "config.py")
447
+ >>> print(config.some_variable)
448
+ """
449
+ import importlib.util
399
450
 
400
451
  spec = importlib.util.spec_from_file_location(module_name, file_path)
401
452
  module = importlib.util.module_from_spec(spec)
402
- try:
403
- spec.loader.exec_module(module)
404
- except FileNotFoundError:
405
- raise ImportError(f"Cannot find the module {module_name} from {file_path}.")
453
+ spec.loader.exec_module(module)
406
454
  return module
407
455
 
408
456
 
409
457
  def check_bench_version(bench_version, omnigenome_version):
410
- assert (
411
- bench_version is not None
412
- ), "Benchmark metadata does not contain a valid __omnigenome__ version."
413
-
414
- if not isinstance(bench_version, (int, float, str)):
415
- raise TypeError(
416
- f"Invalid type for benchmark version. Expected int, float, or str but got {type(bench_version).__name__}."
417
- )
458
+ """
459
+ Check if benchmark version is compatible with OmniGenome version.
460
+
461
+ This function compares the benchmark version with the OmniGenome version
462
+ to ensure compatibility and warns if there are potential issues.
418
463
 
419
- assert (
420
- omnigenome_version is not None
421
- ), "AutoBench is missing a valid omnigenome version."
464
+ Args:
465
+ bench_version (str): Version of the benchmark
466
+ omnigenome_version (str): Version of OmniGenome
422
467
 
423
- if bench_version > omnigenome_version:
424
- raise ValueError(
425
- f"AutoBench version {omnigenome_version} is not compatible with the benchmark version "
426
- f"{bench_version}. Please update the benchmark or AutoBench."
468
+ Example:
469
+ >>> check_bench_version("0.2.0", "0.3.0")
470
+ """
471
+ if bench_version != omnigenome_version:
472
+ warnings.warn(
473
+ f"Benchmark version ({bench_version}) differs from "
474
+ f"OmniGenome version ({omnigenome_version}). "
475
+ f"This may cause compatibility issues."
427
476
  )
428
477
 
429
478
 
430
479
  def clean_temp_dir_pt_files():
431
- tmp_dir = tempfile.gettempdir()
432
- for f in os.listdir(tmp_dir):
433
- if f.endswith(".pt") and f.startswith("tmp_ckpt"):
434
- path = os.path.join(tmp_dir, f)
480
+ """
481
+ Clean up temporary PyTorch files in the current directory.
482
+
483
+ This function removes temporary PyTorch files (like .pt, .pth files)
484
+ that may be left over from previous runs.
485
+
486
+ Example:
487
+ >>> clean_temp_dir_pt_files()
488
+ """
489
+ import glob
490
+
491
+ temp_patterns = ["*.pt", "*.pth", "temp_*", "checkpoint_*"]
492
+
493
+ for pattern in temp_patterns:
494
+ for file_path in glob.glob(pattern):
435
495
  try:
436
- os.remove(path)
437
- print(f"Removed: {path}")
438
- except Exception as e:
439
- print(f"Failed to remove {path}: {e}")
496
+ if os.path.isfile(file_path):
497
+ os.remove(file_path)
498
+ except Exception:
499
+ pass