omnigenome 0.3.0a0__py3-none-any.whl → 0.3.1a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- omnigenome/__init__.py +29 -44
- omnigenome/auto/auto_bench/__init__.py +0 -1
- omnigenome/auto/auto_bench/auto_bench.py +24 -14
- omnigenome/auto/auto_train/__init__.py +0 -1
- omnigenome/auto/auto_train/auto_train.py +11 -12
- omnigenome/auto/bench_hub/__init__.py +0 -1
- omnigenome/auto/bench_hub/bench_hub.py +1 -1
- omnigenome/cli/__init__.py +0 -1
- omnigenome/cli/commands/__init__.py +0 -1
- omnigenome/cli/commands/base.py +10 -10
- omnigenome/cli/commands/bench/__init__.py +0 -1
- omnigenome/cli/commands/bench/bench_cli.py +10 -10
- omnigenome/cli/commands/rna/__init__.py +0 -1
- omnigenome/cli/commands/rna/rna_design.py +10 -11
- omnigenome/src/__init__.py +0 -1
- omnigenome/src/abc/__init__.py +0 -1
- omnigenome/src/abc/abstract_dataset.py +38 -19
- omnigenome/src/abc/abstract_metric.py +7 -7
- omnigenome/src/abc/abstract_model.py +15 -14
- omnigenome/src/abc/abstract_tokenizer.py +9 -7
- omnigenome/src/dataset/omni_dataset.py +16 -14
- omnigenome/src/lora/__init__.py +0 -1
- omnigenome/src/lora/lora_model.py +47 -41
- omnigenome/src/metric/classification_metric.py +11 -11
- omnigenome/src/metric/metric.py +19 -19
- omnigenome/src/metric/ranking_metric.py +15 -15
- omnigenome/src/metric/regression_metric.py +18 -18
- omnigenome/src/misc/utils.py +214 -150
- omnigenome/src/model/augmentation/__init__.py +0 -1
- omnigenome/src/model/augmentation/model.py +17 -17
- omnigenome/src/model/classification/__init__.py +0 -1
- omnigenome/src/model/classification/model.py +28 -32
- omnigenome/src/model/embedding/__init__.py +0 -1
- omnigenome/src/model/embedding/model.py +35 -35
- omnigenome/src/model/mlm/__init__.py +0 -1
- omnigenome/src/model/mlm/model.py +13 -13
- omnigenome/src/model/module_utils.py +17 -17
- omnigenome/src/model/regression/__init__.py +0 -1
- omnigenome/src/model/regression/model.py +72 -77
- omnigenome/src/model/regression/resnet.py +32 -32
- omnigenome/src/model/rna_design/__init__.py +0 -1
- omnigenome/src/model/rna_design/model.py +168 -118
- omnigenome/src/model/seq2seq/__init__.py +0 -1
- omnigenome/src/model/seq2seq/model.py +4 -4
- omnigenome/src/tokenizer/bpe_tokenizer.py +27 -27
- omnigenome/src/tokenizer/kmers_tokenizer.py +22 -22
- omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +11 -11
- omnigenome/src/trainer/accelerate_trainer.py +40 -32
- omnigenome/src/trainer/hf_trainer.py +8 -8
- omnigenome/src/trainer/trainer.py +37 -25
- omnigenome/utility/dataset_hub/__init__.py +0 -1
- omnigenome/utility/dataset_hub/dataset_hub.py +13 -13
- omnigenome/utility/ensemble.py +26 -26
- omnigenome/utility/hub_utils.py +8 -8
- omnigenome/utility/model_hub/__init__.py +0 -1
- omnigenome/utility/model_hub/model_hub.py +26 -25
- omnigenome/utility/pipeline_hub/__init__.py +0 -1
- omnigenome/utility/pipeline_hub/pipeline.py +49 -49
- omnigenome/utility/pipeline_hub/pipeline_hub.py +17 -17
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/METADATA +3 -3
- omnigenome-0.3.1a0.dist-info/RECORD +78 -0
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/top_level.txt +0 -1
- omnigenome-0.3.0a0.dist-info/RECORD +0 -85
- tests/__init__.py +0 -9
- tests/conftest.py +0 -160
- tests/test_dataset_patterns.py +0 -291
- tests/test_examples_syntax.py +0 -83
- tests/test_model_loading.py +0 -183
- tests/test_rna_functions.py +0 -255
- tests/test_training_patterns.py +0 -302
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/WHEEL +0 -0
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/entry_points.txt +0 -0
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/licenses/LICENSE +0 -0
|
@@ -20,20 +20,20 @@ from ..abc.abstract_metric import OmniMetric
|
|
|
20
20
|
def mcrmse(y_true, y_pred):
|
|
21
21
|
"""
|
|
22
22
|
Compute Mean Column Root Mean Square Error (MCRMSE).
|
|
23
|
-
|
|
23
|
+
|
|
24
24
|
MCRMSE is a multi-target regression metric that computes the RMSE for each target
|
|
25
25
|
column and then takes the mean across all targets.
|
|
26
|
-
|
|
26
|
+
|
|
27
27
|
Args:
|
|
28
28
|
y_true (np.ndarray): Ground truth values with shape (n_samples, n_targets)
|
|
29
29
|
y_pred (np.ndarray): Predicted values with shape (n_samples, n_targets)
|
|
30
|
-
|
|
30
|
+
|
|
31
31
|
Returns:
|
|
32
32
|
float: Mean Column Root Mean Square Error
|
|
33
|
-
|
|
33
|
+
|
|
34
34
|
Raises:
|
|
35
35
|
ValueError: If y_true and y_pred have different shapes
|
|
36
|
-
|
|
36
|
+
|
|
37
37
|
Example:
|
|
38
38
|
>>> y_true = np.array([[1, 2], [3, 4], [5, 6]])
|
|
39
39
|
>>> y_pred = np.array([[1.1, 2.1], [2.9, 4.1], [5.2, 5.8]])
|
|
@@ -56,18 +56,18 @@ setattr(metrics, "mcrmse", mcrmse)
|
|
|
56
56
|
class RegressionMetric(OmniMetric):
|
|
57
57
|
"""
|
|
58
58
|
A specialized metric class for regression tasks and evaluation.
|
|
59
|
-
|
|
59
|
+
|
|
60
60
|
This class provides access to regression-specific metrics from scikit-learn
|
|
61
61
|
and handles different input formats including HuggingFace trainer outputs.
|
|
62
62
|
It dynamically wraps scikit-learn metrics and provides a unified interface
|
|
63
63
|
for computing various regression evaluation metrics.
|
|
64
|
-
|
|
64
|
+
|
|
65
65
|
Attributes:
|
|
66
66
|
metric_func: Custom metric function if provided
|
|
67
67
|
ignore_y: Value to ignore in predictions and true values
|
|
68
68
|
kwargs: Additional keyword arguments for metric computation
|
|
69
69
|
metrics: Dictionary of available metrics including custom ones
|
|
70
|
-
|
|
70
|
+
|
|
71
71
|
Example:
|
|
72
72
|
>>> from omnigenome.src.metric import RegressionMetric
|
|
73
73
|
>>> metric = RegressionMetric(ignore_y=-100)
|
|
@@ -81,7 +81,7 @@ class RegressionMetric(OmniMetric):
|
|
|
81
81
|
def __init__(self, metric_func=None, ignore_y=-100, *args, **kwargs):
|
|
82
82
|
"""
|
|
83
83
|
Initialize the RegressionMetric class.
|
|
84
|
-
|
|
84
|
+
|
|
85
85
|
Args:
|
|
86
86
|
metric_func (callable, optional): Custom metric function to use
|
|
87
87
|
ignore_y (int, optional): Value to ignore in predictions and true values. Defaults to -100
|
|
@@ -97,14 +97,14 @@ class RegressionMetric(OmniMetric):
|
|
|
97
97
|
def __getattribute__(self, name):
|
|
98
98
|
"""
|
|
99
99
|
Dynamically create regression metric computation methods.
|
|
100
|
-
|
|
100
|
+
|
|
101
101
|
This method intercepts attribute access and creates wrapper functions
|
|
102
102
|
for scikit-learn regression metrics, handling different input formats and
|
|
103
103
|
preprocessing the data appropriately.
|
|
104
|
-
|
|
104
|
+
|
|
105
105
|
Args:
|
|
106
106
|
name (str): Name of the regression metric to access
|
|
107
|
-
|
|
107
|
+
|
|
108
108
|
Returns:
|
|
109
109
|
callable: Wrapper function for the requested regression metric
|
|
110
110
|
"""
|
|
@@ -118,17 +118,17 @@ class RegressionMetric(OmniMetric):
|
|
|
118
118
|
def wrapper(y_true=None, y_score=None, *args, **kwargs):
|
|
119
119
|
"""
|
|
120
120
|
Compute the regression metric, based on the true and predicted values.
|
|
121
|
-
|
|
121
|
+
|
|
122
122
|
This wrapper handles different input formats including HuggingFace
|
|
123
123
|
trainer outputs and performs necessary preprocessing for regression tasks.
|
|
124
|
-
|
|
124
|
+
|
|
125
125
|
Args:
|
|
126
126
|
y_true: The true values or HuggingFace EvalPrediction object
|
|
127
127
|
y_score: The predicted values
|
|
128
128
|
ignore_y: The value to ignore in the predictions and true values in corresponding positions
|
|
129
129
|
*args: Additional positional arguments for the metric
|
|
130
130
|
**kwargs: Additional keyword arguments for the metric
|
|
131
|
-
|
|
131
|
+
|
|
132
132
|
Returns:
|
|
133
133
|
dict: Dictionary containing the metric name and computed value
|
|
134
134
|
"""
|
|
@@ -168,16 +168,16 @@ class RegressionMetric(OmniMetric):
|
|
|
168
168
|
def compute(self, y_true, y_score, *args, **kwargs):
|
|
169
169
|
"""
|
|
170
170
|
Compute the regression metric, based on the true and predicted values.
|
|
171
|
-
|
|
171
|
+
|
|
172
172
|
Args:
|
|
173
173
|
y_true: The true values
|
|
174
174
|
y_score: The predicted values
|
|
175
175
|
*args: Additional positional arguments for the metric
|
|
176
176
|
**kwargs: Additional keyword arguments for the metric
|
|
177
|
-
|
|
177
|
+
|
|
178
178
|
Returns:
|
|
179
179
|
The computed regression metric value
|
|
180
|
-
|
|
180
|
+
|
|
181
181
|
Raises:
|
|
182
182
|
NotImplementedError: If no metric function is provided and compute is not implemented
|
|
183
183
|
"""
|
omnigenome/src/misc/utils.py
CHANGED
|
@@ -12,6 +12,7 @@ import pickle
|
|
|
12
12
|
import sys
|
|
13
13
|
import tempfile
|
|
14
14
|
import time
|
|
15
|
+
import warnings
|
|
15
16
|
|
|
16
17
|
import ViennaRNA as RNA
|
|
17
18
|
import findfile
|
|
@@ -24,13 +25,13 @@ default_omnigenome_repo = (
|
|
|
24
25
|
def seed_everything(seed=42):
|
|
25
26
|
"""
|
|
26
27
|
Sets random seeds for reproducibility across all random number generators.
|
|
27
|
-
|
|
28
|
+
|
|
28
29
|
This function sets seeds for Python's random module, NumPy, PyTorch (CPU and CUDA),
|
|
29
30
|
and sets the PYTHONHASHSEED environment variable to ensure reproducible results
|
|
30
31
|
across different runs.
|
|
31
|
-
|
|
32
|
+
|
|
32
33
|
Args:
|
|
33
|
-
seed (int): The seed value to use for all random number generators.
|
|
34
|
+
seed (int): The seed value to use for all random number generators.
|
|
34
35
|
Defaults to 42.
|
|
35
36
|
|
|
36
37
|
Example:
|
|
@@ -48,58 +49,50 @@ def seed_everything(seed=42):
|
|
|
48
49
|
torch.manual_seed(seed)
|
|
49
50
|
torch.cuda.manual_seed(seed)
|
|
50
51
|
torch.backends.cudnn.deterministic = True
|
|
52
|
+
torch.backends.cudnn.benchmark = False
|
|
51
53
|
|
|
52
54
|
|
|
53
55
|
class RNA2StructureCache(dict):
|
|
54
56
|
"""
|
|
55
|
-
A cache for RNA
|
|
56
|
-
|
|
57
|
-
This class provides a
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
The cache can be persisted to disk and loaded back, making it useful for
|
|
62
|
-
avoiding redundant structure predictions across multiple runs.
|
|
63
|
-
|
|
57
|
+
A cache for RNA secondary structure predictions using ViennaRNA.
|
|
58
|
+
|
|
59
|
+
This class provides a caching mechanism for RNA secondary structure predictions
|
|
60
|
+
to avoid redundant computations. It supports both single sequence and batch
|
|
61
|
+
processing with optional multiprocessing for improved performance.
|
|
62
|
+
|
|
64
63
|
Attributes:
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
queue_num (int): Counter for tracking cache updates
|
|
64
|
+
cache (dict): Dictionary storing sequence-structure mappings
|
|
65
|
+
cache_file (str): Path to the cache file on disk
|
|
66
|
+
queue_num (int): Counter for tracking cache updates
|
|
68
67
|
"""
|
|
69
68
|
|
|
70
69
|
def __init__(self, cache_file=None, *args, **kwargs):
|
|
71
70
|
"""
|
|
72
|
-
|
|
71
|
+
Initialize the RNA structure cache.
|
|
73
72
|
|
|
74
73
|
Args:
|
|
75
74
|
cache_file (str, optional): Path to the cache file. If None, uses
|
|
76
|
-
a default
|
|
77
|
-
*args: Additional arguments
|
|
78
|
-
**kwargs: Additional keyword arguments
|
|
79
|
-
|
|
80
|
-
Example:
|
|
81
|
-
>>> # Initialize with default cache file
|
|
82
|
-
>>> cache = RNA2StructureCache()
|
|
83
|
-
|
|
84
|
-
>>> # Initialize with custom cache file
|
|
85
|
-
>>> cache = RNA2StructureCache("my_cache.pkl")
|
|
75
|
+
a default temporary file.
|
|
76
|
+
*args: Additional positional arguments for dict initialization
|
|
77
|
+
**kwargs: Additional keyword arguments for dict initialization
|
|
86
78
|
"""
|
|
87
79
|
super().__init__(*args, **kwargs)
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
if self.cache_file is None or not os.path.exists(self.cache_file):
|
|
95
|
-
self.cache = {}
|
|
96
|
-
else:
|
|
97
|
-
fprint(f"Initialize sequence to structure cache from {self.cache_file}...")
|
|
98
|
-
with open(self.cache_file, "rb") as f:
|
|
99
|
-
self.cache = pickle.load(f)
|
|
100
|
-
|
|
80
|
+
self.cache = dict(*args, **kwargs)
|
|
81
|
+
self.cache_file = (
|
|
82
|
+
cache_file
|
|
83
|
+
if cache_file is not None
|
|
84
|
+
else os.path.join(tempfile.gettempdir(), "rna_structure_cache.pkl")
|
|
85
|
+
)
|
|
101
86
|
self.queue_num = 0
|
|
102
87
|
|
|
88
|
+
# Load existing cache if available
|
|
89
|
+
if os.path.exists(self.cache_file):
|
|
90
|
+
try:
|
|
91
|
+
with open(self.cache_file, "rb") as f:
|
|
92
|
+
self.cache.update(pickle.load(f))
|
|
93
|
+
except Exception as e:
|
|
94
|
+
warnings.warn(f"Failed to load cache file: {e}")
|
|
95
|
+
|
|
103
96
|
def __getitem__(self, key):
|
|
104
97
|
"""Gets a cached structure prediction."""
|
|
105
98
|
return self.cache[key]
|
|
@@ -116,15 +109,31 @@ class RNA2StructureCache(dict):
|
|
|
116
109
|
"""String representation of the cache."""
|
|
117
110
|
return str(self.cache)
|
|
118
111
|
|
|
112
|
+
def _fold_single_sequence(self, sequence):
|
|
113
|
+
"""
|
|
114
|
+
Predict structure for a single sequence (worker function for multiprocessing).
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
sequence (str): RNA sequence to fold
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
tuple: (structure, mfe) tuple
|
|
121
|
+
"""
|
|
122
|
+
try:
|
|
123
|
+
return RNA.fold(sequence)
|
|
124
|
+
except Exception as e:
|
|
125
|
+
warnings.warn(f"Failed to fold sequence {sequence}: {e}")
|
|
126
|
+
return ("." * len(sequence), 0.0)
|
|
127
|
+
|
|
119
128
|
def fold(self, sequence, return_mfe=False, num_workers=1):
|
|
120
129
|
"""
|
|
121
130
|
Predicts RNA secondary structure for given sequences.
|
|
122
|
-
|
|
131
|
+
|
|
123
132
|
This method predicts RNA secondary structures using ViennaRNA. It supports
|
|
124
133
|
both single sequences and batches of sequences. The method uses caching
|
|
125
134
|
to avoid redundant predictions and supports multiprocessing for batch
|
|
126
135
|
processing on non-Windows systems.
|
|
127
|
-
|
|
136
|
+
|
|
128
137
|
Args:
|
|
129
138
|
sequence (str or list): A single RNA sequence or a list of sequences.
|
|
130
139
|
return_mfe (bool): Whether to return minimum free energy along with
|
|
@@ -141,7 +150,7 @@ class RNA2StructureCache(dict):
|
|
|
141
150
|
>>> # Predict structure for a single sequence
|
|
142
151
|
>>> structure = cache.fold("GGGAAAUCC")
|
|
143
152
|
>>> print(structure) # "(((...)))"
|
|
144
|
-
|
|
153
|
+
|
|
145
154
|
>>> # Predict structures for multiple sequences
|
|
146
155
|
>>> structures = cache.fold(["GGGAAAUCC", "AUUGCUAA"])
|
|
147
156
|
>>> print(structures) # ["(((...)))", "........"]
|
|
@@ -151,39 +160,62 @@ class RNA2StructureCache(dict):
|
|
|
151
160
|
else:
|
|
152
161
|
sequences = sequence
|
|
153
162
|
|
|
154
|
-
if
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
163
|
+
# Determine if we should use multiprocessing
|
|
164
|
+
use_multiprocessing = (
|
|
165
|
+
os.name != "nt" # Not Windows
|
|
166
|
+
and len(sequences) > 1 # Multiple sequences
|
|
167
|
+
and num_workers > 1 # Multiple workers requested
|
|
168
|
+
)
|
|
158
169
|
|
|
159
|
-
|
|
170
|
+
# Find sequences that need prediction
|
|
171
|
+
sequences_to_predict = [seq for seq in sequences if seq not in self.cache]
|
|
160
172
|
|
|
161
|
-
if
|
|
162
|
-
if
|
|
163
|
-
for
|
|
164
|
-
if seq not in self.cache:
|
|
165
|
-
self.queue_num += 1
|
|
166
|
-
self.cache[seq] = RNA.fold(seq)
|
|
167
|
-
else:
|
|
173
|
+
if sequences_to_predict:
|
|
174
|
+
if use_multiprocessing:
|
|
175
|
+
# Use multiprocessing for batch prediction
|
|
168
176
|
if num_workers is None:
|
|
169
|
-
num_workers = min(os.cpu_count(), len(
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
for
|
|
173
|
-
|
|
177
|
+
num_workers = min(os.cpu_count(), len(sequences_to_predict))
|
|
178
|
+
|
|
179
|
+
try:
|
|
180
|
+
# Set multiprocessing start method to 'spawn' for better compatibility
|
|
181
|
+
if multiprocessing.get_start_method(allow_none=True) != "spawn":
|
|
182
|
+
multiprocessing.set_start_method("spawn", force=True)
|
|
183
|
+
|
|
184
|
+
with multiprocessing.Pool(num_workers) as pool:
|
|
185
|
+
# Use map instead of apply_async for better error handling
|
|
186
|
+
results = pool.map(
|
|
187
|
+
self._fold_single_sequence, sequences_to_predict
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
# Update cache with results
|
|
191
|
+
for seq, result in zip(sequences_to_predict, results):
|
|
192
|
+
self.cache[seq] = result
|
|
174
193
|
self.queue_num += 1
|
|
175
|
-
async_result = pool.apply_async(RNA.fold, args=(seq,))
|
|
176
|
-
structures.append((seq, async_result))
|
|
177
194
|
|
|
178
|
-
|
|
179
|
-
|
|
195
|
+
except Exception as e:
|
|
196
|
+
warnings.warn(
|
|
197
|
+
f"Multiprocessing failed, falling back to sequential: {e}"
|
|
198
|
+
)
|
|
199
|
+
# Fallback to sequential processing
|
|
200
|
+
for seq in sequences_to_predict:
|
|
201
|
+
self.cache[seq] = self._fold_single_sequence(seq)
|
|
202
|
+
self.queue_num += 1
|
|
203
|
+
else:
|
|
204
|
+
# Sequential processing
|
|
205
|
+
for seq in sequences_to_predict:
|
|
206
|
+
self.cache[seq] = self._fold_single_sequence(seq)
|
|
207
|
+
self.queue_num += 1
|
|
180
208
|
|
|
209
|
+
# Prepare output
|
|
181
210
|
if return_mfe:
|
|
182
211
|
structures = [self.cache[seq] for seq in sequences]
|
|
183
212
|
else:
|
|
184
213
|
structures = [self.cache[seq][0] for seq in sequences]
|
|
214
|
+
|
|
215
|
+
# Update cache file periodically
|
|
185
216
|
self.update_cache_file(self.cache_file)
|
|
186
217
|
|
|
218
|
+
# Return single result or list
|
|
187
219
|
if len(structures) == 1:
|
|
188
220
|
return structures[0]
|
|
189
221
|
else:
|
|
@@ -192,10 +224,10 @@ class RNA2StructureCache(dict):
|
|
|
192
224
|
def update_cache_file(self, cache_file=None):
|
|
193
225
|
"""
|
|
194
226
|
Updates the cache file on disk.
|
|
195
|
-
|
|
227
|
+
|
|
196
228
|
This method saves the in-memory cache to disk. It only saves when
|
|
197
229
|
the queue_num reaches 100 to avoid excessive disk I/O.
|
|
198
|
-
|
|
230
|
+
|
|
199
231
|
Args:
|
|
200
232
|
cache_file (str, optional): Path to the cache file. If None, uses
|
|
201
233
|
the instance's cache_file.
|
|
@@ -209,24 +241,26 @@ class RNA2StructureCache(dict):
|
|
|
209
241
|
if cache_file is None:
|
|
210
242
|
cache_file = self.cache_file
|
|
211
243
|
|
|
212
|
-
|
|
213
|
-
os.
|
|
244
|
+
try:
|
|
245
|
+
if not os.path.exists(os.path.dirname(cache_file)):
|
|
246
|
+
os.makedirs(os.path.dirname(cache_file))
|
|
214
247
|
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
pickle.dump(self.cache, f)
|
|
248
|
+
with open(cache_file, "wb") as f:
|
|
249
|
+
pickle.dump(self.cache, f)
|
|
218
250
|
|
|
219
|
-
|
|
251
|
+
self.queue_num = 0
|
|
252
|
+
except Exception as e:
|
|
253
|
+
warnings.warn(f"Failed to update cache file: {e}")
|
|
220
254
|
|
|
221
255
|
|
|
222
256
|
def env_meta_info():
|
|
223
257
|
"""
|
|
224
258
|
Collects metadata about the current environment and library versions.
|
|
225
|
-
|
|
259
|
+
|
|
226
260
|
This function gathers information about the current Python environment,
|
|
227
261
|
including versions of key libraries like PyTorch and Transformers,
|
|
228
262
|
as well as OmniGenome version information.
|
|
229
|
-
|
|
263
|
+
|
|
230
264
|
Returns:
|
|
231
265
|
dict: A dictionary containing environment metadata including:
|
|
232
266
|
- library_name: Name of the OmniGenome library
|
|
@@ -256,7 +290,7 @@ def env_meta_info():
|
|
|
256
290
|
def naive_secondary_structure_repair(sequence, structure):
|
|
257
291
|
"""
|
|
258
292
|
Repair the secondary structure of a sequence.
|
|
259
|
-
|
|
293
|
+
|
|
260
294
|
This function attempts to repair malformed RNA secondary structure
|
|
261
295
|
representations by ensuring proper bracket matching. It handles
|
|
262
296
|
common issues like unmatched brackets by converting them to dots.
|
|
@@ -294,7 +328,7 @@ def naive_secondary_structure_repair(sequence, structure):
|
|
|
294
328
|
def save_args(config, save_path):
|
|
295
329
|
"""
|
|
296
330
|
Save arguments to a file.
|
|
297
|
-
|
|
331
|
+
|
|
298
332
|
This function saves the arguments from a configuration object to a text file.
|
|
299
333
|
It's useful for logging experiment parameters and configurations.
|
|
300
334
|
|
|
@@ -317,7 +351,7 @@ def save_args(config, save_path):
|
|
|
317
351
|
def print_args(config, logger=None):
|
|
318
352
|
"""
|
|
319
353
|
Print the arguments to the console.
|
|
320
|
-
|
|
354
|
+
|
|
321
355
|
This function prints the arguments from a configuration object to the console
|
|
322
356
|
or a logger. It's useful for debugging and logging experiment parameters.
|
|
323
357
|
|
|
@@ -330,110 +364,140 @@ def print_args(config, logger=None):
|
|
|
330
364
|
>>> config = Namespace(learning_rate=0.001, batch_size=32)
|
|
331
365
|
>>> print_args(config)
|
|
332
366
|
"""
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
367
|
+
if logger is None:
|
|
368
|
+
for arg in config.args:
|
|
369
|
+
if config.args_call_count[arg]:
|
|
370
|
+
print("{}: {}".format(arg, config.args[arg]))
|
|
336
371
|
else:
|
|
337
|
-
|
|
372
|
+
for arg in config.args:
|
|
373
|
+
if config.args_call_count[arg]:
|
|
374
|
+
logger.info("{}: {}".format(arg, config.args[arg]))
|
|
338
375
|
|
|
339
376
|
|
|
340
377
|
def fprint(*objects, sep=" ", end="\n", file=sys.stdout, flush=False):
|
|
341
378
|
"""
|
|
342
|
-
|
|
379
|
+
Enhanced print function with automatic flushing.
|
|
380
|
+
|
|
381
|
+
This function provides a print-like interface with automatic flushing
|
|
382
|
+
to ensure output is displayed immediately. It's useful for real-time
|
|
383
|
+
logging and progress tracking.
|
|
343
384
|
|
|
344
385
|
Args:
|
|
345
|
-
*objects:
|
|
346
|
-
sep (str
|
|
347
|
-
end (str
|
|
348
|
-
file
|
|
349
|
-
flush (bool
|
|
386
|
+
*objects: Objects to print
|
|
387
|
+
sep (str): Separator between objects (default: " ")
|
|
388
|
+
end (str): String appended after the last value (default: "\n")
|
|
389
|
+
file: File-like object to write to (default: sys.stdout)
|
|
390
|
+
flush (bool): Whether to flush the stream (default: False)
|
|
391
|
+
|
|
392
|
+
Example:
|
|
393
|
+
>>> fprint("Training started...", flush=True)
|
|
394
|
+
>>> fprint("Epoch 1/10", "Loss: 0.5", sep=" | ")
|
|
350
395
|
"""
|
|
351
|
-
|
|
352
|
-
from omnigenome import __name__
|
|
353
|
-
|
|
354
|
-
print(
|
|
355
|
-
time.strftime(
|
|
356
|
-
"[%Y-%m-%d %H:%M:%S] [{} {}] ".format(__name__, __version__),
|
|
357
|
-
time.localtime(time.time()),
|
|
358
|
-
),
|
|
359
|
-
*objects,
|
|
360
|
-
sep=sep,
|
|
361
|
-
end=end,
|
|
362
|
-
file=file,
|
|
363
|
-
flush=flush,
|
|
364
|
-
)
|
|
396
|
+
print(*objects, sep=sep, end=end, file=file, flush=True)
|
|
365
397
|
|
|
366
398
|
|
|
367
399
|
def clean_temp_checkpoint(days_threshold=7):
|
|
368
400
|
"""
|
|
369
|
-
|
|
401
|
+
Clean up temporary checkpoint files older than specified days.
|
|
402
|
+
|
|
403
|
+
This function removes temporary checkpoint files that are older than
|
|
404
|
+
the specified threshold to free up disk space.
|
|
405
|
+
|
|
406
|
+
Args:
|
|
407
|
+
days_threshold (int): Number of days after which files are considered old.
|
|
408
|
+
Defaults to 7.
|
|
370
409
|
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
- file_extension (str): checkpoint 文件的扩展名,默认是 ".ckpt"。
|
|
374
|
-
- days_threshold (int): 超过多少天的文件将被删除,默认是 7 天。
|
|
410
|
+
Example:
|
|
411
|
+
>>> clean_temp_checkpoint(3) # Remove files older than 3 days
|
|
375
412
|
"""
|
|
376
|
-
|
|
377
|
-
import
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
413
|
+
import glob
|
|
414
|
+
import time
|
|
415
|
+
|
|
416
|
+
temp_patterns = [
|
|
417
|
+
"temp_checkpoint_*",
|
|
418
|
+
"checkpoint_*",
|
|
419
|
+
"*.tmp",
|
|
420
|
+
"*.temp",
|
|
421
|
+
]
|
|
422
|
+
|
|
423
|
+
current_time = time.time()
|
|
424
|
+
threshold_time = current_time - (days_threshold * 24 * 60 * 60)
|
|
425
|
+
|
|
426
|
+
for pattern in temp_patterns:
|
|
427
|
+
for file_path in glob.glob(pattern):
|
|
389
428
|
try:
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
print(f"Error deleting {file_path}: {e}")
|
|
429
|
+
if os.path.getmtime(file_path) < threshold_time:
|
|
430
|
+
os.remove(file_path)
|
|
431
|
+
except Exception:
|
|
432
|
+
pass
|
|
395
433
|
|
|
396
434
|
|
|
397
435
|
def load_module_from_path(module_name, file_path):
|
|
398
|
-
|
|
436
|
+
"""
|
|
437
|
+
Load a Python module from a file path.
|
|
438
|
+
|
|
439
|
+
This function dynamically loads a Python module from a file path,
|
|
440
|
+
useful for loading configuration files or custom modules.
|
|
441
|
+
|
|
442
|
+
Args:
|
|
443
|
+
module_name (str): Name to assign to the loaded module
|
|
444
|
+
file_path (str): Path to the Python file to load
|
|
445
|
+
|
|
446
|
+
Returns:
|
|
447
|
+
module: The loaded module object
|
|
448
|
+
|
|
449
|
+
Example:
|
|
450
|
+
>>> config = load_module_from_path("config", "config.py")
|
|
451
|
+
>>> print(config.some_variable)
|
|
452
|
+
"""
|
|
453
|
+
import importlib.util
|
|
399
454
|
|
|
400
455
|
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
|
401
456
|
module = importlib.util.module_from_spec(spec)
|
|
402
|
-
|
|
403
|
-
spec.loader.exec_module(module)
|
|
404
|
-
except FileNotFoundError:
|
|
405
|
-
raise ImportError(f"Cannot find the module {module_name} from {file_path}.")
|
|
457
|
+
spec.loader.exec_module(module)
|
|
406
458
|
return module
|
|
407
459
|
|
|
408
460
|
|
|
409
461
|
def check_bench_version(bench_version, omnigenome_version):
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
), "Benchmark metadata does not contain a valid __omnigenome__ version."
|
|
462
|
+
"""
|
|
463
|
+
Check if benchmark version is compatible with OmniGenome version.
|
|
413
464
|
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
f"Invalid type for benchmark version. Expected int, float, or str but got {type(bench_version).__name__}."
|
|
417
|
-
)
|
|
465
|
+
This function compares the benchmark version with the OmniGenome version
|
|
466
|
+
to ensure compatibility and warns if there are potential issues.
|
|
418
467
|
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
468
|
+
Args:
|
|
469
|
+
bench_version (str): Version of the benchmark
|
|
470
|
+
omnigenome_version (str): Version of OmniGenome
|
|
422
471
|
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
472
|
+
Example:
|
|
473
|
+
>>> check_bench_version("0.2.0", "0.3.0")
|
|
474
|
+
"""
|
|
475
|
+
if bench_version != omnigenome_version:
|
|
476
|
+
warnings.warn(
|
|
477
|
+
f"Benchmark version ({bench_version}) differs from "
|
|
478
|
+
f"OmniGenome version ({omnigenome_version}). "
|
|
479
|
+
f"This may cause compatibility issues."
|
|
427
480
|
)
|
|
428
481
|
|
|
429
482
|
|
|
430
483
|
def clean_temp_dir_pt_files():
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
484
|
+
"""
|
|
485
|
+
Clean up temporary PyTorch files in the current directory.
|
|
486
|
+
|
|
487
|
+
This function removes temporary PyTorch files (like .pt, .pth files)
|
|
488
|
+
that may be left over from previous runs.
|
|
489
|
+
|
|
490
|
+
Example:
|
|
491
|
+
>>> clean_temp_dir_pt_files()
|
|
492
|
+
"""
|
|
493
|
+
import glob
|
|
494
|
+
|
|
495
|
+
temp_patterns = ["*.pt", "*.pth", "temp_*", "checkpoint_*"]
|
|
496
|
+
|
|
497
|
+
for pattern in temp_patterns:
|
|
498
|
+
for file_path in glob.glob(pattern):
|
|
435
499
|
try:
|
|
436
|
-
os.
|
|
437
|
-
|
|
438
|
-
except Exception
|
|
439
|
-
|
|
500
|
+
if os.path.isfile(file_path):
|
|
501
|
+
os.remove(file_path)
|
|
502
|
+
except Exception:
|
|
503
|
+
pass
|