omnigenome 0.3.0a1__py3-none-any.whl → 0.3.1a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- omnigenome/__init__.py +16 -8
- omnigenome/auto/auto_bench/__init__.py +0 -1
- omnigenome/auto/auto_bench/auto_bench.py +24 -14
- omnigenome/auto/auto_train/__init__.py +0 -1
- omnigenome/auto/auto_train/auto_train.py +11 -12
- omnigenome/auto/bench_hub/__init__.py +0 -1
- omnigenome/auto/bench_hub/bench_hub.py +1 -1
- omnigenome/cli/__init__.py +0 -1
- omnigenome/cli/commands/__init__.py +0 -1
- omnigenome/cli/commands/base.py +10 -10
- omnigenome/cli/commands/bench/__init__.py +0 -1
- omnigenome/cli/commands/bench/bench_cli.py +10 -10
- omnigenome/cli/commands/rna/__init__.py +0 -1
- omnigenome/cli/commands/rna/rna_design.py +10 -11
- omnigenome/src/__init__.py +0 -1
- omnigenome/src/abc/__init__.py +0 -1
- omnigenome/src/abc/abstract_dataset.py +38 -19
- omnigenome/src/abc/abstract_metric.py +7 -7
- omnigenome/src/abc/abstract_model.py +15 -14
- omnigenome/src/abc/abstract_tokenizer.py +9 -7
- omnigenome/src/dataset/omni_dataset.py +16 -14
- omnigenome/src/lora/__init__.py +0 -1
- omnigenome/src/lora/lora_model.py +47 -41
- omnigenome/src/metric/classification_metric.py +11 -11
- omnigenome/src/metric/metric.py +19 -19
- omnigenome/src/metric/ranking_metric.py +15 -15
- omnigenome/src/metric/regression_metric.py +18 -18
- omnigenome/src/misc/utils.py +40 -36
- omnigenome/src/model/augmentation/__init__.py +0 -1
- omnigenome/src/model/augmentation/model.py +17 -17
- omnigenome/src/model/classification/__init__.py +0 -1
- omnigenome/src/model/classification/model.py +28 -32
- omnigenome/src/model/embedding/__init__.py +0 -1
- omnigenome/src/model/embedding/model.py +35 -35
- omnigenome/src/model/mlm/__init__.py +0 -1
- omnigenome/src/model/mlm/model.py +13 -13
- omnigenome/src/model/module_utils.py +17 -17
- omnigenome/src/model/regression/__init__.py +0 -1
- omnigenome/src/model/regression/model.py +72 -77
- omnigenome/src/model/regression/resnet.py +32 -32
- omnigenome/src/model/rna_design/__init__.py +0 -1
- omnigenome/src/model/rna_design/model.py +65 -58
- omnigenome/src/model/seq2seq/__init__.py +0 -1
- omnigenome/src/model/seq2seq/model.py +4 -4
- omnigenome/src/tokenizer/bpe_tokenizer.py +27 -27
- omnigenome/src/tokenizer/kmers_tokenizer.py +22 -22
- omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +11 -11
- omnigenome/src/trainer/accelerate_trainer.py +40 -32
- omnigenome/src/trainer/hf_trainer.py +8 -8
- omnigenome/src/trainer/trainer.py +37 -25
- omnigenome/utility/dataset_hub/__init__.py +0 -1
- omnigenome/utility/dataset_hub/dataset_hub.py +13 -13
- omnigenome/utility/ensemble.py +26 -26
- omnigenome/utility/hub_utils.py +8 -8
- omnigenome/utility/model_hub/__init__.py +0 -1
- omnigenome/utility/model_hub/model_hub.py +26 -25
- omnigenome/utility/pipeline_hub/__init__.py +0 -1
- omnigenome/utility/pipeline_hub/pipeline.py +49 -49
- omnigenome/utility/pipeline_hub/pipeline_hub.py +17 -17
- {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/METADATA +2 -2
- omnigenome-0.3.1a0.dist-info/RECORD +78 -0
- omnigenome-0.3.0a1.dist-info/RECORD +0 -78
- {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/WHEEL +0 -0
- {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/entry_points.txt +0 -0
- {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/licenses/LICENSE +0 -0
- {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/top_level.txt +0 -0
omnigenome/src/metric/metric.py
CHANGED
|
@@ -20,20 +20,20 @@ from ..abc.abstract_metric import OmniMetric
|
|
|
20
20
|
def mcrmse(y_true, y_pred):
|
|
21
21
|
"""
|
|
22
22
|
Compute Mean Column Root Mean Square Error (MCRMSE).
|
|
23
|
-
|
|
23
|
+
|
|
24
24
|
MCRMSE is a multi-target regression metric that computes the RMSE for each target
|
|
25
25
|
column and then takes the mean across all targets.
|
|
26
|
-
|
|
26
|
+
|
|
27
27
|
Args:
|
|
28
28
|
y_true (np.ndarray): Ground truth values with shape (n_samples, n_targets)
|
|
29
29
|
y_pred (np.ndarray): Predicted values with shape (n_samples, n_targets)
|
|
30
|
-
|
|
30
|
+
|
|
31
31
|
Returns:
|
|
32
32
|
float: Mean Column Root Mean Square Error
|
|
33
|
-
|
|
33
|
+
|
|
34
34
|
Raises:
|
|
35
35
|
ValueError: If y_true and y_pred have different shapes
|
|
36
|
-
|
|
36
|
+
|
|
37
37
|
Example:
|
|
38
38
|
>>> y_true = np.array([[1, 2], [3, 4], [5, 6]])
|
|
39
39
|
>>> y_pred = np.array([[1.1, 2.1], [2.9, 4.1], [5.2, 5.8]])
|
|
@@ -57,18 +57,18 @@ class Metric(OmniMetric):
|
|
|
57
57
|
"""
|
|
58
58
|
A flexible metric class that provides access to all scikit-learn metrics
|
|
59
59
|
and custom metrics for evaluation.
|
|
60
|
-
|
|
60
|
+
|
|
61
61
|
This class dynamically wraps scikit-learn metrics and provides a unified
|
|
62
62
|
interface for computing various evaluation metrics. It handles different
|
|
63
63
|
input formats including HuggingFace trainer outputs and supports
|
|
64
64
|
custom metric functions.
|
|
65
|
-
|
|
65
|
+
|
|
66
66
|
Attributes:
|
|
67
67
|
metric_func: Custom metric function if provided
|
|
68
68
|
ignore_y: Value to ignore in predictions and true values
|
|
69
69
|
kwargs: Additional keyword arguments for metric computation
|
|
70
70
|
metrics: Dictionary of available metrics including custom ones
|
|
71
|
-
|
|
71
|
+
|
|
72
72
|
Example:
|
|
73
73
|
>>> from omnigenome.src.metric import Metric
|
|
74
74
|
>>> metric = Metric(ignore_y=-100)
|
|
@@ -82,7 +82,7 @@ class Metric(OmniMetric):
|
|
|
82
82
|
def __init__(self, metric_func=None, ignore_y=-100, *args, **kwargs):
|
|
83
83
|
"""
|
|
84
84
|
Initialize the Metric class.
|
|
85
|
-
|
|
85
|
+
|
|
86
86
|
Args:
|
|
87
87
|
metric_func (callable, optional): Custom metric function to use
|
|
88
88
|
ignore_y (int, optional): Value to ignore in predictions and true values. Defaults to -100
|
|
@@ -98,14 +98,14 @@ class Metric(OmniMetric):
|
|
|
98
98
|
def __getattribute__(self, name):
|
|
99
99
|
"""
|
|
100
100
|
Dynamically create metric computation methods.
|
|
101
|
-
|
|
101
|
+
|
|
102
102
|
This method intercepts attribute access and creates wrapper functions
|
|
103
103
|
for scikit-learn metrics, handling different input formats and
|
|
104
104
|
preprocessing the data appropriately.
|
|
105
|
-
|
|
105
|
+
|
|
106
106
|
Args:
|
|
107
107
|
name (str): Name of the metric to access
|
|
108
|
-
|
|
108
|
+
|
|
109
109
|
Returns:
|
|
110
110
|
callable: Wrapper function for the requested metric
|
|
111
111
|
"""
|
|
@@ -119,20 +119,20 @@ class Metric(OmniMetric):
|
|
|
119
119
|
def wrapper(y_true=None, y_score=None, *args, **kwargs):
|
|
120
120
|
"""
|
|
121
121
|
Compute the metric, based on the true and predicted values.
|
|
122
|
-
|
|
122
|
+
|
|
123
123
|
This wrapper handles different input formats including HuggingFace
|
|
124
124
|
trainer outputs and performs necessary preprocessing.
|
|
125
|
-
|
|
125
|
+
|
|
126
126
|
Args:
|
|
127
127
|
y_true: The true values or HuggingFace EvalPrediction object
|
|
128
128
|
y_score: The predicted values
|
|
129
129
|
ignore_y: The value to ignore in the predictions and true values in corresponding positions
|
|
130
130
|
*args: Additional positional arguments for the metric
|
|
131
131
|
**kwargs: Additional keyword arguments for the metric
|
|
132
|
-
|
|
132
|
+
|
|
133
133
|
Returns:
|
|
134
134
|
dict: Dictionary containing the metric name and computed value
|
|
135
|
-
|
|
135
|
+
|
|
136
136
|
Raises:
|
|
137
137
|
ValueError: If neither y_true nor y_score is provided
|
|
138
138
|
"""
|
|
@@ -176,16 +176,16 @@ class Metric(OmniMetric):
|
|
|
176
176
|
def compute(self, y_true, y_score, *args, **kwargs):
|
|
177
177
|
"""
|
|
178
178
|
Compute the metric, based on the true and predicted values.
|
|
179
|
-
|
|
179
|
+
|
|
180
180
|
Args:
|
|
181
181
|
y_true: The true values
|
|
182
182
|
y_score: The predicted values
|
|
183
183
|
*args: Additional positional arguments for the metric
|
|
184
184
|
**kwargs: Additional keyword arguments for the metric
|
|
185
|
-
|
|
185
|
+
|
|
186
186
|
Returns:
|
|
187
187
|
The computed metric value
|
|
188
|
-
|
|
188
|
+
|
|
189
189
|
Raises:
|
|
190
190
|
NotImplementedError: If no metric function is provided and compute is not implemented
|
|
191
191
|
"""
|
|
@@ -20,16 +20,16 @@ from ..abc.abstract_metric import OmniMetric
|
|
|
20
20
|
class RankingMetric(OmniMetric):
|
|
21
21
|
"""
|
|
22
22
|
A specialized metric class for ranking tasks and evaluation.
|
|
23
|
-
|
|
23
|
+
|
|
24
24
|
This class provides access to ranking-specific metrics from scikit-learn
|
|
25
25
|
and handles different input formats including HuggingFace trainer outputs.
|
|
26
26
|
It dynamically wraps scikit-learn metrics and provides a unified interface
|
|
27
27
|
for computing various ranking evaluation metrics.
|
|
28
|
-
|
|
28
|
+
|
|
29
29
|
Attributes:
|
|
30
30
|
metric_func: Custom metric function if provided
|
|
31
31
|
ignore_y: Value to ignore in predictions and true values
|
|
32
|
-
|
|
32
|
+
|
|
33
33
|
Example:
|
|
34
34
|
>>> from omnigenome.src.metric import RankingMetric
|
|
35
35
|
>>> metric = RankingMetric(ignore_y=-100)
|
|
@@ -43,7 +43,7 @@ class RankingMetric(OmniMetric):
|
|
|
43
43
|
def __init__(self, *args, **kwargs):
|
|
44
44
|
"""
|
|
45
45
|
Initialize the RankingMetric class.
|
|
46
|
-
|
|
46
|
+
|
|
47
47
|
Args:
|
|
48
48
|
*args: Additional positional arguments passed to parent class
|
|
49
49
|
**kwargs: Additional keyword arguments passed to parent class
|
|
@@ -53,17 +53,17 @@ class RankingMetric(OmniMetric):
|
|
|
53
53
|
def __getattr__(self, name):
|
|
54
54
|
"""
|
|
55
55
|
Dynamically create ranking metric computation methods.
|
|
56
|
-
|
|
56
|
+
|
|
57
57
|
This method intercepts attribute access and creates wrapper functions
|
|
58
58
|
for scikit-learn ranking metrics, handling different input formats and
|
|
59
59
|
preprocessing the data appropriately.
|
|
60
|
-
|
|
60
|
+
|
|
61
61
|
Args:
|
|
62
62
|
name (str): Name of the ranking metric to access
|
|
63
|
-
|
|
63
|
+
|
|
64
64
|
Returns:
|
|
65
65
|
callable: Wrapper function for the requested ranking metric
|
|
66
|
-
|
|
66
|
+
|
|
67
67
|
Raises:
|
|
68
68
|
AttributeError: If the requested metric is not found
|
|
69
69
|
"""
|
|
@@ -74,17 +74,17 @@ class RankingMetric(OmniMetric):
|
|
|
74
74
|
def wrapper(y_true=None, y_score=None, *args, **kwargs):
|
|
75
75
|
"""
|
|
76
76
|
Compute the ranking metric, based on the true and predicted values.
|
|
77
|
-
|
|
77
|
+
|
|
78
78
|
This wrapper handles different input formats including HuggingFace
|
|
79
79
|
trainer outputs and performs necessary preprocessing for ranking tasks.
|
|
80
|
-
|
|
80
|
+
|
|
81
81
|
Args:
|
|
82
82
|
y_true: The true values or HuggingFace EvalPrediction object
|
|
83
83
|
y_score: The predicted values (scores for ranking)
|
|
84
84
|
ignore_y: The value to ignore in the predictions and true values in corresponding positions
|
|
85
85
|
*args: Additional positional arguments for the metric
|
|
86
86
|
**kwargs: Additional keyword arguments for the metric
|
|
87
|
-
|
|
87
|
+
|
|
88
88
|
Returns:
|
|
89
89
|
dict: Dictionary containing the metric name and computed value
|
|
90
90
|
"""
|
|
@@ -121,19 +121,19 @@ class RankingMetric(OmniMetric):
|
|
|
121
121
|
def compute(self, y_true, y_score, *args, **kwargs):
|
|
122
122
|
"""
|
|
123
123
|
Compute the ranking metric, based on the true and predicted values.
|
|
124
|
-
|
|
124
|
+
|
|
125
125
|
This method should be implemented by subclasses to provide specific
|
|
126
126
|
ranking metric computation logic.
|
|
127
|
-
|
|
127
|
+
|
|
128
128
|
Args:
|
|
129
129
|
y_true: The true values
|
|
130
130
|
y_score: The predicted values (scores for ranking)
|
|
131
131
|
*args: Additional positional arguments for the metric
|
|
132
132
|
**kwargs: Additional keyword arguments for the metric
|
|
133
|
-
|
|
133
|
+
|
|
134
134
|
Returns:
|
|
135
135
|
The computed ranking metric value
|
|
136
|
-
|
|
136
|
+
|
|
137
137
|
Raises:
|
|
138
138
|
NotImplementedError: If compute method is not implemented in the child class
|
|
139
139
|
"""
|
|
@@ -20,20 +20,20 @@ from ..abc.abstract_metric import OmniMetric
|
|
|
20
20
|
def mcrmse(y_true, y_pred):
|
|
21
21
|
"""
|
|
22
22
|
Compute Mean Column Root Mean Square Error (MCRMSE).
|
|
23
|
-
|
|
23
|
+
|
|
24
24
|
MCRMSE is a multi-target regression metric that computes the RMSE for each target
|
|
25
25
|
column and then takes the mean across all targets.
|
|
26
|
-
|
|
26
|
+
|
|
27
27
|
Args:
|
|
28
28
|
y_true (np.ndarray): Ground truth values with shape (n_samples, n_targets)
|
|
29
29
|
y_pred (np.ndarray): Predicted values with shape (n_samples, n_targets)
|
|
30
|
-
|
|
30
|
+
|
|
31
31
|
Returns:
|
|
32
32
|
float: Mean Column Root Mean Square Error
|
|
33
|
-
|
|
33
|
+
|
|
34
34
|
Raises:
|
|
35
35
|
ValueError: If y_true and y_pred have different shapes
|
|
36
|
-
|
|
36
|
+
|
|
37
37
|
Example:
|
|
38
38
|
>>> y_true = np.array([[1, 2], [3, 4], [5, 6]])
|
|
39
39
|
>>> y_pred = np.array([[1.1, 2.1], [2.9, 4.1], [5.2, 5.8]])
|
|
@@ -56,18 +56,18 @@ setattr(metrics, "mcrmse", mcrmse)
|
|
|
56
56
|
class RegressionMetric(OmniMetric):
|
|
57
57
|
"""
|
|
58
58
|
A specialized metric class for regression tasks and evaluation.
|
|
59
|
-
|
|
59
|
+
|
|
60
60
|
This class provides access to regression-specific metrics from scikit-learn
|
|
61
61
|
and handles different input formats including HuggingFace trainer outputs.
|
|
62
62
|
It dynamically wraps scikit-learn metrics and provides a unified interface
|
|
63
63
|
for computing various regression evaluation metrics.
|
|
64
|
-
|
|
64
|
+
|
|
65
65
|
Attributes:
|
|
66
66
|
metric_func: Custom metric function if provided
|
|
67
67
|
ignore_y: Value to ignore in predictions and true values
|
|
68
68
|
kwargs: Additional keyword arguments for metric computation
|
|
69
69
|
metrics: Dictionary of available metrics including custom ones
|
|
70
|
-
|
|
70
|
+
|
|
71
71
|
Example:
|
|
72
72
|
>>> from omnigenome.src.metric import RegressionMetric
|
|
73
73
|
>>> metric = RegressionMetric(ignore_y=-100)
|
|
@@ -81,7 +81,7 @@ class RegressionMetric(OmniMetric):
|
|
|
81
81
|
def __init__(self, metric_func=None, ignore_y=-100, *args, **kwargs):
|
|
82
82
|
"""
|
|
83
83
|
Initialize the RegressionMetric class.
|
|
84
|
-
|
|
84
|
+
|
|
85
85
|
Args:
|
|
86
86
|
metric_func (callable, optional): Custom metric function to use
|
|
87
87
|
ignore_y (int, optional): Value to ignore in predictions and true values. Defaults to -100
|
|
@@ -97,14 +97,14 @@ class RegressionMetric(OmniMetric):
|
|
|
97
97
|
def __getattribute__(self, name):
|
|
98
98
|
"""
|
|
99
99
|
Dynamically create regression metric computation methods.
|
|
100
|
-
|
|
100
|
+
|
|
101
101
|
This method intercepts attribute access and creates wrapper functions
|
|
102
102
|
for scikit-learn regression metrics, handling different input formats and
|
|
103
103
|
preprocessing the data appropriately.
|
|
104
|
-
|
|
104
|
+
|
|
105
105
|
Args:
|
|
106
106
|
name (str): Name of the regression metric to access
|
|
107
|
-
|
|
107
|
+
|
|
108
108
|
Returns:
|
|
109
109
|
callable: Wrapper function for the requested regression metric
|
|
110
110
|
"""
|
|
@@ -118,17 +118,17 @@ class RegressionMetric(OmniMetric):
|
|
|
118
118
|
def wrapper(y_true=None, y_score=None, *args, **kwargs):
|
|
119
119
|
"""
|
|
120
120
|
Compute the regression metric, based on the true and predicted values.
|
|
121
|
-
|
|
121
|
+
|
|
122
122
|
This wrapper handles different input formats including HuggingFace
|
|
123
123
|
trainer outputs and performs necessary preprocessing for regression tasks.
|
|
124
|
-
|
|
124
|
+
|
|
125
125
|
Args:
|
|
126
126
|
y_true: The true values or HuggingFace EvalPrediction object
|
|
127
127
|
y_score: The predicted values
|
|
128
128
|
ignore_y: The value to ignore in the predictions and true values in corresponding positions
|
|
129
129
|
*args: Additional positional arguments for the metric
|
|
130
130
|
**kwargs: Additional keyword arguments for the metric
|
|
131
|
-
|
|
131
|
+
|
|
132
132
|
Returns:
|
|
133
133
|
dict: Dictionary containing the metric name and computed value
|
|
134
134
|
"""
|
|
@@ -168,16 +168,16 @@ class RegressionMetric(OmniMetric):
|
|
|
168
168
|
def compute(self, y_true, y_score, *args, **kwargs):
|
|
169
169
|
"""
|
|
170
170
|
Compute the regression metric, based on the true and predicted values.
|
|
171
|
-
|
|
171
|
+
|
|
172
172
|
Args:
|
|
173
173
|
y_true: The true values
|
|
174
174
|
y_score: The predicted values
|
|
175
175
|
*args: Additional positional arguments for the metric
|
|
176
176
|
**kwargs: Additional keyword arguments for the metric
|
|
177
|
-
|
|
177
|
+
|
|
178
178
|
Returns:
|
|
179
179
|
The computed regression metric value
|
|
180
|
-
|
|
180
|
+
|
|
181
181
|
Raises:
|
|
182
182
|
NotImplementedError: If no metric function is provided and compute is not implemented
|
|
183
183
|
"""
|
omnigenome/src/misc/utils.py
CHANGED
|
@@ -25,13 +25,13 @@ default_omnigenome_repo = (
|
|
|
25
25
|
def seed_everything(seed=42):
|
|
26
26
|
"""
|
|
27
27
|
Sets random seeds for reproducibility across all random number generators.
|
|
28
|
-
|
|
28
|
+
|
|
29
29
|
This function sets seeds for Python's random module, NumPy, PyTorch (CPU and CUDA),
|
|
30
30
|
and sets the PYTHONHASHSEED environment variable to ensure reproducible results
|
|
31
31
|
across different runs.
|
|
32
|
-
|
|
32
|
+
|
|
33
33
|
Args:
|
|
34
|
-
seed (int): The seed value to use for all random number generators.
|
|
34
|
+
seed (int): The seed value to use for all random number generators.
|
|
35
35
|
Defaults to 42.
|
|
36
36
|
|
|
37
37
|
Example:
|
|
@@ -55,11 +55,11 @@ def seed_everything(seed=42):
|
|
|
55
55
|
class RNA2StructureCache(dict):
|
|
56
56
|
"""
|
|
57
57
|
A cache for RNA secondary structure predictions using ViennaRNA.
|
|
58
|
-
|
|
58
|
+
|
|
59
59
|
This class provides a caching mechanism for RNA secondary structure predictions
|
|
60
60
|
to avoid redundant computations. It supports both single sequence and batch
|
|
61
61
|
processing with optional multiprocessing for improved performance.
|
|
62
|
-
|
|
62
|
+
|
|
63
63
|
Attributes:
|
|
64
64
|
cache (dict): Dictionary storing sequence-structure mappings
|
|
65
65
|
cache_file (str): Path to the cache file on disk
|
|
@@ -69,7 +69,7 @@ class RNA2StructureCache(dict):
|
|
|
69
69
|
def __init__(self, cache_file=None, *args, **kwargs):
|
|
70
70
|
"""
|
|
71
71
|
Initialize the RNA structure cache.
|
|
72
|
-
|
|
72
|
+
|
|
73
73
|
Args:
|
|
74
74
|
cache_file (str, optional): Path to the cache file. If None, uses
|
|
75
75
|
a default temporary file.
|
|
@@ -112,10 +112,10 @@ class RNA2StructureCache(dict):
|
|
|
112
112
|
def _fold_single_sequence(self, sequence):
|
|
113
113
|
"""
|
|
114
114
|
Predict structure for a single sequence (worker function for multiprocessing).
|
|
115
|
-
|
|
115
|
+
|
|
116
116
|
Args:
|
|
117
117
|
sequence (str): RNA sequence to fold
|
|
118
|
-
|
|
118
|
+
|
|
119
119
|
Returns:
|
|
120
120
|
tuple: (structure, mfe) tuple
|
|
121
121
|
"""
|
|
@@ -128,12 +128,12 @@ class RNA2StructureCache(dict):
|
|
|
128
128
|
def fold(self, sequence, return_mfe=False, num_workers=1):
|
|
129
129
|
"""
|
|
130
130
|
Predicts RNA secondary structure for given sequences.
|
|
131
|
-
|
|
131
|
+
|
|
132
132
|
This method predicts RNA secondary structures using ViennaRNA. It supports
|
|
133
133
|
both single sequences and batches of sequences. The method uses caching
|
|
134
134
|
to avoid redundant predictions and supports multiprocessing for batch
|
|
135
135
|
processing on non-Windows systems.
|
|
136
|
-
|
|
136
|
+
|
|
137
137
|
Args:
|
|
138
138
|
sequence (str or list): A single RNA sequence or a list of sequences.
|
|
139
139
|
return_mfe (bool): Whether to return minimum free energy along with
|
|
@@ -150,7 +150,7 @@ class RNA2StructureCache(dict):
|
|
|
150
150
|
>>> # Predict structure for a single sequence
|
|
151
151
|
>>> structure = cache.fold("GGGAAAUCC")
|
|
152
152
|
>>> print(structure) # "(((...)))"
|
|
153
|
-
|
|
153
|
+
|
|
154
154
|
>>> # Predict structures for multiple sequences
|
|
155
155
|
>>> structures = cache.fold(["GGGAAAUCC", "AUUGCUAA"])
|
|
156
156
|
>>> print(structures) # ["(((...)))", "........"]
|
|
@@ -162,36 +162,40 @@ class RNA2StructureCache(dict):
|
|
|
162
162
|
|
|
163
163
|
# Determine if we should use multiprocessing
|
|
164
164
|
use_multiprocessing = (
|
|
165
|
-
os.name != "nt"
|
|
166
|
-
len(sequences) > 1
|
|
167
|
-
num_workers > 1 # Multiple workers requested
|
|
165
|
+
os.name != "nt" # Not Windows
|
|
166
|
+
and len(sequences) > 1 # Multiple sequences
|
|
167
|
+
and num_workers > 1 # Multiple workers requested
|
|
168
168
|
)
|
|
169
169
|
|
|
170
170
|
# Find sequences that need prediction
|
|
171
171
|
sequences_to_predict = [seq for seq in sequences if seq not in self.cache]
|
|
172
|
-
|
|
172
|
+
|
|
173
173
|
if sequences_to_predict:
|
|
174
174
|
if use_multiprocessing:
|
|
175
175
|
# Use multiprocessing for batch prediction
|
|
176
176
|
if num_workers is None:
|
|
177
177
|
num_workers = min(os.cpu_count(), len(sequences_to_predict))
|
|
178
|
-
|
|
178
|
+
|
|
179
179
|
try:
|
|
180
180
|
# Set multiprocessing start method to 'spawn' for better compatibility
|
|
181
|
-
if multiprocessing.get_start_method(allow_none=True) !=
|
|
182
|
-
multiprocessing.set_start_method(
|
|
183
|
-
|
|
181
|
+
if multiprocessing.get_start_method(allow_none=True) != "spawn":
|
|
182
|
+
multiprocessing.set_start_method("spawn", force=True)
|
|
183
|
+
|
|
184
184
|
with multiprocessing.Pool(num_workers) as pool:
|
|
185
185
|
# Use map instead of apply_async for better error handling
|
|
186
|
-
results = pool.map(
|
|
187
|
-
|
|
186
|
+
results = pool.map(
|
|
187
|
+
self._fold_single_sequence, sequences_to_predict
|
|
188
|
+
)
|
|
189
|
+
|
|
188
190
|
# Update cache with results
|
|
189
191
|
for seq, result in zip(sequences_to_predict, results):
|
|
190
192
|
self.cache[seq] = result
|
|
191
193
|
self.queue_num += 1
|
|
192
|
-
|
|
194
|
+
|
|
193
195
|
except Exception as e:
|
|
194
|
-
warnings.warn(
|
|
196
|
+
warnings.warn(
|
|
197
|
+
f"Multiprocessing failed, falling back to sequential: {e}"
|
|
198
|
+
)
|
|
195
199
|
# Fallback to sequential processing
|
|
196
200
|
for seq in sequences_to_predict:
|
|
197
201
|
self.cache[seq] = self._fold_single_sequence(seq)
|
|
@@ -207,7 +211,7 @@ class RNA2StructureCache(dict):
|
|
|
207
211
|
structures = [self.cache[seq] for seq in sequences]
|
|
208
212
|
else:
|
|
209
213
|
structures = [self.cache[seq][0] for seq in sequences]
|
|
210
|
-
|
|
214
|
+
|
|
211
215
|
# Update cache file periodically
|
|
212
216
|
self.update_cache_file(self.cache_file)
|
|
213
217
|
|
|
@@ -220,10 +224,10 @@ class RNA2StructureCache(dict):
|
|
|
220
224
|
def update_cache_file(self, cache_file=None):
|
|
221
225
|
"""
|
|
222
226
|
Updates the cache file on disk.
|
|
223
|
-
|
|
227
|
+
|
|
224
228
|
This method saves the in-memory cache to disk. It only saves when
|
|
225
229
|
the queue_num reaches 100 to avoid excessive disk I/O.
|
|
226
|
-
|
|
230
|
+
|
|
227
231
|
Args:
|
|
228
232
|
cache_file (str, optional): Path to the cache file. If None, uses
|
|
229
233
|
the instance's cache_file.
|
|
@@ -252,11 +256,11 @@ class RNA2StructureCache(dict):
|
|
|
252
256
|
def env_meta_info():
|
|
253
257
|
"""
|
|
254
258
|
Collects metadata about the current environment and library versions.
|
|
255
|
-
|
|
259
|
+
|
|
256
260
|
This function gathers information about the current Python environment,
|
|
257
261
|
including versions of key libraries like PyTorch and Transformers,
|
|
258
262
|
as well as OmniGenome version information.
|
|
259
|
-
|
|
263
|
+
|
|
260
264
|
Returns:
|
|
261
265
|
dict: A dictionary containing environment metadata including:
|
|
262
266
|
- library_name: Name of the OmniGenome library
|
|
@@ -286,7 +290,7 @@ def env_meta_info():
|
|
|
286
290
|
def naive_secondary_structure_repair(sequence, structure):
|
|
287
291
|
"""
|
|
288
292
|
Repair the secondary structure of a sequence.
|
|
289
|
-
|
|
293
|
+
|
|
290
294
|
This function attempts to repair malformed RNA secondary structure
|
|
291
295
|
representations by ensuring proper bracket matching. It handles
|
|
292
296
|
common issues like unmatched brackets by converting them to dots.
|
|
@@ -324,7 +328,7 @@ def naive_secondary_structure_repair(sequence, structure):
|
|
|
324
328
|
def save_args(config, save_path):
|
|
325
329
|
"""
|
|
326
330
|
Save arguments to a file.
|
|
327
|
-
|
|
331
|
+
|
|
328
332
|
This function saves the arguments from a configuration object to a text file.
|
|
329
333
|
It's useful for logging experiment parameters and configurations.
|
|
330
334
|
|
|
@@ -347,7 +351,7 @@ def save_args(config, save_path):
|
|
|
347
351
|
def print_args(config, logger=None):
|
|
348
352
|
"""
|
|
349
353
|
Print the arguments to the console.
|
|
350
|
-
|
|
354
|
+
|
|
351
355
|
This function prints the arguments from a configuration object to the console
|
|
352
356
|
or a logger. It's useful for debugging and logging experiment parameters.
|
|
353
357
|
|
|
@@ -373,7 +377,7 @@ def print_args(config, logger=None):
|
|
|
373
377
|
def fprint(*objects, sep=" ", end="\n", file=sys.stdout, flush=False):
|
|
374
378
|
"""
|
|
375
379
|
Enhanced print function with automatic flushing.
|
|
376
|
-
|
|
380
|
+
|
|
377
381
|
This function provides a print-like interface with automatic flushing
|
|
378
382
|
to ensure output is displayed immediately. It's useful for real-time
|
|
379
383
|
logging and progress tracking.
|
|
@@ -395,7 +399,7 @@ def fprint(*objects, sep=" ", end="\n", file=sys.stdout, flush=False):
|
|
|
395
399
|
def clean_temp_checkpoint(days_threshold=7):
|
|
396
400
|
"""
|
|
397
401
|
Clean up temporary checkpoint files older than specified days.
|
|
398
|
-
|
|
402
|
+
|
|
399
403
|
This function removes temporary checkpoint files that are older than
|
|
400
404
|
the specified threshold to free up disk space.
|
|
401
405
|
|
|
@@ -431,7 +435,7 @@ def clean_temp_checkpoint(days_threshold=7):
|
|
|
431
435
|
def load_module_from_path(module_name, file_path):
|
|
432
436
|
"""
|
|
433
437
|
Load a Python module from a file path.
|
|
434
|
-
|
|
438
|
+
|
|
435
439
|
This function dynamically loads a Python module from a file path,
|
|
436
440
|
useful for loading configuration files or custom modules.
|
|
437
441
|
|
|
@@ -457,7 +461,7 @@ def load_module_from_path(module_name, file_path):
|
|
|
457
461
|
def check_bench_version(bench_version, omnigenome_version):
|
|
458
462
|
"""
|
|
459
463
|
Check if benchmark version is compatible with OmniGenome version.
|
|
460
|
-
|
|
464
|
+
|
|
461
465
|
This function compares the benchmark version with the OmniGenome version
|
|
462
466
|
to ensure compatibility and warns if there are potential issues.
|
|
463
467
|
|
|
@@ -479,7 +483,7 @@ def check_bench_version(bench_version, omnigenome_version):
|
|
|
479
483
|
def clean_temp_dir_pt_files():
|
|
480
484
|
"""
|
|
481
485
|
Clean up temporary PyTorch files in the current directory.
|
|
482
|
-
|
|
486
|
+
|
|
483
487
|
This function removes temporary PyTorch files (like .pt, .pth files)
|
|
484
488
|
that may be left over from previous runs.
|
|
485
489
|
|
|
@@ -24,12 +24,12 @@ import autocuda
|
|
|
24
24
|
class OmniModelForAugmentation(torch.nn.Module):
|
|
25
25
|
"""
|
|
26
26
|
Data augmentation model for genomic sequences using masked language modeling.
|
|
27
|
-
|
|
27
|
+
|
|
28
28
|
This model uses a pre-trained masked language model to generate augmented
|
|
29
29
|
versions of genomic sequences by randomly masking tokens and predicting
|
|
30
30
|
replacements. It's useful for expanding training datasets and improving
|
|
31
31
|
model generalization.
|
|
32
|
-
|
|
32
|
+
|
|
33
33
|
Attributes:
|
|
34
34
|
tokenizer: Tokenizer for processing genomic sequences
|
|
35
35
|
model: Pre-trained masked language model
|
|
@@ -38,7 +38,7 @@ class OmniModelForAugmentation(torch.nn.Module):
|
|
|
38
38
|
max_length: Maximum sequence length for tokenization
|
|
39
39
|
k: Number of augmented instances to generate per sequence
|
|
40
40
|
"""
|
|
41
|
-
|
|
41
|
+
|
|
42
42
|
def __init__(
|
|
43
43
|
self,
|
|
44
44
|
model_name_or_path=None,
|
|
@@ -50,7 +50,7 @@ class OmniModelForAugmentation(torch.nn.Module):
|
|
|
50
50
|
):
|
|
51
51
|
"""
|
|
52
52
|
Initialize the augmentation model.
|
|
53
|
-
|
|
53
|
+
|
|
54
54
|
Args:
|
|
55
55
|
model_name_or_path (str): Path or model name for loading the pre-trained model
|
|
56
56
|
noise_ratio (float): The proportion of tokens to mask in each sequence for augmentation (default: 0.15)
|
|
@@ -82,10 +82,10 @@ class OmniModelForAugmentation(torch.nn.Module):
|
|
|
82
82
|
def load_sequences_from_file(self, input_file):
|
|
83
83
|
"""
|
|
84
84
|
Load sequences from a JSON file.
|
|
85
|
-
|
|
85
|
+
|
|
86
86
|
Args:
|
|
87
87
|
input_file (str): Path to the input JSON file containing sequences
|
|
88
|
-
|
|
88
|
+
|
|
89
89
|
Returns:
|
|
90
90
|
list: List of sequences loaded from the file
|
|
91
91
|
"""
|
|
@@ -98,10 +98,10 @@ class OmniModelForAugmentation(torch.nn.Module):
|
|
|
98
98
|
def apply_noise_to_sequence(self, seq):
|
|
99
99
|
"""
|
|
100
100
|
Apply noise to a single sequence by randomly masking tokens.
|
|
101
|
-
|
|
101
|
+
|
|
102
102
|
Args:
|
|
103
103
|
seq (str): Input genomic sequence
|
|
104
|
-
|
|
104
|
+
|
|
105
105
|
Returns:
|
|
106
106
|
str: Sequence with randomly masked tokens
|
|
107
107
|
"""
|
|
@@ -114,10 +114,10 @@ class OmniModelForAugmentation(torch.nn.Module):
|
|
|
114
114
|
def augment_sequence(self, seq):
|
|
115
115
|
"""
|
|
116
116
|
Perform augmentation on a single sequence by predicting masked tokens.
|
|
117
|
-
|
|
117
|
+
|
|
118
118
|
Args:
|
|
119
119
|
seq (str): Input genomic sequence with masked tokens
|
|
120
|
-
|
|
120
|
+
|
|
121
121
|
Returns:
|
|
122
122
|
str: Augmented sequence with predicted tokens replacing masked tokens
|
|
123
123
|
"""
|
|
@@ -145,11 +145,11 @@ class OmniModelForAugmentation(torch.nn.Module):
|
|
|
145
145
|
def augment(self, seq, k=None):
|
|
146
146
|
"""
|
|
147
147
|
Generate multiple augmented instances for a single sequence.
|
|
148
|
-
|
|
148
|
+
|
|
149
149
|
Args:
|
|
150
150
|
seq (str): Input genomic sequence
|
|
151
151
|
k (int, optional): Number of augmented instances to generate (default: None, uses self.k)
|
|
152
|
-
|
|
152
|
+
|
|
153
153
|
Returns:
|
|
154
154
|
list: List of augmented sequences
|
|
155
155
|
"""
|
|
@@ -163,10 +163,10 @@ class OmniModelForAugmentation(torch.nn.Module):
|
|
|
163
163
|
def augment_sequences(self, sequences):
|
|
164
164
|
"""
|
|
165
165
|
Augment a list of sequences by applying noise and performing MLM-based predictions.
|
|
166
|
-
|
|
166
|
+
|
|
167
167
|
Args:
|
|
168
168
|
sequences (list): List of genomic sequences to augment
|
|
169
|
-
|
|
169
|
+
|
|
170
170
|
Returns:
|
|
171
171
|
list: List of all augmented sequences
|
|
172
172
|
"""
|
|
@@ -179,7 +179,7 @@ class OmniModelForAugmentation(torch.nn.Module):
|
|
|
179
179
|
def save_augmented_sequences(self, augmented_sequences, output_file):
|
|
180
180
|
"""
|
|
181
181
|
Save augmented sequences to a JSON file.
|
|
182
|
-
|
|
182
|
+
|
|
183
183
|
Args:
|
|
184
184
|
augmented_sequences (list): List of augmented sequences to save
|
|
185
185
|
output_file (str): Path to the output JSON file
|
|
@@ -191,10 +191,10 @@ class OmniModelForAugmentation(torch.nn.Module):
|
|
|
191
191
|
def augment_from_file(self, input_file, output_file):
|
|
192
192
|
"""
|
|
193
193
|
Main function to handle the augmentation process from a file input to a file output.
|
|
194
|
-
|
|
194
|
+
|
|
195
195
|
This method loads sequences from an input file, augments them using the MLM model,
|
|
196
196
|
and saves the augmented sequences to an output file.
|
|
197
|
-
|
|
197
|
+
|
|
198
198
|
Args:
|
|
199
199
|
input_file (str): Path to the input file containing sequences
|
|
200
200
|
output_file (str): Path to the output file where augmented sequences will be saved
|