omnigenome 0.3.0a1__py3-none-any.whl → 0.3.1a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. omnigenome/__init__.py +16 -8
  2. omnigenome/auto/auto_bench/__init__.py +0 -1
  3. omnigenome/auto/auto_bench/auto_bench.py +24 -14
  4. omnigenome/auto/auto_train/__init__.py +0 -1
  5. omnigenome/auto/auto_train/auto_train.py +11 -12
  6. omnigenome/auto/bench_hub/__init__.py +0 -1
  7. omnigenome/auto/bench_hub/bench_hub.py +1 -1
  8. omnigenome/cli/__init__.py +0 -1
  9. omnigenome/cli/commands/__init__.py +0 -1
  10. omnigenome/cli/commands/base.py +10 -10
  11. omnigenome/cli/commands/bench/__init__.py +0 -1
  12. omnigenome/cli/commands/bench/bench_cli.py +10 -10
  13. omnigenome/cli/commands/rna/__init__.py +0 -1
  14. omnigenome/cli/commands/rna/rna_design.py +10 -11
  15. omnigenome/src/__init__.py +0 -1
  16. omnigenome/src/abc/__init__.py +0 -1
  17. omnigenome/src/abc/abstract_dataset.py +38 -19
  18. omnigenome/src/abc/abstract_metric.py +7 -7
  19. omnigenome/src/abc/abstract_model.py +15 -14
  20. omnigenome/src/abc/abstract_tokenizer.py +9 -7
  21. omnigenome/src/dataset/omni_dataset.py +16 -14
  22. omnigenome/src/lora/__init__.py +0 -1
  23. omnigenome/src/lora/lora_model.py +47 -41
  24. omnigenome/src/metric/classification_metric.py +11 -11
  25. omnigenome/src/metric/metric.py +19 -19
  26. omnigenome/src/metric/ranking_metric.py +15 -15
  27. omnigenome/src/metric/regression_metric.py +18 -18
  28. omnigenome/src/misc/utils.py +40 -36
  29. omnigenome/src/model/augmentation/__init__.py +0 -1
  30. omnigenome/src/model/augmentation/model.py +17 -17
  31. omnigenome/src/model/classification/__init__.py +0 -1
  32. omnigenome/src/model/classification/model.py +28 -32
  33. omnigenome/src/model/embedding/__init__.py +0 -1
  34. omnigenome/src/model/embedding/model.py +35 -35
  35. omnigenome/src/model/mlm/__init__.py +0 -1
  36. omnigenome/src/model/mlm/model.py +13 -13
  37. omnigenome/src/model/module_utils.py +17 -17
  38. omnigenome/src/model/regression/__init__.py +0 -1
  39. omnigenome/src/model/regression/model.py +72 -77
  40. omnigenome/src/model/regression/resnet.py +32 -32
  41. omnigenome/src/model/rna_design/__init__.py +0 -1
  42. omnigenome/src/model/rna_design/model.py +65 -58
  43. omnigenome/src/model/seq2seq/__init__.py +0 -1
  44. omnigenome/src/model/seq2seq/model.py +4 -4
  45. omnigenome/src/tokenizer/bpe_tokenizer.py +27 -27
  46. omnigenome/src/tokenizer/kmers_tokenizer.py +22 -22
  47. omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +11 -11
  48. omnigenome/src/trainer/accelerate_trainer.py +40 -32
  49. omnigenome/src/trainer/hf_trainer.py +8 -8
  50. omnigenome/src/trainer/trainer.py +37 -25
  51. omnigenome/utility/dataset_hub/__init__.py +0 -1
  52. omnigenome/utility/dataset_hub/dataset_hub.py +13 -13
  53. omnigenome/utility/ensemble.py +26 -26
  54. omnigenome/utility/hub_utils.py +8 -8
  55. omnigenome/utility/model_hub/__init__.py +0 -1
  56. omnigenome/utility/model_hub/model_hub.py +26 -25
  57. omnigenome/utility/pipeline_hub/__init__.py +0 -1
  58. omnigenome/utility/pipeline_hub/pipeline.py +49 -49
  59. omnigenome/utility/pipeline_hub/pipeline_hub.py +17 -17
  60. {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/METADATA +2 -2
  61. omnigenome-0.3.1a0.dist-info/RECORD +78 -0
  62. omnigenome-0.3.0a1.dist-info/RECORD +0 -78
  63. {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/WHEEL +0 -0
  64. {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/entry_points.txt +0 -0
  65. {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/licenses/LICENSE +0 -0
  66. {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/top_level.txt +0 -0
@@ -20,20 +20,20 @@ from ..abc.abstract_metric import OmniMetric
20
20
  def mcrmse(y_true, y_pred):
21
21
  """
22
22
  Compute Mean Column Root Mean Square Error (MCRMSE).
23
-
23
+
24
24
  MCRMSE is a multi-target regression metric that computes the RMSE for each target
25
25
  column and then takes the mean across all targets.
26
-
26
+
27
27
  Args:
28
28
  y_true (np.ndarray): Ground truth values with shape (n_samples, n_targets)
29
29
  y_pred (np.ndarray): Predicted values with shape (n_samples, n_targets)
30
-
30
+
31
31
  Returns:
32
32
  float: Mean Column Root Mean Square Error
33
-
33
+
34
34
  Raises:
35
35
  ValueError: If y_true and y_pred have different shapes
36
-
36
+
37
37
  Example:
38
38
  >>> y_true = np.array([[1, 2], [3, 4], [5, 6]])
39
39
  >>> y_pred = np.array([[1.1, 2.1], [2.9, 4.1], [5.2, 5.8]])
@@ -57,18 +57,18 @@ class Metric(OmniMetric):
57
57
  """
58
58
  A flexible metric class that provides access to all scikit-learn metrics
59
59
  and custom metrics for evaluation.
60
-
60
+
61
61
  This class dynamically wraps scikit-learn metrics and provides a unified
62
62
  interface for computing various evaluation metrics. It handles different
63
63
  input formats including HuggingFace trainer outputs and supports
64
64
  custom metric functions.
65
-
65
+
66
66
  Attributes:
67
67
  metric_func: Custom metric function if provided
68
68
  ignore_y: Value to ignore in predictions and true values
69
69
  kwargs: Additional keyword arguments for metric computation
70
70
  metrics: Dictionary of available metrics including custom ones
71
-
71
+
72
72
  Example:
73
73
  >>> from omnigenome.src.metric import Metric
74
74
  >>> metric = Metric(ignore_y=-100)
@@ -82,7 +82,7 @@ class Metric(OmniMetric):
82
82
  def __init__(self, metric_func=None, ignore_y=-100, *args, **kwargs):
83
83
  """
84
84
  Initialize the Metric class.
85
-
85
+
86
86
  Args:
87
87
  metric_func (callable, optional): Custom metric function to use
88
88
  ignore_y (int, optional): Value to ignore in predictions and true values. Defaults to -100
@@ -98,14 +98,14 @@ class Metric(OmniMetric):
98
98
  def __getattribute__(self, name):
99
99
  """
100
100
  Dynamically create metric computation methods.
101
-
101
+
102
102
  This method intercepts attribute access and creates wrapper functions
103
103
  for scikit-learn metrics, handling different input formats and
104
104
  preprocessing the data appropriately.
105
-
105
+
106
106
  Args:
107
107
  name (str): Name of the metric to access
108
-
108
+
109
109
  Returns:
110
110
  callable: Wrapper function for the requested metric
111
111
  """
@@ -119,20 +119,20 @@ class Metric(OmniMetric):
119
119
  def wrapper(y_true=None, y_score=None, *args, **kwargs):
120
120
  """
121
121
  Compute the metric, based on the true and predicted values.
122
-
122
+
123
123
  This wrapper handles different input formats including HuggingFace
124
124
  trainer outputs and performs necessary preprocessing.
125
-
125
+
126
126
  Args:
127
127
  y_true: The true values or HuggingFace EvalPrediction object
128
128
  y_score: The predicted values
129
129
  ignore_y: The value to ignore in the predictions and true values in corresponding positions
130
130
  *args: Additional positional arguments for the metric
131
131
  **kwargs: Additional keyword arguments for the metric
132
-
132
+
133
133
  Returns:
134
134
  dict: Dictionary containing the metric name and computed value
135
-
135
+
136
136
  Raises:
137
137
  ValueError: If neither y_true nor y_score is provided
138
138
  """
@@ -176,16 +176,16 @@ class Metric(OmniMetric):
176
176
  def compute(self, y_true, y_score, *args, **kwargs):
177
177
  """
178
178
  Compute the metric, based on the true and predicted values.
179
-
179
+
180
180
  Args:
181
181
  y_true: The true values
182
182
  y_score: The predicted values
183
183
  *args: Additional positional arguments for the metric
184
184
  **kwargs: Additional keyword arguments for the metric
185
-
185
+
186
186
  Returns:
187
187
  The computed metric value
188
-
188
+
189
189
  Raises:
190
190
  NotImplementedError: If no metric function is provided and compute is not implemented
191
191
  """
@@ -20,16 +20,16 @@ from ..abc.abstract_metric import OmniMetric
20
20
  class RankingMetric(OmniMetric):
21
21
  """
22
22
  A specialized metric class for ranking tasks and evaluation.
23
-
23
+
24
24
  This class provides access to ranking-specific metrics from scikit-learn
25
25
  and handles different input formats including HuggingFace trainer outputs.
26
26
  It dynamically wraps scikit-learn metrics and provides a unified interface
27
27
  for computing various ranking evaluation metrics.
28
-
28
+
29
29
  Attributes:
30
30
  metric_func: Custom metric function if provided
31
31
  ignore_y: Value to ignore in predictions and true values
32
-
32
+
33
33
  Example:
34
34
  >>> from omnigenome.src.metric import RankingMetric
35
35
  >>> metric = RankingMetric(ignore_y=-100)
@@ -43,7 +43,7 @@ class RankingMetric(OmniMetric):
43
43
  def __init__(self, *args, **kwargs):
44
44
  """
45
45
  Initialize the RankingMetric class.
46
-
46
+
47
47
  Args:
48
48
  *args: Additional positional arguments passed to parent class
49
49
  **kwargs: Additional keyword arguments passed to parent class
@@ -53,17 +53,17 @@ class RankingMetric(OmniMetric):
53
53
  def __getattr__(self, name):
54
54
  """
55
55
  Dynamically create ranking metric computation methods.
56
-
56
+
57
57
  This method intercepts attribute access and creates wrapper functions
58
58
  for scikit-learn ranking metrics, handling different input formats and
59
59
  preprocessing the data appropriately.
60
-
60
+
61
61
  Args:
62
62
  name (str): Name of the ranking metric to access
63
-
63
+
64
64
  Returns:
65
65
  callable: Wrapper function for the requested ranking metric
66
-
66
+
67
67
  Raises:
68
68
  AttributeError: If the requested metric is not found
69
69
  """
@@ -74,17 +74,17 @@ class RankingMetric(OmniMetric):
74
74
  def wrapper(y_true=None, y_score=None, *args, **kwargs):
75
75
  """
76
76
  Compute the ranking metric, based on the true and predicted values.
77
-
77
+
78
78
  This wrapper handles different input formats including HuggingFace
79
79
  trainer outputs and performs necessary preprocessing for ranking tasks.
80
-
80
+
81
81
  Args:
82
82
  y_true: The true values or HuggingFace EvalPrediction object
83
83
  y_score: The predicted values (scores for ranking)
84
84
  ignore_y: The value to ignore in the predictions and true values in corresponding positions
85
85
  *args: Additional positional arguments for the metric
86
86
  **kwargs: Additional keyword arguments for the metric
87
-
87
+
88
88
  Returns:
89
89
  dict: Dictionary containing the metric name and computed value
90
90
  """
@@ -121,19 +121,19 @@ class RankingMetric(OmniMetric):
121
121
  def compute(self, y_true, y_score, *args, **kwargs):
122
122
  """
123
123
  Compute the ranking metric, based on the true and predicted values.
124
-
124
+
125
125
  This method should be implemented by subclasses to provide specific
126
126
  ranking metric computation logic.
127
-
127
+
128
128
  Args:
129
129
  y_true: The true values
130
130
  y_score: The predicted values (scores for ranking)
131
131
  *args: Additional positional arguments for the metric
132
132
  **kwargs: Additional keyword arguments for the metric
133
-
133
+
134
134
  Returns:
135
135
  The computed ranking metric value
136
-
136
+
137
137
  Raises:
138
138
  NotImplementedError: If compute method is not implemented in the child class
139
139
  """
@@ -20,20 +20,20 @@ from ..abc.abstract_metric import OmniMetric
20
20
  def mcrmse(y_true, y_pred):
21
21
  """
22
22
  Compute Mean Column Root Mean Square Error (MCRMSE).
23
-
23
+
24
24
  MCRMSE is a multi-target regression metric that computes the RMSE for each target
25
25
  column and then takes the mean across all targets.
26
-
26
+
27
27
  Args:
28
28
  y_true (np.ndarray): Ground truth values with shape (n_samples, n_targets)
29
29
  y_pred (np.ndarray): Predicted values with shape (n_samples, n_targets)
30
-
30
+
31
31
  Returns:
32
32
  float: Mean Column Root Mean Square Error
33
-
33
+
34
34
  Raises:
35
35
  ValueError: If y_true and y_pred have different shapes
36
-
36
+
37
37
  Example:
38
38
  >>> y_true = np.array([[1, 2], [3, 4], [5, 6]])
39
39
  >>> y_pred = np.array([[1.1, 2.1], [2.9, 4.1], [5.2, 5.8]])
@@ -56,18 +56,18 @@ setattr(metrics, "mcrmse", mcrmse)
56
56
  class RegressionMetric(OmniMetric):
57
57
  """
58
58
  A specialized metric class for regression tasks and evaluation.
59
-
59
+
60
60
  This class provides access to regression-specific metrics from scikit-learn
61
61
  and handles different input formats including HuggingFace trainer outputs.
62
62
  It dynamically wraps scikit-learn metrics and provides a unified interface
63
63
  for computing various regression evaluation metrics.
64
-
64
+
65
65
  Attributes:
66
66
  metric_func: Custom metric function if provided
67
67
  ignore_y: Value to ignore in predictions and true values
68
68
  kwargs: Additional keyword arguments for metric computation
69
69
  metrics: Dictionary of available metrics including custom ones
70
-
70
+
71
71
  Example:
72
72
  >>> from omnigenome.src.metric import RegressionMetric
73
73
  >>> metric = RegressionMetric(ignore_y=-100)
@@ -81,7 +81,7 @@ class RegressionMetric(OmniMetric):
81
81
  def __init__(self, metric_func=None, ignore_y=-100, *args, **kwargs):
82
82
  """
83
83
  Initialize the RegressionMetric class.
84
-
84
+
85
85
  Args:
86
86
  metric_func (callable, optional): Custom metric function to use
87
87
  ignore_y (int, optional): Value to ignore in predictions and true values. Defaults to -100
@@ -97,14 +97,14 @@ class RegressionMetric(OmniMetric):
97
97
  def __getattribute__(self, name):
98
98
  """
99
99
  Dynamically create regression metric computation methods.
100
-
100
+
101
101
  This method intercepts attribute access and creates wrapper functions
102
102
  for scikit-learn regression metrics, handling different input formats and
103
103
  preprocessing the data appropriately.
104
-
104
+
105
105
  Args:
106
106
  name (str): Name of the regression metric to access
107
-
107
+
108
108
  Returns:
109
109
  callable: Wrapper function for the requested regression metric
110
110
  """
@@ -118,17 +118,17 @@ class RegressionMetric(OmniMetric):
118
118
  def wrapper(y_true=None, y_score=None, *args, **kwargs):
119
119
  """
120
120
  Compute the regression metric, based on the true and predicted values.
121
-
121
+
122
122
  This wrapper handles different input formats including HuggingFace
123
123
  trainer outputs and performs necessary preprocessing for regression tasks.
124
-
124
+
125
125
  Args:
126
126
  y_true: The true values or HuggingFace EvalPrediction object
127
127
  y_score: The predicted values
128
128
  ignore_y: The value to ignore in the predictions and true values in corresponding positions
129
129
  *args: Additional positional arguments for the metric
130
130
  **kwargs: Additional keyword arguments for the metric
131
-
131
+
132
132
  Returns:
133
133
  dict: Dictionary containing the metric name and computed value
134
134
  """
@@ -168,16 +168,16 @@ class RegressionMetric(OmniMetric):
168
168
  def compute(self, y_true, y_score, *args, **kwargs):
169
169
  """
170
170
  Compute the regression metric, based on the true and predicted values.
171
-
171
+
172
172
  Args:
173
173
  y_true: The true values
174
174
  y_score: The predicted values
175
175
  *args: Additional positional arguments for the metric
176
176
  **kwargs: Additional keyword arguments for the metric
177
-
177
+
178
178
  Returns:
179
179
  The computed regression metric value
180
-
180
+
181
181
  Raises:
182
182
  NotImplementedError: If no metric function is provided and compute is not implemented
183
183
  """
@@ -25,13 +25,13 @@ default_omnigenome_repo = (
25
25
  def seed_everything(seed=42):
26
26
  """
27
27
  Sets random seeds for reproducibility across all random number generators.
28
-
28
+
29
29
  This function sets seeds for Python's random module, NumPy, PyTorch (CPU and CUDA),
30
30
  and sets the PYTHONHASHSEED environment variable to ensure reproducible results
31
31
  across different runs.
32
-
32
+
33
33
  Args:
34
- seed (int): The seed value to use for all random number generators.
34
+ seed (int): The seed value to use for all random number generators.
35
35
  Defaults to 42.
36
36
 
37
37
  Example:
@@ -55,11 +55,11 @@ def seed_everything(seed=42):
55
55
  class RNA2StructureCache(dict):
56
56
  """
57
57
  A cache for RNA secondary structure predictions using ViennaRNA.
58
-
58
+
59
59
  This class provides a caching mechanism for RNA secondary structure predictions
60
60
  to avoid redundant computations. It supports both single sequence and batch
61
61
  processing with optional multiprocessing for improved performance.
62
-
62
+
63
63
  Attributes:
64
64
  cache (dict): Dictionary storing sequence-structure mappings
65
65
  cache_file (str): Path to the cache file on disk
@@ -69,7 +69,7 @@ class RNA2StructureCache(dict):
69
69
  def __init__(self, cache_file=None, *args, **kwargs):
70
70
  """
71
71
  Initialize the RNA structure cache.
72
-
72
+
73
73
  Args:
74
74
  cache_file (str, optional): Path to the cache file. If None, uses
75
75
  a default temporary file.
@@ -112,10 +112,10 @@ class RNA2StructureCache(dict):
112
112
  def _fold_single_sequence(self, sequence):
113
113
  """
114
114
  Predict structure for a single sequence (worker function for multiprocessing).
115
-
115
+
116
116
  Args:
117
117
  sequence (str): RNA sequence to fold
118
-
118
+
119
119
  Returns:
120
120
  tuple: (structure, mfe) tuple
121
121
  """
@@ -128,12 +128,12 @@ class RNA2StructureCache(dict):
128
128
  def fold(self, sequence, return_mfe=False, num_workers=1):
129
129
  """
130
130
  Predicts RNA secondary structure for given sequences.
131
-
131
+
132
132
  This method predicts RNA secondary structures using ViennaRNA. It supports
133
133
  both single sequences and batches of sequences. The method uses caching
134
134
  to avoid redundant predictions and supports multiprocessing for batch
135
135
  processing on non-Windows systems.
136
-
136
+
137
137
  Args:
138
138
  sequence (str or list): A single RNA sequence or a list of sequences.
139
139
  return_mfe (bool): Whether to return minimum free energy along with
@@ -150,7 +150,7 @@ class RNA2StructureCache(dict):
150
150
  >>> # Predict structure for a single sequence
151
151
  >>> structure = cache.fold("GGGAAAUCC")
152
152
  >>> print(structure) # "(((...)))"
153
-
153
+
154
154
  >>> # Predict structures for multiple sequences
155
155
  >>> structures = cache.fold(["GGGAAAUCC", "AUUGCUAA"])
156
156
  >>> print(structures) # ["(((...)))", "........"]
@@ -162,36 +162,40 @@ class RNA2StructureCache(dict):
162
162
 
163
163
  # Determine if we should use multiprocessing
164
164
  use_multiprocessing = (
165
- os.name != "nt" and # Not Windows
166
- len(sequences) > 1 and # Multiple sequences
167
- num_workers > 1 # Multiple workers requested
165
+ os.name != "nt" # Not Windows
166
+ and len(sequences) > 1 # Multiple sequences
167
+ and num_workers > 1 # Multiple workers requested
168
168
  )
169
169
 
170
170
  # Find sequences that need prediction
171
171
  sequences_to_predict = [seq for seq in sequences if seq not in self.cache]
172
-
172
+
173
173
  if sequences_to_predict:
174
174
  if use_multiprocessing:
175
175
  # Use multiprocessing for batch prediction
176
176
  if num_workers is None:
177
177
  num_workers = min(os.cpu_count(), len(sequences_to_predict))
178
-
178
+
179
179
  try:
180
180
  # Set multiprocessing start method to 'spawn' for better compatibility
181
- if multiprocessing.get_start_method(allow_none=True) != 'spawn':
182
- multiprocessing.set_start_method('spawn', force=True)
183
-
181
+ if multiprocessing.get_start_method(allow_none=True) != "spawn":
182
+ multiprocessing.set_start_method("spawn", force=True)
183
+
184
184
  with multiprocessing.Pool(num_workers) as pool:
185
185
  # Use map instead of apply_async for better error handling
186
- results = pool.map(self._fold_single_sequence, sequences_to_predict)
187
-
186
+ results = pool.map(
187
+ self._fold_single_sequence, sequences_to_predict
188
+ )
189
+
188
190
  # Update cache with results
189
191
  for seq, result in zip(sequences_to_predict, results):
190
192
  self.cache[seq] = result
191
193
  self.queue_num += 1
192
-
194
+
193
195
  except Exception as e:
194
- warnings.warn(f"Multiprocessing failed, falling back to sequential: {e}")
196
+ warnings.warn(
197
+ f"Multiprocessing failed, falling back to sequential: {e}"
198
+ )
195
199
  # Fallback to sequential processing
196
200
  for seq in sequences_to_predict:
197
201
  self.cache[seq] = self._fold_single_sequence(seq)
@@ -207,7 +211,7 @@ class RNA2StructureCache(dict):
207
211
  structures = [self.cache[seq] for seq in sequences]
208
212
  else:
209
213
  structures = [self.cache[seq][0] for seq in sequences]
210
-
214
+
211
215
  # Update cache file periodically
212
216
  self.update_cache_file(self.cache_file)
213
217
 
@@ -220,10 +224,10 @@ class RNA2StructureCache(dict):
220
224
  def update_cache_file(self, cache_file=None):
221
225
  """
222
226
  Updates the cache file on disk.
223
-
227
+
224
228
  This method saves the in-memory cache to disk. It only saves when
225
229
  the queue_num reaches 100 to avoid excessive disk I/O.
226
-
230
+
227
231
  Args:
228
232
  cache_file (str, optional): Path to the cache file. If None, uses
229
233
  the instance's cache_file.
@@ -252,11 +256,11 @@ class RNA2StructureCache(dict):
252
256
  def env_meta_info():
253
257
  """
254
258
  Collects metadata about the current environment and library versions.
255
-
259
+
256
260
  This function gathers information about the current Python environment,
257
261
  including versions of key libraries like PyTorch and Transformers,
258
262
  as well as OmniGenome version information.
259
-
263
+
260
264
  Returns:
261
265
  dict: A dictionary containing environment metadata including:
262
266
  - library_name: Name of the OmniGenome library
@@ -286,7 +290,7 @@ def env_meta_info():
286
290
  def naive_secondary_structure_repair(sequence, structure):
287
291
  """
288
292
  Repair the secondary structure of a sequence.
289
-
293
+
290
294
  This function attempts to repair malformed RNA secondary structure
291
295
  representations by ensuring proper bracket matching. It handles
292
296
  common issues like unmatched brackets by converting them to dots.
@@ -324,7 +328,7 @@ def naive_secondary_structure_repair(sequence, structure):
324
328
  def save_args(config, save_path):
325
329
  """
326
330
  Save arguments to a file.
327
-
331
+
328
332
  This function saves the arguments from a configuration object to a text file.
329
333
  It's useful for logging experiment parameters and configurations.
330
334
 
@@ -347,7 +351,7 @@ def save_args(config, save_path):
347
351
  def print_args(config, logger=None):
348
352
  """
349
353
  Print the arguments to the console.
350
-
354
+
351
355
  This function prints the arguments from a configuration object to the console
352
356
  or a logger. It's useful for debugging and logging experiment parameters.
353
357
 
@@ -373,7 +377,7 @@ def print_args(config, logger=None):
373
377
  def fprint(*objects, sep=" ", end="\n", file=sys.stdout, flush=False):
374
378
  """
375
379
  Enhanced print function with automatic flushing.
376
-
380
+
377
381
  This function provides a print-like interface with automatic flushing
378
382
  to ensure output is displayed immediately. It's useful for real-time
379
383
  logging and progress tracking.
@@ -395,7 +399,7 @@ def fprint(*objects, sep=" ", end="\n", file=sys.stdout, flush=False):
395
399
  def clean_temp_checkpoint(days_threshold=7):
396
400
  """
397
401
  Clean up temporary checkpoint files older than specified days.
398
-
402
+
399
403
  This function removes temporary checkpoint files that are older than
400
404
  the specified threshold to free up disk space.
401
405
 
@@ -431,7 +435,7 @@ def clean_temp_checkpoint(days_threshold=7):
431
435
  def load_module_from_path(module_name, file_path):
432
436
  """
433
437
  Load a Python module from a file path.
434
-
438
+
435
439
  This function dynamically loads a Python module from a file path,
436
440
  useful for loading configuration files or custom modules.
437
441
 
@@ -457,7 +461,7 @@ def load_module_from_path(module_name, file_path):
457
461
  def check_bench_version(bench_version, omnigenome_version):
458
462
  """
459
463
  Check if benchmark version is compatible with OmniGenome version.
460
-
464
+
461
465
  This function compares the benchmark version with the OmniGenome version
462
466
  to ensure compatibility and warns if there are potential issues.
463
467
 
@@ -479,7 +483,7 @@ def check_bench_version(bench_version, omnigenome_version):
479
483
  def clean_temp_dir_pt_files():
480
484
  """
481
485
  Clean up temporary PyTorch files in the current directory.
482
-
486
+
483
487
  This function removes temporary PyTorch files (like .pt, .pth files)
484
488
  that may be left over from previous runs.
485
489
 
@@ -9,4 +9,3 @@
9
9
  """
10
10
  This package contains modules for data augmentation.
11
11
  """
12
-
@@ -24,12 +24,12 @@ import autocuda
24
24
  class OmniModelForAugmentation(torch.nn.Module):
25
25
  """
26
26
  Data augmentation model for genomic sequences using masked language modeling.
27
-
27
+
28
28
  This model uses a pre-trained masked language model to generate augmented
29
29
  versions of genomic sequences by randomly masking tokens and predicting
30
30
  replacements. It's useful for expanding training datasets and improving
31
31
  model generalization.
32
-
32
+
33
33
  Attributes:
34
34
  tokenizer: Tokenizer for processing genomic sequences
35
35
  model: Pre-trained masked language model
@@ -38,7 +38,7 @@ class OmniModelForAugmentation(torch.nn.Module):
38
38
  max_length: Maximum sequence length for tokenization
39
39
  k: Number of augmented instances to generate per sequence
40
40
  """
41
-
41
+
42
42
  def __init__(
43
43
  self,
44
44
  model_name_or_path=None,
@@ -50,7 +50,7 @@ class OmniModelForAugmentation(torch.nn.Module):
50
50
  ):
51
51
  """
52
52
  Initialize the augmentation model.
53
-
53
+
54
54
  Args:
55
55
  model_name_or_path (str): Path or model name for loading the pre-trained model
56
56
  noise_ratio (float): The proportion of tokens to mask in each sequence for augmentation (default: 0.15)
@@ -82,10 +82,10 @@ class OmniModelForAugmentation(torch.nn.Module):
82
82
  def load_sequences_from_file(self, input_file):
83
83
  """
84
84
  Load sequences from a JSON file.
85
-
85
+
86
86
  Args:
87
87
  input_file (str): Path to the input JSON file containing sequences
88
-
88
+
89
89
  Returns:
90
90
  list: List of sequences loaded from the file
91
91
  """
@@ -98,10 +98,10 @@ class OmniModelForAugmentation(torch.nn.Module):
98
98
  def apply_noise_to_sequence(self, seq):
99
99
  """
100
100
  Apply noise to a single sequence by randomly masking tokens.
101
-
101
+
102
102
  Args:
103
103
  seq (str): Input genomic sequence
104
-
104
+
105
105
  Returns:
106
106
  str: Sequence with randomly masked tokens
107
107
  """
@@ -114,10 +114,10 @@ class OmniModelForAugmentation(torch.nn.Module):
114
114
  def augment_sequence(self, seq):
115
115
  """
116
116
  Perform augmentation on a single sequence by predicting masked tokens.
117
-
117
+
118
118
  Args:
119
119
  seq (str): Input genomic sequence with masked tokens
120
-
120
+
121
121
  Returns:
122
122
  str: Augmented sequence with predicted tokens replacing masked tokens
123
123
  """
@@ -145,11 +145,11 @@ class OmniModelForAugmentation(torch.nn.Module):
145
145
  def augment(self, seq, k=None):
146
146
  """
147
147
  Generate multiple augmented instances for a single sequence.
148
-
148
+
149
149
  Args:
150
150
  seq (str): Input genomic sequence
151
151
  k (int, optional): Number of augmented instances to generate (default: None, uses self.k)
152
-
152
+
153
153
  Returns:
154
154
  list: List of augmented sequences
155
155
  """
@@ -163,10 +163,10 @@ class OmniModelForAugmentation(torch.nn.Module):
163
163
  def augment_sequences(self, sequences):
164
164
  """
165
165
  Augment a list of sequences by applying noise and performing MLM-based predictions.
166
-
166
+
167
167
  Args:
168
168
  sequences (list): List of genomic sequences to augment
169
-
169
+
170
170
  Returns:
171
171
  list: List of all augmented sequences
172
172
  """
@@ -179,7 +179,7 @@ class OmniModelForAugmentation(torch.nn.Module):
179
179
  def save_augmented_sequences(self, augmented_sequences, output_file):
180
180
  """
181
181
  Save augmented sequences to a JSON file.
182
-
182
+
183
183
  Args:
184
184
  augmented_sequences (list): List of augmented sequences to save
185
185
  output_file (str): Path to the output JSON file
@@ -191,10 +191,10 @@ class OmniModelForAugmentation(torch.nn.Module):
191
191
  def augment_from_file(self, input_file, output_file):
192
192
  """
193
193
  Main function to handle the augmentation process from a file input to a file output.
194
-
194
+
195
195
  This method loads sequences from an input file, augments them using the MLM model,
196
196
  and saves the augmented sequences to an output file.
197
-
197
+
198
198
  Args:
199
199
  input_file (str): Path to the input file containing sequences
200
200
  output_file (str): Path to the output file where augmented sequences will be saved