omnigenome 0.3.1a0__py3-none-any.whl → 0.3.4a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of omnigenome might be problematic. Click here for more details.

Files changed (79) hide show
  1. omnigenome/__init__.py +252 -266
  2. {omnigenome-0.3.1a0.dist-info → omnigenome-0.3.4a0.dist-info}/METADATA +9 -9
  3. omnigenome-0.3.4a0.dist-info/RECORD +7 -0
  4. omnigenome/auto/__init__.py +0 -3
  5. omnigenome/auto/auto_bench/__init__.py +0 -11
  6. omnigenome/auto/auto_bench/auto_bench.py +0 -494
  7. omnigenome/auto/auto_bench/auto_bench_cli.py +0 -230
  8. omnigenome/auto/auto_bench/auto_bench_config.py +0 -216
  9. omnigenome/auto/auto_bench/config_check.py +0 -34
  10. omnigenome/auto/auto_train/__init__.py +0 -12
  11. omnigenome/auto/auto_train/auto_train.py +0 -429
  12. omnigenome/auto/auto_train/auto_train_cli.py +0 -222
  13. omnigenome/auto/bench_hub/__init__.py +0 -11
  14. omnigenome/auto/bench_hub/bench_hub.py +0 -25
  15. omnigenome/cli/__init__.py +0 -12
  16. omnigenome/cli/commands/__init__.py +0 -12
  17. omnigenome/cli/commands/base.py +0 -83
  18. omnigenome/cli/commands/bench/__init__.py +0 -12
  19. omnigenome/cli/commands/bench/bench_cli.py +0 -202
  20. omnigenome/cli/commands/rna/__init__.py +0 -12
  21. omnigenome/cli/commands/rna/rna_design.py +0 -177
  22. omnigenome/cli/omnigenome_cli.py +0 -128
  23. omnigenome/src/__init__.py +0 -11
  24. omnigenome/src/abc/__init__.py +0 -11
  25. omnigenome/src/abc/abstract_dataset.py +0 -641
  26. omnigenome/src/abc/abstract_metric.py +0 -114
  27. omnigenome/src/abc/abstract_model.py +0 -690
  28. omnigenome/src/abc/abstract_tokenizer.py +0 -269
  29. omnigenome/src/dataset/__init__.py +0 -16
  30. omnigenome/src/dataset/omni_dataset.py +0 -437
  31. omnigenome/src/lora/__init__.py +0 -12
  32. omnigenome/src/lora/lora_model.py +0 -300
  33. omnigenome/src/metric/__init__.py +0 -15
  34. omnigenome/src/metric/classification_metric.py +0 -184
  35. omnigenome/src/metric/metric.py +0 -199
  36. omnigenome/src/metric/ranking_metric.py +0 -142
  37. omnigenome/src/metric/regression_metric.py +0 -191
  38. omnigenome/src/misc/__init__.py +0 -3
  39. omnigenome/src/misc/utils.py +0 -503
  40. omnigenome/src/model/__init__.py +0 -19
  41. omnigenome/src/model/augmentation/__init__.py +0 -11
  42. omnigenome/src/model/augmentation/model.py +0 -219
  43. omnigenome/src/model/classification/__init__.py +0 -11
  44. omnigenome/src/model/classification/model.py +0 -638
  45. omnigenome/src/model/embedding/__init__.py +0 -11
  46. omnigenome/src/model/embedding/model.py +0 -263
  47. omnigenome/src/model/mlm/__init__.py +0 -11
  48. omnigenome/src/model/mlm/model.py +0 -177
  49. omnigenome/src/model/module_utils.py +0 -232
  50. omnigenome/src/model/regression/__init__.py +0 -11
  51. omnigenome/src/model/regression/model.py +0 -781
  52. omnigenome/src/model/regression/resnet.py +0 -483
  53. omnigenome/src/model/rna_design/__init__.py +0 -11
  54. omnigenome/src/model/rna_design/model.py +0 -476
  55. omnigenome/src/model/seq2seq/__init__.py +0 -11
  56. omnigenome/src/model/seq2seq/model.py +0 -44
  57. omnigenome/src/tokenizer/__init__.py +0 -16
  58. omnigenome/src/tokenizer/bpe_tokenizer.py +0 -226
  59. omnigenome/src/tokenizer/kmers_tokenizer.py +0 -247
  60. omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +0 -249
  61. omnigenome/src/trainer/__init__.py +0 -14
  62. omnigenome/src/trainer/accelerate_trainer.py +0 -747
  63. omnigenome/src/trainer/hf_trainer.py +0 -75
  64. omnigenome/src/trainer/trainer.py +0 -591
  65. omnigenome/utility/__init__.py +0 -3
  66. omnigenome/utility/dataset_hub/__init__.py +0 -12
  67. omnigenome/utility/dataset_hub/dataset_hub.py +0 -178
  68. omnigenome/utility/ensemble.py +0 -324
  69. omnigenome/utility/hub_utils.py +0 -517
  70. omnigenome/utility/model_hub/__init__.py +0 -11
  71. omnigenome/utility/model_hub/model_hub.py +0 -232
  72. omnigenome/utility/pipeline_hub/__init__.py +0 -11
  73. omnigenome/utility/pipeline_hub/pipeline.py +0 -483
  74. omnigenome/utility/pipeline_hub/pipeline_hub.py +0 -129
  75. omnigenome-0.3.1a0.dist-info/RECORD +0 -78
  76. {omnigenome-0.3.1a0.dist-info → omnigenome-0.3.4a0.dist-info}/WHEEL +0 -0
  77. {omnigenome-0.3.1a0.dist-info → omnigenome-0.3.4a0.dist-info}/entry_points.txt +0 -0
  78. {omnigenome-0.3.1a0.dist-info → omnigenome-0.3.4a0.dist-info}/licenses/LICENSE +0 -0
  79. {omnigenome-0.3.1a0.dist-info → omnigenome-0.3.4a0.dist-info}/top_level.txt +0 -0
@@ -1,503 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # file: utils.py
3
- # time: 14:45 06/04/2024
4
- # author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
5
- # github: https://github.com/yangheng95
6
- # huggingface: https://huggingface.co/yangheng
7
- # google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
8
- # Copyright (C) 2019-2024. All Rights Reserved.
9
- import multiprocessing
10
- import os
11
- import pickle
12
- import sys
13
- import tempfile
14
- import time
15
- import warnings
16
-
17
- import ViennaRNA as RNA
18
- import findfile
19
-
20
- default_omnigenome_repo = (
21
- "https://huggingface.co/spaces/yangheng/OmniGenomeLeaderboard/"
22
- )
23
-
24
-
25
- def seed_everything(seed=42):
26
- """
27
- Sets random seeds for reproducibility across all random number generators.
28
-
29
- This function sets seeds for Python's random module, NumPy, PyTorch (CPU and CUDA),
30
- and sets the PYTHONHASHSEED environment variable to ensure reproducible results
31
- across different runs.
32
-
33
- Args:
34
- seed (int): The seed value to use for all random number generators.
35
- Defaults to 42.
36
-
37
- Example:
38
- >>> # Set seeds for reproducibility
39
- >>> seed_everything(42)
40
- >>> # Now all random operations will be reproducible
41
- """
42
- import random
43
- import numpy as np
44
- import torch
45
-
46
- random.seed(seed)
47
- os.environ["PYTHONHASHSEED"] = str(seed)
48
- np.random.seed(seed)
49
- torch.manual_seed(seed)
50
- torch.cuda.manual_seed(seed)
51
- torch.backends.cudnn.deterministic = True
52
- torch.backends.cudnn.benchmark = False
53
-
54
-
55
- class RNA2StructureCache(dict):
56
- """
57
- A cache for RNA secondary structure predictions using ViennaRNA.
58
-
59
- This class provides a caching mechanism for RNA secondary structure predictions
60
- to avoid redundant computations. It supports both single sequence and batch
61
- processing with optional multiprocessing for improved performance.
62
-
63
- Attributes:
64
- cache (dict): Dictionary storing sequence-structure mappings
65
- cache_file (str): Path to the cache file on disk
66
- queue_num (int): Counter for tracking cache updates
67
- """
68
-
69
- def __init__(self, cache_file=None, *args, **kwargs):
70
- """
71
- Initialize the RNA structure cache.
72
-
73
- Args:
74
- cache_file (str, optional): Path to the cache file. If None, uses
75
- a default temporary file.
76
- *args: Additional positional arguments for dict initialization
77
- **kwargs: Additional keyword arguments for dict initialization
78
- """
79
- super().__init__(*args, **kwargs)
80
- self.cache = dict(*args, **kwargs)
81
- self.cache_file = (
82
- cache_file
83
- if cache_file is not None
84
- else os.path.join(tempfile.gettempdir(), "rna_structure_cache.pkl")
85
- )
86
- self.queue_num = 0
87
-
88
- # Load existing cache if available
89
- if os.path.exists(self.cache_file):
90
- try:
91
- with open(self.cache_file, "rb") as f:
92
- self.cache.update(pickle.load(f))
93
- except Exception as e:
94
- warnings.warn(f"Failed to load cache file: {e}")
95
-
96
- def __getitem__(self, key):
97
- """Gets a cached structure prediction."""
98
- return self.cache[key]
99
-
100
- def __setitem__(self, key, value):
101
- """Sets a structure prediction in the cache."""
102
- self.cache[key] = value
103
-
104
- def __str__(self):
105
- """String representation of the cache."""
106
- return str(self.cache)
107
-
108
- def __repr__(self):
109
- """String representation of the cache."""
110
- return str(self.cache)
111
-
112
- def _fold_single_sequence(self, sequence):
113
- """
114
- Predict structure for a single sequence (worker function for multiprocessing).
115
-
116
- Args:
117
- sequence (str): RNA sequence to fold
118
-
119
- Returns:
120
- tuple: (structure, mfe) tuple
121
- """
122
- try:
123
- return RNA.fold(sequence)
124
- except Exception as e:
125
- warnings.warn(f"Failed to fold sequence {sequence}: {e}")
126
- return ("." * len(sequence), 0.0)
127
-
128
- def fold(self, sequence, return_mfe=False, num_workers=1):
129
- """
130
- Predicts RNA secondary structure for given sequences.
131
-
132
- This method predicts RNA secondary structures using ViennaRNA. It supports
133
- both single sequences and batches of sequences. The method uses caching
134
- to avoid redundant predictions and supports multiprocessing for batch
135
- processing on non-Windows systems.
136
-
137
- Args:
138
- sequence (str or list): A single RNA sequence or a list of sequences.
139
- return_mfe (bool): Whether to return minimum free energy along with
140
- structure. Defaults to False.
141
- num_workers (int): Number of worker processes for batch processing.
142
- Defaults to 1. Set to None for auto-detection.
143
-
144
- Returns:
145
- str or list: The predicted structure(s). If return_mfe is True,
146
- returns tuples of (structure, mfe).
147
-
148
- Example:
149
- >>> cache = RNA2StructureCache()
150
- >>> # Predict structure for a single sequence
151
- >>> structure = cache.fold("GGGAAAUCC")
152
- >>> print(structure) # "(((...)))"
153
-
154
- >>> # Predict structures for multiple sequences
155
- >>> structures = cache.fold(["GGGAAAUCC", "AUUGCUAA"])
156
- >>> print(structures) # ["(((...)))", "........"]
157
- """
158
- if not isinstance(sequence, list):
159
- sequences = [sequence]
160
- else:
161
- sequences = sequence
162
-
163
- # Determine if we should use multiprocessing
164
- use_multiprocessing = (
165
- os.name != "nt" # Not Windows
166
- and len(sequences) > 1 # Multiple sequences
167
- and num_workers > 1 # Multiple workers requested
168
- )
169
-
170
- # Find sequences that need prediction
171
- sequences_to_predict = [seq for seq in sequences if seq not in self.cache]
172
-
173
- if sequences_to_predict:
174
- if use_multiprocessing:
175
- # Use multiprocessing for batch prediction
176
- if num_workers is None:
177
- num_workers = min(os.cpu_count(), len(sequences_to_predict))
178
-
179
- try:
180
- # Set multiprocessing start method to 'spawn' for better compatibility
181
- if multiprocessing.get_start_method(allow_none=True) != "spawn":
182
- multiprocessing.set_start_method("spawn", force=True)
183
-
184
- with multiprocessing.Pool(num_workers) as pool:
185
- # Use map instead of apply_async for better error handling
186
- results = pool.map(
187
- self._fold_single_sequence, sequences_to_predict
188
- )
189
-
190
- # Update cache with results
191
- for seq, result in zip(sequences_to_predict, results):
192
- self.cache[seq] = result
193
- self.queue_num += 1
194
-
195
- except Exception as e:
196
- warnings.warn(
197
- f"Multiprocessing failed, falling back to sequential: {e}"
198
- )
199
- # Fallback to sequential processing
200
- for seq in sequences_to_predict:
201
- self.cache[seq] = self._fold_single_sequence(seq)
202
- self.queue_num += 1
203
- else:
204
- # Sequential processing
205
- for seq in sequences_to_predict:
206
- self.cache[seq] = self._fold_single_sequence(seq)
207
- self.queue_num += 1
208
-
209
- # Prepare output
210
- if return_mfe:
211
- structures = [self.cache[seq] for seq in sequences]
212
- else:
213
- structures = [self.cache[seq][0] for seq in sequences]
214
-
215
- # Update cache file periodically
216
- self.update_cache_file(self.cache_file)
217
-
218
- # Return single result or list
219
- if len(structures) == 1:
220
- return structures[0]
221
- else:
222
- return structures
223
-
224
- def update_cache_file(self, cache_file=None):
225
- """
226
- Updates the cache file on disk.
227
-
228
- This method saves the in-memory cache to disk. It only saves when
229
- the queue_num reaches 100 to avoid excessive disk I/O.
230
-
231
- Args:
232
- cache_file (str, optional): Path to the cache file. If None, uses
233
- the instance's cache_file.
234
-
235
- Example:
236
- >>> cache.update_cache_file() # Force save to disk
237
- """
238
- if self.queue_num < 100:
239
- return
240
-
241
- if cache_file is None:
242
- cache_file = self.cache_file
243
-
244
- try:
245
- if not os.path.exists(os.path.dirname(cache_file)):
246
- os.makedirs(os.path.dirname(cache_file))
247
-
248
- with open(cache_file, "wb") as f:
249
- pickle.dump(self.cache, f)
250
-
251
- self.queue_num = 0
252
- except Exception as e:
253
- warnings.warn(f"Failed to update cache file: {e}")
254
-
255
-
256
- def env_meta_info():
257
- """
258
- Collects metadata about the current environment and library versions.
259
-
260
- This function gathers information about the current Python environment,
261
- including versions of key libraries like PyTorch and Transformers,
262
- as well as OmniGenome version information.
263
-
264
- Returns:
265
- dict: A dictionary containing environment metadata including:
266
- - library_name: Name of the OmniGenome library
267
- - omnigenome_version: Version of OmniGenome
268
- - torch_version: PyTorch version with CUDA info
269
- - transformers_version: Transformers library version
270
-
271
- Example:
272
- >>> metadata = env_meta_info()
273
- >>> print(metadata['torch_version']) # "2.0.0+cu118+git..."
274
- """
275
- from torch.version import __version__ as torch_version
276
- from torch.version import cuda as torch_cuda_version
277
- from torch.version import git_version
278
- from transformers import __version__ as transformers_version
279
- from ... import __version__ as omnigenome_version
280
- from ... import __name__ as omnigenome_name
281
-
282
- return {
283
- "library_name": omnigenome_name,
284
- "omnigenome_version": omnigenome_version,
285
- "torch_version": f"{torch_version}+cu{torch_cuda_version}+git{git_version}",
286
- "transformers_version": transformers_version,
287
- }
288
-
289
-
290
- def naive_secondary_structure_repair(sequence, structure):
291
- """
292
- Repair the secondary structure of a sequence.
293
-
294
- This function attempts to repair malformed RNA secondary structure
295
- representations by ensuring proper bracket matching. It handles
296
- common issues like unmatched brackets by converting them to dots.
297
-
298
- Args:
299
- sequence (str): A string representing the sequence.
300
- structure (str): A string representing the secondary structure.
301
-
302
- Returns:
303
- str: A string representing the repaired secondary structure.
304
-
305
- Example:
306
- >>> sequence = "GGGAAAUCC"
307
- >>> structure = "(((...)" # Malformed structure
308
- >>> repaired = naive_secondary_structure_repair(sequence, structure)
309
- >>> print(repaired) # "(((...))"
310
- """
311
- repaired_structure = ""
312
- stack = []
313
- for i, (s, c) in enumerate(zip(structure, sequence)):
314
- if s == "(":
315
- stack.append(i)
316
- elif s == ")":
317
- if stack:
318
- stack.pop()
319
- else:
320
- repaired_structure += "."
321
- else:
322
- repaired_structure += s
323
- for i in stack:
324
- repaired_structure = repaired_structure[:i] + "." + repaired_structure[i + 1 :]
325
- return repaired_structure
326
-
327
-
328
- def save_args(config, save_path):
329
- """
330
- Save arguments to a file.
331
-
332
- This function saves the arguments from a configuration object to a text file.
333
- It's useful for logging experiment parameters and configurations.
334
-
335
- Args:
336
- config: A Namespace object containing the arguments.
337
- save_path (str): A string representing the path of the file to be saved.
338
-
339
- Example:
340
- >>> from argparse import Namespace
341
- >>> config = Namespace(learning_rate=0.001, batch_size=32)
342
- >>> save_args(config, "config.txt")
343
- """
344
- f = open(os.path.join(save_path), mode="w", encoding="utf8")
345
- for arg in config.args:
346
- if config.args_call_count[arg]:
347
- f.write("{}: {}\n".format(arg, config.args[arg]))
348
- f.close()
349
-
350
-
351
- def print_args(config, logger=None):
352
- """
353
- Print the arguments to the console.
354
-
355
- This function prints the arguments from a configuration object to the console
356
- or a logger. It's useful for debugging and logging experiment parameters.
357
-
358
- Args:
359
- config: A Namespace object containing the arguments.
360
- logger: A logger object. If None, prints to console.
361
-
362
- Example:
363
- >>> from argparse import Namespace
364
- >>> config = Namespace(learning_rate=0.001, batch_size=32)
365
- >>> print_args(config)
366
- """
367
- if logger is None:
368
- for arg in config.args:
369
- if config.args_call_count[arg]:
370
- print("{}: {}".format(arg, config.args[arg]))
371
- else:
372
- for arg in config.args:
373
- if config.args_call_count[arg]:
374
- logger.info("{}: {}".format(arg, config.args[arg]))
375
-
376
-
377
- def fprint(*objects, sep=" ", end="\n", file=sys.stdout, flush=False):
378
- """
379
- Enhanced print function with automatic flushing.
380
-
381
- This function provides a print-like interface with automatic flushing
382
- to ensure output is displayed immediately. It's useful for real-time
383
- logging and progress tracking.
384
-
385
- Args:
386
- *objects: Objects to print
387
- sep (str): Separator between objects (default: " ")
388
- end (str): String appended after the last value (default: "\n")
389
- file: File-like object to write to (default: sys.stdout)
390
- flush (bool): Whether to flush the stream (default: False)
391
-
392
- Example:
393
- >>> fprint("Training started...", flush=True)
394
- >>> fprint("Epoch 1/10", "Loss: 0.5", sep=" | ")
395
- """
396
- print(*objects, sep=sep, end=end, file=file, flush=True)
397
-
398
-
399
- def clean_temp_checkpoint(days_threshold=7):
400
- """
401
- Clean up temporary checkpoint files older than specified days.
402
-
403
- This function removes temporary checkpoint files that are older than
404
- the specified threshold to free up disk space.
405
-
406
- Args:
407
- days_threshold (int): Number of days after which files are considered old.
408
- Defaults to 7.
409
-
410
- Example:
411
- >>> clean_temp_checkpoint(3) # Remove files older than 3 days
412
- """
413
- import glob
414
- import time
415
-
416
- temp_patterns = [
417
- "temp_checkpoint_*",
418
- "checkpoint_*",
419
- "*.tmp",
420
- "*.temp",
421
- ]
422
-
423
- current_time = time.time()
424
- threshold_time = current_time - (days_threshold * 24 * 60 * 60)
425
-
426
- for pattern in temp_patterns:
427
- for file_path in glob.glob(pattern):
428
- try:
429
- if os.path.getmtime(file_path) < threshold_time:
430
- os.remove(file_path)
431
- except Exception:
432
- pass
433
-
434
-
435
- def load_module_from_path(module_name, file_path):
436
- """
437
- Load a Python module from a file path.
438
-
439
- This function dynamically loads a Python module from a file path,
440
- useful for loading configuration files or custom modules.
441
-
442
- Args:
443
- module_name (str): Name to assign to the loaded module
444
- file_path (str): Path to the Python file to load
445
-
446
- Returns:
447
- module: The loaded module object
448
-
449
- Example:
450
- >>> config = load_module_from_path("config", "config.py")
451
- >>> print(config.some_variable)
452
- """
453
- import importlib.util
454
-
455
- spec = importlib.util.spec_from_file_location(module_name, file_path)
456
- module = importlib.util.module_from_spec(spec)
457
- spec.loader.exec_module(module)
458
- return module
459
-
460
-
461
- def check_bench_version(bench_version, omnigenome_version):
462
- """
463
- Check if benchmark version is compatible with OmniGenome version.
464
-
465
- This function compares the benchmark version with the OmniGenome version
466
- to ensure compatibility and warns if there are potential issues.
467
-
468
- Args:
469
- bench_version (str): Version of the benchmark
470
- omnigenome_version (str): Version of OmniGenome
471
-
472
- Example:
473
- >>> check_bench_version("0.2.0", "0.3.0")
474
- """
475
- if bench_version != omnigenome_version:
476
- warnings.warn(
477
- f"Benchmark version ({bench_version}) differs from "
478
- f"OmniGenome version ({omnigenome_version}). "
479
- f"This may cause compatibility issues."
480
- )
481
-
482
-
483
- def clean_temp_dir_pt_files():
484
- """
485
- Clean up temporary PyTorch files in the current directory.
486
-
487
- This function removes temporary PyTorch files (like .pt, .pth files)
488
- that may be left over from previous runs.
489
-
490
- Example:
491
- >>> clean_temp_dir_pt_files()
492
- """
493
- import glob
494
-
495
- temp_patterns = ["*.pt", "*.pth", "temp_*", "checkpoint_*"]
496
-
497
- for pattern in temp_patterns:
498
- for file_path in glob.glob(pattern):
499
- try:
500
- if os.path.isfile(file_path):
501
- os.remove(file_path)
502
- except Exception:
503
- pass
@@ -1,19 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # file: __init__.py
3
- # time: 14:08 06/04/2024
4
- # author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
5
- # github: https://github.com/yangheng95
6
- # huggingface: https://huggingface.co/yangheng
7
- # google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
8
- # Copyright (C) 2019-2024. All Rights Reserved.
9
- """
10
- This package contains model definitions for various tasks.
11
- """
12
-
13
- from .classification.model import *
14
- from .mlm.model import *
15
- from .regression.model import *
16
- from .seq2seq.model import *
17
- from .rna_design.model import *
18
- from .embedding.model import *
19
- from .augmentation.model import *
@@ -1,11 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # file: __init__.py
3
- # time: 19:06 22/09/2024
4
- # author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
5
- # github: https://github.com/yangheng95
6
- # huggingface: https://huggingface.co/yangheng
7
- # google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
8
- # Copyright (C) 2019-2024. All Rights Reserved.
9
- """
10
- This package contains modules for data augmentation.
11
- """